diff options
author | Jeff Smith <jeffksmith@fb.com> | 2018-09-28 07:09:31 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-09-28 07:27:20 -0700 |
commit | d291cf7de6487ea351ba015555ebb2dd2c660370 (patch) | |
tree | 9e31451a48f55af765fad41d5b6ce7c38fcb2c9e /torch/distributions | |
parent | 04c0971679374148aa4105e2e998de6478eca1eb (diff) | |
download | pytorch-d291cf7de6487ea351ba015555ebb2dd2c660370.tar.gz pytorch-d291cf7de6487ea351ba015555ebb2dd2c660370.tar.bz2 pytorch-d291cf7de6487ea351ba015555ebb2dd2c660370.zip |
Ensuring positive definite matrix before constructing (#12102)
Summary:
Ensuring positive definite matrix in Multivariate Normal Distribution
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12102
Reviewed By: ezyang, Balandat
Differential Revision: D10052091
Pulled By: jeffreyksmithjr
fbshipit-source-id: 276cfc6995f6a217a5ad9eac299445ff1b67a65f
Diffstat (limited to 'torch/distributions')
-rw-r--r-- | torch/distributions/multivariate_normal.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index 014a07e53c..345fe35cee 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -125,27 +125,29 @@ class MultivariateNormal(Distribution): if scale_tril.dim() < 2: raise ValueError("scale_tril matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self._unbroadcasted_scale_tril = scale_tril self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_) elif covariance_matrix is not None: if covariance_matrix.dim() < 2: raise ValueError("covariance_matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self._unbroadcasted_scale_tril = _batch_potrf_lower(covariance_matrix) self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_) else: if precision_matrix.dim() < 2: raise ValueError("precision_matrix must be at least two-dimensional, " "with optional leading batch dimensions") - covariance_matrix = _batch_inverse(precision_matrix) - self._unbroadcasted_scale_tril = _batch_potrf_lower(covariance_matrix) - self.covariance_matrix, self.precision_matrix, loc_ = torch.broadcast_tensors( - covariance_matrix, precision_matrix, loc_) + self.precision_matrix, loc_ = torch.broadcast_tensors(precision_matrix, loc_) self.loc = loc_[..., 0] # drop rightmost dim batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:] super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args) + if scale_tril is not None: + self._unbroadcasted_scale_tril = scale_tril + else: + if precision_matrix is not None: + self.covariance_matrix = _batch_inverse(precision_matrix).expand_as(loc_) + self._unbroadcasted_scale_tril = _batch_potrf_lower(self.covariance_matrix) + def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(MultivariateNormal, _instance) batch_shape = torch.Size(batch_shape) |