Skip to content
Snippets Groups Projects
Commit a1afc5e9 authored by Matthew David Pitkin's avatar Matthew David Pitkin Committed by Colm Talbot
Browse files

Resolve "Allow results to be returned as an ArviZ InferenceData object"

parent dded5f6c
Branches release/2.0.x
No related tags found
No related merge requests found
......@@ -21,8 +21,9 @@ from .utils import (
logger, infer_parameters_from_function,
check_directory_exists_and_if_not_mkdir,
latex_plot_format, safe_save_figure,
BilbyJsonEncoder, load_json,
move_old_file, get_version_information
)
from .utils import BilbyJsonEncoder, load_json, move_old_file
from .prior import Prior, PriorDict, DeltaFunction
......@@ -1237,7 +1238,7 @@ class Result(object):
return self._kde
def posterior_probability(self, sample):
""" Calculate the posterior probabily for a new sample
""" Calculate the posterior probability for a new sample
This queries a Kernel Density Estimate of the posterior to calculate
the posterior probability density for the new sample.
......@@ -1315,6 +1316,75 @@ class Result(object):
return weights
def to_arviz(self, prior=None):
""" Convert the Result object to an ArviZ InferenceData object.
Parameters
----------
prior: int
If a positive integer is given then that number of prior
samples will be drawn and stored in the ArviZ InferenceData
object.
Returns
-------
azdata: InferenceData
The ArviZ InferenceData object.
"""
try:
import arviz as az
except ImportError:
logger.debug(
"ArviZ is not installed, so cannot convert to InferenceData"
)
posdict = {}
for key in self.posterior:
posdict[key] = self.posterior[key].values
if "log_likelihood" in posdict:
loglikedict = {
"log_likelihood": posdict.pop("log_likelihood")
}
else:
if self.log_likelihood_evaluations is not None:
loglikedict = {
"log_likelihood": self.log_likelihood_evaluations
}
else:
loglikedict = None
priorsamples = None
if prior is not None:
if self.priors is None:
logger.warning(
"No priors are in the Result object, so prior samples "
"will not be included in the output."
)
else:
priorsamples = self.priors.sample(size=prior)
azdata = az.from_dict(
posterior=posdict,
log_likelihood=loglikedict,
prior=priorsamples,
)
# add attributes
version = {
"inference_library": "bilby: {}".format(self.sampler),
"inference_library_version": get_version_information()
}
azdata.posterior.attrs.update(version)
if "log_likelihood" in azdata._groups:
azdata.log_likelihood.attrs.update(version)
if "prior" in azdata._groups:
azdata.prior.attrs.update(version)
return azdata
class ResultList(list):
......
......@@ -505,6 +505,43 @@ class TestResult(unittest.TestCase):
)
)
def test_to_arviz(self):
with self.assertRaises(TypeError):
self.result.to_arviz(prior=dict())
Nprior = 100
log_likelihood = np.random.rand(len(self.result.posterior))
self.result.log_likelihood_evaluations = log_likelihood
az = self.result.to_arviz(prior=Nprior)
self.assertTrue("x" in az.posterior and "y" in az.posterior)
for var in ["x", "y"]:
self.assertTrue(np.array_equal(az.posterior[var].values.squeeze(),
self.result.posterior[var].values))
self.assertTrue(len(az.prior[var][0]) == Nprior)
self.assertTrue(np.array_equal(az.log_likelihood["log_likelihood"].values.squeeze(),
log_likelihood))
self.assertTrue(
az.posterior.attrs["inference_library"] == "bilby: {}".format(
self.result.sampler
)
)
self.assertTrue(
az.posterior.attrs["inference_library_version"]
== bilby.utils.get_version_information()
)
# add log likelihood to samples and extract from there
del az
self.result.posterior["log_likelihood"] = log_likelihood
az = self.result.to_arviz()
self.assertTrue(np.array_equal(az.log_likelihood["log_likelihood"].values.squeeze(),
log_likelihood))
class TestResultListError(unittest.TestCase):
def setUp(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment