summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2018-02-14 16:59:48 +0100
committerGitHub <noreply@github.com>2018-02-14 16:59:48 +0100
commit8910dd5a81cdddb2612ac1c9b93a202fcf3a6e60 (patch)
tree6cf5d39cca17fecae252be3a0bee629443a78834 /torch
parent318ae2085a37852a640bd4664b4f485c92b0cbe5 (diff)
downloadpytorch-8910dd5a81cdddb2612ac1c9b93a202fcf3a6e60.tar.gz
pytorch-8910dd5a81cdddb2612ac1c9b93a202fcf3a6e60.tar.bz2
pytorch-8910dd5a81cdddb2612ac1c9b93a202fcf3a6e60.zip
Fix GraphExecutor and add more AD formulas (#5215)
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/jit/autodiff.cpp87
-rw-r--r--torch/csrc/jit/autodiff.h3
-rw-r--r--torch/csrc/jit/graph_executor.cpp10
-rw-r--r--torch/csrc/jit/init.cpp1
-rw-r--r--torch/csrc/jit/interned_strings.h8
-rw-r--r--torch/csrc/jit/passes/create_autodiff_subgraphs.cpp5
-rw-r--r--torch/csrc/jit/passes/peephole.cpp35
-rw-r--r--torch/csrc/jit/passes/shape_analysis.cpp67
-rw-r--r--torch/csrc/jit/symbolic_variable.h49
-rw-r--r--torch/csrc/jit/type.h8
10 files changed, 243 insertions, 30 deletions
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index 8ea46dc407..377ed243e9 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -5,20 +5,30 @@
#include "torch/csrc/utils/functional.h"
#include "torch/csrc/utils/auto_gpu.h"
+#include <algorithm>
+
namespace torch { namespace jit {
using value_map = std::unordered_map<Value*, Value*>;
using value_set = std::unordered_set<Value*>;
+// TODO: unsqueeze!
std::unordered_set<Symbol> differentiable_kinds = {
kadd, ksub, kmul, kConstant, kReplaceIfUndef,
- ksigmoid, ktanh, kmm, kchunk, ksplit, kt
+ ksigmoid, ktanh, kmm, kchunk, ksplit, kt, kneg,
+ kunsqueeze
};
bool isDifferentiable(Node * n) {
return differentiable_kinds.count(n->kind()) > 0;
}
+bool isDifferentiable(Graph & g) {
+ return std::all_of(g.nodes().begin(), g.nodes().end(),
+ static_cast<bool(*)(Node*)>(isDifferentiable));
+}
+
+
static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_values) {
const auto build_sym_grad = [node](const std::vector<SymbolicVariable>& grads) -> std::vector<SymbolicVariable> {
auto inputs = fmap<SymbolicVariable>(node->inputs());
@@ -55,7 +65,13 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
return {SymbolicVariable::cat(grads, node->i(kdim))};
case kt:
return {grads.at(0).t()};
- case kmm:
+ case kneg:
+ return {-grads.at(0)};
+ case kview:
+ return {grads.at(0).view(inputs.at(0).sizes())};
+ case kunsqueeze:
+ return {grads.at(0).squeeze(node->i(kdim))};
+ case kmm: {
SymbolicVariable dmat1, dmat2;
if (inputs.at(0).value()->hasType()) {
auto type = inputs.at(0).value()->type()->expect<TensorType>();
@@ -80,10 +96,77 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
dmat2 = grads.at(0).mm(inputs.at(1).t());
}
return {dmat1, dmat2};
+ }
+ case kexpand: {
+ const auto& input_sizes = inputs.at(0).sizes();
+ if (input_sizes.size() == 0)
+ return {grads.at(0).sum()};
+ auto grad_sizes = node->is(ksize);
+ auto grad = grads.at(0);
+ while (grad_sizes.size() > input_sizes.size()) {
+ grad = grad.sum(0, false);
+ grad_sizes.erase(grad_sizes.begin());
+ }
+ for (size_t i = 0; i < input_sizes.size(); ++i) {
+ if (input_sizes[i] == 1 && grad_sizes[i] > 1) {
+ grad = grad.sum(i, true);
+ }
+ }
+ return {grad};
+ }
+ case ksqueeze: {
+ const auto& sizes = inputs.at(0).sizes();
+ if (node->hasAttribute(kdim)) {
+ int dim = node->i(kdim);
+ return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim)};
+ } else {
+ std::vector<size_t> squeezed_dims;
+ for (size_t i = 0; i < sizes.size(); ++i) {
+ if (sizes[i] != 1) continue;
+ squeezed_dims.push_back(i);
+ }
+ SymbolicVariable returned_grad = grads.at(0);
+ for (auto it = squeezed_dims.rbegin(); it != squeezed_dims.rend(); ++it)
+ returned_grad = returned_grad.unsqueeze(*it);
+ return {returned_grad};
+ }
+ }
+ case kcat: {
+ int dim = node->i(kdim);
+ const auto& first_sizes = inputs.at(0).sizes();
+ const auto has_first_sizes = [&first_sizes](SymbolicVariable var) {
+ return var.sizes() == first_sizes;
+ };
+ // NB: this is a specialization for the common case where all inputs are
+ // of equal sizes. We can use a single split operation to handle that.
+ if (std::all_of(inputs.begin(), inputs.end(), has_first_sizes)) {
+ return grads.at(0).chunk(inputs.size(), dim);
+ } else {
+ size_t offset = 0;
+ auto grad = grads.at(0);
+ std::vector<SymbolicVariable> returned_grads;
+ for (auto input : inputs) {
+ returned_grads.push_back(grad.narrow(dim, offset, input.sizes()[dim]));
+ offset += input.sizes()[dim];
+ }
+ return returned_grads;
+ }
+ }
}
throw std::runtime_error(std::string("don't support differentiation of `") +
node->kind().toString() + "`");
};
+ const auto has_type = [](Value *v) { return v->hasType(); };
+ if (!isDifferentiable(node)) {
+ throw std::runtime_error(std::string("differentiation of ") + node->kind().toString() + " "
+ "is not supported, or it is missing necessary type information");
+ }
+ if (!std::all_of(node->inputs().begin(), node->inputs().end(), has_type) ||
+ !std::all_of(node->outputs().begin(), node->outputs().end(), has_type)) {
+ throw std::runtime_error("differentiate should be called with a graph where every value "
+ "has a type registered");
+
+ }
auto sym_grads = build_sym_grad(fmap<SymbolicVariable>(grad_values));
return fmap(sym_grads, [](const SymbolicVariable &v) { return v.value(); });
}
diff --git a/torch/csrc/jit/autodiff.h b/torch/csrc/jit/autodiff.h
index f9ba7e3d4d..bc09bfd793 100644
--- a/torch/csrc/jit/autodiff.h
+++ b/torch/csrc/jit/autodiff.h
@@ -79,10 +79,13 @@ struct Gradient {
// - Interpret df
// - Wrap outputs of df into Variables (that don't require grad)
};
+// XXX: When calling this function, graph should have complete type information.
+// Use the shape analysis pass to fill in the gaps if it doesn't.
Gradient differentiate(std::shared_ptr<Graph>& graph, const std::vector<bool>& requires_grad);
// can we take a derivative of this node symbolically?
bool isDifferentiable(Node * n);
+bool isDifferentiable(Graph & g);
bool isZero(Value * v);
}}
diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp
index 91ac44f8e4..54ce6ed0b7 100644
--- a/torch/csrc/jit/graph_executor.cpp
+++ b/torch/csrc/jit/graph_executor.cpp
@@ -203,14 +203,6 @@ private:
}
return false;
}
- static bool isDifferentiable(Graph & g) {
- for(auto n : g.nodes()) {
- if(!jit::isDifferentiable(n))
- return false;
- }
- return true;
- }
-
void runOptimization(std::shared_ptr<Graph> & graph, bool graphMustSupportVariables) {
// these optimizations must run in the presence of variables
@@ -297,7 +289,7 @@ private:
// 0 + a -> a
void propagateZeros(Graph & g) {
for(auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
- if(it->kind() == kadd && at::Scalar(it->t(kalpha)).toDouble() == 1.0) {
+ if(it->kind() == kadd && it->inputs().size() == 2 && at::Scalar(it->t(kalpha)).toDouble() == 1.0) {
if(isZero(it->inputs()[0])) {
it->output()->replaceAllUsesWith(it->inputs()[1]);
it.destroyCurrent();
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index 403eb6f10f..6c2fa17854 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -63,6 +63,7 @@ GraphExecutor createExecutorByTracing(py::function func, std::vector<tracer::Tra
}
tracer::exit(outputs);
auto graph = enter_info.first->graph;
+ EliminateDeadCode(graph);
return createExecutorByGraph(std::move(graph), optimize);
}
diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h
index 933f499a0e..b2cedcca00 100644
--- a/torch/csrc/jit/interned_strings.h
+++ b/torch/csrc/jit/interned_strings.h
@@ -73,6 +73,7 @@ _(dilation) \
_(broadcast) \
_(axis) \
_(size) \
+_(sizes) \
_(dim) \
_(perm) \
_(shape) \
@@ -129,6 +130,13 @@ _(sqrt) \
_(sub) \
_(tan) \
_(trunc) \
+_(squeeze) \
+_(unsqueeze) \
+_(view) \
+_(narrow) \
+_(sum) \
+_(length) \
+_(keepdim) \
_(zeros) \
_(zeros_like) \
_(exponent) \
diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp
index 04bc957042..a09433a058 100644
--- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp
+++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp
@@ -28,7 +28,7 @@ void mergeNodes(Graph & g, Symbol group_node_kind, ArrayRef<Node*> nodes) {
if(value_map.count(v) > 0) {
return value_map[v];
}
- Value * nv = new_graph->addInput();
+ Value * nv = new_graph->addInput()->setType(v->typeOption());
group_node->addInput(v);
value_map[v] = nv;
return nv;
@@ -54,7 +54,7 @@ void mergeNodes(Graph & g, Symbol group_node_kind, ArrayRef<Node*> nodes) {
}
if(to_replace.size() > 0) {
new_graph->registerOutput(new_output);
- Value * external_output = group_node->addOutput();
+ Value * external_output = group_node->addOutput()->setType(old_output->typeOption());
for(auto u : to_replace) {
u.user->replaceInput(u.offset, external_output);
}
@@ -66,6 +66,7 @@ void mergeNodes(Graph & g, Symbol group_node_kind, ArrayRef<Node*> nodes) {
for(size_t i = nodes.size(); i > 0; --i) {
nodes[i - 1]->destroy();
}
+ JIT_ASSERT(isDifferentiable(*new_graph));
}
}
diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp
index 1debc86572..2678930727 100644
--- a/torch/csrc/jit/passes/peephole.cpp
+++ b/torch/csrc/jit/passes/peephole.cpp
@@ -6,20 +6,39 @@ namespace torch { namespace jit {
// catch peephole optimizations you might be interested in doing.
//
// Right now, it does:
-// - Redundant 'expand' elimination
+// - Eliminate no-op 'expand' nodes
+// - Simply x.t().t() to x
//
// TODO: Decide what kind of fixed point strategy we will have
void PeepholeOptimize(std::shared_ptr<Graph>& graph) {
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
auto* n = *it;
- // eliminate redundant expand
- if (n->kind() == kexpand) {
- if (n->is(ksize) == n->input()->type()->expect<TensorType>()->sizes()) {
- n->output()->replaceAllUsesWith(n->input());
- it.destroyCurrent();
- continue;
- }
+ switch (n->kind()) {
+ case kexpand:
+ // Eliminate redundant expand
+ if (!n->input()->hasType()) break;
+ if (n->is(ksize) == n->input()->type()->expect<TensorType>()->sizes()) {
+ n->output()->replaceAllUsesWith(n->input());
+ it.destroyCurrent();
+ }
+ break;
+ case kt:
+ // x.t().t() == x
+ auto input_node = n->input()->node();
+ if (input_node->kind() == kt) {
+ n->output()->replaceAllUsesWith(input_node->input());
+ it.destroyCurrent();
+ // The previous transpose might be unnecessary now.
+ if (input_node->output()->uses().size() == 0) {
+ if (*it == input_node) {
+ it.destroyCurrent();
+ } else {
+ input_node->destroy();
+ }
+ }
+ }
+ break;
}
}
}
diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp
index 3574fbfaf5..f2d404a85b 100644
--- a/torch/csrc/jit/passes/shape_analysis.cpp
+++ b/torch/csrc/jit/passes/shape_analysis.cpp
@@ -46,7 +46,7 @@ void PropagateShapeOnNode(Node * node) {
if(!input->hasType()) {
return SetUnknownType(node);
}
- if(TensorType * t = input->type()->expect<TensorType>()) {
+ if(TensorType * t = input->type()->cast<TensorType>()) {
types.push_back(t);
} else {
return SetUnknownType(node);
@@ -60,6 +60,70 @@ void PropagateShapeOnNode(Node * node) {
case kneg: {
node->output()->setType(types[0]->contiguous());
} break;
+ case kmm: {
+ auto lhs_type = types.at(0);
+ auto rhs_type = types.at(1);
+ node->output()->setType(std::make_shared<TensorType>(
+ lhs_type->scalarType(), lhs_type->device(),
+ at::IntList{lhs_type->sizes()[0], rhs_type->sizes()[1]}));
+ } break;
+ case kt: {
+ auto tp = types.at(0);
+ auto sizes = tp->sizes();
+ auto strides = tp->strides();
+ std::swap(sizes.at(0), sizes.at(1));
+ std::swap(strides.at(0), strides.at(1));
+ node->output()->setType(tp->withSizesStrides(sizes, strides));
+ } break;
+ case knarrow: {
+ auto tp = types.at(0);
+ auto sizes = tp->sizes();
+ int64_t dim = node->i(kdim);
+ int64_t length = node->i(klength);
+ sizes.at(dim) = length;
+ node->output()->setType(tp->withSizes(sizes));
+ } break;
+ case ksum: {
+ if (node->hasAttribute(kdim)) {
+ auto tp = types.at(0);
+ auto sizes = tp->sizes();
+ int64_t dim = node->i(kdim);
+ JIT_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
+ if (node->i(kkeepdim)) {
+ sizes[dim] = 1;
+ } else {
+ sizes.erase(sizes.begin() + dim);
+ }
+ node->output()->setType(tp->withSizes(sizes));
+ } else {
+ node->output()->setType(types.at(0)->withSizes({}));
+ }
+ } break;
+ case ksqueeze: {
+ auto tp = types.at(0);
+ auto sizes = tp->sizes();
+ auto strides = tp->strides();
+ int64_t dim = node->i(kdim);
+ JIT_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
+ if (sizes[dim] == 1) {
+ sizes.erase(sizes.begin() + dim);
+ strides.erase(strides.begin() + dim);
+ }
+ node->output()->setType(tp->withSizesStrides(sizes, strides));
+ } break;
+ case kunsqueeze: {
+ auto tp = types.at(0);
+ auto sizes = tp->sizes();
+ auto strides = tp->strides();
+ int64_t dim = node->i(kdim);
+ JIT_ASSERT(dim >= 0 && static_cast<size_t>(dim) <= sizes.size());
+ sizes.insert(sizes.begin() + dim, 1);
+ strides.insert(strides.begin() + dim, 1);
+ node->output()->setType(tp->withSizesStrides(sizes, strides));
+ } break;
+ case kview: {
+ node->output()->setType(types.at(0)->withSizes(node->is(ksizes)));
+ } break;
case kReplaceIfUndef: {
// If types[0] has a type, then it is not defined, and the type will
// get set to types[0] because that will be the value propagated.
@@ -73,7 +137,6 @@ void PropagateShapeOnNode(Node * node) {
node->output()->setType(nullptr);
} break;
default: {
-
auto op = getTensorOp(node);
std::vector<at::Tensor> inputs;
for(auto & type : types) {
diff --git a/torch/csrc/jit/symbolic_variable.h b/torch/csrc/jit/symbolic_variable.h
index cb6af330b6..fb7aa9bef5 100644
--- a/torch/csrc/jit/symbolic_variable.h
+++ b/torch/csrc/jit/symbolic_variable.h
@@ -18,6 +18,9 @@ struct SymbolicVariable {
static SymbolicVariable asNewInput(Graph & g, TypePtr type) {
return g.addInput()->setType(std::move(type));
}
+ const std::vector<int64_t>& sizes() {
+ return v->type()->expect<TensorType>()->sizes();
+ }
void addAsOutput() {
v->owningGraph()->registerOutput(v);
}
@@ -77,11 +80,12 @@ struct SymbolicVariable {
return create(kneg, {*this})[0].typeLike(*this);
}
SymbolicVariable mm(const SymbolicVariable rhs) const {
- // TODO: set types
- return create(s("mm"), {*this, rhs})[0];
+ auto r = create(s("mm"), {*this, rhs})[0];
+ return r;
}
SymbolicVariable t() const {
- return create(s("t"), {*this})[0];
+ auto r = create(s("t"), {*this})[0];
+ return r;
}
SymbolicVariable sigmoid() const {
return create(ksigmoid, {*this})[0].typeLike(*this);
@@ -93,7 +97,15 @@ struct SymbolicVariable {
Node * n;
auto r = create(s("chunk"), { *this }, chunks, &n);
n->i_(s("chunks"), chunks)
- ->i_(s("dim"), dim);
+ ->i_(s("dim"), dim);
+ return r;
+ }
+ SymbolicVariable narrow(int dim, int64_t start, int64_t length) const {
+ Node * n;
+ auto r = create(s("narrow"), { *this }, 1, &n)[0];
+ n->i_(s("dim"), dim)
+ ->i_(s("start"), start)
+ ->i_(s("length"), length);
return r;
}
static SymbolicVariable cat(ArrayRef<SymbolicVariable> inputs, int32_t dim) {
@@ -102,6 +114,35 @@ struct SymbolicVariable {
n->i_(kdim, dim);
return r;
}
+ SymbolicVariable sum() const {
+ auto r = create(s("sum"), {*this})[0];
+ return r;
+ }
+ SymbolicVariable sum(int dim, bool keepdim) const {
+ Node * n;
+ auto r = create(s("sum"), {*this}, 1, &n)[0];
+ n->i_(s("dim"), dim)
+ ->i_(s("keepdim"), keepdim);
+ return r;
+ }
+ SymbolicVariable squeeze(int dim) const {
+ Node * n;
+ auto r = create(s("squeeze"), {*this}, 1, &n)[0];
+ n->i_(s("dim"), dim);
+ return r;
+ }
+ SymbolicVariable unsqueeze(int dim) const {
+ Node * n;
+ auto r = create(s("unsqueeze"), {*this}, 1, &n)[0];
+ n->i_(s("dim"), dim);
+ return r;
+ }
+ SymbolicVariable view(std::vector<std::int64_t> sizes) const {
+ Node *n;
+ auto r = create(kview, {*this}, 1, &n)[0];
+ n->is_(s("size"), std::move(sizes));
+ return r;
+ }
Value * value() const {
return v;
}
diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h
index 6bd09cfecc..d9e26b15a5 100644
--- a/torch/csrc/jit/type.h
+++ b/torch/csrc/jit/type.h
@@ -64,6 +64,8 @@ struct TensorType : public Type {
, device_(tensor.type().is_cuda() ? tensor.get_device() : -1)
, sizes_(tensor.sizes())
, strides_(tensor.strides()) {}
+ TensorType(at::ScalarType scalar_type, int device, at::IntList sizes)
+ : TensorType(scalar_type, device, sizes, TensorType::contiguousStridesOf(sizes)) {}
TensorType(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides)
: Type(TypeKind::TensorType)
, scalar_type_(scalar_type)
@@ -84,16 +86,16 @@ struct TensorType : public Type {
}
TypePtr withSizes(at::IntList sizes) const {
- return withSizesStrides(sizes, contiguousStridesOf(sizes));
+ return withSizesStrides(sizes, TensorType::contiguousStridesOf(sizes));
}
TypePtr contiguous() const {
auto t = std::make_shared<TensorType>(*this);
- t->strides_ = contiguousStridesOf(sizes_);
+ t->strides_ = TensorType::contiguousStridesOf(sizes_);
return t;
}
private:
- std::vector<int64_t> contiguousStridesOf(at::IntList sizes) const {
+ static std::vector<int64_t> contiguousStridesOf(at::IntList sizes) {
std::vector<int64_t> strides(sizes.size());
strides.back() = 1;
for(std::size_t i = strides.size() - 1; i > 0; i--) {