From d1a992a85ebce32b3aaa8efc520c35cbb17d97b0 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Thu, 19 Apr 2018 18:31:14 -0400 Subject: Disallow chunks that are <= in torch.chunk (#6761) Fixes #6759. Before, `tensor.chunk(0)` would cause a divide by 0. `tensor.chunk(-1)` would throw an error complaining that "split_size needs to be positive". This PR changes it so that the error message makes it clear that `chunks` has to be greater than 0. --- aten/src/ATen/native/TensorShape.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'aten') diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 0fc1b2d704..5cda20979b 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -33,7 +33,11 @@ Tensor cat(TensorList tensors, int64_t dim) { std::vector chunk(const Tensor& self, int64_t chunks, int64_t dim) { if (self.dim() == 0) { - throw std::runtime_error("chunk expects at least a 1-dimensional tensor"); + AT_ERROR("chunk expects at least a 1-dimensional tensor"); + } + if (chunks <= 0) { + AT_ERROR("chunk expects `chunks` to be greater than 0, got: %lld", + (long long)chunks); } int64_t split_size = (self.size(dim) + chunks - 1) / chunks; // ensure this is dispatched through Tensor/Type, rather than the native function directly. -- cgit v1.2.3