summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorMorgan Funtowicz <morgan.funtowicz@naverlabs.com>2019-02-28 13:27:27 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-28 13:36:17 -0800
commitc596683309d9f50406d395f2b93b7c12b0e68fee (patch)
tree0e0c17469cff9ff5dc39ce4f3af9d925549d4d4e /torch
parent9cbd7a18f5af36699809fa19f98b2f2a24b6f6e1 (diff)
downloadpytorch-c596683309d9f50406d395f2b93b7c12b0e68fee.tar.gz
pytorch-c596683309d9f50406d395f2b93b7c12b0e68fee.tar.bz2
pytorch-c596683309d9f50406d395f2b93b7c12b0e68fee.zip
Rely on numel() == 1 to check if distribution parameters are scalar. (#17503)
Summary: As discussed here #16952, this PR aims at improving the __repr__ for distribution when the provided parameters are torch.Tensor with only one element. Currently, __repr__() relies on dim() == 0 leading to the following behaviour : ``` >>> torch.distributions.Normal(torch.tensor([1.0]), torch.tensor([0.1])) Normal(loc: torch.Size([1]), scale: torch.Size([1])) ``` With this PR, the output looks like the following: ``` >>> torch.distributions.Normal(torch.tensor([1.0]), torch.tensor([0.1])) Normal(loc: 1.0, scale: 0.10000000149011612) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/17503 Differential Revision: D14245439 Pulled By: soumith fbshipit-source-id: a440998905fd60cf2ac9a94f75706021dd9ce5bf
Diffstat (limited to 'torch')
-rw-r--r--torch/distributions/distribution.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py
index 2c2733a42d..d1e3a39247 100644
--- a/torch/distributions/distribution.py
+++ b/torch/distributions/distribution.py
@@ -262,6 +262,6 @@ class Distribution(object):
def __repr__(self):
param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p]
- if self.__dict__[p].dim() == 0
+ if self.__dict__[p].numel() == 1
else self.__dict__[p].size()) for p in param_names])
return self.__class__.__name__ + '(' + args_string + ')'