summaryrefslogtreecommitdiff
path: root/aten
diff options
context:
space:
mode:
authorRichard Zou <zou3519@users.noreply.github.com>2018-04-19 18:31:14 -0400
committerSoumith Chintala <soumith@gmail.com>2018-04-19 18:31:14 -0400
commitd1a992a85ebce32b3aaa8efc520c35cbb17d97b0 (patch)
tree4c4d904cac087500eb2817a2015831b88e3a34f4 /aten
parent264ffd143c3694df6f39a16ca0c5df3ba34a0d32 (diff)
downloadpytorch-d1a992a85ebce32b3aaa8efc520c35cbb17d97b0.tar.gz
pytorch-d1a992a85ebce32b3aaa8efc520c35cbb17d97b0.tar.bz2
pytorch-d1a992a85ebce32b3aaa8efc520c35cbb17d97b0.zip
Disallow chunks that are <= in torch.chunk (#6761)
Fixes #6759. Before, `tensor.chunk(0)` would cause a divide by 0. `tensor.chunk(-1)` would throw an error complaining that "split_size needs to be positive". This PR changes it so that the error message makes it clear that `chunks` has to be greater than 0.
Diffstat (limited to 'aten')
-rw-r--r--aten/src/ATen/native/TensorShape.cpp6
1 files changed, 5 insertions, 1 deletions
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 0fc1b2d704..5cda20979b 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -33,7 +33,11 @@ Tensor cat(TensorList tensors, int64_t dim) {
std::vector<Tensor> chunk(const Tensor& self, int64_t chunks, int64_t dim) {
if (self.dim() == 0) {
- throw std::runtime_error("chunk expects at least a 1-dimensional tensor");
+ AT_ERROR("chunk expects at least a 1-dimensional tensor");
+ }
+ if (chunks <= 0) {
+ AT_ERROR("chunk expects `chunks` to be greater than 0, got: %lld",
+ (long long)chunks);
}
int64_t split_size = (self.size(dim) + chunks - 1) / chunks;
// ensure this is dispatched through Tensor/Type, rather than the native function directly.