diff options
author | Zachary DeVito <zdevito@gmail.com> | 2018-04-16 15:19:05 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-16 15:19:05 -0700 |
commit | ee240aa00cdde948382338e908b18a62386a5989 (patch) | |
tree | d87f3de59d10a18eee9a1adc236d56dcf2675304 /torch/csrc | |
parent | 0e93a2c3342316e30b1befd6aa53ababc828e708 (diff) | |
download | pytorch-ee240aa00cdde948382338e908b18a62386a5989.tar.gz pytorch-ee240aa00cdde948382338e908b18a62386a5989.tar.bz2 pytorch-ee240aa00cdde948382338e908b18a62386a5989.zip |
Allow script_methods to be defined out of order (#6341)
This modifies the registration process so that all script methods
in a ScriptModule are defined at once.
Method gains a `method_creator` callback that gets invoked when the
method is first called to define it if it has not already been defined.
Recursive cycles in this `method_creator` are checked.
This approach was chosen over first creating all the graphs and then
inlining the call sites because it will combine better with type
propagation for non-tensor types like tuples. e.g.
```
a = foo(b)
return bar(*a)
```
Diffstat (limited to 'torch/csrc')
-rw-r--r-- | torch/csrc/jit/script/compiler.cpp | 46 | ||||
-rw-r--r-- | torch/csrc/jit/script/compiler.h | 3 | ||||
-rw-r--r-- | torch/csrc/jit/script/init.cpp | 20 | ||||
-rw-r--r-- | torch/csrc/jit/script/module.cpp | 26 | ||||
-rw-r--r-- | torch/csrc/jit/script/module.h | 36 |
5 files changed, 93 insertions, 38 deletions
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 8d32ecd503..4e940ffe53 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -820,9 +820,7 @@ private: std::shared_ptr<SugaredValue> emitApplyIdent(Ident ident, std::vector<Value*> inputs, List<Attribute> attributes, size_t n_binders) { auto it = function_table.find(ident.name()); if (it != function_table.end()) { - if(inputs.size() != it->second.num_inputs()) - throw ErrorReport(ident) << "expected " << it->second.num_inputs() << " but found " << inputs.size(); - return packOutputs(*graph, method.emit_call_to(it->second, inputs)); + return packOutputs(*graph, method.emit_call_to(ident.range(), it->second, inputs)); } else if (ident.name() == "print") { if (!attributes.empty()) throw ErrorReport(ident) << "print doesn't accept any keyword arguments"; @@ -1071,35 +1069,51 @@ std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> input return outputs; } -void defineMethodsInModule(Module & m, const std::vector<Def>& definitions, const Resolver& resolver, SugaredValuePtr self) { +void defineMethodsInModule(Module & m, const std::vector<Def>& definitions, const std::vector<Resolver>& resolvers, SugaredValuePtr self) { FunctionTable table; - for(auto def : definitions) { + JIT_ASSERT(definitions.size() == resolvers.size()); + auto resolver_it = resolvers.begin(); + std::vector<Method*> methods; + for(Def def : definitions) { const std::string& name = def.name().name(); - Method& method = m.create_method(name); - to_ir(def, table, resolver, self, method); - auto result = table.emplace(name, method); - if(!result.second) { - throw ErrorReport(def) << "duplicate definition of function '" << name << "'"; + Resolver resolver = *resolver_it++; + auto creator = [def, &table, resolver, self](Method& method) { + to_ir(def, table, resolver, self, method); + }; + Method& method = m.create_method(name, creator); + // if self is defined, then these are methods and do not go into the global namespace + // otherwise, they get defined together so we add them to the function table + // so the methods can see each other + if(!self) { + auto result = table.emplace(name, method); + if(!result.second) { + throw ErrorReport(def) << "duplicate definition of function '" << name << "'"; + } } + methods.push_back(&method); + } + for(Method* method : methods) { + method->ensure_defined(); } } void defineMethodsInModule(Module & m, const std::string& source, const Resolver& resolver, SugaredValuePtr self) { Parser p(source); std::vector<Def> definitions; + std::vector<Resolver> resolvers; while (p.lexer().cur().kind != TK_EOF) { definitions.push_back(Def(p.parseFunction())); + resolvers.push_back(resolver); } - defineMethodsInModule(m, definitions, resolver, self); + defineMethodsInModule(m, definitions, resolvers, self); } std::shared_ptr<Graph> compileFunction(Def def, const Resolver& resolver) { Module m; //note: we don't use 'm' to execute so this setting is unused - defineMethodsInModule(m, {def}, resolver, nullptr); + defineMethodsInModule(m, {def}, {resolver}, nullptr); return m.get_method(def.name().name()).graph(); } - std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(SourceRange loc, Method& m) { auto & graph = *m.graph(); if(value->type()->kind() == TypeKind::TupleType) { @@ -1111,6 +1125,12 @@ std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(SourceRange loc, return SugaredValue::asTuple(loc, m); } +void ensureSizeMatches(SourceRange loc, size_t expected, size_t actual, const std::string& what) { + if(expected != actual) { + throw ErrorReport(loc) << "expected " << expected << " " << what << " but found " << actual; + } +} + } // namespace script } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h index a8adfd08a9..3b9889eb44 100644 --- a/torch/csrc/jit/script/compiler.h +++ b/torch/csrc/jit/script/compiler.h @@ -116,7 +116,7 @@ using Resolver = std::function<std::shared_ptr<SugaredValue>(const std::string& void defineMethodsInModule( Module & m, const std::vector<Def>& definitions, - const Resolver& resolver, /* determines how we handle free variables*/ + const std::vector<Resolver>& resolvers, /* determines how we handle free variables in each definition*/ std::shared_ptr<SugaredValue> self /* if non-null, the first argument to each def, is bound to this value */ ); @@ -128,6 +128,7 @@ std::shared_ptr<Graph> compileFunction(Def def, const Resolver& resolver); // a SimpleValue, otherwise pack all the values into a Tuple. std::shared_ptr<SugaredValue> packOutputs(Graph& g, at::ArrayRef<Value*> values); std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs); +void ensureSizeMatches(SourceRange loc, size_t expected, size_t actual, const std::string& what); } // namespace script diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 0c0f23ec68..fad8729b44 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -15,12 +15,6 @@ using ResolutionCallback = std::function<py::function(std::string)>; #define VISIBILITY_HIDDEN __attribute__((visibility("hidden"))) #endif -static void ensureSizeMatches(SourceRange loc, size_t expected, size_t actual, const std::string& what) { - if(expected != actual) { - throw ErrorReport(loc) << "expected " << expected << " " << what << " but found " << actual; - } -} - static std::string typeString(py::handle h) { return py::str(h.get_type().attr("__name__")); } @@ -172,9 +166,7 @@ struct MethodValue : public SugaredValue { if(attributes.size() != 0) { throw ErrorReport(loc) << "not yet implemented - calls to script methods using keyword arguments"; } - ensureSizeMatches(loc, method.num_inputs(), inputs.size(), "inputs"); - auto outputs = caller.emit_call_to(method, inputs); - return packOutputs(*caller.graph(), outputs); + return packOutputs(*caller.graph(), caller.emit_call_to(loc, method, inputs)); } private: std::shared_ptr<Module> module; @@ -298,11 +290,15 @@ void initJitScriptBindings(PyObject* module) { auto self = has_self ? std::make_shared<ModuleValue>(m.shared_from_this()) : nullptr; return defineMethodsInModule(m, script, pythonResolver(rcb), self); }) - .def("_create_method", [](Module& m, Def def, ResolutionCallback rcb) { + .def("_create_methods", [](Module& m, const std::vector<Def>& defs, const std::vector<ResolutionCallback>& rcbs) { + std::vector<Resolver> resolvers; + for(auto & callback : rcbs) { + resolvers.push_back(pythonResolver(callback)); + } defineMethodsInModule( m, - { def }, - pythonResolver(rcb), + defs, + resolvers, std::make_shared<ModuleValue>(m.shared_from_this())); }) .def("_get_method", diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp index 13bc04d20b..da72b283ef 100644 --- a/torch/csrc/jit/script/module.cpp +++ b/torch/csrc/jit/script/module.cpp @@ -1,12 +1,25 @@ #include "torch/csrc/jit/script/module.h" #include "torch/csrc/jit/script/compiler.h" +#include "torch/csrc/jit/script/error_report.h" namespace torch { namespace jit { namespace script { -std::vector<Value*> Method::emit_call_to(Method & callee, ArrayRef<Value*> inputs) { + +struct RecursiveMethodCallError : public std::exception {}; +void placeholderCreator(Method&) { + throw RecursiveMethodCallError(); +} + +std::vector<Value*> Method::emit_call_to(SourceRange loc, Method & callee, ArrayRef<Value*> inputs) { JIT_ASSERT(!executor); + try { + callee.ensure_defined(); + } catch (RecursiveMethodCallError&) { + throw ErrorReport(loc) << " method '" << callee.name() + << "' is called recursively involving this call site. Recursive calls are not supported"; + } auto fn = callee.graph(); - JIT_ASSERT(inputs.size() == callee.num_inputs()); + ensureSizeMatches(loc, callee.num_inputs(), inputs.size(), "inputs"); std::vector<Value*> all_inputs = inputs; // parameters to callee method (which become parameters to _this_ method // if they were not already) @@ -16,4 +29,13 @@ std::vector<Value*> Method::emit_call_to(Method & callee, ArrayRef<Value*> input return inlineCallTo(*graph(), *callee.graph(), all_inputs); } +void Method::ensure_defined() { + if(method_creator) { + auto creator = method_creator; + method_creator = placeholderCreator; + creator(*this); + method_creator = nullptr; + } +} + }}} diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index f07007415b..7903e7a569 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -3,6 +3,7 @@ #include "torch/csrc/jit/graph_executor.h" #include "torch/csrc/autograd/variable.h" #include <ATen/optional.h> +#include <functional> // This file contains classes which assist in desugaring Python style // modules and their methods into flattened graphs which don't have any @@ -19,24 +20,28 @@ namespace torch { namespace jit { namespace script { // Note: because Method/Module are exposed to python these // classes use python method naming conventions +struct SourceRange; + struct Method { Method(std::string name, bool optimize, std::shared_ptr<Graph> graph, - std::vector<at::Tensor*> initial_members) + std::vector<at::Tensor*> initial_members, + std::function<void(Method&)> method_creator) : name_(std::move(name)) , graph_(std::move(graph)) , optimize(optimize) - , member_inputs(std::move(initial_members)) { + , member_inputs(std::move(initial_members)) + , method_creator(method_creator) { JIT_ASSERT(graph_->inputs().size() >= member_inputs.size()); int i = graph_->inputs().size() - member_inputs.size(); for(at::Tensor* member : member_inputs) { member_input_index[member] = i++; } } - + variable_tensor_list run(variable_tensor_list && inputs) { std::call_once(executor_init, [&]{ - executor = GraphExecutor(graph_, optimize); + executor = GraphExecutor(graph(), optimize); }); for(auto tp : member_inputs) { inputs.push_back(*tp); @@ -54,10 +59,13 @@ struct Method { // adding any extra parameters necessary to do this call // defined here to keep details of member_input handling confined to this class - std::vector<Value*> emit_call_to(Method & callee, ArrayRef<Value*> inputs); + std::vector<Value*> emit_call_to(SourceRange loc, Method & callee, ArrayRef<Value*> inputs); + // if this isn't yet defined, run its method_creator function + void ensure_defined(); + size_t num_inputs() const { - return graph_->inputs().size() - member_inputs.size(); + return graph()->inputs().size() - member_inputs.size(); } Value * get_or_add_parameter(at::Tensor* slot) { auto it = member_input_index.find(slot); @@ -94,6 +102,11 @@ private: // std::vector<at::Tensor*> member_outputs; std::once_flag executor_init; + + // an optional function that actually creates the method when emit_call_to(this,...) + // is first called. + // this is used by the compiler so that it can construct methods out of order + std::function<void(Method&)> method_creator; }; struct Module; @@ -205,11 +218,14 @@ struct Module : public std::enable_shared_from_this<Module> { modules.insert(name, {name, std::move(module)}); } - Method& create_method(const std::string & name, std::shared_ptr<Graph> graph = nullptr, std::vector<at::Tensor*> member_inputs = {}) { - if(!graph) - graph = std::make_shared<Graph>(); - std::unique_ptr<Method> method(new Method(name, optimize, std::move(graph), std::move(member_inputs))); + Method& create_method(const std::string & name, std::shared_ptr<Graph> graph, std::vector<at::Tensor*> member_inputs) { + JIT_ASSERT(graph); + std::unique_ptr<Method> method(new Method(name, optimize, std::move(graph), std::move(member_inputs), nullptr)); + return *methods.insert(name, std::move(method)); + } + Method& create_method(const std::string & name, std::function<void(Method&)> creator) { + std::unique_ptr<Method> method(new Method(name, optimize, std::make_shared<Graph>(), {}, creator)); return *methods.insert(name, std::move(method)); } |