summaryrefslogtreecommitdiff
path: root/aten/src
diff options
context:
space:
mode:
authorTongzhou Wang <SsnL@users.noreply.github.com>2018-04-03 15:53:43 -0400
committerSoumith Chintala <soumith@gmail.com>2018-04-03 15:53:43 -0400
commitdfcd90783c37075631e873bed72013fe4917d66f (patch)
treef2e5b125cf82e47b8e2a24ad55a2af763587661e /aten/src
parent14bf37f22e6d3b627470dbdcf092abb2560c398e (diff)
downloadpytorch-dfcd90783c37075631e873bed72013fe4917d66f.tar.gz
pytorch-dfcd90783c37075631e873bed72013fe4917d66f.tar.bz2
pytorch-dfcd90783c37075631e873bed72013fe4917d66f.zip
fix sparse embedding backward when input contains only padding_idx (#6211)
Diffstat (limited to 'aten/src')
-rw-r--r--aten/src/ATen/native/Embedding.cpp12
1 files changed, 9 insertions, 3 deletions
diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp
index 7672257e6f..f5fadc0e2a 100644
--- a/aten/src/ATen/native/Embedding.cpp
+++ b/aten/src/ATen/native/Embedding.cpp
@@ -66,13 +66,19 @@ Tensor embedding_sparse_backward(
grad = grad.index(c);
}
- int64_t num_features = grad.size(-1);
+ int64_t num_features = grad_.size(-1);
auto weight_size = std::array<int64_t, 2>{{ num_weights, num_features }};
+ auto& dense_type = grad.type();
+ auto& sparse_type = dense_type.toBackend(grad.is_cuda() ? kSparseCUDA : kSparseCPU);
+
+ // check if all our grad come from padding_idx
+ if (grad.numel() == 0) {
+ return sparse_type.sparse_coo_tensor(indices_.type().tensor(),
+ dense_type.tensor(), weight_size);
+ }
auto index = indices.view({1, -1});
auto values = grad.contiguous().view({-1, num_features});
-
- auto& sparse_type = grad.type().toBackend(grad.is_cuda() ? kSparseCUDA : kSparseCPU);
return sparse_type.sparse_coo_tensor(index, values, weight_size);
}