diff options
author | Sam Gross <sgross@fb.com> | 2017-02-28 12:20:25 -0800 |
---|---|---|
committer | Sam Gross <sgross@fb.com> | 2017-03-03 11:26:00 -0800 |
commit | 34ce58c909543632a0ba41791c9a83f29580b23b (patch) | |
tree | dbc2f5ccb61ae6e160d2984525891d5c99f23ca4 /torch/csrc/autograd/engine.h | |
parent | c238ee368165b5cfd0ff59a5d6479cf6393c719b (diff) | |
download | pytorch-34ce58c909543632a0ba41791c9a83f29580b23b.tar.gz pytorch-34ce58c909543632a0ba41791c9a83f29580b23b.tar.bz2 pytorch-34ce58c909543632a0ba41791c9a83f29580b23b.zip |
Parallelize backwards
Diffstat (limited to 'torch/csrc/autograd/engine.h')
-rw-r--r-- | torch/csrc/autograd/engine.h | 27 |
1 files changed, 22 insertions, 5 deletions
diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 94ce0cf2d2..269bacc85d 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -14,22 +14,39 @@ namespace torch { namespace autograd { +struct ReadyQueue; +struct FunctionTask; +struct BackwardTask; + struct Engine { + Engine(); + virtual ~Engine(); + using ready_queue_type = std::deque<std::pair<std::shared_ptr<Function>, GradBuffer>>; using function_queue = std::vector<Function*>; using dependencies_type = std::unordered_map<Function*, int>; // Given a list of output variables and their gradients, computes the // gradients of "root" variables by backpropagation. - static void backward( + void backward( const variable_list& variables, tensor_list& grad_variables, bool retain_variables); -private: - static dependencies_type compute_dependencies( - function_queue queue, - ready_queue_type& ready); +protected: + function_queue find_creators( + const variable_list& variables, + tensor_list& grad_variables, + BackwardTask& task); + void find_stochastic_functions(function_queue& queue, BackwardTask& task); + void compute_dependencies(function_queue queue, BackwardTask& task); + void evaluate_function(FunctionTask& task); + ReadyQueue& ready_queue(int device); + void start_threads(); + virtual void thread_main(ReadyQueue& queue); + virtual void thread_on_exception(FunctionTask& task, std::exception& e); + + std::vector<std::unique_ptr<ReadyQueue>> ready_queues; }; }} // namespace torch::autograd |