summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorgchanan <gregchanan@gmail.com>2018-01-29 15:09:59 -0500
committerGitHub <noreply@github.com>2018-01-29 15:09:59 -0500
commit260a2461929df8d0d6172aa253fc5d9f9ae10d3b (patch)
tree8c0aef5c40547912e07d0e9fc8fb130dded8b8de /tools
parente93ece90a5c9addf7e1b14e88223b35acd126b68 (diff)
downloadpytorch-260a2461929df8d0d6172aa253fc5d9f9ae10d3b.tar.gz
pytorch-260a2461929df8d0d6172aa253fc5d9f9ae10d3b.tar.bz2
pytorch-260a2461929df8d0d6172aa253fc5d9f9ae10d3b.zip
Move repeat autograd to C++. (#4885)
Diffstat (limited to 'tools')
-rw-r--r--tools/autograd/derivatives.yaml3
-rw-r--r--tools/autograd/templates/Functions.cpp27
2 files changed, 30 insertions, 0 deletions
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index b6f98ba40f..ae0c55ec96 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -515,6 +515,9 @@
- name: renorm(Tensor self, Scalar p, int64_t dim, Scalar maxnorm)
self: renorm_backward(grad, self, p, dim, maxnorm)
+- name: repeat(Tensor self, IntList repeats)
+ self: repeat_backward(grad, self.dim(), repeats)
+
- name: RoiPooling2d_forward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale)
input: RoiPooling2d_backward(input, rois, pooledHeight, pooledWidth, spatialScale, grad, result1)
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index e0e806e067..b76f1f695d 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -429,6 +429,33 @@ Tensor renorm_backward(const Tensor & grad, const Tensor & self, Scalar p, int64
return at::where(mask, grad, grad_norm);
}
+Tensor sum_tensorlist(TensorList tl) {
+ if (tl.size() == 0) {
+ throw std::runtime_error("Can't sum tensorlist of size 0");
+ }
+ Tensor sum = tl[0];
+ for(size_t i = 1; i < tl.size(); ++i) {
+ sum = sum + tl[i];
+ }
+ return sum;
+}
+
+Tensor repeat_backward(Tensor grad, int64_t input_dims, IntList repeats) {
+ int64_t num_unsqueezed = grad.dim() - input_dims;
+ for (int64_t i = 0; i < num_unsqueezed; ++i) {
+ grad = grad.sum(0, false);
+ }
+ for (size_t j = num_unsqueezed; j < repeats.size(); ++j) {
+ int64_t repeat = repeats[j];
+ if (repeat == 1) {
+ continue;
+ }
+ int64_t dim = j - num_unsqueezed;
+ grad = sum_tensorlist(grad.chunk(repeat, dim));
+ }
+ return grad;
+}
+
Tensor select_backward_scalar(Tensor grad, const Tensor & input, const Tensor & value) {
auto grad_input = zeros_like(input);
#ifdef WITH_SCALARS