diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 9f028c9fc91901344434aaf434c48d190daa50cf..56e6887d188ccb78e16e614799f74cd324a7e8b9 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function import numpy as np from pandas import DataFrame +from distutils.version import LooseVersion from ..utils import logger, get_progress_bar from .base_sampler import MCMCSampler, SamplerError @@ -38,7 +39,7 @@ class Emcee(MCMCSampler): """ default_kwargs = dict(nwalkers=500, a=2, args=[], kwargs={}, - postargs=None, threads=1, pool=None, live_dangerously=False, + postargs=None, pool=None, live_dangerously=False, runtime_sortingfn=None, lnprob0=None, rstate0=None, blobs0=None, iterations=100, thin=1, storechain=True, mh_proposal=None) @@ -62,17 +63,69 @@ class Emcee(MCMCSampler): if 'iterations' not in kwargs: if 'nsteps' in kwargs: kwargs['iterations'] = kwargs.pop('nsteps') + if 'threads' in kwargs: + if kwargs['threads'] != 1: + logger.warning("The 'threads' argument cannot be used for " + "parallelisation. This run will proceed " + "without parallelisation, but consider the use " + "of an appropriate Pool object passed to the " + "'pool' keyword.") + kwargs['threads'] = 1 @property def sampler_function_kwargs(self): + import emcee + keys = ['lnprob0', 'rstate0', 'blobs0', 'iterations', 'thin', 'storechain', 'mh_proposal'] - return {key: self.kwargs[key] for key in keys} + + # updated function keywords for emcee > v2.2.1 + updatekeys = {'p0': 'initial_state', + 'lnprob0': 'log_prob0', + 'storechain': 'store'} + + function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} + function_kwargs['p0'] = self.pos0 + + if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'): + if function_kwargs['mh_proposal'] is not None: + logger.warning("The 'mh_proposal' option is no longer used " + "in emcee v{}, and will be ignored.".format(emcee.__version__)) + del function_kwargs['mh_proposal'] + + for key in updatekeys: + if updatekeys[key] not in function_kwargs: + function_kwargs[updatekeys[key]] = function_kwargs.pop(key) + else: + del function_kwargs[key] + + return function_kwargs @property def sampler_init_kwargs(self): - return {key: value - for key, value in self.kwargs.items() - if key not in self.sampler_function_kwargs} + import emcee + + init_kwargs = {key: value + for key, value in self.kwargs.items() + if key not in self.sampler_function_kwargs} + + init_kwargs['lnpostfn'] = self.lnpostfn + init_kwargs['dim'] = self.ndim + + # updated init keywords for emcee > v2.2.1 + updatekeys = {'dim': 'ndim', + 'lnpostfn': 'log_prob_fn'} + + if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'): + for key in updatekeys: + if key in init_kwargs: + init_kwargs[updatekeys[key]] = init_kwargs.pop(key) + + oldfunckeys = ['p0', 'lnprob0', 'storechain', 'mh_proposal'] + for key in oldfunckeys: + if key in init_kwargs: + del init_kwargs[key] + + return init_kwargs @property def nburn(self): @@ -104,12 +157,19 @@ class Emcee(MCMCSampler): def nsteps(self, nsteps): self.kwargs['iterations'] = nsteps + def __getstate__(self): + # In order to be picklable with dill, we need to discard the pool + # object before trying. + d = self.__dict__ + d["_Sampler__kwargs"]["pool"] = None + return d + def run_sampler(self): import emcee tqdm = get_progress_bar() - sampler = emcee.EnsembleSampler(dim=self.ndim, lnpostfn=self.lnpostfn, **self.sampler_init_kwargs) + sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs) self._set_pos0() - for _ in tqdm(sampler.sample(p0=self.pos0, **self.sampler_function_kwargs), + for _ in tqdm(sampler.sample(**self.sampler_function_kwargs), total=self.nsteps): pass self.result.sampler_output = np.nan diff --git a/bilby/gw/detector.py b/bilby/gw/detector.py index 38cda74fe96d52d5d5c52aa49fc10c6d7c801bff..599f492462a458d2393c2f0b59ac84a71a7b41db 100644 --- a/bilby/gw/detector.py +++ b/bilby/gw/detector.py @@ -1304,7 +1304,7 @@ class Interferometer(object): `waveform_generator.frequency_domain_strain()`. If `waveform_generator` is also given, the injection_polarizations will be calculated directly and this argument can be ignored. - waveform_generator: bilby.gw.waveform_generator + waveform_generator: bilby.gw.waveform_generator.WaveformGenerator A WaveformGenerator instance using the source model to inject. If `injection_polarizations` is given, this will be ignored. @@ -2158,7 +2158,7 @@ def get_interferometer_with_fake_noise_and_injection( `waveform_generator.frequency_domain_strain()`. If `waveform_generator` is also given, the injection_polarizations will be calculated directly and this argument can be ignored. - waveform_generator: bilby.gw.waveform_generator + waveform_generator: bilby.gw.waveform_generator.WaveformGenerator A WaveformGenerator instance using the source model to inject. If `injection_polarizations` is given, this will be ignored. sampling_frequency: float diff --git a/test/sampler_test.py b/test/sampler_test.py index cd40784c786189ae8dd640889c9330b8b0ef4dac..0ffbb2bac84f0096fa8dc7e20052b0ad1ee09b78 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -191,7 +191,7 @@ class TestEmcee(unittest.TestCase): def test_default_kwargs(self): expected = dict(nwalkers=500, a=2, args=[], kwargs={}, - postargs=None, threads=1, pool=None, live_dangerously=False, + postargs=None, pool=None, live_dangerously=False, runtime_sortingfn=None, lnprob0=None, rstate0=None, blobs0=None, iterations=100, thin=1, storechain=True, mh_proposal=None ) @@ -199,7 +199,7 @@ class TestEmcee(unittest.TestCase): def test_translate_kwargs(self): expected = dict(nwalkers=100, a=2, args=[], kwargs={}, - postargs=None, threads=1, pool=None, live_dangerously=False, + postargs=None, pool=None, live_dangerously=False, runtime_sortingfn=None, lnprob0=None, rstate0=None, blobs0=None, iterations=100, thin=1, storechain=True, mh_proposal=None) for equiv in bilby.core.sampler.base_sampler.MCMCSampler.nwalkers_equiv_kwargs: