Commit 29817d24 authored by Gregory Ashton's avatar Gregory Ashton Committed by Moritz Huebner

Handle priors as part of the result object and several other minor fixes

parent 988c3589
......@@ -25,6 +25,7 @@ Changes currently on master, but not under a tag.
- Fix interpretation of kwargs for dynesty
- PowerSpectralDensity structure modified
- Fixed bug in get_open_data
- .prior files are no longer created. The prior is stored in the result object.
### Removed
- Removes the "--detectors" command line argument (not a general CLI requirement)
......@@ -7,6 +7,7 @@ from scipy.special import erf, erfinv
import scipy.stats
import os
from collections import OrderedDict
from future.utils import iteritems
from tupak.core.utils import logger
from tupak.core import utils
......@@ -26,17 +27,17 @@ class PriorSet(OrderedDict):
if isinstance(dictionary, dict):
elif type(dictionary) is str:
logger.debug('Argument "dictionary" is a string.' +
' Assuming it is intended as a file name.')
elif type(filename) is str:
elif dictionary is not None:
raise ValueError("PriorSet input dictionay not understood")
raise ValueError("PriorSet input dictionary not understood")
def write_to_file(self, outdir, label):
def to_file(self, outdir, label):
""" Write the prior distribution to file.
......@@ -55,7 +56,7 @@ class PriorSet(OrderedDict):
"{} = {}\n".format(key, self[key]))
def read_in_file(self, filename):
def from_file(self, filename):
""" Reads in a prior from a file specification
......@@ -75,6 +76,20 @@ class PriorSet(OrderedDict):
prior[key] = eval(val)
def from_dictionary(self, dictionary):
for key, val in iteritems(dictionary):
if isinstance(val, str):
prior = eval(val)
if isinstance(prior, Prior):
val = prior
except (NameError, SyntaxError, TypeError):
"Failed to load dictionary value {} correctlty"
self[key] = val
def convert_floats_to_delta_functions(self):
""" Convert all float parameters to delta functions """
for key in self:
......@@ -10,7 +10,7 @@ from collections import OrderedDict
from tupak.core import utils
from tupak.core.utils import logger
from tupak.core.prior import DeltaFunction
from tupak.core.prior import PriorSet, DeltaFunction
def result_file_name(outdir, label):
......@@ -70,12 +70,19 @@ class Result(dict):
A dictionary containing values to be set in this instance
# Set some defaults
self.outdir = '.'
self.label = 'no_name'
if type(dictionary) is dict:
for key in dictionary:
val = self._standardise_a_string(dictionary[key])
setattr(self, key, val)
if getattr(self, 'priors', None) is not None:
self.priors = PriorSet(self.priors)
def __add__(self, other):
matches = ['sampler', 'search_parameter_keys']
for match in matches:
......@@ -171,8 +178,14 @@ class Result(dict):
os.rename(file_name, file_name + '.old')
logger.debug("Saving result to {}".format(file_name))
# Convert the prior to a string representation for saving on disk
dictionary = dict(self)
if dictionary.get('priors', False):
dictionary['priors'] = {key: str(self.priors[key]) for key in self.priors}
try:, dict(self)), dictionary)
except Exception as e:
logger.error("\n\n Saving the data has failed with the "
"following message:\n {} \n\n".format(e))
......@@ -270,8 +283,8 @@ class Result(dict):
string = r"${{{0}}}_{{-{1}}}^{{+{2}}}$"
return string.format(fmt(median), fmt(lower), fmt(upper))
def plot_corner(self, parameters=None, priors=None, titles=True, save=True,
filename=None, dpi=300, **kwargs):
def plot_corner(self, parameters=None, priors=False, titles=True,
save=True, filename=None, dpi=300, **kwargs):
""" Plot a corner-plot using corner
See for a detailed API.
......@@ -280,9 +293,10 @@ class Result(dict):
parameters: list, optional
If given, a list of the parameter names to include
priors: tupak.core.prior.PriorSet
If given, add the prior probability density functions to the
one-dimensional marginal distributions
priors: {bool (False), tupak.core.prior.PriorSet}
If true, add the stored prior probability density functions to the
one-dimensional marginal distributions. If instead a PriorSet
is provided, this will be plotted.
titles: bool
If true, add 1D titles of the median and (by default 1-sigma)
error bars. To change the error bars, pass in the quantiles kwarg.
......@@ -363,11 +377,17 @@ class Result(dict):
# Add priors to the 1D plots
if priors is not None:
if priors is True:
priors = getattr(self, 'priors', False)
if isinstance(priors, dict):
for i, par in enumerate(parameters):
ax = axes[i + i * len(parameters)]
theta = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 300)
ax.plot(theta, priors[par].prob(theta), color='C2')
elif priors in [False, None]:
raise ValueError('Input priors={} not understood'.format(priors))
if save:
if filename is None:
......@@ -108,9 +108,6 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
priors.fill_priors(likelihood, default_priors_file=default_priors_file)
if save:
priors.write_to_file(outdir, label)
if isinstance(sampler, Sampler):
elif isinstance(sampler, str):
......@@ -148,6 +145,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
if type(meta_data) == dict:
result.priors = priors
end_time =
result.sampling_time = (end_time - start_time).total_seconds()'Sampling time: {}'.format(end_time - start_time))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment