Commit 0ac57e3d authored by Colm Talbot's avatar Colm Talbot Committed by Moritz Huebner

update cpnest syntax

parent 4c358a29
from __future__ import absolute_import
import numpy as np
from pandas import DataFrame
from ..utils import logger
from ..utils import logger, check_directory_exists_and_if_not_mkdir
from .base_sampler import NestedSampler
......@@ -15,27 +15,33 @@ class Cpnest(NestedSampler):
Keyword Arguments
-----------------
npoints: int
nlive: int
The number of live points, note this can also equivalently be given as
one of [nlive, nlives, n_live_points]
one of [npoints, nlives, n_live_points]
seed: int (1234)
Initialised random seed
Nthreads: int, (1)
nthreads: int, (1)
Number of threads to use
maxmcmc: int (1000)
The maximum number of MCMC steps to take
verbose: Bool
verbose: Bool (True)
If true, print information information about the convergence during
resume: Bool (False)
Whether or not to resume from a previous run
output: str
Where to write the CPNest, by default this is
{self.outdir}/cpnest_{self.label}/
"""
default_kwargs = dict(verbose=1, Nthreads=1, Nlive=500, maxmcmc=1000,
Poolsize=100, seed=None, balance_samplers=True)
default_kwargs = dict(verbose=1, nthreads=1, nlive=500, maxmcmc=1000,
seed=None, poolsize=100, nhamiltonian=0, resume=False,
output=None)
def _translate_kwargs(self, kwargs):
if 'Nlive' not in kwargs:
if 'nlive' not in kwargs:
for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs:
kwargs['Nlive'] = kwargs.pop(equiv)
kwargs['nlive'] = kwargs.pop(equiv)
if 'seed' not in kwargs:
logger.warning('No seed provided, cpnest will use 1234.')
......@@ -69,15 +75,25 @@ class Cpnest(NestedSampler):
bounds = [[self.priors[key].minimum, self.priors[key].maximum]
for key in self.search_parameter_keys]
model = Model(self.search_parameter_keys, bounds)
out = CPNest(model, output=self.outdir, **self.kwargs)
out = CPNest(model, **self.kwargs)
out.run()
if self.plot:
out.plot()
# Since the output is not just samples, but log_likelihood as well,
# we turn this into a dataframe here. The index [0] here may be wrong
self.result.posterior = DataFrame(out.posterior_samples[0])
self.result.posterior = DataFrame(out.posterior_samples)
self.result.log_evidence = out.NS.state.logZ
self.result.log_evidence_err = np.nan
return self.result
def _verify_kwargs_against_default_kwargs(self):
"""
Set the directory where the output will be written.
"""
if not self.kwargs['output']:
self.kwargs['output'] = \
'{}/cpnest_{}/'.format(self.outdir, self.label)
if self.kwargs['output'].endswith('/') is False:
self.kwargs['output'] = '{}/'.format(self.kwargs['output'])
check_directory_exists_and_if_not_mkdir(self.kwargs['output'])
NestedSampler._verify_kwargs_against_default_kwargs(self)
......@@ -121,16 +121,18 @@ class TestCPNest(unittest.TestCase):
del self.sampler
def test_default_kwargs(self):
expected = dict(verbose=1, Nthreads=1, Nlive=500, maxmcmc=1000,
Poolsize=100, seed=None, balance_samplers=True)
expected = dict(verbose=1, nthreads=1, nlive=500, maxmcmc=1000,
seed=None, poolsize=100, nhamiltonian=0, resume=False,
output='outdir/cpnest_label/')
self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs(self):
expected = dict(verbose=1, Nthreads=1, Nlive=250, maxmcmc=1000,
Poolsize=100, seed=None, balance_samplers=True)
expected = dict(verbose=1, nthreads=1, nlive=250, maxmcmc=1000,
seed=None, poolsize=100, nhamiltonian=0, resume=False,
output='outdir/cpnest_label/')
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
del new_kwargs['Nlive']
del new_kwargs['nlive']
new_kwargs[equiv] = 250
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment