summaryrefslogtreecommitdiff
path: root/tools/jit/templates
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2017-11-24 14:42:29 -0500
committerAdam Paszke <adam.paszke@gmail.com>2017-11-28 09:52:49 +0100
commite66c592d10ca042f707312c161772d61ad4fc2bc (patch)
tree556899b21edb45a82c6a781abc83c4db9ddacced /tools/jit/templates
parent00fe1f7cc88aff585197f397a51268e25ea169af (diff)
downloadpytorch-e66c592d10ca042f707312c161772d61ad4fc2bc.tar.gz
pytorch-e66c592d10ca042f707312c161772d61ad4fc2bc.tar.bz2
pytorch-e66c592d10ca042f707312c161772d61ad4fc2bc.zip
Handle ops with multiple inputs in aten_dispatch.cpp
Diffstat (limited to 'tools/jit/templates')
-rw-r--r--tools/jit/templates/aten_dispatch.cpp12
1 files changed, 11 insertions, 1 deletions
diff --git a/tools/jit/templates/aten_dispatch.cpp b/tools/jit/templates/aten_dispatch.cpp
index 4a9f4515a5..1c7c55ed4d 100644
--- a/tools/jit/templates/aten_dispatch.cpp
+++ b/tools/jit/templates/aten_dispatch.cpp
@@ -35,6 +35,12 @@ void pack_list(std::vector<Tensor> & outputs, std::tuple<Tensor, Tensor, Tensor>
outputs.push_back(std::get<2>(v));
}
+// A list of functions taking TensorList arguments (where we can't use
+// the number of inputs to choose an overload).
+std::unordered_set<Symbol> tensor_vararg_fns = {
+ kcat,
+};
+
template<size_t N>
std::array<bool, N> as_bool_array(const std::vector<int64_t>& vec) {
std::array<bool, N> res;
@@ -49,7 +55,11 @@ std::unordered_map<std::string, operator_constructor> constructors = {
std::string getDescriptor(jit::Node* n) {
std::stringstream s;
- s << symbolToString(n->kind()) << "-" << n->inputs().size();
+ s << symbolToString(n->kind());
+ if (tensor_vararg_fns.count(n->kind()) == 0)
+ s << "-" << n->inputs().size();
+ else
+ s << "-*";
std::vector<const char*> attr_names = fmap(n->attributeNames(), &symbolToString);
std::sort(attr_names.begin(), attr_names.end(), [](const char *a, const char *b) {
return std::strcmp(a, b) < 0;