From c29fc37fd173793f9b50581e1392f4cb6e506c9b Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Mon, 6 Sep 2021 04:10:52 +0000 Subject: [PATCH] Add pool to compute_snrs --- bilby/gw/conversion.py | 65 +++++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index bec64ef87..77772ad64 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -829,7 +829,7 @@ def _generate_all_cbc_parameters(sample, defaults, base_conversion, .format(type(output_sample)) ) if likelihood is not None: - compute_snrs(output_sample, likelihood) + compute_snrs(output_sample, likelihood, npool=npool) for key, func in zip(["mass", "spin", "source frame"], [ generate_mass_parameters, generate_spin_parameters, generate_source_frame_parameters]): @@ -1130,7 +1130,7 @@ def generate_source_frame_parameters(sample): return output_sample -def compute_snrs(sample, likelihood): +def compute_snrs(sample, likelihood, npool=1): """ Compute the optimal and matched filter snrs of all posterior samples and print it out. @@ -1157,36 +1157,49 @@ def compute_snrs(sample, likelihood): per_detector_snr.optimal_snr_squared.real ** 0.5 else: from tqdm.auto import tqdm - logger.info( - 'Computing SNRs for every sample.') - - matched_filter_snrs = { - ifo.name: [] for ifo in likelihood.interferometers} - optimal_snrs = {ifo.name: [] for ifo in likelihood.interferometers} - for ii in tqdm(range(len(sample)), file=sys.stdout): - signal_polarizations =\ - likelihood.waveform_generator.frequency_domain_strain( - dict(sample.iloc[ii])) - likelihood.parameters.update(sample.iloc[ii]) - for ifo in likelihood.interferometers: - per_detector_snr = likelihood.calculate_snrs( - signal_polarizations, ifo) - - matched_filter_snrs[ifo.name].append( - per_detector_snr.complex_matched_filter_snr) - optimal_snrs[ifo.name].append( - per_detector_snr.optimal_snr_squared.real ** 0.5) + logger.info('Computing SNRs for every sample.') - for ifo in likelihood.interferometers: - sample['{}_matched_filter_snr'.format(ifo.name)] =\ - matched_filter_snrs[ifo.name] - sample['{}_optimal_snr'.format(ifo.name)] =\ - optimal_snrs[ifo.name] + fill_args = [(ii, row, likelihood) for ii, row in sample.iterrows()] + if npool > 1: + pool = multiprocessing.Pool(processes=npool) + logger.info( + "Using a pool with size {} for nsamples={}".format(npool, len(sample)) + ) + new_samples = np.array(pool.map(_compute_snrs, tqdm(fill_args, file=sys.stdout))) + pool.close() + else: + new_samples = np.array([_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)]) + + for ii, ifo in enumerate(likelihood.interferometers): + matched_filter_snrs = list() + optimal_snrs = list() + mf_key = '{}_matched_filter_snr'.format(ifo.name) + optimal_key = '{}_optimal_snr'.format(ifo.name) + for new_sample in new_samples: + matched_filter_snrs.append(new_sample[ii].complex_matched_filter_snr) + optimal_snrs.append(new_sample[ii].optimal_snr_squared.real ** 0.5) + + sample[mf_key] = matched_filter_snrs + sample[optimal_key] = optimal_snrs else: logger.debug('Not computing SNRs.') +def _compute_snrs(args): + """A wrapper of computing the SNRs to enable multiprocessing""" + ii, sample, likelihood = args + sample = dict(sample).copy() + signal_polarizations = likelihood.waveform_generator.frequency_domain_strain( + sample + ) + likelihood.parameters.update(sample) + snrs = list() + for ifo in likelihood.interferometers: + snrs.append(likelihood.calculate_snrs(signal_polarizations, ifo)) + return snrs + + def generate_posterior_samples_from_marginalized_likelihood( samples, likelihood, npool=1): """ -- GitLab