summaryrefslogtreecommitdiff
path: root/aten
diff options
context:
space:
mode:
authorBram Wasti <bwasti@fb.com>2019-04-04 00:24:16 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-04 00:28:13 -0700
commit095f88e0934f5503cb3d01dbb7065d3c0be8e860 (patch)
tree45567056915dd467c8cc4d442d0f1ff2cb4b3c18 /aten
parente5e2110a8ead028c863a7f449273bf6ee90bc423 (diff)
downloadpytorch-095f88e0934f5503cb3d01dbb7065d3c0be8e860.tar.gz
pytorch-095f88e0934f5503cb3d01dbb7065d3c0be8e860.tar.bz2
pytorch-095f88e0934f5503cb3d01dbb7065d3c0be8e860.zip
Fix to handle null strides in DLPack tensor (#18510)
Summary: DLPack can have non-strided tensors, which is represented by a nullptr in the place of dl_tensor.strides. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18510 Differential Revision: D14647328 Pulled By: bwasti fbshipit-source-id: 5364282810a5772cfc2319fc8133fe86fdd84dd1
Diffstat (limited to 'aten')
-rw-r--r--aten/src/ATen/DLConvertor.cpp7
-rw-r--r--aten/src/ATen/test/dlconvertor_test.cpp12
2 files changed, 19 insertions, 0 deletions
diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp
index e2983ebdff..a40e872724 100644
--- a/aten/src/ATen/DLConvertor.cpp
+++ b/aten/src/ATen/DLConvertor.cpp
@@ -178,6 +178,13 @@ Tensor fromDLPack(const DLManagedTensor* src) {
auto deleter = [src](void* self) {
src->deleter(const_cast<DLManagedTensor*>(src));
};
+ if (!src->dl_tensor.strides) {
+ return at::from_blob(src->dl_tensor.data,
+ IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
+ deleter,
+ at::device(device_type).dtype(stype));
+ }
+
return at::from_blob(
src->dl_tensor.data,
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
diff --git a/aten/src/ATen/test/dlconvertor_test.cpp b/aten/src/ATen/test/dlconvertor_test.cpp
index e2c7a62403..4e5eb93df4 100644
--- a/aten/src/ATen/test/dlconvertor_test.cpp
+++ b/aten/src/ATen/test/dlconvertor_test.cpp
@@ -18,3 +18,15 @@ TEST(TestDlconvertor, TestDlconvertor) {
ASSERT_TRUE(a.equal(b));
}
+
+TEST(TestDlconvertor, TestDlconvertorNoStrides) {
+ manual_seed(123);
+
+ Tensor a = rand({3, 4});
+ DLManagedTensor* dlMTensor = toDLPack(a);
+ dlMTensor->dl_tensor.strides = nullptr;
+
+ Tensor b = fromDLPack(dlMTensor);
+
+ ASSERT_TRUE(a.equal(b));
+}