From e28ba213b83177ffa943cb8b354788c60708628e Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Fri, 12 Aug 2022 13:23:17 +0000 Subject: [PATCH] Improve pool post processing --- bilby/bilby_mcmc/sampler.py | 3 +++ bilby/gw/conversion.py | 30 ++++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py index ec6c2abaa..555e32149 100644 --- a/bilby/bilby_mcmc/sampler.py +++ b/bilby/bilby_mcmc/sampler.py @@ -1075,6 +1075,9 @@ class BilbyMCMCSampler(object): Eindex=0, use_ratio=False, ): + from ..core.sampler.base_sampler import _sampling_convenience_dump + + self._sampling_helper = _sampling_convenience_dump self.beta = beta self.Tindex = Tindex self.Eindex = Eindex diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 24ae47ba4..5955aa0f7 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -1381,9 +1381,14 @@ def compute_snrs(sample, likelihood, npool=1): from tqdm.auto import tqdm logger.info('Computing SNRs for every sample.') - fill_args = [(ii, row, likelihood) for ii, row in sample.iterrows()] + fill_args = [(ii, row) for ii, row in sample.iterrows()] if npool > 1: - pool = multiprocessing.Pool(processes=npool) + from ..core.sampler.base_sampler import _initialize_global_variables + pool = multiprocessing.Pool( + processes=npool, + initializer=_initialize_global_variables, + initargs=(likelihood, None, None, False), + ) logger.info( "Using a pool with size {} for nsamples={}".format(npool, len(sample)) ) @@ -1391,6 +1396,8 @@ def compute_snrs(sample, likelihood, npool=1): pool.close() pool.join() else: + from ..core.sampler.base_sampler import _sampling_convenience_dump + _sampling_convenience_dump.likelihood = likelihood new_samples = [_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)] for ii, ifo in enumerate(likelihood.interferometers): @@ -1411,7 +1418,9 @@ def compute_snrs(sample, likelihood, npool=1): def _compute_snrs(args): """A wrapper of computing the SNRs to enable multiprocessing""" - ii, sample, likelihood = args + from ..core.sampler.base_sampler import _sampling_convenience_dump + likelihood = _sampling_convenience_dump.likelihood + ii, sample = args sample = dict(sample).copy() likelihood.parameters.update(sample) signal_polarizations = likelihood.waveform_generator.frequency_domain_strain( @@ -1494,15 +1503,22 @@ def generate_posterior_samples_from_marginalized_likelihood( # Set up the multiprocessing if npool > 1: - pool = multiprocessing.Pool(processes=npool) + from ..core.sampler.base_sampler import _initialize_global_variables + pool = multiprocessing.Pool( + processes=npool, + initializer=_initialize_global_variables, + initargs=(likelihood, None, None, False), + ) logger.info( "Using a pool with size {} for nsamples={}" .format(npool, len(samples)) ) else: + from ..core.sampler.base_sampler import _sampling_convenience_dump + _sampling_convenience_dump.likelihood = likelihood pool = None - fill_args = [(ii, row, likelihood) for ii, row in samples.iterrows()] + fill_args = [(ii, row) for ii, row in samples.iterrows()] ii = 0 pbar = tqdm(total=len(samples), file=sys.stdout) while ii < len(samples): @@ -1561,7 +1577,9 @@ def generate_sky_frame_parameters(samples, likelihood): def fill_sample(args): - ii, sample, likelihood = args + from ..core.sampler.base_sampler import _sampling_convenience_dump + likelihood = _sampling_convenience_dump.likelihood + ii, sample = args marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list()) sample = dict(sample).copy() likelihood.parameters.update(dict(sample).copy()) -- GitLab