diff --git a/.gitignore b/.gitignore index 88717818866b233ab49b644f24d116aa61efea91..cac42814dd0a1cf3988e6ccdbe38e5400181155d 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,4 @@ MANIFEST *.version *.ipynb_checkpoints outdir/* -.idea/* +.idea/* \ No newline at end of file diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 0bc59adff77bed66f25bfd426c590e64c28a6ebb..d5b1ba3aaa2ebf29f3f0a2cfcb77dfa22559b414 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -11,6 +11,7 @@ from .cpnest import Cpnest from .dynesty import Dynesty from .dynamic_dynesty import DynamicDynesty from .emcee import Emcee +from .kombine import Kombine from .nestle import Nestle from .polychord import PyPolyChord from .ptemcee import Ptemcee @@ -21,10 +22,10 @@ from .fake_sampler import FakeSampler from . import proposal IMPLEMENTED_SAMPLERS = { - 'cpnest': Cpnest, 'dynamic_dynesty': DynamicDynesty, 'dynesty': Dynesty, 'emcee': Emcee, 'nestle': Nestle, - 'ptemcee': Ptemcee,'ptmcmcsampler' : PTMCMCSampler, + 'cpnest': Cpnest, 'dynesty': Dynesty, 'emcee': Emcee, 'kombine': Kombine, + 'nestle': Nestle, 'ptemcee': Ptemcee, 'ptmcmcsampler': PTMCMCSampler, 'pymc3': Pymc3, 'pymultinest': Pymultinest, 'pypolychord': PyPolyChord, - 'fake_sampler': FakeSampler } + 'fake_sampler': FakeSampler} if command_line_args.sampler_help: sampler = command_line_args.sampler_help diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 12d795d3362ae0208f75aaab13e1ad4422617551..144bede6d31bf50b40e937f0c584cc87e89b5153 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -67,6 +67,7 @@ class Emcee(MCMCSampler): likelihood=likelihood, priors=priors, outdir=outdir, label=label, use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification, **kwargs) + self.emcee = self._check_version() self.resume = resume self.pos0 = pos0 self.nburn = nburn @@ -76,6 +77,14 @@ class Emcee(MCMCSampler): signal.signal(signal.SIGTERM, self.checkpoint_and_exit) signal.signal(signal.SIGINT, self.checkpoint_and_exit) + def _check_version(self): + import emcee + if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'): + self.prerelease = True + else: + self.prerelease = False + return emcee + def _translate_kwargs(self, kwargs): if 'nwalkers' not in kwargs: for equiv in self.nwalkers_equiv_kwargs: @@ -292,7 +301,6 @@ class Emcee(MCMCSampler): temp_chain_file = chain_file + '.temp' if os.path.isfile(chain_file): copyfile(chain_file, temp_chain_file) - if self.prerelease: points = np.hstack([sample.coords, sample.blobs]) else: @@ -362,7 +370,6 @@ class Emcee(MCMCSampler): self.checkpoint() self.result.sampler_output = np.nan - self.calculate_autocorrelation( self.sampler.chain.reshape((-1, self.ndim))) self.print_nburn_logging_info() diff --git a/bilby/core/sampler/kombine.py b/bilby/core/sampler/kombine.py new file mode 100644 index 0000000000000000000000000000000000000000..cd37070511ecf805e563d3d43a71b1a13dbb5e01 --- /dev/null +++ b/bilby/core/sampler/kombine.py @@ -0,0 +1,176 @@ +from __future__ import absolute_import, print_function +from ..utils import logger, get_progress_bar +import numpy as np +import os +from .emcee import Emcee +from .base_sampler import SamplerError + + +class Kombine(Emcee): + """bilby wrapper kombine (https://github.com/bfarr/kombine) + + All positional and keyword arguments (i.e., the args and kwargs) passed to + `run_sampler` will be propagated to `kombine.Sampler`, see + documentation for that class for further help. Under Other Parameters, we + list commonly used kwargs and the bilby defaults. + + Other Parameters + ---------------- + nwalkers: int, (500) + The number of walkers + iterations: int, (100) + The number of iterations + auto_burnin: bool (False) + Use `kombine`'s automatic burnin (at your own risk) + nburn: int (None) + If given, the fixed number of steps to discard as burn-in. These will + be discarded from the total number of steps set by `nsteps` and + therefore the value must be greater than `nsteps`. Else, nburn is + estimated from the autocorrelation time + burn_in_fraction: float, (0.25) + The fraction of steps to discard as burn-in in the event that the + autocorrelation time cannot be calculated + burn_in_act: float (3.) + The number of autocorrelation times to discard as burn-in + + + """ + + default_kwargs = dict(nwalkers=500, args=[], pool=None, transd=False, + lnpost0=None, blob0=None, iterations=500, storechain=True, processes=1, update_interval=None, + kde=None, kde_size=None, spaces=None, freeze_transd=False, test_steps=16, critical_pval=0.05, + max_steps=None, burnin_verbose=False) + + def __init__(self, likelihood, priors, outdir='outdir', label='label', + use_ratio=False, plot=False, skip_import_verification=False, + pos0=None, nburn=None, burn_in_fraction=0.25, resume=True, + burn_in_act=3, autoburnin=False, **kwargs): + super(Kombine, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label, + use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification, + pos0=pos0, nburn=nburn, burn_in_fraction=burn_in_fraction, + burn_in_act=burn_in_act, resume=resume, **kwargs) + + if self.kwargs['nwalkers'] > self.kwargs['iterations']: + raise ValueError("Kombine Sampler requires Iterations be > nWalkers") + self.autoburnin = autoburnin + + def _check_version(self): + # set prerelease to False to prevent checks for newer emcee versions in parent class + self.prerelease = False + + def _translate_kwargs(self, kwargs): + if 'nwalkers' not in kwargs: + for equiv in self.nwalkers_equiv_kwargs: + if equiv in kwargs: + kwargs['nwalkers'] = kwargs.pop(equiv) + if 'iterations' not in kwargs: + if 'nsteps' in kwargs: + kwargs['iterations'] = kwargs.pop('nsteps') + # make sure processes kwarg is 1 + if 'processes' in kwargs: + if kwargs['processes'] != 1: + logger.warning("The 'processes' argument cannot be used for " + "parallelisation. This run will proceed " + "without parallelisation, but consider the use " + "of an appropriate Pool object passed to the " + "'pool' keyword.") + kwargs['processes'] = 1 + + @property + def sampler_function_kwargs(self): + keys = ['lnpost0', 'blob0', 'iterations', 'storechain', 'lnprop0', 'update_interval', 'kde', + 'kde_size', 'spaces', 'freeze_transd'] + function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} + function_kwargs['p0'] = self.pos0 + return function_kwargs + + @property + def sampler_burnin_kwargs(self): + extra_keys = ['test_steps', 'critical_pval', 'max_steps', 'burnin_verbose'] + removal_keys = ['iterations', 'spaces', 'freeze_transd'] + burnin_kwargs = self.sampler_function_kwargs.copy() + for key in extra_keys: + if key in self.kwargs: + burnin_kwargs[key] = self.kwargs[key] + if 'burnin_verbose' in burnin_kwargs.keys(): + burnin_kwargs['verbose'] = burnin_kwargs.pop('burnin_verbose') + for key in removal_keys: + if key in burnin_kwargs.keys(): + burnin_kwargs.pop(key) + return burnin_kwargs + + @property + def sampler_init_kwargs(self): + init_kwargs = {key: value + for key, value in self.kwargs.items() + if key not in self.sampler_function_kwargs and key not in self.sampler_burnin_kwargs} + init_kwargs.pop("burnin_verbose") + init_kwargs['lnpostfn'] = self.lnpostfn + init_kwargs['ndim'] = self.ndim + + # have to make sure pool is None so sampler will be pickleable + init_kwargs['pool'] = None + return init_kwargs + + def _initialise_sampler(self): + import kombine + self._sampler = kombine.Sampler(**self.sampler_init_kwargs) + self._init_chain_file() + + def _set_pos0_for_resume(self): + # take last iteration + self.pos0 = self.sampler.chain[-1, :, :] + + @property + def sampler_chain(self): + # remove last iterations when resuming + nsteps = self._previous_iterations + return self.sampler.chain[:nsteps, :, :] + + def check_resume(self): + return self.resume and os.path.isfile(self.checkpoint_info.sampler_file) + + def run_sampler(self): + if self.autoburnin: + if self.check_resume(): + logger.info("Resuming with autoburnin=True skips burnin process:") + else: + logger.info("Running kombine sampler's automatic burnin process") + self.sampler.burnin(**self.sampler_burnin_kwargs) + self.kwargs["iterations"] += self._previous_iterations + self.nburn = self._previous_iterations + logger.info("Kombine auto-burnin complete. Removing {} samples from chains".format(self.nburn)) + self._set_pos0_for_resume() + + tqdm = get_progress_bar() + sampler_function_kwargs = self.sampler_function_kwargs + iterations = sampler_function_kwargs.pop('iterations') + iterations -= self._previous_iterations + sampler_function_kwargs['p0'] = self.pos0 + for sample in tqdm( + self.sampler.sample(iterations=iterations, **sampler_function_kwargs), + total=iterations): + self.write_chains_to_file(sample) + self.checkpoint() + self.result.sampler_output = np.nan + if not self.autoburnin: + tmp_chain = self.sampler.chain.copy() + self.calculate_autocorrelation(tmp_chain.reshape((-1, self.ndim))) + self.print_nburn_logging_info() + self.result.nburn = self.nburn + if self.result.nburn > self.nsteps: + raise SamplerError( + "The run has finished, but the chain is not burned in: " + "`nburn < nsteps`. Try increasing the number of steps.") + tmp_chain = self.sampler.chain[self.nburn:, :, :].copy() + self.result.samples = tmp_chain.reshape((-1, self.ndim)) + blobs = np.array(self.sampler.blobs) + blobs_trimmed = blobs[self.nburn:, :, :].reshape((-1, 2)) + self.calc_likelihood_count() + log_likelihoods, log_priors = blobs_trimmed.T + self.result.log_likelihood_evaluations = log_likelihoods + self.result.log_prior_evaluations = log_priors + self.result.walkers = self.sampler.chain.reshape((self.nwalkers, self.nsteps, self.ndim)) + self.result.log_evidence = np.nan + self.result.log_evidence_err = np.nan + return self.result diff --git a/sampler_requirements.txt b/sampler_requirements.txt index ee82c447b494093f00031f66b7ffd9515b569664..a6196d0e6aa04fecf495c6dcc290c03e27f83652 100644 --- a/sampler_requirements.txt +++ b/sampler_requirements.txt @@ -5,4 +5,5 @@ nestle ptemcee pymc3==3.6; python_version <= '2.7' pymc3>=3.6; python_version > '3.4' -pymultinest \ No newline at end of file +pymultinest +kombine \ No newline at end of file diff --git a/test/sampler_test.py b/test/sampler_test.py index e937a6452d25e4df75162aab7301caafc4951ea8..63237d1c106508d92349a519e51f9d5f0e572645 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -238,6 +238,41 @@ class TestEmcee(unittest.TestCase): self.assertDictEqual(expected, self.sampler.kwargs) +class TestKombine(unittest.TestCase): + + def setUp(self): + self.likelihood = MagicMock() + self.priors = dict() + self.sampler = bilby.core.sampler.Kombine(self.likelihood, self.priors, + outdir='outdir', label='label', + use_ratio=False, plot=False, + skip_import_verification=True) + + def tearDown(self): + del self.likelihood + del self.priors + del self.sampler + + def test_default_kwargs(self): + expected = dict(nwalkers=500, args=[], pool=None, transd=False, + lnpost0=None, blob0=None, iterations=500, storechain=True, processes=1, update_interval=None, + kde=None, kde_size=None, spaces=None, freeze_transd=False, test_steps=16, critical_pval=0.05, + max_steps=None, burnin_verbose=False) + self.assertDictEqual(expected, self.sampler.kwargs) + + def test_translate_kwargs(self): + expected = dict(nwalkers=400, args=[], pool=None, transd=False, + lnpost0=None, blob0=None, iterations=500, storechain=True, processes=1, update_interval=None, + kde=None, kde_size=None, spaces=None, freeze_transd=False, test_steps=16, critical_pval=0.05, + max_steps=None, burnin_verbose=False) + for equiv in bilby.core.sampler.base_sampler.MCMCSampler.nwalkers_equiv_kwargs: + new_kwargs = self.sampler.kwargs.copy() + del new_kwargs['nwalkers'] + new_kwargs[equiv] = 400 + self.sampler.kwargs = new_kwargs + self.assertDictEqual(expected, self.sampler.kwargs) + + class TestNestle(unittest.TestCase): def setUp(self): @@ -499,7 +534,12 @@ class TestRunningSamplers(unittest.TestCase): def test_run_emcee(self): _ = bilby.run_sampler( likelihood=self.likelihood, priors=self.priors, sampler='emcee', - nsteps=1000, nwalkers=10, save=False) + iterations=1000, nwalkers=10, save=False) + + def test_run_kombine(self): + _ = bilby.run_sampler( + likelihood=self.likelihood, priors=self.priors, sampler='kombine', + iterations=2500, nwalkers=100, save=False) def test_run_nestle(self): _ = bilby.run_sampler(