summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
Diffstat (limited to 'tools')
-rw-r--r--tools/autograd/derivatives.yaml2
-rw-r--r--tools/autograd/templates/Functions.cpp6
2 files changed, 4 insertions, 4 deletions
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 6d7a3524cb..e6086c3faf 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -451,7 +451,7 @@
- name: log_normal_(Tensor self, double mean, double std, Generator generator)
self: zeros_like(grad)
-- name: logsumexp(Tensor self, int64_t dim, bool keepdim)
+- name: logsumexp(Tensor self, IntList dim, bool keepdim)
self: logsumexp_backward(grad, self, result, dim, keepdim)
- name: lt_(Tensor self, Scalar other)
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index 4b1e1d8ba6..69e1edc558 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -429,10 +429,10 @@ Tensor cumsum_backward(const Tensor &x, int64_t dim, ScalarType input_dtype) {
return cumsum_backward(x.to(input_dtype), dim);
}
-Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim, bool keepdim) {
+Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntList dim, bool keepdim) {
if (!keepdim && self.dim() != 0) {
- grad = grad.unsqueeze(dim);
- result = result.unsqueeze(dim);
+ grad = unsqueeze_multiple(grad, dim, self.sizes().size());
+ result = unsqueeze_multiple(result, dim, self.sizes().size());
}
return grad * (self - result).exp();
}