diff --git a/bilby/core/result.py b/bilby/core/result.py index 8236f24b87cbc98728ad005f94f411f8e9814aa2..cefef401d1b7d2468abb1609dbec50e3b134b1e9 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -21,8 +21,9 @@ from .utils import ( logger, infer_parameters_from_function, check_directory_exists_and_if_not_mkdir, latex_plot_format, safe_save_figure, + BilbyJsonEncoder, load_json, + move_old_file, get_version_information ) -from .utils import BilbyJsonEncoder, load_json, move_old_file from .prior import Prior, PriorDict, DeltaFunction @@ -1237,7 +1238,7 @@ class Result(object): return self._kde def posterior_probability(self, sample): - """ Calculate the posterior probabily for a new sample + """ Calculate the posterior probability for a new sample This queries a Kernel Density Estimate of the posterior to calculate the posterior probability density for the new sample. @@ -1315,6 +1316,75 @@ class Result(object): return weights + def to_arviz(self, prior=None): + """ Convert the Result object to an ArviZ InferenceData object. + + Parameters + ---------- + prior: int + If a positive integer is given then that number of prior + samples will be drawn and stored in the ArviZ InferenceData + object. + + Returns + ------- + azdata: InferenceData + The ArviZ InferenceData object. + """ + + try: + import arviz as az + except ImportError: + logger.debug( + "ArviZ is not installed, so cannot convert to InferenceData" + ) + + posdict = {} + for key in self.posterior: + posdict[key] = self.posterior[key].values + + if "log_likelihood" in posdict: + loglikedict = { + "log_likelihood": posdict.pop("log_likelihood") + } + else: + if self.log_likelihood_evaluations is not None: + loglikedict = { + "log_likelihood": self.log_likelihood_evaluations + } + else: + loglikedict = None + + priorsamples = None + if prior is not None: + if self.priors is None: + logger.warning( + "No priors are in the Result object, so prior samples " + "will not be included in the output." + ) + else: + priorsamples = self.priors.sample(size=prior) + + azdata = az.from_dict( + posterior=posdict, + log_likelihood=loglikedict, + prior=priorsamples, + ) + + # add attributes + version = { + "inference_library": "bilby: {}".format(self.sampler), + "inference_library_version": get_version_information() + } + + azdata.posterior.attrs.update(version) + if "log_likelihood" in azdata._groups: + azdata.log_likelihood.attrs.update(version) + if "prior" in azdata._groups: + azdata.prior.attrs.update(version) + + return azdata + class ResultList(list): diff --git a/test/result_test.py b/test/result_test.py index 991456928c91324bd5c25d6ec4356a77c4c88a6b..b49f73435e86ada39fbeeee567d5c57e3561c069 100644 --- a/test/result_test.py +++ b/test/result_test.py @@ -505,6 +505,43 @@ class TestResult(unittest.TestCase): ) ) + def test_to_arviz(self): + with self.assertRaises(TypeError): + self.result.to_arviz(prior=dict()) + + Nprior = 100 + + log_likelihood = np.random.rand(len(self.result.posterior)) + self.result.log_likelihood_evaluations = log_likelihood + + az = self.result.to_arviz(prior=Nprior) + + self.assertTrue("x" in az.posterior and "y" in az.posterior) + for var in ["x", "y"]: + self.assertTrue(np.array_equal(az.posterior[var].values.squeeze(), + self.result.posterior[var].values)) + self.assertTrue(len(az.prior[var][0]) == Nprior) + + self.assertTrue(np.array_equal(az.log_likelihood["log_likelihood"].values.squeeze(), + log_likelihood)) + + self.assertTrue( + az.posterior.attrs["inference_library"] == "bilby: {}".format( + self.result.sampler + ) + ) + self.assertTrue( + az.posterior.attrs["inference_library_version"] + == bilby.utils.get_version_information() + ) + + # add log likelihood to samples and extract from there + del az + self.result.posterior["log_likelihood"] = log_likelihood + az = self.result.to_arviz() + self.assertTrue(np.array_equal(az.log_likelihood["log_likelihood"].values.squeeze(), + log_likelihood)) + class TestResultListError(unittest.TestCase): def setUp(self):