summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--aten/src/THCS/generic/THCSTensor.cpp11
-rw-r--r--aten/src/THS/generic/THSTensor.cpp49
-rw-r--r--test/test_sparse.py7
3 files changed, 45 insertions, 22 deletions
diff --git a/aten/src/THCS/generic/THCSTensor.cpp b/aten/src/THCS/generic/THCSTensor.cpp
index 208c71c8af..da97263f09 100644
--- a/aten/src/THCS/generic/THCSTensor.cpp
+++ b/aten/src/THCS/generic/THCSTensor.cpp
@@ -146,8 +146,15 @@ THCSTensor *THCSTensor_(newWithTensorAndSize)(THCState *state, THCIndexTensor *i
THCSTensor *self = (THCSTensor *)THAlloc(sizeof(THCSTensor));
THCSTensor_(rawInit)(state, self);
- nDimI = THCIndexTensor_(size)(state, indices, 0);
- nDimV = THCTensor_(nDimension)(state, values) - 1;
+ // TODO: we may need to special case when only one of these are empty.
+ if (THCudaLongTensor_nDimension(state, indices) == 0 && THCTensor_(nDimension)(state, values) == 0
+ && sizes != NULL) {
+ nDimI = 0;
+ nDimV = THLongStorage_size(sizes);
+ } else {
+ nDimI = THCIndexTensor_(size)(state, indices, 0);
+ nDimV = THCTensor_(nDimension)(state, values) - 1;
+ }
if (!sizes) {
// TODO Make it work with N-dimensional values
THArgCheck(nDimV > 0, 3, "size must be provided when nDimV > 0");
diff --git a/aten/src/THS/generic/THSTensor.cpp b/aten/src/THS/generic/THSTensor.cpp
index ed74eb0d8d..e3ee532467 100644
--- a/aten/src/THS/generic/THSTensor.cpp
+++ b/aten/src/THS/generic/THSTensor.cpp
@@ -146,8 +146,14 @@ THSTensor *THSTensor_(newWithTensorAndSize)(THLongTensor *indices, THTensor *val
THSTensor *self = (THSTensor *)THAlloc(sizeof(THSTensor));
THSTensor_(rawInit)(self);
- nDimI = THLongTensor_size(indices, 0);
- nDimV = THTensor_(nDimension)(values) - 1;
+ // TODO: we may need to special case when only one of these are empty.
+ if (THLongTensor_nDimension(indices) == 0 && THTensor_(nDimension)(values) == 0 && sizes != NULL) {
+ nDimI = 0;
+ nDimV = THLongStorage_size(sizes);
+ } else {
+ nDimI = THLongTensor_size(indices, 0);
+ nDimV = THTensor_(nDimension)(values) - 1;
+ }
if (!sizes) {
ignore = THLongTensor_new();
THLongTensor *computed_indices_sizes = THLongTensor_new();
@@ -169,27 +175,30 @@ THSTensor *THSTensor_(newWithTensorAndSize)(THLongTensor *indices, THTensor *val
THArgCheck(THLongStorage_size(sizes) == nDimI + nDimV, 2,
"number of dimensions must be nDimI + nDimV");
- THLongTensor *max_indices = THLongTensor_new();
- ignore = THLongTensor_new();
- THLongTensor_max(max_indices, ignore, indices, 1, 0);
- THLongTensor_free(ignore);
- for (int d = 0; d < nDimI; d++) {
- int64_t max_index_in_dim = THTensor_fastGet1d(max_indices, d);
- int64_t dim_size = sizes->data[d];
- THArgCheck(max_index_in_dim <= dim_size, 2,
- "sizes is inconsistent with indices: for dim %d, size is %lld but found index %lld",
- d, (long long)dim_size, (long long)max_index_in_dim);
- }
- for (int d = 0; d < nDimV; d++) {
- int64_t values_size = THTensor_(size)(values, d + 1);
- int64_t specified_size = sizes->data[nDimI + d];
- THArgCheck(values_size <= specified_size, 2,
- "values and sizes are inconsistent: sizes[%d] is %lld but values.size(%d) is %lld",
- d + nDimI, (long long)specified_size, d + 1, (long long)values_size);
+ // TODO: we may need to special case when only one of these are empty.
+ if (!(THLongTensor_nDimension(indices) == 0 && THTensor_(nDimension)(values) == 0 && sizes != NULL)) {
+ THLongTensor *max_indices = THLongTensor_new();
+ ignore = THLongTensor_new();
+ THLongTensor_max(max_indices, ignore, indices, 1, 0);
+ THLongTensor_free(ignore);
+ for (int d = 0; d < nDimI; d++) {
+ int64_t max_index_in_dim = THTensor_fastGet1d(max_indices, d);
+ int64_t dim_size = sizes->data[d];
+ THArgCheck(max_index_in_dim <= dim_size, 2,
+ "sizes is inconsistent with indices: for dim %d, size is %lld but found index %lld",
+ d, (long long)dim_size, (long long)max_index_in_dim);
+ }
+ for (int d = 0; d < nDimV; d++) {
+ int64_t values_size = THTensor_(size)(values, d + 1);
+ int64_t specified_size = sizes->data[nDimI + d];
+ THArgCheck(values_size <= specified_size, 2,
+ "values and sizes are inconsistent: sizes[%d] is %lld but values.size(%d) is %lld",
+ d + nDimI, (long long)specified_size, d + 1, (long long)values_size);
+ }
+ THLongTensor_free(max_indices);
}
THSTensor_(rawResize)(self, nDimI, nDimV, THLongStorage_data(sizes));
- THLongTensor_free(max_indices);
}
// NB: by default, we do NOT clone indices/values into the sparse tensor.
// Efficient API by default!
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 7279c8e5f4..01e006ae74 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -339,6 +339,13 @@ class TestSparse(TestCase):
y = x.clone()
self.assertTrue(y.is_coalesced())
+ @cuda_only
+ def test_cuda_empty(self):
+ from torch.autograd import Variable
+ x = Variable(torch.sparse.FloatTensor(2, 3, 4))
+ y = x.cuda(0)
+ x.cpu()
+
def test_transpose(self):
x = self._gen_sparse(4, 20, 5)[0]
y = self.safeToDense(x)