summaryrefslogtreecommitdiff
path: root/torch/autograd/variable.py
diff options
context:
space:
mode:
authorgchanan <gregchanan@gmail.com>2018-01-12 14:26:38 -0500
committerGitHub <noreply@github.com>2018-01-12 14:26:38 -0500
commiteb857ec36760eac9db02f9d0cd6426a1415f3718 (patch)
tree4795d67ecbd97bbf0eb902dcc01efb4865533976 /torch/autograd/variable.py
parenta14dd69be825158680c9b7ca213ac451ace9fdf6 (diff)
downloadpytorch-eb857ec36760eac9db02f9d0cd6426a1415f3718.tar.gz
pytorch-eb857ec36760eac9db02f9d0cd6426a1415f3718.tar.bz2
pytorch-eb857ec36760eac9db02f9d0cd6426a1415f3718.zip
Introduce a (non-public) autograd scalar method and improve printing (#4586)
* Specialize Variable pinting and always print device for GPU tensors/Variables. * Introduce a (non-public) _scalar_sum() method for autograd scalar testing.
Diffstat (limited to 'torch/autograd/variable.py')
-rw-r--r--torch/autograd/variable.py20
1 files changed, 19 insertions, 1 deletions
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py
index 4c95bc8760..ce11c38cc2 100644
--- a/torch/autograd/variable.py
+++ b/torch/autograd/variable.py
@@ -70,7 +70,25 @@ class Variable(_C._VariableBase):
self.requires_grad, _, self._backward_hooks = state
def __repr__(self):
- return 'Variable containing:' + self.data.__repr__()
+ strt = 'Variable containing:' + torch._tensor_str._str(self.data, False)
+ # let's make our own Variable-specific footer
+ size_str = '(' + ','.join(str(size) for size in self.size()) + (',)' if len(self.size()) == 1 else ')')
+ device_str = '' if not self.is_cuda else \
+ ' (GPU {})'.format(self.get_device())
+ strt += '[{} of size {}{}]\n'.format(torch.typename(self.data),
+ size_str, device_str)
+
+ # All strings are unicode in Python 3, while we have to encode unicode
+ # strings in Python2. If we can't, let python decide the best
+ # characters to replace unicode characters with.
+ if sys.version_info > (3,):
+ return strt
+ else:
+ if hasattr(sys.stdout, 'encoding'):
+ return strt.encode(
+ sys.stdout.encoding or 'UTF-8', 'replace')
+ else:
+ return strt.encode('UTF-8', 'replace')
def backward(self, gradient=None, retain_graph=None, create_graph=False):
"""Computes the gradient of current variable w.r.t. graph leaves.