diff --git a/bilby/core/result.py b/bilby/core/result.py index 82eedcc80339d67c4d9d26b1147a75a12ba94ae2..8b4d97dda30734a85eef823b55fa0787f77cc9fb 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -844,6 +844,12 @@ class Result(object): python module, see https://corner.readthedocs.io for more information. + Truth-lines can be passed in in several ways. Either as the values + of the parameters dict, or a list via the `truths` kwarg. If + injection_parameters where given to run_sampler, these will auto- + matically be added to the plot. This behaviour can be stopped by + adding truths=False. + Returns ------- fig: @@ -879,7 +885,7 @@ class Result(object): # Handle if truths was passed in if 'truth' in kwargs: kwargs['truths'] = kwargs.pop('truth') - if kwargs.get('truths'): + if "truths" in kwargs: truths = kwargs.get('truths') if isinstance(parameters, list) and isinstance(truths, list): if len(parameters) != len(truths): @@ -887,6 +893,10 @@ class Result(object): "Length of parameters and truths don't match") elif isinstance(truths, dict) and parameters is None: parameters = kwargs.pop('truths') + elif isinstance(truths, bool): + pass + elif truths is None: + kwargs["truths"] = False else: raise ValueError( "Combination of parameters and truths not understood") @@ -895,7 +905,8 @@ class Result(object): # but do not overwrite input parameters (or truths) cond1 = getattr(self, 'injection_parameters', None) is not None cond2 = parameters is None - if cond1 and cond2: + cond3 = bool(kwargs.get("truths", True)) + if cond1 and cond2 and cond3: parameters = {key: self.injection_parameters[key] for key in self.search_parameter_keys} @@ -918,6 +929,10 @@ class Result(object): # This prevents ValueErrors being raised for parameters with no range kwargs['range'] = kwargs.get('range', [1] * len(plot_parameter_keys)) + # Remove truths if it is a bool + if isinstance(kwargs.get('truths'), bool): + kwargs.pop('truths') + # Create the data array to plot and pass everything to corner xs = self.posterior[plot_parameter_keys].values fig = corner.corner(xs, **kwargs) diff --git a/test/result_test.py b/test/result_test.py index 32a7e34ca4ea3f5ad234fbb28ca2bb9aa04e4c73..c49f2ad182a568eb2816cbb6b3f12982dfc29397 100644 --- a/test/result_test.py +++ b/test/result_test.py @@ -347,6 +347,9 @@ class TestResult(unittest.TestCase): self.result.plot_corner(parameters=dict(x=1, y=1)) self.result.plot_corner(truths=dict(x=1, y=1)) self.result.plot_corner(truth=dict(x=1, y=1)) + self.result.plot_corner(truths=None) + self.result.plot_corner(truths=False) + self.result.plot_corner(truths=True) with self.assertRaises(ValueError): self.result.plot_corner(truths=dict(x=1, y=1), parameters=dict(x=1, y=1))