Commit 4a6817c6 authored by Moritz Huebner's avatar Moritz Huebner

Merge branch 'add-priors-to-result' into 'master'

Handle priors as part of the result object and several other minor fixes

See merge request Monash/tupak!198
parents 988c3589 29817d24
Pipeline #31675 passed with stages
in 4 minutes and 38 seconds
......@@ -25,6 +25,7 @@ Changes currently on master, but not under a tag.
- Fix interpretation of kwargs for dynesty
- PowerSpectralDensity structure modified
- Fixed bug in get_open_data
- .prior files are no longer created. The prior is stored in the result object.
### Removed
- Removes the "--detectors" command line argument (not a general CLI requirement)
......
......@@ -7,6 +7,7 @@ from scipy.special import erf, erfinv
import scipy.stats
import os
from collections import OrderedDict
from future.utils import iteritems
from tupak.core.utils import logger
from tupak.core import utils
......@@ -26,17 +27,17 @@ class PriorSet(OrderedDict):
"""
OrderedDict.__init__(self)
if isinstance(dictionary, dict):
self.update(dictionary)
self.from_dictionary(dictionary)
elif type(dictionary) is str:
logger.debug('Argument "dictionary" is a string.' +
' Assuming it is intended as a file name.')
self.read_in_file(dictionary)
self.from_file(dictionary)
elif type(filename) is str:
self.read_in_file(filename)
self.from_file(filename)
elif dictionary is not None:
raise ValueError("PriorSet input dictionay not understood")
raise ValueError("PriorSet input dictionary not understood")
def write_to_file(self, outdir, label):
def to_file(self, outdir, label):
""" Write the prior distribution to file.
Parameters
......@@ -55,7 +56,7 @@ class PriorSet(OrderedDict):
outfile.write(
"{} = {}\n".format(key, self[key]))
def read_in_file(self, filename):
def from_file(self, filename):
""" Reads in a prior from a file specification
Parameters
......@@ -75,6 +76,20 @@ class PriorSet(OrderedDict):
prior[key] = eval(val)
self.update(prior)
def from_dictionary(self, dictionary):
for key, val in iteritems(dictionary):
if isinstance(val, str):
try:
prior = eval(val)
if isinstance(prior, Prior):
val = prior
except (NameError, SyntaxError, TypeError):
logger.debug(
"Failed to load dictionary value {} correctlty"
.format(key))
pass
self[key] = val
def convert_floats_to_delta_functions(self):
""" Convert all float parameters to delta functions """
for key in self:
......
......@@ -10,7 +10,7 @@ from collections import OrderedDict
from tupak.core import utils
from tupak.core.utils import logger
from tupak.core.prior import DeltaFunction
from tupak.core.prior import PriorSet, DeltaFunction
def result_file_name(outdir, label):
......@@ -70,12 +70,19 @@ class Result(dict):
A dictionary containing values to be set in this instance
"""
# Set some defaults
self.outdir = '.'
self.label = 'no_name'
dict.__init__(self)
if type(dictionary) is dict:
for key in dictionary:
val = self._standardise_a_string(dictionary[key])
setattr(self, key, val)
if getattr(self, 'priors', None) is not None:
self.priors = PriorSet(self.priors)
def __add__(self, other):
matches = ['sampler', 'search_parameter_keys']
for match in matches:
......@@ -171,8 +178,14 @@ class Result(dict):
os.rename(file_name, file_name + '.old')
logger.debug("Saving result to {}".format(file_name))
# Convert the prior to a string representation for saving on disk
dictionary = dict(self)
if dictionary.get('priors', False):
dictionary['priors'] = {key: str(self.priors[key]) for key in self.priors}
try:
deepdish.io.save(file_name, dict(self))
deepdish.io.save(file_name, dictionary)
except Exception as e:
logger.error("\n\n Saving the data has failed with the "
"following message:\n {} \n\n".format(e))
......@@ -270,8 +283,8 @@ class Result(dict):
string = r"${{{0}}}_{{-{1}}}^{{+{2}}}$"
return string.format(fmt(median), fmt(lower), fmt(upper))
def plot_corner(self, parameters=None, priors=None, titles=True, save=True,
filename=None, dpi=300, **kwargs):
def plot_corner(self, parameters=None, priors=False, titles=True,
save=True, filename=None, dpi=300, **kwargs):
""" Plot a corner-plot using corner
See https://corner.readthedocs.io/en/latest/ for a detailed API.
......@@ -280,9 +293,10 @@ class Result(dict):
----------
parameters: list, optional
If given, a list of the parameter names to include
priors: tupak.core.prior.PriorSet
If given, add the prior probability density functions to the
one-dimensional marginal distributions
priors: {bool (False), tupak.core.prior.PriorSet}
If true, add the stored prior probability density functions to the
one-dimensional marginal distributions. If instead a PriorSet
is provided, this will be plotted.
titles: bool
If true, add 1D titles of the median and (by default 1-sigma)
error bars. To change the error bars, pass in the quantiles kwarg.
......@@ -363,11 +377,17 @@ class Result(dict):
**kwargs['title_kwargs'])
# Add priors to the 1D plots
if priors is not None:
if priors is True:
priors = getattr(self, 'priors', False)
if isinstance(priors, dict):
for i, par in enumerate(parameters):
ax = axes[i + i * len(parameters)]
theta = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 300)
ax.plot(theta, priors[par].prob(theta), color='C2')
elif priors in [False, None]:
pass
else:
raise ValueError('Input priors={} not understood'.format(priors))
if save:
if filename is None:
......
......@@ -108,9 +108,6 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
priors.fill_priors(likelihood, default_priors_file=default_priors_file)
if save:
priors.write_to_file(outdir, label)
if isinstance(sampler, Sampler):
pass
elif isinstance(sampler, str):
......@@ -148,6 +145,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
if type(meta_data) == dict:
result.update(meta_data)
result.priors = priors
end_time = datetime.datetime.now()
result.sampling_time = (end_time - start_time).total_seconds()
logger.info('Sampling time: {}'.format(end_time - start_time))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment