diff options
23 files changed, 1527 insertions, 528 deletions
diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 55a0d73414..92416812c4 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -857,6 +857,11 @@ Non-linear activation functions .. autofunction:: rrelu +:hidden:`glu` +~~~~~~~~~~~~~~~ + +.. autofunction:: glu + :hidden:`logsigmoid` ~~~~~~~~~~~~~~~~~~~~ @@ -416,6 +416,7 @@ main_sources = [ "torch/csrc/autograd/generated/Functions.cpp", "torch/csrc/autograd/generated/python_variable_methods.cpp", "torch/csrc/autograd/generated/python_functions.cpp", + "torch/csrc/autograd/generated/python_nn_functions.cpp", "torch/csrc/autograd/functions/batch_normalization.cpp", "torch/csrc/autograd/functions/convolution.cpp", "torch/csrc/autograd/functions/softmax.cpp", diff --git a/test/common_nn.py b/test/common_nn.py index cae61dcb9d..e6b6469800 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -324,6 +324,7 @@ criterion_tests = [ module_name='NLLLoss2d', input_size=(2, 3, 5, 5), target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(), + check_no_size_average=True, ), dict( module_name='NLLLoss2d', @@ -356,6 +357,7 @@ criterion_tests = [ module_name='MultiLabelMarginLoss', input_size=(5, 10), target_fn=lambda: torch.rand(5, 10).mul(10).floor().long(), + check_no_size_average=True, check_gradgrad=False, ), dict( diff --git a/test/test_autograd.py b/test/test_autograd.py index ea69d45f2a..1c338b0112 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -1785,11 +1785,11 @@ method_tests = [ ('fmod', (S, S, S), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor'), ('fmod', (S,), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_lhs'), ('fmod', (S, S, S), (Variable(torch.rand(S) + 1.5, requires_grad=False),), 'tensor_broadcast_rhs'), - ('fmod', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broacast_all'), + ('fmod', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_all'), ('remainder', (S, S, S), (1.5,)), ('remainder', (S, S, S), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor'), ('remainder', (S,), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_lhs'), - ('remainder', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broacast_all'), + ('remainder', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_all'), ('lerp', (S, S, S), ((S, S, S), 0.4)), ('lerp', (S, S, S), ((S,), 0.4), 'broadcast_rhs'), ('lerp', (S,), ((S, S, S), 0.4), 'broadcast_lhs'), @@ -2212,7 +2212,9 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, for test in method_tests: name, self_size, args = test[:3] - basic_test_name = 'test_' + name + ('_' + test[3] if len(test) >= 4 else '') + basic_test_name = 'test_' + name + if len(test) >= 4 and test[3] != '': + basic_test_name += '_' + test[3] dim_args_idx = test[4] if len(test) == 5 else [] diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 45b4ae1cce..2eb8ee0822 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1,16 +1,4 @@ # Defines derivative formulas and Python signatures of methods on Variable - -- name: __and__ -- name: __iand__ -- name: __ilshift__ -- name: __ior__ -- name: __irshift__ -- name: __ixor__ -- name: __lshift__ -- name: __or__ -- name: __rshift__ -- name: __xor__ - - name: abs(Tensor self) self: grad * self.sign() @@ -25,14 +13,14 @@ other: grad * alpha - name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) - self: grad * alpha - batch1: grad.bmm(batch2.transpose(1, 2)) * beta - batch2: batch1.transpose(1, 2).bmm(grad) * beta + self: grad * beta + batch1: grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) }).bmm(batch2.transpose(1, 2)) * alpha + batch2: batch1.transpose(1, 2).bmm(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) })) * alpha - name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) - self: grad * value - tensor1: grad / tensor2 * value - tensor2: -grad * tensor1 / (tensor2 * tensor2) + self: grad + tensor1: grad * value / tensor2 + tensor2: -grad * value * tensor1 / (tensor2 * tensor2) - name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) self: grad @@ -58,6 +46,8 @@ - name: any # fallthrough +- name: arange # fallthrough + - name: asin(Tensor self) self: grad * (-self * self + 1).sqrt().reciprocal() @@ -77,23 +67,23 @@ self: grad.bmm(mat2.transpose(1, 2)) mat2: self.transpose(1, 2).bmm(grad) -- name: btrifact -- name: btrisolve +- name: btrifact(Tensor self, Tensor info, bool pivot) + self: not_implemented("btrifact") + +- name: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots) + self: not_implemented("btrisolve") - name: cat(TensorList tensors, int64_t dim=0) - tensors: cat_tensors_backward(grad, tensors.sizes(dim), dim) + tensors: cat_tensors_backward(grad, to_arg_sizes(tensors, dim), dim) -- name: cauchy +- name: cauchy # TODO: reinforce - name: ceil(Tensor self) - self: zeros_like(self) + self: zeros_like(grad) - name: clone(Tensor self) self: grad -- name: contiguous(Tensor self) - self: grad - - name: cos(Tensor self) self: grad * -self.sin() @@ -104,7 +94,7 @@ self: other.cross(grad, dim) other: grad.cross(self, dim) -- name: cumprod +- name: cumprod # complicated - name: cumsum(Tensor self, int64_t dim) self: cumsum_backward(grad, dim) @@ -129,75 +119,197 @@ self: grad * tensor tensor: grad * self -- name: eig -- name: eq -- name: equal +- name: eig(Tensor self, bool eigenvectors) + self: not_implemented("eig") + +- name: eq(Tensor self, Scalar value) + self: zeros_like(self) + +- name: eq(Tensor self, Tensor other) + self: zeros_like(self) + other: zeros_like(other) + +- name: equal # fallthrough + +- name: erf(Tensor self) + self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad + +- name: erfinv(Tensor self) + self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad - name: exp(Tensor self) - self: grad * output + self: grad * result - name: expand(Tensor self, IntList size) self: reduce_to(grad, self.sizes()) + __view__: True -- name: eye -- name: fill -- name: floor -- name: fmod -- name: frac -- name: gather -- name: ge -- name: gels -- name: geometric -- name: geqrf +- name: eye # fallthrough + +- name: fill(Tensor self, Scalar value) # FIXME + +- name: floor(Tensor self) + self: zeros_like(grad) + +- name: fmod(Tensor self, Scalar value) + self: grad + +- name: fmod(Tensor self, Tensor other) + self: grad + other: 'not_implemented("fmod: other")' + +- name: frac(Tensor self) + self: grad + +- name: gather(Tensor self, int64_t dim, Tensor index) + self: grad.type().zeros(self.sizes()).scatter_add_(dim, index, grad) + +- name: ge(Tensor self, Scalar value) + self: zeros_like(self) + +- name: ge(Tensor self, Tensor other) + self: zeros_like(self) + other: zeros_like(other) + +- name: gels(Tensor self, Tensor A) + self: not_implemented("gels") + A: not_implemented("gels") + +- name: geometric(Tensor self, double p, Generator Generator) + self: zeros_like(grad) + +- name: geqrf(Tensor self) + self: not_implemented("geqrf") - name: ger(Tensor self, Tensor vec2) self: grad.mv(vec2) vec2: grad.t().mv(self) -- name: gesv +- name: gesv(Tensor self, Tensor A) + self: std::get<0>(gesv(grad, A.t())) + A: -at::mm(std::get<0>(gesv(grad, A.t())), solution.t()) - name: get_device # fallthrough -- name: gt -- name: histc -- name: index_add -- name: index_copy -- name: index_fill -- name: index_select -- name: inverse +- name: gt(Tensor self, Scalar value) + self: zeros_like(self) + +- name: gt(Tensor self, Tensor other) + self: zeros_like(self) + other: zeros_like(other) + +- name: histc(Tensor self, int64_t bins, Scalar min, Scalar max) + self: not_implemented("histc") + +- name: index_add(Tensor self, int64_t dim, Tensor index, Tensor source) + self: grad + source: grad.index_select(dim, index) + +- name: index_copy(Tensor self, int64_t dim, Tensor index, Tensor source) + self: grad.clone().index_fill_(dim, index, 0) + source: grad.index_select(dim, index) + +- name: index_fill(Tensor self, int64_t dim, Tensor index, Scalar value) + self: grad.clone().index_fill_(dim, index, 0) -- name: is_contiguous +- name: index_select(Tensor self, int64_t dim, Tensor index) + self: grad.type().zeros(self.sizes()).index_add_(dim, index, grad) + __view__: True + +- name: inverse(Tensor self) + self: -at::mm(output.t(), at::mm(grad, output.t())) + +- name: is_contiguous # fallthrough - name: is_same_size # fallthrough - name: is_set_to # fallthrough -- name: kthvalue -- name: le -- name: lerp -- name: lgamma -- name: linspace +- name: kthvalue(Tensor self, int64_t k, int64_t dim, bool keepdim) + self: select_backward(grad, dim, indices, self.sizes(), keepdim) + +- name: le(Tensor self, Scalar value) + self: zeros_like(self) + +- name: le(Tensor self, Tensor other) + self: zeros_like(self) + other: zeros_like(other) + +- name: lerp(Tensor self, Tensor end, Scalar weight) + self: grad * (1 - weight.toDouble()) + end: grad * weight + +- name: lgamma(Tensor self) + self: not_implemented("lgamma") + +- name: linspace(Scalar start, Scalar end, int64_t steps) - name: log(Tensor self) self: grad.div(self) -- name: log1p -- name: log_normal -- name: logspace -- name: lt -- name: masked_fill -- name: masked_scatter -- name: masked_select -- name: max -- name: mean -- name: median -- name: min +- name: log1p(Tensor self) + self: grad / (self + 1) + +- name: log_normal(Tensor self, double mean, double std, Generator generator) + self: zeros_like(grad) + +- name: logspace # fallthrough + +- name: lt(Tensor self, Scalar value) + self: zeros_like(self) + +- name: lt(Tensor self, Tensor other) + self: zeros_like(self) + other: zeros_like(other) + +- name: masked_fill(Tensor self, Tensor mask, Scalar value) + self: grad.clone().masked_fill_(mask, 0) + +- name: masked_scatter(Tensor self, Tensor mask, Tensor source) + self: grad.clone().masked_fill_(mask, 0) + source: masked_scatter_backward(grad, mask, source.sizes()) + +- name: masked_select(Tensor self, Tensor mask) + self: zeros_like(self).masked_scatter_(mask, grad) + +- name: max(Tensor self, int64_t dim, bool keepdim) + self: select_backward(grad, dim, max_indices, self.sizes(), keepdim) + +- name: max(Tensor self) + self: select_backward_scalar(grad, self, result) + +- name: max(Tensor self, Tensor other) + self: grad.clone().masked_fill_(self <= other, 0) + other: grad.clone().masked_fill_(self > other, 0) + +- name: mean(Tensor self, int64_t dim, bool keepdim) + self: sum_backward(grad, self.sizes(), dim, keepdim) / self.size(dim) + +- name: mean(Tensor self) + self: grad.expand(self.sizes()) / self.numel() + +- name: median(Tensor self) + self: select_backward_scalar(grad, self, result) + +- name: median(Tensor self, int64_t dim, bool keepdim) + self: select_backward(grad, dim, indices, self.sizes(), keepdim) + +- name: min(Tensor self, int64_t dim, bool keepdim) + self: select_backward(grad, dim, min_indices, self.sizes(), keepdim) + +- name: min(Tensor self) + self: select_backward_scalar(grad, self, result) + +- name: min(Tensor self, Tensor other) + self: grad.clone().masked_fill_(self >= other, 0) + other: grad.clone().masked_fill_(self < other, 0) - name: mm(Tensor self, Tensor mat2) self: grad.mm(mat2.t()) mat2: self.t().mm(grad) -- name: mode +- name: mode(Tensor self, int64_t dim, bool keepdim) + self: select_backward(grad, dim, indices, self.sizes(), keepdim) - name: mul(Tensor self, Scalar value) self: grad * value @@ -206,7 +318,7 @@ self: grad * other other: grad * self -- name: multinomial +- name: multinomial # TODO: reinforce - name: mv(Tensor self, Tensor vec) self: grad.ger(vec) @@ -214,16 +326,23 @@ - name: narrow(Tensor self, int64_t dimension, int64_t start, int64_t length) self: grad._unnarrow(dimension, start, self.size(dimension)) + __view__: True - name: _unnarrow(Tensor self, int64_t dimension, int64_t offset, int64_t dimSize) self: grad.narrow(dimension, offset, self.size(dimension)) -- name: ne +- name: ne(Tensor self, Scalar value) + self: zeros_like(self) + +- name: ne(Tensor self, Tensor other) + self: zeros_like(self) + other: zeros_like(other) - name: neg(Tensor self) self: grad.neg() -- name: nonzero +- name: nonzero(Tensor self) + self: zeros_like(grad) - name: norm(Tensor self, Scalar p=2) self: norm_backward(grad, self, p) @@ -231,13 +350,27 @@ - name: norm(Tensor self, Scalar p, int64_t dim, bool keepdim=False) self: norm_backward(grad, self, p, dim, keepdim) -- name: numel -- name: ones -- name: orgqr -- name: ormqr -- name: potrf -- name: potri -- name: potrs +- name: numel # fallthrough +- name: ones # fallthrough + +- name: orgqr(Tensor self, Tensor input2) + self: not_implemented("orgqr") + input2: not_implemented("orgqr") + +- name: ormqr(Tensor self, Tensor input2, Tensor input3, bool left, bool transpose) + self: not_implemented("ormqr") + input2: not_implemented("ormqr") + input3: not_implemented("ormqr") + +- name: potrf(Tensor self, bool upper) + self: potrf_backward(grad, upper, output) + +- name: potri(Tensor self, bool upper) + self: not_implemented("potri") + +- name: potrs(Tensor self, Tensor input2, bool upper) + self: not_implemented("potri") + input2: not_implemented("potri") - name: pow(Tensor self, Scalar exponent) self: grad * exponent * self.pow(exponent.toDouble() - 1) @@ -246,28 +379,56 @@ self: grad * exponent * self.pow(exponent - 1) exponent: grad * self.pow(exponent) * self.log() -- name: prod -- name: pstrf -- name: qr -- name: rand -- name: randn -- name: randperm -- name: range +# TODO: complicated +# - name: prod(Tensor self, int64_t dim, bool keepdim) + +# - name: prod(Tensor self) + +- name: pstrf(Tensor self, bool upper, Scalar tol) + self: not_implemented("pstrf") + +- name: qr(Tensor self) + self: not_implemented("qr") + +- name: rand # fallthrough +- name: randn # fallthrough +- name: randperm # fallthrough +- name: range # fallthrough - name: reciprocal(Tensor self) self: grad / -(self * self) -- name: remainder -- name: renorm -- name: resize -- name: resize_as -- name: round -- name: rsqrt -- name: scatter -- name: scatter_add -- name: select -- name: set -- name: sigmoid +- name: remainder(Tensor self, Scalar value) + self: grad + +- name: remainder(Tensor self, Tensor other) + self: grad + +- name: renorm # TODO! + +- name: round(Tensor self) + self: zeros_like(grad) + +- name: rsqrt(Tensor self) + self: -0.5 * grad * result.pow(3) + +- name: scatter(Tensor self, int64_t dim, Tensor index, Tensor src) + self: grad.clone().scatter_(dim, index, 0) + src: grad.gather(dim, index) + +- name: scatter(Tensor self, int64_t dim, Tensor index, Scalar value) + self: grad.clone().scatter_(dim, index, 0) + +- name: scatter_add(Tensor self, int64_t dim, Tensor index, Tensor src) + self: grad + src: grad.gather(dim, index) + +- name: select # TODO: ATen definition conflicts with PyTorch + +- name: set # TODO + +- name: sigmoid(Tensor self) + self: _sigmoid_backward(grad, result) - name: sign(Tensor self) self: zeros_like(grad) @@ -278,17 +439,21 @@ - name: sinh(Tensor self) self: grad * self.cosh() -- name: size -- name: sort +- name: size # fallthrough + +- name: sort(Tensor self, int64_t dim, bool descending) + self: select_backward(grad, dim, indices, self.sizes(), true) - name: sqrt(Tensor self) self: grad * self.pow(-0.5) / 2 - name: squeeze(Tensor self) self: unsqueeze_to(grad, self.sizes()); + __view__: True - name: squeeze(Tensor self, int64_t dim) - self: grad.unsqueeze(dim) + self: maybe_unsqueeze(grad, dim, self.size(dim)) + __view__: True - name: std @@ -309,22 +474,33 @@ - name: sum(Tensor self, int64_t dim, bool keepdim=False) self: sum_backward(grad, self.sizes(), dim, keepdim) -- name: svd -- name: symeig +- name: svd(Tensor self, bool some) + self: not_implemented("svd") + +- name: symeig(Tensor self, bool eigenvectors, bool upper) + self: not_implemented("symeig") - name: t(Tensor self) self: grad.t() + __view__: True + +- name: tan(Tensor self) + self: grad / self.cos().pow(2) -- name: tan -- name: tanh +- name: tanh(Tensor self) + self: _tanh_backward(grad, result) - name: tensor # fallthrough -- name: topk -- name: trace +- name: topk(Tensor self, int64_t k, int64_t dim, bool largest, bool sorted) + self: select_backward(grad, dim, indices, self.sizes(), true) + +- name: trace(Tensor self) + self: trace_backward(grad, self.sizes()) - name: transpose(Tensor self, int64_t dim0, int64_t dim1) self: grad.transpose(dim0, dim1) + __view__: True - name: tril(Tensor self, int64_t diagonal=0) self: grad.tril(diagonal) @@ -332,18 +508,129 @@ - name: triu(Tensor self, int64_t diagonal=0) self: grad.triu(diagonal) -- name: trtrs -- name: trunc -- name: unfold -- name: uniform +- name: trtrs(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) + self: not_implemented("trtrs") + +- name: trunc(Tensor self) + self: zeros_like(grad) + +- name: unfold(Tensor self, int64_t dimension, int64_t size, int64_t step) + self: unfold_backward(grad, self.sizes(), dimension, size, step) + +- name: uniform # fallthrough - name: unsqueeze(Tensor self, int64_t dim) self: grad.squeeze(dim) + __view__: True -- name: var +- name: var # TODO - name: view(Tensor self, IntList size) self: grad.contiguous().view(self.sizes()) + __view__: True + +- name: zero(Tensor self) + self: zeros_like(grad) + +- name: zeros # fallthrough + +# NN double backwards support + +- name: avg_pool2d_backward(Tensor grad_output, Tensor input, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) + grad_output: avg_pool2d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad) + input: zeros_like(input) + +- name: avg_pool3d_backward(Tensor grad_output, Tensor input, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) + grad_output: avg_pool3d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad) + input: zeros_like(input) + +- name: elu_backward(Tensor grad_output, Tensor input, Scalar alpha, bool inplace, Tensor output) + grad_output: elu_backward(grad, input, alpha, inplace, output) + input: grad * grad_input * (input < 0).toType(grad.type()) + +- name: glu_backward(Tensor grad_output, Tensor input, int64_t dim) + grad_output: glu_double_backward_grad_output(grad, input, dim) + input: glu_double_backward(grad, grad_output, input, dim) + +- name: hardshrink_backward(Tensor grad_output, Tensor input, Scalar lambd) + grad_output: hardshrink_backward(grad, input, lambd) + input: zeros_like(grad) + +- name: hardtanh_backward(Tensor grad_output, Tensor input, Scalar min_val, Scalar max_val, bool inplace) + grad_output: hardtanh_backward(grad, input, min_val, max_val, false) + input: zeros_like(grad) + +- name: kl_div_backward(Tensor input, Tensor target, bool size_average) + input: zeros_like(grad) + +- name: l1_loss_backward(Tensor input, Tensor target, bool size_average) + input: zeros_like(grad) + +- name: log_sigmoid_backward(Tensor grad_output, Tensor input, Tensor buffer) + grad_output: log_sigmoid_backward(grad, input, buffer) + input: log_sigmoid_double_backward(grad * grad_output, input) + +- name: log_softmax_backward(Tensor grad_output, Tensor input, int dim, Tensor output) + grad_output: grad - (grad * output.exp()).sum(dim, true) + input: log_softmax_double_backward(grad, grad_output, dim, output) + +- name: leaky_relu_backward(Tensor grad_output, Tensor input, Scalar negative_slope, bool inplace) + grad_output: leaky_relu_backward(grad, input, negative_slope, false) + input: zeros_like(grad) + +- name: max_pool2d_backward(Tensor grad_output, Tensor input, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode, Tensor indices) + grad_output: max_pool2d_double_backward(grad, indices); + input: zeros_like(input) + +- name: max_unpool2d_backward(Tensor grad_output, Tensor input, Tensor indices, IntList output_size) + grad_output: max_unpool2d(grad, indices, output_size) + input: zeros_like(input) + +- name: mse_loss_backward(Tensor grad_output, Tensor input, Tensor target, bool size_average, bool reduce) + grad_output: mse_loss_double_backward_grad_output(grad, grad_output, input, target, size_average, reduce) + input: mse_loss_double_backward(grad * grad_output, input, size_average, reduce) + +- name: nll_loss_backward(Tensor input, Tensor target, Tensor weight, bool size_average, int64_t ignore_index, Tensor total_weight) + input: zeros_like(grad) + +- name: nll_loss2d_backward(Tensor input, Tensor target, Tensor weight, bool size_average, int64_t ignore_index, Tensor total_weight) + input: zeros_like(grad) + +- name: prelu_backward(Tensor grad_output, Tensor input, Tensor weight, std::array<bool, 2> output_mask) + grad_output: zeros_like(grad_output) + input: zeros_like(input) + weight: zeros_like(weight) + +- name: rrelu_backward(Tensor grad_output, Tensor input, Scalar lower, Scalar upper, bool training, bool inplace, Tensor noise) + grad_output: rrelu_backward(grad, input, lower, upper, training, false, noise) + input: zeros_like(grad) + +- name: smooth_l1_loss_backward(Tensor input, Tensor target, bool size_average) + input: smooth_l1_loss_double_backward(grad, input, target, size_average) + +- name: softplus_backward(Tensor grad_output, Tensor input, Scalar beta, Scalar threshold, Tensor output) + grad_output: softplus_backward(grad, input, beta, threshold, output) + input: softplus_double_backward(grad * grad_output, input, beta, threshold) + +- name: softmax_backward(Tensor grad_output, Tensor input, int dim, Tensor output) + grad_output: softmax_backward(grad, input, dim, output) + input: softmax_double_backward(grad, grad_output, dim, output) + +- name: soft_margin_loss_backward(Tensor input, Tensor target, bool size_average) + input: soft_margin_loss_double_backward(grad, input, target, size_average) + +- name: softshrink_backward(Tensor grad_output, Tensor input, Scalar lambd) + grad_output: softshrink_backward(grad, input, lambd) + input: zeros_like(grad) + +- name: threshold_backward(Tensor grad_output, Tensor input, Scalar threshold, Scalar value, bool inplace) + grad_output: threshold_backward(grad, input, threshold, value, false) + input: zeros_like(grad) + +- name: _sigmoid_backward(Tensor grad_output, Tensor output) + grad_output: _sigmoid_backward(grad, output) + output: grad * grad_output * (-2 * output + 1) -- name: zero -- name: zeros +- name: _tanh_backward(Tensor grad_output, Tensor output) + grad_output: _tanh_backward(grad, output) + output: -2 * output * grad * grad_output diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 56e320c8e5..46821d6152 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -50,7 +50,8 @@ UNPACK_SELF = "auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;" def create_python_bindings( - python_functions, py_methods, py_method_defs, py_method_dispatch): + python_functions, py_methods, py_method_defs, py_method_dispatch, + is_class): """python_variable_methods.cpp Generates Python bindings to Variable methods @@ -61,6 +62,7 @@ def create_python_bindings( 'Generator *': 'generator', 'Storage &': 'storage', 'int64_t': 'toInt64', + 'int': 'toInt64', 'bool': 'toBool', 'double': 'toDouble', } @@ -157,7 +159,7 @@ def create_python_bindings( tmpl = PY_VARIABLE_METHOD_VARARGS env['flags'] = 'METH_VARARGS | METH_KEYWORDS' - if not is_method: + if is_class and not is_method: env['flags'] += ' | METH_STATIC' py_methods.append(tmpl.substitute(env)) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index bbeaea02bc..a349a2c4ce 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -3,8 +3,7 @@ import copy import os import re import yaml -from collections import OrderedDict, defaultdict -from itertools import groupby +from collections import defaultdict from tools.shared.module_loader import import_module from .nested_dict import nested_dict @@ -35,13 +34,16 @@ METHOD_DEFINITION_FALLTHROUGH = CodeTemplate("""\ return baseType->${method_prefix}${api_name}(${unpacked_args});""") METHOD_DEFINITION_FALLTHROUGH_VARIABLE = CodeTemplate("""\ -return make_variable(baseType->${method_prefix}${api_name}(${unpacked_args}));""") +return as_variable(baseType->${method_prefix}${api_name}(${unpacked_args}));""") -UNWRAP_TENSOR = CodeTemplate("""\ -auto& ${arg_name}_ = checked_unpack(${arg_name}, "${arg_name}", ${arg_pos});""") +METHOD_DEFINITION_FALLTHROUGH_INPLACE = CodeTemplate("""\ +baseType->${method_prefix}${api_name}(${unpacked_args}); +increment_version(self); +return self; +""") -UNWRAP_TENSORLIST = CodeTemplate("""\ -auto ${arg_name}_ = checked_unpack(${arg_name}, "${arg_name}", ${arg_pos});""") +UNPACK_TENSOR = CodeTemplate("""\ +auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});""") FUNCTION_DECLARATION = CodeTemplate("""\ struct ${op} : public Function { @@ -57,7 +59,7 @@ struct ${op} : public Function { FUNCTION_DEFINITION = CodeTemplate("""\ variable_list ${op}::apply(const variable_list& inputs) { - variable_list grad_inputs(${num_inputs}); + variable_list grad_inputs{${num_inputs}}; ${body} return grad_inputs; } @@ -68,10 +70,18 @@ static PyTypeObject ${op}Class; addClass<${op}>(${op}Class, "${op}"); """) - DERIVATIVE_TENSOR = CodeTemplate("""\ -if (should_compute_output(${i})) { - grad_inputs[${i}] = ${derivative}; +if (should_compute_output(${idx})) { + grad_inputs[${idx}] = ${derivative}; +} +""") + +DERIVATIVE_MULTI = CodeTemplate("""\ +if (should_compute_output({ ${idxs} })) { + auto output_mask = std::array<bool, ${n}>{ + ${masks} + }; + std::tie(${grad_inputs}) = ${derivative}; } """) @@ -81,50 +91,52 @@ if (should_compute_any_outputs()) { } """) -METHOD_DEFINITION_FLAGS_TENSORS = CodeTemplate("""\ -auto flags = Function::flags({ ${tensor_args} }); -""") - -METHOD_DEFINITION_FLAGS_TENSORLIST = CodeTemplate("""\ -auto flags = Function::flags( ${tensorlist_args}); -""") - METHOD_DEFINITION_DERIVATIVE = CodeTemplate("""\ -${flags_def} +auto flags = Function::flags({ ${tensor_args} }); auto grad_fn = std::make_shared<${op}>(); -if (flags.is_executable) { - ${save_variables} -} -auto output = as_variable(baseType->${method_prefix}${api_name}(${unpacked_args})); -${save_output} -wrap_output(*output.get(), std::move(flags), std::move(grad_fn)); +${buffers} +${save_inputs} +auto ret = as_variable(baseType->${method_prefix}${base_name}(${unpacked_args})); +${version_counter} +wrap_output(ret, std::move(flags), grad_fn); +${save_outputs} return ${return_value}; """) METHOD_DEFINITION_INPLACE = CodeTemplate("""\ auto& pImpl = static_cast<VariableImpl&>(*self.get()); check_inplace(pImpl); -${flags_def} +auto flags = Function::flags({ ${tensor_args} }); auto grad_fn = std::make_shared<${op}>(); -if (flags.is_executable) { - ${save_variables} -} -baseType->${method_prefix}${api_name}(${unpacked_args}); +${save_inputs} +baseType->${method_prefix}${base_name}(${unpacked_args}); (*pImpl.version_counter)++; -${save_output} -wrap_output(pImpl, std::move(flags), std::move(grad_fn)); +wrap_output(self, std::move(flags), grad_fn); +${save_outputs} return ${return_value}; """) -SAVE_OUTPUT = CodeTemplate("""\ -if (flags.is_executable) { - grad_fn->output_ = SavedVariable(${return_name}, grad_fn.get()); +METHOD_DEFINITION_NOT_DIFFERENTIABLE = CodeTemplate("""\ +auto flags = Function::flags({ ${tensor_args} }); +auto grad_fn = std::make_shared<Error>("${api_name} is not differentiable"); +auto ret = as_variable(baseType->${method_prefix}${api_name}(${unpacked_args})); +wrap_output(ret, std::move(flags), std::move(grad_fn)); +return ret; +""") + +CONDITIONAL = CodeTemplate("""\ +if (${cond}) { + ${statements} } """) FUNCTION_PROTOTYPE = CodeTemplate("""\ ${name}(${typed_args})""") +BUFFER_DECLARATION = CodeTemplate("""\ +auto ${name} = tensor(); +auto& ${name}_ = static_cast<VariableImpl*>(${name}.get())->data;""") + GENERATED_COMMENT = CodeTemplate("""\ generated from tools/autograd/templates/${filename}""") @@ -136,6 +148,9 @@ FUNCTIONS_H = CodeTemplate.from_file(template_path + '/Functions.h') FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/Functions.cpp') PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp') PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h') +PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp') +PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h') +PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h') PY_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_functions.h') PY_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_functions.cpp') @@ -146,9 +161,20 @@ deprecated_path = os.path.join(os.path.dirname(__file__), 'deprecated.yaml') # base at::Type FALLTHROUGH_RETURN_TYPES = {'int64_t', 'void*', 'bool', 'IntList'} FALLTHROUGH_FUNCTIONS = { - 'eye', 'linspace', 'logspace', 'tensor', 'ones', 'ones_like', 'rand', - 'randn' 'randperm', 'range', 'tensor', 'zeros', 'zeros_like', + 'arange', 'eye', 'linspace', 'logspace', 'tensor', 'ones', 'ones_like', + 'rand', 'randn', 'randperm', 'range', 'tensor', 'uniform', 'zeros', + 'zeros_like', 'set_', + # these are only implemented on integral types + '__and__', '__iand__', '__ilshift__', '__ior__', '__irshift__', '__ixor__', + '__lshift__', '__or__', '__rshift__', '__xor__', } +MANUAL_IMPLEMENTATIONS = { + 'contiguous', 'resize_', 'resize_as_' +} + +# Matches "foo" in "foo, bar" but not "foobar". Used to search for the +# occurence of a parameter in the derivative formula +IDENT_REGEX = r'(^|\W){}($|\W)' def format_return_type(returns): @@ -168,147 +194,222 @@ def write(dirname, name, template, env): f.write(template.substitute(env)) -def load_derivatives(path): - with open(path, 'r') as f: - definitions = yaml.load(f, Loader=Loader) +def saved_variables(formula, args): + # find which arguments need to be saved + saved = [] + + for arg in args: + if 'name' not in arg: + # some returned arguments do not have names + continue + name = arg['name'] + + def replace_sizes(m): + res = name + '_sizes' + saved.append({'name': res, 'type': 'IntList'}) + return res + + def replace_zeros(m): + r = name + '_info' + saved.append({'name': r, 'type': 'TypeAndSize'}) + return r + '.zeros()' + + def replace_size_n(m): + res = name + '_argsize_{}'.format(*m.groups()) + saved.append({'name': res, 'type': 'int64_t'}) + return res + + def replace_to_arg_sizes(m): + res = name + '_argsizes_{}'.format(*m.groups()) + saved.append({'name': res, 'type': 'IntList'}) + return res + + # replace self.sizes() with self_sizes + formula = re.sub(r'{}.sizes\(\)'.format(name), replace_sizes, formula) + # replace zeros_like(self) with self_info + formula = re.sub(r'zeros_like\({}\)'.format(name), replace_zeros, formula) + # replace self.size(2) with self_size_2 + formula = re.sub(r'{}.size\((\w+)\)'.format(name), replace_size_n, formula) + # replace to_arg_sizes(self, 2) with self_argsizes_2 + formula = re.sub(r'to_arg_sizes\({}, (\w+)\)'.format(name), replace_to_arg_sizes, formula) + + if re.search(IDENT_REGEX.format(name), formula): + arg = copy.deepcopy(arg) + arg['type'] = arg['type'].replace('const ', '').replace(' &', '') + saved.append(arg) + return formula, saved + + +def create_derivative(declaration, formula, output_indices, var_names): + returns = [r for r in declaration['returns'] if r.get('name') != 'self'] + arguments = declaration['arguments'] + if any(arg['name'] == 'inplace' for arg in arguments): + for arg in arguments: + if arg['name'] == 'input': + returns += [arg] + arguments = [arg for arg in arguments if arg['name'] != 'input'] + formula, saved_inputs = saved_variables(formula, arguments) + formula, saved_outputs = saved_variables(formula, returns) + + return { + 'formula': formula, + 'output_indices': output_indices, + 'saved_inputs': saved_inputs, + 'saved_outputs': saved_outputs, + 'var_names': var_names, + } + + +def create_autograd_function(name, derivatives, num_inputs, buffers=None): + return { + 'name': name, + 'op': to_camel_case(name) + 'Backward', + 'num_inputs': num_inputs, + 'derivatives': derivatives, + 'buffers': [] if buffers is None else buffers, + 'saved_inputs': all_saved_variables(derivatives, 'saved_inputs'), + 'saved_outputs': all_saved_variables(derivatives, 'saved_outputs'), + } + + +def all_saved_variables(derivatives, key): + seen = set() + saved = [] + for d in derivatives: + for saved_arg in d[key]: + if saved_arg['name'] in seen: + continue + seen.add(saved_arg['name']) + saved.append(saved_arg) + return saved - # Matches "foo" in "foo, bar" but not "foobar". The name is substituted for - # the {} characters. - name_regex = r'(^|\W){}($|\W)' - def split_name_params(prototype): - name, params = re.match('(\w+)\((.*)\)', prototype).groups() - return name, params.split(', ') +def to_camel_case(name): + return ''.join([p.title() for p in name.split('_')]) - def get_signature(option): - arguments = option['python_arguments'] - arg_types = [arg['type'] for arg in arguments] - if option['aten'] is not None: - call_args = split_name_params(option['aten'])[1] - arg_indices = {arg['name']: i for i, arg in enumerate(arguments)} - def get_type(arg_name): - if arg_name not in arg_indices: - # if the name is not an argument, assume it's a literal - # number, with type 'Scalar' - return 'Scalar' - return arg_types[arg_indices[arg_name]] +def split_name_params(prototype): + name, params = re.match('(\w+)\((.*)\)', prototype).groups() + return name, params.split(', ') - arg_types = [get_type(arg_name) for arg_name in call_args] - return '{}({})'.format(option['name'], ', '.join(arg_types)) + +def load_derivatives(path, declarations_by_signature): + with open(path, 'r') as f: + definitions = yaml.load(f, Loader=Loader) + + def canonical_declaration(declarations, name): + for declaration in declarations: + if declaration['name'] == name: + return declaration + # some functions only have in-place variants + assert name + '_' == declarations[0]['name'] + return declarations[0] # Parse each entry from derivatives.yaml - options = [] + autograd_functions = [] for defn in definitions: - option = {} if '(' not in defn['name']: continue + name, params = split_name_params(defn['name']) - num_tensor_inputs = 0 - option['name'] = name - option['aten'] = defn.get('aten') - option['python_arguments'] = [] - option['prototype'] = defn['name'] # with default - option['fallthrough'] = defn.get('fallthrough', False) - option['op'] = name[0].upper() + name[1:] + 'Backward' - - arg_sizes_found = [] + param_types = [p.split(' ')[0] for p in params if p != '*'] + signature = '{}({})'.format(name, ', '.join(param_types)) + + declarations = declarations_by_signature[signature] + if len(declarations) == 0: + raise RuntimeError('no ATen declaration found for: {}'.format(signature)) + canonical = canonical_declaration(declarations, name) + + num_inputs = 0 derivatives = [] - for param in params: - if param == '' or param == '*': + for arg in canonical['arguments']: + if arg['name'] not in defn: continue - arg = {} - arg['type'], name = param.split(' ') - if '=' in name: - name, default = name.split('=') - arg['optional'] = True - arg['default'] = default - arg['name'] = name - option['python_arguments'].append(arg) - - if name in defn: - saved = [] - formula = defn[name] - for arg in option['python_arguments']: - size_str = arg['name'] + '.sizes()' - if size_str in formula: - sizes_name = arg['name'] + '_sizes' - formula = formula.replace(size_str, sizes_name) - - # If x is a TensorList, turn x.sizes(y) into x_argsizes_y - def argsizes_repl(matchobj): - if arg['type'] != 'TensorList': - raise RuntimeError("sizes(argument) only supported on TensorList") - argsizes_name = arg['name'] + "_argsizes_" + matchobj.group(1) - arg_sizes_found.append(argsizes_name + ".size()") - return argsizes_name - formula = re.sub(arg['name'] + r".sizes\((\w+)\)", argsizes_repl, formula) - - # If x is a Tensor, turn x.size(y) into x_argsize_y - def argsize_repl(matchobj): - if arg['type'] != 'Tensor': - raise RuntimeError("size(argument) only supported on Tensor") - argsize_name = arg['name'] + "_argsize_" + matchobj.group(1) - return argsize_name - formula = re.sub(arg['name'] + r".size\((\w+)\)", argsize_repl, formula) - - derivatives.append(formula) - arg['derivative'] = formula - if arg['type'] != "TensorList": - num_tensor_inputs += 1 - - if arg_sizes_found: - option['num_inputs'] = ("+".join(arg_sizes_found) + - "" if num_tensor_inputs == 0 else " + " + str(num_tensor_inputs)) - else: - option['num_inputs'] = str(num_tensor_inputs) + formula = defn[arg['name']] + if arg['type'] == 'TensorList': + num_inputs = '' + output_indices = '*' + else: + output_indices = [num_inputs] + num_inputs += 1 + derivatives.append(create_derivative(canonical, formula, output_indices, [arg['name']])) - if option['aten'] is not None: - option['call_args'] = split_name_params(option['aten'])[1] - else: - option['call_args'] = [arg['name'] for arg in option['python_arguments']] - option['signature'] = get_signature(option) + func = create_autograd_function(name, derivatives, num_inputs) + func['__view__'] = defn.get('__view__', False) + autograd_functions.append(func) + for declaration in declarations: + declaration['derivative'] = func + + return autograd_functions + + +def ensure_unique_names(autograd_functions): + # de-duplicate operation names + functions_by_name = defaultdict(list) + for func in autograd_functions: + functions_by_name[func['op']].append(func) + for op in functions_by_name.keys(): + overloads = functions_by_name[op] + if len(overloads) > 1: + for i, func in enumerate(overloads): + func['op'] += str(i) + + +def preprocess_nn_functions(declarations): + declarations_by_name = defaultdict(list) + for d in declarations: + declarations_by_name[d['name']].append(d) + + autograd_functions = [] + for declaration in declarations: + name = declaration['name'] + if name == 'batch_norm' or 'conv' in name: + continue + + fwd_name = name + '_forward' + if fwd_name not in declarations_by_name: + continue + declaration['base_name'] = fwd_name + + fwd = declarations_by_name[fwd_name][0] - saved = [] - for arg in option['python_arguments']: + input_num = 0 + bwd_name = name + '_backward' + assert len(declarations_by_name[bwd_name]) == 1 + bwd = declarations_by_name[bwd_name][0] + + def actual(arg): name = arg['name'] - sizes_name = name + '_sizes' - if any(re.search(name_regex.format(name), f) for f in derivatives): - saved.append(arg) - if any(sizes_name in f for f in derivatives): - saved.append({ - 'name': sizes_name, - 'type': 'IntList', - }) - for f in derivatives: - for match_name in re.findall(r"{}_argsize_\w+".format(name), f): - saved.append({ - 'name': match_name, - 'type': 'int64_t', - }) - for match_name in re.findall(r"{}_argsizes_\w+".format(name), f): - saved.append({ - 'name': match_name, - 'type': 'IntList', - }) - option['saved'] = saved - option['save_output'] = any(re.search(name_regex.format('output'), f) for f in derivatives) - - options.append(option) - - options = sorted(options, key=lambda o: o['name']) - for name, overloads in groupby(options, lambda o: o['name']): - overloads = list(overloads) - for i, option in enumerate(overloads): - name = option['name'] - option['op'] = name[0].upper() + name[1:] + 'Backward' - if len(overloads) > 1: - option['op'] += str(i) - - return options - - -def create_autograd_functions(top_env, declarations): + return name if name != 'inplace' else 'false' + + actuals = [actual(arg) for arg in bwd['arguments']] + formula = '{}({})'.format(bwd_name, ', '.join(actuals)) + formula = formula.replace('grad_output', 'grad') + if not re.search(IDENT_REGEX.format('grad'), formula): + formula = '({}).mul_(grad)'.format(formula) + + # we are computing the derivatives w.r.t these variables + var_names = [] + for ret in bwd['returns']: + assert ret['name'].startswith('grad_') + var_names.append(ret['name'][5:]) # remove grad_ prefix + output_indices = list(range(len(var_names))) + derivatives = [create_derivative(fwd, formula, output_indices, var_names)] + input_num += len(output_indices) + + # find arguments to foo_forward() call which don't exist in foo() + # these are buffers which have to be saved for the backwards call + args_by_name = {arg['name']: arg for arg in declaration['arguments']} + buffers = [arg['name'] for arg in fwd['arguments'] + if arg['name'] not in args_by_name] + + func = create_autograd_function(name, derivatives, input_num, buffers) + declaration['derivative'] = func + autograd_functions.append(func) + return autograd_functions + + +def create_autograd_functions(top_env, autogen_functions): """Functions.h and Functions.cpp body These contain the auto-generated subclasses of torch::autograd::Function @@ -318,78 +419,75 @@ def create_autograd_functions(top_env, declarations): function_declarations = top_env['autograd_function_declarations'] py_function_initializers = top_env['py_function_initializers'] - def process_function(op): + def process_function(func): + env = {} saved_variables = [] release_variables = [] - for arg in op['saved']: + unpack = [] + + def save_arg(arg, is_output): name = arg['name'] - if arg['type'] == 'Tensor': + if arg['type'] == 'Tensor' or (arg['type'] == 'Scalar' and is_output): saved_variables.append('SavedVariable {}_;'.format(name)) release_variables.append('{}_.data.reset();'.format(name)) + ptr = 'shared_from_this()' if is_output else '' + unpack.append('auto {} = {}_.unpack({});'.format(name, name, ptr)) elif arg['type'] == 'IntList': saved_variables.append('std::vector<int64_t> {};'.format(name)) else: saved_variables.append('{} {};'.format(arg['type'], name)) - if op['save_output']: - saved_variables.append('SavedVariable output_;') - op['saved_variables'] = saved_variables - op['release_variables'] = release_variables + + for arg in func['saved_inputs']: + save_arg(arg, is_output=False) + for arg in func['saved_outputs']: + save_arg(arg, is_output=True) + env['saved_variables'] = saved_variables + env['release_variables'] = release_variables + + def uses_grad(func): + for derivative in func['derivatives']: + formula = derivative['formula'] + if re.search(IDENT_REGEX.format('grad'), formula): + return True + return False body = [] - body.append('auto& grad = inputs[0];') - - def unpack_args(): - unpack = [] - for arg in op['saved']: - if arg['type'] == 'Tensor': - name = arg['name'] - unpack.append('auto {} = {}_.unpack();'.format(name, name)) - if op['save_output']: - unpack.append('auto output = output_.unpack(shared_from_this());') - return unpack - - body.extend(unpack_args()) - - i = 0 - added_derivative_tensor = False - added_derivative_tensorlist = False - for arg in op['python_arguments']: - derivative = arg.get('derivative') - if derivative is None: - continue - if arg['type'] == 'TensorList': - if added_derivative_tensor: - raise RuntimeError("derivatives don't support specifying both a TensorList " - "and non-TensorList derivative yet") - added_derivative_tensorlist = True - body.append(DERIVATIVE_TENSORLIST.substitute({ - 'i': i, - 'derivative': derivative, - })) + if uses_grad(func): + body.append('auto& grad = inputs[0];') + + def emit_derivative(derivative): + formula = derivative['formula'] + idxs = derivative['output_indices'] + if idxs == '*': + return DERIVATIVE_TENSORLIST.substitute(derivative=formula) + elif len(idxs) == 1: + return DERIVATIVE_TENSOR.substitute(idx=idxs[0], derivative=formula) else: - if added_derivative_tensorlist: - raise RuntimeError("derivatives don't support specifying both a TensorList " - "and non-TensorList derivative yet") - added_derivative_tensor = True - body.append(DERIVATIVE_TENSOR.substitute({ - 'i': i, - 'derivative': derivative, - })) - i += 1 + grad_inputs = ', '.join(['grad_inputs[{}]'.format(i) for i in idxs]) + masks = ['should_compute_output({}),'.format(i) for i in idxs] + return DERIVATIVE_MULTI.substitute( + idxs=idxs, derivative=formula, grad_inputs=grad_inputs, + masks=masks, n=len(idxs)) + + body.extend(unpack) + for derivative in func['derivatives']: + body.append(emit_derivative(derivative)) - op['body'] = body - function_declarations.append(FUNCTION_DECLARATION.substitute(op)) - function_definitions.append(FUNCTION_DEFINITION.substitute(op)) - py_function_initializers.append(PY_FUNCTION_DEFINITION.substitute(op)) + env['body'] = body + env = nested_dict(env, func) + function_declarations.append(FUNCTION_DECLARATION.substitute(env)) + function_definitions.append(FUNCTION_DEFINITION.substitute(env)) + py_function_initializers.append(PY_FUNCTION_DEFINITION.substitute(env)) - for option in declarations: - process_function(option) + for func in autogen_functions: + process_function(func) def is_implemented(option): return (option['return_type'] in FALLTHROUGH_RETURN_TYPES or option['name'] in FALLTHROUGH_FUNCTIONS or + option['name'].endswith('_backward') or option.get('derivative') is not None) @@ -404,44 +502,105 @@ def create_variable_type(top_env, aten_declarations): type_declarations = top_env['type_derived_method_declarations'] type_definitions = top_env['type_derived_method_definitions'] - def save_variables(option, derivative): + def skip_function(name): + return (name.endswith('_out') or name.endswith('_forward')) + + def differentiable_args(declaration, autograd_function): + names = set(name for d in autograd_function['derivatives'] for name in d['var_names']) + args = [arg for arg in declaration['arguments'] if arg['name'] in names] + if len(args) != len(names): + missing = names - set(arg['name'] for arg in args) + raise RuntimeError('Missing arguments for derivatives: {}'.format(missing)) + return args + + def save_variables(option, saved_variables, is_output): # assign the saved variables to the generated grad_fn stmts = [] - for arg in derivative['saved']: + for arg in saved_variables: name = arg['name'] expr = arg['name'] + if is_output and not option['inplace']: + if len(option['returns']) > 1: + # unpack multiple outputs + return_names = [r['name'] for r in option['returns']] + idx = return_names.index(name) + stmts.append('auto& {} = std::get<{}>(ret);'.format(name, idx)) + elif name != 'input': + stmts.append('auto& {} = ret;'.format(name)) if '_sizes' in name: expr = name.replace('_sizes', '.sizes()') + elif name.endswith('_info'): + expr = name.replace('_info', '') elif '_argsize_' in name: - # turn x_argsizes_y into to_arg_sizes(x, y) + # turn x_argsize_y into x.size(y) expr = re.sub(r"(\w+)_argsize_(\w+)", r"\1.size(\2)", name) elif '_argsizes_' in name: # turn x_argsizes_y into to_arg_sizes(x, y) expr = re.sub(r"(\w+)_argsizes_(\w+)", r"to_arg_sizes(\1, \2)", name) - elif arg['type'] == 'Tensor': + elif arg['type'] == 'Tensor' or (is_output and arg['type'] == 'Scalar'): name += '_' var = arg['name'] if var == 'self' and option['inplace']: var = 'self.clone()' - expr = 'SavedVariable({}, nullptr)'.format(var) + assert not is_output + if option['inplace'] and is_output: + var = 'self' + ptr = 'grad_fn.get()' if is_output else 'nullptr' + expr = 'SavedVariable({}, {})'.format(var, ptr) stmts.append('grad_fn->{} = {};'.format(name, expr)) + if len(stmts) > 0: + return CONDITIONAL.substitute( + cond='flags.is_executable', + statements=stmts) return stmts + def requires_unpack(arg): + return 'Tensor' in arg['dynamic_type'] + + def get_suffix(dynamic_type, is_nullable): + if is_nullable: + assert dynamic_type == 'Tensor' + return '_opt' + elif dynamic_type == 'IndexTensor': + return '_long' + elif dynamic_type == 'BoolTensor': + return '_byte' + else: + return '' + def unpack_args(env, option): body = [] unpacked_args = [] for i, arg in enumerate(option['arguments']): - if arg['dynamic_type'] == 'Tensor': - body.append(UNWRAP_TENSOR.substitute(arg_name=arg['name'], arg_pos=i)) - unpacked_args.append(arg['name'] + '_') - elif arg['dynamic_type'] == 'TensorList': - body.append(UNWRAP_TENSORLIST.substitute(arg_name=arg['name'], arg_pos=i)) - unpacked_args.append(arg['name'] + '_') - else: + if not requires_unpack(arg): unpacked_args.append(arg['name']) + continue + + dynamic_type = arg['dynamic_type'] + is_nullable = arg.get('is_nullable', False) + ref = (not is_nullable) and dynamic_type != 'TensorList' + suffix = get_suffix(dynamic_type, is_nullable) + + body.append(UNPACK_TENSOR.substitute( + arg_name=arg['name'], + arg_pos=i, + suffix=suffix, + ref='&' if ref else '', + )) + unpacked_args.append(arg['name'] + '_') + + if option.get('derivative') is not None: + for arg in option['derivative'].get('buffers', []): + unpacked_args.append(arg + '_') env['unpacked_args'] = unpacked_args return body + def emit_buffers(buffers): + res = [] + for name in buffers: + res.append(BUFFER_DECLARATION.substitute(name=name)) + return res + def emit_body(env, option): if not is_implemented(option): return METHOD_DEFINITION_NYI.substitute(option) @@ -453,55 +612,61 @@ def create_variable_type(top_env, aten_declarations): if option['return_type'] in FALLTHROUGH_RETURN_TYPES: body.extend(METHOD_DEFINITION_FALLTHROUGH.substitute(combined).split('\n')) return body - elif option['derivative'] is None: - assert option['name'] in FALLTHROUGH_FUNCTIONS - body.extend(METHOD_DEFINITION_FALLTHROUGH_VARIABLE.substitute(combined).split('\n')) + elif option['name'] in FALLTHROUGH_FUNCTIONS: + tmpl = (METHOD_DEFINITION_FALLTHROUGH_INPLACE if option['inplace'] + else METHOD_DEFINITION_FALLTHROUGH_VARIABLE) + body.extend(tmpl.substitute(combined).split('\n')) + return body + elif option.get('derivative') is None: + assert option['name'].endswith('_backward'), option['name'] + body.extend(METHOD_DEFINITION_NOT_DIFFERENTIABLE.substitute(combined).split('\n')) return body - if combined['tensorlist_args']: - flags_def = METHOD_DEFINITION_FLAGS_TENSORLIST.substitute(combined) - if combined['tensor_args']: - raise RuntimeError("both tensorlist_args and tensor_args not currently supported") - else: - flags_def = METHOD_DEFINITION_FLAGS_TENSORS.substitute(combined) if option['inplace']: - body.extend(METHOD_DEFINITION_INPLACE.substitute(combined, flags_def=flags_def).split('\n')) + body.extend(METHOD_DEFINITION_INPLACE.substitute(combined).split('\n')) else: - body.extend(METHOD_DEFINITION_DERIVATIVE.substitute(combined, flags_def=flags_def).split('\n')) + body.extend(METHOD_DEFINITION_DERIVATIVE.substitute(combined).split('\n')) return body - def process_function(option): - env = {} + def process_function(declaration): + if skip_function(declaration['name']): + return - if option['inplace']: + env = { + 'version_counter': [], + } + + if declaration['inplace']: env['return_value'] = 'self' - return_name = 'self' else: - if option['return_type'] == 'Scalar': - env['return_value'] = 'Scalar(output)' - else: - env['return_value'] = 'Tensor(std::move(output))' - return_name = 'output' + env['return_value'] = '{}(std::move(ret))'.format(declaration['return_type']) + + if declaration.get('derivative') is not None: + func = declaration['derivative'] + env['op'] = func['op'] + env['buffers'] = emit_buffers(func.get('buffers', [])) + env['save_inputs'] = save_variables(declaration, func['saved_inputs'], False) + env['save_outputs'] = save_variables(declaration, func['saved_outputs'], True) + dargs = differentiable_args(declaration, func) + env['tensor_args'] = [arg['name'] for arg in dargs] + if any(arg['name'] == 'inplace' for arg in declaration['arguments']): + env['version_counter'].append('if (inplace) increment_version(input);') + if func.get('__view__', False): + env['version_counter'].append('take_version_counter(ret, self);') - if option.get('derivative') is not None: - derivative = option['derivative'] - env['op'] = derivative['op'] - env['save_variables'] = save_variables(option, derivative) - env['save_output'] = SAVE_OUTPUT.substitute(return_name=return_name) if derivative['save_output'] else '' - env['tensor_args'] = [arg['name'] for arg in option['arguments'] - if arg['dynamic_type'] == 'Tensor'] - env['tensorlist_args'] = [arg['name'] for arg in option['arguments'] - if arg['dynamic_type'] == 'TensorList'] + else: + env['tensor_args'] = [arg['name'] for arg in declaration['arguments'] + if arg['simple_type'] in {'Tensor', 'TensorList'}] - env['type_definition_body'] = emit_body(env, option) + env['type_definition_body'] = emit_body(env, declaration) - combined = nested_dict(env, option) + combined = nested_dict(env, declaration) type_declarations.append(METHOD_DECLARATION.substitute(combined)) - if option['name'] != 'resize_': + if declaration['name'] not in MANUAL_IMPLEMENTATIONS: type_definitions.append(METHOD_DEFINITION.substitute(combined)) - for function in aten_declarations: - process_function(function) + for declaration in aten_declarations: + process_function(declaration) def load_aten_declarations(path): @@ -523,6 +688,14 @@ def load_aten_declarations(path): declaration['api_name'] = declaration['name'] declaration['return_type'] = format_return_type(declaration['returns']) + declaration['base_name'] = declaration['name'] + + # if the return value is missing a name, call it 'output' + for ret in declaration['returns']: + if 'name' not in ret: + assert len(declaration['returns']) == 1 + ret['name'] = 'result' + # Compute the Python function prototype for argument parsing typed_args = [] positional = True @@ -530,10 +703,13 @@ def load_aten_declarations(path): if arg.get('kwarg_only', False) and positional: typed_args.append('*') positional = False - param = arg['simple_type'] + ' ' + arg['name'] + typename = arg['simple_type'] + if arg.get('size') is not None: + typename = '{}[{}]'.format(typename, arg['size']) + param = typename + ' ' + arg['name'] if arg.get('default') is not None: default = arg['default'] - if default == 'nullptr': + if default == 'nullptr' or default == '{}': default = 'None' param += '=' + str(default) typed_args.append(param) @@ -546,46 +722,64 @@ def load_aten_declarations(path): def load_deprecated_signatures(declarations_by_signature): + with open(deprecated_path, 'r') as f: + deprecated_defs = yaml.load(f, Loader=Loader) declarations = [] - for deprecated in load_derivatives(deprecated_path): - declaration = declarations_by_signature[deprecated['signature']][0] - declaration = copy.deepcopy(declaration) - declaration['deprecated'] = True - args_by_name = {arg['name']: arg for arg in declaration['arguments']} - declaration['arguments'] = [ - args_by_name[arg['name']] for arg in deprecated['python_arguments']] - declaration['call_args'] = deprecated['call_args'] - declaration['prototype'] = deprecated['prototype'] - declarations.append(declaration) + + def get_signature(name, params, call_args): + # create a mapping of parameter name to parameter type + types = dict([param.split(' ')[::-1] for param in params]) + # if the name in the call is not in the parameter list, assume it's + # a literal Scalar + rearranged_types = [types.get(arg, 'Scalar') for arg in call_args] + return '{}({})'.format(name, ', '.join(rearranged_types)) + + for deprecated in deprecated_defs: + prototype = deprecated['name'] + call_args = split_name_params(deprecated['aten'])[1] + name, params = split_name_params(prototype) + signature = get_signature(name, params, call_args) + + for declaration in declarations_by_signature[signature]: + declaration = copy.deepcopy(declaration) + declaration['deprecated'] = True + declaration['call_args'] = call_args + if declaration['inplace']: + declaration['prototype'] = prototype.replace(name, name + '_') + else: + declaration['prototype'] = prototype + + args_by_name = {arg['name']: arg for arg in declaration['arguments']} + declaration['arguments'] = [] + for arg in params: + _, arg_name = arg.split(' ') + declaration['arguments'].append(args_by_name[arg_name]) + declarations.append(declaration) return declarations def gen_variable_type(declarations, out): aten_decls = load_aten_declarations(declarations) - derivatives = load_derivatives(derivatives_path) def by_name(option): return option['name'] - def by_aten_name(option): - return option.get('aten_name', option['name']) + def group_declarations_by_signature(): + d = defaultdict(list) + for declaration in aten_decls: + name = declaration['name'] + base_name = name[:-1] if declaration['inplace'] else name + simple_types = [arg['simple_type'] for arg in declaration['arguments']] + signature = '{}({})'.format(base_name, ', '.join(simple_types)) + d[signature].append(declaration) + return d - aten_decls = sorted(aten_decls, key=by_name) - derivatives = sorted(derivatives, key=by_aten_name) + declarations_by_signature = group_declarations_by_signature() - derivatives_by_signature = {d['signature']: d for d in derivatives} - options_by_name = OrderedDict([(k, list(g)) for k, g in groupby(aten_decls, by_name)]) - options_by_signature = defaultdict(list) - - for declaration in aten_decls: - name = declaration['name'] - base_name = name[:-1] if declaration['inplace'] else name - simple_types = [arg['simple_type'] for arg in declaration['arguments']] - signature = '{}({})'.format(base_name, ', '.join(simple_types)) - options_by_signature[signature].append(declaration) - - derivative = derivatives_by_signature.get(signature) - declaration['derivative'] = derivative + th_autograd_funcs = load_derivatives(derivatives_path, declarations_by_signature) + nn_autograd_funcs = preprocess_nn_functions(aten_decls) + all_autograd_functions = th_autograd_funcs + nn_autograd_funcs + ensure_unique_names(all_autograd_functions) def should_generate_python_binding(declaration): name = declaration['name'] @@ -597,6 +791,9 @@ def gen_variable_type(declarations, out): if name in ['size', 'stride']: return False + if name.endswith('_backward'): + return False + # we don't currently support functions which are only defined on Type # such as zeros(), randn(), etc. method_of = declaration['method_of'] @@ -605,15 +802,19 @@ def gen_variable_type(declarations, out): return True - python_functions = defaultdict(list) + py_variable_methods = defaultdict(list) + py_nn_functions = defaultdict(list) for declaration in aten_decls: name = declaration['name'] if not should_generate_python_binding(declaration): continue - python_functions[name].append(declaration) + if declaration['mode'] == 'NN': + py_nn_functions[name].append(declaration) + else: + py_variable_methods[name].append(declaration) - for declaration in load_deprecated_signatures(options_by_signature): - python_functions[declaration['name']].append(declaration) + for declaration in load_deprecated_signatures(declarations_by_signature): + py_variable_methods[declaration['name']].append(declaration) env = { 'autograd_function_declarations': [], @@ -624,17 +825,28 @@ def gen_variable_type(declarations, out): 'py_method_defs': [], 'py_method_dispatch': [], 'py_function_initializers': [], + 'py_nn_functions': [], + 'py_nn_function_defs': [], + 'py_nn_function_dispatch': [], } - create_autograd_functions(env, derivatives) + create_autograd_functions(env, all_autograd_functions) create_variable_type(env, aten_decls) from .gen_python_functions import create_python_bindings create_python_bindings( - python_functions, + py_variable_methods, env['py_methods'], env['py_method_defs'], - env['py_method_dispatch']) + env['py_method_dispatch'], + is_class=True) + + create_python_bindings( + py_nn_functions, + env['py_nn_functions'], + env['py_nn_function_defs'], + env['py_nn_function_dispatch'], + is_class=False) write(out, 'VariableType.h', VARIABLE_TYPE_H, env) write(out, 'VariableType.cpp', VARIABLE_TYPE_CPP, env) @@ -642,6 +854,9 @@ def gen_variable_type(declarations, out): write(out, 'Functions.cpp', FUNCTIONS_CPP, env) write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env) write(out, 'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env) + write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env) + write(out, 'python_nn_functions.h', PY_NN_FUNCTIONS_H, env) + write(out, 'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env) write(out, 'python_functions.h', PY_FUNCTIONS_H, env) write(out, 'python_functions.cpp', PY_FUNCTIONS_CPP, env) diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index a734013a9d..71fc176543 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -6,7 +6,14 @@ using at::Tensor; using at::Scalar; using at::IntList; -namespace torch { namespace autograd { +namespace torch { namespace autograd { namespace generated { + +namespace { + +Tensor not_implemented(const char* name) { + throw std::runtime_error( + std::string("the derivative for '") + name + "' is not implemented"); +} Tensor maybe_multiply(const Tensor & t, const Scalar & s) { bool is_one = false; @@ -25,20 +32,45 @@ Tensor maybe_multiply(const Tensor & t, const Scalar & s) { Tensor norm_backward(const Tensor & grad, const Tensor & self, const Scalar & p_) { auto p = p_.toDouble(); + auto norm = self.norm(p_); + + if (norm.toDouble() == 0.0) { + // handle case at 0 where we return a subgradient containing 0 + return zeros_like(self); + } + if (p == 2.0) { - return self * (grad / self.norm(2)); + return self * (grad / norm); } else { auto pow_ = self.abs().pow(p - 2); - auto scale_v = grad / self.norm(p).toTensor().pow(p - 1); + auto scale_v = grad / norm.toTensor().pow(p - 1); return self * pow_ * scale_v; } } -Tensor norm_backward(const Tensor & grad, const Tensor & self, const Scalar & p, int64_t dim, bool keepdim) { - throw std::runtime_error("norm_backward(dim): NYI"); +Tensor norm_backward(Tensor grad, const Tensor & self, const Scalar & p_, int64_t dim, bool keepdim) { + if (!keepdim && self.dim() > 1) { + grad = grad.unsqueeze(dim); + } + auto p = p_.toDouble(); + auto norm = self.norm(p, dim, true); + Tensor grad_input; + if (p == 2.0) { + grad_input = self * (grad / norm); + } else { + auto pow_ = self.abs().pow(p - 2); + auto scale_v = grad / norm.pow(p - 1); + grad_input = self * pow_ * scale_v; + } + // handle case at 0 where we return a subgradient containing 0 + grad_input.masked_fill_(norm == 0, 0); + return grad_input; } Tensor reduce_to(const Tensor & grad, IntList sizes) { + if (sizes.size() == 0) { + return grad.sum().toTensor(); + } Tensor result = grad; while (result.dim() > (int64_t)sizes.size()) { result = result.sum(0, false); @@ -52,7 +84,7 @@ Tensor reduce_to(const Tensor & grad, IntList sizes) { } Tensor sum_backward(const Tensor & grad, IntList sizes, int64_t dim, bool keepdim) { - if (!keepdim) { + if (!keepdim && sizes.size() > 1) { return grad.unsqueeze(dim).expand(sizes); } else { return grad.expand(sizes); @@ -78,6 +110,13 @@ Tensor unsqueeze_to(const Tensor & self, IntList sizes) { return result; } +Tensor maybe_unsqueeze(const Tensor & self, int64_t dim, int64_t prev_size) { + if (prev_size == 1) { + return self.unsqueeze(dim); + } + return self; +} + Tensor addmm_self_backward(const Tensor & grad, const Scalar &beta) { return maybe_multiply(grad, beta); } @@ -114,6 +153,202 @@ variable_list cat_tensors_backward(const Tensor & grad, const std::vector<int64_ return grad_inputs; } +Tensor select_backward_scalar(Tensor grad, const Tensor & input, const Tensor & value) { + if (grad.dim() == 1) { + // TODO: remove this once zero-dim tensor work properly in PyTorch + grad = grad.view({}); + } + auto grad_input = zeros_like(input); + grad_input.masked_fill_(input == value, Scalar(grad)); + return grad_input; +} + +Tensor select_backward(Tensor grad, int64_t dim, Tensor indices, IntList sizes, bool keepdim) { + if (!keepdim && sizes.size() > 1) { + grad = grad.unsqueeze(dim); + indices = indices.unsqueeze(dim); + } + return grad.type().zeros(sizes).scatter_(dim, indices, grad); +} + +Tensor trace_backward(const Tensor & grad, IntList sizes) { + if (sizes.size() != 2) { + throw std::runtime_error("expected matrix input"); + } + + // TODO: simplify once toScalarType is virtual + auto& long_type = *VariableImpl::getType( + Variable(grad).data().type().toScalarType(at::kLong)); + + auto grad_input = grad.type().zeros(sizes[0] * sizes[1]); + auto indices = long_type.arange(0, grad_input.numel(), sizes[1] + 1); + grad_input.index_fill_(0, indices, Scalar(grad.view({}))); + return grad_input.view(sizes); +} + +Tensor unfold_backward(const Tensor & grad, IntList input_sizes, int64_t dim, int64_t size, int64_t step) { + // TODO: simplify once toScalarType is virtual + auto& long_type = *VariableImpl::getType( + Variable(grad).data().type().toScalarType(at::kLong)); + + int64_t numel = 1; + for (auto size : input_sizes) { + numel *= size; + } + + auto idx = long_type.arange(0, numel).view(input_sizes); + auto idx_unfolded = idx.unfold(dim, size, step).contiguous().view(-1); + auto grad_input = grad.type().zeros({numel}); + grad_input.index_add_(0, idx_unfolded, grad.contiguous().view(-1)); + return grad_input.view(input_sizes); +} + +Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntList sizes) { + int64_t numel = 1; + for (auto size : sizes) { + numel *= size; + } + auto mask_selected = grad.masked_select(mask); + auto diff_nelem = numel - mask_selected.numel(); + if (diff_nelem > 0) { + // because mask_selected returns a 1-d tensor with size of masked elements that are 1, + // we need to fill out the rest with zeros then reshape back to tensor2's size. + auto zeros_fillin = grad.type().zeros({diff_nelem}); + mask_selected = at::cat({mask_selected, zeros_fillin}, 0); + } + return mask_selected.view(sizes); +} + +Tensor potrf_backward(Tensor grad, bool upper, Tensor L) { + // cf. Iain Murray (2016); arXiv 1602.07527 + if (upper) { + L = L.t(); + grad = grad.t(); + } + + auto phi = [](const Tensor & A) -> Tensor { + auto B = A.tril(); + B = B - 0.5 * at::diag(at::diag(B)); + return B; + }; + + // make sure not to double-count variation, since + // only half of output matrix is unique + auto Lbar = grad.tril(); + + auto P = phi(at::mm(L.t(), Lbar)); + Tensor S; + std::tie(S, std::ignore) = at::gesv(P + P.t(), L.t()); + std::tie(S, std::ignore) = at::gesv(S.t(), L.t()); + S = phi(S); + return S; +} + +Tensor glu_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & input, int64_t dim) { + auto& gO = grad_output; + auto input_size = input.size(dim) / 2; + auto first_half = input.narrow(dim, 0, input_size); + auto second_half = input.narrow(dim, input_size, input_size); + auto sig_second_half = second_half.sigmoid(); + auto one_sub_sig_second_half = 1 - sig_second_half; + auto sig_one_sub_sig = sig_second_half * one_sub_sig_second_half; + + auto ggI_first_half = grad.narrow(dim, 0, input_size); + auto ggI_second_half = grad.narrow(dim, input_size, input_size); + auto ggI_second_half_times_first_half = ggI_second_half * first_half; + + auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig; + auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig; + auto gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig; + return at::cat({gI_first_half, gI_second_half}, dim); +} + +Tensor glu_double_backward_grad_output(const Tensor & grad, const Tensor & input, int64_t dim) { + if (dim < 0) dim += input.dim(); + auto sizes = std::vector<int64_t>{input.sizes()}; + sizes[dim] /= 2; + auto tmp = grad * glu_backward(input.type().ones(sizes), input, dim); + return tmp.narrow(dim, 0, sizes[dim]) + tmp.narrow(dim, sizes[dim], sizes[dim]); +} + +Tensor log_sigmoid_double_backward(const Tensor & grad, const Tensor & input) { + auto z = input.sigmoid(); + return grad * (z - 1) * z; +} + +Tensor softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) { + auto gO = grad_output; + auto ggI = grad; + + auto ggI_output = ggI * output; + auto ggI_out_sum = ggI_output.sum(dim, true); + auto ggI_out_sum_output = ggI_out_sum * output; + auto gO_out_sum = (gO * output).sum(dim, true); + + // gI calculation + auto gI_t0 = ggI_output * (gO - gO_out_sum); + auto gI_t1 = output * ((ggI_output * gO).sum(dim, true).sub_(gO_out_sum * ggI_out_sum)); + auto gI_t2 = ggI_out_sum_output * gO; + auto gI_t3 = ggI_out_sum_output * gO_out_sum; + return gI_t0 - gI_t1 - gI_t2 + gI_t3; +} + +Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) { + auto z = output.exp(); + return z * grad_output.sum(dim, true) * ((grad * z).sum(dim, true) - grad); +} + +Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, bool size_average) { + auto d = (input - target).abs(); + auto grad_input = grad * (d < 1).toType(grad.type()); + if (size_average) { + grad_input /= input.numel(); + } + return grad_input; +} + +Tensor max_pool2d_double_backward(const Tensor & grad, const Tensor & indices) { + // fold the first two dims together and the last two together + auto fold = [](const Tensor & t) -> Tensor { + auto sizes = t.sizes(); + return t.contiguous().view({sizes[0] * sizes[1], sizes[2] * sizes[3]}); + }; + return fold(grad).gather(1, fold(indices)).view(indices.sizes()); +} + +Tensor mse_loss_double_backward(const Tensor & grad, const Tensor & input, bool size_average, bool reduce) { + auto grad_input = 2 * grad; + if (size_average && reduce) { + grad_input /= input.numel(); + } + return grad_input; +} + +Tensor mse_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, bool size_average, bool reduce) { + if (!reduce) { + return mse_loss_backward(grad, input, target, size_average, reduce); + } + auto r = mse_loss_backward(ones_like(grad_output), input, target, size_average, true); + return (r * grad).sum().toTensor().view({1}); +} + +Tensor soft_margin_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, bool size_average) { + auto z = (input * -target).exp(); + auto zplus1 = z + 1; + auto grad_input = grad * (target * target) * z / (zplus1 * zplus1); + if (size_average) { + grad_input /= input.numel(); + } + return grad_input; +} + +Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, Scalar beta, Scalar threshold) { + auto x = (input * beta); + return _sigmoid_backward(grad, x.sigmoid()) * (x < threshold).toType(grad.type()) * beta; +} + +} + ${autograd_function_definitions} -}} // namespace torch::autograd +}}} // namespace torch::autograd::generated diff --git a/tools/autograd/templates/Functions.h b/tools/autograd/templates/Functions.h index 1ce1b5fd76..aa672be130 100644 --- a/tools/autograd/templates/Functions.h +++ b/tools/autograd/templates/Functions.h @@ -8,15 +8,30 @@ #include "torch/csrc/autograd/variable.h" #include "torch/csrc/autograd/saved_variable.h" -namespace torch { namespace autograd { +namespace torch { namespace autograd { namespace generated { using at::Scalar; using at::Tensor; using at::IntList; +using at::Type; + +struct TypeAndSize { + TypeAndSize() : type(nullptr) {} + /* implicit */ + TypeAndSize(const Tensor & t) + : sizes(t.sizes()) + , type(&t.type()) {} + + Tensor zeros() { return type->zeros(sizes); } + +private: + std::vector<int64_t> sizes; + Type* type; +}; // avoid mutiply if scalar is 1. inline Tensor maybe_multiply(const Tensor & t, const Scalar & s); ${autograd_function_declarations} -}} // namespace torch::autograd +}}} // namespace torch::autograd::generated diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 5708ef39c7..df0631e86c 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -7,6 +7,7 @@ #include "torch/csrc/autograd/saved_variable.h" #include "torch/csrc/autograd/generated/Functions.h" #include "torch/csrc/autograd/functions/tensor.h" +#include "torch/csrc/autograd/functions/basic_ops.h" #include <initializer_list> #include <iostream> @@ -19,12 +20,14 @@ #endif using namespace at; +using namespace torch::autograd::generated; namespace torch { namespace autograd { VariableType::VariableType(Context* context, Type* baseType) : Type(context) , baseType(baseType) { + str = std::string("Variable[") + baseType->toString() + "]"; } ScalarType VariableType::scalarType() const { @@ -54,7 +57,7 @@ std::unique_ptr<Generator> VariableType::generator() const { } const char * VariableType::toString() const { - return VariableType::typeString(); + return str.c_str(); } size_t VariableType::elementSizeInBytes() const { return baseType->elementSizeInBytes(); @@ -67,44 +70,88 @@ const char * VariableType::typeString() { return "VariableType"; } -Tensor & VariableType::checked_unpack(const Tensor & t, const char * name, int pos) const -{ - if(!t.defined()) { - runtime_error("Expected a Tensor of type %s but found an undefined Tensor for argument #%d '%s'", - toString(), pos, name); - } - if (&t.type() == this) { - return static_cast<VariableImpl*>(t.pImpl)->data; - } - runtime_error("Expected object of type %s but found type %s for argument #%d '%s'", - toString(),t.type().toString(), pos, name); -} - -std::vector<at::Tensor> VariableType::checked_unpack(const at::TensorList &tl, const char *name, int pos) const { - std::vector<at::Tensor> ret(tl.size()); - for (size_t i = 0; i < tl.size(); ++i) { - const auto &t = tl[i]; - if(!t.defined()) { - runtime_error("Expected a Tensor of type %s but found an undefined Tensor at position #%d " - "for iterable argument #%d '%s'", - toString(), i, pos, name); - } - if (&t.type() == this) { - ret[i] = static_cast<VariableImpl*>(t.pImpl)->data; - } else { - runtime_error("Expected object of type %s but found type %s at position #%d " - "for iterable argument #%d '%s'", - toString(),t.type().toString(), i, pos, name); - } +Variable & VariableType::checked_cast(const Type & type, const Tensor & t, const char * name, int pos) { + if(!t.defined()) { + runtime_error("Expected a Tensor of type %s but found an undefined Tensor for argument #%d '%s'", + type.toString(), pos, name); } - return ret; + if (&t.type() != &type) { + runtime_error("Expected object of type %s but found type %s for argument #%d '%s'", + type.toString(), t.type().toString(), pos, name); + } + return static_cast<Variable&>(const_cast<Tensor&>(t)); +} + +Tensor & VariableType::unpack(const Tensor & t, const char * name, int pos) const { + return checked_cast(*this, t, name, pos).data(); +} + +Tensor & VariableType::unpack_long(const Tensor & t, const char * name, int pos) const { + auto& type = *VariableImpl::getType(baseType->toScalarType(kLong)); + return checked_cast(type, t, name, pos).data(); +} + +Tensor & VariableType::unpack_byte(const Tensor & t, const char * name, int pos) const { + auto& type = *VariableImpl::getType(baseType->toScalarType(kByte)); + return checked_cast(type, t, name, pos).data(); +} + +Tensor & VariableType::unpack_any(const Tensor & t, const char * name, int pos) const { + if (!t.defined()) { + runtime_error("Expected a Tensor of type Variable but found an undefined Tensor for argument #%d '%s'", + pos, name); + } + auto scalarType = t.type().scalarType(); + auto& type = *VariableImpl::getType(baseType->toScalarType(scalarType)); + return checked_cast(type, t, name, pos).data(); } +Tensor VariableType::unpack_opt(const Tensor & t, const char * name, int pos) const { + if(!t.defined()) { + return Tensor(); + } + return unpack(t, name, pos); +} + +std::vector<at::Tensor> VariableType::unpack(const at::TensorList &tl, const char *name, int pos) const { + std::vector<at::Tensor> ret(tl.size()); + for (size_t i = 0; i < tl.size(); ++i) { + const auto &t = tl[i]; + if (!t.defined()) { + runtime_error("Expected a Tensor of type %s but found an undefined Tensor at position #%d " + "for iterable argument #%d '%s'", + toString(), i, pos, name); + } + if (&t.type() == this) { + ret[i] = static_cast<VariableImpl*>(t.pImpl)->data; + } else { + runtime_error("Expected object of type %s but found type %s at position #%d " + "for iterable argument #%d '%s'", + toString(),t.type().toString(), i, pos, name); + } + } + return ret; +} Variable VariableType::as_variable(Tensor tensor) const { return make_variable(std::move(tensor)); } +std::tuple<Variable, Variable> +VariableType::as_variable(std::tuple<Tensor, Tensor> tensors) const { + return std::make_tuple<>( + make_variable(std::move(std::get<0>(tensors))), + make_variable(std::move(std::get<1>(tensors)))); +} + +std::tuple<Variable, Variable, Variable> +VariableType::as_variable(std::tuple<Tensor, Tensor, Tensor> tensors) const { + return std::make_tuple<>( + make_variable(std::move(std::get<0>(tensors))), + make_variable(std::move(std::get<1>(tensors))), + make_variable(std::move(std::get<2>(tensors)))); +} + Variable VariableType::as_variable(const Scalar & scalar) const { auto tensor = scalar.toTensor(); if (&tensor.type() != baseType) { @@ -113,7 +160,7 @@ Variable VariableType::as_variable(const Scalar & scalar) const { return make_variable(std::move(tensor)); } -void check_inplace(const VariableImpl& pImpl) { +static void check_inplace(const VariableImpl& pImpl) { auto& version_counter = *pImpl.version_counter; if (pImpl.requires_grad && !pImpl.grad_fn) { at::runtime_error( @@ -127,7 +174,7 @@ void check_inplace(const VariableImpl& pImpl) { } } -void wrap_output(VariableImpl& pImpl, FunctionFlags flags, std::shared_ptr<Function> grad_fn) { +static void wrap_output(VariableImpl& pImpl, FunctionFlags flags, std::shared_ptr<Function> grad_fn) { // Hooks up the grad_fn and sets the flags of the function output. This only // supports a single differentiable output. pImpl.requires_grad = flags.is_executable; @@ -139,19 +186,59 @@ void wrap_output(VariableImpl& pImpl, FunctionFlags flags, std::shared_ptr<Funct } } +static void wrap_output(Tensor& t, FunctionFlags flags, std::shared_ptr<Function> grad_fn) { + auto pImpl = static_cast<VariableImpl*>(t.get()); + wrap_output(*pImpl, std::move(flags), std::move(grad_fn)); +} + +static void wrap_output(std::tuple<Variable, Variable>& t, FunctionFlags flags, std::shared_ptr<Function> grad_fn) { + wrap_output(std::get<0>(t), std::move(flags), std::move(grad_fn)); +} + +static void wrap_output(std::tuple<Variable, Variable, Variable>& t, FunctionFlags flags, std::shared_ptr<Function> grad_fn) { + wrap_output(std::get<0>(t), std::move(flags), std::move(grad_fn)); +} + +static void increment_version(const Tensor & t) { + auto pImpl = static_cast<VariableImpl*>(t.get()); + (*pImpl->version_counter)++; +} + +static void take_version_counter(Tensor & dst, const Tensor & src) { + // replaces the version counter in dst with the one in src + // call when dst is a view of src + auto srcImpl = static_cast<VariableImpl*>(src.get()); + auto dstImpl = static_cast<VariableImpl*>(dst.get()); + dstImpl->version_counter->join_with(*srcImpl->version_counter); +} + +static bool isFloatingPoint(ScalarType s) { + return s == kFloat || s == kDouble || s == kHalf; +} + void VariableType::s_copy(const Tensor & src, Tensor & dst) const { - auto& src_ = checked_unpack(src, "src", 0); - auto& dst_ = checked_unpack(dst, "dst", 1); + // TODO: once copy is exposed in Declarations.yaml we may be able to bind + // it automatically + auto& src_ = unpack_any(src, "src", 0); + auto& dst_ = unpack(dst, "dst", 1); auto& pImpl = static_cast<VariableImpl&>(*dst.get()); check_inplace(pImpl); auto flags = Function::flags({ src }); baseType->s_copy(src_, dst_); (*pImpl.version_counter)++; - wrap_output(pImpl, std::move(flags), std::make_shared<Identity>()); + if (isFloatingPoint(dst.type().scalarType())) { + if (isFloatingPoint(src.type().scalarType())) { + // TODO: handle type conversions + wrap_output(pImpl, std::move(flags), std::make_shared<Identity>()); + } else { + // TODO: handle type conversions + wrap_output(pImpl, std::move(flags), std::make_shared<Identity>()); + } + } } Tensor & VariableType::m_resize_(Tensor & self, IntList size) const { - auto& self_ = checked_unpack(self, "self", 0); + auto& self_ = unpack(self, "self", 0); auto& pImpl = static_cast<VariableImpl&>(*self.get()); check_inplace(pImpl); if (pImpl.grad_fn) { @@ -164,6 +251,18 @@ Tensor & VariableType::m_resize_(Tensor & self, IntList size) const { return self; } +Tensor & VariableType::m_resize_as_(Tensor & self, const Tensor & the_template) const { + return m_resize_(self, the_template.sizes()); +} + +Tensor VariableType::m_contiguous(const Tensor & self) const { + unpack(self, "self", 0); + if (self.is_contiguous()) { + return self; + } + return self.clone(); +} + std::vector<int64_t> to_arg_sizes(TensorList tensors, int64_t dim) { std::vector<int64_t> arg_sizes(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { diff --git a/tools/autograd/templates/VariableType.h b/tools/autograd/templates/VariableType.h index ecda7e89b1..2a1b93a655 100644 --- a/tools/autograd/templates/VariableType.h +++ b/tools/autograd/templates/VariableType.h @@ -3,6 +3,7 @@ // ${generated_comment} #include <ATen/ATen.h> +#include <string> namespace torch { namespace autograd { @@ -38,13 +39,23 @@ struct VariableType : public at::Type { ${type_derived_method_declarations} private: - at::Tensor & checked_unpack(const Tensor & t, const char * name, int pos) const; - std::vector<at::Tensor> checked_unpack(const at::TensorList &tl, const char *name, int pos) const; - Variable as_variable(Tensor tensor) const; + // checks that t is actually a Variable with the given expected_type + static Variable & checked_cast(const Type & expected_type, const Tensor & t, const char * name, int pos); + at::Tensor & unpack(const Tensor & t, const char * name, int pos) const; + at::Tensor & unpack_long(const Tensor & t, const char * name, int pos) const; + at::Tensor & unpack_byte(const Tensor & t, const char * name, int pos) const; + at::Tensor & unpack_any(const Tensor & t, const char * name, int pos) const; + at::Tensor unpack_opt(const Tensor & t, const char * name, int pos) const; + std::vector<at::Tensor> unpack(const at::TensorList &tl, const char *name, int pos) const; + Variable as_variable(const Scalar & scalar) const; + Variable as_variable(Tensor tensor) const; + std::tuple<Variable, Variable> as_variable(std::tuple<Tensor, Tensor> tensor) const; + std::tuple<Variable, Variable, Variable> as_variable(std::tuple<Tensor, Tensor, Tensor> tensor) const; private: at::Type* baseType; + std::string str; }; }} // namespace torch::autograd diff --git a/tools/autograd/templates/python_functions.cpp b/tools/autograd/templates/python_functions.cpp index 1bc2019c2e..b0f5efa891 100644 --- a/tools/autograd/templates/python_functions.cpp +++ b/tools/autograd/templates/python_functions.cpp @@ -8,7 +8,7 @@ #include "Functions.h" #include "torch/csrc/autograd/python_cpp_function.h" -namespace torch { namespace autograd { +namespace torch { namespace autograd { namespace generated { template<typename C> static void addClass(PyTypeObject& type, const char* name, @@ -23,4 +23,4 @@ void initialize_autogenerated_functions() { ${py_function_initializers} } -}} // namespace torch::autograd +}}} // namespace torch::autograd::generated diff --git a/tools/autograd/templates/python_functions.h b/tools/autograd/templates/python_functions.h index af28d20e92..c434e57b32 100644 --- a/tools/autograd/templates/python_functions.h +++ b/tools/autograd/templates/python_functions.h @@ -4,8 +4,8 @@ // Python bindings for automatically generated autograd functions -namespace torch { namespace autograd { +namespace torch { namespace autograd { namespace generated { void initialize_autogenerated_functions(); -}} // namespace torch::autograd +}}} // namespace torch::autograd::generated diff --git a/tools/autograd/templates/python_nn_functions.cpp b/tools/autograd/templates/python_nn_functions.cpp new file mode 100644 index 0000000000..0cafff1073 --- /dev/null +++ b/tools/autograd/templates/python_nn_functions.cpp @@ -0,0 +1,46 @@ +#include "python_nn_functions.h" + +// ${generated_comment} + +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/utils/python_arg_parser.h" + +#include "python_nn_functions_dispatch.h" + +using at::Tensor; +using at::Scalar; +using namespace torch::autograd::utils; + +namespace torch { namespace autograd { + +${py_nn_functions} + +static PyMethodDef nn_functions[] = { + ${py_nn_function_defs} + {NULL} +}; + +void initNNFunctions(PyObject* module) { +#if PY_MAJOR_VERSION == 2 + PyObject* nn = Py_InitModule("torch._C._nn", nn_functions); +#else + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, + "torch._C._nn", + NULL, + -1, + nn_functions + }; + PyObject* nn = PyModule_Create(&def); +#endif + if (!nn) { + throw python_error(); + } + if (PyModule_AddObject(module, "_nn", nn) != 0) { + throw python_error(); + } +} + +}} // namespace torch::autograd diff --git a/tools/autograd/templates/python_nn_functions.h b/tools/autograd/templates/python_nn_functions.h new file mode 100644 index 0000000000..f84147e688 --- /dev/null +++ b/tools/autograd/templates/python_nn_functions.h @@ -0,0 +1,7 @@ +#include <Python.h> + +namespace torch { namespace autograd { + +void initNNFunctions(PyObject* module); + +}} // namespace torch::autograd diff --git a/tools/autograd/templates/python_nn_functions_dispatch.h b/tools/autograd/templates/python_nn_functions_dispatch.h new file mode 100644 index 0000000000..8b844265c9 --- /dev/null +++ b/tools/autograd/templates/python_nn_functions_dispatch.h @@ -0,0 +1,19 @@ +#pragma once + +// ${generated_comment} + +#include <ATen/ATen.h> +#include "torch/csrc/utils/auto_gil.h" +#include "torch/csrc/utils/auto_gpu.h" + +// Contains inline wrappers around ATen functions which release the GIL and +// switch to the correct CUDA device. + +namespace torch { namespace autograd { + +using namespace at; +using at::Generator; + +${py_nn_function_dispatch} + +}} // namespace torch::autograd diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 7ea3099b2e..1d027d60da 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -2,8 +2,9 @@ #include <Python.h> -#include "torch/csrc/autograd/python_variable.h" #include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" #include "torch/csrc/utils/python_arg_parser.h" #include "torch/csrc/utils/python_numbers.h" @@ -11,50 +12,10 @@ using at::Tensor; using at::Scalar; +using namespace torch::autograd::utils; namespace torch { namespace autograd { -namespace { - -inline PyObject* wrap(Tensor tensor) { - return THPVariable_Wrap(Variable(std::move(tensor))); -} - -inline PyObject* wrap(std::tuple<Tensor, Tensor> tensors) { - auto tuple = THPObjectPtr{PyTuple_New(2)}; - if (!tuple) return NULL; - PyTuple_SET_ITEM(tuple.get(), 0, wrap(std::move(std::get<0>(tensors)))); - PyTuple_SET_ITEM(tuple.get(), 1, wrap(std::move(std::get<1>(tensors)))); - return tuple.release(); -} - -inline PyObject* wrap(std::tuple<Tensor, Tensor, Tensor> tensors) { - auto tuple = THPObjectPtr{PyTuple_New(3)}; - if (!tuple) return NULL; - PyTuple_SET_ITEM(tuple.get(), 0, wrap(std::move(std::get<0>(tensors)))); - PyTuple_SET_ITEM(tuple.get(), 1, wrap(std::move(std::get<1>(tensors)))); - PyTuple_SET_ITEM(tuple.get(), 2, wrap(std::move(std::get<2>(tensors)))); - return tuple.release(); -} - -inline PyObject* wrap(bool value) { - if (value) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } -} - -inline PyObject* wrap(int64_t value) { - return THPUtils_packInt64(value); -} - -inline PyObject* wrap(Scalar scalar) { - return wrap(scalar.toTensor()); -} - -} // anonymous namespace - ${py_methods} PyMethodDef variable_methods[] = { diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 4becd015bb..0efac9082b 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -11,6 +11,7 @@ #include <ATen/DLConvertor.h> #include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/autograd/generated/python_nn_functions.h" #include "torch/csrc/utils/python_strings.h" #include "torch/csrc/jit/python_tracer.h" #include "torch/csrc/jit/init.h" @@ -474,7 +475,8 @@ PyObject *THPModule_addDocStr(PyObject *_unused, PyObject *args) "don't know how to add docstring to type '%s'", Py_TYPE(obj)->tp_name); } - Py_RETURN_NONE; + Py_INCREF(obj); + return obj; } @@ -774,19 +776,11 @@ static std::vector<PyMethodDef> methods; PyMethodDef* THDPModule_methods(); #endif -#if PY_MAJOR_VERSION == 2 -PyMODINIT_FUNC init_C() -#else -PyMODINIT_FUNC PyInit__C() -#endif -{ +static PyObject* initModule() { + HANDLE_TH_ERRORS THInferNumThreads(); -#if PY_MAJOR_VERSION == 2 -#define ASSERT_TRUE(cmd) if (!(cmd)) {PyErr_SetString(PyExc_ImportError, "initialization error"); return;} -#else #define ASSERT_TRUE(cmd) if (!(cmd)) return NULL -#endif THPUtils_addPyMethodDefs(methods, TorchMethods); #ifdef WITH_CUDA @@ -820,6 +814,7 @@ PyMODINIT_FUNC PyInit__C() ASSERT_TRUE(THPEngine_initModule(module)); torch::autograd::initAutogradClosureBindings(module); torch::jit::initJITBindings(module); + torch::autograd::initNNFunctions(module); ASSERT_TRUE(THPDoubleStorage_init(module)); ASSERT_TRUE(THPFloatStorage_init(module)); ASSERT_TRUE(THPHalfStorage_init(module)); @@ -920,11 +915,22 @@ PyMODINIT_FUNC PyInit__C() ASSERT_TRUE(PyModule_AddObject(module, "default_generator", (PyObject*)THPDefaultGenerator) == 0); #ifdef WITH_NUMPY - import_array(); + if (_import_array() < 0) return NULL; #endif + return module; + END_HANDLE_TH_ERRORS +} + #if PY_MAJOR_VERSION == 2 +PyMODINIT_FUNC init_C() #else - return module; +PyMODINIT_FUNC PyInit__C() +#endif +{ +#if PY_MAJOR_VERSION == 2 + initModule(); +#else + return initModule(); #endif } diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 3b32763e41..ec2e937727 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -121,6 +121,12 @@ struct Function : std::enable_shared_from_this<Function> { return false; } + inline bool should_compute_output(std::initializer_list<int> idxs) const { + return std::any_of(idxs.begin(), idxs.end(), [this](int i) { + return should_compute_output(i); + }); + } + inline void set_flags(FunctionFlags&& flags) { is_executable = flags.is_executable; next_functions = std::move(flags.next_functions); diff --git a/torch/csrc/autograd/functions/init.cpp b/torch/csrc/autograd/functions/init.cpp index 82ece0359a..a791f3ce90 100644 --- a/torch/csrc/autograd/functions/init.cpp +++ b/torch/csrc/autograd/functions/init.cpp @@ -312,7 +312,7 @@ bool THPAutograd_initFunctions(PyObject* _unused) static PyTypeObject AutogradClosureClass; addClass<AutogradClosure, NoCtor>(module, AutogradClosureClass, "AutogradClosure"); - initialize_autogenerated_functions(); + generated::initialize_autogenerated_functions(); THPObjectPtr parent(PyImport_ImportModule("torch._C")); if (!parent) return false; diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h new file mode 100644 index 0000000000..8933ba1635 --- /dev/null +++ b/torch/csrc/autograd/utils/wrap_outputs.h @@ -0,0 +1,53 @@ +#pragma once + +// Wrap tensor operation outputs as PyObject* + +#include <ATen/ATen.h> +#include <Python.h> +#include <tuple> + +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/variable.h" +#include "torch/csrc/utils/python_numbers.h" + +namespace torch { namespace autograd { namespace utils { + +inline PyObject* wrap(at::Tensor tensor) { + return THPVariable_Wrap(Variable(std::move(tensor))); +} + +inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor> tensors) { + auto r = THPObjectPtr{PyTuple_New(2)}; + if (!r) throw python_error(); + PyTuple_SET_ITEM(r.get(), 0, wrap(std::get<0>(tensors))); + PyTuple_SET_ITEM(r.get(), 1, wrap(std::get<1>(tensors))); + return r.release(); +} + +inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor, at::Tensor> tensors) { + auto r = THPObjectPtr{PyTuple_New(3)}; + if (!r) throw python_error(); + PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors)))); + PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors)))); + PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors)))); + return r.release(); +} + +inline PyObject* wrap(bool value) { + if (value) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +inline PyObject* wrap(int64_t value) { + return THPUtils_packInt64(value); +} + +inline PyObject* wrap(at::Scalar scalar) { + return wrap(scalar.toTensor()); +} + + +}}} // namespace torch::autograd::utils diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 11ac56c62b..3af3627442 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -27,6 +27,7 @@ static std::unordered_map<std::string, ParameterType> type_map = { FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) : optional(false) , keyword_only(keyword_only) + , size(0) , default_scalar(0) { auto space = fmt.find(' '); @@ -35,6 +36,13 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) } auto type_str = fmt.substr(0, space); + auto bracket = type_str.find('['); + if (bracket != std::string::npos) { + auto size_str = type_str.substr(bracket + 1, type_str.length() - bracket - 2); + size = atoi(size_str.c_str()); + type_str = type_str.substr(0, bracket); + } + auto name_str = fmt.substr(space + 1); type_ = type_map[type_str]; @@ -55,12 +63,20 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) bool FunctionParameter::check(PyObject* obj) { switch (type_) { - case ParameterType::TENSOR: return THPVariable_Check(obj); + case ParameterType::TENSOR: { + return THPVariable_Check(obj) || (optional && obj == Py_None); + } case ParameterType::SCALAR: return THPUtils_checkDouble(obj); case ParameterType::INT64: return THPUtils_checkLong(obj); case ParameterType::DOUBLE: return THPUtils_checkDouble(obj); case ParameterType::TENSOR_LIST: return PyTuple_Check(obj) || PyList_Check(obj); - case ParameterType::INT_LIST: return PyTuple_Check(obj) || PyList_Check(obj); + case ParameterType::INT_LIST: { + if (PyTuple_Check(obj) || PyList_Check(obj)) { + return true; + } + // if a size is specified (e.g. IntList[2]) we also allow passing a single int + return size > 0 && THPUtils_checkLong(obj); + } case ParameterType::GENERATOR: return false; case ParameterType::BOOL: return PyBool_Check(obj); case ParameterType::STORAGE: return false; @@ -97,6 +113,10 @@ void FunctionParameter::set_default_str(const std::string& str) { default_double = atof(str.c_str()); } else if (type_ == ParameterType::SCALAR) { default_scalar = Scalar(atof(str.c_str())); + } else if (type_ == ParameterType::INT_LIST) { + if (str != "None") { + default_intlist.assign(size, std::stoi(str)); + } } } diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index ec5f915421..37dcc39d8e 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -108,11 +108,13 @@ struct FunctionParameter { ParameterType type_; bool optional; bool keyword_only; + int size; std::string name; // having this as a raw PyObject * will presumably leak it, but these are only held by static objects // anyway, and Py_Finalize can already be called when this is destructed. PyObject *python_name; at::Scalar default_scalar; + std::vector<int64_t> default_intlist; union { bool default_bool; int64_t default_int; @@ -121,7 +123,7 @@ struct FunctionParameter { }; inline at::Tensor PythonArgs::tensor(int i) { - if (!args[i]) return at::Tensor(); + if (!args[i] || args[i] == Py_None) return at::Tensor(); if (!THPVariable_Check(args[i])) { type_error("expected Variable as argument %d, but got %s", i, THPUtils_typename(args[i])); } @@ -154,10 +156,14 @@ inline std::vector<at::Tensor> PythonArgs::tensorlist(int i) { } inline std::vector<int64_t> PythonArgs::intlist(int i) { - if (!args[i]) return std::vector<int64_t>(); + if (!args[i]) return signature.params[i].default_intlist; PyObject* arg = args[i]; + auto size = signature.params[i].size; + if (size > 0 && THPUtils_checkLong(arg)) { + return std::vector<int64_t>(size, THPUtils_unpackLong(arg)); + } auto tuple = PyTuple_Check(arg); - auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); + size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); std::vector<int64_t> res(size); for (int idx = 0; idx < size; idx++) { PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); @@ -188,6 +194,7 @@ inline bool PythonArgs::toBool(int i) { } inline at::Generator* PythonArgs::generator(int i) { + if (!args[i]) return nullptr; throw std::runtime_error("PythonArgs::generator not implemented"); } |