diff options
author | Richard Zou <zou3519@users.noreply.github.com> | 2018-04-19 18:31:14 -0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2018-04-19 18:31:14 -0400 |
commit | d1a992a85ebce32b3aaa8efc520c35cbb17d97b0 (patch) | |
tree | 4c4d904cac087500eb2817a2015831b88e3a34f4 /aten | |
parent | 264ffd143c3694df6f39a16ca0c5df3ba34a0d32 (diff) | |
download | pytorch-d1a992a85ebce32b3aaa8efc520c35cbb17d97b0.tar.gz pytorch-d1a992a85ebce32b3aaa8efc520c35cbb17d97b0.tar.bz2 pytorch-d1a992a85ebce32b3aaa8efc520c35cbb17d97b0.zip |
Disallow chunks that are <= in torch.chunk (#6761)
Fixes #6759.
Before, `tensor.chunk(0)` would cause a divide by 0.
`tensor.chunk(-1)` would throw an error complaining that "split_size
needs to be positive".
This PR changes it so that the error message makes it clear that
`chunks` has to be greater than 0.
Diffstat (limited to 'aten')
-rw-r--r-- | aten/src/ATen/native/TensorShape.cpp | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 0fc1b2d704..5cda20979b 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -33,7 +33,11 @@ Tensor cat(TensorList tensors, int64_t dim) { std::vector<Tensor> chunk(const Tensor& self, int64_t chunks, int64_t dim) { if (self.dim() == 0) { - throw std::runtime_error("chunk expects at least a 1-dimensional tensor"); + AT_ERROR("chunk expects at least a 1-dimensional tensor"); + } + if (chunks <= 0) { + AT_ERROR("chunk expects `chunks` to be greater than 0, got: %lld", + (long long)chunks); } int64_t split_size = (self.size(dim) + chunks - 1) / chunks; // ensure this is dispatched through Tensor/Type, rather than the native function directly. |