Skip to content
Snippets Groups Projects

Resolve "Allow results to be returned as an ArviZ InferenceData object"

Merged Matthew David Pitkin requested to merge matthew-pitkin/bilby:arviz into master
All threads resolved!
Files
2
+ 72
2
@@ -21,8 +21,9 @@ from .utils import (
logger, infer_parameters_from_function,
check_directory_exists_and_if_not_mkdir,
latex_plot_format, safe_save_figure,
BilbyJsonEncoder, load_json,
move_old_file, get_version_information
)
from .utils import BilbyJsonEncoder, load_json, move_old_file
from .prior import Prior, PriorDict, DeltaFunction
@@ -1237,7 +1238,7 @@ class Result(object):
return self._kde
def posterior_probability(self, sample):
""" Calculate the posterior probabily for a new sample
""" Calculate the posterior probability for a new sample
This queries a Kernel Density Estimate of the posterior to calculate
the posterior probability density for the new sample.
@@ -1315,6 +1316,75 @@ class Result(object):
return weights
def to_arviz(self, prior=None):
""" Convert the Result object to an ArviZ InferenceData object.
Parameters
----------
prior: int
If a positive integer is given then that number of prior
samples will be drawn and stored in the ArviZ InferenceData
object.
Returns
-------
azdata: InferenceData
The ArviZ InferenceData object.
"""
try:
import arviz as az
except ImportError:
logger.debug(
"ArviZ is not installed, so cannot convert to InferenceData"
)
posdict = {}
for key in self.posterior:
posdict[key] = self.posterior[key].values
if "log_likelihood" in posdict:
loglikedict = {
"log_likelihood": posdict.pop("log_likelihood")
}
else:
if self.log_likelihood_evaluations is not None:
loglikedict = {
"log_likelihood": self.log_likelihood_evaluations
}
else:
loglikedict = None
priorsamples = None
if prior is not None:
if self.priors is None:
logger.warning(
"No priors are in the Result object, so prior samples "
"will not be included in the output."
)
else:
priorsamples = self.priors.sample(size=prior)
azdata = az.from_dict(
posterior=posdict,
log_likelihood=loglikedict,
prior=priorsamples,
)
# add attributes
version = {
"inference_library": "bilby: {}".format(self.sampler),
"inference_library_version": get_version_information()
}
azdata.posterior.attrs.update(version)
if "log_likelihood" in azdata._groups:
azdata.log_likelihood.attrs.update(version)
if "prior" in azdata._groups:
azdata.prior.attrs.update(version)
return azdata
class ResultList(list):
Loading