summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/jit/passes/alias_analysis.cpp91
-rw-r--r--torch/csrc/jit/passes/alias_analysis.h4
-rw-r--r--torch/csrc/jit/passes/utils/alias_tracker.cpp12
-rw-r--r--torch/csrc/jit/passes/utils/alias_tracker.h5
-rw-r--r--torch/csrc/jit/register_prim_ops.cpp356
5 files changed, 294 insertions, 174 deletions
diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp
index f5cd1ccf13..717a96b247 100644
--- a/torch/csrc/jit/passes/alias_analysis.cpp
+++ b/torch/csrc/jit/passes/alias_analysis.cpp
@@ -11,6 +11,7 @@ bool shouldAnnotate(const TypePtr& type) {
type->kind() == TypeKind::ListType ||
type->kind() == TypeKind::TupleType ||
type->kind() == TypeKind::DictType || type->kind() == TypeKind::VarType ||
+ type->kind() == TypeKind::FutureType ||
(type->kind() == TypeKind::OptionalType &&
shouldAnnotate(type->cast<OptionalType>()->getElementType()));
}
@@ -110,6 +111,12 @@ bool AliasDb::mayAlias(const ValueSet& a, const ValueSet& b) const {
return aliasTracker_->mayAlias(a, b);
}
+void AliasDb::getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks) const {
+ for (auto node : b->nodes()) {
+ getWritesImpl(node, ret, recurseBlocks);
+ }
+}
+
void AliasDb::getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks) const {
for (const auto input : n->inputs()) {
if (writesTo(n, input)) {
@@ -124,14 +131,20 @@ void AliasDb::getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks) const {
if (recurseBlocks) {
for (auto block : n->blocks()) {
- for (auto node : block->nodes()) {
- getWritesImpl(node, ret, recurseBlocks);
- }
+ getWritesImpl(block, ret, recurseBlocks);
}
}
}
-ValueSet AliasDb::getWrites(Node* n, bool recurseBlocks) const {
+// Get all writes by all nodes in a block, recursively exploring sub-blocks
+ValueSet AliasDb::getWrites(Block* b) const {
+ ValueSet writes;
+ getWritesImpl(b, writes, /*recurseBlocks=*/true);
+ return writes;
+}
+
+std::unordered_set<const Value*> AliasDb::getWrites(Node* n, bool recurseBlocks)
+ const {
ValueSet writes;
getWritesImpl(n, writes, recurseBlocks);
return writes;
@@ -263,6 +276,10 @@ void AliasDb::analyzeImpl(Node* node) {
case prim::FusionGroup:
case prim::DifferentiableGraph:
return analyzeSubgraph(node);
+ case prim::fork:
+ return analyzeFork(node);
+ case aten::wait:
+ return analyzeWait(node);
case prim::Constant:
case prim::DictConstruct:
case prim::ListConstruct:
@@ -485,6 +502,72 @@ void AliasDb::analyzeChunk(Node* node) {
}
}
+// Propagate aliasing and write information from the subgraph outputs to the
+// outputs of the corresponding aten::wait() calls, since that's where the
+// values will eventually emerge.
+void AliasDb::analyzeFork(Node* node) {
+ const auto subgraph = node->g(attr::Subgraph).get();
+ subgraphToOwner_.insert({subgraph, node});
+
+ const auto subgraphBlock = subgraph->block();
+ mapAliases(subgraphBlock->inputs(), node->inputs());
+ analyze(subgraphBlock);
+
+ // Give the future that the fork emits a fresh value
+ for (const auto output : node->outputs()) {
+ giveFreshAlias(output);
+ }
+}
+
+void AliasDb::analyzeWait(Node* node) {
+ const auto fut = node->input();
+ AT_ASSERT(fut->type()->kind() == TypeKind::FutureType);
+
+ if (aliasTracker_->isWildcard(fut)) {
+ for (const auto output : node->outputs()) {
+ aliasTracker_->setWildcard(output);
+ }
+ return;
+ }
+
+ const auto originFuts = aliasTracker_->getMemoryLocations(fut);
+ for (const auto originFut : originFuts) {
+ const auto subgraphNode = originFut->node();
+
+ const auto subgraph = subgraphNode->g(attr::Subgraph).get();
+ const auto subgraphWrites = getWrites(subgraph->block());
+
+ // Retrieve aliasing info from the subgraph
+ mapAliases(node->outputs(), subgraph->outputs());
+
+ // Propagate write information to the `wait` node.
+ //
+ // We need to do this for all writes in the entire subgraph, so that we
+ // disallow reorders past a call to "aten::wait".
+ //
+ // Consider the following Fork where the subgraph writes to %a:
+ //
+ // %c : Future[Tensor] = prim::Fork(%a, %b) <-- writes to %a
+ // ...
+ // aten::wait(%c)
+ // aten::use(%a) <-- we can't move this node before the `wait` safely!
+ //
+ // Say we define the "live interval" of a fork the interval between the
+ // `fork` and its first corresponding `wait` (inclusive).
+ //
+ // Any writes in the subgraph can happen at any point in the live interval,
+ // so it's not safe to re-order any reads to those memory locations from
+ // outside the live interval to inside.
+ //
+ // In reality, any reads *inside* the live interval are undefined behavior,
+ // since the writes may or may not have been executed yet. But we'll let
+ // users do that and shoot themselves in the foot for now.
+ for (const auto write : subgraphWrites) {
+ aliasTracker_->registerWrite(write, node);
+ }
+ }
+}
+
// BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
// This is an intermediate node used only in the graph fuser.
void AliasDb::analyzeBroadcastingChunk(Node* node) {
diff --git a/torch/csrc/jit/passes/alias_analysis.h b/torch/csrc/jit/passes/alias_analysis.h
index a6589a77bd..5532a00592 100644
--- a/torch/csrc/jit/passes/alias_analysis.h
+++ b/torch/csrc/jit/passes/alias_analysis.h
@@ -80,6 +80,8 @@ class AliasDb {
void move(Node* toMove, Node* movePoint, MoveSide moveSide);
bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
+ ValueSet getWrites(Block* b) const;
+ void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const;
void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
// Get all the values that `n` reads from.
@@ -110,6 +112,8 @@ class AliasDb {
void analyzeExtractor(Node* node);
void analyzeChunk(Node* node);
void analyzeBroadcastingChunk(Node* node);
+ void analyzeFork(Node* node);
+ void analyzeWait(Node* node);
void makeAliasOf(const Value* value, const Value* to);
void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
diff --git a/torch/csrc/jit/passes/utils/alias_tracker.cpp b/torch/csrc/jit/passes/utils/alias_tracker.cpp
index 3c3cbef3c9..a231f011cf 100644
--- a/torch/csrc/jit/passes/utils/alias_tracker.cpp
+++ b/torch/csrc/jit/passes/utils/alias_tracker.cpp
@@ -237,5 +237,17 @@ bool AliasTracker::Element::bfs(Fn fn, BfsDirection dir) const {
}
return false;
}
+
+ValueSet AliasTracker::getMemoryLocations(const Value* v) const {
+ ValueSet ret;
+ if (!map_.count(v)) {
+ return ret;
+ }
+
+ for (const auto el : map_.at(v)->getMemoryLocations()) {
+ ret.insert(el->value);
+ }
+ return ret;
+}
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/passes/utils/alias_tracker.h b/torch/csrc/jit/passes/utils/alias_tracker.h
index 131c2b8059..afeae89a8d 100644
--- a/torch/csrc/jit/passes/utils/alias_tracker.h
+++ b/torch/csrc/jit/passes/utils/alias_tracker.h
@@ -51,6 +51,11 @@ class AliasTracker {
return wildcardWriters_;
}
+ // Get the values that represent the memory locations that `v` may point to.
+ // Return values are guaranteed to be "fresh" tensors--they do not point to
+ // anything else.
+ ValueSet getMemoryLocations(const Value* v) const;
+
// Do `a` and `b` potentially share a memory location?
bool mayAlias(const Value* a, const Value* b) const;
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index 6faa9bd45e..14b3c1d3f0 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -1,3 +1,4 @@
+#include <aten/src/ATen/Context.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
@@ -9,12 +10,11 @@
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/jit_exception.h>
-#include <aten/src/ATen/Context.h>
#include <ATen/ExpandUtils.h>
-#include <ATen/core/thread_pool.h>
-#include <ATen/core/ivalue.h>
#include <ATen/WrapDimUtils.h>
+#include <ATen/core/ivalue.h>
+#include <ATen/core/thread_pool.h>
#include <c10/util/SmallVector.h>
#include <exception>
@@ -66,9 +66,8 @@ template <typename dtype> // int64_t, bool, double
Operation listConstruct(int64_t num_inputs) {
return [=](Stack& stack) {
auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
- std::vector<dtype> vals = fmap(inputs, [](const IValue& v) {
- return v.to<dtype>();
- });
+ std::vector<dtype> vals =
+ fmap(inputs, [](const IValue& v) { return v.to<dtype>(); });
drop(stack, num_inputs);
push(stack, std::move(vals));
return 0;
@@ -84,14 +83,18 @@ static int64_t floordiv(int64_t a, int64_t b) {
return a / b;
} else {
// in python division rounds down, it doesnt not truncate like in c++
- auto r = lldiv(a, b);
+ auto r = lldiv(a, b);
return (r.rem) ? r.quot - 1 : r.quot;
}
}
// reference function THPVariable_to in python_variable_methods.cpp
-static at::Tensor to_dispatch(at::Tensor self, c10::optional<at::Device> device,
- c10::optional<at::ScalarType> scalarType, bool non_blocking, bool copy) {
+static at::Tensor to_dispatch(
+ at::Tensor self,
+ c10::optional<at::Device> device,
+ c10::optional<at::ScalarType> scalarType,
+ bool non_blocking,
+ bool copy) {
if (device && device->is_cuda()) {
at::globalContext().lazyInitCUDA();
}
@@ -134,7 +137,7 @@ RegisterOperators reg({
return [](Stack& stack) {
int64_t i;
pop(stack, i);
- push(stack, (bool) i);
+ push(stack, (bool)i);
return 0;
};
}),
@@ -144,7 +147,7 @@ RegisterOperators reg({
return [](Stack& stack) {
double d;
pop(stack, d);
- push(stack, (bool) d);
+ push(stack, (bool)d);
return 0;
};
}),
@@ -237,7 +240,7 @@ RegisterOperators reg({
return [](Stack& stack) {
bool b;
pop(stack, b);
- push(stack, (float) b);
+ push(stack, (float)b);
return 0;
};
}),
@@ -247,7 +250,7 @@ RegisterOperators reg({
return [](Stack& stack) {
bool b;
pop(stack, b);
- push(stack, (int) b);
+ push(stack, (int)b);
return 0;
};
}),
@@ -284,10 +287,14 @@ RegisterOperators reg({
bool non_blocking;
bool copy;
pop(stack, non_blocking, copy);
- c10::optional<at::ScalarType> scalarType = pop(stack).toOptional<at::ScalarType>();
- c10::optional<c10::Device> device = pop(stack).toOptional<c10::Device>();
+ c10::optional<at::ScalarType> scalarType =
+ pop(stack).toOptional<at::ScalarType>();
+ c10::optional<c10::Device> device =
+ pop(stack).toOptional<c10::Device>();
at::Tensor self = pop(stack).toTensor();
- push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
+ push(
+ stack,
+ to_dispatch(self, device, scalarType, non_blocking, copy));
return 0;
};
}),
@@ -298,10 +305,13 @@ RegisterOperators reg({
bool non_blocking;
bool copy;
pop(stack, non_blocking, copy);
- c10::optional<at::ScalarType> scalarType = pop(stack).toOptional<at::ScalarType>();
+ c10::optional<at::ScalarType> scalarType =
+ pop(stack).toOptional<at::ScalarType>();
c10::optional<c10::Device> device = c10::nullopt;
at::Tensor self = pop(stack).toTensor();
- push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
+ push(
+ stack,
+ to_dispatch(self, device, scalarType, non_blocking, copy));
return 0;
};
}),
@@ -315,7 +325,9 @@ RegisterOperators reg({
pop(stack, self, non_blocking, copy);
c10::optional<c10::Device> device = c10::nullopt;
c10::optional<at::ScalarType> scalarType = c10::nullopt;
- push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
+ push(
+ stack,
+ to_dispatch(self, device, scalarType, non_blocking, copy));
return 0;
};
}),
@@ -459,7 +471,8 @@ RegisterOperators reg({
std::vector<int64_t> last_shape = shape;
int64_t dim = at::maybe_wrap_dim(raw_dim, shape.size());
AT_CHECK(
- dim < (int64_t)regular_shape.size(), "Dimension out of range for chunk");
+ dim < (int64_t)regular_shape.size(),
+ "Dimension out of range for chunk");
int64_t split_size = (regular_shape[dim] + chunks - 1) / chunks;
regular_shape[dim] = split_size;
if (shape[dim] % chunks == 0) {
@@ -772,9 +785,9 @@ RegisterOperators reg({
[](const Node* node) -> Operation {
const auto num_inputs = node->inputs().size();
ListTypePtr lt = node->output()->type()->expect<ListType>();
- if(IntType::get() == lt->getElementType()) {
+ if (IntType::get() == lt->getElementType()) {
return listConstruct<int64_t>(num_inputs);
- } else if(FloatType::get() == lt->getElementType()) {
+ } else if (FloatType::get() == lt->getElementType()) {
return listConstruct<double>(num_inputs);
} else if (lt->getElementType() == BoolType::get()) {
return listConstruct<bool>(num_inputs);
@@ -805,24 +818,24 @@ RegisterOperators reg({
}
}),
Operator(
- prim::DictConstruct,
- [](const Node* node) -> Operation {
- const auto num_inputs = node->inputs().size();
- if (num_inputs % 2 != 0) {
- throw std::runtime_error("DictConstruct must have an even number of inputs");
- }
- return [=](Stack& stack) {
- c10::ivalue::DictUnorderedMap<IValue, IValue> vals;
- for (size_t i = 0; i < num_inputs; i += 2) {
- auto val = pop(stack);
- auto key = pop(stack);
- vals[key] = val;
+ prim::DictConstruct,
+ [](const Node* node) -> Operation {
+ const auto num_inputs = node->inputs().size();
+ if (num_inputs % 2 != 0) {
+ throw std::runtime_error(
+ "DictConstruct must have an even number of inputs");
}
- push(stack, std::move(vals));
- return 0;
- };
- }
- ),
+ return [=](Stack& stack) {
+ c10::ivalue::DictUnorderedMap<IValue, IValue> vals;
+ for (size_t i = 0; i < num_inputs; i += 2) {
+ auto val = pop(stack);
+ auto key = pop(stack);
+ vals[key] = val;
+ }
+ push(stack, std::move(vals));
+ return 0;
+ };
+ }),
Operator(
"aten::_unwrap_optional(t(a)? optional) -> t(a)",
[](const Node* node) -> Operation {
@@ -833,9 +846,9 @@ RegisterOperators reg({
return 0;
};
}),
- // This op can be removed in preprocessing before being run in the interpreter
- // (but is currently not removed), even when it is removed it needs to remain
- // a registered op so that constant prop can run.
+ // This op can be removed in preprocessing before being run in the
+ // interpreter (but is currently not removed), even when it is removed it
+ // needs to remain a registered op so that constant prop can run.
Operator("prim::unchecked_unwrap_optional(t(a)? optional) -> t(a)", noop),
Operator(
prim::fork,
@@ -1004,7 +1017,7 @@ Operation listSelect(const Node* node) {
}
// needs specialization because cannot return a pointer to a bool in an array
-template<>
+template <>
Operation listSelect<Shared<BoolList>>(const Node* node) {
return [=](Stack& stack) {
Shared<BoolList> list;
@@ -1017,7 +1030,6 @@ Operation listSelect<Shared<BoolList>>(const Node* node) {
};
}
-
template <typename T>
Operation listLen(const Node* node) {
return [=](Stack& stack) {
@@ -1029,7 +1041,6 @@ Operation listLen(const Node* node) {
};
}
-
template <typename T>
Operation listEq(const Node* node) {
return [=](Stack& stack) {
@@ -1177,8 +1188,7 @@ Operation listSetItem(const Node* node) {
};
}
-
-template<>
+template <>
Operation listSetItem<Shared<BoolList>, bool>(const Node* node) {
return [](Stack& stack) {
Shared<BoolList> list;
@@ -1239,123 +1249,126 @@ int dictIndex(Stack& stack) {
return 0;
}
-
RegisterOperators reg2({
- #define DEFINE_STRING_OP(op_name, string_op, result) \
- Operator(#op_name "(str a, str b) ->" #result, [](const Node* node) { \
- return [=](Stack& stack) { \
- auto b = pop(stack).toStringRef(); \
- auto a = pop(stack).toStringRef(); \
- push(stack, string_op); \
- return 0; \
- }; \
- })
-
- DEFINE_STRING_OP(aten::eq, a == b, bool),
- DEFINE_STRING_OP(aten::ne, a != b, bool),
- DEFINE_STRING_OP(aten::add, a + b, str),
- #undef DEFINE_STRING_OP
-
- // tensor length op (size of 1st dimension)
- Operator(
- "aten::len(Tensor t) -> int",
- [](Stack& stack) {
- at::Tensor t = pop(stack).toTensor();
- if (t.dim() == 0) {
- AT_ERROR("len() of a 0-d tensor");
- }
- push(stack, t.sizes()[0]);
- return 0;
- }),
- Operator(
- "aten::append(Tensor[](a!) self, Tensor(c) el) -> Tensor[](a!)",
- listAppend<Shared<TensorList>, at::Tensor>),
- Operator(
- "aten::select(Tensor[](a) list, int idx) -> Tensor(*)",
- listSelect<Shared<TensorList>>),
- Operator(
- "aten::_set_item(Tensor[](a!) l, int idx, Tensor el) -> Tensor[](a!)",
- listSetItem<Shared<TensorList>, at::Tensor>),
-
- // Mutable ops for lists containing immutable types.
- #define CREATE_IMMUTABLE_LIST_OPS(decl_type, c_type) \
- Operator( \
- "aten::select(" decl_type "[] a, int b) -> " decl_type, \
- listSelect<Shared<c_type>>), \
- Operator( \
- "aten::append(" decl_type "[](a!) self, " decl_type \
- " el) -> " decl_type "[](a!)", \
- listAppend<Shared<c_type>, c_type::ElemType>), \
- Operator( \
- "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \
- " el) -> " decl_type "[](a!)", \
- listSetItem<Shared<c_type>, c_type::ElemType>)
-
- CREATE_IMMUTABLE_LIST_OPS("int", IntList),
- CREATE_IMMUTABLE_LIST_OPS("float", DoubleList),
- CREATE_IMMUTABLE_LIST_OPS("t", GenericList),
- CREATE_IMMUTABLE_LIST_OPS("bool", BoolList),
-
- #define CREATE_LIST_OPS(decl_type, c_type) \
- Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
- Operator( \
- "aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type \
- "[]", \
- listAdd<Shared<c_type>, c_type::ElemType>), \
- Operator( \
- "aten::slice(" decl_type \
- "[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \
- "[]", \
- listSlice<Shared<c_type>, c_type::ElemType>), \
- Operator("aten::list(" decl_type "[] l) -> " decl_type "[]", listList)
-
- CREATE_LIST_OPS("int", IntList),
- CREATE_LIST_OPS("float", DoubleList),
- CREATE_LIST_OPS("Tensor", TensorList),
- CREATE_LIST_OPS("t", GenericList),
- #undef CREATE_LIST_OPS
-
- Operator("aten::eq(int[] a, int[] b) -> bool", listEq<Shared<IntList>>),
- Operator(
- "aten::eq(float[] a, float[] b) -> bool",
- listEq<Shared<DoubleList>>),
- Operator(
- "aten::eq(Tensor[] a, Tensor[] b) -> bool",
- listEq<Shared<TensorList>>),
- Operator(
- "aten::eq(bool[] a, bool[] b) -> bool",
- listEq<Shared<BoolList>>),
- Operator("aten::ne(int[] a, int[] b) -> bool", listNe<Shared<IntList>>),
- Operator(
- "aten::ne(float[] a, float[] b) -> bool",
- listNe<Shared<DoubleList>>),
- Operator(
- "aten::ne(Tensor[] a, Tensor[] b) -> bool",
- listNe<Shared<TensorList>>),
- Operator(
- "aten::ne(bool[] a, bool[] b) -> bool",
- listNe<Shared<BoolList>>),
-
-
- #define CREATE_COPY_OP(other_type, c_type) \
- Operator( \
- "aten::copy_(Tensor(a!) self, " #other_type " other) -> Tensor(a!)", \
- [](const Node* node) { \
- return [=](Stack& stack) { \
- at::Tensor t; \
- c_type other; \
- pop(stack, t, other); \
- std::move(t) = other; /* NOLINT(bugprone-use-after-move) */ \
- push(stack, std::move(t)); /* NOLINT(bugprone-use-after-move) */ \
- return 0; \
- }; \
- })
-
- CREATE_COPY_OP(Tensor, at::Tensor),
- CREATE_COPY_OP(int, int64_t),
- CREATE_COPY_OP(float, double),
- #undef CREATE_COPY_OP
+#define DEFINE_STRING_OP(op_name, string_op, result) \
+ Operator(#op_name "(str a, str b) ->" #result, [](const Node* node) { \
+ return [=](Stack& stack) { \
+ auto b = pop(stack).toStringRef(); \
+ auto a = pop(stack).toStringRef(); \
+ push(stack, string_op); \
+ return 0; \
+ }; \
+ })
+
+ DEFINE_STRING_OP(aten::eq, a == b, bool),
+ DEFINE_STRING_OP(aten::ne, a != b, bool),
+ DEFINE_STRING_OP(aten::add, a + b, str),
+#undef DEFINE_STRING_OP
+
+ // tensor length op (size of 1st dimension)
+ Operator(
+ "aten::len(Tensor t) -> int",
+ [](Stack& stack) {
+ at::Tensor t = pop(stack).toTensor();
+ if (t.dim() == 0) {
+ AT_ERROR("len() of a 0-d tensor");
+ }
+ push(stack, t.sizes()[0]);
+ return 0;
+ }),
+// Mutable ops for lists containing mutable types.
+#define CREATE_MUTABLE_LIST_OPS(decl_type, c_type) \
+ Operator( \
+ "aten::select(" decl_type "[](a) list, int idx) -> " decl_type "(*)", \
+ listSelect<Shared<c_type>>), \
+ Operator( \
+ "aten::append( " decl_type "[](a!) self, " decl_type \
+ "(c) el) -> " decl_type "[](a!)", \
+ listAppend<Shared<c_type>, c_type::ElemType>), \
+ Operator( \
+ "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \
+ " el) -> " decl_type "[](a!)", \
+ listSetItem<Shared<c_type>, c_type::ElemType>)
+
+ CREATE_MUTABLE_LIST_OPS("Tensor", TensorList),
+
+// Mutable ops for lists containing immutable types.
+#define CREATE_IMMUTABLE_LIST_OPS(decl_type, c_type) \
+ Operator( \
+ "aten::select(" decl_type "[] a, int b) -> " decl_type, \
+ listSelect<Shared<c_type>>), \
+ Operator( \
+ "aten::append(" decl_type "[](a!) self, " decl_type \
+ " el) -> " decl_type "[](a!)", \
+ listAppend<Shared<c_type>, c_type::ElemType>), \
+ Operator( \
+ "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \
+ " el) -> " decl_type "[](a!)", \
+ listSetItem<Shared<c_type>, c_type::ElemType>)
+
+ CREATE_IMMUTABLE_LIST_OPS("int", IntList),
+ CREATE_IMMUTABLE_LIST_OPS("float", DoubleList),
+ CREATE_IMMUTABLE_LIST_OPS("bool", BoolList),
+
+ // NOTE: this must be after the other list specializations so that operator
+ // resolution doesn't pick this up first
+ CREATE_MUTABLE_LIST_OPS("t", GenericList),
+
+#define CREATE_LIST_OPS(decl_type, c_type) \
+ Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
+ Operator( \
+ "aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type \
+ "[]", \
+ listAdd<Shared<c_type>, c_type::ElemType>), \
+ Operator( \
+ "aten::slice(" decl_type \
+ "[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \
+ "[]", \
+ listSlice<Shared<c_type>, c_type::ElemType>), \
+ Operator("aten::list(" decl_type "[] l) -> " decl_type "[]", listList)
+
+ CREATE_LIST_OPS("int", IntList),
+ CREATE_LIST_OPS("float", DoubleList),
+ CREATE_LIST_OPS("Tensor", TensorList),
+ CREATE_LIST_OPS("t", GenericList),
+#undef CREATE_LIST_OPS
+
+ Operator("aten::eq(int[] a, int[] b) -> bool", listEq<Shared<IntList>>),
+ Operator(
+ "aten::eq(float[] a, float[] b) -> bool",
+ listEq<Shared<DoubleList>>),
+ Operator(
+ "aten::eq(Tensor[] a, Tensor[] b) -> bool",
+ listEq<Shared<TensorList>>),
+ Operator("aten::eq(bool[] a, bool[] b) -> bool", listEq<Shared<BoolList>>),
+ Operator("aten::ne(int[] a, int[] b) -> bool", listNe<Shared<IntList>>),
+ Operator(
+ "aten::ne(float[] a, float[] b) -> bool",
+ listNe<Shared<DoubleList>>),
+ Operator(
+ "aten::ne(Tensor[] a, Tensor[] b) -> bool",
+ listNe<Shared<TensorList>>),
+ Operator("aten::ne(bool[] a, bool[] b) -> bool", listNe<Shared<BoolList>>),
+
+#define CREATE_COPY_OP(other_type, c_type) \
+ Operator( \
+ "aten::copy_(Tensor(a!) self, " #other_type " other) -> Tensor(a!)", \
+ [](const Node* node) { \
+ return [=](Stack& stack) { \
+ at::Tensor t; \
+ c_type other; \
+ pop(stack, t, other); \
+ std::move(t) = other; /* NOLINT(bugprone-use-after-move) */ \
+ push(stack, std::move(t)); /* NOLINT(bugprone-use-after-move) */ \
+ return 0; \
+ }; \
+ })
+
+ CREATE_COPY_OP(Tensor, at::Tensor),
+ CREATE_COPY_OP(int, int64_t),
+ CREATE_COPY_OP(float, double),
+#undef CREATE_COPY_OP
DEFINE_BINARY_OP(aten::add, a + b),
DEFINE_BINARY_OP(aten::sub, a - b),
@@ -1510,18 +1523,21 @@ RegisterOperators reg2({
return 0;
};
}),
- #define CREATE_DICT_OPS(key_type) \
- Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \
- Operator( \
- "aten::keys(Dict(" key_type ", t) self) -> " key_type "[]", \
- dictKeys), \
- Operator("aten::values(Dict(" key_type ", t) self) -> t[]", dictValues), \
- Operator("prim::DictIndex(Dict(" key_type ", t) self, " key_type " key) -> t", dictIndex)
+#define CREATE_DICT_OPS(key_type) \
+ Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \
+ Operator( \
+ "aten::keys(Dict(" key_type ", t) self) -> " key_type "[]", \
+ dictKeys), \
+ Operator("aten::values(Dict(" key_type ", t) self) -> t[]", dictValues), \
+ Operator( \
+ "prim::DictIndex(Dict(" key_type ", t) self, " key_type \
+ " key) -> t", \
+ dictIndex)
CREATE_DICT_OPS("str"),
CREATE_DICT_OPS("int"),
CREATE_DICT_OPS("float"),
- #undef CREATE_DICT_OPS
+#undef CREATE_DICT_OPS
});
// reference: _output_size in torch/nn/functional.py