summaryrefslogtreecommitdiff
path: root/torch/csrc/autograd/engine.h
diff options
context:
space:
mode:
authorSam Gross <sgross@fb.com>2017-02-28 12:20:25 -0800
committerSam Gross <sgross@fb.com>2017-03-03 11:26:00 -0800
commit34ce58c909543632a0ba41791c9a83f29580b23b (patch)
treedbc2f5ccb61ae6e160d2984525891d5c99f23ca4 /torch/csrc/autograd/engine.h
parentc238ee368165b5cfd0ff59a5d6479cf6393c719b (diff)
downloadpytorch-34ce58c909543632a0ba41791c9a83f29580b23b.tar.gz
pytorch-34ce58c909543632a0ba41791c9a83f29580b23b.tar.bz2
pytorch-34ce58c909543632a0ba41791c9a83f29580b23b.zip
Parallelize backwards
Diffstat (limited to 'torch/csrc/autograd/engine.h')
-rw-r--r--torch/csrc/autograd/engine.h27
1 files changed, 22 insertions, 5 deletions
diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h
index 94ce0cf2d2..269bacc85d 100644
--- a/torch/csrc/autograd/engine.h
+++ b/torch/csrc/autograd/engine.h
@@ -14,22 +14,39 @@
namespace torch { namespace autograd {
+struct ReadyQueue;
+struct FunctionTask;
+struct BackwardTask;
+
struct Engine {
+ Engine();
+ virtual ~Engine();
+
using ready_queue_type = std::deque<std::pair<std::shared_ptr<Function>, GradBuffer>>;
using function_queue = std::vector<Function*>;
using dependencies_type = std::unordered_map<Function*, int>;
// Given a list of output variables and their gradients, computes the
// gradients of "root" variables by backpropagation.
- static void backward(
+ void backward(
const variable_list& variables,
tensor_list& grad_variables,
bool retain_variables);
-private:
- static dependencies_type compute_dependencies(
- function_queue queue,
- ready_queue_type& ready);
+protected:
+ function_queue find_creators(
+ const variable_list& variables,
+ tensor_list& grad_variables,
+ BackwardTask& task);
+ void find_stochastic_functions(function_queue& queue, BackwardTask& task);
+ void compute_dependencies(function_queue queue, BackwardTask& task);
+ void evaluate_function(FunctionTask& task);
+ ReadyQueue& ready_queue(int device);
+ void start_threads();
+ virtual void thread_main(ReadyQueue& queue);
+ virtual void thread_on_exception(FunctionTask& task, std::exception& e);
+
+ std::vector<std::unique_ptr<ReadyQueue>> ready_queues;
};
}} // namespace torch::autograd