summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
Diffstat (limited to 'torch')
-rw-r--r--torch/autograd/variable.py18
-rw-r--r--torch/csrc/autograd/functions/jit_closure.cpp73
-rw-r--r--torch/csrc/autograd/functions/jit_closure.h2
-rw-r--r--torch/csrc/autograd/python_function.cpp2
-rw-r--r--torch/csrc/jit/interned_strings.h2
-rw-r--r--torch/csrc/jit/ir.cpp12
-rw-r--r--torch/csrc/jit/passes/common_subexpression_elimination.cpp21
-rw-r--r--torch/csrc/jit/tracer.cpp2
-rw-r--r--torch/csrc/jit/tracer.h3
-rw-r--r--torch/csrc/jit/tracer_state.h3
-rw-r--r--torch/jit/__init__.py1
-rw-r--r--torch/jit/passes/inplace.py2
12 files changed, 98 insertions, 43 deletions
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py
index b3b5a92bbe..8b2655320c 100644
--- a/torch/autograd/variable.py
+++ b/torch/autograd/variable.py
@@ -426,15 +426,12 @@ class Variable(_C._VariableBase):
def bernoulli(self):
return Bernoulli.apply(self)
- def __add__(self, other):
- return self.add(other)
- __radd__ = __add__
+ __radd__ = __add__ = _C._VariableBase.add
def __iadd__(self, other):
return self.add_(other)
- def __sub__(self, other):
- return self.sub(other)
+ __sub__ = _C._VariableBase.sub
def __isub__(self, other):
return self.sub_(other)
@@ -442,9 +439,7 @@ class Variable(_C._VariableBase):
def __rsub__(self, other):
return -self + other
- def __mul__(self, other):
- return self.mul(other)
- __rmul__ = __mul__
+ __rmul__ = __mul__ = _C._VariableBase.mul
def __imul__(self, other):
return self.mul_(other)
@@ -454,9 +449,7 @@ class Variable(_C._VariableBase):
return NotImplemented
return self.matmul(other)
- def __div__(self, other):
- return self.div(other)
- __truediv__ = __div__
+ __truediv__ = __div__ = _C._VariableBase.div
def __rdiv__(self, other):
return self.reciprocal() * other
@@ -465,8 +458,7 @@ class Variable(_C._VariableBase):
def __idiv__(self, other):
return self.div_(other)
- def __pow__(self, other):
- return self.pow(other)
+ __pow__ = _C._VariableBase.pow
def __ipow__(self, other):
raise NotImplementedError("in-place pow not implemented")
diff --git a/torch/csrc/autograd/functions/jit_closure.cpp b/torch/csrc/autograd/functions/jit_closure.cpp
index d1d6e51fff..9f8fd92598 100644
--- a/torch/csrc/autograd/functions/jit_closure.cpp
+++ b/torch/csrc/autograd/functions/jit_closure.cpp
@@ -11,6 +11,7 @@
#include "torch/csrc/autograd/python_engine.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/python_function.h"
+#include "torch/csrc/jit/generated/aten_dispatch.h"
#ifdef WITH_CUDA
#include "torch/csrc/jit/fusion_compiler.h"
#endif
@@ -115,20 +116,28 @@ struct EmitNull : public Function {
};
};
-// A hack that will let us implement some of the ops we care
-// about before the major Python -> C++ Function migration
struct LambdaFunction : public Function {
+ LambdaFunction(const jit::TensorOp& op)
+ : LambdaFunction(op.num_inputs, op.op) {
+ this->name_ = op.name;
+ }
+
LambdaFunction(int num_inputs, std::function<variable_list(const variable_list&)> fn)
- : fn(fn) {
+ : fn_(fn) {
this->is_executable = true;
this->num_inputs = num_inputs;
}
- virtual variable_list apply(const variable_list& inputs) {
- return fn(inputs);
+ virtual std::string name() override {
+ return name_.size() == 0 ? "LambdaFunction" : name_;
}
- std::function<variable_list(const variable_list&)> fn;
+ virtual variable_list apply(const variable_list& inputs) override {
+ return fn_(inputs);
+ }
+
+ std::string name_;
+ std::function<variable_list(const variable_list&)> fn_;
};
// Wraps a PythonOp and dispatches calls to Functions implemented in Python
@@ -583,7 +592,7 @@ struct StageClosure {
IR_ELSEIF(Concat)
return std::make_shared<torch::autograd::Cat>(value->i(kaxis));
IR_ELSE()
- throw std::runtime_error(std::string("unrecognized NodeKind: ") + symbolToString(node->kind()));
+ return std::make_shared<LambdaFunction>(getTensorOp(node));
IR_END()
}
@@ -671,7 +680,7 @@ struct StageClosure {
// Roots for a call to the engine. The list contains function in this order:
// [ apply input roots | prev stage input roots | constant factory ]
function_list roots;
- std::vector<VariableFlags> var_flags;
+ std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>> var_flags;
// Output node
std::shared_ptr<Function> output;
@@ -703,15 +712,14 @@ struct MultiStageClosure {
};
AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc)
- : AutogradClosure(desc, 0, {}) {}
+ : AutogradClosure(desc, 0) {}
// TODO: there's a lot processing involved in creating a new AutogradClosure instance,
// so it might be worth to keep a pool of unused instances (or at least their attrs)
// for all stages. We can't save saved_vars and saved_handles, but all callbacks
// can be made reusable.
-AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage, FunctionFlags &&f)
- : Function(std::move(f))
- , desc(desc)
+AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage)
+ : desc(desc)
, stage(stage) {
auto & stage_desc = desc->stages[stage];
@@ -777,10 +785,10 @@ variable_list AutogradClosure::apply(const variable_list& inputs) {
// Validate inputs
auto num_inputs = inputs.size();
- if (num_inputs != stage_closure.var_flags.size())
+ if (num_inputs != stage_closure.var_flags.first.size())
throw std::runtime_error("AutogradClosure received an incorrect number of inputs");
for (std::size_t i = 0; i < num_inputs; ++i) {
- auto & flags = stage_closure.var_flags[i];
+ auto & flags = stage_closure.var_flags.first[i];
if (!flags.verify(inputs[i]))
throw std::runtime_error("AutogradClosure received inputs with different flags");
}
@@ -797,16 +805,15 @@ variable_list AutogradClosure::apply(const variable_list& inputs) {
auto& engine = python::PythonEngine::getDefaultEngine();
engine.execute(stage_closure.roots, input_leaves, true, pre_callbacks, post_callbacks);
- // See Note [Null-edge pruning]
- auto relevant_inputs = filter(inputs, [](const Variable& var) { return var.defined() && var.requires_grad(); });
- auto result = wrap_outputs(relevant_inputs, std::move(outputs), [this](FunctionFlags f) -> std::shared_ptr<Function> {
+ // Create the backward function lazily
+ auto make_grad_fn = [this]() -> std::shared_ptr<Function> {
if (this->stage == this->desc->stages.size() - 1) {
std::string msg = "JIT closure compiled only for ";
msg += std::to_string(this->stage);
msg += " backwards";
- return std::make_shared<Error>(std::move(msg), std::move(f));
+ return std::make_shared<Error>(std::move(msg));
}
- auto bw_fn = std::shared_ptr<AutogradClosure>(new AutogradClosure(this->desc, this->stage + 1, std::move(f)));
+ auto bw_fn = std::shared_ptr<AutogradClosure>(new AutogradClosure(this->desc, this->stage + 1));
// TODO: don't make a full copy of saved_* - copy only the things that bw needs
bw_fn->saved_vars = this->saved_vars;
bw_fn->saved_vars.insert(std::make_move_iterator(this->captured_vars.begin()),
@@ -824,7 +831,33 @@ variable_list AutogradClosure::apply(const variable_list& inputs) {
// was run, so it must have been executable).
bw_fn->is_executable = true;
return bw_fn;
- });
+ };
+
+ // See Note [Null-edge pruning]
+ variable_list result;
+ auto num_outputs = outputs.size();
+ std::shared_ptr<Function> grad_fn;
+ JIT_ASSERT(outputs.size() == stage_closure.var_flags.second.size());
+ for (std::size_t i = 0; i < num_outputs; ++i) {
+ auto & flags = stage_closure.var_flags.second[i];
+ if (flags.requires_grad) {
+ if (!grad_fn) grad_fn = make_grad_fn();
+ result.push_back(make_variable(outputs[i], grad_fn));
+ } else {
+ result.push_back(make_variable(outputs[i], flags.requires_grad, flags.is_volatile));
+ }
+ }
+
+ // If we created grad_fn for any of the outputs, we also need to fill in next_functions
+ if (grad_fn) {
+ for (auto & input : inputs) {
+ if (!input.requires_grad()) continue;
+ grad_fn->next_functions.emplace_back(
+ input.grad_fn() ? input.grad_fn() : input.grad_accumulator(),
+ input.output_nr());
+ }
+ }
+
captured_vars.clear();
captured_handles.clear();
outputs.clear();
diff --git a/torch/csrc/autograd/functions/jit_closure.h b/torch/csrc/autograd/functions/jit_closure.h
index 25f2ca2a33..6d905e63c4 100644
--- a/torch/csrc/autograd/functions/jit_closure.h
+++ b/torch/csrc/autograd/functions/jit_closure.h
@@ -28,7 +28,7 @@ struct AutogradClosure : public Function {
virtual variable_list apply(const variable_list& inputs) override;
private:
- AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage, FunctionFlags&& f);
+ AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage);
variable_list rewrapInputs(const variable_list& inputs);
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index b6335ed7c5..29a7a41cf4 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -711,7 +711,7 @@ static void _trace_create(PyObject* op_obj, THPFunction* bw_obj,
sel->inferTypeFrom(output.data());
tracer::setValueTrace(tracing_state, output, sel);
}
- this_expr->i_(k__inplace, is_inplace);
+ this_expr->i_(kinplace, is_inplace);
// See definition in function.cpp.
THPObjectPtr passes_py_bool {PyObject_GetAttrString(op_obj, "is_traceable")};
diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h
index 08362950e3..f6542c7806 100644
--- a/torch/csrc/jit/interned_strings.h
+++ b/torch/csrc/jit/interned_strings.h
@@ -64,7 +64,7 @@ _(perm) \
_(shape) \
_(axes) \
_(group) \
-_(__inplace)
+_(inplace)
enum BuiltinSymbol {
#define DEFINE_SYMBOL(s) \
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index 0420ad1097..45b9408528 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -177,7 +177,17 @@ void printAttributes(std::ostream & out, Node * n) {
case AttributeKind::t:
{
at::Tensor t = n->t(name);
- if (t.numel() <= max_tensor_display_size) {
+ // 1-elem tensors are usually boxed scalars, so print them like it
+ if (t.numel() == 1) {
+ auto scalar = at::Scalar(t.view({})).local();
+ out << "{";
+ if (scalar.isFloatingPoint()) {
+ out << scalar.toDouble();
+ } else {
+ out << scalar.toLong();
+ }
+ out << "}";
+ } else if (t.numel() <= max_tensor_display_size) {
// TODO: This is awful code. Also it doesn't work on Windows.
std::ostringstream tensor_ss;
tensor_ss << t;
diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
index 93d9b0865f..669edbafd9 100644
--- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp
+++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
@@ -8,6 +8,16 @@
namespace torch { namespace jit {
+namespace {
+
+bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
+ return &lhs.type() == &rhs.type() && lhs.equal(rhs);
+}
+
+bool tensorListEqual(const std::vector<at::Tensor>& lhs, const std::vector<at::Tensor>& rhs) {
+ if (lhs.size() != rhs.size()) return false;
+ return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
+};
// Check whether two nodes have the same attributes in CSE.
@@ -24,6 +34,8 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
auto lnames = lhs->attributeNames();
auto rnames = rhs->attributeNames();
+ std::sort(lnames.begin(), lnames.end());
+ std::sort(rnames.begin(), rnames.end());
if (lnames != rnames) return false;
for (auto name : lnames) {
@@ -40,8 +52,13 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
COMPARE_ATTRIBUTEVALUE(is)
COMPARE_ATTRIBUTEVALUE(s)
COMPARE_ATTRIBUTEVALUE(ss)
+ case AttributeKind::t:
+ if (!tensorEqual(lhs->t(name), rhs->t(name))) return false;
+ break;
+ case AttributeKind::ts:
+ if (!tensorListEqual(lhs->ts(name), rhs->ts(name))) return false;
default:
- // NB: Comparison of nodes with tensor(s) or graph(s) will return false.
+ // NB: Comparison of nodes with graph(s) will return false.
return false;
}
@@ -92,6 +109,8 @@ struct EqualNodeCSE {
}
};
+} // anonymous namespace
+
// The function implements common subexpression elimination.
// Since the nodes are visited in topological order, one pass is enough.
void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp
index 89f298adc9..f566c9abae 100644
--- a/torch/csrc/jit/tracer.cpp
+++ b/torch/csrc/jit/tracer.cpp
@@ -56,7 +56,7 @@ struct TraceEval : autograd::Eval {
setValueTrace(tracing_state, input, input_node);
input_node->inferTypeFrom(input.data());
}
- tracing_state->var_flags.at(graph->stage()) = detail::getVarFlags(inputs);
+ tracing_state->var_flags.at(graph->stage()).first = detail::getVarFlags(inputs);
}
void exitTrace(const variable_list& inputs, const variable_list& outputs) {
diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h
index 62e283d727..ce832e302a 100644
--- a/torch/csrc/jit/tracer.h
+++ b/torch/csrc/jit/tracer.h
@@ -200,7 +200,7 @@ inline std::shared_ptr<TracingState> enter(std::vector<TraceInput>&& trace_input
}
}
// TODO: this might not work with the way we handle buffers
- state->var_flags[0] = detail::getVarFlags(inputs);
+ state->var_flags[0].first = detail::getVarFlags(inputs);
state->active = true;
state->inputs = inputs;
return state;
@@ -214,6 +214,7 @@ inline void _exit(const std::shared_ptr<TracingState>& state, const variable_lis
state->graph->registerOutput(getValueTrace(state, output, true));
}
state->active = false;
+ state->var_flags[state->graph->stage()].second = detail::getVarFlags(outputs);
}
// Marks a backwards subgraph that should be traced as the next stage.
diff --git a/torch/csrc/jit/tracer_state.h b/torch/csrc/jit/tracer_state.h
index 8b383ca0f7..3c2c32d80b 100644
--- a/torch/csrc/jit/tracer_state.h
+++ b/torch/csrc/jit/tracer_state.h
@@ -64,7 +64,8 @@ struct TracingState : public std::enable_shared_from_this<TracingState> {
// TODO: Perhaps, turn this into an owning reference. The buffers
// are persistent, so this won't lead to a leak.
std::unordered_map<void*, Node*> buffer_map;
- std::vector<std::vector<VariableFlags>> var_flags;
+ // A pair of (input_flags, output_flags) for each stage
+ std::vector<std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>>> var_flags;
std::vector<function_list> output_edges;
std::mutex mutex;
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 26b85919c6..1d7f442ffc 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -469,7 +469,6 @@ class TraceForKey(object):
# It's important to always run DCE, because backward can create a lot of unnecessary nodes
_run_pass(torch._C._jit_pass_dce, complete_trace)
- _run_pass(torch._C._jit_pass_onnx, complete_trace)
_run_pass(_passes._check_inplace, complete_trace)
if self.optimize:
_run_pass(torch._C._jit_pass_fuse, complete_trace)
diff --git a/torch/jit/passes/inplace.py b/torch/jit/passes/inplace.py
index 83246cd09b..0ea910932d 100644
--- a/torch/jit/passes/inplace.py
+++ b/torch/jit/passes/inplace.py
@@ -7,5 +7,5 @@ def _check_inplace(trace):
graph = trace.graph()
for node in graph.nodes():
if node.kind() == 'PythonOp':
- if node.i('__inplace'):
+ if node.i('inplace'):
raise RuntimeError("inplace {} not supported in the JIT".format(node.pyname()))