Skip to content
Snippets Groups Projects
Commit b17a48d5 authored by Matthew Pitkin's avatar Matthew Pitkin Committed by Colm Talbot
Browse files

Flatten the list of rescaled values in a ConditionalPriorDict

parent 0b7ea1b7
No related branches found
No related tags found
1 merge request!1112Flatten the list of rescaled values in a ConditionalPriorDict
......@@ -824,6 +824,8 @@ class ConditionalPriorDict(PriorDict):
=======
list: List of floats containing the rescaled sample
"""
from matplotlib.cbook import flatten
keys = list(keys)
theta = list(theta)
self._check_resolved()
......@@ -836,7 +838,7 @@ class ConditionalPriorDict(PriorDict):
theta[index], **self.get_required_variables(key)
)
self[key].least_recently_sampled = result[key]
return [result[key] for key in keys]
return list(flatten([result[key] for key in keys]))
def _update_rescale_keys(self, keys):
if not keys == self._least_recently_rescaled_keys:
......
......@@ -43,7 +43,7 @@ class BaseJointPriorDist(object):
if isinstance(bounds, (list, tuple, np.ndarray)):
if len(bound) != 2:
raise ValueError(
"Bounds must contain an upper and " "lower value."
"Bounds must contain an upper and lower value."
)
else:
if bound[1] <= bound[0]:
......@@ -399,7 +399,7 @@ class MultivariateGaussianDist(BaseJointPriorDist):
if len(np.shape(sigmas)) == 1:
sigmas = [sigmas]
elif len(np.shape(sigmas)) == 0:
raise ValueError("Must supply a list of standard " "deviations")
raise ValueError("Must supply a list of standard deviations")
if covs is not None:
if isinstance(covs, np.ndarray):
covs = [covs]
......@@ -421,7 +421,7 @@ class MultivariateGaussianDist(BaseJointPriorDist):
"List of correlation coefficients the wrong shape"
)
elif not isinstance(corrcoefs, list):
raise TypeError("Must pass a list of correlation " "coefficients")
raise TypeError("Must pass a list of correlation coefficients")
if weights is not None:
if isinstance(weights, (int, float)):
weights = [weights]
......@@ -489,7 +489,7 @@ class MultivariateGaussianDist(BaseJointPriorDist):
if len(self.corrcoefs[-1].shape) != 2:
raise ValueError(
"Correlation coefficient matrix must be a 2d " "array."
"Correlation coefficient matrix must be a 2d array."
)
if (
......@@ -497,16 +497,16 @@ class MultivariateGaussianDist(BaseJointPriorDist):
or self.corrcoefs[-1].shape[0] != self.num_vars
):
raise ValueError(
"Correlation coefficient matrix shape is " "inconsistent"
"Correlation coefficient matrix shape is inconsistent"
)
# check matrix is symmetric
if not np.allclose(self.corrcoefs[-1], self.corrcoefs[-1].T):
raise ValueError("Correlation coefficient matrix is not " "symmetric")
raise ValueError("Correlation coefficient matrix is not symmetric")
# check diagonal is all ones
if not np.all(np.diag(self.corrcoefs[-1]) == 1.0):
raise ValueError("Correlation coefficient matrix is not" "correct")
raise ValueError("Correlation coefficient matrix is not correct")
try:
self.sigmas.append(list(sigmas)) # standard deviations
......@@ -535,13 +535,13 @@ class MultivariateGaussianDist(BaseJointPriorDist):
self.eigvectors.append(evecs)
except Exception as e:
raise RuntimeError(
"Problem getting eigenvalues and vectors: " "{}".format(e)
"Problem getting eigenvalues and vectors: {}".format(e)
)
# check eigenvalues are positive
if np.any(self.eigvalues[-1] <= 0.0):
raise ValueError(
"Correlation coefficient matrix is not positive " "definite"
"Correlation coefficient matrix is not positive definite"
)
self.sqeigvalues.append(np.sqrt(self.eigvalues[-1]))
......
......@@ -320,6 +320,43 @@ class TestConditionalPriorDict(unittest.TestCase):
expected.append(expected[-1] * self.test_sample[f"var_{ii}"])
self.assertListEqual(expected, res)
def test_rescale_with_joint_prior(self):
"""
Add a joint prior into the conditional prior dictionary and check that
the returned list is flat.
"""
# set multivariate Gaussian distribution
names = ["mvgvar_0", "mvgvar_1"]
mu = [[0.79, -0.83]]
cov = [[[0.03, 0.], [0., 0.04]]]
mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov)
priordict = bilby.core.prior.ConditionalPriorDict(
dict(
var_3=self.prior_3,
var_2=self.prior_2,
var_0=self.prior_0,
var_1=self.prior_1,
mvgvar_0=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_0"),
mvgvar_1=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_1"),
)
)
ref_variables = list(self.test_sample.values()) + [0.4, 0.1]
keys = list(self.test_sample.keys()) + names
res = priordict.rescale(keys=keys, theta=ref_variables)
self.assertIsInstance(res, list)
self.assertEqual(np.shape(res), (6,))
self.assertListEqual([isinstance(r, float) for r in res], 6 * [True])
# check conditional values are still as expected
expected = [self.test_sample["var_0"]]
for ii in range(1, 4):
expected.append(expected[-1] * self.test_sample[f"var_{ii}"])
self.assertListEqual(expected, res[0:4])
def test_cdf(self):
"""
Test that the CDF method is the inverse of the rescale method.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment