Skip to content
Snippets Groups Projects
Commit 0ac57e3d authored by Colm Talbot's avatar Colm Talbot Committed by Moritz Huebner
Browse files

update cpnest syntax

parent 4c358a29
No related branches found
No related tags found
No related merge requests found
from __future__ import absolute_import
import numpy as np
from pandas import DataFrame
from ..utils import logger
from ..utils import logger, check_directory_exists_and_if_not_mkdir
from .base_sampler import NestedSampler
......@@ -15,27 +15,33 @@ class Cpnest(NestedSampler):
Keyword Arguments
-----------------
npoints: int
nlive: int
The number of live points, note this can also equivalently be given as
one of [nlive, nlives, n_live_points]
one of [npoints, nlives, n_live_points]
seed: int (1234)
Initialised random seed
Nthreads: int, (1)
nthreads: int, (1)
Number of threads to use
maxmcmc: int (1000)
The maximum number of MCMC steps to take
verbose: Bool
verbose: Bool (True)
If true, print information information about the convergence during
resume: Bool (False)
Whether or not to resume from a previous run
output: str
Where to write the CPNest, by default this is
{self.outdir}/cpnest_{self.label}/
"""
default_kwargs = dict(verbose=1, Nthreads=1, Nlive=500, maxmcmc=1000,
Poolsize=100, seed=None, balance_samplers=True)
default_kwargs = dict(verbose=1, nthreads=1, nlive=500, maxmcmc=1000,
seed=None, poolsize=100, nhamiltonian=0, resume=False,
output=None)
def _translate_kwargs(self, kwargs):
if 'Nlive' not in kwargs:
if 'nlive' not in kwargs:
for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs:
kwargs['Nlive'] = kwargs.pop(equiv)
kwargs['nlive'] = kwargs.pop(equiv)
if 'seed' not in kwargs:
logger.warning('No seed provided, cpnest will use 1234.')
......@@ -69,15 +75,25 @@ class Cpnest(NestedSampler):
bounds = [[self.priors[key].minimum, self.priors[key].maximum]
for key in self.search_parameter_keys]
model = Model(self.search_parameter_keys, bounds)
out = CPNest(model, output=self.outdir, **self.kwargs)
out = CPNest(model, **self.kwargs)
out.run()
if self.plot:
out.plot()
# Since the output is not just samples, but log_likelihood as well,
# we turn this into a dataframe here. The index [0] here may be wrong
self.result.posterior = DataFrame(out.posterior_samples[0])
self.result.posterior = DataFrame(out.posterior_samples)
self.result.log_evidence = out.NS.state.logZ
self.result.log_evidence_err = np.nan
return self.result
def _verify_kwargs_against_default_kwargs(self):
"""
Set the directory where the output will be written.
"""
if not self.kwargs['output']:
self.kwargs['output'] = \
'{}/cpnest_{}/'.format(self.outdir, self.label)
if self.kwargs['output'].endswith('/') is False:
self.kwargs['output'] = '{}/'.format(self.kwargs['output'])
check_directory_exists_and_if_not_mkdir(self.kwargs['output'])
NestedSampler._verify_kwargs_against_default_kwargs(self)
cpnest
cpnest>=0.9.4
dynesty
emcee
nestle
......
......@@ -121,16 +121,18 @@ class TestCPNest(unittest.TestCase):
del self.sampler
def test_default_kwargs(self):
expected = dict(verbose=1, Nthreads=1, Nlive=500, maxmcmc=1000,
Poolsize=100, seed=None, balance_samplers=True)
expected = dict(verbose=1, nthreads=1, nlive=500, maxmcmc=1000,
seed=None, poolsize=100, nhamiltonian=0, resume=False,
output='outdir/cpnest_label/')
self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs(self):
expected = dict(verbose=1, Nthreads=1, Nlive=250, maxmcmc=1000,
Poolsize=100, seed=None, balance_samplers=True)
expected = dict(verbose=1, nthreads=1, nlive=250, maxmcmc=1000,
seed=None, poolsize=100, nhamiltonian=0, resume=False,
output='outdir/cpnest_label/')
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
del new_kwargs['Nlive']
del new_kwargs['nlive']
new_kwargs[equiv] = 250
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment