Skip to content
Snippets Groups Projects
Commit cf446ec8 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improve plot_corner

- Add ability to pass in parameters as either keys or latex labels
- Update truth to not overwrite the given kwarg, but preserve "plot
  injections by default" behaviour.
- Add function to convert a set of parameter_keys into the latex labels
  used by CC
parent d8b3eedd
No related branches found
No related tags found
No related merge requests found
......@@ -44,6 +44,19 @@ class Result(dict):
"\n\n Saving the data has failed with the following message:\n {} \n\n"
.format(e))
def get_latex_labels_from_parameter_keys(self, keys):
return_list = []
for k in keys:
if k in self.search_parameter_keys:
idx = self.search_parameter_keys.index(k)
return_list.append(self.parameter_labels[idx])
elif k in self.parameter_labels:
return_list.append(k)
else:
raise ValueError('key {} not a parameter label or latex label'
.format(k))
return return_list
def plot_corner(self, save=True, **kwargs):
""" Plot a corner-plot using chain-consumer
......@@ -61,12 +74,27 @@ class Result(dict):
# Set some defaults (unless already set)
kwargs['figsize'] = kwargs.get('figsize', 'GROW')
if save:
kwargs['filename'] = '{}/{}_corner.png'.format(self.outdir, self.label)
filename = '{}/{}_corner.png'.format(self.outdir, self.label)
kwargs['filename'] = kwargs.get('filename', filename)
logging.info('Saving corner plot to {}'.format(kwargs['filename']))
if self.injection_parameters is not None:
kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys]
# If no truth argument given, set these to the injection params
injection_parameters = [self.injection_parameters[key]
for key in self.search_parameter_keys]
kwargs['truth'] = kwargs.get('truth', injection_parameters)
if type(kwargs.get('truth')) == dict:
old_keys = kwargs['truth'].keys()
new_keys = self.get_latex_labels_from_parameter_keys(old_keys)
for old, new in zip(old_keys, new_keys):
kwargs['truth'][new] = kwargs['truth'].pop(old)
if 'parameters' in kwargs:
kwargs['parameters'] = self.get_latex_labels_from_parameter_keys(
kwargs['parameters'])
c = ChainConsumer()
c.add_chain(self.samples, parameters=self.parameter_labels)
c.add_chain(self.samples, parameters=self.parameter_labels,
name=self.label)
fig = c.plotter.plot(**kwargs)
return fig
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment