diff options
author | Morgan Funtowicz <morgan.funtowicz@naverlabs.com> | 2019-02-28 13:27:27 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-28 13:36:17 -0800 |
commit | c596683309d9f50406d395f2b93b7c12b0e68fee (patch) | |
tree | 0e0c17469cff9ff5dc39ce4f3af9d925549d4d4e /torch | |
parent | 9cbd7a18f5af36699809fa19f98b2f2a24b6f6e1 (diff) | |
download | pytorch-c596683309d9f50406d395f2b93b7c12b0e68fee.tar.gz pytorch-c596683309d9f50406d395f2b93b7c12b0e68fee.tar.bz2 pytorch-c596683309d9f50406d395f2b93b7c12b0e68fee.zip |
Rely on numel() == 1 to check if distribution parameters are scalar. (#17503)
Summary:
As discussed here #16952, this PR aims at improving the __repr__ for distribution when the provided parameters are torch.Tensor with only one element.
Currently, __repr__() relies on dim() == 0 leading to the following behaviour :
```
>>> torch.distributions.Normal(torch.tensor([1.0]), torch.tensor([0.1]))
Normal(loc: torch.Size([1]), scale: torch.Size([1]))
```
With this PR, the output looks like the following:
```
>>> torch.distributions.Normal(torch.tensor([1.0]), torch.tensor([0.1]))
Normal(loc: 1.0, scale: 0.10000000149011612)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17503
Differential Revision: D14245439
Pulled By: soumith
fbshipit-source-id: a440998905fd60cf2ac9a94f75706021dd9ce5bf
Diffstat (limited to 'torch')
-rw-r--r-- | torch/distributions/distribution.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index 2c2733a42d..d1e3a39247 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -262,6 +262,6 @@ class Distribution(object): def __repr__(self): param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p] - if self.__dict__[p].dim() == 0 + if self.__dict__[p].numel() == 1 else self.__dict__[p].size()) for p in param_names]) return self.__class__.__name__ + '(' + args_string + ')' |