diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py index 4f45928ce584d85dd031e4f21130cd4a58832f3c..caa03ede669f3161a39443b934328dceb98d13aa 100644 --- a/bilby/core/sampler/pymc3.py +++ b/bilby/core/sampler/pymc3.py @@ -425,6 +425,10 @@ class Pymc3(MCMCSampler): # set the prior self.set_prior() + # get the step method keyword arguments + step_kwargs = self.kwargs.pop('step_kwargs') + nuts_kwargs = self.kwargs.pop('nuts_kwargs') + # set the step method if isinstance(self.step_method, (dict, OrderedDict)): # create list of step methods (any not given will default to NUTS) @@ -435,19 +439,59 @@ class Pymc3(MCMCSampler): if isinstance(self.step_method[key], list): for sms in self.step_method[key]: curmethod = sms.lower() - self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]])) + args = {} + if curmethod == 'nuts': + if nuts_kwargs is not None: + args = nuts_kwargs + else: + args = step_kwargs.get('nuts', {}) + else: + args = step_kwargs.get(curmethod, {}) + self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args)) else: curmethod = self.step_method[key].lower() - self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]])) + args = {} + if curmethod == 'nuts': + if nuts_kwargs is not None: + args = nuts_kwargs + else: + args = step_kwargs.get('nuts', {}) + else: + args = step_kwargs.get(curmethod, {}) + self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args)) else: with self.pymc3_model: # check for a compound step list if isinstance(self.step_method, list): compound = [] for sms in self.step_method: - compound.append(pymc3.__dict__[step_methods[sms.lower()]]()) + curmethod = sms.lower() + args = {} + if curmethod == 'nuts': + if nuts_kwargs is not None: + args = nuts_kwargs + else: + args = step_kwargs.get('nuts', {}) + else: + args = step_kwargs.get(curmethod, {}) + compound.append(pymc3.__dict__[step_methods[curmethod]](**args)) else: - self.kwargs['step'] = None if self.step_method is None else pymc3.__dict__[step_methods[self.step_method.lower()]]() + self.kwargs['step'] = None + if self.step_method is not None: + curmethod = self.step_method.lower() + args = {} + if curmethod == 'nuts': + if nuts_kwargs is not None: + args = nuts_kwargs + else: + args = step_kwargs.get('nuts', {}) + else: + args = step_kwargs.get(curmethod, {}) + self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args) + else: + # re-add step_kwargs and nuts_kwargs if no step methods are set + self.kwargs['nuts_kwargs'] = nuts_kwargs + self.kwargs['step_kwargs'] = step_kwargs # if a custom log_likelihood function requires a `sampler` argument # then use that log_likelihood function, with the assumption that it