From 52cdf6fc4b947270a872fca3c506c3d8f448fa8c Mon Sep 17 00:00:00 2001
From: Matthew Pitkin <matthew.pitkin@ligo.org>
Date: Tue, 27 Nov 2018 15:59:44 +0000
Subject: [PATCH] pymc3.py: explicitly set step method arguments  - PyMC3 only
 applies the step_kwargs and nuts_kwargs arguments    to the step methods if
 they are automaticallu allocated. Otherwise,    it assumes you have already
 initialised the method with the    required keywork arguments. This was not
 being done in bilby, so    this patch fixes it and does explicitly initialise
 the step methods    with values passed to step_kwargs or nuts_kwargs as
 required.

---
 bilby/core/sampler/pymc3.py | 52 ++++++++++++++++++++++++++++++++++---
 1 file changed, 48 insertions(+), 4 deletions(-)

diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py
index 4f45928c..caa03ede 100644
--- a/bilby/core/sampler/pymc3.py
+++ b/bilby/core/sampler/pymc3.py
@@ -425,6 +425,10 @@ class Pymc3(MCMCSampler):
         # set the prior
         self.set_prior()
 
+        # get the step method keyword arguments
+        step_kwargs = self.kwargs.pop('step_kwargs')
+        nuts_kwargs = self.kwargs.pop('nuts_kwargs')
+
         # set the step method
         if isinstance(self.step_method, (dict, OrderedDict)):
             # create list of step methods (any not given will default to NUTS)
@@ -435,19 +439,59 @@ class Pymc3(MCMCSampler):
                     if isinstance(self.step_method[key], list):
                         for sms in self.step_method[key]:
                             curmethod = sms.lower()
-                            self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]]))
+                            args = {}
+                            if curmethod == 'nuts':
+                                if nuts_kwargs is not None:
+                                    args = nuts_kwargs
+                                else:
+                                    args = step_kwargs.get('nuts', {})
+                            else:
+                                args = step_kwargs.get(curmethod, {})
+                            self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
                     else:
                         curmethod = self.step_method[key].lower()
-                        self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]]))
+                        args = {}
+                        if curmethod == 'nuts':
+                            if nuts_kwargs is not None:
+                                args = nuts_kwargs
+                            else:
+                                args = step_kwargs.get('nuts', {})
+                        else:
+                            args = step_kwargs.get(curmethod, {})
+                        self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
         else:
             with self.pymc3_model:
                 # check for a compound step list
                 if isinstance(self.step_method, list):
                     compound = []
                     for sms in self.step_method:
-                        compound.append(pymc3.__dict__[step_methods[sms.lower()]]())
+                        curmethod = sms.lower()
+                        args = {}
+                        if curmethod == 'nuts':
+                            if nuts_kwargs is not None:
+                                args = nuts_kwargs
+                            else:
+                                args = step_kwargs.get('nuts', {})
+                        else:
+                            args = step_kwargs.get(curmethod, {})
+                        compound.append(pymc3.__dict__[step_methods[curmethod]](**args))
                 else:
-                    self.kwargs['step'] = None if self.step_method is None else pymc3.__dict__[step_methods[self.step_method.lower()]]()
+                    self.kwargs['step'] = None
+                    if self.step_method is not None:
+                        curmethod = self.step_method.lower()
+                        args = {}
+                        if curmethod == 'nuts':
+                            if nuts_kwargs is not None:
+                                args = nuts_kwargs
+                            else:
+                                args = step_kwargs.get('nuts', {})
+                        else:
+                            args = step_kwargs.get(curmethod, {})
+                        self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
+                    else:
+                        # re-add step_kwargs and nuts_kwargs if no step methods are set
+                        self.kwargs['nuts_kwargs'] = nuts_kwargs
+                        self.kwargs['step_kwargs'] = step_kwargs
 
         # if a custom log_likelihood function requires a `sampler` argument
         # then use that log_likelihood function, with the assumption that it
-- 
GitLab