diff options
author | gchanan <gregchanan@gmail.com> | 2018-01-12 14:26:38 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-12 14:26:38 -0500 |
commit | eb857ec36760eac9db02f9d0cd6426a1415f3718 (patch) | |
tree | 4795d67ecbd97bbf0eb902dcc01efb4865533976 /torch/autograd/variable.py | |
parent | a14dd69be825158680c9b7ca213ac451ace9fdf6 (diff) | |
download | pytorch-eb857ec36760eac9db02f9d0cd6426a1415f3718.tar.gz pytorch-eb857ec36760eac9db02f9d0cd6426a1415f3718.tar.bz2 pytorch-eb857ec36760eac9db02f9d0cd6426a1415f3718.zip |
Introduce a (non-public) autograd scalar method and improve printing (#4586)
* Specialize Variable pinting and always print device for GPU tensors/Variables.
* Introduce a (non-public) _scalar_sum() method for autograd scalar testing.
Diffstat (limited to 'torch/autograd/variable.py')
-rw-r--r-- | torch/autograd/variable.py | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py index 4c95bc8760..ce11c38cc2 100644 --- a/torch/autograd/variable.py +++ b/torch/autograd/variable.py @@ -70,7 +70,25 @@ class Variable(_C._VariableBase): self.requires_grad, _, self._backward_hooks = state def __repr__(self): - return 'Variable containing:' + self.data.__repr__() + strt = 'Variable containing:' + torch._tensor_str._str(self.data, False) + # let's make our own Variable-specific footer + size_str = '(' + ','.join(str(size) for size in self.size()) + (',)' if len(self.size()) == 1 else ')') + device_str = '' if not self.is_cuda else \ + ' (GPU {})'.format(self.get_device()) + strt += '[{} of size {}{}]\n'.format(torch.typename(self.data), + size_str, device_str) + + # All strings are unicode in Python 3, while we have to encode unicode + # strings in Python2. If we can't, let python decide the best + # characters to replace unicode characters with. + if sys.version_info > (3,): + return strt + else: + if hasattr(sys.stdout, 'encoding'): + return strt.encode( + sys.stdout.encoding or 'UTF-8', 'replace') + else: + return strt.encode('UTF-8', 'replace') def backward(self, gradient=None, retain_graph=None, create_graph=False): """Computes the gradient of current variable w.r.t. graph leaves. |