diff options
author | Elias Ellison <eellison@fb.com> | 2019-01-30 13:48:36 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-30 14:20:56 -0800 |
commit | 18659e13367dd1ff37af9d79b954a035f7230582 (patch) | |
tree | 2f275e053ef6dba56ad04ae6d2163c84632d791e /aten | |
parent | 22e9c3055a802e228624190c7e69e083b92359c3 (diff) | |
download | pytorch-18659e13367dd1ff37af9d79b954a035f7230582.tar.gz pytorch-18659e13367dd1ff37af9d79b954a035f7230582.tar.bz2 pytorch-18659e13367dd1ff37af9d79b954a035f7230582.zip |
Allow generic containers as module inputs (#16482)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/16326
Previously we didn't handle module inputs which included Generic Lists. When checking whether a generic list if a subvalue of the input arg type, I currently recurse on every element of the list. This shouldn't be too slow since the innermost list will be specialized and we won't have to check it's elements.
E.g. Tensor[][] -> GenericList [TensorList ].
The error message could be improved, but extracting the complete type of nested lists would have to deal with unifying types across lists / empty lists & typevars so I'm going to save that for a follow up PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16482
Differential Revision: D13882582
Pulled By: eellison
fbshipit-source-id: 3609bc572f0ee9ebf20a77ea5ebc8fa3b165e24b
Diffstat (limited to 'aten')
-rw-r--r-- | aten/src/ATen/core/jit_type.h | 2 | ||||
-rw-r--r-- | aten/src/ATen/core/type.cpp | 32 |
2 files changed, 34 insertions, 0 deletions
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 000c6fac19..49918f8bd1 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -917,6 +917,8 @@ template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::o template<> inline TypePtr getTypePtr<std::vector<int64_t>>() { return ListType::ofInts(); } CAFFE2_API TypePtr incompleteInferTypeFrom(const IValue& value); +CAFFE2_API TypePtr attemptToRecoverType(const IValue& input_ivalue); +CAFFE2_API bool isSubvalueOf(const IValue& input_ivalue, TypePtr type); using TypeEnv = std::unordered_map<std::string, TypePtr>; struct MatchTypeReturn { diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 9c69a49714..12765a89ee 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -149,6 +149,38 @@ TypePtr incompleteInferTypeFrom(const IValue& value) { AT_ERROR("Type cannot be accurately recovered from this IValue."); } +// This attempts to recover the type from an IValue, including nested Generic +// Lists. It only examines the first element of each generic container, +// and if a generic container is empty returns typevar as the base element. +// XXX: only used for better error messages, should not be used elsewhere +TypePtr attemptToRecoverType(const IValue& input_ivalue) { + if (input_ivalue.isGenericList()) { + auto& ivalue_list = input_ivalue.toGenericListRef(); + if (ivalue_list.size() == 0) { + return ListType::create(VarType::create("t")); + } + return ListType::create(attemptToRecoverType(ivalue_list[0])); + } + return incompleteInferTypeFrom(input_ivalue); +} + +// Checks if input_ivalue is a subvalue of type. +bool isSubvalueOf(const IValue& ivalue, TypePtr type) { + if (ivalue.isGenericList()) { + auto list_type = type->cast<ListType>(); + if (!list_type) { + return false; + } + auto& ivalue_list = ivalue.toGenericListRef(); + auto element_type = list_type->getElementType(); + return std::all_of(ivalue_list.begin(), ivalue_list.end(), [&](const IValue& list_elem) { + return isSubvalueOf(list_elem, element_type); + }); + } + return incompleteInferTypeFrom(ivalue)->isSubtypeOf(type); +} + + c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) { //cases that t1 == t2, or t1 is a type refinement of t2 and vice versa if (t1->isSubtypeOf(t2)) { |