diff options
Diffstat (limited to 'tools')
-rw-r--r-- | tools/autograd/derivatives.yaml | 2 | ||||
-rw-r--r-- | tools/autograd/templates/Functions.cpp | 6 |
2 files changed, 4 insertions, 4 deletions
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 6d7a3524cb..e6086c3faf 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -451,7 +451,7 @@ - name: log_normal_(Tensor self, double mean, double std, Generator generator) self: zeros_like(grad) -- name: logsumexp(Tensor self, int64_t dim, bool keepdim) +- name: logsumexp(Tensor self, IntList dim, bool keepdim) self: logsumexp_backward(grad, self, result, dim, keepdim) - name: lt_(Tensor self, Scalar other) diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 4b1e1d8ba6..69e1edc558 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -429,10 +429,10 @@ Tensor cumsum_backward(const Tensor &x, int64_t dim, ScalarType input_dtype) { return cumsum_backward(x.to(input_dtype), dim); } -Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim, bool keepdim) { +Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntList dim, bool keepdim) { if (!keepdim && self.dim() != 0) { - grad = grad.unsqueeze(dim); - result = result.unsqueeze(dim); + grad = unsqueeze_multiple(grad, dim, self.sizes().size()); + result = unsqueeze_multiple(result, dim, self.sizes().size()); } return grad * (self - result).exp(); } |