diff options
-rw-r--r-- | torch/autograd/__init__.py | 3 | ||||
-rw-r--r-- | torch/autograd/_functions/__init__.py | 1 | ||||
-rw-r--r-- | torch/autograd/_functions/stochastic.py | 38 | ||||
-rw-r--r-- | torch/autograd/stochastic_function.py | 44 | ||||
-rw-r--r-- | torch/autograd/variable.py | 10 | ||||
-rw-r--r-- | torch/csrc/autograd/engine.cpp | 84 | ||||
-rw-r--r-- | torch/csrc/autograd/engine.h | 8 | ||||
-rw-r--r-- | torch/csrc/autograd/function.h | 3 | ||||
-rw-r--r-- | torch/csrc/autograd/init.cpp | 3 | ||||
-rw-r--r-- | torch/csrc/autograd/python_function.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/autograd/python_function.h | 1 |
11 files changed, 29 insertions, 168 deletions
diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 5cc7502021..81a1e79755 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -9,11 +9,10 @@ import warnings from .variable import Variable from .function import Function, NestedIOFunction -from .stochastic_function import StochasticFunction from .gradcheck import gradcheck from . import profiler -__all__ = ['Variable', 'Function', 'StochasticFunction', 'backward'] +__all__ = ['Variable', 'Function', 'backward'] def _make_grads(outputs, grads, user_create_graph): diff --git a/torch/autograd/_functions/__init__.py b/torch/autograd/_functions/__init__.py index 35429e34ac..2d0ab21bd2 100644 --- a/torch/autograd/_functions/__init__.py +++ b/torch/autograd/_functions/__init__.py @@ -4,6 +4,5 @@ from .pointwise import * from .reduce import * from .linalg import * from .blas import * -from .stochastic import * from .compare import * from .initializers import * diff --git a/torch/autograd/_functions/stochastic.py b/torch/autograd/_functions/stochastic.py deleted file mode 100644 index 6ef22a4c35..0000000000 --- a/torch/autograd/_functions/stochastic.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch -from ..function import Function - - -class Categorical(Function): - @staticmethod - def forward(ctx, probs, num_samples, with_replacement): - samples = probs.multinomial(num_samples, with_replacement) - ctx.mark_non_differentiable(samples) - return samples - - @staticmethod - def backward(ctx, grad_output): - return None, None, None - - -class Bernoulli(Function): - @staticmethod - def forward(ctx, probs): - samples = probs.new().resize_as_(probs).bernoulli_(probs) - ctx.mark_non_differentiable(samples) - return samples - - @staticmethod - def backward(ctx, grad_output): - return None - - -class Normal(Function): - @staticmethod - def forward(ctx, means, stddevs=None): - samples = torch.normal(means, stddevs) - ctx.mark_non_differentiable(samples) - return samples - - @staticmethod - def backward(ctx, grad_output): - return None, None diff --git a/torch/autograd/stochastic_function.py b/torch/autograd/stochastic_function.py deleted file mode 100644 index 84b835cd0b..0000000000 --- a/torch/autograd/stochastic_function.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -from numbers import Number -from .function import Function - -_NOT_PROVIDED = object() - - -class StochasticFunction(Function): - - def __init__(self): - self.reward = _NOT_PROVIDED - - def _do_backward(self, grad_output, retain_variables): - if self.reward is _NOT_PROVIDED: - raise RuntimeError("differentiating stochastic functions requires " - "providing a reward") - result = super(StochasticFunction, self)._do_backward((self.reward,), retain_variables) - if not retain_variables: - self.reward = None - return result - - def _do_forward(self, *inputs): - result = super(StochasticFunction, self)._do_forward(*inputs) - # save output type and size, to check the type of reward - assert isinstance(result, torch.autograd.Variable), \ - "stochastic functions support only a single output at the moment" - self.reward_info = (type(inputs[0].data), result.size()) - return result - - __call__ = _do_forward - - def _reinforce(self, reward): - is_number = isinstance(reward, Number) - if not is_number and type(reward) != self.reward_info[0]: - raise TypeError("mismatch between reward and output type: got {}, " - "but expected {}".format(torch.typename(reward), - torch.typename(self.reward_info[0]))) - if not is_number and reward.size() != self.reward_info[1]: - raise ValueError("got reward of size {}, but expected a tensor of size {}".format( - 'x'.join(map(str, reward.size())), - 'x'.join(map(str, self.reward_info[1])))) - if self.reward is not _NOT_PROVIDED: - raise RuntimeError("you can only reinforce a stochastic Function once") - self.reward = reward diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py index 256fa1a39f..d8378320db 100644 --- a/torch/autograd/variable.py +++ b/torch/autograd/variable.py @@ -374,10 +374,10 @@ class Variable(_C._VariableBase): return self.expand(tensor.size()) def multinomial(self, num_samples=1, replacement=False): - return Categorical.apply(self, num_samples, replacement) + return Variable(torch.multinomial(self.data, num_samples, replacement)) def bernoulli(self): - return Bernoulli.apply(self) + return Variable(torch.bernoulli(self.data)) def __rsub__(self, other): return -self + other @@ -432,7 +432,11 @@ class Variable(_C._VariableBase): class _torch(object): @staticmethod def normal(means, std=1): - return Normal.apply(means, std) + if isinstance(means, Variable): + means = means.data + if isinstance(std, Variable): + std = std.data + return Variable(torch.normal(means, std)) for method in dir(Variable): diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 602653d54a..7b89de9e82 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -64,7 +64,6 @@ struct GraphTask { std::atomic_bool has_error; std::atomic<uint64_t> outstanding_tasks; bool keep_graph; - bool has_any_work; std::mutex mutex; // Notified when a task finishes executing. Check outstanding_tasks to see @@ -82,7 +81,6 @@ struct GraphTask { , has_error(false) , outstanding_tasks(0) , keep_graph(keep_graph) - , has_any_work(false) , mutex() , not_done() , pre_callbacks(pre_callbacks) @@ -234,22 +232,17 @@ auto Engine::evaluate_function(FunctionTask& task) -> void { } int num_outputs = outputs.size(); + if (num_outputs == 0) return; // Don't even acquire the mutex + std::lock_guard<std::mutex> lock(task.base->mutex); for (int i = 0; i < num_outputs; ++i) { auto& output = outputs[i]; auto& next_fn = fn.next_functions[i].first; int input_nr = fn.next_functions[i].second; - if (!next_fn) { + if (!next_fn || !next_fn->is_executable) { continue; } - // Stochastic functions are placed in the ready queue by - // compute_dependencies, so we have to skip them here. - if (next_fn->is_stochastic || !next_fn->is_executable) { - continue; - } - - std::lock_guard<std::mutex> lock(task.base->mutex); // Check if the next function is ready to be computed bool is_ready = false; auto& dependencies = task.base->dependencies; @@ -287,49 +280,24 @@ auto Engine::evaluate_function(FunctionTask& task) -> void { } } -/** Finds all stochastic functions and appends them to the queue */ -auto Engine::find_stochastic_functions(function_queue& queue, Function* graph_root, GraphTask& task) -> void { - std::unordered_set<Function*> seen {graph_root}; - function_queue search_queue {graph_root}; - while (search_queue.size() > 0) { - auto fn = search_queue.back(); search_queue.pop_back(); - for (auto& next_fn_pair : fn->next_functions) { - auto& next_fn = next_fn_pair.first; - Function* next_ptr = next_fn.get(); - if (!next_ptr) continue; - if (next_ptr->is_stochastic && next_ptr->is_executable && seen.count(next_ptr) == 0) { - ready_queue(-1).push_front(FunctionTask(&task, next_fn, InputBuffer(0))); - queue.push_back(next_ptr); - task.has_any_work = true; - } - if (seen.count(next_ptr) == 0) { - seen.insert(next_ptr); - search_queue.push_back(next_ptr); - } - } - } -} - -/** Computes the number of dependencies for each function which requires grad */ -auto Engine::compute_dependencies(function_queue queue, GraphTask& task) -> void { +/* Computes the number of dependencies for each function which requires grad */ +auto Engine::compute_dependencies(Function* root, GraphTask& task) -> void { // Just to make sure that they will never be added to the queue again - std::unordered_set<Function*> seen(queue.begin(), queue.end()); + std::unordered_set<Function*> seen; + std::vector<Function*> queue { root }; // Queue contains all nodes that will start propagating gradients. // We no longer have to expand functions that don't require grad. auto& dependencies = task.dependencies; while (queue.size() > 0) { - auto fn = std::move(queue.back()); queue.pop_back(); - for (auto& next_fn_pair : fn->next_functions) { - Function* next_ptr = next_fn_pair.first.get(); - if (!next_ptr) continue; - if (!next_ptr->is_executable) continue; - if (next_ptr->is_stochastic) continue; // Stochastic nodes were in the queue already + auto fn = queue.back(); queue.pop_back(); + for (auto& edge : fn->next_functions) { + Function* next_ptr = edge.first.get(); + if (!next_ptr || !next_ptr->is_executable) continue; dependencies[next_ptr] += 1; - if (seen.count(next_ptr) == 0) { - seen.insert(next_ptr); - queue.push_back(next_ptr); - } + bool inserted; + std::tie(std::ignore, inserted) = seen.insert(next_ptr); + if (inserted) queue.push_back(next_ptr); } } } @@ -360,30 +328,18 @@ auto Engine::execute(const function_list& input_roots, ClearCallbacks _cb_guard(final_callbacks, post_callbacks_lock); GraphTask graph_task(keep_graph, pre_callbacks, post_callbacks); - std::unique_lock<std::mutex> lock(graph_task.mutex); - auto graph_root = std::make_shared<GraphRoot>(input_roots, inputs); - function_queue roots; - for (auto entry : input_roots) { - if (entry.first->is_executable) { - graph_task.has_any_work = true; - roots.push_back(graph_root.get()); - ready_queue(-1).push_front(FunctionTask(&graph_task, graph_root, InputBuffer(0))); - break; - } - } - - // Search the graph and find all stochastic functions. Append them to the queue. - find_stochastic_functions(roots, graph_root.get(), graph_task); - - if (!graph_task.has_any_work) { + auto is_executable = [](const edge_type& e) { return e.first->is_executable; }; + if (!std::any_of(input_roots.begin(), input_roots.end(), is_executable)) { throw std::runtime_error( "there are no graph nodes that require computing gradients"); } - // Now compute the dependencies for all executable functions - compute_dependencies(std::move(roots), graph_task); + // Now compute the dependencies for all executable functions and queue the root + auto graph_root = std::make_shared<GraphRoot>(input_roots, inputs); + compute_dependencies(graph_root.get(), graph_task); + ready_queue(-1).push_front(FunctionTask(&graph_task, std::move(graph_root), InputBuffer(0))); // Not a worker if (worker_device == NO_DEVICE) { diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index ff1ae45cbd..9d1c3aaea5 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -27,7 +27,6 @@ struct Engine { virtual ~Engine(); using ready_queue_type = std::deque<std::pair<std::shared_ptr<Function>, InputBuffer>>; - using function_queue = std::vector<Function*>; using dependencies_type = std::unordered_map<Function*, int>; using pre_callback_type = std::function<bool (Function*, variable_list&)>; @@ -47,12 +46,7 @@ struct Engine { void queue_callback(std::function<void()> callback); protected: - function_queue find_roots( - const function_list& roots, - variable_list& inputs, - GraphTask& task); - void find_stochastic_functions(function_queue& queue, Function* graph_root, GraphTask& task); - void compute_dependencies(function_queue queue, GraphTask& task); + void compute_dependencies(Function* root, GraphTask& task); void evaluate_function(FunctionTask& task); ReadyQueue& ready_queue(int device); void start_threads(); diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 27690ad885..40f8d6ee43 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -56,7 +56,6 @@ struct Function : std::enable_shared_from_this<Function> { : num_inputs(0) , next_functions() , is_executable(false) - , is_stochastic(false) , pre_hooks() , post_hooks() , pyobj(nullptr) @@ -66,7 +65,6 @@ struct Function : std::enable_shared_from_this<Function> { : num_inputs(0) , next_functions(std::move(flags.next_functions)) , is_executable(flags.is_executable) - , is_stochastic(false) , pre_hooks() , post_hooks() , pyobj(nullptr) @@ -163,7 +161,6 @@ struct Function : std::enable_shared_from_this<Function> { int num_inputs; function_list next_functions; bool is_executable; - bool is_stochastic; std::vector<std::shared_ptr<FunctionPreHook>> pre_hooks; std::vector<std::shared_ptr<FunctionPostHook>> post_hooks; diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 868852fa1b..7919d33ef6 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -21,13 +21,10 @@ PyObject * THPAutograd_initExtension(PyObject *_unused) THPUtils_assert_PyImport("torch.nn._functions.thnn", thnn_functions); THPBatchNormBackwardBackwardFunction = PyObject_GetAttrString(thnn_functions,(char*)"batchnorm_double_backwards_fn"); - THPStochasticFunctionClass = PyMapping_GetItemString(autograd_dict,(char*)"StochasticFunction"); THPUtils_assert(THPVariableClass, "couldn't find Variable class in " "torch.autograd module"); THPUtils_assert(THPFunctionClass, "couldn't find Function class in " "torch.autograd module"); - THPUtils_assert(THPStochasticFunctionClass, "couldn't find " - "StochasticFunction class in torch.autograd module"); auto m = py::handle(autograd_module).cast<py::module>(); diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 1d56bfd200..fbf39690db 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -30,7 +30,6 @@ using namespace torch::jit; using at::Tensor; PyObject *THPFunctionClass = NULL; -PyObject *THPStochasticFunctionClass = NULL; PyObject *THPBatchNormBackwardBackwardFunction = NULL; #define THPFunction_assert(condition, ...) \ @@ -301,7 +300,6 @@ PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) new (&self->saved_variables) std::vector<SavedVariable>(); new (&self->is_variable_input) std::vector<bool>(); self->cdata.num_inputs = -1; - self->cdata.is_stochastic = PyObject_IsInstance(obj, THPStochasticFunctionClass); return obj; } diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 837b0addcb..0894aa6d01 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -98,7 +98,6 @@ struct THPFunction { bool THPFunction_initModule(PyObject *module); extern PyTypeObject THPFunctionType; extern PyObject *THPFunctionClass; -extern PyObject *THPStochasticFunctionClass; extern PyObject *THPBatchNormBackwardBackwardFunction; // Temporarily here until we move it to C++ // XXX: this function requires the GIL (it can have side effects). |