summaryrefslogtreecommitdiff
path: root/torch/autograd/variable.py
diff options
context:
space:
mode:
authorgchanan <gregchanan@gmail.com>2017-10-30 19:46:05 -0400
committerSam Gross <colesbury@gmail.com>2017-10-30 19:46:05 -0400
commit3e6e81da460f6f90d65c0def58a3538cfe106cda (patch)
treee1ba02c1a25b3f4ecf1dc6875a2e203db801f7c8 /torch/autograd/variable.py
parent8cd0df020ca6a251ab566e1c5a84627cbc70d483 (diff)
downloadpytorch-3e6e81da460f6f90d65c0def58a3538cfe106cda.tar.gz
pytorch-3e6e81da460f6f90d65c0def58a3538cfe106cda.tar.bz2
pytorch-3e6e81da460f6f90d65c0def58a3538cfe106cda.zip
Dispatch trivial variable operators to C++ aten functions. (#3372)
Implement __comparison_ops__ by calling the VariableBase methods.
Diffstat (limited to 'torch/autograd/variable.py')
-rw-r--r--torch/autograd/variable.py51
1 files changed, 8 insertions, 43 deletions
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py
index 8b2655320c..fcf211816f 100644
--- a/torch/autograd/variable.py
+++ b/torch/autograd/variable.py
@@ -426,38 +426,18 @@ class Variable(_C._VariableBase):
def bernoulli(self):
return Bernoulli.apply(self)
- __radd__ = __add__ = _C._VariableBase.add
-
- def __iadd__(self, other):
- return self.add_(other)
-
- __sub__ = _C._VariableBase.sub
-
- def __isub__(self, other):
- return self.sub_(other)
-
def __rsub__(self, other):
return -self + other
- __rmul__ = __mul__ = _C._VariableBase.mul
-
- def __imul__(self, other):
- return self.mul_(other)
-
def __matmul__(self, other):
if not isinstance(other, Variable):
return NotImplemented
return self.matmul(other)
- __truediv__ = __div__ = _C._VariableBase.div
-
def __rdiv__(self, other):
return self.reciprocal() * other
__rtruediv__ = __rdiv__
- def __idiv__(self, other):
- return self.div_(other)
-
__pow__ = _C._VariableBase.pow
def __ipow__(self, other):
@@ -466,8 +446,14 @@ class Variable(_C._VariableBase):
def __rpow__(self, other):
return PowConstant.apply(other, self)
- def __neg__(self):
- return Negate.apply(self)
+ __neg__ = _C._VariableBase.neg
+
+ __eq__ = _C._VariableBase.eq
+ __ne__ = _C._VariableBase.ne
+ __lt__ = _C._VariableBase.lt
+ __le__ = _C._VariableBase.le
+ __gt__ = _C._VariableBase.gt
+ __ge__ = _C._VariableBase.ge
def __len__(self):
return len(self.data)
@@ -481,27 +467,6 @@ class Variable(_C._VariableBase):
# map will interleave them.)
return iter(imap(lambda i: self[i], range(self.size(0))))
- def __mod__(self, other):
- return self.remainder(other)
-
- def __eq__(self, other):
- return self.eq(other)
-
- def __ne__(self, other):
- return self.ne(other)
-
- def __lt__(self, other):
- return self.lt(other)
-
- def __le__(self, other):
- return self.le(other)
-
- def __gt__(self, other):
- return self.gt(other)
-
- def __ge__(self, other):
- return self.ge(other)
-
def __hash__(self):
return id(self)