Skip to content
Snippets Groups Projects

Clean up the logic of fill_priors and close #58

Merged Gregory Ashton requested to merge clean-up-logic-up-fill-priors into master
1 file
+ 26
16
Compare changes
  • Side-by-side
  • Inline
+ 26
16
@@ -403,6 +403,8 @@ def create_default_prior(name):
if name in default_priors.keys():
prior = default_priors[name]
else:
logging.info(
"No default prior found for variable {}.".format(name))
prior = None
return prior
@@ -430,40 +432,48 @@ def parse_keys_to_parameters(keys):
def fill_priors(prior, waveform_generator):
"""
Fill dictionary of priors based on required parameters for waveform generator
Fill dictionary of priors based on required parameters of waveform generator
Any floats in prior will be converted to delta function prior. Any
required, non-specified parameters will use the default.
Any floats in prior will be converted to delta function prior.
Any required, non-specified parameters will use the default.
Parameters
----------
prior: dict
dictionary of prior objects and floats
waveform_generator: WaveformGenerator
waveform generator to be used for inference
Returns
-------
prior: dict
The filled prior dictionary
"""
bad_keys = []
for key in prior:
if isinstance(prior[key], Prior):
continue
elif isinstance(prior[key], float) or isinstance(prior[key], int):
prior[key] = DeltaFunction(prior[key])
logging.info("{} converted to delta function prior.".format(key))
logging.info(
"{} converted to delta function prior.".format(key))
else:
logging.warning("{} cannot be converted to delta function prior.".format(key))
logging.warning("If required the default prior will be used.")
bad_keys.append(key)
logging.info(
"{} cannot be converted to delta function prior.".format(key))
missing_keys = set(waveform_generator.parameters) - set(prior.keys())
for missing_key in missing_keys:
prior[missing_key] = create_default_prior(missing_key)
if prior[missing_key] is None:
logging.warning("No default prior found for unspecified variable {}.".format(missing_key))
logging.warning("This variable will NOT be sampled.")
bad_keys.append(missing_key)
for key in bad_keys:
prior.pop(key)
default_prior = create_default_prior(missing_key)
if default_prior is None:
set_val = waveform_generator.parameters[missing_key]
logging.warning(
"Parameter {} has no default prior and is set to {}, this will"
" not be sampled and may cause an error."
.format(missing_key, set_val))
else:
prior[missing_key] = default_prior
return prior
Loading