diff --git a/tupak/prior.py b/tupak/prior.py index a1c8665acae05f56b7605fbce4357b9cf188da2f..4546075d44ff6958d69629b2e552546c21511ab4 100644 --- a/tupak/prior.py +++ b/tupak/prior.py @@ -433,7 +433,7 @@ def parse_keys_to_parameters(keys): return parameters -def fill_priors(prior, likelihood, parameters=None): +def fill_priors(prior, likelihood): """ Fill dictionary of priors based on required parameters of likelihood @@ -446,9 +446,9 @@ def fill_priors(prior, likelihood, parameters=None): dictionary of prior objects and floats likelihood: tupak.likelihood.GravitationalWaveTransient instance Used to infer the set of parameters to fill the prior with - parameters: list - list of parameters to be sampled in, this can override the default - priors for the waveform generator + + Note: if `likelihood` has `non_standard_sampling_parameter_keys`, then this + will set-up default priors for those as well. Returns ------- @@ -470,8 +470,8 @@ def fill_priors(prior, likelihood, parameters=None): missing_keys = set(likelihood.parameters) - set(prior.keys()) - if parameters is not None: - for parameter in parameters: + if getattr(likelihood, 'non_standard_sampling_parameter_keys', None) is not None: + for parameter in likelihood.non_standard_sampling_parameter_keys: prior[parameter] = create_default_prior(parameter) for missing_key in missing_keys: diff --git a/tupak/sampler.py b/tupak/sampler.py index 805cd87bf10b8e695f66a1d3001af36fb4e60c97..80f87338f07eff9a955e2f9e82c2c58f0110b1d6 100644 --- a/tupak/sampler.py +++ b/tupak/sampler.py @@ -444,7 +444,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', if priors is None: priors = dict() - priors = fill_priors(priors, likelihood, parameters=likelihood.non_standard_sampling_parameter_keys) + priors = fill_priors(priors, likelihood) tupak.prior.write_priors_to_file(priors, outdir) if implemented_samplers.__contains__(sampler.title()):