diff options
Diffstat (limited to 'aten')
-rw-r--r-- | aten/src/ATen/native/TensorShape.cpp | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 0fc1b2d704..5cda20979b 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -33,7 +33,11 @@ Tensor cat(TensorList tensors, int64_t dim) { std::vector<Tensor> chunk(const Tensor& self, int64_t chunks, int64_t dim) { if (self.dim() == 0) { - throw std::runtime_error("chunk expects at least a 1-dimensional tensor"); + AT_ERROR("chunk expects at least a 1-dimensional tensor"); + } + if (chunks <= 0) { + AT_ERROR("chunk expects `chunks` to be greater than 0, got: %lld", + (long long)chunks); } int64_t split_size = (self.size(dim) + chunks - 1) / chunks; // ensure this is dispatched through Tensor/Type, rather than the native function directly. |