From 13df713f07a8f39b81542a7dcda5b79be0b2e8c8 Mon Sep 17 00:00:00 2001 From: Soichiro Morisaki <soichiro.morisaki@ligo.org> Date: Fri, 4 Feb 2022 15:38:51 +0000 Subject: [PATCH] Implement time marginalization into ROQ likelihood --- bilby/gw/likelihood/base.py | 4 +- bilby/gw/likelihood/roq.py | 179 ++++++++++++++++++++++++++--------- test/gw/likelihood_test.py | 182 +++++++++++++++++++++++++++++++++++- 3 files changed, 315 insertions(+), 50 deletions(-) diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index d6c9f6f14..a75dded56 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -378,9 +378,7 @@ class GravitationalWaveTransient(Likelihood): elif self.time_marginalization: if self.jitter_time: self.parameters['geocent_time'] += self.parameters['time_jitter'] - d_inner_h_array = np.zeros( - len(self.interferometers.frequency_array[0:-1]), - dtype=np.complex128) + d_inner_h_array = np.zeros(len(self._times), dtype=np.complex128) elif self.calibration_marginalization: d_inner_h_array = np.zeros(self.number_of_response_curves, dtype=np.complex128) diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 5d459082a..ce3e679f5 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -44,6 +44,20 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): A dictionary of priors containing at least the geocent_time prior Warning: when using marginalisation the dict is overwritten which will change the the dict you are passing in. If this behaviour is undesired, pass `priors.copy()`. + time_marginalization: bool, optional + If true, marginalize over time in the likelihood. + The spacing of time samples can be specified through delta_tc. + If using time marginalisation and jitter_time is True a "jitter" + parameter is added to the prior which modifies the position of the + grid of times. + jitter_time: bool, optional + Whether to introduce a `time_jitter` parameter. This avoids either + missing the likelihood peak, or introducing biases in the + reconstructed time posterior due to an insufficient sampling frequency. + Default is False, however using this parameter is strongly encouraged. + delta_tc: float, optional + The spacing of time samples for time marginalization. If not specified, + it is determined based on the signal-to-noise ratio of signal. distance_marginalization_lookup_table: (dict, str), optional If a dict, dictionary containing the lookup_table, distance_array, (distance) prior_array, and reference_distance used to construct @@ -71,18 +85,20 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): weights=None, linear_matrix=None, quadratic_matrix=None, roq_params=None, roq_params_check=True, roq_scale_factor=1, distance_marginalization=False, phase_marginalization=False, + time_marginalization=False, jitter_time=True, delta_tc=None, distance_marginalization_lookup_table=None, reference_frame="sky", time_reference="geocenter" ): + self._delta_tc = delta_tc super(ROQGravitationalWaveTransient, self).__init__( interferometers=interferometers, waveform_generator=waveform_generator, priors=priors, distance_marginalization=distance_marginalization, phase_marginalization=phase_marginalization, - time_marginalization=False, + time_marginalization=time_marginalization, distance_marginalization_lookup_table=distance_marginalization_lookup_table, - jitter_time=False, + jitter_time=jitter_time, reference_frame=reference_frame, time_reference=time_reference ) @@ -117,6 +133,19 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): self.frequency_nodes_quadratic = \ waveform_generator.waveform_arguments['frequency_nodes_quadratic'] + def _setup_time_marginalization(self): + if self._delta_tc is None: + self._delta_tc = self._get_time_resolution() + tcmin = self.priors['geocent_time'].minimum + tcmax = self.priors['geocent_time'].maximum + number_of_time_samples = int(np.ceil((tcmax - tcmin) / self._delta_tc)) + # adjust delta tc so that the last time sample has an equal weight + self._delta_tc = (tcmax - tcmin) / number_of_time_samples + logger.info( + "delta tc for time marginalization = {} seconds.".format(self._delta_tc)) + self._times = tcmin + self._delta_tc / 2. + np.arange(number_of_time_samples) * self._delta_tc + self._beam_pattern_reference_time = (tcmin + tcmax) / 2. + def calculate_snrs(self, waveform_polarizations, interferometer): """ Compute the snrs for ROQ @@ -128,18 +157,21 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): """ - f_plus = interferometer.antenna_response( - self.parameters['ra'], self.parameters['dec'], - self.parameters['geocent_time'], self.parameters['psi'], 'plus') - f_cross = interferometer.antenna_response( - self.parameters['ra'], self.parameters['dec'], - self.parameters['geocent_time'], self.parameters['psi'], 'cross') - - dt = interferometer.time_delay_from_geocenter( - self.parameters['ra'], self.parameters['dec'], - self.parameters['geocent_time']) - dt_geocent = self.parameters['geocent_time'] - interferometer.strain_data.start_time - ifo_time = dt_geocent + dt + if self.time_marginalization: + time_ref = self._beam_pattern_reference_time + else: + time_ref = self.parameters['geocent_time'] + + h_linear = np.zeros(len(self.frequency_nodes_linear), dtype=complex) + h_quadratic = np.zeros(len(self.frequency_nodes_quadratic), dtype=complex) + for mode in waveform_polarizations['linear']: + response = interferometer.antenna_response( + self.parameters['ra'], self.parameters['dec'], + self.parameters['geocent_time'], self.parameters['psi'], + mode + ) + h_linear += waveform_polarizations['linear'][mode] * response + h_quadratic += waveform_polarizations['quadratic'][mode] * response calib_linear = interferometer.calibration_model.get_calibration_factor( self.frequency_nodes_linear, @@ -148,46 +180,56 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): self.frequency_nodes_quadratic, prefix='recalib_{}_'.format(interferometer.name), **self.parameters) - h_plus_linear = f_plus * waveform_polarizations['linear']['plus'] * calib_linear - h_cross_linear = f_cross * waveform_polarizations['linear']['cross'] * calib_linear - h_plus_quadratic = ( - f_plus * waveform_polarizations['quadratic']['plus'] * calib_quadratic - ) - h_cross_quadratic = ( - f_cross * waveform_polarizations['quadratic']['cross'] * calib_quadratic - ) + h_linear *= calib_linear + h_quadratic *= calib_quadratic - indices, in_bounds = self._closest_time_indices( - ifo_time, self.weights['time_samples']) - if not in_bounds: - logger.debug("SNR calculation error: requested time at edge of ROQ time samples") - return self._CalculatedSNRs( - d_inner_h=np.nan_to_num(-np.inf), optimal_snr_squared=0, - complex_matched_filter_snr=np.nan_to_num(-np.inf), - d_inner_h_squared_tc_array=None, - d_inner_h_array=None, - optimal_snr_squared_array=None) + optimal_snr_squared = \ + np.vdot(np.abs(h_quadratic)**2, self.weights[interferometer.name + '_quadratic']) + + dt = interferometer.time_delay_from_geocenter( + self.parameters['ra'], self.parameters['dec'], time_ref) - d_inner_h_tc_array = np.einsum( - 'i,ji->j', np.conjugate(h_plus_linear + h_cross_linear), - self.weights[interferometer.name + '_linear'][indices]) + if not self.time_marginalization: + dt_geocent = self.parameters['geocent_time'] - interferometer.strain_data.start_time + ifo_time = dt_geocent + dt - d_inner_h = self._interp_five_samples( - self.weights['time_samples'][indices], d_inner_h_tc_array, ifo_time) + indices, in_bounds = self._closest_time_indices( + ifo_time, self.weights['time_samples']) + if not in_bounds: + logger.debug("SNR calculation error: requested time at edge of ROQ time samples") + return self._CalculatedSNRs( + d_inner_h=np.nan_to_num(-np.inf), optimal_snr_squared=0, + complex_matched_filter_snr=np.nan_to_num(-np.inf), + d_inner_h_squared_tc_array=None, + d_inner_h_array=None, + optimal_snr_squared_array=None) - optimal_snr_squared = \ - np.vdot(np.abs(h_plus_quadratic + h_cross_quadratic)**2, - self.weights[interferometer.name + '_quadratic']) + d_inner_h_tc_array = np.einsum( + 'i,ji->j', np.conjugate(h_linear), + self.weights[interferometer.name + '_linear'][indices]) + + d_inner_h = self._interp_five_samples( + self.weights['time_samples'][indices], d_inner_h_tc_array, ifo_time) + + with np.errstate(invalid="ignore"): + complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) + + d_inner_h_array = None - with np.errstate(invalid="ignore"): - complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) - d_inner_h_squared_tc_array = None + else: + ifo_times = self._times - interferometer.strain_data.start_time + dt + if self.jitter_time: + ifo_times += self.parameters['time_jitter'] + d_inner_h_array = self._calculate_d_inner_h_array(ifo_times, h_linear, interferometer.name) + + d_inner_h = 0. + complex_matched_filter_snr = 0. return self._CalculatedSNRs( d_inner_h=d_inner_h, optimal_snr_squared=optimal_snr_squared, complex_matched_filter_snr=complex_matched_filter_snr, - d_inner_h_squared_tc_array=d_inner_h_squared_tc_array, - d_inner_h_array=None, + d_inner_h_squared_tc_array=None, + d_inner_h_array=d_inner_h_array, optimal_snr_squared_array=None) @staticmethod @@ -242,6 +284,53 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): d = (b**3. - b) / 6. return a * values[2] + b * values[3] + c * r1 + d * r2 + def _calculate_d_inner_h_array(self, times, h_linear, ifo_name): + """ + Calculate d_inner_h at regularly-spaced time samples. Each value is + interpolated from the nearest 5 samples with the algorithm explained in + https://dcc.ligo.org/T2100224. + + Parameters + ========== + times: array-like + Regularly-spaced time samples at which d_inner_h are calculated. + h_linear: array-like + Waveforms at linear frequency nodes + ifo_name: str + + Returns + ======= + d_inner_h_array: array-like + """ + roq_time_space = self.weights['time_samples'][1] - self.weights['time_samples'][0] + times_per_roq_time_space = (times - self.weights['time_samples'][0]) / roq_time_space + closest_idxs = np.floor(times_per_roq_time_space).astype(int) + # Get the nearest 5 samples of d_inner_h. Calculate only the required d_inner_h values if the time + # spacing is larger than 5 times the ROQ time spacing. + weights_linear = self.weights[ifo_name + '_linear'] + h_linear_conj = np.conjugate(h_linear) + if (times[1] - times[0]) / roq_time_space > 5: + d_inner_h_m2 = np.dot(weights_linear[closest_idxs - 2], h_linear_conj) + d_inner_h_m1 = np.dot(weights_linear[closest_idxs - 1], h_linear_conj) + d_inner_h_0 = np.dot(weights_linear[closest_idxs], h_linear_conj) + d_inner_h_p1 = np.dot(weights_linear[closest_idxs + 1], h_linear_conj) + d_inner_h_p2 = np.dot(weights_linear[closest_idxs + 2], h_linear_conj) + else: + d_inner_h_at_roq_time_samples = np.dot(weights_linear, h_linear_conj) + d_inner_h_m2 = d_inner_h_at_roq_time_samples[closest_idxs - 2] + d_inner_h_m1 = d_inner_h_at_roq_time_samples[closest_idxs - 1] + d_inner_h_0 = d_inner_h_at_roq_time_samples[closest_idxs] + d_inner_h_p1 = d_inner_h_at_roq_time_samples[closest_idxs + 1] + d_inner_h_p2 = d_inner_h_at_roq_time_samples[closest_idxs + 2] + # quantities required for spline interpolation + b = times_per_roq_time_space - closest_idxs + a = 1. - b + c = (a**3. - a) / 6. + d = (b**3. - b) / 6. + r1 = (-d_inner_h_m2 + 8. * d_inner_h_m1 - 14. * d_inner_h_0 + 8. * d_inner_h_p1 - d_inner_h_p2) / 4. + r2 = d_inner_h_0 - 2. * d_inner_h_p1 + d_inner_h_p2 + return a * d_inner_h_0 + b * d_inner_h_p1 + c * r1 + d * r2 + def perform_roq_params_check(self, ifo=None): """ Perform checking that the prior and data are valid for the ROQ diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index f2c6e7ec5..6e34f5872 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -428,6 +428,9 @@ class TestMarginalizations(unittest.TestCase): The `time_jitter` parameter makes this a weaker dependence during sampling. """ + lookup_phase = "distance_lookup_phase.npz" + lookup_no_phase = "distance_lookup_no_phase.npz" + def setUp(self): np.random.seed(500) self.duration = 4 @@ -482,6 +485,13 @@ class TestMarginalizations(unittest.TestCase): del self.waveform_generator del self.priors + @classmethod + def tearDownClass(cls): + # remove lookup tables so that they are not used accidentally in subsequent tests + for filename in [cls.lookup_phase, cls.lookup_no_phase]: + if os.path.exists(filename): + os.remove(filename) + def get_likelihood( self, time_marginalization=False, @@ -492,9 +502,9 @@ class TestMarginalizations(unittest.TestCase): if priors is None: priors = self.priors.copy() if distance_marginalization and phase_marginalization: - lookup = "distance_lookup_phase.npz" + lookup = TestMarginalizations.lookup_phase elif distance_marginalization: - lookup = "distance_lookup_no_phase.npz" + lookup = TestMarginalizations.lookup_no_phase else: lookup = None like = bilby.gw.likelihood.GravitationalWaveTransient( @@ -644,6 +654,174 @@ class TestMarginalizations(unittest.TestCase): ) +class TestMarginalizationsROQ(TestMarginalizations): + + lookup_phase = "distance_lookup_phase.npz" + lookup_no_phase = "distance_lookup_no_phase.npz" + path_to_roq_weights = "weights.npz" + + def setUp(self): + np.random.seed(500) + self.duration = 4 + self.sampling_frequency = 2048 + self.parameters = dict( + mass_1=31.0, + mass_2=29.0, + a_1=0.4, + a_2=0.3, + tilt_1=0.0, + tilt_2=0.0, + phi_12=1.7, + phi_jl=0.3, + luminosity_distance=4000.0, + theta_jn=0.4, + psi=2.659, + phase=1.3, + geocent_time=1126259642.413, + ra=1.375, + dec=-1.2108, + time_jitter=0, + ) + + self.interferometers = bilby.gw.detector.InterferometerList(["H1"]) + self.interferometers.set_strain_data_from_power_spectral_densities( + sampling_frequency=self.sampling_frequency, + duration=self.duration, + start_time=1126259640, + ) + + waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + duration=self.duration, + sampling_frequency=self.sampling_frequency, + frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, + start_time=1126259640, + waveform_arguments=dict( + reference_frequency=20.0, + minimum_frequency=20.0, + approximant="IMRPhenomPv2" + ) + ) + self.interferometers.inject_signal( + parameters=self.parameters, waveform_generator=waveform_generator + ) + + self.priors = bilby.gw.prior.BBHPriorDict() + # prior range should be a part of segment since ROQ likelihood can not + # calculate values at samples close to edges + self.priors["geocent_time"] = bilby.prior.Uniform( + minimum=self.parameters["geocent_time"] - 0.1, + maximum=self.parameters["geocent_time"] + 0.1 + ) + + # Possible locations for the ROQ: in the docker image, local, or on CIT + trial_roq_paths = [ + "/roq_basis", + os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), + "/home/cbc/ROQ_data/IMRPhenomPv2/4s", + ] + roq_dir = None + for path in trial_roq_paths: + if os.path.isdir(path): + roq_dir = path + break + if roq_dir is None: + raise Exception("Unable to load ROQ basis: cannot proceed with tests") + + self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + duration=self.duration, + sampling_frequency=self.sampling_frequency, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, + start_time=1126259640, + waveform_arguments=dict( + reference_frequency=20.0, + minimum_frequency=20.0, + approximant="IMRPhenomPv2", + frequency_nodes_linear=np.load("{}/fnodes_linear.npy".format(roq_dir)), + frequency_nodes_quadratic=np.load("{}/fnodes_quadratic.npy".format(roq_dir)), + ) + ) + self.roq_linear_matrix_file = "{}/B_linear.npy".format(roq_dir) + self.roq_quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) + + @classmethod + def tearDownClass(cls): + for filename in [cls.lookup_phase, cls.lookup_no_phase, cls.path_to_roq_weights]: + if os.path.exists(filename): + os.remove(filename) + + def get_likelihood( + self, + time_marginalization=False, + phase_marginalization=False, + distance_marginalization=False, + priors=None + ): + if priors is None: + priors = self.priors.copy() + if distance_marginalization and phase_marginalization: + lookup = TestMarginalizationsROQ.lookup_phase + elif distance_marginalization: + lookup = TestMarginalizationsROQ.lookup_no_phase + else: + lookup = None + kwargs = dict( + interferometers=self.interferometers, + waveform_generator=self.waveform_generator, + distance_marginalization=distance_marginalization, + phase_marginalization=phase_marginalization, + time_marginalization=time_marginalization, + distance_marginalization_lookup_table=lookup, + priors=priors + ) + if os.path.exists(TestMarginalizationsROQ.path_to_roq_weights): + kwargs.update(dict(weights=TestMarginalizationsROQ.path_to_roq_weights)) + like = bilby.gw.likelihood.ROQGravitationalWaveTransient(**kwargs) + else: + kwargs.update( + dict( + linear_matrix=self.roq_linear_matrix_file, + quadratic_matrix=self.roq_quadratic_matrix_file + ) + ) + like = bilby.gw.likelihood.ROQGravitationalWaveTransient(**kwargs) + like.save_weights(TestMarginalizationsROQ.path_to_roq_weights) + like.parameters = self.parameters.copy() + if time_marginalization: + like.parameters["geocent_time"] = self.interferometers.start_time + return like + + def test_time_marginalisation(self): + self._template( + self.get_likelihood(time_marginalization=True), + self.get_likelihood(), + key="geocent_time", + ) + + def test_time_distance_marginalisation(self): + self._template( + self.get_likelihood(time_marginalization=True, distance_marginalization=True), + self.get_likelihood(distance_marginalization=True), + key="geocent_time", + ) + + def test_time_phase_marginalisation(self): + self._template( + self.get_likelihood(time_marginalization=True, phase_marginalization=True), + self.get_likelihood(phase_marginalization=True), + key="geocent_time", + ) + + def test_time_distance_phase_marginalisation(self): + self._template( + self.get_likelihood(time_marginalization=True, phase_marginalization=True, distance_marginalization=True), + self.get_likelihood(phase_marginalization=True, distance_marginalization=True), + key="geocent_time", + ) + + def test_time_marginalisation_partial_segment(self): + pass + + class TestROQLikelihood(unittest.TestCase): def setUp(self): self.duration = 4 -- GitLab