From b7d0696756f4485f8d1f0fa3747867acc1324c71 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Wed, 3 Apr 2019 15:05:32 +1100 Subject: [PATCH] Minor bug fixing --- bilby/core/sampler/emcee.py | 5 ++++- bilby/core/sampler/ptemcee.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 2f76850fe..993ab4836 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -274,6 +274,8 @@ class Emcee(MCMCSampler): if hasattr(self, '_sampler'): pass elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file): + logger.info("Resuming run from checkpoint file {}" + .format(self.checkpoint_info.sampler_file)) with open(self.checkpoint_info.sampler_file, 'rb') as f: self._sampler = pickle.load(f) self._set_pos0_for_resume() @@ -335,13 +337,14 @@ class Emcee(MCMCSampler): iterations = sampler_function_kwargs.pop('iterations') iterations -= self._previous_iterations - print('pos0', self.pos0) sampler_function_kwargs['p0'] = self.pos0 + # main iteration loop for sample in tqdm( self.sampler.sample(iterations=iterations, **sampler_function_kwargs), total=iterations): self.write_chains_to_file(sample) + self.checkpoint() self.result.sampler_output = np.nan blobs_flat = np.array(self.sampler.blobs).reshape((-1, 2)) diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index 9bcd9e7f2..37eb22201 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -42,7 +42,7 @@ class Ptemcee(Emcee): label=label, use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification, nburn=nburn, burn_in_fraction=burn_in_fraction, - burn_in_act=burn_in_act, resume=True, **kwargs) + burn_in_act=burn_in_act, resume=resume, **kwargs) @property def sampler_function_kwargs(self): @@ -116,6 +116,7 @@ class Ptemcee(Emcee): **sampler_function_kwargs), total=iterations): self.write_chains_to_file(pos, loglike, logpost) + self.checkpoint() self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim))) self.result.sampler_output = np.nan -- GitLab