summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--torch/autograd/__init__.py3
-rw-r--r--torch/autograd/_functions/__init__.py1
-rw-r--r--torch/autograd/_functions/stochastic.py38
-rw-r--r--torch/autograd/stochastic_function.py44
-rw-r--r--torch/autograd/variable.py10
-rw-r--r--torch/csrc/autograd/engine.cpp84
-rw-r--r--torch/csrc/autograd/engine.h8
-rw-r--r--torch/csrc/autograd/function.h3
-rw-r--r--torch/csrc/autograd/init.cpp3
-rw-r--r--torch/csrc/autograd/python_function.cpp2
-rw-r--r--torch/csrc/autograd/python_function.h1
11 files changed, 29 insertions, 168 deletions
diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py
index 5cc7502021..81a1e79755 100644
--- a/torch/autograd/__init__.py
+++ b/torch/autograd/__init__.py
@@ -9,11 +9,10 @@ import warnings
from .variable import Variable
from .function import Function, NestedIOFunction
-from .stochastic_function import StochasticFunction
from .gradcheck import gradcheck
from . import profiler
-__all__ = ['Variable', 'Function', 'StochasticFunction', 'backward']
+__all__ = ['Variable', 'Function', 'backward']
def _make_grads(outputs, grads, user_create_graph):
diff --git a/torch/autograd/_functions/__init__.py b/torch/autograd/_functions/__init__.py
index 35429e34ac..2d0ab21bd2 100644
--- a/torch/autograd/_functions/__init__.py
+++ b/torch/autograd/_functions/__init__.py
@@ -4,6 +4,5 @@ from .pointwise import *
from .reduce import *
from .linalg import *
from .blas import *
-from .stochastic import *
from .compare import *
from .initializers import *
diff --git a/torch/autograd/_functions/stochastic.py b/torch/autograd/_functions/stochastic.py
deleted file mode 100644
index 6ef22a4c35..0000000000
--- a/torch/autograd/_functions/stochastic.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import torch
-from ..function import Function
-
-
-class Categorical(Function):
- @staticmethod
- def forward(ctx, probs, num_samples, with_replacement):
- samples = probs.multinomial(num_samples, with_replacement)
- ctx.mark_non_differentiable(samples)
- return samples
-
- @staticmethod
- def backward(ctx, grad_output):
- return None, None, None
-
-
-class Bernoulli(Function):
- @staticmethod
- def forward(ctx, probs):
- samples = probs.new().resize_as_(probs).bernoulli_(probs)
- ctx.mark_non_differentiable(samples)
- return samples
-
- @staticmethod
- def backward(ctx, grad_output):
- return None
-
-
-class Normal(Function):
- @staticmethod
- def forward(ctx, means, stddevs=None):
- samples = torch.normal(means, stddevs)
- ctx.mark_non_differentiable(samples)
- return samples
-
- @staticmethod
- def backward(ctx, grad_output):
- return None, None
diff --git a/torch/autograd/stochastic_function.py b/torch/autograd/stochastic_function.py
deleted file mode 100644
index 84b835cd0b..0000000000
--- a/torch/autograd/stochastic_function.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import torch
-from numbers import Number
-from .function import Function
-
-_NOT_PROVIDED = object()
-
-
-class StochasticFunction(Function):
-
- def __init__(self):
- self.reward = _NOT_PROVIDED
-
- def _do_backward(self, grad_output, retain_variables):
- if self.reward is _NOT_PROVIDED:
- raise RuntimeError("differentiating stochastic functions requires "
- "providing a reward")
- result = super(StochasticFunction, self)._do_backward((self.reward,), retain_variables)
- if not retain_variables:
- self.reward = None
- return result
-
- def _do_forward(self, *inputs):
- result = super(StochasticFunction, self)._do_forward(*inputs)
- # save output type and size, to check the type of reward
- assert isinstance(result, torch.autograd.Variable), \
- "stochastic functions support only a single output at the moment"
- self.reward_info = (type(inputs[0].data), result.size())
- return result
-
- __call__ = _do_forward
-
- def _reinforce(self, reward):
- is_number = isinstance(reward, Number)
- if not is_number and type(reward) != self.reward_info[0]:
- raise TypeError("mismatch between reward and output type: got {}, "
- "but expected {}".format(torch.typename(reward),
- torch.typename(self.reward_info[0])))
- if not is_number and reward.size() != self.reward_info[1]:
- raise ValueError("got reward of size {}, but expected a tensor of size {}".format(
- 'x'.join(map(str, reward.size())),
- 'x'.join(map(str, self.reward_info[1]))))
- if self.reward is not _NOT_PROVIDED:
- raise RuntimeError("you can only reinforce a stochastic Function once")
- self.reward = reward
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py
index 256fa1a39f..d8378320db 100644
--- a/torch/autograd/variable.py
+++ b/torch/autograd/variable.py
@@ -374,10 +374,10 @@ class Variable(_C._VariableBase):
return self.expand(tensor.size())
def multinomial(self, num_samples=1, replacement=False):
- return Categorical.apply(self, num_samples, replacement)
+ return Variable(torch.multinomial(self.data, num_samples, replacement))
def bernoulli(self):
- return Bernoulli.apply(self)
+ return Variable(torch.bernoulli(self.data))
def __rsub__(self, other):
return -self + other
@@ -432,7 +432,11 @@ class Variable(_C._VariableBase):
class _torch(object):
@staticmethod
def normal(means, std=1):
- return Normal.apply(means, std)
+ if isinstance(means, Variable):
+ means = means.data
+ if isinstance(std, Variable):
+ std = std.data
+ return Variable(torch.normal(means, std))
for method in dir(Variable):
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index 602653d54a..7b89de9e82 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -64,7 +64,6 @@ struct GraphTask {
std::atomic_bool has_error;
std::atomic<uint64_t> outstanding_tasks;
bool keep_graph;
- bool has_any_work;
std::mutex mutex;
// Notified when a task finishes executing. Check outstanding_tasks to see
@@ -82,7 +81,6 @@ struct GraphTask {
, has_error(false)
, outstanding_tasks(0)
, keep_graph(keep_graph)
- , has_any_work(false)
, mutex()
, not_done()
, pre_callbacks(pre_callbacks)
@@ -234,22 +232,17 @@ auto Engine::evaluate_function(FunctionTask& task) -> void {
}
int num_outputs = outputs.size();
+ if (num_outputs == 0) return; // Don't even acquire the mutex
+ std::lock_guard<std::mutex> lock(task.base->mutex);
for (int i = 0; i < num_outputs; ++i) {
auto& output = outputs[i];
auto& next_fn = fn.next_functions[i].first;
int input_nr = fn.next_functions[i].second;
- if (!next_fn) {
+ if (!next_fn || !next_fn->is_executable) {
continue;
}
- // Stochastic functions are placed in the ready queue by
- // compute_dependencies, so we have to skip them here.
- if (next_fn->is_stochastic || !next_fn->is_executable) {
- continue;
- }
-
- std::lock_guard<std::mutex> lock(task.base->mutex);
// Check if the next function is ready to be computed
bool is_ready = false;
auto& dependencies = task.base->dependencies;
@@ -287,49 +280,24 @@ auto Engine::evaluate_function(FunctionTask& task) -> void {
}
}
-/** Finds all stochastic functions and appends them to the queue */
-auto Engine::find_stochastic_functions(function_queue& queue, Function* graph_root, GraphTask& task) -> void {
- std::unordered_set<Function*> seen {graph_root};
- function_queue search_queue {graph_root};
- while (search_queue.size() > 0) {
- auto fn = search_queue.back(); search_queue.pop_back();
- for (auto& next_fn_pair : fn->next_functions) {
- auto& next_fn = next_fn_pair.first;
- Function* next_ptr = next_fn.get();
- if (!next_ptr) continue;
- if (next_ptr->is_stochastic && next_ptr->is_executable && seen.count(next_ptr) == 0) {
- ready_queue(-1).push_front(FunctionTask(&task, next_fn, InputBuffer(0)));
- queue.push_back(next_ptr);
- task.has_any_work = true;
- }
- if (seen.count(next_ptr) == 0) {
- seen.insert(next_ptr);
- search_queue.push_back(next_ptr);
- }
- }
- }
-}
-
-/** Computes the number of dependencies for each function which requires grad */
-auto Engine::compute_dependencies(function_queue queue, GraphTask& task) -> void {
+/* Computes the number of dependencies for each function which requires grad */
+auto Engine::compute_dependencies(Function* root, GraphTask& task) -> void {
// Just to make sure that they will never be added to the queue again
- std::unordered_set<Function*> seen(queue.begin(), queue.end());
+ std::unordered_set<Function*> seen;
+ std::vector<Function*> queue { root };
// Queue contains all nodes that will start propagating gradients.
// We no longer have to expand functions that don't require grad.
auto& dependencies = task.dependencies;
while (queue.size() > 0) {
- auto fn = std::move(queue.back()); queue.pop_back();
- for (auto& next_fn_pair : fn->next_functions) {
- Function* next_ptr = next_fn_pair.first.get();
- if (!next_ptr) continue;
- if (!next_ptr->is_executable) continue;
- if (next_ptr->is_stochastic) continue; // Stochastic nodes were in the queue already
+ auto fn = queue.back(); queue.pop_back();
+ for (auto& edge : fn->next_functions) {
+ Function* next_ptr = edge.first.get();
+ if (!next_ptr || !next_ptr->is_executable) continue;
dependencies[next_ptr] += 1;
- if (seen.count(next_ptr) == 0) {
- seen.insert(next_ptr);
- queue.push_back(next_ptr);
- }
+ bool inserted;
+ std::tie(std::ignore, inserted) = seen.insert(next_ptr);
+ if (inserted) queue.push_back(next_ptr);
}
}
}
@@ -360,30 +328,18 @@ auto Engine::execute(const function_list& input_roots,
ClearCallbacks _cb_guard(final_callbacks, post_callbacks_lock);
GraphTask graph_task(keep_graph, pre_callbacks, post_callbacks);
-
std::unique_lock<std::mutex> lock(graph_task.mutex);
- auto graph_root = std::make_shared<GraphRoot>(input_roots, inputs);
- function_queue roots;
- for (auto entry : input_roots) {
- if (entry.first->is_executable) {
- graph_task.has_any_work = true;
- roots.push_back(graph_root.get());
- ready_queue(-1).push_front(FunctionTask(&graph_task, graph_root, InputBuffer(0)));
- break;
- }
- }
-
- // Search the graph and find all stochastic functions. Append them to the queue.
- find_stochastic_functions(roots, graph_root.get(), graph_task);
-
- if (!graph_task.has_any_work) {
+ auto is_executable = [](const edge_type& e) { return e.first->is_executable; };
+ if (!std::any_of(input_roots.begin(), input_roots.end(), is_executable)) {
throw std::runtime_error(
"there are no graph nodes that require computing gradients");
}
- // Now compute the dependencies for all executable functions
- compute_dependencies(std::move(roots), graph_task);
+ // Now compute the dependencies for all executable functions and queue the root
+ auto graph_root = std::make_shared<GraphRoot>(input_roots, inputs);
+ compute_dependencies(graph_root.get(), graph_task);
+ ready_queue(-1).push_front(FunctionTask(&graph_task, std::move(graph_root), InputBuffer(0)));
// Not a worker
if (worker_device == NO_DEVICE) {
diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h
index ff1ae45cbd..9d1c3aaea5 100644
--- a/torch/csrc/autograd/engine.h
+++ b/torch/csrc/autograd/engine.h
@@ -27,7 +27,6 @@ struct Engine {
virtual ~Engine();
using ready_queue_type = std::deque<std::pair<std::shared_ptr<Function>, InputBuffer>>;
- using function_queue = std::vector<Function*>;
using dependencies_type = std::unordered_map<Function*, int>;
using pre_callback_type = std::function<bool (Function*, variable_list&)>;
@@ -47,12 +46,7 @@ struct Engine {
void queue_callback(std::function<void()> callback);
protected:
- function_queue find_roots(
- const function_list& roots,
- variable_list& inputs,
- GraphTask& task);
- void find_stochastic_functions(function_queue& queue, Function* graph_root, GraphTask& task);
- void compute_dependencies(function_queue queue, GraphTask& task);
+ void compute_dependencies(Function* root, GraphTask& task);
void evaluate_function(FunctionTask& task);
ReadyQueue& ready_queue(int device);
void start_threads();
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h
index 27690ad885..40f8d6ee43 100644
--- a/torch/csrc/autograd/function.h
+++ b/torch/csrc/autograd/function.h
@@ -56,7 +56,6 @@ struct Function : std::enable_shared_from_this<Function> {
: num_inputs(0)
, next_functions()
, is_executable(false)
- , is_stochastic(false)
, pre_hooks()
, post_hooks()
, pyobj(nullptr)
@@ -66,7 +65,6 @@ struct Function : std::enable_shared_from_this<Function> {
: num_inputs(0)
, next_functions(std::move(flags.next_functions))
, is_executable(flags.is_executable)
- , is_stochastic(false)
, pre_hooks()
, post_hooks()
, pyobj(nullptr)
@@ -163,7 +161,6 @@ struct Function : std::enable_shared_from_this<Function> {
int num_inputs;
function_list next_functions;
bool is_executable;
- bool is_stochastic;
std::vector<std::shared_ptr<FunctionPreHook>> pre_hooks;
std::vector<std::shared_ptr<FunctionPostHook>> post_hooks;
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
index 868852fa1b..7919d33ef6 100644
--- a/torch/csrc/autograd/init.cpp
+++ b/torch/csrc/autograd/init.cpp
@@ -21,13 +21,10 @@ PyObject * THPAutograd_initExtension(PyObject *_unused)
THPUtils_assert_PyImport("torch.nn._functions.thnn", thnn_functions);
THPBatchNormBackwardBackwardFunction = PyObject_GetAttrString(thnn_functions,(char*)"batchnorm_double_backwards_fn");
- THPStochasticFunctionClass = PyMapping_GetItemString(autograd_dict,(char*)"StochasticFunction");
THPUtils_assert(THPVariableClass, "couldn't find Variable class in "
"torch.autograd module");
THPUtils_assert(THPFunctionClass, "couldn't find Function class in "
"torch.autograd module");
- THPUtils_assert(THPStochasticFunctionClass, "couldn't find "
- "StochasticFunction class in torch.autograd module");
auto m = py::handle(autograd_module).cast<py::module>();
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 1d56bfd200..fbf39690db 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -30,7 +30,6 @@ using namespace torch::jit;
using at::Tensor;
PyObject *THPFunctionClass = NULL;
-PyObject *THPStochasticFunctionClass = NULL;
PyObject *THPBatchNormBackwardBackwardFunction = NULL;
#define THPFunction_assert(condition, ...) \
@@ -301,7 +300,6 @@ PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
new (&self->saved_variables) std::vector<SavedVariable>();
new (&self->is_variable_input) std::vector<bool>();
self->cdata.num_inputs = -1;
- self->cdata.is_stochastic = PyObject_IsInstance(obj, THPStochasticFunctionClass);
return obj;
}
diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h
index 837b0addcb..0894aa6d01 100644
--- a/torch/csrc/autograd/python_function.h
+++ b/torch/csrc/autograd/python_function.h
@@ -98,7 +98,6 @@ struct THPFunction {
bool THPFunction_initModule(PyObject *module);
extern PyTypeObject THPFunctionType;
extern PyObject *THPFunctionClass;
-extern PyObject *THPStochasticFunctionClass;
extern PyObject *THPBatchNormBackwardBackwardFunction; // Temporarily here until we move it to C++
// XXX: this function requires the GIL (it can have side effects).