diff options
-rw-r--r-- | aten/src/ATen/native/BinaryOps.cpp | 48 |
1 files changed, 28 insertions, 20 deletions
diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index 318d6e5348..bdd7299c16 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -15,9 +15,6 @@ DEFINE_DISPATCH(div_stub); Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { if (other.is_sparse()) { - if (!result.defined()) { - result = at::empty({0}, self.options()); - } if (self.is_sparse()) { at::_sparse_add_out(result, self, other, alpha); } else { @@ -29,13 +26,18 @@ Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar } auto iter = TensorIterator::binary_op(result, self, other); add_stub(iter->device_type(), *iter, alpha); - result = iter->output(); return result; } Tensor add(const Tensor& self, const Tensor& other, Scalar alpha) { Tensor result; - return native::add_out(result, self, other, alpha); + if (other.is_sparse()) { + result = at::empty({0}, self.options()); + return native::add_out(result, self, other, alpha); + } + auto iter = TensorIterator::binary_op(result, self, other); + add_stub(iter->device_type(), *iter, alpha); + return iter->output(); } Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) { @@ -44,9 +46,6 @@ Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) { Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { if (self.is_sparse()) { - if (!result.defined()) { - result = at::empty({0}, self.options()); - } if (other.dim() != 0) { AT_ERROR("div(): sparse division only supports division by a scalar ", "(got shape ", other.sizes(), " for argument 'other')"); @@ -55,13 +54,18 @@ Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { } auto iter = TensorIterator::binary_op(result, self, other); div_stub(iter->device_type(), *iter); - result = iter->output(); return result; } Tensor div(const Tensor& self, const Tensor& other) { Tensor result; - return native::div_out(result, self, other); + if (self.is_sparse()) { + result = at::empty({0}, self.options()); + return native::div_out(result, self, other); + } + auto iter = TensorIterator::binary_op(result, self, other); + div_stub(iter->device_type(), *iter); + return iter->output(); } Tensor& div_(Tensor& self, const Tensor& other) { @@ -70,20 +74,22 @@ Tensor& div_(Tensor& self, const Tensor& other) { Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) { if (self.is_sparse() || other.is_sparse()) { - if (!result.defined()) { - result = at::empty({0}, self.options()); - } return at::_sparse_mul_out(result, self, other); } auto iter = TensorIterator::binary_op(result, self, other); mul_stub(iter->device_type(), *iter); - result = iter->output(); return result; } Tensor mul(const Tensor& self, const Tensor& other) { Tensor result; - return native::mul_out(result, self, other); + if (self.is_sparse() || other.is_sparse()) { + result = at::empty({0}, self.options()); + return native::mul_out(result, self, other); + } + auto iter = TensorIterator::binary_op(result, self, other); + mul_stub(iter->device_type(), *iter); + return iter->output(); } Tensor& mul_(Tensor& self, const Tensor& other) { @@ -92,9 +98,6 @@ Tensor& mul_(Tensor& self, const Tensor& other) { Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { if (other.is_sparse()) { - if (!result.defined()) { - result = at::empty({0}, self.options()); - } if (!self.sizes().equals(other.sizes())) { AT_ERROR("sizes do not match"); } @@ -109,13 +112,18 @@ Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar } auto iter = TensorIterator::binary_op(result, self, other); sub_stub(iter->device_type(), *iter, alpha); - result = iter->output(); return result; } Tensor sub(const Tensor& self, const Tensor& other, Scalar alpha) { Tensor result; - return native::sub_out(result, self, other, alpha); + if (other.is_sparse()) { + result = at::empty({0}, self.options()); + return native::sub_out(result, self, other, alpha); + } + auto iter = TensorIterator::binary_op(result, self, other); + sub_stub(iter->device_type(), *iter, alpha); + return iter->output(); } Tensor& sub_(Tensor& self, const Tensor& other, Scalar alpha) { |