diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2018-02-14 16:59:48 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-14 16:59:48 +0100 |
commit | 8910dd5a81cdddb2612ac1c9b93a202fcf3a6e60 (patch) | |
tree | 6cf5d39cca17fecae252be3a0bee629443a78834 /torch | |
parent | 318ae2085a37852a640bd4664b4f485c92b0cbe5 (diff) | |
download | pytorch-8910dd5a81cdddb2612ac1c9b93a202fcf3a6e60.tar.gz pytorch-8910dd5a81cdddb2612ac1c9b93a202fcf3a6e60.tar.bz2 pytorch-8910dd5a81cdddb2612ac1c9b93a202fcf3a6e60.zip |
Fix GraphExecutor and add more AD formulas (#5215)
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/jit/autodiff.cpp | 87 | ||||
-rw-r--r-- | torch/csrc/jit/autodiff.h | 3 | ||||
-rw-r--r-- | torch/csrc/jit/graph_executor.cpp | 10 | ||||
-rw-r--r-- | torch/csrc/jit/init.cpp | 1 | ||||
-rw-r--r-- | torch/csrc/jit/interned_strings.h | 8 | ||||
-rw-r--r-- | torch/csrc/jit/passes/create_autodiff_subgraphs.cpp | 5 | ||||
-rw-r--r-- | torch/csrc/jit/passes/peephole.cpp | 35 | ||||
-rw-r--r-- | torch/csrc/jit/passes/shape_analysis.cpp | 67 | ||||
-rw-r--r-- | torch/csrc/jit/symbolic_variable.h | 49 | ||||
-rw-r--r-- | torch/csrc/jit/type.h | 8 |
10 files changed, 243 insertions, 30 deletions
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index 8ea46dc407..377ed243e9 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -5,20 +5,30 @@ #include "torch/csrc/utils/functional.h" #include "torch/csrc/utils/auto_gpu.h" +#include <algorithm> + namespace torch { namespace jit { using value_map = std::unordered_map<Value*, Value*>; using value_set = std::unordered_set<Value*>; +// TODO: unsqueeze! std::unordered_set<Symbol> differentiable_kinds = { kadd, ksub, kmul, kConstant, kReplaceIfUndef, - ksigmoid, ktanh, kmm, kchunk, ksplit, kt + ksigmoid, ktanh, kmm, kchunk, ksplit, kt, kneg, + kunsqueeze }; bool isDifferentiable(Node * n) { return differentiable_kinds.count(n->kind()) > 0; } +bool isDifferentiable(Graph & g) { + return std::all_of(g.nodes().begin(), g.nodes().end(), + static_cast<bool(*)(Node*)>(isDifferentiable)); +} + + static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_values) { const auto build_sym_grad = [node](const std::vector<SymbolicVariable>& grads) -> std::vector<SymbolicVariable> { auto inputs = fmap<SymbolicVariable>(node->inputs()); @@ -55,7 +65,13 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val return {SymbolicVariable::cat(grads, node->i(kdim))}; case kt: return {grads.at(0).t()}; - case kmm: + case kneg: + return {-grads.at(0)}; + case kview: + return {grads.at(0).view(inputs.at(0).sizes())}; + case kunsqueeze: + return {grads.at(0).squeeze(node->i(kdim))}; + case kmm: { SymbolicVariable dmat1, dmat2; if (inputs.at(0).value()->hasType()) { auto type = inputs.at(0).value()->type()->expect<TensorType>(); @@ -80,10 +96,77 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val dmat2 = grads.at(0).mm(inputs.at(1).t()); } return {dmat1, dmat2}; + } + case kexpand: { + const auto& input_sizes = inputs.at(0).sizes(); + if (input_sizes.size() == 0) + return {grads.at(0).sum()}; + auto grad_sizes = node->is(ksize); + auto grad = grads.at(0); + while (grad_sizes.size() > input_sizes.size()) { + grad = grad.sum(0, false); + grad_sizes.erase(grad_sizes.begin()); + } + for (size_t i = 0; i < input_sizes.size(); ++i) { + if (input_sizes[i] == 1 && grad_sizes[i] > 1) { + grad = grad.sum(i, true); + } + } + return {grad}; + } + case ksqueeze: { + const auto& sizes = inputs.at(0).sizes(); + if (node->hasAttribute(kdim)) { + int dim = node->i(kdim); + return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim)}; + } else { + std::vector<size_t> squeezed_dims; + for (size_t i = 0; i < sizes.size(); ++i) { + if (sizes[i] != 1) continue; + squeezed_dims.push_back(i); + } + SymbolicVariable returned_grad = grads.at(0); + for (auto it = squeezed_dims.rbegin(); it != squeezed_dims.rend(); ++it) + returned_grad = returned_grad.unsqueeze(*it); + return {returned_grad}; + } + } + case kcat: { + int dim = node->i(kdim); + const auto& first_sizes = inputs.at(0).sizes(); + const auto has_first_sizes = [&first_sizes](SymbolicVariable var) { + return var.sizes() == first_sizes; + }; + // NB: this is a specialization for the common case where all inputs are + // of equal sizes. We can use a single split operation to handle that. + if (std::all_of(inputs.begin(), inputs.end(), has_first_sizes)) { + return grads.at(0).chunk(inputs.size(), dim); + } else { + size_t offset = 0; + auto grad = grads.at(0); + std::vector<SymbolicVariable> returned_grads; + for (auto input : inputs) { + returned_grads.push_back(grad.narrow(dim, offset, input.sizes()[dim])); + offset += input.sizes()[dim]; + } + return returned_grads; + } + } } throw std::runtime_error(std::string("don't support differentiation of `") + node->kind().toString() + "`"); }; + const auto has_type = [](Value *v) { return v->hasType(); }; + if (!isDifferentiable(node)) { + throw std::runtime_error(std::string("differentiation of ") + node->kind().toString() + " " + "is not supported, or it is missing necessary type information"); + } + if (!std::all_of(node->inputs().begin(), node->inputs().end(), has_type) || + !std::all_of(node->outputs().begin(), node->outputs().end(), has_type)) { + throw std::runtime_error("differentiate should be called with a graph where every value " + "has a type registered"); + + } auto sym_grads = build_sym_grad(fmap<SymbolicVariable>(grad_values)); return fmap(sym_grads, [](const SymbolicVariable &v) { return v.value(); }); } diff --git a/torch/csrc/jit/autodiff.h b/torch/csrc/jit/autodiff.h index f9ba7e3d4d..bc09bfd793 100644 --- a/torch/csrc/jit/autodiff.h +++ b/torch/csrc/jit/autodiff.h @@ -79,10 +79,13 @@ struct Gradient { // - Interpret df // - Wrap outputs of df into Variables (that don't require grad) }; +// XXX: When calling this function, graph should have complete type information. +// Use the shape analysis pass to fill in the gaps if it doesn't. Gradient differentiate(std::shared_ptr<Graph>& graph, const std::vector<bool>& requires_grad); // can we take a derivative of this node symbolically? bool isDifferentiable(Node * n); +bool isDifferentiable(Graph & g); bool isZero(Value * v); }} diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 91ac44f8e4..54ce6ed0b7 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -203,14 +203,6 @@ private: } return false; } - static bool isDifferentiable(Graph & g) { - for(auto n : g.nodes()) { - if(!jit::isDifferentiable(n)) - return false; - } - return true; - } - void runOptimization(std::shared_ptr<Graph> & graph, bool graphMustSupportVariables) { // these optimizations must run in the presence of variables @@ -297,7 +289,7 @@ private: // 0 + a -> a void propagateZeros(Graph & g) { for(auto it = g.nodes().begin(); it != g.nodes().end(); ++it) { - if(it->kind() == kadd && at::Scalar(it->t(kalpha)).toDouble() == 1.0) { + if(it->kind() == kadd && it->inputs().size() == 2 && at::Scalar(it->t(kalpha)).toDouble() == 1.0) { if(isZero(it->inputs()[0])) { it->output()->replaceAllUsesWith(it->inputs()[1]); it.destroyCurrent(); diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 403eb6f10f..6c2fa17854 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -63,6 +63,7 @@ GraphExecutor createExecutorByTracing(py::function func, std::vector<tracer::Tra } tracer::exit(outputs); auto graph = enter_info.first->graph; + EliminateDeadCode(graph); return createExecutorByGraph(std::move(graph), optimize); } diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h index 933f499a0e..b2cedcca00 100644 --- a/torch/csrc/jit/interned_strings.h +++ b/torch/csrc/jit/interned_strings.h @@ -73,6 +73,7 @@ _(dilation) \ _(broadcast) \ _(axis) \ _(size) \ +_(sizes) \ _(dim) \ _(perm) \ _(shape) \ @@ -129,6 +130,13 @@ _(sqrt) \ _(sub) \ _(tan) \ _(trunc) \ +_(squeeze) \ +_(unsqueeze) \ +_(view) \ +_(narrow) \ +_(sum) \ +_(length) \ +_(keepdim) \ _(zeros) \ _(zeros_like) \ _(exponent) \ diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index 04bc957042..a09433a058 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -28,7 +28,7 @@ void mergeNodes(Graph & g, Symbol group_node_kind, ArrayRef<Node*> nodes) { if(value_map.count(v) > 0) { return value_map[v]; } - Value * nv = new_graph->addInput(); + Value * nv = new_graph->addInput()->setType(v->typeOption()); group_node->addInput(v); value_map[v] = nv; return nv; @@ -54,7 +54,7 @@ void mergeNodes(Graph & g, Symbol group_node_kind, ArrayRef<Node*> nodes) { } if(to_replace.size() > 0) { new_graph->registerOutput(new_output); - Value * external_output = group_node->addOutput(); + Value * external_output = group_node->addOutput()->setType(old_output->typeOption()); for(auto u : to_replace) { u.user->replaceInput(u.offset, external_output); } @@ -66,6 +66,7 @@ void mergeNodes(Graph & g, Symbol group_node_kind, ArrayRef<Node*> nodes) { for(size_t i = nodes.size(); i > 0; --i) { nodes[i - 1]->destroy(); } + JIT_ASSERT(isDifferentiable(*new_graph)); } } diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 1debc86572..2678930727 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -6,20 +6,39 @@ namespace torch { namespace jit { // catch peephole optimizations you might be interested in doing. // // Right now, it does: -// - Redundant 'expand' elimination +// - Eliminate no-op 'expand' nodes +// - Simply x.t().t() to x // // TODO: Decide what kind of fixed point strategy we will have void PeepholeOptimize(std::shared_ptr<Graph>& graph) { for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) { auto* n = *it; - // eliminate redundant expand - if (n->kind() == kexpand) { - if (n->is(ksize) == n->input()->type()->expect<TensorType>()->sizes()) { - n->output()->replaceAllUsesWith(n->input()); - it.destroyCurrent(); - continue; - } + switch (n->kind()) { + case kexpand: + // Eliminate redundant expand + if (!n->input()->hasType()) break; + if (n->is(ksize) == n->input()->type()->expect<TensorType>()->sizes()) { + n->output()->replaceAllUsesWith(n->input()); + it.destroyCurrent(); + } + break; + case kt: + // x.t().t() == x + auto input_node = n->input()->node(); + if (input_node->kind() == kt) { + n->output()->replaceAllUsesWith(input_node->input()); + it.destroyCurrent(); + // The previous transpose might be unnecessary now. + if (input_node->output()->uses().size() == 0) { + if (*it == input_node) { + it.destroyCurrent(); + } else { + input_node->destroy(); + } + } + } + break; } } } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 3574fbfaf5..f2d404a85b 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -46,7 +46,7 @@ void PropagateShapeOnNode(Node * node) { if(!input->hasType()) { return SetUnknownType(node); } - if(TensorType * t = input->type()->expect<TensorType>()) { + if(TensorType * t = input->type()->cast<TensorType>()) { types.push_back(t); } else { return SetUnknownType(node); @@ -60,6 +60,70 @@ void PropagateShapeOnNode(Node * node) { case kneg: { node->output()->setType(types[0]->contiguous()); } break; + case kmm: { + auto lhs_type = types.at(0); + auto rhs_type = types.at(1); + node->output()->setType(std::make_shared<TensorType>( + lhs_type->scalarType(), lhs_type->device(), + at::IntList{lhs_type->sizes()[0], rhs_type->sizes()[1]})); + } break; + case kt: { + auto tp = types.at(0); + auto sizes = tp->sizes(); + auto strides = tp->strides(); + std::swap(sizes.at(0), sizes.at(1)); + std::swap(strides.at(0), strides.at(1)); + node->output()->setType(tp->withSizesStrides(sizes, strides)); + } break; + case knarrow: { + auto tp = types.at(0); + auto sizes = tp->sizes(); + int64_t dim = node->i(kdim); + int64_t length = node->i(klength); + sizes.at(dim) = length; + node->output()->setType(tp->withSizes(sizes)); + } break; + case ksum: { + if (node->hasAttribute(kdim)) { + auto tp = types.at(0); + auto sizes = tp->sizes(); + int64_t dim = node->i(kdim); + JIT_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size()); + if (node->i(kkeepdim)) { + sizes[dim] = 1; + } else { + sizes.erase(sizes.begin() + dim); + } + node->output()->setType(tp->withSizes(sizes)); + } else { + node->output()->setType(types.at(0)->withSizes({})); + } + } break; + case ksqueeze: { + auto tp = types.at(0); + auto sizes = tp->sizes(); + auto strides = tp->strides(); + int64_t dim = node->i(kdim); + JIT_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size()); + if (sizes[dim] == 1) { + sizes.erase(sizes.begin() + dim); + strides.erase(strides.begin() + dim); + } + node->output()->setType(tp->withSizesStrides(sizes, strides)); + } break; + case kunsqueeze: { + auto tp = types.at(0); + auto sizes = tp->sizes(); + auto strides = tp->strides(); + int64_t dim = node->i(kdim); + JIT_ASSERT(dim >= 0 && static_cast<size_t>(dim) <= sizes.size()); + sizes.insert(sizes.begin() + dim, 1); + strides.insert(strides.begin() + dim, 1); + node->output()->setType(tp->withSizesStrides(sizes, strides)); + } break; + case kview: { + node->output()->setType(types.at(0)->withSizes(node->is(ksizes))); + } break; case kReplaceIfUndef: { // If types[0] has a type, then it is not defined, and the type will // get set to types[0] because that will be the value propagated. @@ -73,7 +137,6 @@ void PropagateShapeOnNode(Node * node) { node->output()->setType(nullptr); } break; default: { - auto op = getTensorOp(node); std::vector<at::Tensor> inputs; for(auto & type : types) { diff --git a/torch/csrc/jit/symbolic_variable.h b/torch/csrc/jit/symbolic_variable.h index cb6af330b6..fb7aa9bef5 100644 --- a/torch/csrc/jit/symbolic_variable.h +++ b/torch/csrc/jit/symbolic_variable.h @@ -18,6 +18,9 @@ struct SymbolicVariable { static SymbolicVariable asNewInput(Graph & g, TypePtr type) { return g.addInput()->setType(std::move(type)); } + const std::vector<int64_t>& sizes() { + return v->type()->expect<TensorType>()->sizes(); + } void addAsOutput() { v->owningGraph()->registerOutput(v); } @@ -77,11 +80,12 @@ struct SymbolicVariable { return create(kneg, {*this})[0].typeLike(*this); } SymbolicVariable mm(const SymbolicVariable rhs) const { - // TODO: set types - return create(s("mm"), {*this, rhs})[0]; + auto r = create(s("mm"), {*this, rhs})[0]; + return r; } SymbolicVariable t() const { - return create(s("t"), {*this})[0]; + auto r = create(s("t"), {*this})[0]; + return r; } SymbolicVariable sigmoid() const { return create(ksigmoid, {*this})[0].typeLike(*this); @@ -93,7 +97,15 @@ struct SymbolicVariable { Node * n; auto r = create(s("chunk"), { *this }, chunks, &n); n->i_(s("chunks"), chunks) - ->i_(s("dim"), dim); + ->i_(s("dim"), dim); + return r; + } + SymbolicVariable narrow(int dim, int64_t start, int64_t length) const { + Node * n; + auto r = create(s("narrow"), { *this }, 1, &n)[0]; + n->i_(s("dim"), dim) + ->i_(s("start"), start) + ->i_(s("length"), length); return r; } static SymbolicVariable cat(ArrayRef<SymbolicVariable> inputs, int32_t dim) { @@ -102,6 +114,35 @@ struct SymbolicVariable { n->i_(kdim, dim); return r; } + SymbolicVariable sum() const { + auto r = create(s("sum"), {*this})[0]; + return r; + } + SymbolicVariable sum(int dim, bool keepdim) const { + Node * n; + auto r = create(s("sum"), {*this}, 1, &n)[0]; + n->i_(s("dim"), dim) + ->i_(s("keepdim"), keepdim); + return r; + } + SymbolicVariable squeeze(int dim) const { + Node * n; + auto r = create(s("squeeze"), {*this}, 1, &n)[0]; + n->i_(s("dim"), dim); + return r; + } + SymbolicVariable unsqueeze(int dim) const { + Node * n; + auto r = create(s("unsqueeze"), {*this}, 1, &n)[0]; + n->i_(s("dim"), dim); + return r; + } + SymbolicVariable view(std::vector<std::int64_t> sizes) const { + Node *n; + auto r = create(kview, {*this}, 1, &n)[0]; + n->is_(s("size"), std::move(sizes)); + return r; + } Value * value() const { return v; } diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h index 6bd09cfecc..d9e26b15a5 100644 --- a/torch/csrc/jit/type.h +++ b/torch/csrc/jit/type.h @@ -64,6 +64,8 @@ struct TensorType : public Type { , device_(tensor.type().is_cuda() ? tensor.get_device() : -1) , sizes_(tensor.sizes()) , strides_(tensor.strides()) {} + TensorType(at::ScalarType scalar_type, int device, at::IntList sizes) + : TensorType(scalar_type, device, sizes, TensorType::contiguousStridesOf(sizes)) {} TensorType(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides) : Type(TypeKind::TensorType) , scalar_type_(scalar_type) @@ -84,16 +86,16 @@ struct TensorType : public Type { } TypePtr withSizes(at::IntList sizes) const { - return withSizesStrides(sizes, contiguousStridesOf(sizes)); + return withSizesStrides(sizes, TensorType::contiguousStridesOf(sizes)); } TypePtr contiguous() const { auto t = std::make_shared<TensorType>(*this); - t->strides_ = contiguousStridesOf(sizes_); + t->strides_ = TensorType::contiguousStridesOf(sizes_); return t; } private: - std::vector<int64_t> contiguousStridesOf(at::IntList sizes) const { + static std::vector<int64_t> contiguousStridesOf(at::IntList sizes) { std::vector<int64_t> strides(sizes.size()); strides.back() = 1; for(std::size_t i = strides.size() - 1; i > 0; i--) { |