summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorDavid Riazati <davidriazati@fb.com>2019-02-05 13:48:52 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-05 13:55:25 -0800
commite2d3a3fd6a248a788e3d548bb1caff9019c585ef (patch)
tree80a3aa8a400d9fb84a6687329baa5f9bbdad3566 /torch
parent0ceef3c9f686c654917e7608c4bcca0e85548f2d (diff)
downloadpytorch-e2d3a3fd6a248a788e3d548bb1caff9019c585ef.tar.gz
pytorch-e2d3a3fd6a248a788e3d548bb1caff9019c585ef.tar.bz2
pytorch-e2d3a3fd6a248a788e3d548bb1caff9019c585ef.zip
dict values(), keys(), and len() (#16629)
Summary: Adds some operations for dicts to match Python and tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/16629 Differential Revision: D13961144 Pulled By: driazati fbshipit-source-id: b31f27a4320ff62cd118b508fb0a13056535dc7c
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/jit/operator.cpp9
-rw-r--r--torch/csrc/jit/register_prim_ops.cpp79
-rw-r--r--torch/csrc/jit/script/compiler.cpp3
3 files changed, 74 insertions, 17 deletions
diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp
index b825a07977..8971c64ac2 100644
--- a/torch/csrc/jit/operator.cpp
+++ b/torch/csrc/jit/operator.cpp
@@ -160,6 +160,15 @@ struct SchemaParser {
L.next();
value = DynamicType::get();
alias_info = parseAliasAnnotation();
+ } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
+ L.next();
+ L.expect('(');
+ auto key_type = parseType().first;
+ L.expect(',');
+ auto value_type = parseType().first;
+ alias_info = parseAliasAnnotation();
+ L.expect(')');
+ value = DictType::create(key_type, value_type);
} else {
auto value_alias = parseBaseType();
value = value_alias.first;
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index db4592774b..c984354584 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -13,6 +13,7 @@
#include <ATen/ExpandUtils.h>
#include <ATen/core/thread_pool.h>
+#include <ATen/core/ivalue.h>
#include <ATen/WrapDimUtils.h>
#include <c10/util/SmallVector.h>
@@ -823,21 +824,6 @@ RegisterOperators reg({
}
),
Operator(
- prim::DictIndex,
- [](const Node* node) {
- return [=](Stack& stack) {
- auto index = pop(stack);
- auto dict = pop(stack).toGenericDict();
- const auto& elems = dict->elements();
- auto value = elems.find(index);
- if (value == elems.end()) {
- AT_ERROR("KeyError: '", index, "'");
- }
- push(stack, value->second);
- return 0;
- };
- }),
- Operator(
"aten::_unwrap_optional(t(a)? optional) -> t(a)",
[](const Node* node) -> Operation {
return [=](Stack& stack) {
@@ -1110,6 +1096,14 @@ Operation listNe<Shared<TensorList>>(const Node* node) {
};
}
+Operation listList(const Node* node) {
+ return [=](Stack& stack) {
+ // Intentional no-op, needed to match Python semantics for list(iterable),
+ // but in JIT these will already be lists
+ return 0;
+ };
+}
+
template <class TList, class TElement>
Operation listAdd(const Node* node) {
return [=](Stack& stack) {
@@ -1205,6 +1199,46 @@ Operation listSetItem<Shared<BoolList>, bool>(const Node* node) {
};
}
+int dictLen(Stack& stack) {
+ auto dict = pop(stack).toGenericDictRef();
+ push(stack, int64_t(dict.size()));
+ return 0;
+}
+
+int dictKeys(Stack& stack) {
+ auto dict = pop(stack).toGenericDictRef();
+ std::vector<IValue> keys;
+ keys.reserve(dict.size());
+ for (auto item : dict) {
+ keys.push_back(item.first);
+ }
+ push(stack, IValue(keys));
+ return 0;
+}
+
+int dictValues(Stack& stack) {
+ auto dict = pop(stack).toGenericDictRef();
+ std::vector<IValue> values;
+ values.reserve(dict.size());
+ for (auto item : dict) {
+ values.push_back(item.second);
+ }
+ push(stack, IValue(values));
+ return 0;
+}
+
+int dictIndex(Stack& stack) {
+ auto index = pop(stack);
+ auto dict = pop(stack).toGenericDict();
+ const auto& elems = dict->elements();
+ auto value = elems.find(index);
+ if (value == elems.end()) {
+ AT_ERROR("KeyError: '", index, "'");
+ }
+ push(stack, value->second);
+ return 0;
+}
+
RegisterOperators reg2({
@@ -1273,7 +1307,8 @@ RegisterOperators reg2({
"aten::slice(" decl_type \
"[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \
"[]", \
- listSlice<Shared<c_type>, c_type::ElemType>)
+ listSlice<Shared<c_type>, c_type::ElemType>), \
+ Operator("aten::list(" decl_type "[] l) -> " decl_type "[]", listList)
CREATE_LIST_OPS("int", IntList),
CREATE_LIST_OPS("float", DoubleList),
@@ -1475,6 +1510,18 @@ RegisterOperators reg2({
return 0;
};
}),
+ #define CREATE_DICT_OPS(key_type) \
+ Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \
+ Operator( \
+ "aten::keys(Dict(" key_type ", t) self) -> " key_type "[]", \
+ dictKeys), \
+ Operator("aten::values(Dict(" key_type ", t) self) -> t[]", dictValues), \
+ Operator("prim::DictIndex(Dict(" key_type ", t) self, " key_type " key) -> t", dictIndex)
+
+ CREATE_DICT_OPS("str"),
+ CREATE_DICT_OPS("int"),
+ CREATE_DICT_OPS("float"),
+ #undef CREATE_DICT_OPS
});
// reference: _output_size in torch/nn/functional.py
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index a96690d673..0e5ff51d01 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -399,6 +399,7 @@ struct Environment {
{"len", std::make_shared<BuiltinFunction>(aten::len, at::nullopt)},
{"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
{"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
+ {"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
};
auto it = globals.find(ident);
if (it != globals.end())
@@ -2230,7 +2231,7 @@ struct to_ir {
auto value_trees = dl.value_inputs().tree()->trees();
AT_ASSERT(key_trees.size() == value_trees.size());
std::vector<Value*> keys, values;
- for(size_t i = 0; i < keys.size(); ++i) {
+ for(size_t i = 0; i < key_trees.size(); ++i) {
keys.push_back(emitExpr(Expr(key_trees[i])));
values.push_back(emitExpr(Expr(value_trees[i])));
}