Skip to content
Snippets Groups Projects
Commit e28ba213 authored by Colm Talbot's avatar Colm Talbot
Browse files

Improve pool post processing

parent 602eaa64
Branches release/2.0.x
No related tags found
1 merge request!1096Improve pool post processing
......@@ -1075,6 +1075,9 @@ class BilbyMCMCSampler(object):
Eindex=0,
use_ratio=False,
):
from ..core.sampler.base_sampler import _sampling_convenience_dump
self._sampling_helper = _sampling_convenience_dump
self.beta = beta
self.Tindex = Tindex
self.Eindex = Eindex
......
......@@ -1381,9 +1381,14 @@ def compute_snrs(sample, likelihood, npool=1):
from tqdm.auto import tqdm
logger.info('Computing SNRs for every sample.')
fill_args = [(ii, row, likelihood) for ii, row in sample.iterrows()]
fill_args = [(ii, row) for ii, row in sample.iterrows()]
if npool > 1:
pool = multiprocessing.Pool(processes=npool)
from ..core.sampler.base_sampler import _initialize_global_variables
pool = multiprocessing.Pool(
processes=npool,
initializer=_initialize_global_variables,
initargs=(likelihood, None, None, False),
)
logger.info(
"Using a pool with size {} for nsamples={}".format(npool, len(sample))
)
......@@ -1391,6 +1396,8 @@ def compute_snrs(sample, likelihood, npool=1):
pool.close()
pool.join()
else:
from ..core.sampler.base_sampler import _sampling_convenience_dump
_sampling_convenience_dump.likelihood = likelihood
new_samples = [_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)]
for ii, ifo in enumerate(likelihood.interferometers):
......@@ -1411,7 +1418,9 @@ def compute_snrs(sample, likelihood, npool=1):
def _compute_snrs(args):
"""A wrapper of computing the SNRs to enable multiprocessing"""
ii, sample, likelihood = args
from ..core.sampler.base_sampler import _sampling_convenience_dump
likelihood = _sampling_convenience_dump.likelihood
ii, sample = args
sample = dict(sample).copy()
likelihood.parameters.update(sample)
signal_polarizations = likelihood.waveform_generator.frequency_domain_strain(
......@@ -1494,15 +1503,22 @@ def generate_posterior_samples_from_marginalized_likelihood(
# Set up the multiprocessing
if npool > 1:
pool = multiprocessing.Pool(processes=npool)
from ..core.sampler.base_sampler import _initialize_global_variables
pool = multiprocessing.Pool(
processes=npool,
initializer=_initialize_global_variables,
initargs=(likelihood, None, None, False),
)
logger.info(
"Using a pool with size {} for nsamples={}"
.format(npool, len(samples))
)
else:
from ..core.sampler.base_sampler import _sampling_convenience_dump
_sampling_convenience_dump.likelihood = likelihood
pool = None
fill_args = [(ii, row, likelihood) for ii, row in samples.iterrows()]
fill_args = [(ii, row) for ii, row in samples.iterrows()]
ii = 0
pbar = tqdm(total=len(samples), file=sys.stdout)
while ii < len(samples):
......@@ -1561,7 +1577,9 @@ def generate_sky_frame_parameters(samples, likelihood):
def fill_sample(args):
ii, sample, likelihood = args
from ..core.sampler.base_sampler import _sampling_convenience_dump
likelihood = _sampling_convenience_dump.likelihood
ii, sample = args
marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list())
sample = dict(sample).copy()
likelihood.parameters.update(dict(sample).copy())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment