Skip to content
Snippets Groups Projects
Commit 159c2141 authored by Colm Talbot's avatar Colm Talbot
Browse files

modify fill_prior and logic for testing overlapping parameters

parent 0adf8200
No related branches found
No related tags found
1 merge request!31Change sampled parameters
......@@ -306,7 +306,7 @@ class Interped(Prior):
self.xx = np.linspace(self.minimum, self.maximum, len(xx))
self.yy = all_interpolated(self.xx)
if np.trapz(self.yy, self.xx) != 1:
logging.info('Supplied PDF is not normalised, normalising.')
logging.info('Supplied PDF for {} is not normalised, normalising.'.format(self.name))
self.yy /= np.trapz(self.yy, self.xx)
self.YY = cumtrapz(self.yy, self.xx, initial=0)
# Need last element of cumulative distribution to be exactly one.
......@@ -432,7 +432,7 @@ def parse_keys_to_parameters(keys):
return parameters
def fill_priors(prior, likelihood):
def fill_priors(prior, likelihood, parameters=None):
"""
Fill dictionary of priors based on required parameters of likelihood
......@@ -445,6 +445,9 @@ def fill_priors(prior, likelihood):
dictionary of prior objects and floats
likelihood: tupak.likelihood.Likelihood instance
Used to infer the set of parameters to fill the prior with
parameters: list
list of parameters to be sampled in, this can override the default
priors for the waveform generator
Returns
-------
......@@ -466,6 +469,10 @@ def fill_priors(prior, likelihood):
missing_keys = set(likelihood.parameters) - set(prior.keys())
if parameters is not None:
for parameter in parameters:
prior[parameter] = create_default_prior(parameter)
for missing_key in missing_keys:
default_prior = create_default_prior(missing_key)
if default_prior is None:
......@@ -475,7 +482,11 @@ def fill_priors(prior, likelihood):
" not be sampled and may cause an error."
.format(missing_key, set_val))
else:
prior[missing_key] = default_prior
if not test_redundancy(missing_key, prior):
prior[missing_key] = default_prior
for key in prior:
test_redundancy(key, prior)
return prior
......@@ -499,12 +510,13 @@ def test_redundancy(key, prior):
redundant = False
mass_parameters = {'mass_1', 'mass_2', 'chirp_mass', 'total_mass', 'mass_ratio', 'symmetric_mass_ratio'}
spin_magnitude_parameters = {'a_1', 'a_2'}
spin_tilt_parameters = {'tilt_1', 'tilt_2', 'cos_tilt_1', 'cos_tilt_2'}
spin_tilt_1_parameters = {'tilt_1', 'cos_tilt_1'}
spin_tilt_2_parameters = {'tilt_2', 'cos_tilt_2'}
spin_azimuth_parameters = {'phi_1', 'phi_2', 'phi_12', 'phi_jl'}
inclination_parameters = {'iota', 'cos_iota'}
distance_parameters = {'luminosity_distance', 'comoving_distance', 'redshift'}
for parameter_set in [mass_parameters, spin_magnitude_parameters, spin_tilt_parameters, spin_azimuth_parameters]:
for parameter_set in [mass_parameters, spin_magnitude_parameters, spin_azimuth_parameters]:
if key in parameter_set:
if len(parameter_set.intersection(prior.keys())) > 2:
redundant = True
......@@ -514,7 +526,7 @@ def test_redundancy(key, prior):
elif len(parameter_set.intersection(prior.keys())) == 2:
redundant = True
break
for parameter_set in [inclination_parameters, distance_parameters]:
for parameter_set in [inclination_parameters, distance_parameters, spin_tilt_1_parameters, spin_tilt_2_parameters]:
if key in parameter_set:
if len(parameter_set.intersection(prior.keys())) > 1:
redundant = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment