Skip to content
Snippets Groups Projects
Commit 8a791548 authored by Colm Talbot's avatar Colm Talbot Committed by Gregory Ashton
Browse files

Merge branch 'master' into 'distance_marg_cache_v2'

# Conflicts:
#   bilby/gw/likelihood.py
parent c5a2df27
No related branches found
No related tags found
No related merge requests found
......@@ -53,6 +53,14 @@ class GravitationalWaveTransient(likelihood.Likelihood):
This is done analytically using a Bessel function.
priors: dict, optional
If given, used in the distance and phase marginalization.
distance_marginalization_lookup_table: (dict, str), optional
If a dict, dictionary containing the lookup_table, distance_array,
(distance) prior_array, and reference_distance used to construct
the table.
If a string the name of a file containing these quantities.
The lookup table is stored after construction in either the
provided string or a default location:
'.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz'
Returns
-------
......@@ -62,8 +70,10 @@ class GravitationalWaveTransient(likelihood.Likelihood):
"""
def __init__(self, interferometers, waveform_generator, time_marginalization=False, distance_marginalization=False,
phase_marginalization=False, priors=None):
def __init__(self, interferometers, waveform_generator,
time_marginalization=False, distance_marginalization=False,
phase_marginalization=False, priors=None,
distance_marginalization_lookup_table=None):
self.waveform_generator = waveform_generator
likelihood.Likelihood.__init__(self, dict())
......@@ -79,7 +89,8 @@ class GravitationalWaveTransient(likelihood.Likelihood):
phase_marginalization=self.phase_marginalization,
distance_marginalization=self.distance_marginalization,
waveform_arguments=waveform_generator.waveform_arguments,
frequency_domain_source_model=str(waveform_generator.frequency_domain_source_model))
frequency_domain_source_model=str(
waveform_generator.frequency_domain_source_model))
if self.time_marginalization:
self._check_prior_is_set(key='geocent_time')
......@@ -93,10 +104,16 @@ class GravitationalWaveTransient(likelihood.Likelihood):
priors['phase'] = float(0)
if self.distance_marginalization:
self._lookup_table_filename = None
self._check_prior_is_set(key='luminosity_distance')
self._distance_array = np.linspace(self.priors['luminosity_distance'].minimum,
self.priors['luminosity_distance'].maximum, int(1e4))
self._setup_distance_marginalization()
self._distance_array = np.linspace(
self.priors['luminosity_distance'].minimum,
self.priors['luminosity_distance'].maximum, int(1e4))
self.distance_prior_array = np.array(
[self.priors['luminosity_distance'].prob(distance)
for distance in self._distance_array])
self._setup_distance_marginalization(
distance_marginalization_lookup_table)
priors['luminosity_distance'] = float(self._ref_dist)
def __repr__(self):
......@@ -292,68 +309,74 @@ class GravitationalWaveTransient(likelihood.Likelihood):
return np.hstack((-np.logspace(3, -3, self._dist_margd_loglikelihood_array.shape[1] / 2),
np.logspace(-3, 10, self._dist_margd_loglikelihood_array.shape[1] / 2)))
def _setup_distance_marginalization(self):
self._create_lookup_table()
def _setup_distance_marginalization(self, lookup_table=None):
if isinstance(lookup_table, str) or lookup_table is None:
self.cached_lookup_table_filename = lookup_table
lookup_table = self.load_lookup_table(
self.cached_lookup_table_filename)
if isinstance(lookup_table, dict):
if self._test_cached_lookup_table(lookup_table):
self._dist_margd_loglikelihood_array = lookup_table[
'lookup_table']
else:
self._create_lookup_table()
else:
self._create_lookup_table()
self._interp_dist_margd_loglikelihood = UnsortedInterp2d(
self._rho_mf_ref_array, self._rho_opt_ref_array,
self._dist_margd_loglikelihood_array)
@property
def cached_lookup_table_filename(self):
dmin = self._distance_array[0]
dmax = self._distance_array[-1]
n = len(self._distance_array)
cached_lookup_table_filename = (
'.distance_marginalization_lookup_dmin{}_dmax{}_n{}_v1.npy'
.format(dmin, dmax, n))
return cached_lookup_table_filename
@property
def cached_lookup_table(self):
""" Reads in the cached lookup table
Returns
-------
cached_lookup_table: np.ndarray
Columns are _distance_array, distance_prior_array,
dist_marged_log_l_tc_array. This is only returned if the file
exists and the first two columns match the equivalent values
stored on disk.
"""
if os.path.exists(self.cached_lookup_table_filename):
loaded_file = np.load(self.cached_lookup_table_filename)
if self._lookup_table_filename is None:
dmin = self._distance_array[0]
dmax = self._distance_array[-1]
n = len(self._distance_array)
self._lookup_table_filename = (
'.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz'
.format(dmin, dmax, n))
return self._lookup_table_filename
@cached_lookup_table_filename.setter
def cached_lookup_table_filename(self, filename):
if isinstance(filename, str):
if filename[-4:] != '.npz':
filename += '.npz'
self._lookup_table_filename = filename
def load_lookup_table(self, filename):
if os.path.exists(filename):
loaded_file = dict(np.load(filename))
if self._test_cached_lookup_table(loaded_file):
logger.info('Loaded distance marginalisation lookup table from '
'{}.'.format(filename))
return loaded_file
else:
logger.info('Loaded distance marginalisation lookup table does '
'not match prior')
return None
elif isinstance(filename, str):
logger.info('Distance marginalisation file {} does not '
'exist'.format(filename))
return None
else:
return None
@cached_lookup_table.setter
def cached_lookup_table(self, lookup_table):
np.save(self.cached_lookup_table_filename, lookup_table)
def cache_lookup_table(self):
np.savez(self.cached_lookup_table_filename,
distance_array=self._distance_array,
prior_array=self.distance_prior_array,
lookup_table=self._dist_margd_loglikelihood_array,
reference_distance=self._ref_dist)
def _test_cached_lookup_table(self, lookup_table):
cond_a = np.all(self._distance_array == lookup_table[0])
cond_b = np.all(self.distance_prior_array == lookup_table[1])
if cond_a and cond_b:
return True
def _test_cached_lookup_table(self, loaded_file):
cond_a = np.all(self._distance_array == loaded_file['distance_array'])
cond_b = np.all(self.distance_prior_array == loaded_file['prior_array'])
cond_c = self._ref_dist == loaded_file['reference_distance']
return all([cond_a, cond_b, cond_c])
def _create_lookup_table(self):
""" Make the lookup table """
self.distance_prior_array = np.array(
[self.priors['luminosity_distance'].prob(distance)
for distance in self._distance_array])
# Check if a cached lookup table exists in file
cached_lookup_table = self.cached_lookup_table
if cached_lookup_table is not None:
self._dist_margd_loglikelihood_array = cached_lookup_table[-1]
logger.info("Using the cached lookup table {}"
.format(os.path.abspath(self.cached_lookup_table_filename)))
return
logger.info('Building lookup table for distance marginalisation.')
self._dist_margd_loglikelihood_array = np.zeros((400, 800))
......@@ -370,10 +393,7 @@ class GravitationalWaveTransient(likelihood.Likelihood):
log_norm = logsumexp(0. / self._distance_array - 0. / self._distance_array ** 2.,
b=self.distance_prior_array * self._delta_distance)
self._dist_margd_loglikelihood_array -= log_norm
self.cached_lookup_table = np.array([
self._distance_array,
self.distance_prior_array,
self._dist_margd_loglikelihood_array])
self.cache_lookup_table()
def _setup_phase_marginalization(self):
self._bessel_function_interped = interp1d(
......@@ -513,16 +533,26 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
array, or the array itself.
priors: dict, bilby.prior.PriorDict
A dictionary of priors containing at least the geocent_time prior
distance_marginalization_lookup_table: (dict, str), optional
If a dict, dictionary containing the lookup_table, distance_array,
(distance) prior_array, and reference_distance used to construct
the table.
If a string the name of a file containing these quantities.
The lookup table is stored after construction in either the
provided string or a default location:
'.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz'
"""
def __init__(self, interferometers, waveform_generator, priors,
weights=None, linear_matrix=None, quadratic_matrix=None,
distance_marginalization=False, phase_marginalization=False):
distance_marginalization=False, phase_marginalization=False,
distance_marginalization_lookup_table=None):
GravitationalWaveTransient.__init__(
self, interferometers=interferometers,
waveform_generator=waveform_generator, priors=priors,
distance_marginalization=distance_marginalization,
phase_marginalization=phase_marginalization)
phase_marginalization=phase_marginalization,
distance_marginalization_lookup_table=distance_marginalization_lookup_table)
self.time_samples = np.arange(
self.priors['geocent_time'].minimum - 0.045,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment