diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py index 91cf7bc6020016462bbb7368190c1b0ec21934d3..0fcdba59328a62d4ce9bcbb6e87245222f0f6905 100644 --- a/bilby/core/sampler/cpnest.py +++ b/bilby/core/sampler/cpnest.py @@ -1,7 +1,7 @@ from __future__ import absolute_import import numpy as np from pandas import DataFrame -from ..utils import logger +from ..utils import logger, check_directory_exists_and_if_not_mkdir from .base_sampler import NestedSampler @@ -15,27 +15,33 @@ class Cpnest(NestedSampler): Keyword Arguments ----------------- - npoints: int + nlive: int The number of live points, note this can also equivalently be given as - one of [nlive, nlives, n_live_points] + one of [npoints, nlives, n_live_points] seed: int (1234) Initialised random seed - Nthreads: int, (1) + nthreads: int, (1) Number of threads to use maxmcmc: int (1000) The maximum number of MCMC steps to take - verbose: Bool + verbose: Bool (True) If true, print information information about the convergence during + resume: Bool (False) + Whether or not to resume from a previous run + output: str + Where to write the CPNest, by default this is + {self.outdir}/cpnest_{self.label}/ """ - default_kwargs = dict(verbose=1, Nthreads=1, Nlive=500, maxmcmc=1000, - Poolsize=100, seed=None, balance_samplers=True) + default_kwargs = dict(verbose=1, nthreads=1, nlive=500, maxmcmc=1000, + seed=None, poolsize=100, nhamiltonian=0, resume=False, + output=None) def _translate_kwargs(self, kwargs): - if 'Nlive' not in kwargs: + if 'nlive' not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: - kwargs['Nlive'] = kwargs.pop(equiv) + kwargs['nlive'] = kwargs.pop(equiv) if 'seed' not in kwargs: logger.warning('No seed provided, cpnest will use 1234.') @@ -69,15 +75,25 @@ class Cpnest(NestedSampler): bounds = [[self.priors[key].minimum, self.priors[key].maximum] for key in self.search_parameter_keys] model = Model(self.search_parameter_keys, bounds) - out = CPNest(model, output=self.outdir, **self.kwargs) + out = CPNest(model, **self.kwargs) out.run() if self.plot: out.plot() - # Since the output is not just samples, but log_likelihood as well, - # we turn this into a dataframe here. The index [0] here may be wrong - self.result.posterior = DataFrame(out.posterior_samples[0]) + self.result.posterior = DataFrame(out.posterior_samples) self.result.log_evidence = out.NS.state.logZ self.result.log_evidence_err = np.nan return self.result + + def _verify_kwargs_against_default_kwargs(self): + """ + Set the directory where the output will be written. + """ + if not self.kwargs['output']: + self.kwargs['output'] = \ + '{}/cpnest_{}/'.format(self.outdir, self.label) + if self.kwargs['output'].endswith('/') is False: + self.kwargs['output'] = '{}/'.format(self.kwargs['output']) + check_directory_exists_and_if_not_mkdir(self.kwargs['output']) + NestedSampler._verify_kwargs_against_default_kwargs(self) diff --git a/sampler_requirements.txt b/sampler_requirements.txt index fbb2c32fdefe1e8999f6ae81decc7e83da9f4f1f..e81411d3b5134d018046dfda621971252e6f3ea7 100644 --- a/sampler_requirements.txt +++ b/sampler_requirements.txt @@ -1,4 +1,4 @@ -cpnest +cpnest>=0.9.4 dynesty emcee nestle diff --git a/test/sampler_test.py b/test/sampler_test.py index 56409d4a27714444d0ffc8323cd56b06910467e4..963813f007d32fa9c66214c36ed9a5f19a1207b0 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -121,16 +121,18 @@ class TestCPNest(unittest.TestCase): del self.sampler def test_default_kwargs(self): - expected = dict(verbose=1, Nthreads=1, Nlive=500, maxmcmc=1000, - Poolsize=100, seed=None, balance_samplers=True) + expected = dict(verbose=1, nthreads=1, nlive=500, maxmcmc=1000, + seed=None, poolsize=100, nhamiltonian=0, resume=False, + output='outdir/cpnest_label/') self.assertDictEqual(expected, self.sampler.kwargs) def test_translate_kwargs(self): - expected = dict(verbose=1, Nthreads=1, Nlive=250, maxmcmc=1000, - Poolsize=100, seed=None, balance_samplers=True) + expected = dict(verbose=1, nthreads=1, nlive=250, maxmcmc=1000, + seed=None, poolsize=100, nhamiltonian=0, resume=False, + output='outdir/cpnest_label/') for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: new_kwargs = self.sampler.kwargs.copy() - del new_kwargs['Nlive'] + del new_kwargs['nlive'] new_kwargs[equiv] = 250 self.sampler.kwargs = new_kwargs self.assertDictEqual(expected, self.sampler.kwargs)