Skip to content
Snippets Groups Projects
Commit 0ac57e3d authored by Colm Talbot's avatar Colm Talbot Committed by Moritz Huebner
Browse files

update cpnest syntax

parent 4c358a29
No related branches found
No related tags found
No related merge requests found
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np import numpy as np
from pandas import DataFrame from pandas import DataFrame
from ..utils import logger from ..utils import logger, check_directory_exists_and_if_not_mkdir
from .base_sampler import NestedSampler from .base_sampler import NestedSampler
...@@ -15,27 +15,33 @@ class Cpnest(NestedSampler): ...@@ -15,27 +15,33 @@ class Cpnest(NestedSampler):
Keyword Arguments Keyword Arguments
----------------- -----------------
npoints: int nlive: int
The number of live points, note this can also equivalently be given as The number of live points, note this can also equivalently be given as
one of [nlive, nlives, n_live_points] one of [npoints, nlives, n_live_points]
seed: int (1234) seed: int (1234)
Initialised random seed Initialised random seed
Nthreads: int, (1) nthreads: int, (1)
Number of threads to use Number of threads to use
maxmcmc: int (1000) maxmcmc: int (1000)
The maximum number of MCMC steps to take The maximum number of MCMC steps to take
verbose: Bool verbose: Bool (True)
If true, print information information about the convergence during If true, print information information about the convergence during
resume: Bool (False)
Whether or not to resume from a previous run
output: str
Where to write the CPNest, by default this is
{self.outdir}/cpnest_{self.label}/
""" """
default_kwargs = dict(verbose=1, Nthreads=1, Nlive=500, maxmcmc=1000, default_kwargs = dict(verbose=1, nthreads=1, nlive=500, maxmcmc=1000,
Poolsize=100, seed=None, balance_samplers=True) seed=None, poolsize=100, nhamiltonian=0, resume=False,
output=None)
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
if 'Nlive' not in kwargs: if 'nlive' not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
kwargs['Nlive'] = kwargs.pop(equiv) kwargs['nlive'] = kwargs.pop(equiv)
if 'seed' not in kwargs: if 'seed' not in kwargs:
logger.warning('No seed provided, cpnest will use 1234.') logger.warning('No seed provided, cpnest will use 1234.')
...@@ -69,15 +75,25 @@ class Cpnest(NestedSampler): ...@@ -69,15 +75,25 @@ class Cpnest(NestedSampler):
bounds = [[self.priors[key].minimum, self.priors[key].maximum] bounds = [[self.priors[key].minimum, self.priors[key].maximum]
for key in self.search_parameter_keys] for key in self.search_parameter_keys]
model = Model(self.search_parameter_keys, bounds) model = Model(self.search_parameter_keys, bounds)
out = CPNest(model, output=self.outdir, **self.kwargs) out = CPNest(model, **self.kwargs)
out.run() out.run()
if self.plot: if self.plot:
out.plot() out.plot()
# Since the output is not just samples, but log_likelihood as well, self.result.posterior = DataFrame(out.posterior_samples)
# we turn this into a dataframe here. The index [0] here may be wrong
self.result.posterior = DataFrame(out.posterior_samples[0])
self.result.log_evidence = out.NS.state.logZ self.result.log_evidence = out.NS.state.logZ
self.result.log_evidence_err = np.nan self.result.log_evidence_err = np.nan
return self.result return self.result
def _verify_kwargs_against_default_kwargs(self):
"""
Set the directory where the output will be written.
"""
if not self.kwargs['output']:
self.kwargs['output'] = \
'{}/cpnest_{}/'.format(self.outdir, self.label)
if self.kwargs['output'].endswith('/') is False:
self.kwargs['output'] = '{}/'.format(self.kwargs['output'])
check_directory_exists_and_if_not_mkdir(self.kwargs['output'])
NestedSampler._verify_kwargs_against_default_kwargs(self)
...@@ -121,16 +121,18 @@ class TestCPNest(unittest.TestCase): ...@@ -121,16 +121,18 @@ class TestCPNest(unittest.TestCase):
del self.sampler del self.sampler
def test_default_kwargs(self): def test_default_kwargs(self):
expected = dict(verbose=1, Nthreads=1, Nlive=500, maxmcmc=1000, expected = dict(verbose=1, nthreads=1, nlive=500, maxmcmc=1000,
Poolsize=100, seed=None, balance_samplers=True) seed=None, poolsize=100, nhamiltonian=0, resume=False,
output='outdir/cpnest_label/')
self.assertDictEqual(expected, self.sampler.kwargs) self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs(self): def test_translate_kwargs(self):
expected = dict(verbose=1, Nthreads=1, Nlive=250, maxmcmc=1000, expected = dict(verbose=1, nthreads=1, nlive=250, maxmcmc=1000,
Poolsize=100, seed=None, balance_samplers=True) seed=None, poolsize=100, nhamiltonian=0, resume=False,
output='outdir/cpnest_label/')
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy() new_kwargs = self.sampler.kwargs.copy()
del new_kwargs['Nlive'] del new_kwargs['nlive']
new_kwargs[equiv] = 250 new_kwargs[equiv] = 250
self.sampler.kwargs = new_kwargs self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs) self.assertDictEqual(expected, self.sampler.kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment