diff --git a/examples/injection_examples/basic_tutorial.py b/examples/injection_examples/basic_tutorial.py index 82dbc13d3214bb225956254ad9cf9054841eb1e6..f9946da0039934f222a9072026a2f773c67d0ea1 100644 --- a/examples/injection_examples/basic_tutorial.py +++ b/examples/injection_examples/basic_tutorial.py @@ -63,6 +63,4 @@ result = tupak.sampler.run_sampler(likelihood=likelihood, priors=priors, sampler # make some plots of the outputs result.plot_corner() -result.plot_walks() -result.plot_distributions() print(result) diff --git a/examples/injection_examples/change_sampled_parameters.py b/examples/injection_examples/change_sampled_parameters.py index d24c7e4c2bd1cb5dd1310c87b300b3c140e65f56..6d7734b51b60d73f1782762904a02ab7550ba039 100644 --- a/examples/injection_examples/change_sampled_parameters.py +++ b/examples/injection_examples/change_sampled_parameters.py @@ -50,6 +50,4 @@ result = tupak.sampler.run_sampler(likelihood=likelihood, priors=priors, sampler injection_parameters=injection_parameters, label='DifferentParameters', outdir=outdir, conversion_function=tupak.conversion.generate_all_bbh_parameters) result.plot_corner() -result.plot_walks() -result.plot_distributions() print(result) diff --git a/examples/injection_examples/create_your_own_source_model.py b/examples/injection_examples/create_your_own_source_model.py index 64b74ded3ec33b7ade097de89daa4150fa0cd8d8..51d548526404169d69c671121c864206de4dbdf9 100644 --- a/examples/injection_examples/create_your_own_source_model.py +++ b/examples/injection_examples/create_your_own_source_model.py @@ -50,7 +50,5 @@ likelihood = tupak.likelihood.GravitationalWaveTransient(IFOs, waveform_generato result = tupak.sampler.run_sampler( likelihood, prior, sampler='dynesty', outdir=outdir, label=label, resume=False, sample='unif', injection_parameters=injection_parameters) -result.plot_walks() -result.plot_distributions() result.plot_corner() print(result) diff --git a/examples/injection_examples/create_your_own_time_domain_source_model.py b/examples/injection_examples/create_your_own_time_domain_source_model.py index eb95e7f379cfd73e30b872a274e3ea4b18ed2f2f..af51c5812b5cfc7f49f57f1fb5c3e689c128e1c4 100644 --- a/examples/injection_examples/create_your_own_time_domain_source_model.py +++ b/examples/injection_examples/create_your_own_time_domain_source_model.py @@ -71,8 +71,5 @@ result = tupak.sampler.run_sampler(likelihood, prior, sampler='dynesty', npoints injection_parameters=injection_parameters, outdir=outdir, label=label) -result.plot_walks() -result.plot_distributions() result.plot_corner() - print(result) diff --git a/examples/injection_examples/how_to_specify_the_prior.py b/examples/injection_examples/how_to_specify_the_prior.py index abd2237330dceea7c8dfea31dcde9a076c43ac1d..7a3a0743e4826a4714507f0bca99cf4eb77ad07d 100644 --- a/examples/injection_examples/how_to_specify_the_prior.py +++ b/examples/injection_examples/how_to_specify_the_prior.py @@ -64,6 +64,4 @@ likelihood = tupak.likelihood.GravitationalWaveTransient(interferometers=IFOs, w result = tupak.sampler.run_sampler(likelihood=likelihood, priors=priors, sampler='dynesty', injection_parameters=injection_parameters, outdir=outdir, label='specify_prior') result.plot_corner() -result.plot_walks() -result.plot_distributions() print(result) diff --git a/examples/injection_examples/marginalized_likelihood.py b/examples/injection_examples/marginalized_likelihood.py index b39f62614b4150d43ea4ca99d7cb6b62bb824ae1..2cd028c550526dbb455d2da63ed146e89f616325 100644 --- a/examples/injection_examples/marginalized_likelihood.py +++ b/examples/injection_examples/marginalized_likelihood.py @@ -46,6 +46,4 @@ likelihood = tupak.likelihood.GravitationalWaveTransient( result = tupak.sampler.run_sampler(likelihood=likelihood, priors=priors, sampler='dynesty', injection_parameters=injection_parameters, outdir=outdir, label='BasicTutorial') result.plot_corner() -result.plot_walks() -result.plot_distributions() print(result) diff --git a/examples/open_data_examples/GW150914.py b/examples/open_data_examples/GW150914.py index e9a66a0784eae92c6e85d74d3943684aa5e89814..56ea1cde47d0274b68c2088ccd79ca571937d5a2 100644 --- a/examples/open_data_examples/GW150914.py +++ b/examples/open_data_examples/GW150914.py @@ -60,6 +60,4 @@ likelihood = tupak.likelihood.GravitationalWaveTransient(interferometers, wavefo result = tupak.sampler.run_sampler(likelihood, prior, sampler='dynesty', outdir=outdir, label=label) result.plot_corner() -result.plot_walks() -result.plot_distributions() print(result) diff --git a/examples/other_examples/hyper_parameter_example.py b/examples/other_examples/hyper_parameter_example.py new file mode 100644 index 0000000000000000000000000000000000000000..936ef26833c2005ba4f614cbc154b0c93c09784e --- /dev/null +++ b/examples/other_examples/hyper_parameter_example.py @@ -0,0 +1,82 @@ +#!/bin/python +""" +An example of how to use tupak to perform paramater estimation for hyperparams +""" +from __future__ import division +import tupak +import numpy as np + +tupak.utils.setup_logger() +outdir = 'outdir' + + +class GaussianLikelihood(tupak.likelihood.Likelihood): + def __init__(self, x, y, waveform_generator): + self.x = x + self.y = y + self.N = len(x) + self.waveform_generator = waveform_generator + self.parameters = waveform_generator.parameters + + def log_likelihood(self): + sigma = 1 + res = self.y - self.waveform_generator.time_domain_strain() + return -0.5 * (np.sum((res / sigma)**2) + + self.N*np.log(2*np.pi*sigma**2)) + + +def model(time, m): + return m * time + + +sampling_frequency = 10 +time_duration = 100 +time = np.arange(0, time_duration, 1/sampling_frequency) + +true_mu_m = 5 +true_sigma_m = 0.1 +sigma = 0.1 +Nevents = 10 +samples = [] + +# Make the sample sets +for i in range(Nevents): + m = np.random.normal(true_mu_m, true_sigma_m) + injection_parameters = dict(m=m) + + N = len(time) + data = model(time, **injection_parameters) + np.random.normal(0, sigma, N) + + waveform_generator = tupak.waveform_generator.WaveformGenerator( + time_duration=time_duration, sampling_frequency=sampling_frequency, + time_domain_source_model=model) + + likelihood = GaussianLikelihood(time, data, waveform_generator) + + priors = dict(m=tupak.prior.Uniform(-10, 10, 'm')) + + result = tupak.sampler.run_sampler( + likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000, + injection_parameters=injection_parameters, outdir=outdir, + verbose=False, label='individual_{}'.format(i), use_ratio=False, + sample='unif') + result.plot_corner() + samples.append(result.samples) + +# Now run the hyperparameter inference +run_prior = tupak.prior.Uniform(minimum=-10, maximum=10, name='mu_m') +hyper_prior = tupak.prior.Gaussian(mu=0, sigma=1, name='hyper') + +hp_likelihood = tupak.likelihood.HyperparameterLikelihood( + samples, hyper_prior, run_prior) + +hp_priors = dict( + mu=tupak.prior.Uniform(-10, 10, 'mu', '$\mu_m$'), + sigma=tupak.prior.Uniform(0, 10, 'sigma', '$\sigma_m$')) + +# And run sampler +result = tupak.sampler.run_sampler( + likelihood=hp_likelihood, priors=hp_priors, sampler='dynesty', + npoints=1000, outdir=outdir, label='hyperparameter', use_ratio=False, + sample='unif', verbose=True) +result.plot_corner(truth=dict(mu=true_mu_m, sigma=true_sigma_m)) diff --git a/requirements.txt b/requirements.txt index f71732b9408956a5cac6709a0079630222d20fde..fe8a629b9a049938e04c73d9c49034b19bffc522 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ future dynesty corner -numpy +numpy>=1.9 matplotlib>=2.0 scipy gwpy diff --git a/test/sampler_tests.py b/test/sampler_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..dcbdfdf255e1c47e84e74230b11c5740e773e5c7 --- /dev/null +++ b/test/sampler_tests.py @@ -0,0 +1,130 @@ +from context import tupak +from tupak import prior +from tupak.result import Result +import unittest +from mock import MagicMock +import numpy as np +import inspect +import os +import copy + + +class TestSampler(unittest.TestCase): + + def setUp(self): + likelihood = tupak.likelihood.Likelihood() + likelihood.parameters = dict(a=1, b=2, c=3) + delta_prior = prior.DeltaFunction(peak=0) + delta_prior.rescale = MagicMock(return_value=prior.DeltaFunction(peak=1)) + delta_prior.prob = MagicMock(return_value=1) + delta_prior.sample = MagicMock(return_value=0) + uniform_prior = prior.Uniform(0, 1) + uniform_prior.rescale = MagicMock(return_value=prior.Uniform(0, 2)) + uniform_prior.prob = MagicMock(return_value=1) + uniform_prior.sample = MagicMock(return_value=0.5) + + priors = dict(a=delta_prior, b='string', c=uniform_prior) + likelihood.log_likelihood_ratio = MagicMock(return_value=1) + likelihood.log_likelihood = MagicMock(return_value=2) + test_directory = 'test_directory' + if os.path.isdir(test_directory): + os.rmdir(test_directory) + self.sampler = tupak.sampler.Sampler(likelihood=likelihood, + priors=priors, + external_sampler='nestle', + outdir=test_directory, + use_ratio=False) + + def tearDown(self): + os.rmdir(self.sampler.outdir) + del self.sampler + + def test_search_parameter_keys(self): + expected_search_parameter_keys = ['c'] + self.assertListEqual(self.sampler.search_parameter_keys, expected_search_parameter_keys) + + def test_fixed_parameter_keys(self): + expected_fixed_parameter_keys = ['a'] + self.assertListEqual(self.sampler.fixed_parameter_keys, expected_fixed_parameter_keys) + + def test_ndim(self): + self.assertEqual(self.sampler.ndim, 1) + + def test_kwargs(self): + self.assertDictEqual(self.sampler.kwargs, {}) + + def test_label(self): + self.assertEqual(self.sampler.label, 'label') + + def test_if_external_sampler_is_module(self): + self.assertTrue(inspect.ismodule(self.sampler.external_sampler)) + + def test_if_external_sampler_has_the_correct_module_name(self): + expected_name = 'nestle' + self.assertEqual(self.sampler.external_sampler.__name__, expected_name) + + def test_external_sampler_raises_if_sampler_not_installed(self): + with self.assertRaises(ImportError): + self.sampler.external_sampler = 'unexpected_sampler' + + def test_setting_custom_sampler(self): + other_sampler = tupak.sampler.Sampler(self.sampler.likelihood, + self.sampler.priors) + self.sampler.external_sampler = other_sampler + self.assertEqual(self.sampler.external_sampler, other_sampler) + + def test_setting_external_sampler_to_something_else_raises_error(self): + with self.assertRaises(TypeError): + self.sampler.external_sampler = object() + + def test_result(self): + expected_result = Result() + expected_result.search_parameter_keys = ['c'] + expected_result.fixed_parameter_keys = ['a'] + expected_result.parameter_labels = ['c'] + expected_result.label = 'label' + expected_result.outdir = 'outdir' + expected_result.kwargs = {} + self.assertDictEqual(self.sampler.result.__dict__, expected_result.__dict__) + + def test_make_outdir_if_no_outdir_exists(self): + self.assertTrue(os.path.isdir(self.sampler.outdir)) + + def test_prior_transform_transforms_search_parameter_keys(self): + self.sampler.prior_transform([0]) + expected_prior = prior.Uniform(0, 1) + self.assertListEqual([self.sampler.priors['c'].minimum, + self.sampler.priors['c'].maximum], + [expected_prior.minimum, + expected_prior.maximum]) + + def test_prior_transform_does_not_transform_fixed_parameter_keys(self): + self.sampler.prior_transform([0]) + self.assertEqual(self.sampler.priors['a'].peak, + prior.DeltaFunction(peak=0).peak) + + def test_log_prior(self): + self.assertEqual(self.sampler.log_prior({1}), 0.0) + + def test_log_likelihood_with_use_ratio(self): + self.sampler.use_ratio = True + self.assertEqual(self.sampler.log_likelihood([0]), 1) + + def test_log_likelihood_without_use_ratio(self): + self.sampler.use_ratio = False + self.assertEqual(self.sampler.log_likelihood([0]), 2) + + def test_log_likelihood_correctly_sets_parameters(self): + expected_dict = dict(a=0, + b=2, + c=0) + _ = self.sampler.log_likelihood([0]) + self.assertDictEqual(self.sampler.likelihood.parameters, expected_dict) + + def test_get_random_draw(self): + self.assertEqual(self.sampler.get_random_draw_from_prior(), np.array([0.5])) + + def test_base_run_sampler(self): + sampler_copy = copy.copy(self.sampler) + self.sampler.run_sampler() + self.assertDictEqual(sampler_copy.__dict__, self.sampler.__dict__) \ No newline at end of file diff --git a/tupak/likelihood.py b/tupak/likelihood.py index 382c1aea8cddabdb010ee507f0855f06154ea749..47e70944d59a2220f19ddf1fcb93ce6322fd21bf 100644 --- a/tupak/likelihood.py +++ b/tupak/likelihood.py @@ -14,6 +14,9 @@ import logging class Likelihood(object): """ Empty likelihood class to be subclassed by other likelihoods """ + def __init__(self, parameters=None): + self.parameters = parameters + def log_likelihood(self): return np.nan @@ -51,16 +54,15 @@ class GravitationalWaveTransient(Likelihood): Returns ------- Likelihood: `tupak.likelihood.Likelihood` - A likehood object, able to compute the likelihood of the data given + A likelihood object, able to compute the likelihood of the data given some model parameters """ def __init__(self, interferometers, waveform_generator, distance_marginalization=False, phase_marginalization=False, prior=None): - # GravitationalWaveTransient.__init__(self, interferometers, waveform_generator) + Likelihood.__init__(self, waveform_generator.parameters) self.interferometers = interferometers self.waveform_generator = waveform_generator - self.parameters = self.waveform_generator.parameters self.non_standard_sampling_parameter_keys = self.waveform_generator.non_standard_sampling_parameter_keys self.distance_marginalization = distance_marginalization self.phase_marginalization = phase_marginalization @@ -153,7 +155,7 @@ class GravitationalWaveTransient(Likelihood): class BasicGravitationalWaveTransient(Likelihood): - """ A basic gravitaitonal wave transient likelihood + """ A basic gravitational wave transient likelihood The simplest frequency-domain gravitational wave transient likelihood. Does not include distance/phase marginalization. @@ -170,11 +172,12 @@ class BasicGravitationalWaveTransient(Likelihood): Returns ------- Likelihood: `tupak.likelihood.Likelihood` - A likehood object, able to compute the likelihood of the data given + A likelihood object, able to compute the likelihood of the data given some model parameters """ def __init__(self, interferometers, waveform_generator): + Likelihood.__init__(self, waveform_generator.parameters) self.interferometers = interferometers self.waveform_generator = waveform_generator @@ -230,3 +233,53 @@ def get_binary_black_hole_likelihood(interferometers): likelihood = tupak.likelihood.GravitationalWaveTransient(interferometers, waveform_generator) return likelihood + +class HyperparameterLikelihood(Likelihood): + """ A likelihood for infering hyperparameter posterior distributions + + See Eq. (1) of https://arxiv.org/abs/1801.02699 for a definition. + + Parameters + ---------- + samples: list + An N-dimensional list of individual sets of samples. Each set may have + a different size. + hyper_prior: `tupak.prior.Prior` + A prior distribution with a `parameters` argument pointing to the + hyperparameters to infer from the samples. These may need to be + initialized to any arbitrary value, but this will not effect the + result. + run_prior: `tupak.prior.Prior` + The prior distribution used in the inidivudal inferences which resulted + in the set of samples. + + """ + + def __init__(self, samples, hyper_prior, run_prior): + Likelihood.__init__(self, parameters=hyper_prior.__dict__) + self.samples = samples + self.hyper_prior = hyper_prior + self.run_prior = run_prior + if hasattr(hyper_prior, 'lnprob') and hasattr(run_prior, 'lnprob'): + logging.info("Using log-probabilities in likelihood") + self.log_likelihood = self.log_likelihood_using_lnprob + else: + logging.info("Using probabilities in likelihood") + self.log_likelihood = self.log_likelihood_using_prob + + def log_likelihood_using_lnprob(self): + L = [] + self.hyper_prior.__dict__.update(self.parameters) + for samp in self.samples: + f = self.hyper_prior.lnprob(samp) - self.run_prior.lnprob(samp) + L.append(logsumexp(f)) + return np.sum(L) + + def log_likelihood_using_prob(self): + L = [] + self.hyper_prior.__dict__.update(self.parameters) + for samp in self.samples: + L.append( + np.sum(self.hyper_prior.prob(samp) / + self.run_prior.prob(samp))) + return np.sum(np.log(L)) diff --git a/tupak/prior.py b/tupak/prior.py index 680c38602176f6c90bb077c24c8222e90104a967..b25c23cc11a8a0b58c65d0c28fa2f573b7eed784 100644 --- a/tupak/prior.py +++ b/tupak/prior.py @@ -173,6 +173,12 @@ class PowerLaw(Prior): return np.nan_to_num(val ** self.alpha * (1 + self.alpha) / (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha))) * in_prior + def lnprob(self, val): + in_prior = (val >= self.minimum) & (val <= self.maximum) + normalising = (1+self.alpha)/(self.maximum ** (1 + self.alpha) + - self.minimum ** (1 + self.alpha)) + return self.alpha * np.log(val) * np.log(normalising) * in_prior + class Uniform(PowerLaw): """Uniform prior""" @@ -254,6 +260,9 @@ class Gaussian(Prior): """Return the prior probability of val""" return np.exp(-(self.mu - val)**2 / (2 * self.sigma**2)) / (2 * np.pi)**0.5 / self.sigma + def lnprob(self, val): + return -0.5*((self.mu - val)**2 / self.sigma**2 + np.log(2 * np.pi * self.sigma**2)) + class TruncatedGaussian(Prior): """ diff --git a/tupak/result.py b/tupak/result.py index 07023b03cace71d61d721e7fffcd2817e547c1b0..681eb03bd5cf7a59c4d89011c9286e8d45fa48aa 100644 --- a/tupak/result.py +++ b/tupak/result.py @@ -3,14 +3,7 @@ import os import numpy as np import deepdish import pandas as pd - -try: - from chainconsumer import ChainConsumer -except ImportError: - def ChainConsumer(): - logging.warning( - "You do not have the optional module chainconsumer installed" - " unable to generate a corner plot") +import corner def result_file_name(outdir, label): @@ -34,10 +27,12 @@ def read_in_result(outdir=None, label=None, filename=None): """ if filename is None: filename = result_file_name(outdir, label) + elif (outdir is None or label is None) and filename is None: + raise ValueError("No information given to load file") if os.path.isfile(filename): return Result(deepdish.io.load(filename)) else: - raise ValueError("No information given to load file") + raise ValueError("No result found") class Result(dict): @@ -102,103 +97,78 @@ class Result(dict): .format(k)) return return_list - def plot_corner(self, save=True, **kwargs): - """ Plot a corner-plot using chain-consumer + def plot_corner(self, parameters=None, save=True, dpi=300, **kwargs): + """ Plot a corner-plot using corner + + See https://corner.readthedocs.io/en/latest/ for a detailed API. Parameters ---------- + parameters: list + If given, a list of the parameter names to include save: bool If true, save the image using the given label and outdir + **kwargs: + Other keyword arguments are passed to `corner.corner`. We set some + defaults to improve the basic look and feel, but these can all be + overridden. Returns ------- fig: A matplotlib figure instance + """ - # Set some defaults (unless already set) - kwargs['figsize'] = kwargs.get('figsize', 'GROW') - if save: - filename = '{}/{}_corner.png'.format(self.outdir, self.label) - kwargs['filename'] = kwargs.get('filename', filename) - logging.info('Saving corner plot to {}'.format(kwargs['filename'])) + defaults_kwargs = dict( + bins=50, smooth=0.9, label_kwargs=dict(fontsize=16), + title_kwargs=dict(fontsize=16), color='#0072C1', + truth_color='tab:orange', show_titles=True, + quantiles=[0.025, 0.975], levels=(0.39, 0.8, 0.97), + plot_density=False, plot_datapoints=True, fill_contours=True, + max_n_ticks=3) + + defaults_kwargs.update(kwargs) + kwargs = defaults_kwargs + + if 'truth' in kwargs: + kwargs['truths'] = kwargs.pop('truth') + if getattr(self, 'injection_parameters', None) is not None: - # If no truth argument given, set these to the injection params injection_parameters = [self.injection_parameters[key] for key in self.search_parameter_keys] - kwargs['truth'] = kwargs.get('truth', injection_parameters) - - if type(kwargs.get('truth')) == dict: - old_keys = kwargs['truth'].keys() - new_keys = self.get_latex_labels_from_parameter_keys(old_keys) - for old, new in zip(old_keys, new_keys): - kwargs['truth'][new] = kwargs['truth'].pop(old) - if 'parameters' in kwargs: - kwargs['parameters'] = self.get_latex_labels_from_parameter_keys( - kwargs['parameters']) - - # Check all parameter_labels are a valid string - for i, label in enumerate(self.parameter_labels): - if label is None: - self.parameter_labels[i] = 'Unknown' - c = ChainConsumer() - if c: - c.add_chain(self.samples, parameters=self.parameter_labels, - name=self.label) - fig = c.plotter.plot(**kwargs) - return fig + kwargs['truths'] = kwargs.get('truths', injection_parameters) - def plot_walks(self, save=True, **kwargs): - """ Plot the chain walks using chain-consumer + if parameters is None: + parameters = self.search_parameter_keys - Parameters - ---------- - save: bool - If true, save the image using the given label and outdir + xs = self.posterior[parameters].values + kwargs['labels'] = kwargs.get( + 'labels', self.get_latex_labels_from_parameter_keys( + parameters)) - Returns - ------- - fig: - A matplotlib figure instance - """ + if type(kwargs.get('truths')) == dict: + truths = [kwargs['truths'][k] for k in parameters] + kwargs['truths'] = truths - # Set some defaults (unless already set) - if save: - kwargs['filename'] = '{}/{}_walks.png'.format(self.outdir, self.label) - logging.info('Saving walker plot to {}'.format(kwargs['filename'])) - if getattr(self, 'injection_parameters', None) is not None: - kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys] - c = ChainConsumer() - if c: - c.add_chain(self.samples, parameters=self.parameter_labels) - fig = c.plotter.plot_walks(**kwargs) - return fig + fig = corner.corner(xs, **kwargs) - def plot_distributions(self, save=True, **kwargs): - """ Plot the chain walks using chain-consumer + if save: + filename = '{}/{}_corner.png'.format(self.outdir, self.label) + logging.info('Saving corner plot to {}'.format(filename)) + fig.savefig(filename, dpi=dpi) - Parameters - ---------- - save: bool - If true, save the image using the given label and outdir + return fig - Returns - ------- - fig: - A matplotlib figure instance + def plot_walks(self, save=True, **kwargs): + """ """ + logging.warning("plot_walks deprecated") - # Set some defaults (unless already set) - if save: - kwargs['filename'] = '{}/{}_distributions.png'.format(self.outdir, self.label) - logging.info('Saving distributions plot to {}'.format(kwargs['filename'])) - if getattr(self, 'injection_parameters', None) is not None: - kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys] - c = ChainConsumer() - if c: - c.add_chain(self.samples, parameters=self.parameter_labels) - fig = c.plotter.plot_distributions(**kwargs) - return fig + def plot_distributions(self, save=True, **kwargs): + """ + """ + logging.warning("plot_distributions deprecated") def write_prior_to_file(self, outdir): """ diff --git a/tupak/sampler.py b/tupak/sampler.py index 16f7a34fe8d2ca72c66f1c67b238911224e03bd6..0b6c0d316722affc0552513facb204a9e00e81db 100644 --- a/tupak/sampler.py +++ b/tupak/sampler.py @@ -19,7 +19,7 @@ class Sampler(object): Parameters ---------- - likelihood: likelihood.GravitationalWaveTransient + likelihood: likelihood.Likelihood A object with a log_l method prior: dict The prior to be used in the search. Elements can either be floats @@ -36,8 +36,10 @@ class Sampler(object): """ - def __init__(self, likelihood, priors, external_sampler='nestle', outdir='outdir', label='label', use_ratio=False, - **kwargs): + def __init__( + self, likelihood, priors, external_sampler='nestle', + outdir='outdir', label='label', use_ratio=False, plot=False, + **kwargs): self.likelihood = likelihood self.priors = priors self.label = label @@ -45,6 +47,7 @@ class Sampler(object): self.use_ratio = use_ratio self.external_sampler = external_sampler self.external_sampler_function = None + self.plot = plot self.__search_parameter_keys = [] self.__fixed_parameter_keys = [] @@ -148,10 +151,12 @@ class Sampler(object): except AttributeError as e: logging.warning('Cannot sample from {}, {}'.format(key, e)) try: - self.likelihood.log_likelihood_ratio() - except TypeError: - raise TypeError('GravitationalWaveTransient evaluation failed. Have you definitely specified all the parameters?\n{}'.format( - self.likelihood.parameters)) + self.likelihood.log_likelihood() + except TypeError as e: + raise TypeError( + "Likelihood evaluation failed with message: \n'{}'\n" + "Have you specified all the parameters:\n{}" + .format(e, self.likelihood.parameters)) def prior_transform(self, theta): return [self.priors[key].rescale(t) for key, t in zip(self.__search_parameter_keys, theta)] @@ -325,8 +330,20 @@ class Dynesty(Sampler): out.samples, weights) self.result.logz = out.logz[-1] self.result.logzerr = out.logzerr[-1] + + if self.plot: + self.generate_trace_plots(out) return self.result + def generate_trace_plots(self, dynesty_results): + filename = '{}/{}_trace.png'.format(self.outdir, self.label) + logging.info("Writing trace plot to {}".format(filename)) + from dynesty import plotting as dyplot + fig, axes = dyplot.traceplot(dynesty_results, + labels=self.result.parameter_labels) + fig.tight_layout() + fig.savefig(filename) + def _run_test(self): dynesty = self.external_sampler nested_sampler = dynesty.NestedSampler( @@ -434,7 +451,7 @@ class Ptemcee(Sampler): def run_sampler(likelihood, priors=None, label='label', outdir='outdir', sampler='nestle', use_ratio=True, injection_parameters=None, - conversion_function=None, **kwargs): + conversion_function=None, plot=False, **kwargs): """ The primary interface to easy parameter estimation @@ -459,7 +476,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', injection_parameters: dict A dictionary of injection parameters used in creating the data (if using simulated data). Appended to the result object and saved. - + plot: bool + If true, generate a corner plot and, if applicable diagnostic plots conversion_function: function, optional Function to apply to posterior to generate additional parameters. **kwargs: @@ -482,7 +500,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', if implemented_samplers.__contains__(sampler.title()): sampler_class = globals()[sampler.title()] sampler = sampler_class(likelihood, priors, sampler, outdir=outdir, - label=label, use_ratio=use_ratio, + label=label, use_ratio=use_ratio, plot=plot, **kwargs) if sampler.cached_result: @@ -509,6 +527,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', result.samples_to_data_frame(likelihood=likelihood, priors=priors, conversion_function=conversion_function) result.kwargs = sampler.kwargs result.save_to_file(outdir=outdir, label=label) + if plot: + result.plot_corner() return result else: raise ValueError(