Skip to content
Snippets Groups Projects
Commit bad30b7c authored by Rhiannon Udall's avatar Rhiannon Udall Committed by Colm Talbot
Browse files

Add identity conversion and generation functions

parent 13f4c77d
No related branches found
No related tags found
1 merge request!1264Add identity conversion and generation functions
Pipeline #649980 passed
......@@ -2556,3 +2556,59 @@ def fill_sample(args):
likelihood.parameters.update(dict(sample).copy())
new_sample = likelihood.generate_posterior_sample_from_marginalized_likelihood()
return tuple((new_sample[key] for key in marginalized_parameters))
def identity_map_conversion(parameters):
"""An identity map conversion function that makes no changes to the parameters,
but returns the correct signature expected by other conversion functions
(e.g. convert_to_lal_binary_black_hole_parameters)"""
return parameters, []
def identity_map_generation(sample, likelihood=None, priors=None, npool=1):
"""An identity map generation function that handles marginalizations, SNRs, etc. correctly,
but does not attempt e.g. conversions in mass or spins
Parameters
==========
sample: dict or pandas.DataFrame
Samples to fill in with extra parameters, this may be either an
injection or posterior samples.
likelihood: bilby.gw.likelihood.GravitationalWaveTransient, optional
GravitationalWaveTransient used for sampling, used for waveform and
likelihood.interferometers.
priors: dict, optional
Dictionary of prior objects, used to fill in non-sampled parameters.
Returns
=======
"""
output_sample = sample.copy()
output_sample = fill_from_fixed_priors(output_sample, priors)
if likelihood is not None:
compute_per_detector_log_likelihoods(
samples=output_sample, likelihood=likelihood, npool=npool)
marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list())
if len(marginalized_parameters) > 0:
try:
generate_posterior_samples_from_marginalized_likelihood(
samples=output_sample, likelihood=likelihood, npool=npool)
except MarginalizedLikelihoodReconstructionError as e:
logger.warning(
"Marginalised parameter reconstruction failed with message "
"{}. Some parameters may not have the intended "
"interpretation.".format(e)
)
if ("ra" in output_sample.keys() and "dec" in output_sample.keys() and "psi" in output_sample.keys()):
compute_snrs(output_sample, likelihood, npool=npool)
else:
logger.info(
"Skipping SNR computation since samples have insufficient sky location information"
)
return output_sample
......@@ -171,6 +171,26 @@ class TestBasicConversions(unittest.TestCase):
)
self.assertTrue((self.delta_lambda_tilde - delta_lambda_tilde) < 1e-5)
def test_identity_conversion(self):
original_samples = dict(
mass_1=self.mass_1,
mass_2=self.mass_2,
mass_ratio=self.mass_ratio,
total_mass=self.total_mass,
chirp_mass=self.chirp_mass,
symmetric_mass_ratio=self.symmetric_mass_ratio,
cos_angle=self.cos_angle,
angle=self.angle,
lambda_1=self.lambda_1,
lambda_2=self.lambda_2,
lambda_tilde=self.lambda_tilde,
delta_lambda_tilde=self.delta_lambda_tilde
)
identity_samples, blank_list = conversion.identity_map_conversion(original_samples)
assert blank_list == []
for key, val in identity_samples.items():
assert val == self.__dict__[key]
class TestConvertToLALParams(unittest.TestCase):
def setUp(self):
......@@ -509,6 +529,50 @@ class TestGenerateAllParameters(unittest.TestCase):
for key in extra_expected:
self.assertIn(key, converted)
def test_identity_generation_no_likelihood(self):
test_fixed_prior = bilby.core.prior.PriorDict({
"test_param_a": bilby.core.prior.DeltaFunction(0, name="test_param_a"),
"test_param_b": bilby.core.prior.DeltaFunction(1, name="test_param_b")
}
)
output_sample = conversion.identity_map_generation(self.parameters, priors=test_fixed_prior)
assert output_sample.pop("test_param_a") == 0
assert output_sample.pop("test_param_b") == 1
for key, val in self.parameters.items():
assert output_sample.pop(key) == val
assert output_sample == {}
def test_identity_generation_with_likelihood(self):
priors = bilby.gw.prior.BBHPriorDict()
priors["geocent_time"] = bilby.core.prior.Uniform(0.4, 0.6)
self.parameters["time_jitter"] = 0.0
# Note we do *not* switch to azimuth/zenith, because the identity generation function
# is not intended to be capable of that conversion
ifos = bilby.gw.detector.InterferometerList(["H1"])
ifos.set_strain_data_from_power_spectral_densities(duration=1, sampling_frequency=256)
wfg = bilby.gw.waveform_generator.WaveformGenerator(
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole
)
likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
interferometers=ifos,
waveform_generator=wfg,
priors=priors,
phase_marginalization=True,
time_marginalization=True,
reference_frame="sky",
)
output_sample = conversion.identity_map_generation(self.parameters, priors=priors, likelihood=likelihood)
extra_expected = [
"phase",
"geocent_time",
"H1_optimal_snr",
"H1_matched_filter_snr",
]
for key in extra_expected:
self.assertIn(key, output_sample)
for key, val in self.parameters.items():
self.assertTrue(output_sample[key] == val)
class TestDistanceTransformations(unittest.TestCase):
def setUp(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment