diff options
author | Brennan Vincent <btv@fb.com> | 2019-02-05 08:27:04 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-05 08:32:11 -0800 |
commit | 1ce188c510a160cac37fb08eb45c8ebcd157075a (patch) | |
tree | 12e376790ddb6dfda5506547df1aa01249fecee9 /tools | |
parent | 4047c972669238ac62073a5afe47be14e1cb48be (diff) | |
download | pytorch-1ce188c510a160cac37fb08eb45c8ebcd157075a.tar.gz pytorch-1ce188c510a160cac37fb08eb45c8ebcd157075a.tar.bz2 pytorch-1ce188c510a160cac37fb08eb45c8ebcd157075a.zip |
logsumexp for multiple dimensions (#16475)
Summary:
Move `logsumexp` and `max_values` to `TensorIterator` and use it to make `logsumexp` work for multiple dimensions.
Timings on a tensor of shape `(10,1000000,10)`, for each combination of (cpu, single-threaded cpu, gpu) and dimension:
**before**
208 ms ± 2.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
279 ms ± 5.07 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
199 ms ± 2.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.11 s ± 33.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.25 s ± 25.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.11 s ± 6.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
15.4 ms ± 1.02 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
132 ms ± 30.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
39.6 ms ± 19.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
**after**
199 ms ± 8.23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
307 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
207 ms ± 7.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.16 s ± 8.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.26 s ± 47.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.13 s ± 13.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
15.4 ms ± 868 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
132 ms ± 27.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
39.6 ms ± 21.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16475
Differential Revision: D13855746
Pulled By: umanwizard
fbshipit-source-id: aaacc0b967c3f89073487e1952ae6f76b7bd7ad3
Diffstat (limited to 'tools')
-rw-r--r-- | tools/autograd/derivatives.yaml | 2 | ||||
-rw-r--r-- | tools/autograd/templates/Functions.cpp | 6 |
2 files changed, 4 insertions, 4 deletions
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 6d7a3524cb..e6086c3faf 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -451,7 +451,7 @@ - name: log_normal_(Tensor self, double mean, double std, Generator generator) self: zeros_like(grad) -- name: logsumexp(Tensor self, int64_t dim, bool keepdim) +- name: logsumexp(Tensor self, IntList dim, bool keepdim) self: logsumexp_backward(grad, self, result, dim, keepdim) - name: lt_(Tensor self, Scalar other) diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 4b1e1d8ba6..69e1edc558 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -429,10 +429,10 @@ Tensor cumsum_backward(const Tensor &x, int64_t dim, ScalarType input_dtype) { return cumsum_backward(x.to(input_dtype), dim); } -Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim, bool keepdim) { +Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntList dim, bool keepdim) { if (!keepdim && self.dim() != 0) { - grad = grad.unsqueeze(dim); - result = result.unsqueeze(dim); + grad = unsqueeze_multiple(grad, dim, self.sizes().size()); + result = unsqueeze_multiple(result, dim, self.sizes().size()); } return grad * (self - result).exp(); } |