Skip to content
Snippets Groups Projects
Commit 8b7c5818 authored by Colm Talbot's avatar Colm Talbot Committed by Moritz Huebner
Browse files

Replace bessel interpolant to scipy function

parent 315ac28d
No related branches found
No related tags found
1 merge request!976Replace bessel interpolant to scipy function
......@@ -6,7 +6,7 @@ import copy
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d
from scipy.special import logsumexp, i0e
from scipy.special import logsumexp
from ..core.likelihood import Likelihood
from ..core.utils import BilbyJsonEncoder, decode_bilby_json
......@@ -18,7 +18,9 @@ from .detector import InterferometerList, get_empty_interferometer, calibration
from .prior import BBHPriorDict, CBCPriorDict, Cosmological
from .source import lal_binary_black_hole
from .utils import (
noise_weighted_inner_product, build_roq_weights, zenith_azimuth_to_ra_dec)
noise_weighted_inner_product, build_roq_weights, zenith_azimuth_to_ra_dec,
ln_i0
)
from .waveform_generator import WaveformGenerator
from collections import namedtuple
......@@ -173,8 +175,6 @@ class GravitationalWaveTransient(Likelihood):
if self.phase_marginalization:
self._check_marginalized_prior_is_set(key='phase')
self._bessel_function_interped = None
self._setup_phase_marginalization()
priors['phase'] = float(0)
self._marginalized_parameters.append('phase')
......@@ -575,8 +575,7 @@ class GravitationalWaveTransient(Likelihood):
time_log_like = self.distance_marginalized_likelihood(
d_inner_h, h_inner_h)
elif self.phase_marginalization:
time_log_like = (self._bessel_function_interped(abs(d_inner_h)) -
h_inner_h.real / 2)
time_log_like = ln_i0(abs(d_inner_h)) - h_inner_h.real / 2
else:
time_log_like = (d_inner_h.real - h_inner_h.real / 2)
......@@ -631,8 +630,9 @@ class GravitationalWaveTransient(Likelihood):
if self.phase_marginalization:
distance_log_like = (
self._bessel_function_interped(abs(d_inner_h_dist)) -
h_inner_h_dist.real / 2)
ln_i0(abs(d_inner_h_dist)) -
h_inner_h_dist.real / 2
)
else:
distance_log_like = (d_inner_h_dist.real - h_inner_h_dist.real / 2)
......@@ -703,7 +703,7 @@ class GravitationalWaveTransient(Likelihood):
d_inner_h_ref, h_inner_h_ref)
def phase_marginalized_likelihood(self, d_inner_h, h_inner_h):
d_inner_h = self._bessel_function_interped(abs(d_inner_h))
d_inner_h = ln_i0(abs(d_inner_h))
if self.calibration_marginalization and self.time_marginalization:
return d_inner_h - np.outer(h_inner_h, np.ones(np.shape(d_inner_h)[1])) / 2
......@@ -922,9 +922,7 @@ class GravitationalWaveTransient(Likelihood):
d_inner_h_array_full = np.outer(self._d_inner_h_ref_array, scaling)
h_inner_h_array_full = np.outer(self._optimal_snr_squared_ref_array, scaling ** 2)
if self.phase_marginalization:
d_inner_h_array_full = self._bessel_function_interped(abs(
d_inner_h_array_full
))
d_inner_h_array_full = ln_i0(abs(d_inner_h_array_full))
prior_term = self.distance_prior_array * self._delta_distance
for ii, optimal_snr_squared_array in tqdm(
enumerate(h_inner_h_array_full), total=len(self._optimal_snr_squared_ref_array)
......@@ -941,11 +939,20 @@ class GravitationalWaveTransient(Likelihood):
self.cache_lookup_table()
def _setup_phase_marginalization(self, min_bound=-5, max_bound=10):
x_values = np.logspace(min_bound, max_bound, int(1e6))
self._bessel_function_interped = interp1d(
x_values, x_values + np.log([i0e(snr) for snr in x_values]),
bounds_error=False, fill_value=(0, np.nan)
logger.warning(
"The _setup_phase_marginalization method is deprecated and will be removed, "
"please update the implementation of phase marginalization "
"to use bilby.gw.utils.ln_i0"
)
@staticmethod
def _bessel_function_interped(xx):
logger.warning(
"The _bessel_function_interped method is deprecated and will be removed, "
"please update the implementation of phase marginalization "
"to use bilby.gw.utils.ln_i0"
)
return ln_i0(xx) + xx
def _setup_time_marginalization(self):
self._delta_tc = 2 / self.waveform_generator.sampling_frequency
......
......@@ -4,6 +4,7 @@ from math import fmod
import numpy as np
from scipy.interpolate import interp1d
from scipy.special import i0e
from ..core.utils import (ra_dec_to_theta_phi,
speed_of_light, logger, run_commandline,
......@@ -987,3 +988,21 @@ def greenwich_mean_sidereal_time(time):
from lal import GreenwichMeanSiderealTime
time = float(time)
return GreenwichMeanSiderealTime(time)
def ln_i0(value):
"""
A numerically stable method to evaluate ln(I_0) a modified Bessel function
of order 0 used in the phase-marginalized likelihood.
Parameters
==========
value: array-like
Value(s) at which to evaluate the function
Returns
=======
array-like:
The natural logarithm of the bessel function
"""
return np.log(i0e(value)) + value
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment