diff options
-rw-r--r-- | aten/src/THCS/generic/THCSTensor.cpp | 11 | ||||
-rw-r--r-- | aten/src/THS/generic/THSTensor.cpp | 49 | ||||
-rw-r--r-- | test/test_sparse.py | 7 |
3 files changed, 45 insertions, 22 deletions
diff --git a/aten/src/THCS/generic/THCSTensor.cpp b/aten/src/THCS/generic/THCSTensor.cpp index 208c71c8af..da97263f09 100644 --- a/aten/src/THCS/generic/THCSTensor.cpp +++ b/aten/src/THCS/generic/THCSTensor.cpp @@ -146,8 +146,15 @@ THCSTensor *THCSTensor_(newWithTensorAndSize)(THCState *state, THCIndexTensor *i THCSTensor *self = (THCSTensor *)THAlloc(sizeof(THCSTensor)); THCSTensor_(rawInit)(state, self); - nDimI = THCIndexTensor_(size)(state, indices, 0); - nDimV = THCTensor_(nDimension)(state, values) - 1; + // TODO: we may need to special case when only one of these are empty. + if (THCudaLongTensor_nDimension(state, indices) == 0 && THCTensor_(nDimension)(state, values) == 0 + && sizes != NULL) { + nDimI = 0; + nDimV = THLongStorage_size(sizes); + } else { + nDimI = THCIndexTensor_(size)(state, indices, 0); + nDimV = THCTensor_(nDimension)(state, values) - 1; + } if (!sizes) { // TODO Make it work with N-dimensional values THArgCheck(nDimV > 0, 3, "size must be provided when nDimV > 0"); diff --git a/aten/src/THS/generic/THSTensor.cpp b/aten/src/THS/generic/THSTensor.cpp index ed74eb0d8d..e3ee532467 100644 --- a/aten/src/THS/generic/THSTensor.cpp +++ b/aten/src/THS/generic/THSTensor.cpp @@ -146,8 +146,14 @@ THSTensor *THSTensor_(newWithTensorAndSize)(THLongTensor *indices, THTensor *val THSTensor *self = (THSTensor *)THAlloc(sizeof(THSTensor)); THSTensor_(rawInit)(self); - nDimI = THLongTensor_size(indices, 0); - nDimV = THTensor_(nDimension)(values) - 1; + // TODO: we may need to special case when only one of these are empty. + if (THLongTensor_nDimension(indices) == 0 && THTensor_(nDimension)(values) == 0 && sizes != NULL) { + nDimI = 0; + nDimV = THLongStorage_size(sizes); + } else { + nDimI = THLongTensor_size(indices, 0); + nDimV = THTensor_(nDimension)(values) - 1; + } if (!sizes) { ignore = THLongTensor_new(); THLongTensor *computed_indices_sizes = THLongTensor_new(); @@ -169,27 +175,30 @@ THSTensor *THSTensor_(newWithTensorAndSize)(THLongTensor *indices, THTensor *val THArgCheck(THLongStorage_size(sizes) == nDimI + nDimV, 2, "number of dimensions must be nDimI + nDimV"); - THLongTensor *max_indices = THLongTensor_new(); - ignore = THLongTensor_new(); - THLongTensor_max(max_indices, ignore, indices, 1, 0); - THLongTensor_free(ignore); - for (int d = 0; d < nDimI; d++) { - int64_t max_index_in_dim = THTensor_fastGet1d(max_indices, d); - int64_t dim_size = sizes->data[d]; - THArgCheck(max_index_in_dim <= dim_size, 2, - "sizes is inconsistent with indices: for dim %d, size is %lld but found index %lld", - d, (long long)dim_size, (long long)max_index_in_dim); - } - for (int d = 0; d < nDimV; d++) { - int64_t values_size = THTensor_(size)(values, d + 1); - int64_t specified_size = sizes->data[nDimI + d]; - THArgCheck(values_size <= specified_size, 2, - "values and sizes are inconsistent: sizes[%d] is %lld but values.size(%d) is %lld", - d + nDimI, (long long)specified_size, d + 1, (long long)values_size); + // TODO: we may need to special case when only one of these are empty. + if (!(THLongTensor_nDimension(indices) == 0 && THTensor_(nDimension)(values) == 0 && sizes != NULL)) { + THLongTensor *max_indices = THLongTensor_new(); + ignore = THLongTensor_new(); + THLongTensor_max(max_indices, ignore, indices, 1, 0); + THLongTensor_free(ignore); + for (int d = 0; d < nDimI; d++) { + int64_t max_index_in_dim = THTensor_fastGet1d(max_indices, d); + int64_t dim_size = sizes->data[d]; + THArgCheck(max_index_in_dim <= dim_size, 2, + "sizes is inconsistent with indices: for dim %d, size is %lld but found index %lld", + d, (long long)dim_size, (long long)max_index_in_dim); + } + for (int d = 0; d < nDimV; d++) { + int64_t values_size = THTensor_(size)(values, d + 1); + int64_t specified_size = sizes->data[nDimI + d]; + THArgCheck(values_size <= specified_size, 2, + "values and sizes are inconsistent: sizes[%d] is %lld but values.size(%d) is %lld", + d + nDimI, (long long)specified_size, d + 1, (long long)values_size); + } + THLongTensor_free(max_indices); } THSTensor_(rawResize)(self, nDimI, nDimV, THLongStorage_data(sizes)); - THLongTensor_free(max_indices); } // NB: by default, we do NOT clone indices/values into the sparse tensor. // Efficient API by default! diff --git a/test/test_sparse.py b/test/test_sparse.py index 7279c8e5f4..01e006ae74 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -339,6 +339,13 @@ class TestSparse(TestCase): y = x.clone() self.assertTrue(y.is_coalesced()) + @cuda_only + def test_cuda_empty(self): + from torch.autograd import Variable + x = Variable(torch.sparse.FloatTensor(2, 3, 4)) + y = x.cuda(0) + x.cpu() + def test_transpose(self): x = self._gen_sparse(4, 20, 5)[0] y = self.safeToDense(x) |