Skip to content
Snippets Groups Projects
Commit 29817d24 authored by Gregory Ashton's avatar Gregory Ashton Committed by Moritz Huebner
Browse files

Handle priors as part of the result object and several other minor fixes

parent 988c3589
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,7 @@ Changes currently on master, but not under a tag.
- Fix interpretation of kwargs for dynesty
- PowerSpectralDensity structure modified
- Fixed bug in get_open_data
- .prior files are no longer created. The prior is stored in the result object.
### Removed
- Removes the "--detectors" command line argument (not a general CLI requirement)
......
......@@ -7,6 +7,7 @@ from scipy.special import erf, erfinv
import scipy.stats
import os
from collections import OrderedDict
from future.utils import iteritems
from tupak.core.utils import logger
from tupak.core import utils
......@@ -26,17 +27,17 @@ class PriorSet(OrderedDict):
"""
OrderedDict.__init__(self)
if isinstance(dictionary, dict):
self.update(dictionary)
self.from_dictionary(dictionary)
elif type(dictionary) is str:
logger.debug('Argument "dictionary" is a string.' +
' Assuming it is intended as a file name.')
self.read_in_file(dictionary)
self.from_file(dictionary)
elif type(filename) is str:
self.read_in_file(filename)
self.from_file(filename)
elif dictionary is not None:
raise ValueError("PriorSet input dictionay not understood")
raise ValueError("PriorSet input dictionary not understood")
def write_to_file(self, outdir, label):
def to_file(self, outdir, label):
""" Write the prior distribution to file.
Parameters
......@@ -55,7 +56,7 @@ class PriorSet(OrderedDict):
outfile.write(
"{} = {}\n".format(key, self[key]))
def read_in_file(self, filename):
def from_file(self, filename):
""" Reads in a prior from a file specification
Parameters
......@@ -75,6 +76,20 @@ class PriorSet(OrderedDict):
prior[key] = eval(val)
self.update(prior)
def from_dictionary(self, dictionary):
for key, val in iteritems(dictionary):
if isinstance(val, str):
try:
prior = eval(val)
if isinstance(prior, Prior):
val = prior
except (NameError, SyntaxError, TypeError):
logger.debug(
"Failed to load dictionary value {} correctlty"
.format(key))
pass
self[key] = val
def convert_floats_to_delta_functions(self):
""" Convert all float parameters to delta functions """
for key in self:
......
......@@ -10,7 +10,7 @@ from collections import OrderedDict
from tupak.core import utils
from tupak.core.utils import logger
from tupak.core.prior import DeltaFunction
from tupak.core.prior import PriorSet, DeltaFunction
def result_file_name(outdir, label):
......@@ -70,12 +70,19 @@ class Result(dict):
A dictionary containing values to be set in this instance
"""
# Set some defaults
self.outdir = '.'
self.label = 'no_name'
dict.__init__(self)
if type(dictionary) is dict:
for key in dictionary:
val = self._standardise_a_string(dictionary[key])
setattr(self, key, val)
if getattr(self, 'priors', None) is not None:
self.priors = PriorSet(self.priors)
def __add__(self, other):
matches = ['sampler', 'search_parameter_keys']
for match in matches:
......@@ -171,8 +178,14 @@ class Result(dict):
os.rename(file_name, file_name + '.old')
logger.debug("Saving result to {}".format(file_name))
# Convert the prior to a string representation for saving on disk
dictionary = dict(self)
if dictionary.get('priors', False):
dictionary['priors'] = {key: str(self.priors[key]) for key in self.priors}
try:
deepdish.io.save(file_name, dict(self))
deepdish.io.save(file_name, dictionary)
except Exception as e:
logger.error("\n\n Saving the data has failed with the "
"following message:\n {} \n\n".format(e))
......@@ -270,8 +283,8 @@ class Result(dict):
string = r"${{{0}}}_{{-{1}}}^{{+{2}}}$"
return string.format(fmt(median), fmt(lower), fmt(upper))
def plot_corner(self, parameters=None, priors=None, titles=True, save=True,
filename=None, dpi=300, **kwargs):
def plot_corner(self, parameters=None, priors=False, titles=True,
save=True, filename=None, dpi=300, **kwargs):
""" Plot a corner-plot using corner
See https://corner.readthedocs.io/en/latest/ for a detailed API.
......@@ -280,9 +293,10 @@ class Result(dict):
----------
parameters: list, optional
If given, a list of the parameter names to include
priors: tupak.core.prior.PriorSet
If given, add the prior probability density functions to the
one-dimensional marginal distributions
priors: {bool (False), tupak.core.prior.PriorSet}
If true, add the stored prior probability density functions to the
one-dimensional marginal distributions. If instead a PriorSet
is provided, this will be plotted.
titles: bool
If true, add 1D titles of the median and (by default 1-sigma)
error bars. To change the error bars, pass in the quantiles kwarg.
......@@ -363,11 +377,17 @@ class Result(dict):
**kwargs['title_kwargs'])
# Add priors to the 1D plots
if priors is not None:
if priors is True:
priors = getattr(self, 'priors', False)
if isinstance(priors, dict):
for i, par in enumerate(parameters):
ax = axes[i + i * len(parameters)]
theta = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 300)
ax.plot(theta, priors[par].prob(theta), color='C2')
elif priors in [False, None]:
pass
else:
raise ValueError('Input priors={} not understood'.format(priors))
if save:
if filename is None:
......
......@@ -108,9 +108,6 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
priors.fill_priors(likelihood, default_priors_file=default_priors_file)
if save:
priors.write_to_file(outdir, label)
if isinstance(sampler, Sampler):
pass
elif isinstance(sampler, str):
......@@ -148,6 +145,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
if type(meta_data) == dict:
result.update(meta_data)
result.priors = priors
end_time = datetime.datetime.now()
result.sampling_time = (end_time - start_time).total_seconds()
logger.info('Sampling time: {}'.format(end_time - start_time))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment