summaryrefslogtreecommitdiff
path: root/aten
diff options
context:
space:
mode:
Diffstat (limited to 'aten')
-rw-r--r--aten/src/ATen/native/TensorShape.cpp6
1 files changed, 5 insertions, 1 deletions
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 0fc1b2d704..5cda20979b 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -33,7 +33,11 @@ Tensor cat(TensorList tensors, int64_t dim) {
std::vector<Tensor> chunk(const Tensor& self, int64_t chunks, int64_t dim) {
if (self.dim() == 0) {
- throw std::runtime_error("chunk expects at least a 1-dimensional tensor");
+ AT_ERROR("chunk expects at least a 1-dimensional tensor");
+ }
+ if (chunks <= 0) {
+ AT_ERROR("chunk expects `chunks` to be greater than 0, got: %lld",
+ (long long)chunks);
}
int64_t split_size = (self.size(dim) + chunks - 1) / chunks;
// ensure this is dispatched through Tensor/Type, rather than the native function directly.