Skip to content
Snippets Groups Projects
Commit e28ba213 authored by Colm Talbot's avatar Colm Talbot
Browse files

Improve pool post processing

parent 602eaa64
No related branches found
No related tags found
1 merge request!1096Improve pool post processing
...@@ -1075,6 +1075,9 @@ class BilbyMCMCSampler(object): ...@@ -1075,6 +1075,9 @@ class BilbyMCMCSampler(object):
Eindex=0, Eindex=0,
use_ratio=False, use_ratio=False,
): ):
from ..core.sampler.base_sampler import _sampling_convenience_dump
self._sampling_helper = _sampling_convenience_dump
self.beta = beta self.beta = beta
self.Tindex = Tindex self.Tindex = Tindex
self.Eindex = Eindex self.Eindex = Eindex
......
...@@ -1381,9 +1381,14 @@ def compute_snrs(sample, likelihood, npool=1): ...@@ -1381,9 +1381,14 @@ def compute_snrs(sample, likelihood, npool=1):
from tqdm.auto import tqdm from tqdm.auto import tqdm
logger.info('Computing SNRs for every sample.') logger.info('Computing SNRs for every sample.')
fill_args = [(ii, row, likelihood) for ii, row in sample.iterrows()] fill_args = [(ii, row) for ii, row in sample.iterrows()]
if npool > 1: if npool > 1:
pool = multiprocessing.Pool(processes=npool) from ..core.sampler.base_sampler import _initialize_global_variables
pool = multiprocessing.Pool(
processes=npool,
initializer=_initialize_global_variables,
initargs=(likelihood, None, None, False),
)
logger.info( logger.info(
"Using a pool with size {} for nsamples={}".format(npool, len(sample)) "Using a pool with size {} for nsamples={}".format(npool, len(sample))
) )
...@@ -1391,6 +1396,8 @@ def compute_snrs(sample, likelihood, npool=1): ...@@ -1391,6 +1396,8 @@ def compute_snrs(sample, likelihood, npool=1):
pool.close() pool.close()
pool.join() pool.join()
else: else:
from ..core.sampler.base_sampler import _sampling_convenience_dump
_sampling_convenience_dump.likelihood = likelihood
new_samples = [_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)] new_samples = [_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)]
for ii, ifo in enumerate(likelihood.interferometers): for ii, ifo in enumerate(likelihood.interferometers):
...@@ -1411,7 +1418,9 @@ def compute_snrs(sample, likelihood, npool=1): ...@@ -1411,7 +1418,9 @@ def compute_snrs(sample, likelihood, npool=1):
def _compute_snrs(args): def _compute_snrs(args):
"""A wrapper of computing the SNRs to enable multiprocessing""" """A wrapper of computing the SNRs to enable multiprocessing"""
ii, sample, likelihood = args from ..core.sampler.base_sampler import _sampling_convenience_dump
likelihood = _sampling_convenience_dump.likelihood
ii, sample = args
sample = dict(sample).copy() sample = dict(sample).copy()
likelihood.parameters.update(sample) likelihood.parameters.update(sample)
signal_polarizations = likelihood.waveform_generator.frequency_domain_strain( signal_polarizations = likelihood.waveform_generator.frequency_domain_strain(
...@@ -1494,15 +1503,22 @@ def generate_posterior_samples_from_marginalized_likelihood( ...@@ -1494,15 +1503,22 @@ def generate_posterior_samples_from_marginalized_likelihood(
# Set up the multiprocessing # Set up the multiprocessing
if npool > 1: if npool > 1:
pool = multiprocessing.Pool(processes=npool) from ..core.sampler.base_sampler import _initialize_global_variables
pool = multiprocessing.Pool(
processes=npool,
initializer=_initialize_global_variables,
initargs=(likelihood, None, None, False),
)
logger.info( logger.info(
"Using a pool with size {} for nsamples={}" "Using a pool with size {} for nsamples={}"
.format(npool, len(samples)) .format(npool, len(samples))
) )
else: else:
from ..core.sampler.base_sampler import _sampling_convenience_dump
_sampling_convenience_dump.likelihood = likelihood
pool = None pool = None
fill_args = [(ii, row, likelihood) for ii, row in samples.iterrows()] fill_args = [(ii, row) for ii, row in samples.iterrows()]
ii = 0 ii = 0
pbar = tqdm(total=len(samples), file=sys.stdout) pbar = tqdm(total=len(samples), file=sys.stdout)
while ii < len(samples): while ii < len(samples):
...@@ -1561,7 +1577,9 @@ def generate_sky_frame_parameters(samples, likelihood): ...@@ -1561,7 +1577,9 @@ def generate_sky_frame_parameters(samples, likelihood):
def fill_sample(args): def fill_sample(args):
ii, sample, likelihood = args from ..core.sampler.base_sampler import _sampling_convenience_dump
likelihood = _sampling_convenience_dump.likelihood
ii, sample = args
marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list()) marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list())
sample = dict(sample).copy() sample = dict(sample).copy()
likelihood.parameters.update(dict(sample).copy()) likelihood.parameters.update(dict(sample).copy())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment