diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py index ec6c2abaaf424f8904149cebdde60095262fe1a0..555e321496cc1273db669daf43f21e86ea485a11 100644 --- a/bilby/bilby_mcmc/sampler.py +++ b/bilby/bilby_mcmc/sampler.py @@ -1075,6 +1075,9 @@ class BilbyMCMCSampler(object): Eindex=0, use_ratio=False, ): + from ..core.sampler.base_sampler import _sampling_convenience_dump + + self._sampling_helper = _sampling_convenience_dump self.beta = beta self.Tindex = Tindex self.Eindex = Eindex diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 24ae47ba406f2d46e10d7c781c1dcc1f5083f45f..5955aa0f7d58906a11f226438a57b41e5462bb01 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -1381,9 +1381,14 @@ def compute_snrs(sample, likelihood, npool=1): from tqdm.auto import tqdm logger.info('Computing SNRs for every sample.') - fill_args = [(ii, row, likelihood) for ii, row in sample.iterrows()] + fill_args = [(ii, row) for ii, row in sample.iterrows()] if npool > 1: - pool = multiprocessing.Pool(processes=npool) + from ..core.sampler.base_sampler import _initialize_global_variables + pool = multiprocessing.Pool( + processes=npool, + initializer=_initialize_global_variables, + initargs=(likelihood, None, None, False), + ) logger.info( "Using a pool with size {} for nsamples={}".format(npool, len(sample)) ) @@ -1391,6 +1396,8 @@ def compute_snrs(sample, likelihood, npool=1): pool.close() pool.join() else: + from ..core.sampler.base_sampler import _sampling_convenience_dump + _sampling_convenience_dump.likelihood = likelihood new_samples = [_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)] for ii, ifo in enumerate(likelihood.interferometers): @@ -1411,7 +1418,9 @@ def compute_snrs(sample, likelihood, npool=1): def _compute_snrs(args): """A wrapper of computing the SNRs to enable multiprocessing""" - ii, sample, likelihood = args + from ..core.sampler.base_sampler import _sampling_convenience_dump + likelihood = _sampling_convenience_dump.likelihood + ii, sample = args sample = dict(sample).copy() likelihood.parameters.update(sample) signal_polarizations = likelihood.waveform_generator.frequency_domain_strain( @@ -1494,15 +1503,22 @@ def generate_posterior_samples_from_marginalized_likelihood( # Set up the multiprocessing if npool > 1: - pool = multiprocessing.Pool(processes=npool) + from ..core.sampler.base_sampler import _initialize_global_variables + pool = multiprocessing.Pool( + processes=npool, + initializer=_initialize_global_variables, + initargs=(likelihood, None, None, False), + ) logger.info( "Using a pool with size {} for nsamples={}" .format(npool, len(samples)) ) else: + from ..core.sampler.base_sampler import _sampling_convenience_dump + _sampling_convenience_dump.likelihood = likelihood pool = None - fill_args = [(ii, row, likelihood) for ii, row in samples.iterrows()] + fill_args = [(ii, row) for ii, row in samples.iterrows()] ii = 0 pbar = tqdm(total=len(samples), file=sys.stdout) while ii < len(samples): @@ -1561,7 +1577,9 @@ def generate_sky_frame_parameters(samples, likelihood): def fill_sample(args): - ii, sample, likelihood = args + from ..core.sampler.base_sampler import _sampling_convenience_dump + likelihood = _sampling_convenience_dump.likelihood + ii, sample = args marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list()) sample = dict(sample).copy() likelihood.parameters.update(dict(sample).copy())