diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 85e8ac4ef75c8808fd20f1b175fd5214364deb1a..d25ca6487560bbb946b8930ba8c625e9d0f2d766 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -824,6 +824,8 @@ class ConditionalPriorDict(PriorDict): ======= list: List of floats containing the rescaled sample """ + from matplotlib.cbook import flatten + keys = list(keys) theta = list(theta) self._check_resolved() @@ -836,7 +838,7 @@ class ConditionalPriorDict(PriorDict): theta[index], **self.get_required_variables(key) ) self[key].least_recently_sampled = result[key] - return [result[key] for key in keys] + return list(flatten([result[key] for key in keys])) def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 742e7920b067877ce4d8aa3d2aa0f862db1dee4f..9dc35f40dd57b15a4b14c77f3395b8d1839ccac0 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -43,7 +43,7 @@ class BaseJointPriorDist(object): if isinstance(bounds, (list, tuple, np.ndarray)): if len(bound) != 2: raise ValueError( - "Bounds must contain an upper and " "lower value." + "Bounds must contain an upper and lower value." ) else: if bound[1] <= bound[0]: @@ -399,7 +399,7 @@ class MultivariateGaussianDist(BaseJointPriorDist): if len(np.shape(sigmas)) == 1: sigmas = [sigmas] elif len(np.shape(sigmas)) == 0: - raise ValueError("Must supply a list of standard " "deviations") + raise ValueError("Must supply a list of standard deviations") if covs is not None: if isinstance(covs, np.ndarray): covs = [covs] @@ -421,7 +421,7 @@ class MultivariateGaussianDist(BaseJointPriorDist): "List of correlation coefficients the wrong shape" ) elif not isinstance(corrcoefs, list): - raise TypeError("Must pass a list of correlation " "coefficients") + raise TypeError("Must pass a list of correlation coefficients") if weights is not None: if isinstance(weights, (int, float)): weights = [weights] @@ -489,7 +489,7 @@ class MultivariateGaussianDist(BaseJointPriorDist): if len(self.corrcoefs[-1].shape) != 2: raise ValueError( - "Correlation coefficient matrix must be a 2d " "array." + "Correlation coefficient matrix must be a 2d array." ) if ( @@ -497,16 +497,16 @@ class MultivariateGaussianDist(BaseJointPriorDist): or self.corrcoefs[-1].shape[0] != self.num_vars ): raise ValueError( - "Correlation coefficient matrix shape is " "inconsistent" + "Correlation coefficient matrix shape is inconsistent" ) # check matrix is symmetric if not np.allclose(self.corrcoefs[-1], self.corrcoefs[-1].T): - raise ValueError("Correlation coefficient matrix is not " "symmetric") + raise ValueError("Correlation coefficient matrix is not symmetric") # check diagonal is all ones if not np.all(np.diag(self.corrcoefs[-1]) == 1.0): - raise ValueError("Correlation coefficient matrix is not" "correct") + raise ValueError("Correlation coefficient matrix is not correct") try: self.sigmas.append(list(sigmas)) # standard deviations @@ -535,13 +535,13 @@ class MultivariateGaussianDist(BaseJointPriorDist): self.eigvectors.append(evecs) except Exception as e: raise RuntimeError( - "Problem getting eigenvalues and vectors: " "{}".format(e) + "Problem getting eigenvalues and vectors: {}".format(e) ) # check eigenvalues are positive if np.any(self.eigvalues[-1] <= 0.0): raise ValueError( - "Correlation coefficient matrix is not positive " "definite" + "Correlation coefficient matrix is not positive definite" ) self.sqeigvalues.append(np.sqrt(self.eigvalues[-1])) diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 3f7e8873169572ab0d6a096c6ddc8ad4f67078d2..4af73bdaa0c7f4f01f66b38089eb22e7fdbd73bb 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -320,6 +320,43 @@ class TestConditionalPriorDict(unittest.TestCase): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) self.assertListEqual(expected, res) + def test_rescale_with_joint_prior(self): + """ + Add a joint prior into the conditional prior dictionary and check that + the returned list is flat. + """ + + # set multivariate Gaussian distribution + names = ["mvgvar_0", "mvgvar_1"] + mu = [[0.79, -0.83]] + cov = [[[0.03, 0.], [0., 0.04]]] + mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov) + + priordict = bilby.core.prior.ConditionalPriorDict( + dict( + var_3=self.prior_3, + var_2=self.prior_2, + var_0=self.prior_0, + var_1=self.prior_1, + mvgvar_0=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_0"), + mvgvar_1=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_1"), + ) + ) + + ref_variables = list(self.test_sample.values()) + [0.4, 0.1] + keys = list(self.test_sample.keys()) + names + res = priordict.rescale(keys=keys, theta=ref_variables) + + self.assertIsInstance(res, list) + self.assertEqual(np.shape(res), (6,)) + self.assertListEqual([isinstance(r, float) for r in res], 6 * [True]) + + # check conditional values are still as expected + expected = [self.test_sample["var_0"]] + for ii in range(1, 4): + expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) + self.assertListEqual(expected, res[0:4]) + def test_cdf(self): """ Test that the CDF method is the inverse of the rescale method.