Commit 9cffdefd authored by Gregory Ashton's avatar Gregory Ashton

Resolve "`plot_corner` should have a straight forward way to not include truths in injection runs."

parent 8a1662db
......@@ -844,6 +844,12 @@ class Result(object):
python module, see https://corner.readthedocs.io for more
information.
Truth-lines can be passed in in several ways. Either as the values
of the parameters dict, or a list via the `truths` kwarg. If
injection_parameters where given to run_sampler, these will auto-
matically be added to the plot. This behaviour can be stopped by
adding truths=False.
Returns
-------
fig:
......@@ -879,7 +885,7 @@ class Result(object):
# Handle if truths was passed in
if 'truth' in kwargs:
kwargs['truths'] = kwargs.pop('truth')
if kwargs.get('truths'):
if "truths" in kwargs:
truths = kwargs.get('truths')
if isinstance(parameters, list) and isinstance(truths, list):
if len(parameters) != len(truths):
......@@ -887,6 +893,10 @@ class Result(object):
"Length of parameters and truths don't match")
elif isinstance(truths, dict) and parameters is None:
parameters = kwargs.pop('truths')
elif isinstance(truths, bool):
pass
elif truths is None:
kwargs["truths"] = False
else:
raise ValueError(
"Combination of parameters and truths not understood")
......@@ -895,7 +905,8 @@ class Result(object):
# but do not overwrite input parameters (or truths)
cond1 = getattr(self, 'injection_parameters', None) is not None
cond2 = parameters is None
if cond1 and cond2:
cond3 = bool(kwargs.get("truths", True))
if cond1 and cond2 and cond3:
parameters = {key: self.injection_parameters[key] for key in
self.search_parameter_keys}
......@@ -918,6 +929,10 @@ class Result(object):
# This prevents ValueErrors being raised for parameters with no range
kwargs['range'] = kwargs.get('range', [1] * len(plot_parameter_keys))
# Remove truths if it is a bool
if isinstance(kwargs.get('truths'), bool):
kwargs.pop('truths')
# Create the data array to plot and pass everything to corner
xs = self.posterior[plot_parameter_keys].values
fig = corner.corner(xs, **kwargs)
......
......@@ -347,6 +347,9 @@ class TestResult(unittest.TestCase):
self.result.plot_corner(parameters=dict(x=1, y=1))
self.result.plot_corner(truths=dict(x=1, y=1))
self.result.plot_corner(truth=dict(x=1, y=1))
self.result.plot_corner(truths=None)
self.result.plot_corner(truths=False)
self.result.plot_corner(truths=True)
with self.assertRaises(ValueError):
self.result.plot_corner(truths=dict(x=1, y=1),
parameters=dict(x=1, y=1))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment