diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py index e9c5d2529cafe893a238669bc2da33393302ad01..27122cdf8dda1257fd35be4d2dd27741671af08e 100644 --- a/bilby/core/sampler/cpnest.py +++ b/bilby/core/sampler/cpnest.py @@ -40,7 +40,7 @@ class Cpnest(NestedSampler): """ default_kwargs = dict(verbose=1, nthreads=1, nlive=500, maxmcmc=1000, seed=None, poolsize=100, nhamiltonian=0, resume=True, - output=None, proposals=None) + output=None, proposals=None, n_periodic_checkpoint=None) def _translate_kwargs(self, kwargs): if 'nlive' not in kwargs: @@ -84,17 +84,22 @@ class Cpnest(NestedSampler): self._resolve_proposal_functions() model = Model(self.search_parameter_keys, self.priors) - try: - out = CPNest(model, **self.kwargs) - except TypeError as e: - if 'proposals' in self.kwargs.keys(): - logger.warning('YOU ARE TRYING TO USE PROPOSALS IN A VERSION OF CPNEST THAT DOES' - 'NOT ACCEPT CUSTOM PROPOSALS. SAMPLING WILL COMMENCE WITH THE DEFAULT' - 'PROPOSALS.') - del self.kwargs['proposals'] + out = None + remove_kwargs = ["proposals", "n_periodic_checkpoint"] + while out is None: + try: out = CPNest(model, **self.kwargs) - else: - raise TypeError(e) + except TypeError as e: + if len(remove_kwargs) > 0: + kwarg = remove_kwargs.pop(0) + else: + raise TypeError("Unable to initialise cpnest sampler") + logger.info( + "CPNest init. failed with error {}, please update" + .format(e)) + logger.info( + "Attempting to rerun with kwarg {} removed".format(kwarg)) + self.kwargs.pop(kwarg) out.run() if self.plot: diff --git a/test/sampler_test.py b/test/sampler_test.py index 9fc9060179f24b2cf2e8b0a54fb9275b3d26edfa..9657cb2864a4efba10dd8b370d775f3bb292536c 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -112,13 +112,15 @@ class TestCPNest(unittest.TestCase): def test_default_kwargs(self): expected = dict(verbose=1, nthreads=1, nlive=500, maxmcmc=1000, seed=None, poolsize=100, nhamiltonian=0, resume=True, - output='outdir/cpnest_label/', proposals=None) + output='outdir/cpnest_label/', proposals=None, + n_periodic_checkpoint=None) self.assertDictEqual(expected, self.sampler.kwargs) def test_translate_kwargs(self): expected = dict(verbose=1, nthreads=1, nlive=250, maxmcmc=1000, seed=None, poolsize=100, nhamiltonian=0, resume=True, - output='outdir/cpnest_label/', proposals=None) + output='outdir/cpnest_label/', proposals=None, + n_periodic_checkpoint=None) for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: new_kwargs = self.sampler.kwargs.copy() del new_kwargs['nlive']