diff options
author | James Sun <jamessun@fb.com> | 2018-12-17 20:28:00 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-12-17 20:34:26 -0800 |
commit | e37a22128eca7ccac6e289659587a9e1bfe6d242 (patch) | |
tree | 4969b8549da2555b4c8d009e665151605c64a386 | |
parent | bd958cde685c2de67ecf691934470ef3c289e00d (diff) | |
download | pytorch-e37a22128eca7ccac6e289659587a9e1bfe6d242.tar.gz pytorch-e37a22128eca7ccac6e289659587a9e1bfe6d242.tar.bz2 pytorch-e37a22128eca7ccac6e289659587a9e1bfe6d242.zip |
Allow tracing with fork/wait (#15184)
Summary:
There is still limitation on this: if a script module is somewhere
in the trace, the inputs/outputs can only be tensors or tuples of
tensors.
resolves #15052
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15184
Differential Revision: D13457691
Pulled By: highker
fbshipit-source-id: 8fe46afc41357a0eb8eadd83f687b31d074deb0e
-rw-r--r-- | aten/src/ATen/core/jit_type.h | 3 | ||||
-rw-r--r-- | test/test_jit.py | 53 | ||||
-rw-r--r-- | torch/csrc/jit/graph_executor.cpp | 11 | ||||
-rw-r--r-- | torch/csrc/jit/tracer.cpp | 27 | ||||
-rw-r--r-- | torch/csrc/jit/tracer.h | 14 |
5 files changed, 99 insertions, 9 deletions
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 9971551890..c057963af3 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -532,6 +532,9 @@ struct CAFFE2_API FutureType : public SingleElementType<TypeKind::FutureType, Fu ss << "Future[" << getElementType()->python_str() << "]"; return ss.str(); } + TypePtr createWithContained(std::vector<TypePtr> contained_types) const override { + return create(contained_types.at(0)); + } private: FutureType(TypePtr elem) : SingleElementType(elem) {} }; diff --git a/test/test_jit.py b/test/test_jit.py index 3e072de08e..86d3410354 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -11222,6 +11222,59 @@ class TestAsync(JitTestCase): self.assertEqual(y2, foo2(x1, x2)) self.assertEqual(y3, foo3(x1, x2, x3)) + def test_async_script_trace(self): + class Traced(nn.Module): + def __init__(self): + super(Traced, self).__init__() + + def forward(self, x): + return tuple([torch.neg(x), x]) + + class Module(torch.jit.ScriptModule): + def __init__(self): + super(Module, self).__init__(False) + x = torch.rand(3, 3) + self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True) + + @torch.jit.script_method + def forward(self, x): + # type: (Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor] + future1 = torch.jit._fork(self.traced, x) + future2 = torch.jit._fork(torch.neg, x) + + tensor_tuple = torch.jit._wait(future1) + tensor_single = torch.jit._wait(future2) + + tensor_list = [] + tensor_list.append(tensor_tuple[0]) + tensor_list.append(tensor_single) + + # return a nested structure of tensors + return (tensor_list, tensor_tuple, tensor_tuple[1]) + + class Tuple(nn.Module): + def __init__(self): + super(Tuple, self).__init__() + self.module = Module() + + def forward(self, x): + z = torch.neg(x) + y = self.module(x) + list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]] + return tuple(list) + + x = torch.rand(3, 3) + module = torch.jit.trace(Tuple(), (x), _force_outplace=True) + + # Make sure we have forks + self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2) + # Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs + self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=1) + self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=3, consider_subgraphs=True) + + y = torch.neg(x) + self.assertEqual(module(x), tuple([y, y, y, y, x, x])) + for test in autograd_method_tests: add_autograd_test(*test) diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 07b3019d14..cbdf89366f 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -513,7 +513,11 @@ private: // NB: we could just run the fallback in here and call it a day, but that would loose all // the control flow information we have in the graph. Thus, we run the fallback to // get the correct output values, but we will override the tracing states later. - getOrCompileFallback().run(stack); + { + // No need to trace a script module. + ResourceGuard guard(tracer::pauseTracing()); + getOrCompileFallback().run(stack); + } // Traces always have types propagated through them, so we make sure to // also propagate types through the graph we are inserting here. @@ -527,10 +531,7 @@ private: auto outputs = last(stack, num_outputs); for (size_t i = 0; i < outputs.size(); ++i) { - // We can't attach tracing states to scalars, so we have to skip them here - // TODO: Should we reinterpret them as scalar tensors instead? - if (!outputs[i].isTensor()) continue; - tracer::setValueTrace(outputs[i].toTensor(), output_values[i]); + tracer::setValueTrace(outputs[i], output_values[i]); } } diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index 91b333c332..f86ae6dd99 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -37,6 +37,33 @@ thread_local std::shared_ptr<TracingState> tracing_state; } // namespace detail +void setValueTrace(const IValue &v, Value *value) { + if (v.isTensor()) { + auto var = v.toTensor(); + JIT_ASSERT(var.defined()); + getTracingState()->value_map[var] = value; + } else if (v.isTensorList()) { + auto& outputs = v.toTensorList()->elements(); + auto graph = getTracingState()->graph; + Node * unpack_node = graph->appendNode(graph->create(prim::ListUnpack, {value}, outputs.size())); + for (size_t i = 0; i < outputs.size(); ++i) { + setValueTrace(outputs[i], unpack_node->outputs()[i]); + } + } else if (v.isTuple()) { + auto& outputs = v.toTuple()->elements(); + auto graph = getTracingState()->graph; + Node * unpack_node = graph->appendNode(graph->create(prim::TupleUnpack, {value}, outputs.size())); + for (size_t i = 0; i < outputs.size(); ++i) { + setValueTrace(outputs[i], unpack_node->outputs()[i]); + } + } else { + std::ostringstream os; + os << "Tracer cannot set value trace for type " << v.tagKind() << ". " + << "Supported types are tensor, tensor list, and tuple of tensors."; + throw std::runtime_error(os.str()); + } +} + void addInputs(Node *n, const char * name, int64_t value) { using ArgumentStash = jit::tracer::ArgumentStash; if (ArgumentStash::hasValue(name)) { diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h index 34b285ff6d..691a1d90f5 100644 --- a/torch/csrc/jit/tracer.h +++ b/torch/csrc/jit/tracer.h @@ -32,16 +32,22 @@ TORCH_API void setRecordSourceLocation(void (*v)(Node*)); // Having finished adding a new 'node' to the graph IR 'setValueTrace' associates // this node with an output variable, so that further operations involving this // variable know which node in the IR to reference. -inline void setValueTrace(const Variable& var, Value *value) { - JIT_ASSERT(var.defined()); - getTracingState()->value_map[var] = value; -} +TORCH_API void setValueTrace(const IValue& v, Value* value); inline void delValueTrace(const Variable& var) { JIT_ASSERT(var.defined()); getTracingState()->value_map.erase(var); } +inline std::function<void()> pauseTracing() { + std::shared_ptr<tracer::TracingState> state = getTracingState(); + tracer::setTracingState(nullptr); + + return [state]() { + tracer::setTracingState(state); + }; +} + // Given a variable 'var', return the 'node' which represents the instruction // which computes the value of this variable in the IR. // Here, we interpret untraced variables as constants that are just embedded |