summaryrefslogtreecommitdiff
path: root/aten
diff options
context:
space:
mode:
authorElias Ellison <eellison@fb.com>2019-01-30 13:48:36 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-30 14:20:56 -0800
commit18659e13367dd1ff37af9d79b954a035f7230582 (patch)
tree2f275e053ef6dba56ad04ae6d2163c84632d791e /aten
parent22e9c3055a802e228624190c7e69e083b92359c3 (diff)
downloadpytorch-18659e13367dd1ff37af9d79b954a035f7230582.tar.gz
pytorch-18659e13367dd1ff37af9d79b954a035f7230582.tar.bz2
pytorch-18659e13367dd1ff37af9d79b954a035f7230582.zip
Allow generic containers as module inputs (#16482)
Summary: Fixes https://github.com/pytorch/pytorch/issues/16326 Previously we didn't handle module inputs which included Generic Lists. When checking whether a generic list if a subvalue of the input arg type, I currently recurse on every element of the list. This shouldn't be too slow since the innermost list will be specialized and we won't have to check it's elements. E.g. Tensor[][] -> GenericList [TensorList ]. The error message could be improved, but extracting the complete type of nested lists would have to deal with unifying types across lists / empty lists & typevars so I'm going to save that for a follow up PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16482 Differential Revision: D13882582 Pulled By: eellison fbshipit-source-id: 3609bc572f0ee9ebf20a77ea5ebc8fa3b165e24b
Diffstat (limited to 'aten')
-rw-r--r--aten/src/ATen/core/jit_type.h2
-rw-r--r--aten/src/ATen/core/type.cpp32
2 files changed, 34 insertions, 0 deletions
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index 000c6fac19..49918f8bd1 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -917,6 +917,8 @@ template<> inline TypePtr getTypePtr<std::vector<double>>() { return ListType::o
template<> inline TypePtr getTypePtr<std::vector<int64_t>>() { return ListType::ofInts(); }
CAFFE2_API TypePtr incompleteInferTypeFrom(const IValue& value);
+CAFFE2_API TypePtr attemptToRecoverType(const IValue& input_ivalue);
+CAFFE2_API bool isSubvalueOf(const IValue& input_ivalue, TypePtr type);
using TypeEnv = std::unordered_map<std::string, TypePtr>;
struct MatchTypeReturn {
diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp
index 9c69a49714..12765a89ee 100644
--- a/aten/src/ATen/core/type.cpp
+++ b/aten/src/ATen/core/type.cpp
@@ -149,6 +149,38 @@ TypePtr incompleteInferTypeFrom(const IValue& value) {
AT_ERROR("Type cannot be accurately recovered from this IValue.");
}
+// This attempts to recover the type from an IValue, including nested Generic
+// Lists. It only examines the first element of each generic container,
+// and if a generic container is empty returns typevar as the base element.
+// XXX: only used for better error messages, should not be used elsewhere
+TypePtr attemptToRecoverType(const IValue& input_ivalue) {
+ if (input_ivalue.isGenericList()) {
+ auto& ivalue_list = input_ivalue.toGenericListRef();
+ if (ivalue_list.size() == 0) {
+ return ListType::create(VarType::create("t"));
+ }
+ return ListType::create(attemptToRecoverType(ivalue_list[0]));
+ }
+ return incompleteInferTypeFrom(input_ivalue);
+}
+
+// Checks if input_ivalue is a subvalue of type.
+bool isSubvalueOf(const IValue& ivalue, TypePtr type) {
+ if (ivalue.isGenericList()) {
+ auto list_type = type->cast<ListType>();
+ if (!list_type) {
+ return false;
+ }
+ auto& ivalue_list = ivalue.toGenericListRef();
+ auto element_type = list_type->getElementType();
+ return std::all_of(ivalue_list.begin(), ivalue_list.end(), [&](const IValue& list_elem) {
+ return isSubvalueOf(list_elem, element_type);
+ });
+ }
+ return incompleteInferTypeFrom(ivalue)->isSubtypeOf(type);
+}
+
+
c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
//cases that t1 == t2, or t1 is a type refinement of t2 and vice versa
if (t1->isSubtypeOf(t2)) {