diff --git a/CHANGELOG.md b/CHANGELOG.md index 074fe4176344047fc8719ce4f7114f6386dae378..8d7d5d2b4f8d4efbbd0877ab8293fb3f29ab6855 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ Changes currently on master, but not under a tag. - Fix interpretation of kwargs for dynesty - PowerSpectralDensity structure modified - Fixed bug in get_open_data +- .prior files are no longer created. The prior is stored in the result object. ### Removed - Removes the "--detectors" command line argument (not a general CLI requirement) diff --git a/tupak/core/prior.py b/tupak/core/prior.py index b42516afab61b7f8c2c97bce85d72cd05cae7b9d..f1c950f3ba70032e597d27bd87285e70a6ce1de2 100644 --- a/tupak/core/prior.py +++ b/tupak/core/prior.py @@ -7,6 +7,7 @@ from scipy.special import erf, erfinv import scipy.stats import os from collections import OrderedDict +from future.utils import iteritems from tupak.core.utils import logger from tupak.core import utils @@ -26,17 +27,17 @@ class PriorSet(OrderedDict): """ OrderedDict.__init__(self) if isinstance(dictionary, dict): - self.update(dictionary) + self.from_dictionary(dictionary) elif type(dictionary) is str: logger.debug('Argument "dictionary" is a string.' + ' Assuming it is intended as a file name.') - self.read_in_file(dictionary) + self.from_file(dictionary) elif type(filename) is str: - self.read_in_file(filename) + self.from_file(filename) elif dictionary is not None: - raise ValueError("PriorSet input dictionay not understood") + raise ValueError("PriorSet input dictionary not understood") - def write_to_file(self, outdir, label): + def to_file(self, outdir, label): """ Write the prior distribution to file. Parameters @@ -55,7 +56,7 @@ class PriorSet(OrderedDict): outfile.write( "{} = {}\n".format(key, self[key])) - def read_in_file(self, filename): + def from_file(self, filename): """ Reads in a prior from a file specification Parameters @@ -75,6 +76,20 @@ class PriorSet(OrderedDict): prior[key] = eval(val) self.update(prior) + def from_dictionary(self, dictionary): + for key, val in iteritems(dictionary): + if isinstance(val, str): + try: + prior = eval(val) + if isinstance(prior, Prior): + val = prior + except (NameError, SyntaxError, TypeError): + logger.debug( + "Failed to load dictionary value {} correctlty" + .format(key)) + pass + self[key] = val + def convert_floats_to_delta_functions(self): """ Convert all float parameters to delta functions """ for key in self: diff --git a/tupak/core/result.py b/tupak/core/result.py index 8ecf8e1a4b0103e72943ad8ade5074f77f2dca9b..3b14ca1dd61d87c950fc8550938c7df71ea6ebb3 100644 --- a/tupak/core/result.py +++ b/tupak/core/result.py @@ -10,7 +10,7 @@ from collections import OrderedDict from tupak.core import utils from tupak.core.utils import logger -from tupak.core.prior import DeltaFunction +from tupak.core.prior import PriorSet, DeltaFunction def result_file_name(outdir, label): @@ -70,12 +70,19 @@ class Result(dict): A dictionary containing values to be set in this instance """ + # Set some defaults + self.outdir = '.' + self.label = 'no_name' + dict.__init__(self) if type(dictionary) is dict: for key in dictionary: val = self._standardise_a_string(dictionary[key]) setattr(self, key, val) + if getattr(self, 'priors', None) is not None: + self.priors = PriorSet(self.priors) + def __add__(self, other): matches = ['sampler', 'search_parameter_keys'] for match in matches: @@ -171,8 +178,14 @@ class Result(dict): os.rename(file_name, file_name + '.old') logger.debug("Saving result to {}".format(file_name)) + + # Convert the prior to a string representation for saving on disk + dictionary = dict(self) + if dictionary.get('priors', False): + dictionary['priors'] = {key: str(self.priors[key]) for key in self.priors} + try: - deepdish.io.save(file_name, dict(self)) + deepdish.io.save(file_name, dictionary) except Exception as e: logger.error("\n\n Saving the data has failed with the " "following message:\n {} \n\n".format(e)) @@ -270,8 +283,8 @@ class Result(dict): string = r"${{{0}}}_{{-{1}}}^{{+{2}}}$" return string.format(fmt(median), fmt(lower), fmt(upper)) - def plot_corner(self, parameters=None, priors=None, titles=True, save=True, - filename=None, dpi=300, **kwargs): + def plot_corner(self, parameters=None, priors=False, titles=True, + save=True, filename=None, dpi=300, **kwargs): """ Plot a corner-plot using corner See https://corner.readthedocs.io/en/latest/ for a detailed API. @@ -280,9 +293,10 @@ class Result(dict): ---------- parameters: list, optional If given, a list of the parameter names to include - priors: tupak.core.prior.PriorSet - If given, add the prior probability density functions to the - one-dimensional marginal distributions + priors: {bool (False), tupak.core.prior.PriorSet} + If true, add the stored prior probability density functions to the + one-dimensional marginal distributions. If instead a PriorSet + is provided, this will be plotted. titles: bool If true, add 1D titles of the median and (by default 1-sigma) error bars. To change the error bars, pass in the quantiles kwarg. @@ -363,11 +377,17 @@ class Result(dict): **kwargs['title_kwargs']) # Add priors to the 1D plots - if priors is not None: + if priors is True: + priors = getattr(self, 'priors', False) + if isinstance(priors, dict): for i, par in enumerate(parameters): ax = axes[i + i * len(parameters)] theta = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 300) ax.plot(theta, priors[par].prob(theta), color='C2') + elif priors in [False, None]: + pass + else: + raise ValueError('Input priors={} not understood'.format(priors)) if save: if filename is None: diff --git a/tupak/core/sampler/__init__.py b/tupak/core/sampler/__init__.py index 12d961b0817e9760fc8a272fb29d65e48c638c30..d61f399d966dc4007a3dc6da2ad31aba491bc811 100644 --- a/tupak/core/sampler/__init__.py +++ b/tupak/core/sampler/__init__.py @@ -108,9 +108,6 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', priors.fill_priors(likelihood, default_priors_file=default_priors_file) - if save: - priors.write_to_file(outdir, label) - if isinstance(sampler, Sampler): pass elif isinstance(sampler, str): @@ -148,6 +145,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', if type(meta_data) == dict: result.update(meta_data) + result.priors = priors + end_time = datetime.datetime.now() result.sampling_time = (end_time - start_time).total_seconds() logger.info('Sampling time: {}'.format(end_time - start_time))