Skip to content
Snippets Groups Projects
Commit c6a5c272 authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'simplify-logic-of-extra-parameters' into 'master'

Minor change to how non standard parameters are checked

See merge request Monash/tupak!43
parents c90b9651 4a037c5d
No related branches found
No related tags found
1 merge request!43Minor change to how non standard parameters are checked
Pipeline #
......@@ -198,6 +198,7 @@ class TestFillPrior(unittest.TestCase):
def setUp(self):
self.likelihood = Mock()
self.likelihood.parameters = dict(a=0, b=0, c=0, d=0, asdf=0, ra=1)
self.likelihood.non_standard_sampling_parameter_keys = dict(t=8)
self.priors = dict(a=1, b=1.1, c='string', d=tupak.prior.Uniform(0, 1))
self.priors = tupak.prior.fill_priors(self.priors, self.likelihood)
......
......@@ -433,7 +433,7 @@ def parse_keys_to_parameters(keys):
return parameters
def fill_priors(prior, likelihood, parameters=None):
def fill_priors(prior, likelihood):
"""
Fill dictionary of priors based on required parameters of likelihood
......@@ -446,9 +446,9 @@ def fill_priors(prior, likelihood, parameters=None):
dictionary of prior objects and floats
likelihood: tupak.likelihood.GravitationalWaveTransient instance
Used to infer the set of parameters to fill the prior with
parameters: list
list of parameters to be sampled in, this can override the default
priors for the waveform generator
Note: if `likelihood` has `non_standard_sampling_parameter_keys`, then this
will set-up default priors for those as well.
Returns
-------
......@@ -470,8 +470,8 @@ def fill_priors(prior, likelihood, parameters=None):
missing_keys = set(likelihood.parameters) - set(prior.keys())
if parameters is not None:
for parameter in parameters:
if getattr(likelihood, 'non_standard_sampling_parameter_keys', None) is not None:
for parameter in likelihood.non_standard_sampling_parameter_keys:
prior[parameter] = create_default_prior(parameter)
for missing_key in missing_keys:
......
......@@ -444,7 +444,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
if priors is None:
priors = dict()
priors = fill_priors(priors, likelihood, parameters=likelihood.non_standard_sampling_parameter_keys)
priors = fill_priors(priors, likelihood)
tupak.prior.write_priors_to_file(priors, outdir)
if implemented_samplers.__contains__(sampler.title()):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment