summaryrefslogtreecommitdiff
path: root/torch/distributions
diff options
context:
space:
mode:
authorJeff Smith <jeffksmith@fb.com>2018-09-28 07:09:31 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-09-28 07:27:20 -0700
commitd291cf7de6487ea351ba015555ebb2dd2c660370 (patch)
tree9e31451a48f55af765fad41d5b6ce7c38fcb2c9e /torch/distributions
parent04c0971679374148aa4105e2e998de6478eca1eb (diff)
downloadpytorch-d291cf7de6487ea351ba015555ebb2dd2c660370.tar.gz
pytorch-d291cf7de6487ea351ba015555ebb2dd2c660370.tar.bz2
pytorch-d291cf7de6487ea351ba015555ebb2dd2c660370.zip
Ensuring positive definite matrix before constructing (#12102)
Summary: Ensuring positive definite matrix in Multivariate Normal Distribution Pull Request resolved: https://github.com/pytorch/pytorch/pull/12102 Reviewed By: ezyang, Balandat Differential Revision: D10052091 Pulled By: jeffreyksmithjr fbshipit-source-id: 276cfc6995f6a217a5ad9eac299445ff1b67a65f
Diffstat (limited to 'torch/distributions')
-rw-r--r--torch/distributions/multivariate_normal.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py
index 014a07e53c..345fe35cee 100644
--- a/torch/distributions/multivariate_normal.py
+++ b/torch/distributions/multivariate_normal.py
@@ -125,27 +125,29 @@ class MultivariateNormal(Distribution):
if scale_tril.dim() < 2:
raise ValueError("scale_tril matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
- self._unbroadcasted_scale_tril = scale_tril
self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_)
elif covariance_matrix is not None:
if covariance_matrix.dim() < 2:
raise ValueError("covariance_matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
- self._unbroadcasted_scale_tril = _batch_potrf_lower(covariance_matrix)
self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_)
else:
if precision_matrix.dim() < 2:
raise ValueError("precision_matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
- covariance_matrix = _batch_inverse(precision_matrix)
- self._unbroadcasted_scale_tril = _batch_potrf_lower(covariance_matrix)
- self.covariance_matrix, self.precision_matrix, loc_ = torch.broadcast_tensors(
- covariance_matrix, precision_matrix, loc_)
+ self.precision_matrix, loc_ = torch.broadcast_tensors(precision_matrix, loc_)
self.loc = loc_[..., 0] # drop rightmost dim
batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:]
super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)
+ if scale_tril is not None:
+ self._unbroadcasted_scale_tril = scale_tril
+ else:
+ if precision_matrix is not None:
+ self.covariance_matrix = _batch_inverse(precision_matrix).expand_as(loc_)
+ self._unbroadcasted_scale_tril = _batch_potrf_lower(self.covariance_matrix)
+
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(MultivariateNormal, _instance)
batch_shape = torch.Size(batch_shape)