From 52cdf6fc4b947270a872fca3c506c3d8f448fa8c Mon Sep 17 00:00:00 2001 From: Matthew Pitkin <matthew.pitkin@ligo.org> Date: Tue, 27 Nov 2018 15:59:44 +0000 Subject: [PATCH] pymc3.py: explicitly set step method arguments - PyMC3 only applies the step_kwargs and nuts_kwargs arguments to the step methods if they are automaticallu allocated. Otherwise, it assumes you have already initialised the method with the required keywork arguments. This was not being done in bilby, so this patch fixes it and does explicitly initialise the step methods with values passed to step_kwargs or nuts_kwargs as required. --- bilby/core/sampler/pymc3.py | 52 ++++++++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py index 4f45928c..caa03ede 100644 --- a/bilby/core/sampler/pymc3.py +++ b/bilby/core/sampler/pymc3.py @@ -425,6 +425,10 @@ class Pymc3(MCMCSampler): # set the prior self.set_prior() + # get the step method keyword arguments + step_kwargs = self.kwargs.pop('step_kwargs') + nuts_kwargs = self.kwargs.pop('nuts_kwargs') + # set the step method if isinstance(self.step_method, (dict, OrderedDict)): # create list of step methods (any not given will default to NUTS) @@ -435,19 +439,59 @@ class Pymc3(MCMCSampler): if isinstance(self.step_method[key], list): for sms in self.step_method[key]: curmethod = sms.lower() - self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]])) + args = {} + if curmethod == 'nuts': + if nuts_kwargs is not None: + args = nuts_kwargs + else: + args = step_kwargs.get('nuts', {}) + else: + args = step_kwargs.get(curmethod, {}) + self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args)) else: curmethod = self.step_method[key].lower() - self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]])) + args = {} + if curmethod == 'nuts': + if nuts_kwargs is not None: + args = nuts_kwargs + else: + args = step_kwargs.get('nuts', {}) + else: + args = step_kwargs.get(curmethod, {}) + self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args)) else: with self.pymc3_model: # check for a compound step list if isinstance(self.step_method, list): compound = [] for sms in self.step_method: - compound.append(pymc3.__dict__[step_methods[sms.lower()]]()) + curmethod = sms.lower() + args = {} + if curmethod == 'nuts': + if nuts_kwargs is not None: + args = nuts_kwargs + else: + args = step_kwargs.get('nuts', {}) + else: + args = step_kwargs.get(curmethod, {}) + compound.append(pymc3.__dict__[step_methods[curmethod]](**args)) else: - self.kwargs['step'] = None if self.step_method is None else pymc3.__dict__[step_methods[self.step_method.lower()]]() + self.kwargs['step'] = None + if self.step_method is not None: + curmethod = self.step_method.lower() + args = {} + if curmethod == 'nuts': + if nuts_kwargs is not None: + args = nuts_kwargs + else: + args = step_kwargs.get('nuts', {}) + else: + args = step_kwargs.get(curmethod, {}) + self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args) + else: + # re-add step_kwargs and nuts_kwargs if no step methods are set + self.kwargs['nuts_kwargs'] = nuts_kwargs + self.kwargs['step_kwargs'] = step_kwargs # if a custom log_likelihood function requires a `sampler` argument # then use that log_likelihood function, with the assumption that it -- GitLab