diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py index 6485ce1a00bfddb52e0d5e11d608ba20044a6ef8..3decaf74a443d8668212e6a32c9cb4a44850f1a0 100644 --- a/bilby/bilby_mcmc/sampler.py +++ b/bilby/bilby_mcmc/sampler.py @@ -1076,6 +1076,9 @@ class BilbyMCMCSampler(object): Eindex=0, use_ratio=False, ): + from ..core.sampler.base_sampler import _sampling_convenience_dump + + self._sampling_helper = _sampling_convenience_dump self.beta = beta self.Tindex = Tindex self.Eindex = Eindex diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index 107822828c52ed32111a9462c1d8f4325b143719..c4dcc36827fd844853095dcbbcafd69f2585c40a 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -371,7 +371,7 @@ class DirichletElement(ConditionalBeta): self._required_variables = [ label + str(ii) for ii in range(order) ] - self.__class__.__name__ = 'Dirichlet' + self.__class__.__name__ = 'DirichletElement' def dirichlet_condition(self, reference_parms, **kwargs): remaining = 1 - sum( diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 24ae47ba406f2d46e10d7c781c1dcc1f5083f45f..46e9127103f95e68430896503680809a451b2398 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -1381,9 +1381,14 @@ def compute_snrs(sample, likelihood, npool=1): from tqdm.auto import tqdm logger.info('Computing SNRs for every sample.') - fill_args = [(ii, row, likelihood) for ii, row in sample.iterrows()] + fill_args = [(ii, row) for ii, row in sample.iterrows()] if npool > 1: - pool = multiprocessing.Pool(processes=npool) + from ..core.sampler.base_sampler import _initialize_global_variables + pool = multiprocessing.Pool( + processes=npool, + initializer=_initialize_global_variables, + initargs=(likelihood, None, None, False), + ) logger.info( "Using a pool with size {} for nsamples={}".format(npool, len(sample)) ) @@ -1391,6 +1396,8 @@ def compute_snrs(sample, likelihood, npool=1): pool.close() pool.join() else: + from ..core.sampler.base_sampler import _sampling_convenience_dump + _sampling_convenience_dump.likelihood = likelihood new_samples = [_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)] for ii, ifo in enumerate(likelihood.interferometers): @@ -1411,7 +1418,9 @@ def compute_snrs(sample, likelihood, npool=1): def _compute_snrs(args): """A wrapper of computing the SNRs to enable multiprocessing""" - ii, sample, likelihood = args + from ..core.sampler.base_sampler import _sampling_convenience_dump + likelihood = _sampling_convenience_dump.likelihood + ii, sample = args sample = dict(sample).copy() likelihood.parameters.update(sample) signal_polarizations = likelihood.waveform_generator.frequency_domain_strain( @@ -1494,15 +1503,22 @@ def generate_posterior_samples_from_marginalized_likelihood( # Set up the multiprocessing if npool > 1: - pool = multiprocessing.Pool(processes=npool) + from ..core.sampler.base_sampler import _initialize_global_variables + pool = multiprocessing.Pool( + processes=npool, + initializer=_initialize_global_variables, + initargs=(likelihood, None, None, False), + ) logger.info( "Using a pool with size {} for nsamples={}" .format(npool, len(samples)) ) else: + from ..core.sampler.base_sampler import _sampling_convenience_dump + _sampling_convenience_dump.likelihood = likelihood pool = None - fill_args = [(ii, row, likelihood) for ii, row in samples.iterrows()] + fill_args = [(ii, row) for ii, row in samples.iterrows()] ii = 0 pbar = tqdm(total=len(samples), file=sys.stdout) while ii < len(samples): @@ -1561,9 +1577,11 @@ def generate_sky_frame_parameters(samples, likelihood): def fill_sample(args): - ii, sample, likelihood = args + from ..core.sampler.base_sampler import _sampling_convenience_dump + likelihood = _sampling_convenience_dump.likelihood + ii, sample = args marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list()) sample = dict(sample).copy() likelihood.parameters.update(dict(sample).copy()) new_sample = likelihood.generate_posterior_sample_from_marginalized_likelihood() - return (new_sample[key] for key in marginalized_parameters) + return tuple((new_sample[key] for key in marginalized_parameters)) diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 42cfd83140b040ec5867e9088a391079cf7a9330..f7ba3b1db5da686e693dfcde267d7ae5c6e418fe 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -371,8 +371,10 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): else: time_ref = self.parameters['geocent_time'] - size_linear = len(self.waveform_generator.waveform_arguments['frequency_nodes_linear']) - size_quadratic = len(self.waveform_generator.waveform_arguments['frequency_nodes_quadratic']) + frequency_nodes_linear = self.waveform_generator.waveform_arguments['frequency_nodes_linear'] + frequency_nodes_quadratic = self.waveform_generator.waveform_arguments['frequency_nodes_quadratic'] + size_linear = len(frequency_nodes_linear) + size_quadratic = len(frequency_nodes_quadratic) h_linear = np.zeros(size_linear, dtype=complex) h_quadratic = np.zeros(size_quadratic, dtype=complex) for mode in waveform_polarizations['linear']: @@ -385,9 +387,9 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): h_quadratic += waveform_polarizations['quadratic'][mode] * response calib_linear = interferometer.calibration_model.get_calibration_factor( - size_linear, prefix='recalib_{}_'.format(interferometer.name), **self.parameters) + frequency_nodes_linear, prefix='recalib_{}_'.format(interferometer.name), **self.parameters) calib_quadratic = interferometer.calibration_model.get_calibration_factor( - size_quadratic, prefix='recalib_{}_'.format(interferometer.name), **self.parameters) + frequency_nodes_quadratic, prefix='recalib_{}_'.format(interferometer.name), **self.parameters) h_linear *= calib_linear h_quadratic *= calib_quadratic diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 4af73bdaa0c7f4f01f66b38089eb22e7fdbd73bb..94a869936fc6cd71cc3647e51a5a372e08ed4544 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -1,7 +1,10 @@ +import os +import shutil import unittest from unittest import mock import numpy as np +import pandas as pd import bilby @@ -408,5 +411,33 @@ class TestConditionalPriorDict(unittest.TestCase): print(res) +class TestDirichletPrior(unittest.TestCase): + + def setUp(self): + self.priors = bilby.core.prior.DirichletPriorDict(5) + + def tearDown(self): + if os.path.isdir("priors"): + shutil.rmtree("priors") + + def test_samples_sum_to_less_than_one(self): + """ + Test that the samples sum to less than one as required for the + Dirichlet distribution. + """ + samples = pd.DataFrame(self.priors.sample(10000)).values + self.assertLess(max(np.sum(samples, axis=1)), 1) + + def test_read_write_file(self): + self.priors.to_file(outdir="priors", label="test") + test = bilby.core.prior.PriorDict(filename="priors/test.prior") + self.assertEqual(self.priors, test) + + def test_read_write_json(self): + self.priors.to_json(outdir="priors", label="test") + test = bilby.core.prior.PriorDict.from_json(filename="priors/test_prior.json") + self.assertEqual(self.priors, test) + + if __name__ == "__main__": unittest.main() diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 3489dab6872591e68f491ab7b9493457f34732d8..eeef6c65c420f740739577fef883e6912e928f63 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -1048,8 +1048,8 @@ class TestROQLikelihoodHDF5(unittest.TestCase): """ - _path_to_basis = "/roq_basis/basis.hdf5" - _path_to_basis_mb = "/roq_basis/basis_multiband.hdf5" + _path_to_basis = "/roq_basis/basis_addcal.hdf5" + _path_to_basis_mb = "/roq_basis/basis_multiband_addcal.hdf5" def setUp(self): self.minimum_frequency = 20 @@ -1152,11 +1152,12 @@ class TestROQLikelihoodHDF5(unittest.TestCase): product( [_path_to_basis, _path_to_basis_mb], [_path_to_basis, _path_to_basis_mb], - [(8, 9), (8, 10.5), (8, 11.5), (8, 12.5), (8, 14)], - [1, 2] + [(8, 9), (8, 14)], + [1, 2], + [False, True] ) ) - def test_likelihood_accuracy(self, basis_linear, basis_quadratic, mc_range, roq_scale_factor): + def test_likelihood_accuracy(self, basis_linear, basis_quadratic, mc_range, roq_scale_factor, add_cal_errors): "Compare with log likelihood ratios computed by the non-ROQ likelihood" self.minimum_frequency *= roq_scale_factor self.sampling_frequency *= roq_scale_factor @@ -1177,6 +1178,25 @@ class TestROQLikelihoodHDF5(unittest.TestCase): duration=self.duration, start_time=self.injection_parameters["geocent_time"] - self.duration + 1 ) + + if add_cal_errors: + spline_calibration_nodes = 10 + np.random.seed(170817) + for ifo in interferometers: + prefix = f"recalib_{ifo.name}_" + ifo.calibration_model = bilby.gw.calibration.CubicSpline( + prefix=prefix, + minimum_frequency=ifo.minimum_frequency, + maximum_frequency=ifo.maximum_frequency, + n_points=spline_calibration_nodes + ) + for i in range(spline_calibration_nodes): + # 5% in amplitude, 5deg in phase + self.injection_parameters[f"{prefix}amplitude_{i}"] = \ + np.random.normal(loc=0, scale=0.05) + self.injection_parameters[f"{prefix}phase_{i}"] = \ + np.random.normal(loc=0, scale=5 * np.pi / 180) + waveform_generator = bilby.gw.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, @@ -1214,9 +1234,9 @@ class TestROQLikelihoodHDF5(unittest.TestCase): # The maximum error of log likelihood ratio. It is set to be larger for roq_scale_factor=1 as the injected SNR # is higher. if roq_scale_factor == 1: - max_llr_error = 1e-1 + max_llr_error = 5e-1 elif roq_scale_factor == 2: - max_llr_error = 1e-2 + max_llr_error = 5e-2 else: raise for mc in np.linspace(self.priors["chirp_mass"].minimum, self.priors["chirp_mass"].maximum, 11): @@ -1238,8 +1258,8 @@ class TestCreateROQLikelihood(unittest.TestCase): """ - _path_to_basis = "/roq_basis/basis.hdf5" - _path_to_basis_mb = "/roq_basis/basis_multiband.hdf5" + _path_to_basis = "/roq_basis/basis_addcal.hdf5" + _path_to_basis_mb = "/roq_basis/basis_multiband_addcal.hdf5" @parameterized.expand(product([_path_to_basis, _path_to_basis_mb], [_path_to_basis, _path_to_basis_mb])) def test_from_hdf5(self, basis_linear, basis_quadratic): @@ -1525,9 +1545,9 @@ class TestInOutROQWeights(unittest.TestCase): ) if multiband: - path_to_basis = "/roq_basis/basis_multiband.hdf5" + path_to_basis = "/roq_basis/basis_multiband_addcal.hdf5" else: - path_to_basis = "/roq_basis/basis.hdf5" + path_to_basis = "/roq_basis/basis_addcal.hdf5" return bilby.gw.likelihood.ROQGravitationalWaveTransient( interferometers=interferometers, priors=priors,