diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2017-10-25 21:26:13 +0200 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-10-27 02:40:09 +0530 |
commit | fa0f3cf98a57543a5a362ecde5133c42bc9443df (patch) | |
tree | aaae5d207d518912524182273092ca9b085465de /torch | |
parent | 61afb0d519eeb165a80f43c285ab2c528fc10879 (diff) | |
download | pytorch-fa0f3cf98a57543a5a362ecde5133c42bc9443df.tar.gz pytorch-fa0f3cf98a57543a5a362ecde5133c42bc9443df.tar.bz2 pytorch-fa0f3cf98a57543a5a362ecde5133c42bc9443df.zip |
Re-enable and fix most JIT tests
Diffstat (limited to 'torch')
-rw-r--r-- | torch/autograd/variable.py | 18 | ||||
-rw-r--r-- | torch/csrc/autograd/functions/jit_closure.cpp | 73 | ||||
-rw-r--r-- | torch/csrc/autograd/functions/jit_closure.h | 2 | ||||
-rw-r--r-- | torch/csrc/autograd/python_function.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/interned_strings.h | 2 | ||||
-rw-r--r-- | torch/csrc/jit/ir.cpp | 12 | ||||
-rw-r--r-- | torch/csrc/jit/passes/common_subexpression_elimination.cpp | 21 | ||||
-rw-r--r-- | torch/csrc/jit/tracer.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/tracer.h | 3 | ||||
-rw-r--r-- | torch/csrc/jit/tracer_state.h | 3 | ||||
-rw-r--r-- | torch/jit/__init__.py | 1 | ||||
-rw-r--r-- | torch/jit/passes/inplace.py | 2 |
12 files changed, 98 insertions, 43 deletions
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py index b3b5a92bbe..8b2655320c 100644 --- a/torch/autograd/variable.py +++ b/torch/autograd/variable.py @@ -426,15 +426,12 @@ class Variable(_C._VariableBase): def bernoulli(self): return Bernoulli.apply(self) - def __add__(self, other): - return self.add(other) - __radd__ = __add__ + __radd__ = __add__ = _C._VariableBase.add def __iadd__(self, other): return self.add_(other) - def __sub__(self, other): - return self.sub(other) + __sub__ = _C._VariableBase.sub def __isub__(self, other): return self.sub_(other) @@ -442,9 +439,7 @@ class Variable(_C._VariableBase): def __rsub__(self, other): return -self + other - def __mul__(self, other): - return self.mul(other) - __rmul__ = __mul__ + __rmul__ = __mul__ = _C._VariableBase.mul def __imul__(self, other): return self.mul_(other) @@ -454,9 +449,7 @@ class Variable(_C._VariableBase): return NotImplemented return self.matmul(other) - def __div__(self, other): - return self.div(other) - __truediv__ = __div__ + __truediv__ = __div__ = _C._VariableBase.div def __rdiv__(self, other): return self.reciprocal() * other @@ -465,8 +458,7 @@ class Variable(_C._VariableBase): def __idiv__(self, other): return self.div_(other) - def __pow__(self, other): - return self.pow(other) + __pow__ = _C._VariableBase.pow def __ipow__(self, other): raise NotImplementedError("in-place pow not implemented") diff --git a/torch/csrc/autograd/functions/jit_closure.cpp b/torch/csrc/autograd/functions/jit_closure.cpp index d1d6e51fff..9f8fd92598 100644 --- a/torch/csrc/autograd/functions/jit_closure.cpp +++ b/torch/csrc/autograd/functions/jit_closure.cpp @@ -11,6 +11,7 @@ #include "torch/csrc/autograd/python_engine.h" #include "torch/csrc/autograd/python_variable.h" #include "torch/csrc/autograd/python_function.h" +#include "torch/csrc/jit/generated/aten_dispatch.h" #ifdef WITH_CUDA #include "torch/csrc/jit/fusion_compiler.h" #endif @@ -115,20 +116,28 @@ struct EmitNull : public Function { }; }; -// A hack that will let us implement some of the ops we care -// about before the major Python -> C++ Function migration struct LambdaFunction : public Function { + LambdaFunction(const jit::TensorOp& op) + : LambdaFunction(op.num_inputs, op.op) { + this->name_ = op.name; + } + LambdaFunction(int num_inputs, std::function<variable_list(const variable_list&)> fn) - : fn(fn) { + : fn_(fn) { this->is_executable = true; this->num_inputs = num_inputs; } - virtual variable_list apply(const variable_list& inputs) { - return fn(inputs); + virtual std::string name() override { + return name_.size() == 0 ? "LambdaFunction" : name_; } - std::function<variable_list(const variable_list&)> fn; + virtual variable_list apply(const variable_list& inputs) override { + return fn_(inputs); + } + + std::string name_; + std::function<variable_list(const variable_list&)> fn_; }; // Wraps a PythonOp and dispatches calls to Functions implemented in Python @@ -583,7 +592,7 @@ struct StageClosure { IR_ELSEIF(Concat) return std::make_shared<torch::autograd::Cat>(value->i(kaxis)); IR_ELSE() - throw std::runtime_error(std::string("unrecognized NodeKind: ") + symbolToString(node->kind())); + return std::make_shared<LambdaFunction>(getTensorOp(node)); IR_END() } @@ -671,7 +680,7 @@ struct StageClosure { // Roots for a call to the engine. The list contains function in this order: // [ apply input roots | prev stage input roots | constant factory ] function_list roots; - std::vector<VariableFlags> var_flags; + std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>> var_flags; // Output node std::shared_ptr<Function> output; @@ -703,15 +712,14 @@ struct MultiStageClosure { }; AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc) - : AutogradClosure(desc, 0, {}) {} + : AutogradClosure(desc, 0) {} // TODO: there's a lot processing involved in creating a new AutogradClosure instance, // so it might be worth to keep a pool of unused instances (or at least their attrs) // for all stages. We can't save saved_vars and saved_handles, but all callbacks // can be made reusable. -AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage, FunctionFlags &&f) - : Function(std::move(f)) - , desc(desc) +AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage) + : desc(desc) , stage(stage) { auto & stage_desc = desc->stages[stage]; @@ -777,10 +785,10 @@ variable_list AutogradClosure::apply(const variable_list& inputs) { // Validate inputs auto num_inputs = inputs.size(); - if (num_inputs != stage_closure.var_flags.size()) + if (num_inputs != stage_closure.var_flags.first.size()) throw std::runtime_error("AutogradClosure received an incorrect number of inputs"); for (std::size_t i = 0; i < num_inputs; ++i) { - auto & flags = stage_closure.var_flags[i]; + auto & flags = stage_closure.var_flags.first[i]; if (!flags.verify(inputs[i])) throw std::runtime_error("AutogradClosure received inputs with different flags"); } @@ -797,16 +805,15 @@ variable_list AutogradClosure::apply(const variable_list& inputs) { auto& engine = python::PythonEngine::getDefaultEngine(); engine.execute(stage_closure.roots, input_leaves, true, pre_callbacks, post_callbacks); - // See Note [Null-edge pruning] - auto relevant_inputs = filter(inputs, [](const Variable& var) { return var.defined() && var.requires_grad(); }); - auto result = wrap_outputs(relevant_inputs, std::move(outputs), [this](FunctionFlags f) -> std::shared_ptr<Function> { + // Create the backward function lazily + auto make_grad_fn = [this]() -> std::shared_ptr<Function> { if (this->stage == this->desc->stages.size() - 1) { std::string msg = "JIT closure compiled only for "; msg += std::to_string(this->stage); msg += " backwards"; - return std::make_shared<Error>(std::move(msg), std::move(f)); + return std::make_shared<Error>(std::move(msg)); } - auto bw_fn = std::shared_ptr<AutogradClosure>(new AutogradClosure(this->desc, this->stage + 1, std::move(f))); + auto bw_fn = std::shared_ptr<AutogradClosure>(new AutogradClosure(this->desc, this->stage + 1)); // TODO: don't make a full copy of saved_* - copy only the things that bw needs bw_fn->saved_vars = this->saved_vars; bw_fn->saved_vars.insert(std::make_move_iterator(this->captured_vars.begin()), @@ -824,7 +831,33 @@ variable_list AutogradClosure::apply(const variable_list& inputs) { // was run, so it must have been executable). bw_fn->is_executable = true; return bw_fn; - }); + }; + + // See Note [Null-edge pruning] + variable_list result; + auto num_outputs = outputs.size(); + std::shared_ptr<Function> grad_fn; + JIT_ASSERT(outputs.size() == stage_closure.var_flags.second.size()); + for (std::size_t i = 0; i < num_outputs; ++i) { + auto & flags = stage_closure.var_flags.second[i]; + if (flags.requires_grad) { + if (!grad_fn) grad_fn = make_grad_fn(); + result.push_back(make_variable(outputs[i], grad_fn)); + } else { + result.push_back(make_variable(outputs[i], flags.requires_grad, flags.is_volatile)); + } + } + + // If we created grad_fn for any of the outputs, we also need to fill in next_functions + if (grad_fn) { + for (auto & input : inputs) { + if (!input.requires_grad()) continue; + grad_fn->next_functions.emplace_back( + input.grad_fn() ? input.grad_fn() : input.grad_accumulator(), + input.output_nr()); + } + } + captured_vars.clear(); captured_handles.clear(); outputs.clear(); diff --git a/torch/csrc/autograd/functions/jit_closure.h b/torch/csrc/autograd/functions/jit_closure.h index 25f2ca2a33..6d905e63c4 100644 --- a/torch/csrc/autograd/functions/jit_closure.h +++ b/torch/csrc/autograd/functions/jit_closure.h @@ -28,7 +28,7 @@ struct AutogradClosure : public Function { virtual variable_list apply(const variable_list& inputs) override; private: - AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage, FunctionFlags&& f); + AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage); variable_list rewrapInputs(const variable_list& inputs); diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index b6335ed7c5..29a7a41cf4 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -711,7 +711,7 @@ static void _trace_create(PyObject* op_obj, THPFunction* bw_obj, sel->inferTypeFrom(output.data()); tracer::setValueTrace(tracing_state, output, sel); } - this_expr->i_(k__inplace, is_inplace); + this_expr->i_(kinplace, is_inplace); // See definition in function.cpp. THPObjectPtr passes_py_bool {PyObject_GetAttrString(op_obj, "is_traceable")}; diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h index 08362950e3..f6542c7806 100644 --- a/torch/csrc/jit/interned_strings.h +++ b/torch/csrc/jit/interned_strings.h @@ -64,7 +64,7 @@ _(perm) \ _(shape) \ _(axes) \ _(group) \ -_(__inplace) +_(inplace) enum BuiltinSymbol { #define DEFINE_SYMBOL(s) \ diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 0420ad1097..45b9408528 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -177,7 +177,17 @@ void printAttributes(std::ostream & out, Node * n) { case AttributeKind::t: { at::Tensor t = n->t(name); - if (t.numel() <= max_tensor_display_size) { + // 1-elem tensors are usually boxed scalars, so print them like it + if (t.numel() == 1) { + auto scalar = at::Scalar(t.view({})).local(); + out << "{"; + if (scalar.isFloatingPoint()) { + out << scalar.toDouble(); + } else { + out << scalar.toLong(); + } + out << "}"; + } else if (t.numel() <= max_tensor_display_size) { // TODO: This is awful code. Also it doesn't work on Windows. std::ostringstream tensor_ss; tensor_ss << t; diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp index 93d9b0865f..669edbafd9 100644 --- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp +++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp @@ -8,6 +8,16 @@ namespace torch { namespace jit { +namespace { + +bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) { + return &lhs.type() == &rhs.type() && lhs.equal(rhs); +} + +bool tensorListEqual(const std::vector<at::Tensor>& lhs, const std::vector<at::Tensor>& rhs) { + if (lhs.size() != rhs.size()) return false; + return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual); +}; // Check whether two nodes have the same attributes in CSE. @@ -24,6 +34,8 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) { auto lnames = lhs->attributeNames(); auto rnames = rhs->attributeNames(); + std::sort(lnames.begin(), lnames.end()); + std::sort(rnames.begin(), rnames.end()); if (lnames != rnames) return false; for (auto name : lnames) { @@ -40,8 +52,13 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) { COMPARE_ATTRIBUTEVALUE(is) COMPARE_ATTRIBUTEVALUE(s) COMPARE_ATTRIBUTEVALUE(ss) + case AttributeKind::t: + if (!tensorEqual(lhs->t(name), rhs->t(name))) return false; + break; + case AttributeKind::ts: + if (!tensorListEqual(lhs->ts(name), rhs->ts(name))) return false; default: - // NB: Comparison of nodes with tensor(s) or graph(s) will return false. + // NB: Comparison of nodes with graph(s) will return false. return false; } @@ -92,6 +109,8 @@ struct EqualNodeCSE { } }; +} // anonymous namespace + // The function implements common subexpression elimination. // Since the nodes are visited in topological order, one pass is enough. void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) { diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index 89f298adc9..f566c9abae 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -56,7 +56,7 @@ struct TraceEval : autograd::Eval { setValueTrace(tracing_state, input, input_node); input_node->inferTypeFrom(input.data()); } - tracing_state->var_flags.at(graph->stage()) = detail::getVarFlags(inputs); + tracing_state->var_flags.at(graph->stage()).first = detail::getVarFlags(inputs); } void exitTrace(const variable_list& inputs, const variable_list& outputs) { diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h index 62e283d727..ce832e302a 100644 --- a/torch/csrc/jit/tracer.h +++ b/torch/csrc/jit/tracer.h @@ -200,7 +200,7 @@ inline std::shared_ptr<TracingState> enter(std::vector<TraceInput>&& trace_input } } // TODO: this might not work with the way we handle buffers - state->var_flags[0] = detail::getVarFlags(inputs); + state->var_flags[0].first = detail::getVarFlags(inputs); state->active = true; state->inputs = inputs; return state; @@ -214,6 +214,7 @@ inline void _exit(const std::shared_ptr<TracingState>& state, const variable_lis state->graph->registerOutput(getValueTrace(state, output, true)); } state->active = false; + state->var_flags[state->graph->stage()].second = detail::getVarFlags(outputs); } // Marks a backwards subgraph that should be traced as the next stage. diff --git a/torch/csrc/jit/tracer_state.h b/torch/csrc/jit/tracer_state.h index 8b383ca0f7..3c2c32d80b 100644 --- a/torch/csrc/jit/tracer_state.h +++ b/torch/csrc/jit/tracer_state.h @@ -64,7 +64,8 @@ struct TracingState : public std::enable_shared_from_this<TracingState> { // TODO: Perhaps, turn this into an owning reference. The buffers // are persistent, so this won't lead to a leak. std::unordered_map<void*, Node*> buffer_map; - std::vector<std::vector<VariableFlags>> var_flags; + // A pair of (input_flags, output_flags) for each stage + std::vector<std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>>> var_flags; std::vector<function_list> output_edges; std::mutex mutex; diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 26b85919c6..1d7f442ffc 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -469,7 +469,6 @@ class TraceForKey(object): # It's important to always run DCE, because backward can create a lot of unnecessary nodes _run_pass(torch._C._jit_pass_dce, complete_trace) - _run_pass(torch._C._jit_pass_onnx, complete_trace) _run_pass(_passes._check_inplace, complete_trace) if self.optimize: _run_pass(torch._C._jit_pass_fuse, complete_trace) diff --git a/torch/jit/passes/inplace.py b/torch/jit/passes/inplace.py index 83246cd09b..0ea910932d 100644 --- a/torch/jit/passes/inplace.py +++ b/torch/jit/passes/inplace.py @@ -7,5 +7,5 @@ def _check_inplace(trace): graph = trace.graph() for node in graph.nodes(): if node.kind() == 'PythonOp': - if node.i('__inplace'): + if node.i('inplace'): raise RuntimeError("inplace {} not supported in the JIT".format(node.pyname())) |