Commit cf446ec8 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improve plot_corner

- Add ability to pass in parameters as either keys or latex labels
- Update truth to not overwrite the given kwarg, but preserve "plot
  injections by default" behaviour.
- Add function to convert a set of parameter_keys into the latex labels
  used by CC
parent d8b3eedd
...@@ -44,6 +44,19 @@ class Result(dict): ...@@ -44,6 +44,19 @@ class Result(dict):
"\n\n Saving the data has failed with the following message:\n {} \n\n" "\n\n Saving the data has failed with the following message:\n {} \n\n"
.format(e)) .format(e))
def get_latex_labels_from_parameter_keys(self, keys):
return_list = []
for k in keys:
if k in self.search_parameter_keys:
idx = self.search_parameter_keys.index(k)
elif k in self.parameter_labels:
raise ValueError('key {} not a parameter label or latex label'
return return_list
def plot_corner(self, save=True, **kwargs): def plot_corner(self, save=True, **kwargs):
""" Plot a corner-plot using chain-consumer """ Plot a corner-plot using chain-consumer
...@@ -61,12 +74,27 @@ class Result(dict): ...@@ -61,12 +74,27 @@ class Result(dict):
# Set some defaults (unless already set) # Set some defaults (unless already set)
kwargs['figsize'] = kwargs.get('figsize', 'GROW') kwargs['figsize'] = kwargs.get('figsize', 'GROW')
if save: if save:
kwargs['filename'] = '{}/{}_corner.png'.format(self.outdir, self.label) filename = '{}/{}_corner.png'.format(self.outdir, self.label)
kwargs['filename'] = kwargs.get('filename', filename)'Saving corner plot to {}'.format(kwargs['filename']))'Saving corner plot to {}'.format(kwargs['filename']))
if self.injection_parameters is not None: if self.injection_parameters is not None:
kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys] # If no truth argument given, set these to the injection params
injection_parameters = [self.injection_parameters[key]
for key in self.search_parameter_keys]
kwargs['truth'] = kwargs.get('truth', injection_parameters)
if type(kwargs.get('truth')) == dict:
old_keys = kwargs['truth'].keys()
new_keys = self.get_latex_labels_from_parameter_keys(old_keys)
for old, new in zip(old_keys, new_keys):
kwargs['truth'][new] = kwargs['truth'].pop(old)
if 'parameters' in kwargs:
kwargs['parameters'] = self.get_latex_labels_from_parameter_keys(
c = ChainConsumer() c = ChainConsumer()
c.add_chain(self.samples, parameters=self.parameter_labels) c.add_chain(self.samples, parameters=self.parameter_labels,
fig = c.plotter.plot(**kwargs) fig = c.plotter.plot(**kwargs)
return fig return fig
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment