From 6e32b9961002f11fb026399a7a4784d1c0b54071 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Sun, 20 May 2018 15:01:49 +1000
Subject: [PATCH] Minor change to how non standard parameters are checked

Previously, `parameters` in `fill_prior` was only ever used to fill in
the non-standard parameters if sampling in something other than the
defaults. This directy checks that. The advantage is a cleaner logic and
we no longer assume the likelihood has a
`non_standard_sampling_parameter_keys` attribute.
---
 tupak/prior.py   | 12 ++++++------
 tupak/sampler.py |  2 +-
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/tupak/prior.py b/tupak/prior.py
index a1c8665ac..4546075d4 100644
--- a/tupak/prior.py
+++ b/tupak/prior.py
@@ -433,7 +433,7 @@ def parse_keys_to_parameters(keys):
     return parameters
 
 
-def fill_priors(prior, likelihood, parameters=None):
+def fill_priors(prior, likelihood):
     """
     Fill dictionary of priors based on required parameters of likelihood
 
@@ -446,9 +446,9 @@ def fill_priors(prior, likelihood, parameters=None):
         dictionary of prior objects and floats
     likelihood: tupak.likelihood.GravitationalWaveTransient instance
         Used to infer the set of parameters to fill the prior with
-    parameters: list
-        list of parameters to be sampled in, this can override the default
-        priors for the waveform generator
+
+    Note: if `likelihood` has `non_standard_sampling_parameter_keys`, then this
+    will set-up default priors for those as well.
 
     Returns
     -------
@@ -470,8 +470,8 @@ def fill_priors(prior, likelihood, parameters=None):
 
     missing_keys = set(likelihood.parameters) - set(prior.keys())
 
-    if parameters is not None:
-        for parameter in parameters:
+    if getattr(likelihood, 'non_standard_sampling_parameter_keys', None) is not None:
+        for parameter in likelihood.non_standard_sampling_parameter_keys:
             prior[parameter] = create_default_prior(parameter)
 
     for missing_key in missing_keys:
diff --git a/tupak/sampler.py b/tupak/sampler.py
index 805cd87bf..80f87338f 100644
--- a/tupak/sampler.py
+++ b/tupak/sampler.py
@@ -444,7 +444,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
 
     if priors is None:
         priors = dict()
-    priors = fill_priors(priors, likelihood, parameters=likelihood.non_standard_sampling_parameter_keys)
+    priors = fill_priors(priors, likelihood)
     tupak.prior.write_priors_to_file(priors, outdir)
 
     if implemented_samplers.__contains__(sampler.title()):
-- 
GitLab