diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py
index 4f45928ce584d85dd031e4f21130cd4a58832f3c..caa03ede669f3161a39443b934328dceb98d13aa 100644
--- a/bilby/core/sampler/pymc3.py
+++ b/bilby/core/sampler/pymc3.py
@@ -425,6 +425,10 @@ class Pymc3(MCMCSampler):
         # set the prior
         self.set_prior()
 
+        # get the step method keyword arguments
+        step_kwargs = self.kwargs.pop('step_kwargs')
+        nuts_kwargs = self.kwargs.pop('nuts_kwargs')
+
         # set the step method
         if isinstance(self.step_method, (dict, OrderedDict)):
             # create list of step methods (any not given will default to NUTS)
@@ -435,19 +439,59 @@ class Pymc3(MCMCSampler):
                     if isinstance(self.step_method[key], list):
                         for sms in self.step_method[key]:
                             curmethod = sms.lower()
-                            self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]]))
+                            args = {}
+                            if curmethod == 'nuts':
+                                if nuts_kwargs is not None:
+                                    args = nuts_kwargs
+                                else:
+                                    args = step_kwargs.get('nuts', {})
+                            else:
+                                args = step_kwargs.get(curmethod, {})
+                            self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
                     else:
                         curmethod = self.step_method[key].lower()
-                        self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]]))
+                        args = {}
+                        if curmethod == 'nuts':
+                            if nuts_kwargs is not None:
+                                args = nuts_kwargs
+                            else:
+                                args = step_kwargs.get('nuts', {})
+                        else:
+                            args = step_kwargs.get(curmethod, {})
+                        self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
         else:
             with self.pymc3_model:
                 # check for a compound step list
                 if isinstance(self.step_method, list):
                     compound = []
                     for sms in self.step_method:
-                        compound.append(pymc3.__dict__[step_methods[sms.lower()]]())
+                        curmethod = sms.lower()
+                        args = {}
+                        if curmethod == 'nuts':
+                            if nuts_kwargs is not None:
+                                args = nuts_kwargs
+                            else:
+                                args = step_kwargs.get('nuts', {})
+                        else:
+                            args = step_kwargs.get(curmethod, {})
+                        compound.append(pymc3.__dict__[step_methods[curmethod]](**args))
                 else:
-                    self.kwargs['step'] = None if self.step_method is None else pymc3.__dict__[step_methods[self.step_method.lower()]]()
+                    self.kwargs['step'] = None
+                    if self.step_method is not None:
+                        curmethod = self.step_method.lower()
+                        args = {}
+                        if curmethod == 'nuts':
+                            if nuts_kwargs is not None:
+                                args = nuts_kwargs
+                            else:
+                                args = step_kwargs.get('nuts', {})
+                        else:
+                            args = step_kwargs.get(curmethod, {})
+                        self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
+                    else:
+                        # re-add step_kwargs and nuts_kwargs if no step methods are set
+                        self.kwargs['nuts_kwargs'] = nuts_kwargs
+                        self.kwargs['step_kwargs'] = step_kwargs
 
         # if a custom log_likelihood function requires a `sampler` argument
         # then use that log_likelihood function, with the assumption that it