diff --git a/tupak/prior.py b/tupak/prior.py index bdce5908b9217ee95778bcd15636e359bb56286f..bdcc435c4a47be3622bfbcfdfb3a599f72fc7666 100644 --- a/tupak/prior.py +++ b/tupak/prior.py @@ -403,6 +403,8 @@ def create_default_prior(name): if name in default_priors.keys(): prior = default_priors[name] else: + logging.info( + "No default prior found for variable {}.".format(name)) prior = None return prior @@ -430,40 +432,48 @@ def parse_keys_to_parameters(keys): def fill_priors(prior, waveform_generator): """ - Fill dictionary of priors based on required parameters for waveform generator + Fill dictionary of priors based on required parameters of waveform generator + + Any floats in prior will be converted to delta function prior. Any + required, non-specified parameters will use the default. - Any floats in prior will be converted to delta function prior. - Any required, non-specified parameters will use the default. Parameters ---------- prior: dict dictionary of prior objects and floats waveform_generator: WaveformGenerator waveform generator to be used for inference + + Returns + ------- + prior: dict + The filled prior dictionary + """ - bad_keys = [] + for key in prior: if isinstance(prior[key], Prior): continue elif isinstance(prior[key], float) or isinstance(prior[key], int): prior[key] = DeltaFunction(prior[key]) - logging.info("{} converted to delta function prior.".format(key)) + logging.info( + "{} converted to delta function prior.".format(key)) else: - logging.warning("{} cannot be converted to delta function prior.".format(key)) - logging.warning("If required the default prior will be used.") - bad_keys.append(key) + logging.info( + "{} cannot be converted to delta function prior.".format(key)) missing_keys = set(waveform_generator.parameters) - set(prior.keys()) for missing_key in missing_keys: - prior[missing_key] = create_default_prior(missing_key) - if prior[missing_key] is None: - logging.warning("No default prior found for unspecified variable {}.".format(missing_key)) - logging.warning("This variable will NOT be sampled.") - bad_keys.append(missing_key) - - for key in bad_keys: - prior.pop(key) + default_prior = create_default_prior(missing_key) + if default_prior is None: + set_val = waveform_generator.parameters[missing_key] + logging.warning( + "Parameter {} has no default prior and is set to {}, this will" + " not be sampled and may cause an error." + .format(missing_key, set_val)) + else: + prior[missing_key] = default_prior return prior