diff options
author | Bram Wasti <bwasti@fb.com> | 2019-04-04 00:24:16 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-04 00:28:13 -0700 |
commit | 095f88e0934f5503cb3d01dbb7065d3c0be8e860 (patch) | |
tree | 45567056915dd467c8cc4d442d0f1ff2cb4b3c18 /aten | |
parent | e5e2110a8ead028c863a7f449273bf6ee90bc423 (diff) | |
download | pytorch-095f88e0934f5503cb3d01dbb7065d3c0be8e860.tar.gz pytorch-095f88e0934f5503cb3d01dbb7065d3c0be8e860.tar.bz2 pytorch-095f88e0934f5503cb3d01dbb7065d3c0be8e860.zip |
Fix to handle null strides in DLPack tensor (#18510)
Summary:
DLPack can have non-strided tensors, which is represented by a nullptr in the place of dl_tensor.strides.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18510
Differential Revision: D14647328
Pulled By: bwasti
fbshipit-source-id: 5364282810a5772cfc2319fc8133fe86fdd84dd1
Diffstat (limited to 'aten')
-rw-r--r-- | aten/src/ATen/DLConvertor.cpp | 7 | ||||
-rw-r--r-- | aten/src/ATen/test/dlconvertor_test.cpp | 12 |
2 files changed, 19 insertions, 0 deletions
diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index e2983ebdff..a40e872724 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -178,6 +178,13 @@ Tensor fromDLPack(const DLManagedTensor* src) { auto deleter = [src](void* self) { src->deleter(const_cast<DLManagedTensor*>(src)); }; + if (!src->dl_tensor.strides) { + return at::from_blob(src->dl_tensor.data, + IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), + deleter, + at::device(device_type).dtype(stype)); + } + return at::from_blob( src->dl_tensor.data, IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), diff --git a/aten/src/ATen/test/dlconvertor_test.cpp b/aten/src/ATen/test/dlconvertor_test.cpp index e2c7a62403..4e5eb93df4 100644 --- a/aten/src/ATen/test/dlconvertor_test.cpp +++ b/aten/src/ATen/test/dlconvertor_test.cpp @@ -18,3 +18,15 @@ TEST(TestDlconvertor, TestDlconvertor) { ASSERT_TRUE(a.equal(b)); } + +TEST(TestDlconvertor, TestDlconvertorNoStrides) { + manual_seed(123); + + Tensor a = rand({3, 4}); + DLManagedTensor* dlMTensor = toDLPack(a); + dlMTensor->dl_tensor.strides = nullptr; + + Tensor b = fromDLPack(dlMTensor); + + ASSERT_TRUE(a.equal(b)); +} |