diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py index 6da8c8363c1ea56b92ff3561a1afd3eab40f6dfe..c5d2b46bae88404318398944788497c1697fe5a2 100644 --- a/bilby/gw/likelihood.py +++ b/bilby/gw/likelihood.py @@ -53,6 +53,14 @@ class GravitationalWaveTransient(likelihood.Likelihood): This is done analytically using a Bessel function. priors: dict, optional If given, used in the distance and phase marginalization. + distance_marginalization_lookup_table: (dict, str), optional + If a dict, dictionary containing the lookup_table, distance_array, + (distance) prior_array, and reference_distance used to construct + the table. + If a string the name of a file containing these quantities. + The lookup table is stored after construction in either the + provided string or a default location: + '.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz' Returns ------- @@ -62,8 +70,10 @@ class GravitationalWaveTransient(likelihood.Likelihood): """ - def __init__(self, interferometers, waveform_generator, time_marginalization=False, distance_marginalization=False, - phase_marginalization=False, priors=None): + def __init__(self, interferometers, waveform_generator, + time_marginalization=False, distance_marginalization=False, + phase_marginalization=False, priors=None, + distance_marginalization_lookup_table=None): self.waveform_generator = waveform_generator likelihood.Likelihood.__init__(self, dict()) @@ -79,7 +89,8 @@ class GravitationalWaveTransient(likelihood.Likelihood): phase_marginalization=self.phase_marginalization, distance_marginalization=self.distance_marginalization, waveform_arguments=waveform_generator.waveform_arguments, - frequency_domain_source_model=str(waveform_generator.frequency_domain_source_model)) + frequency_domain_source_model=str( + waveform_generator.frequency_domain_source_model)) if self.time_marginalization: self._check_prior_is_set(key='geocent_time') @@ -93,10 +104,16 @@ class GravitationalWaveTransient(likelihood.Likelihood): priors['phase'] = float(0) if self.distance_marginalization: + self._lookup_table_filename = None self._check_prior_is_set(key='luminosity_distance') - self._distance_array = np.linspace(self.priors['luminosity_distance'].minimum, - self.priors['luminosity_distance'].maximum, int(1e4)) - self._setup_distance_marginalization() + self._distance_array = np.linspace( + self.priors['luminosity_distance'].minimum, + self.priors['luminosity_distance'].maximum, int(1e4)) + self.distance_prior_array = np.array( + [self.priors['luminosity_distance'].prob(distance) + for distance in self._distance_array]) + self._setup_distance_marginalization( + distance_marginalization_lookup_table) priors['luminosity_distance'] = float(self._ref_dist) def __repr__(self): @@ -292,68 +309,74 @@ class GravitationalWaveTransient(likelihood.Likelihood): return np.hstack((-np.logspace(3, -3, self._dist_margd_loglikelihood_array.shape[1] / 2), np.logspace(-3, 10, self._dist_margd_loglikelihood_array.shape[1] / 2))) - def _setup_distance_marginalization(self): - self._create_lookup_table() + def _setup_distance_marginalization(self, lookup_table=None): + if isinstance(lookup_table, str) or lookup_table is None: + self.cached_lookup_table_filename = lookup_table + lookup_table = self.load_lookup_table( + self.cached_lookup_table_filename) + if isinstance(lookup_table, dict): + if self._test_cached_lookup_table(lookup_table): + self._dist_margd_loglikelihood_array = lookup_table[ + 'lookup_table'] + else: + self._create_lookup_table() + else: + self._create_lookup_table() self._interp_dist_margd_loglikelihood = UnsortedInterp2d( self._rho_mf_ref_array, self._rho_opt_ref_array, self._dist_margd_loglikelihood_array) @property def cached_lookup_table_filename(self): - dmin = self._distance_array[0] - dmax = self._distance_array[-1] - n = len(self._distance_array) - cached_lookup_table_filename = ( - '.distance_marginalization_lookup_dmin{}_dmax{}_n{}_v1.npy' - .format(dmin, dmax, n)) - return cached_lookup_table_filename - - @property - def cached_lookup_table(self): - """ Reads in the cached lookup table - - Returns - ------- - cached_lookup_table: np.ndarray - Columns are _distance_array, distance_prior_array, - dist_marged_log_l_tc_array. This is only returned if the file - exists and the first two columns match the equivalent values - stored on disk. - - """ - - if os.path.exists(self.cached_lookup_table_filename): - loaded_file = np.load(self.cached_lookup_table_filename) + if self._lookup_table_filename is None: + dmin = self._distance_array[0] + dmax = self._distance_array[-1] + n = len(self._distance_array) + self._lookup_table_filename = ( + '.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz' + .format(dmin, dmax, n)) + return self._lookup_table_filename + + @cached_lookup_table_filename.setter + def cached_lookup_table_filename(self, filename): + if isinstance(filename, str): + if filename[-4:] != '.npz': + filename += '.npz' + self._lookup_table_filename = filename + + def load_lookup_table(self, filename): + if os.path.exists(filename): + loaded_file = dict(np.load(filename)) if self._test_cached_lookup_table(loaded_file): + logger.info('Loaded distance marginalisation lookup table from ' + '{}.'.format(filename)) return loaded_file + else: + logger.info('Loaded distance marginalisation lookup table does ' + 'not match prior') + return None + elif isinstance(filename, str): + logger.info('Distance marginalisation file {} does not ' + 'exist'.format(filename)) + return None else: return None - @cached_lookup_table.setter - def cached_lookup_table(self, lookup_table): - np.save(self.cached_lookup_table_filename, lookup_table) + def cache_lookup_table(self): + np.savez(self.cached_lookup_table_filename, + distance_array=self._distance_array, + prior_array=self.distance_prior_array, + lookup_table=self._dist_margd_loglikelihood_array, + reference_distance=self._ref_dist) - def _test_cached_lookup_table(self, lookup_table): - cond_a = np.all(self._distance_array == lookup_table[0]) - cond_b = np.all(self.distance_prior_array == lookup_table[1]) - if cond_a and cond_b: - return True + def _test_cached_lookup_table(self, loaded_file): + cond_a = np.all(self._distance_array == loaded_file['distance_array']) + cond_b = np.all(self.distance_prior_array == loaded_file['prior_array']) + cond_c = self._ref_dist == loaded_file['reference_distance'] + return all([cond_a, cond_b, cond_c]) def _create_lookup_table(self): """ Make the lookup table """ - - self.distance_prior_array = np.array( - [self.priors['luminosity_distance'].prob(distance) - for distance in self._distance_array]) - - # Check if a cached lookup table exists in file - cached_lookup_table = self.cached_lookup_table - if cached_lookup_table is not None: - self._dist_margd_loglikelihood_array = cached_lookup_table[-1] - logger.info("Using the cached lookup table {}" - .format(os.path.abspath(self.cached_lookup_table_filename))) - return - logger.info('Building lookup table for distance marginalisation.') self._dist_margd_loglikelihood_array = np.zeros((400, 800)) @@ -370,10 +393,7 @@ class GravitationalWaveTransient(likelihood.Likelihood): log_norm = logsumexp(0. / self._distance_array - 0. / self._distance_array ** 2., b=self.distance_prior_array * self._delta_distance) self._dist_margd_loglikelihood_array -= log_norm - self.cached_lookup_table = np.array([ - self._distance_array, - self.distance_prior_array, - self._dist_margd_loglikelihood_array]) + self.cache_lookup_table() def _setup_phase_marginalization(self): self._bessel_function_interped = interp1d( @@ -513,16 +533,26 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): array, or the array itself. priors: dict, bilby.prior.PriorDict A dictionary of priors containing at least the geocent_time prior + distance_marginalization_lookup_table: (dict, str), optional + If a dict, dictionary containing the lookup_table, distance_array, + (distance) prior_array, and reference_distance used to construct + the table. + If a string the name of a file containing these quantities. + The lookup table is stored after construction in either the + provided string or a default location: + '.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz' """ def __init__(self, interferometers, waveform_generator, priors, weights=None, linear_matrix=None, quadratic_matrix=None, - distance_marginalization=False, phase_marginalization=False): + distance_marginalization=False, phase_marginalization=False, + distance_marginalization_lookup_table=None): GravitationalWaveTransient.__init__( self, interferometers=interferometers, waveform_generator=waveform_generator, priors=priors, distance_marginalization=distance_marginalization, - phase_marginalization=phase_marginalization) + phase_marginalization=phase_marginalization, + distance_marginalization_lookup_table=distance_marginalization_lookup_table) self.time_samples = np.arange( self.priors['geocent_time'].minimum - 0.045,