diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2017-11-24 14:42:29 -0500 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2017-11-28 09:52:49 +0100 |
commit | e66c592d10ca042f707312c161772d61ad4fc2bc (patch) | |
tree | 556899b21edb45a82c6a781abc83c4db9ddacced /tools/jit/templates | |
parent | 00fe1f7cc88aff585197f397a51268e25ea169af (diff) | |
download | pytorch-e66c592d10ca042f707312c161772d61ad4fc2bc.tar.gz pytorch-e66c592d10ca042f707312c161772d61ad4fc2bc.tar.bz2 pytorch-e66c592d10ca042f707312c161772d61ad4fc2bc.zip |
Handle ops with multiple inputs in aten_dispatch.cpp
Diffstat (limited to 'tools/jit/templates')
-rw-r--r-- | tools/jit/templates/aten_dispatch.cpp | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/tools/jit/templates/aten_dispatch.cpp b/tools/jit/templates/aten_dispatch.cpp index 4a9f4515a5..1c7c55ed4d 100644 --- a/tools/jit/templates/aten_dispatch.cpp +++ b/tools/jit/templates/aten_dispatch.cpp @@ -35,6 +35,12 @@ void pack_list(std::vector<Tensor> & outputs, std::tuple<Tensor, Tensor, Tensor> outputs.push_back(std::get<2>(v)); } +// A list of functions taking TensorList arguments (where we can't use +// the number of inputs to choose an overload). +std::unordered_set<Symbol> tensor_vararg_fns = { + kcat, +}; + template<size_t N> std::array<bool, N> as_bool_array(const std::vector<int64_t>& vec) { std::array<bool, N> res; @@ -49,7 +55,11 @@ std::unordered_map<std::string, operator_constructor> constructors = { std::string getDescriptor(jit::Node* n) { std::stringstream s; - s << symbolToString(n->kind()) << "-" << n->inputs().size(); + s << symbolToString(n->kind()); + if (tensor_vararg_fns.count(n->kind()) == 0) + s << "-" << n->inputs().size(); + else + s << "-*"; std::vector<const char*> attr_names = fmap(n->attributeNames(), &symbolToString); std::sort(attr_names.begin(), attr_names.end(), [](const char *a, const char *b) { return std::strcmp(a, b) < 0; |