diff options
author | Tongzhou Wang <SsnL@users.noreply.github.com> | 2018-04-03 15:53:43 -0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2018-04-03 15:53:43 -0400 |
commit | dfcd90783c37075631e873bed72013fe4917d66f (patch) | |
tree | f2e5b125cf82e47b8e2a24ad55a2af763587661e /aten/src | |
parent | 14bf37f22e6d3b627470dbdcf092abb2560c398e (diff) | |
download | pytorch-dfcd90783c37075631e873bed72013fe4917d66f.tar.gz pytorch-dfcd90783c37075631e873bed72013fe4917d66f.tar.bz2 pytorch-dfcd90783c37075631e873bed72013fe4917d66f.zip |
fix sparse embedding backward when input contains only padding_idx (#6211)
Diffstat (limited to 'aten/src')
-rw-r--r-- | aten/src/ATen/native/Embedding.cpp | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp index 7672257e6f..f5fadc0e2a 100644 --- a/aten/src/ATen/native/Embedding.cpp +++ b/aten/src/ATen/native/Embedding.cpp @@ -66,13 +66,19 @@ Tensor embedding_sparse_backward( grad = grad.index(c); } - int64_t num_features = grad.size(-1); + int64_t num_features = grad_.size(-1); auto weight_size = std::array<int64_t, 2>{{ num_weights, num_features }}; + auto& dense_type = grad.type(); + auto& sparse_type = dense_type.toBackend(grad.is_cuda() ? kSparseCUDA : kSparseCPU); + + // check if all our grad come from padding_idx + if (grad.numel() == 0) { + return sparse_type.sparse_coo_tensor(indices_.type().tensor(), + dense_type.tensor(), weight_size); + } auto index = indices.view({1, -1}); auto values = grad.contiguous().view({-1, num_features}); - - auto& sparse_type = grad.type().toBackend(grad.is_cuda() ? kSparseCUDA : kSparseCPU); return sparse_type.sparse_coo_tensor(index, values, weight_size); } |