summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Sun <jamessun@fb.com>2018-12-17 20:28:00 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-12-17 20:34:26 -0800
commite37a22128eca7ccac6e289659587a9e1bfe6d242 (patch)
tree4969b8549da2555b4c8d009e665151605c64a386
parentbd958cde685c2de67ecf691934470ef3c289e00d (diff)
downloadpytorch-e37a22128eca7ccac6e289659587a9e1bfe6d242.tar.gz
pytorch-e37a22128eca7ccac6e289659587a9e1bfe6d242.tar.bz2
pytorch-e37a22128eca7ccac6e289659587a9e1bfe6d242.zip
Allow tracing with fork/wait (#15184)
Summary: There is still limitation on this: if a script module is somewhere in the trace, the inputs/outputs can only be tensors or tuples of tensors. resolves #15052 Pull Request resolved: https://github.com/pytorch/pytorch/pull/15184 Differential Revision: D13457691 Pulled By: highker fbshipit-source-id: 8fe46afc41357a0eb8eadd83f687b31d074deb0e
-rw-r--r--aten/src/ATen/core/jit_type.h3
-rw-r--r--test/test_jit.py53
-rw-r--r--torch/csrc/jit/graph_executor.cpp11
-rw-r--r--torch/csrc/jit/tracer.cpp27
-rw-r--r--torch/csrc/jit/tracer.h14
5 files changed, 99 insertions, 9 deletions
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index 9971551890..c057963af3 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -532,6 +532,9 @@ struct CAFFE2_API FutureType : public SingleElementType<TypeKind::FutureType, Fu
ss << "Future[" << getElementType()->python_str() << "]";
return ss.str();
}
+ TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
+ return create(contained_types.at(0));
+ }
private:
FutureType(TypePtr elem) : SingleElementType(elem) {}
};
diff --git a/test/test_jit.py b/test/test_jit.py
index 3e072de08e..86d3410354 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -11222,6 +11222,59 @@ class TestAsync(JitTestCase):
self.assertEqual(y2, foo2(x1, x2))
self.assertEqual(y3, foo3(x1, x2, x3))
+ def test_async_script_trace(self):
+ class Traced(nn.Module):
+ def __init__(self):
+ super(Traced, self).__init__()
+
+ def forward(self, x):
+ return tuple([torch.neg(x), x])
+
+ class Module(torch.jit.ScriptModule):
+ def __init__(self):
+ super(Module, self).__init__(False)
+ x = torch.rand(3, 3)
+ self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
+
+ @torch.jit.script_method
+ def forward(self, x):
+ # type: (Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]
+ future1 = torch.jit._fork(self.traced, x)
+ future2 = torch.jit._fork(torch.neg, x)
+
+ tensor_tuple = torch.jit._wait(future1)
+ tensor_single = torch.jit._wait(future2)
+
+ tensor_list = []
+ tensor_list.append(tensor_tuple[0])
+ tensor_list.append(tensor_single)
+
+ # return a nested structure of tensors
+ return (tensor_list, tensor_tuple, tensor_tuple[1])
+
+ class Tuple(nn.Module):
+ def __init__(self):
+ super(Tuple, self).__init__()
+ self.module = Module()
+
+ def forward(self, x):
+ z = torch.neg(x)
+ y = self.module(x)
+ list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]]
+ return tuple(list)
+
+ x = torch.rand(3, 3)
+ module = torch.jit.trace(Tuple(), (x), _force_outplace=True)
+
+ # Make sure we have forks
+ self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
+ # Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs
+ self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=1)
+ self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=3, consider_subgraphs=True)
+
+ y = torch.neg(x)
+ self.assertEqual(module(x), tuple([y, y, y, y, x, x]))
+
for test in autograd_method_tests:
add_autograd_test(*test)
diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp
index 07b3019d14..cbdf89366f 100644
--- a/torch/csrc/jit/graph_executor.cpp
+++ b/torch/csrc/jit/graph_executor.cpp
@@ -513,7 +513,11 @@ private:
// NB: we could just run the fallback in here and call it a day, but that would loose all
// the control flow information we have in the graph. Thus, we run the fallback to
// get the correct output values, but we will override the tracing states later.
- getOrCompileFallback().run(stack);
+ {
+ // No need to trace a script module.
+ ResourceGuard guard(tracer::pauseTracing());
+ getOrCompileFallback().run(stack);
+ }
// Traces always have types propagated through them, so we make sure to
// also propagate types through the graph we are inserting here.
@@ -527,10 +531,7 @@ private:
auto outputs = last(stack, num_outputs);
for (size_t i = 0; i < outputs.size(); ++i) {
- // We can't attach tracing states to scalars, so we have to skip them here
- // TODO: Should we reinterpret them as scalar tensors instead?
- if (!outputs[i].isTensor()) continue;
- tracer::setValueTrace(outputs[i].toTensor(), output_values[i]);
+ tracer::setValueTrace(outputs[i], output_values[i]);
}
}
diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp
index 91b333c332..f86ae6dd99 100644
--- a/torch/csrc/jit/tracer.cpp
+++ b/torch/csrc/jit/tracer.cpp
@@ -37,6 +37,33 @@ thread_local std::shared_ptr<TracingState> tracing_state;
} // namespace detail
+void setValueTrace(const IValue &v, Value *value) {
+ if (v.isTensor()) {
+ auto var = v.toTensor();
+ JIT_ASSERT(var.defined());
+ getTracingState()->value_map[var] = value;
+ } else if (v.isTensorList()) {
+ auto& outputs = v.toTensorList()->elements();
+ auto graph = getTracingState()->graph;
+ Node * unpack_node = graph->appendNode(graph->create(prim::ListUnpack, {value}, outputs.size()));
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ setValueTrace(outputs[i], unpack_node->outputs()[i]);
+ }
+ } else if (v.isTuple()) {
+ auto& outputs = v.toTuple()->elements();
+ auto graph = getTracingState()->graph;
+ Node * unpack_node = graph->appendNode(graph->create(prim::TupleUnpack, {value}, outputs.size()));
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ setValueTrace(outputs[i], unpack_node->outputs()[i]);
+ }
+ } else {
+ std::ostringstream os;
+ os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
+ << "Supported types are tensor, tensor list, and tuple of tensors.";
+ throw std::runtime_error(os.str());
+ }
+}
+
void addInputs(Node *n, const char * name, int64_t value) {
using ArgumentStash = jit::tracer::ArgumentStash;
if (ArgumentStash::hasValue(name)) {
diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h
index 34b285ff6d..691a1d90f5 100644
--- a/torch/csrc/jit/tracer.h
+++ b/torch/csrc/jit/tracer.h
@@ -32,16 +32,22 @@ TORCH_API void setRecordSourceLocation(void (*v)(Node*));
// Having finished adding a new 'node' to the graph IR 'setValueTrace' associates
// this node with an output variable, so that further operations involving this
// variable know which node in the IR to reference.
-inline void setValueTrace(const Variable& var, Value *value) {
- JIT_ASSERT(var.defined());
- getTracingState()->value_map[var] = value;
-}
+TORCH_API void setValueTrace(const IValue& v, Value* value);
inline void delValueTrace(const Variable& var) {
JIT_ASSERT(var.defined());
getTracingState()->value_map.erase(var);
}
+inline std::function<void()> pauseTracing() {
+ std::shared_ptr<tracer::TracingState> state = getTracingState();
+ tracer::setTracingState(nullptr);
+
+ return [state]() {
+ tracer::setTracingState(state);
+ };
+}
+
// Given a variable 'var', return the 'node' which represents the instruction
// which computes the value of this variable in the IR.
// Here, we interpret untraced variables as constants that are just embedded