summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--aten/src/ATen/native/BinaryOps.cpp48
1 files changed, 28 insertions, 20 deletions
diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp
index 318d6e5348..bdd7299c16 100644
--- a/aten/src/ATen/native/BinaryOps.cpp
+++ b/aten/src/ATen/native/BinaryOps.cpp
@@ -15,9 +15,6 @@ DEFINE_DISPATCH(div_stub);
Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
if (other.is_sparse()) {
- if (!result.defined()) {
- result = at::empty({0}, self.options());
- }
if (self.is_sparse()) {
at::_sparse_add_out(result, self, other, alpha);
} else {
@@ -29,13 +26,18 @@ Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar
}
auto iter = TensorIterator::binary_op(result, self, other);
add_stub(iter->device_type(), *iter, alpha);
- result = iter->output();
return result;
}
Tensor add(const Tensor& self, const Tensor& other, Scalar alpha) {
Tensor result;
- return native::add_out(result, self, other, alpha);
+ if (other.is_sparse()) {
+ result = at::empty({0}, self.options());
+ return native::add_out(result, self, other, alpha);
+ }
+ auto iter = TensorIterator::binary_op(result, self, other);
+ add_stub(iter->device_type(), *iter, alpha);
+ return iter->output();
}
Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) {
@@ -44,9 +46,6 @@ Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) {
Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
if (self.is_sparse()) {
- if (!result.defined()) {
- result = at::empty({0}, self.options());
- }
if (other.dim() != 0) {
AT_ERROR("div(): sparse division only supports division by a scalar ",
"(got shape ", other.sizes(), " for argument 'other')");
@@ -55,13 +54,18 @@ Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
}
auto iter = TensorIterator::binary_op(result, self, other);
div_stub(iter->device_type(), *iter);
- result = iter->output();
return result;
}
Tensor div(const Tensor& self, const Tensor& other) {
Tensor result;
- return native::div_out(result, self, other);
+ if (self.is_sparse()) {
+ result = at::empty({0}, self.options());
+ return native::div_out(result, self, other);
+ }
+ auto iter = TensorIterator::binary_op(result, self, other);
+ div_stub(iter->device_type(), *iter);
+ return iter->output();
}
Tensor& div_(Tensor& self, const Tensor& other) {
@@ -70,20 +74,22 @@ Tensor& div_(Tensor& self, const Tensor& other) {
Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) {
if (self.is_sparse() || other.is_sparse()) {
- if (!result.defined()) {
- result = at::empty({0}, self.options());
- }
return at::_sparse_mul_out(result, self, other);
}
auto iter = TensorIterator::binary_op(result, self, other);
mul_stub(iter->device_type(), *iter);
- result = iter->output();
return result;
}
Tensor mul(const Tensor& self, const Tensor& other) {
Tensor result;
- return native::mul_out(result, self, other);
+ if (self.is_sparse() || other.is_sparse()) {
+ result = at::empty({0}, self.options());
+ return native::mul_out(result, self, other);
+ }
+ auto iter = TensorIterator::binary_op(result, self, other);
+ mul_stub(iter->device_type(), *iter);
+ return iter->output();
}
Tensor& mul_(Tensor& self, const Tensor& other) {
@@ -92,9 +98,6 @@ Tensor& mul_(Tensor& self, const Tensor& other) {
Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
if (other.is_sparse()) {
- if (!result.defined()) {
- result = at::empty({0}, self.options());
- }
if (!self.sizes().equals(other.sizes())) {
AT_ERROR("sizes do not match");
}
@@ -109,13 +112,18 @@ Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar
}
auto iter = TensorIterator::binary_op(result, self, other);
sub_stub(iter->device_type(), *iter, alpha);
- result = iter->output();
return result;
}
Tensor sub(const Tensor& self, const Tensor& other, Scalar alpha) {
Tensor result;
- return native::sub_out(result, self, other, alpha);
+ if (other.is_sparse()) {
+ result = at::empty({0}, self.options());
+ return native::sub_out(result, self, other, alpha);
+ }
+ auto iter = TensorIterator::binary_op(result, self, other);
+ sub_stub(iter->device_type(), *iter, alpha);
+ return iter->output();
}
Tensor& sub_(Tensor& self, const Tensor& other, Scalar alpha) {