summaryrefslogtreecommitdiff
path: root/torch/csrc
diff options
context:
space:
mode:
authorZachary DeVito <zdevito@gmail.com>2018-04-16 15:19:05 -0700
committerGitHub <noreply@github.com>2018-04-16 15:19:05 -0700
commitee240aa00cdde948382338e908b18a62386a5989 (patch)
treed87f3de59d10a18eee9a1adc236d56dcf2675304 /torch/csrc
parent0e93a2c3342316e30b1befd6aa53ababc828e708 (diff)
downloadpytorch-ee240aa00cdde948382338e908b18a62386a5989.tar.gz
pytorch-ee240aa00cdde948382338e908b18a62386a5989.tar.bz2
pytorch-ee240aa00cdde948382338e908b18a62386a5989.zip
Allow script_methods to be defined out of order (#6341)
This modifies the registration process so that all script methods in a ScriptModule are defined at once. Method gains a `method_creator` callback that gets invoked when the method is first called to define it if it has not already been defined. Recursive cycles in this `method_creator` are checked. This approach was chosen over first creating all the graphs and then inlining the call sites because it will combine better with type propagation for non-tensor types like tuples. e.g. ``` a = foo(b) return bar(*a) ```
Diffstat (limited to 'torch/csrc')
-rw-r--r--torch/csrc/jit/script/compiler.cpp46
-rw-r--r--torch/csrc/jit/script/compiler.h3
-rw-r--r--torch/csrc/jit/script/init.cpp20
-rw-r--r--torch/csrc/jit/script/module.cpp26
-rw-r--r--torch/csrc/jit/script/module.h36
5 files changed, 93 insertions, 38 deletions
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index 8d32ecd503..4e940ffe53 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -820,9 +820,7 @@ private:
std::shared_ptr<SugaredValue> emitApplyIdent(Ident ident, std::vector<Value*> inputs, List<Attribute> attributes, size_t n_binders) {
auto it = function_table.find(ident.name());
if (it != function_table.end()) {
- if(inputs.size() != it->second.num_inputs())
- throw ErrorReport(ident) << "expected " << it->second.num_inputs() << " but found " << inputs.size();
- return packOutputs(*graph, method.emit_call_to(it->second, inputs));
+ return packOutputs(*graph, method.emit_call_to(ident.range(), it->second, inputs));
} else if (ident.name() == "print") {
if (!attributes.empty())
throw ErrorReport(ident) << "print doesn't accept any keyword arguments";
@@ -1071,35 +1069,51 @@ std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> input
return outputs;
}
-void defineMethodsInModule(Module & m, const std::vector<Def>& definitions, const Resolver& resolver, SugaredValuePtr self) {
+void defineMethodsInModule(Module & m, const std::vector<Def>& definitions, const std::vector<Resolver>& resolvers, SugaredValuePtr self) {
FunctionTable table;
- for(auto def : definitions) {
+ JIT_ASSERT(definitions.size() == resolvers.size());
+ auto resolver_it = resolvers.begin();
+ std::vector<Method*> methods;
+ for(Def def : definitions) {
const std::string& name = def.name().name();
- Method& method = m.create_method(name);
- to_ir(def, table, resolver, self, method);
- auto result = table.emplace(name, method);
- if(!result.second) {
- throw ErrorReport(def) << "duplicate definition of function '" << name << "'";
+ Resolver resolver = *resolver_it++;
+ auto creator = [def, &table, resolver, self](Method& method) {
+ to_ir(def, table, resolver, self, method);
+ };
+ Method& method = m.create_method(name, creator);
+ // if self is defined, then these are methods and do not go into the global namespace
+ // otherwise, they get defined together so we add them to the function table
+ // so the methods can see each other
+ if(!self) {
+ auto result = table.emplace(name, method);
+ if(!result.second) {
+ throw ErrorReport(def) << "duplicate definition of function '" << name << "'";
+ }
}
+ methods.push_back(&method);
+ }
+ for(Method* method : methods) {
+ method->ensure_defined();
}
}
void defineMethodsInModule(Module & m, const std::string& source, const Resolver& resolver, SugaredValuePtr self) {
Parser p(source);
std::vector<Def> definitions;
+ std::vector<Resolver> resolvers;
while (p.lexer().cur().kind != TK_EOF) {
definitions.push_back(Def(p.parseFunction()));
+ resolvers.push_back(resolver);
}
- defineMethodsInModule(m, definitions, resolver, self);
+ defineMethodsInModule(m, definitions, resolvers, self);
}
std::shared_ptr<Graph> compileFunction(Def def, const Resolver& resolver) {
Module m; //note: we don't use 'm' to execute so this setting is unused
- defineMethodsInModule(m, {def}, resolver, nullptr);
+ defineMethodsInModule(m, {def}, {resolver}, nullptr);
return m.get_method(def.name().name()).graph();
}
-
std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(SourceRange loc, Method& m) {
auto & graph = *m.graph();
if(value->type()->kind() == TypeKind::TupleType) {
@@ -1111,6 +1125,12 @@ std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(SourceRange loc,
return SugaredValue::asTuple(loc, m);
}
+void ensureSizeMatches(SourceRange loc, size_t expected, size_t actual, const std::string& what) {
+ if(expected != actual) {
+ throw ErrorReport(loc) << "expected " << expected << " " << what << " but found " << actual;
+ }
+}
+
} // namespace script
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h
index a8adfd08a9..3b9889eb44 100644
--- a/torch/csrc/jit/script/compiler.h
+++ b/torch/csrc/jit/script/compiler.h
@@ -116,7 +116,7 @@ using Resolver = std::function<std::shared_ptr<SugaredValue>(const std::string&
void defineMethodsInModule(
Module & m,
const std::vector<Def>& definitions,
- const Resolver& resolver, /* determines how we handle free variables*/
+ const std::vector<Resolver>& resolvers, /* determines how we handle free variables in each definition*/
std::shared_ptr<SugaredValue> self /* if non-null, the first argument to each def, is bound to this value */
);
@@ -128,6 +128,7 @@ std::shared_ptr<Graph> compileFunction(Def def, const Resolver& resolver);
// a SimpleValue, otherwise pack all the values into a Tuple.
std::shared_ptr<SugaredValue> packOutputs(Graph& g, at::ArrayRef<Value*> values);
std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs);
+void ensureSizeMatches(SourceRange loc, size_t expected, size_t actual, const std::string& what);
} // namespace script
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index 0c0f23ec68..fad8729b44 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -15,12 +15,6 @@ using ResolutionCallback = std::function<py::function(std::string)>;
#define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
#endif
-static void ensureSizeMatches(SourceRange loc, size_t expected, size_t actual, const std::string& what) {
- if(expected != actual) {
- throw ErrorReport(loc) << "expected " << expected << " " << what << " but found " << actual;
- }
-}
-
static std::string typeString(py::handle h) {
return py::str(h.get_type().attr("__name__"));
}
@@ -172,9 +166,7 @@ struct MethodValue : public SugaredValue {
if(attributes.size() != 0) {
throw ErrorReport(loc) << "not yet implemented - calls to script methods using keyword arguments";
}
- ensureSizeMatches(loc, method.num_inputs(), inputs.size(), "inputs");
- auto outputs = caller.emit_call_to(method, inputs);
- return packOutputs(*caller.graph(), outputs);
+ return packOutputs(*caller.graph(), caller.emit_call_to(loc, method, inputs));
}
private:
std::shared_ptr<Module> module;
@@ -298,11 +290,15 @@ void initJitScriptBindings(PyObject* module) {
auto self = has_self ? std::make_shared<ModuleValue>(m.shared_from_this()) : nullptr;
return defineMethodsInModule(m, script, pythonResolver(rcb), self);
})
- .def("_create_method", [](Module& m, Def def, ResolutionCallback rcb) {
+ .def("_create_methods", [](Module& m, const std::vector<Def>& defs, const std::vector<ResolutionCallback>& rcbs) {
+ std::vector<Resolver> resolvers;
+ for(auto & callback : rcbs) {
+ resolvers.push_back(pythonResolver(callback));
+ }
defineMethodsInModule(
m,
- { def },
- pythonResolver(rcb),
+ defs,
+ resolvers,
std::make_shared<ModuleValue>(m.shared_from_this()));
})
.def("_get_method",
diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp
index 13bc04d20b..da72b283ef 100644
--- a/torch/csrc/jit/script/module.cpp
+++ b/torch/csrc/jit/script/module.cpp
@@ -1,12 +1,25 @@
#include "torch/csrc/jit/script/module.h"
#include "torch/csrc/jit/script/compiler.h"
+#include "torch/csrc/jit/script/error_report.h"
namespace torch { namespace jit { namespace script {
-std::vector<Value*> Method::emit_call_to(Method & callee, ArrayRef<Value*> inputs) {
+
+struct RecursiveMethodCallError : public std::exception {};
+void placeholderCreator(Method&) {
+ throw RecursiveMethodCallError();
+}
+
+std::vector<Value*> Method::emit_call_to(SourceRange loc, Method & callee, ArrayRef<Value*> inputs) {
JIT_ASSERT(!executor);
+ try {
+ callee.ensure_defined();
+ } catch (RecursiveMethodCallError&) {
+ throw ErrorReport(loc) << " method '" << callee.name()
+ << "' is called recursively involving this call site. Recursive calls are not supported";
+ }
auto fn = callee.graph();
- JIT_ASSERT(inputs.size() == callee.num_inputs());
+ ensureSizeMatches(loc, callee.num_inputs(), inputs.size(), "inputs");
std::vector<Value*> all_inputs = inputs;
// parameters to callee method (which become parameters to _this_ method
// if they were not already)
@@ -16,4 +29,13 @@ std::vector<Value*> Method::emit_call_to(Method & callee, ArrayRef<Value*> input
return inlineCallTo(*graph(), *callee.graph(), all_inputs);
}
+void Method::ensure_defined() {
+ if(method_creator) {
+ auto creator = method_creator;
+ method_creator = placeholderCreator;
+ creator(*this);
+ method_creator = nullptr;
+ }
+}
+
}}}
diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h
index f07007415b..7903e7a569 100644
--- a/torch/csrc/jit/script/module.h
+++ b/torch/csrc/jit/script/module.h
@@ -3,6 +3,7 @@
#include "torch/csrc/jit/graph_executor.h"
#include "torch/csrc/autograd/variable.h"
#include <ATen/optional.h>
+#include <functional>
// This file contains classes which assist in desugaring Python style
// modules and their methods into flattened graphs which don't have any
@@ -19,24 +20,28 @@ namespace torch { namespace jit { namespace script {
// Note: because Method/Module are exposed to python these
// classes use python method naming conventions
+struct SourceRange;
+
struct Method {
Method(std::string name, bool optimize,
std::shared_ptr<Graph> graph,
- std::vector<at::Tensor*> initial_members)
+ std::vector<at::Tensor*> initial_members,
+ std::function<void(Method&)> method_creator)
: name_(std::move(name))
, graph_(std::move(graph))
, optimize(optimize)
- , member_inputs(std::move(initial_members)) {
+ , member_inputs(std::move(initial_members))
+ , method_creator(method_creator) {
JIT_ASSERT(graph_->inputs().size() >= member_inputs.size());
int i = graph_->inputs().size() - member_inputs.size();
for(at::Tensor* member : member_inputs) {
member_input_index[member] = i++;
}
}
-
+
variable_tensor_list run(variable_tensor_list && inputs) {
std::call_once(executor_init, [&]{
- executor = GraphExecutor(graph_, optimize);
+ executor = GraphExecutor(graph(), optimize);
});
for(auto tp : member_inputs) {
inputs.push_back(*tp);
@@ -54,10 +59,13 @@ struct Method {
// adding any extra parameters necessary to do this call
// defined here to keep details of member_input handling confined to this class
- std::vector<Value*> emit_call_to(Method & callee, ArrayRef<Value*> inputs);
+ std::vector<Value*> emit_call_to(SourceRange loc, Method & callee, ArrayRef<Value*> inputs);
+ // if this isn't yet defined, run its method_creator function
+ void ensure_defined();
+
size_t num_inputs() const {
- return graph_->inputs().size() - member_inputs.size();
+ return graph()->inputs().size() - member_inputs.size();
}
Value * get_or_add_parameter(at::Tensor* slot) {
auto it = member_input_index.find(slot);
@@ -94,6 +102,11 @@ private:
// std::vector<at::Tensor*> member_outputs;
std::once_flag executor_init;
+
+ // an optional function that actually creates the method when emit_call_to(this,...)
+ // is first called.
+ // this is used by the compiler so that it can construct methods out of order
+ std::function<void(Method&)> method_creator;
};
struct Module;
@@ -205,11 +218,14 @@ struct Module : public std::enable_shared_from_this<Module> {
modules.insert(name, {name, std::move(module)});
}
- Method& create_method(const std::string & name, std::shared_ptr<Graph> graph = nullptr, std::vector<at::Tensor*> member_inputs = {}) {
- if(!graph)
- graph = std::make_shared<Graph>();
- std::unique_ptr<Method> method(new Method(name, optimize, std::move(graph), std::move(member_inputs)));
+ Method& create_method(const std::string & name, std::shared_ptr<Graph> graph, std::vector<at::Tensor*> member_inputs) {
+ JIT_ASSERT(graph);
+ std::unique_ptr<Method> method(new Method(name, optimize, std::move(graph), std::move(member_inputs), nullptr));
+ return *methods.insert(name, std::move(method));
+ }
+ Method& create_method(const std::string & name, std::function<void(Method&)> creator) {
+ std::unique_ptr<Method> method(new Method(name, optimize, std::make_shared<Graph>(), {}, creator));
return *methods.insert(name, std::move(method));
}