diff options
author | David Riazati <davidriazati@fb.com> | 2019-02-05 13:48:52 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-05 13:55:25 -0800 |
commit | e2d3a3fd6a248a788e3d548bb1caff9019c585ef (patch) | |
tree | 80a3aa8a400d9fb84a6687329baa5f9bbdad3566 /torch | |
parent | 0ceef3c9f686c654917e7608c4bcca0e85548f2d (diff) | |
download | pytorch-e2d3a3fd6a248a788e3d548bb1caff9019c585ef.tar.gz pytorch-e2d3a3fd6a248a788e3d548bb1caff9019c585ef.tar.bz2 pytorch-e2d3a3fd6a248a788e3d548bb1caff9019c585ef.zip |
dict values(), keys(), and len() (#16629)
Summary:
Adds some operations for dicts to match Python and tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16629
Differential Revision: D13961144
Pulled By: driazati
fbshipit-source-id: b31f27a4320ff62cd118b508fb0a13056535dc7c
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/jit/operator.cpp | 9 | ||||
-rw-r--r-- | torch/csrc/jit/register_prim_ops.cpp | 79 | ||||
-rw-r--r-- | torch/csrc/jit/script/compiler.cpp | 3 |
3 files changed, 74 insertions, 17 deletions
diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index b825a07977..8971c64ac2 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -160,6 +160,15 @@ struct SchemaParser { L.next(); value = DynamicType::get(); alias_info = parseAliasAnnotation(); + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") { + L.next(); + L.expect('('); + auto key_type = parseType().first; + L.expect(','); + auto value_type = parseType().first; + alias_info = parseAliasAnnotation(); + L.expect(')'); + value = DictType::create(key_type, value_type); } else { auto value_alias = parseBaseType(); value = value_alias.first; diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index db4592774b..c984354584 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -13,6 +13,7 @@ #include <ATen/ExpandUtils.h> #include <ATen/core/thread_pool.h> +#include <ATen/core/ivalue.h> #include <ATen/WrapDimUtils.h> #include <c10/util/SmallVector.h> @@ -823,21 +824,6 @@ RegisterOperators reg({ } ), Operator( - prim::DictIndex, - [](const Node* node) { - return [=](Stack& stack) { - auto index = pop(stack); - auto dict = pop(stack).toGenericDict(); - const auto& elems = dict->elements(); - auto value = elems.find(index); - if (value == elems.end()) { - AT_ERROR("KeyError: '", index, "'"); - } - push(stack, value->second); - return 0; - }; - }), - Operator( "aten::_unwrap_optional(t(a)? optional) -> t(a)", [](const Node* node) -> Operation { return [=](Stack& stack) { @@ -1110,6 +1096,14 @@ Operation listNe<Shared<TensorList>>(const Node* node) { }; } +Operation listList(const Node* node) { + return [=](Stack& stack) { + // Intentional no-op, needed to match Python semantics for list(iterable), + // but in JIT these will already be lists + return 0; + }; +} + template <class TList, class TElement> Operation listAdd(const Node* node) { return [=](Stack& stack) { @@ -1205,6 +1199,46 @@ Operation listSetItem<Shared<BoolList>, bool>(const Node* node) { }; } +int dictLen(Stack& stack) { + auto dict = pop(stack).toGenericDictRef(); + push(stack, int64_t(dict.size())); + return 0; +} + +int dictKeys(Stack& stack) { + auto dict = pop(stack).toGenericDictRef(); + std::vector<IValue> keys; + keys.reserve(dict.size()); + for (auto item : dict) { + keys.push_back(item.first); + } + push(stack, IValue(keys)); + return 0; +} + +int dictValues(Stack& stack) { + auto dict = pop(stack).toGenericDictRef(); + std::vector<IValue> values; + values.reserve(dict.size()); + for (auto item : dict) { + values.push_back(item.second); + } + push(stack, IValue(values)); + return 0; +} + +int dictIndex(Stack& stack) { + auto index = pop(stack); + auto dict = pop(stack).toGenericDict(); + const auto& elems = dict->elements(); + auto value = elems.find(index); + if (value == elems.end()) { + AT_ERROR("KeyError: '", index, "'"); + } + push(stack, value->second); + return 0; +} + RegisterOperators reg2({ @@ -1273,7 +1307,8 @@ RegisterOperators reg2({ "aten::slice(" decl_type \ "[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \ "[]", \ - listSlice<Shared<c_type>, c_type::ElemType>) + listSlice<Shared<c_type>, c_type::ElemType>), \ + Operator("aten::list(" decl_type "[] l) -> " decl_type "[]", listList) CREATE_LIST_OPS("int", IntList), CREATE_LIST_OPS("float", DoubleList), @@ -1475,6 +1510,18 @@ RegisterOperators reg2({ return 0; }; }), + #define CREATE_DICT_OPS(key_type) \ + Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \ + Operator( \ + "aten::keys(Dict(" key_type ", t) self) -> " key_type "[]", \ + dictKeys), \ + Operator("aten::values(Dict(" key_type ", t) self) -> t[]", dictValues), \ + Operator("prim::DictIndex(Dict(" key_type ", t) self, " key_type " key) -> t", dictIndex) + + CREATE_DICT_OPS("str"), + CREATE_DICT_OPS("int"), + CREATE_DICT_OPS("float"), + #undef CREATE_DICT_OPS }); // reference: _output_size in torch/nn/functional.py diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index a96690d673..0e5ff51d01 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -399,6 +399,7 @@ struct Environment { {"len", std::make_shared<BuiltinFunction>(aten::len, at::nullopt)}, {"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)}, {"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)}, + {"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)}, }; auto it = globals.find(ident); if (it != globals.end()) @@ -2230,7 +2231,7 @@ struct to_ir { auto value_trees = dl.value_inputs().tree()->trees(); AT_ASSERT(key_trees.size() == value_trees.size()); std::vector<Value*> keys, values; - for(size_t i = 0; i < keys.size(); ++i) { + for(size_t i = 0; i < key_trees.size(); ++i) { keys.push_back(emitExpr(Expr(key_trees[i]))); values.push_back(emitExpr(Expr(value_trees[i]))); } |