diff options
author | Michael Suo <suo@fb.com> | 2019-02-05 20:37:30 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-05 20:43:30 -0800 |
commit | 43f4c862380a44b8820f3533285bdee209b5f0c9 (patch) | |
tree | 5ca49fd63d0fbf424def60283786bb1307f6d74f /torch | |
parent | c1dff549da73d0526cc41ec9882a44c7d8003f6b (diff) | |
download | pytorch-43f4c862380a44b8820f3533285bdee209b5f0c9.tar.gz pytorch-43f4c862380a44b8820f3533285bdee209b5f0c9.tar.bz2 pytorch-43f4c862380a44b8820f3533285bdee209b5f0c9.zip |
Fix alias analysis for fork/wait (#16671)
Summary:
(review top commit only).
As expected, fork/wait introduces some corner cases into the alias analysis. The comments inline should describe the changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16671
Differential Revision: D13963219
Pulled By: suo
fbshipit-source-id: 2bec6fc03a4989cf309fbb9473f3f2ffe2c31431
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/jit/passes/alias_analysis.cpp | 91 | ||||
-rw-r--r-- | torch/csrc/jit/passes/alias_analysis.h | 4 | ||||
-rw-r--r-- | torch/csrc/jit/passes/utils/alias_tracker.cpp | 12 | ||||
-rw-r--r-- | torch/csrc/jit/passes/utils/alias_tracker.h | 5 | ||||
-rw-r--r-- | torch/csrc/jit/register_prim_ops.cpp | 356 |
5 files changed, 294 insertions, 174 deletions
diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index f5cd1ccf13..717a96b247 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -11,6 +11,7 @@ bool shouldAnnotate(const TypePtr& type) { type->kind() == TypeKind::ListType || type->kind() == TypeKind::TupleType || type->kind() == TypeKind::DictType || type->kind() == TypeKind::VarType || + type->kind() == TypeKind::FutureType || (type->kind() == TypeKind::OptionalType && shouldAnnotate(type->cast<OptionalType>()->getElementType())); } @@ -110,6 +111,12 @@ bool AliasDb::mayAlias(const ValueSet& a, const ValueSet& b) const { return aliasTracker_->mayAlias(a, b); } +void AliasDb::getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks) const { + for (auto node : b->nodes()) { + getWritesImpl(node, ret, recurseBlocks); + } +} + void AliasDb::getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks) const { for (const auto input : n->inputs()) { if (writesTo(n, input)) { @@ -124,14 +131,20 @@ void AliasDb::getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks) const { if (recurseBlocks) { for (auto block : n->blocks()) { - for (auto node : block->nodes()) { - getWritesImpl(node, ret, recurseBlocks); - } + getWritesImpl(block, ret, recurseBlocks); } } } -ValueSet AliasDb::getWrites(Node* n, bool recurseBlocks) const { +// Get all writes by all nodes in a block, recursively exploring sub-blocks +ValueSet AliasDb::getWrites(Block* b) const { + ValueSet writes; + getWritesImpl(b, writes, /*recurseBlocks=*/true); + return writes; +} + +std::unordered_set<const Value*> AliasDb::getWrites(Node* n, bool recurseBlocks) + const { ValueSet writes; getWritesImpl(n, writes, recurseBlocks); return writes; @@ -263,6 +276,10 @@ void AliasDb::analyzeImpl(Node* node) { case prim::FusionGroup: case prim::DifferentiableGraph: return analyzeSubgraph(node); + case prim::fork: + return analyzeFork(node); + case aten::wait: + return analyzeWait(node); case prim::Constant: case prim::DictConstruct: case prim::ListConstruct: @@ -485,6 +502,72 @@ void AliasDb::analyzeChunk(Node* node) { } } +// Propagate aliasing and write information from the subgraph outputs to the +// outputs of the corresponding aten::wait() calls, since that's where the +// values will eventually emerge. +void AliasDb::analyzeFork(Node* node) { + const auto subgraph = node->g(attr::Subgraph).get(); + subgraphToOwner_.insert({subgraph, node}); + + const auto subgraphBlock = subgraph->block(); + mapAliases(subgraphBlock->inputs(), node->inputs()); + analyze(subgraphBlock); + + // Give the future that the fork emits a fresh value + for (const auto output : node->outputs()) { + giveFreshAlias(output); + } +} + +void AliasDb::analyzeWait(Node* node) { + const auto fut = node->input(); + AT_ASSERT(fut->type()->kind() == TypeKind::FutureType); + + if (aliasTracker_->isWildcard(fut)) { + for (const auto output : node->outputs()) { + aliasTracker_->setWildcard(output); + } + return; + } + + const auto originFuts = aliasTracker_->getMemoryLocations(fut); + for (const auto originFut : originFuts) { + const auto subgraphNode = originFut->node(); + + const auto subgraph = subgraphNode->g(attr::Subgraph).get(); + const auto subgraphWrites = getWrites(subgraph->block()); + + // Retrieve aliasing info from the subgraph + mapAliases(node->outputs(), subgraph->outputs()); + + // Propagate write information to the `wait` node. + // + // We need to do this for all writes in the entire subgraph, so that we + // disallow reorders past a call to "aten::wait". + // + // Consider the following Fork where the subgraph writes to %a: + // + // %c : Future[Tensor] = prim::Fork(%a, %b) <-- writes to %a + // ... + // aten::wait(%c) + // aten::use(%a) <-- we can't move this node before the `wait` safely! + // + // Say we define the "live interval" of a fork the interval between the + // `fork` and its first corresponding `wait` (inclusive). + // + // Any writes in the subgraph can happen at any point in the live interval, + // so it's not safe to re-order any reads to those memory locations from + // outside the live interval to inside. + // + // In reality, any reads *inside* the live interval are undefined behavior, + // since the writes may or may not have been executed yet. But we'll let + // users do that and shoot themselves in the foot for now. + for (const auto write : subgraphWrites) { + aliasTracker_->registerWrite(write, node); + } + } +} + // BroadcastingChunk: all inputs are broadcasted, and then individually chunked. // This is an intermediate node used only in the graph fuser. void AliasDb::analyzeBroadcastingChunk(Node* node) { diff --git a/torch/csrc/jit/passes/alias_analysis.h b/torch/csrc/jit/passes/alias_analysis.h index a6589a77bd..5532a00592 100644 --- a/torch/csrc/jit/passes/alias_analysis.h +++ b/torch/csrc/jit/passes/alias_analysis.h @@ -80,6 +80,8 @@ class AliasDb { void move(Node* toMove, Node* movePoint, MoveSide moveSide); bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const; + ValueSet getWrites(Block* b) const; + void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const; void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const; // Get all the values that `n` reads from. @@ -110,6 +112,8 @@ class AliasDb { void analyzeExtractor(Node* node); void analyzeChunk(Node* node); void analyzeBroadcastingChunk(Node* node); + void analyzeFork(Node* node); + void analyzeWait(Node* node); void makeAliasOf(const Value* value, const Value* to); void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from); diff --git a/torch/csrc/jit/passes/utils/alias_tracker.cpp b/torch/csrc/jit/passes/utils/alias_tracker.cpp index 3c3cbef3c9..a231f011cf 100644 --- a/torch/csrc/jit/passes/utils/alias_tracker.cpp +++ b/torch/csrc/jit/passes/utils/alias_tracker.cpp @@ -237,5 +237,17 @@ bool AliasTracker::Element::bfs(Fn fn, BfsDirection dir) const { } return false; } + +ValueSet AliasTracker::getMemoryLocations(const Value* v) const { + ValueSet ret; + if (!map_.count(v)) { + return ret; + } + + for (const auto el : map_.at(v)->getMemoryLocations()) { + ret.insert(el->value); + } + return ret; +} } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/utils/alias_tracker.h b/torch/csrc/jit/passes/utils/alias_tracker.h index 131c2b8059..afeae89a8d 100644 --- a/torch/csrc/jit/passes/utils/alias_tracker.h +++ b/torch/csrc/jit/passes/utils/alias_tracker.h @@ -51,6 +51,11 @@ class AliasTracker { return wildcardWriters_; } + // Get the values that represent the memory locations that `v` may point to. + // Return values are guaranteed to be "fresh" tensors--they do not point to + // anything else. + ValueSet getMemoryLocations(const Value* v) const; + // Do `a` and `b` potentially share a memory location? bool mayAlias(const Value* a, const Value* b) const; diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 6faa9bd45e..14b3c1d3f0 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -1,3 +1,4 @@ +#include <aten/src/ATen/Context.h> #include <torch/csrc/autograd/edge.h> #include <torch/csrc/autograd/function.h> #include <torch/csrc/autograd/generated/variable_factories.h> @@ -9,12 +10,11 @@ #include <torch/csrc/jit/ir.h> #include <torch/csrc/jit/operator.h> #include <torch/csrc/jit/script/jit_exception.h> -#include <aten/src/ATen/Context.h> #include <ATen/ExpandUtils.h> -#include <ATen/core/thread_pool.h> -#include <ATen/core/ivalue.h> #include <ATen/WrapDimUtils.h> +#include <ATen/core/ivalue.h> +#include <ATen/core/thread_pool.h> #include <c10/util/SmallVector.h> #include <exception> @@ -66,9 +66,8 @@ template <typename dtype> // int64_t, bool, double Operation listConstruct(int64_t num_inputs) { return [=](Stack& stack) { auto inputs = peekSlice(stack, 0, num_inputs, num_inputs); - std::vector<dtype> vals = fmap(inputs, [](const IValue& v) { - return v.to<dtype>(); - }); + std::vector<dtype> vals = + fmap(inputs, [](const IValue& v) { return v.to<dtype>(); }); drop(stack, num_inputs); push(stack, std::move(vals)); return 0; @@ -84,14 +83,18 @@ static int64_t floordiv(int64_t a, int64_t b) { return a / b; } else { // in python division rounds down, it doesnt not truncate like in c++ - auto r = lldiv(a, b); + auto r = lldiv(a, b); return (r.rem) ? r.quot - 1 : r.quot; } } // reference function THPVariable_to in python_variable_methods.cpp -static at::Tensor to_dispatch(at::Tensor self, c10::optional<at::Device> device, - c10::optional<at::ScalarType> scalarType, bool non_blocking, bool copy) { +static at::Tensor to_dispatch( + at::Tensor self, + c10::optional<at::Device> device, + c10::optional<at::ScalarType> scalarType, + bool non_blocking, + bool copy) { if (device && device->is_cuda()) { at::globalContext().lazyInitCUDA(); } @@ -134,7 +137,7 @@ RegisterOperators reg({ return [](Stack& stack) { int64_t i; pop(stack, i); - push(stack, (bool) i); + push(stack, (bool)i); return 0; }; }), @@ -144,7 +147,7 @@ RegisterOperators reg({ return [](Stack& stack) { double d; pop(stack, d); - push(stack, (bool) d); + push(stack, (bool)d); return 0; }; }), @@ -237,7 +240,7 @@ RegisterOperators reg({ return [](Stack& stack) { bool b; pop(stack, b); - push(stack, (float) b); + push(stack, (float)b); return 0; }; }), @@ -247,7 +250,7 @@ RegisterOperators reg({ return [](Stack& stack) { bool b; pop(stack, b); - push(stack, (int) b); + push(stack, (int)b); return 0; }; }), @@ -284,10 +287,14 @@ RegisterOperators reg({ bool non_blocking; bool copy; pop(stack, non_blocking, copy); - c10::optional<at::ScalarType> scalarType = pop(stack).toOptional<at::ScalarType>(); - c10::optional<c10::Device> device = pop(stack).toOptional<c10::Device>(); + c10::optional<at::ScalarType> scalarType = + pop(stack).toOptional<at::ScalarType>(); + c10::optional<c10::Device> device = + pop(stack).toOptional<c10::Device>(); at::Tensor self = pop(stack).toTensor(); - push(stack, to_dispatch(self, device, scalarType, non_blocking, copy)); + push( + stack, + to_dispatch(self, device, scalarType, non_blocking, copy)); return 0; }; }), @@ -298,10 +305,13 @@ RegisterOperators reg({ bool non_blocking; bool copy; pop(stack, non_blocking, copy); - c10::optional<at::ScalarType> scalarType = pop(stack).toOptional<at::ScalarType>(); + c10::optional<at::ScalarType> scalarType = + pop(stack).toOptional<at::ScalarType>(); c10::optional<c10::Device> device = c10::nullopt; at::Tensor self = pop(stack).toTensor(); - push(stack, to_dispatch(self, device, scalarType, non_blocking, copy)); + push( + stack, + to_dispatch(self, device, scalarType, non_blocking, copy)); return 0; }; }), @@ -315,7 +325,9 @@ RegisterOperators reg({ pop(stack, self, non_blocking, copy); c10::optional<c10::Device> device = c10::nullopt; c10::optional<at::ScalarType> scalarType = c10::nullopt; - push(stack, to_dispatch(self, device, scalarType, non_blocking, copy)); + push( + stack, + to_dispatch(self, device, scalarType, non_blocking, copy)); return 0; }; }), @@ -459,7 +471,8 @@ RegisterOperators reg({ std::vector<int64_t> last_shape = shape; int64_t dim = at::maybe_wrap_dim(raw_dim, shape.size()); AT_CHECK( - dim < (int64_t)regular_shape.size(), "Dimension out of range for chunk"); + dim < (int64_t)regular_shape.size(), + "Dimension out of range for chunk"); int64_t split_size = (regular_shape[dim] + chunks - 1) / chunks; regular_shape[dim] = split_size; if (shape[dim] % chunks == 0) { @@ -772,9 +785,9 @@ RegisterOperators reg({ [](const Node* node) -> Operation { const auto num_inputs = node->inputs().size(); ListTypePtr lt = node->output()->type()->expect<ListType>(); - if(IntType::get() == lt->getElementType()) { + if (IntType::get() == lt->getElementType()) { return listConstruct<int64_t>(num_inputs); - } else if(FloatType::get() == lt->getElementType()) { + } else if (FloatType::get() == lt->getElementType()) { return listConstruct<double>(num_inputs); } else if (lt->getElementType() == BoolType::get()) { return listConstruct<bool>(num_inputs); @@ -805,24 +818,24 @@ RegisterOperators reg({ } }), Operator( - prim::DictConstruct, - [](const Node* node) -> Operation { - const auto num_inputs = node->inputs().size(); - if (num_inputs % 2 != 0) { - throw std::runtime_error("DictConstruct must have an even number of inputs"); - } - return [=](Stack& stack) { - c10::ivalue::DictUnorderedMap<IValue, IValue> vals; - for (size_t i = 0; i < num_inputs; i += 2) { - auto val = pop(stack); - auto key = pop(stack); - vals[key] = val; + prim::DictConstruct, + [](const Node* node) -> Operation { + const auto num_inputs = node->inputs().size(); + if (num_inputs % 2 != 0) { + throw std::runtime_error( + "DictConstruct must have an even number of inputs"); } - push(stack, std::move(vals)); - return 0; - }; - } - ), + return [=](Stack& stack) { + c10::ivalue::DictUnorderedMap<IValue, IValue> vals; + for (size_t i = 0; i < num_inputs; i += 2) { + auto val = pop(stack); + auto key = pop(stack); + vals[key] = val; + } + push(stack, std::move(vals)); + return 0; + }; + }), Operator( "aten::_unwrap_optional(t(a)? optional) -> t(a)", [](const Node* node) -> Operation { @@ -833,9 +846,9 @@ RegisterOperators reg({ return 0; }; }), - // This op can be removed in preprocessing before being run in the interpreter - // (but is currently not removed), even when it is removed it needs to remain - // a registered op so that constant prop can run. + // This op can be removed in preprocessing before being run in the + // interpreter (but is currently not removed), even when it is removed it + // needs to remain a registered op so that constant prop can run. Operator("prim::unchecked_unwrap_optional(t(a)? optional) -> t(a)", noop), Operator( prim::fork, @@ -1004,7 +1017,7 @@ Operation listSelect(const Node* node) { } // needs specialization because cannot return a pointer to a bool in an array -template<> +template <> Operation listSelect<Shared<BoolList>>(const Node* node) { return [=](Stack& stack) { Shared<BoolList> list; @@ -1017,7 +1030,6 @@ Operation listSelect<Shared<BoolList>>(const Node* node) { }; } - template <typename T> Operation listLen(const Node* node) { return [=](Stack& stack) { @@ -1029,7 +1041,6 @@ Operation listLen(const Node* node) { }; } - template <typename T> Operation listEq(const Node* node) { return [=](Stack& stack) { @@ -1177,8 +1188,7 @@ Operation listSetItem(const Node* node) { }; } - -template<> +template <> Operation listSetItem<Shared<BoolList>, bool>(const Node* node) { return [](Stack& stack) { Shared<BoolList> list; @@ -1239,123 +1249,126 @@ int dictIndex(Stack& stack) { return 0; } - RegisterOperators reg2({ - #define DEFINE_STRING_OP(op_name, string_op, result) \ - Operator(#op_name "(str a, str b) ->" #result, [](const Node* node) { \ - return [=](Stack& stack) { \ - auto b = pop(stack).toStringRef(); \ - auto a = pop(stack).toStringRef(); \ - push(stack, string_op); \ - return 0; \ - }; \ - }) - - DEFINE_STRING_OP(aten::eq, a == b, bool), - DEFINE_STRING_OP(aten::ne, a != b, bool), - DEFINE_STRING_OP(aten::add, a + b, str), - #undef DEFINE_STRING_OP - - // tensor length op (size of 1st dimension) - Operator( - "aten::len(Tensor t) -> int", - [](Stack& stack) { - at::Tensor t = pop(stack).toTensor(); - if (t.dim() == 0) { - AT_ERROR("len() of a 0-d tensor"); - } - push(stack, t.sizes()[0]); - return 0; - }), - Operator( - "aten::append(Tensor[](a!) self, Tensor(c) el) -> Tensor[](a!)", - listAppend<Shared<TensorList>, at::Tensor>), - Operator( - "aten::select(Tensor[](a) list, int idx) -> Tensor(*)", - listSelect<Shared<TensorList>>), - Operator( - "aten::_set_item(Tensor[](a!) l, int idx, Tensor el) -> Tensor[](a!)", - listSetItem<Shared<TensorList>, at::Tensor>), - - // Mutable ops for lists containing immutable types. - #define CREATE_IMMUTABLE_LIST_OPS(decl_type, c_type) \ - Operator( \ - "aten::select(" decl_type "[] a, int b) -> " decl_type, \ - listSelect<Shared<c_type>>), \ - Operator( \ - "aten::append(" decl_type "[](a!) self, " decl_type \ - " el) -> " decl_type "[](a!)", \ - listAppend<Shared<c_type>, c_type::ElemType>), \ - Operator( \ - "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \ - " el) -> " decl_type "[](a!)", \ - listSetItem<Shared<c_type>, c_type::ElemType>) - - CREATE_IMMUTABLE_LIST_OPS("int", IntList), - CREATE_IMMUTABLE_LIST_OPS("float", DoubleList), - CREATE_IMMUTABLE_LIST_OPS("t", GenericList), - CREATE_IMMUTABLE_LIST_OPS("bool", BoolList), - - #define CREATE_LIST_OPS(decl_type, c_type) \ - Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \ - Operator( \ - "aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type \ - "[]", \ - listAdd<Shared<c_type>, c_type::ElemType>), \ - Operator( \ - "aten::slice(" decl_type \ - "[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \ - "[]", \ - listSlice<Shared<c_type>, c_type::ElemType>), \ - Operator("aten::list(" decl_type "[] l) -> " decl_type "[]", listList) - - CREATE_LIST_OPS("int", IntList), - CREATE_LIST_OPS("float", DoubleList), - CREATE_LIST_OPS("Tensor", TensorList), - CREATE_LIST_OPS("t", GenericList), - #undef CREATE_LIST_OPS - - Operator("aten::eq(int[] a, int[] b) -> bool", listEq<Shared<IntList>>), - Operator( - "aten::eq(float[] a, float[] b) -> bool", - listEq<Shared<DoubleList>>), - Operator( - "aten::eq(Tensor[] a, Tensor[] b) -> bool", - listEq<Shared<TensorList>>), - Operator( - "aten::eq(bool[] a, bool[] b) -> bool", - listEq<Shared<BoolList>>), - Operator("aten::ne(int[] a, int[] b) -> bool", listNe<Shared<IntList>>), - Operator( - "aten::ne(float[] a, float[] b) -> bool", - listNe<Shared<DoubleList>>), - Operator( - "aten::ne(Tensor[] a, Tensor[] b) -> bool", - listNe<Shared<TensorList>>), - Operator( - "aten::ne(bool[] a, bool[] b) -> bool", - listNe<Shared<BoolList>>), - - - #define CREATE_COPY_OP(other_type, c_type) \ - Operator( \ - "aten::copy_(Tensor(a!) self, " #other_type " other) -> Tensor(a!)", \ - [](const Node* node) { \ - return [=](Stack& stack) { \ - at::Tensor t; \ - c_type other; \ - pop(stack, t, other); \ - std::move(t) = other; /* NOLINT(bugprone-use-after-move) */ \ - push(stack, std::move(t)); /* NOLINT(bugprone-use-after-move) */ \ - return 0; \ - }; \ - }) - - CREATE_COPY_OP(Tensor, at::Tensor), - CREATE_COPY_OP(int, int64_t), - CREATE_COPY_OP(float, double), - #undef CREATE_COPY_OP +#define DEFINE_STRING_OP(op_name, string_op, result) \ + Operator(#op_name "(str a, str b) ->" #result, [](const Node* node) { \ + return [=](Stack& stack) { \ + auto b = pop(stack).toStringRef(); \ + auto a = pop(stack).toStringRef(); \ + push(stack, string_op); \ + return 0; \ + }; \ + }) + + DEFINE_STRING_OP(aten::eq, a == b, bool), + DEFINE_STRING_OP(aten::ne, a != b, bool), + DEFINE_STRING_OP(aten::add, a + b, str), +#undef DEFINE_STRING_OP + + // tensor length op (size of 1st dimension) + Operator( + "aten::len(Tensor t) -> int", + [](Stack& stack) { + at::Tensor t = pop(stack).toTensor(); + if (t.dim() == 0) { + AT_ERROR("len() of a 0-d tensor"); + } + push(stack, t.sizes()[0]); + return 0; + }), +// Mutable ops for lists containing mutable types. +#define CREATE_MUTABLE_LIST_OPS(decl_type, c_type) \ + Operator( \ + "aten::select(" decl_type "[](a) list, int idx) -> " decl_type "(*)", \ + listSelect<Shared<c_type>>), \ + Operator( \ + "aten::append( " decl_type "[](a!) self, " decl_type \ + "(c) el) -> " decl_type "[](a!)", \ + listAppend<Shared<c_type>, c_type::ElemType>), \ + Operator( \ + "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \ + " el) -> " decl_type "[](a!)", \ + listSetItem<Shared<c_type>, c_type::ElemType>) + + CREATE_MUTABLE_LIST_OPS("Tensor", TensorList), + +// Mutable ops for lists containing immutable types. +#define CREATE_IMMUTABLE_LIST_OPS(decl_type, c_type) \ + Operator( \ + "aten::select(" decl_type "[] a, int b) -> " decl_type, \ + listSelect<Shared<c_type>>), \ + Operator( \ + "aten::append(" decl_type "[](a!) self, " decl_type \ + " el) -> " decl_type "[](a!)", \ + listAppend<Shared<c_type>, c_type::ElemType>), \ + Operator( \ + "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \ + " el) -> " decl_type "[](a!)", \ + listSetItem<Shared<c_type>, c_type::ElemType>) + + CREATE_IMMUTABLE_LIST_OPS("int", IntList), + CREATE_IMMUTABLE_LIST_OPS("float", DoubleList), + CREATE_IMMUTABLE_LIST_OPS("bool", BoolList), + + // NOTE: this must be after the other list specializations so that operator + // resolution doesn't pick this up first + CREATE_MUTABLE_LIST_OPS("t", GenericList), + +#define CREATE_LIST_OPS(decl_type, c_type) \ + Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \ + Operator( \ + "aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type \ + "[]", \ + listAdd<Shared<c_type>, c_type::ElemType>), \ + Operator( \ + "aten::slice(" decl_type \ + "[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \ + "[]", \ + listSlice<Shared<c_type>, c_type::ElemType>), \ + Operator("aten::list(" decl_type "[] l) -> " decl_type "[]", listList) + + CREATE_LIST_OPS("int", IntList), + CREATE_LIST_OPS("float", DoubleList), + CREATE_LIST_OPS("Tensor", TensorList), + CREATE_LIST_OPS("t", GenericList), +#undef CREATE_LIST_OPS + + Operator("aten::eq(int[] a, int[] b) -> bool", listEq<Shared<IntList>>), + Operator( + "aten::eq(float[] a, float[] b) -> bool", + listEq<Shared<DoubleList>>), + Operator( + "aten::eq(Tensor[] a, Tensor[] b) -> bool", + listEq<Shared<TensorList>>), + Operator("aten::eq(bool[] a, bool[] b) -> bool", listEq<Shared<BoolList>>), + Operator("aten::ne(int[] a, int[] b) -> bool", listNe<Shared<IntList>>), + Operator( + "aten::ne(float[] a, float[] b) -> bool", + listNe<Shared<DoubleList>>), + Operator( + "aten::ne(Tensor[] a, Tensor[] b) -> bool", + listNe<Shared<TensorList>>), + Operator("aten::ne(bool[] a, bool[] b) -> bool", listNe<Shared<BoolList>>), + +#define CREATE_COPY_OP(other_type, c_type) \ + Operator( \ + "aten::copy_(Tensor(a!) self, " #other_type " other) -> Tensor(a!)", \ + [](const Node* node) { \ + return [=](Stack& stack) { \ + at::Tensor t; \ + c_type other; \ + pop(stack, t, other); \ + std::move(t) = other; /* NOLINT(bugprone-use-after-move) */ \ + push(stack, std::move(t)); /* NOLINT(bugprone-use-after-move) */ \ + return 0; \ + }; \ + }) + + CREATE_COPY_OP(Tensor, at::Tensor), + CREATE_COPY_OP(int, int64_t), + CREATE_COPY_OP(float, double), +#undef CREATE_COPY_OP DEFINE_BINARY_OP(aten::add, a + b), DEFINE_BINARY_OP(aten::sub, a - b), @@ -1510,18 +1523,21 @@ RegisterOperators reg2({ return 0; }; }), - #define CREATE_DICT_OPS(key_type) \ - Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \ - Operator( \ - "aten::keys(Dict(" key_type ", t) self) -> " key_type "[]", \ - dictKeys), \ - Operator("aten::values(Dict(" key_type ", t) self) -> t[]", dictValues), \ - Operator("prim::DictIndex(Dict(" key_type ", t) self, " key_type " key) -> t", dictIndex) +#define CREATE_DICT_OPS(key_type) \ + Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \ + Operator( \ + "aten::keys(Dict(" key_type ", t) self) -> " key_type "[]", \ + dictKeys), \ + Operator("aten::values(Dict(" key_type ", t) self) -> t[]", dictValues), \ + Operator( \ + "prim::DictIndex(Dict(" key_type ", t) self, " key_type \ + " key) -> t", \ + dictIndex) CREATE_DICT_OPS("str"), CREATE_DICT_OPS("int"), CREATE_DICT_OPS("float"), - #undef CREATE_DICT_OPS +#undef CREATE_DICT_OPS }); // reference: _output_size in torch/nn/functional.py |