diff options
author | gchanan <gregchanan@gmail.com> | 2017-10-30 19:46:05 -0400 |
---|---|---|
committer | Sam Gross <colesbury@gmail.com> | 2017-10-30 19:46:05 -0400 |
commit | 3e6e81da460f6f90d65c0def58a3538cfe106cda (patch) | |
tree | e1ba02c1a25b3f4ecf1dc6875a2e203db801f7c8 /torch/autograd/variable.py | |
parent | 8cd0df020ca6a251ab566e1c5a84627cbc70d483 (diff) | |
download | pytorch-3e6e81da460f6f90d65c0def58a3538cfe106cda.tar.gz pytorch-3e6e81da460f6f90d65c0def58a3538cfe106cda.tar.bz2 pytorch-3e6e81da460f6f90d65c0def58a3538cfe106cda.zip |
Dispatch trivial variable operators to C++ aten functions. (#3372)
Implement __comparison_ops__ by calling the VariableBase methods.
Diffstat (limited to 'torch/autograd/variable.py')
-rw-r--r-- | torch/autograd/variable.py | 51 |
1 files changed, 8 insertions, 43 deletions
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py index 8b2655320c..fcf211816f 100644 --- a/torch/autograd/variable.py +++ b/torch/autograd/variable.py @@ -426,38 +426,18 @@ class Variable(_C._VariableBase): def bernoulli(self): return Bernoulli.apply(self) - __radd__ = __add__ = _C._VariableBase.add - - def __iadd__(self, other): - return self.add_(other) - - __sub__ = _C._VariableBase.sub - - def __isub__(self, other): - return self.sub_(other) - def __rsub__(self, other): return -self + other - __rmul__ = __mul__ = _C._VariableBase.mul - - def __imul__(self, other): - return self.mul_(other) - def __matmul__(self, other): if not isinstance(other, Variable): return NotImplemented return self.matmul(other) - __truediv__ = __div__ = _C._VariableBase.div - def __rdiv__(self, other): return self.reciprocal() * other __rtruediv__ = __rdiv__ - def __idiv__(self, other): - return self.div_(other) - __pow__ = _C._VariableBase.pow def __ipow__(self, other): @@ -466,8 +446,14 @@ class Variable(_C._VariableBase): def __rpow__(self, other): return PowConstant.apply(other, self) - def __neg__(self): - return Negate.apply(self) + __neg__ = _C._VariableBase.neg + + __eq__ = _C._VariableBase.eq + __ne__ = _C._VariableBase.ne + __lt__ = _C._VariableBase.lt + __le__ = _C._VariableBase.le + __gt__ = _C._VariableBase.gt + __ge__ = _C._VariableBase.ge def __len__(self): return len(self.data) @@ -481,27 +467,6 @@ class Variable(_C._VariableBase): # map will interleave them.) return iter(imap(lambda i: self[i], range(self.size(0)))) - def __mod__(self, other): - return self.remainder(other) - - def __eq__(self, other): - return self.eq(other) - - def __ne__(self, other): - return self.ne(other) - - def __lt__(self, other): - return self.lt(other) - - def __le__(self, other): - return self.le(other) - - def __gt__(self, other): - return self.gt(other) - - def __ge__(self, other): - return self.ge(other) - def __hash__(self): return id(self) |