Skip to content
Snippets Groups Projects
Commit c29fc37f authored by Colm Talbot's avatar Colm Talbot Committed by Moritz Huebner
Browse files

Add pool to compute_snrs

parent 5e32d717
No related branches found
No related tags found
1 merge request!1013Add pool to compute_snrs
......@@ -829,7 +829,7 @@ def _generate_all_cbc_parameters(sample, defaults, base_conversion,
.format(type(output_sample))
)
if likelihood is not None:
compute_snrs(output_sample, likelihood)
compute_snrs(output_sample, likelihood, npool=npool)
for key, func in zip(["mass", "spin", "source frame"], [
generate_mass_parameters, generate_spin_parameters,
generate_source_frame_parameters]):
......@@ -1130,7 +1130,7 @@ def generate_source_frame_parameters(sample):
return output_sample
def compute_snrs(sample, likelihood):
def compute_snrs(sample, likelihood, npool=1):
"""
Compute the optimal and matched filter snrs of all posterior samples
and print it out.
......@@ -1157,36 +1157,49 @@ def compute_snrs(sample, likelihood):
per_detector_snr.optimal_snr_squared.real ** 0.5
else:
from tqdm.auto import tqdm
logger.info(
'Computing SNRs for every sample.')
matched_filter_snrs = {
ifo.name: [] for ifo in likelihood.interferometers}
optimal_snrs = {ifo.name: [] for ifo in likelihood.interferometers}
for ii in tqdm(range(len(sample)), file=sys.stdout):
signal_polarizations =\
likelihood.waveform_generator.frequency_domain_strain(
dict(sample.iloc[ii]))
likelihood.parameters.update(sample.iloc[ii])
for ifo in likelihood.interferometers:
per_detector_snr = likelihood.calculate_snrs(
signal_polarizations, ifo)
matched_filter_snrs[ifo.name].append(
per_detector_snr.complex_matched_filter_snr)
optimal_snrs[ifo.name].append(
per_detector_snr.optimal_snr_squared.real ** 0.5)
logger.info('Computing SNRs for every sample.')
for ifo in likelihood.interferometers:
sample['{}_matched_filter_snr'.format(ifo.name)] =\
matched_filter_snrs[ifo.name]
sample['{}_optimal_snr'.format(ifo.name)] =\
optimal_snrs[ifo.name]
fill_args = [(ii, row, likelihood) for ii, row in sample.iterrows()]
if npool > 1:
pool = multiprocessing.Pool(processes=npool)
logger.info(
"Using a pool with size {} for nsamples={}".format(npool, len(sample))
)
new_samples = np.array(pool.map(_compute_snrs, tqdm(fill_args, file=sys.stdout)))
pool.close()
else:
new_samples = np.array([_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)])
for ii, ifo in enumerate(likelihood.interferometers):
matched_filter_snrs = list()
optimal_snrs = list()
mf_key = '{}_matched_filter_snr'.format(ifo.name)
optimal_key = '{}_optimal_snr'.format(ifo.name)
for new_sample in new_samples:
matched_filter_snrs.append(new_sample[ii].complex_matched_filter_snr)
optimal_snrs.append(new_sample[ii].optimal_snr_squared.real ** 0.5)
sample[mf_key] = matched_filter_snrs
sample[optimal_key] = optimal_snrs
else:
logger.debug('Not computing SNRs.')
def _compute_snrs(args):
"""A wrapper of computing the SNRs to enable multiprocessing"""
ii, sample, likelihood = args
sample = dict(sample).copy()
signal_polarizations = likelihood.waveform_generator.frequency_domain_strain(
sample
)
likelihood.parameters.update(sample)
snrs = list()
for ifo in likelihood.interferometers:
snrs.append(likelihood.calculate_snrs(signal_polarizations, ifo))
return snrs
def generate_posterior_samples_from_marginalized_likelihood(
samples, likelihood, npool=1):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment