diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py
index 824f294ce47349582aa3ec702cc7ff0ff1639781..73f5f3edb13aa9715a67b11b0a3a945667dbfefe 100644
--- a/bilby/core/sampler/pymc3.py
+++ b/bilby/core/sampler/pymc3.py
@@ -1,6 +1,7 @@
 from __future__ import absolute_import, print_function
 
 from collections import OrderedDict
+from distutils.version import StrictVersion
 
 import numpy as np
 
@@ -45,8 +46,6 @@ class Pymc3(MCMCSampler):
         'CategoricalGibbsMetropolis'. Note: you cannot provide a PyMC3 step
         method function itself here as it is outside of the model context
         manager.
-    nuts_kwargs: dict
-        Keyword arguments for the NUTS sampler.
     step_kwargs: dict
         Options for steps methods other than NUTS. The dictionary is keyed on
         lowercase step method names with values being dictionaries of keywords
@@ -56,13 +55,27 @@ class Pymc3(MCMCSampler):
 
     default_kwargs = dict(
         draws=500, step=None, init='auto', n_init=200000, start=None, trace=None, chain_idx=0,
-        chains=2, cores=1, tune=500, nuts_kwargs=None, step_kwargs=None, progressbar=True,
-        model=None, random_seed=None, discard_tuned_samples=True,
-        compute_convergence_checks=True)
+        chains=2, cores=1, tune=500, progressbar=True, model=None, random_seed=None,
+        discard_tuned_samples=True, compute_convergence_checks=True, nuts_kwargs=None,
+        step_kwargs=None,
+    )
+
+    default_nuts_kwargs = dict(
+        target_accept=None, max_treedepth=None, step_scale=None, Emax=None,
+        gamma=None, k=None, t0=None, adapt_step_size=None, early_max_treedepth=None,
+        scaling=None, is_cov=None, potential=None,
+    )
+
+    default_kwargs.update(default_nuts_kwargs)
 
     def __init__(self, likelihood, priors, outdir='outdir', label='label',
                  use_ratio=False, plot=False,
                  skip_import_verification=False, **kwargs):
+        # add default step kwargs
+        _, STEP_METHODS, _ = self._import_external_sampler()
+        self.default_step_kwargs = {m.__name__.lower(): None for m in STEP_METHODS}
+        self.default_kwargs.update(self.default_step_kwargs)
+
         super(Pymc3, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label,
                                     use_ratio=use_ratio, plot=plot,
                                     skip_import_verification=skip_import_verification, **kwargs)
@@ -454,8 +467,35 @@ class Pymc3(MCMCSampler):
             self.set_likelihood()
 
         # get the step method keyword arguments
-        step_kwargs = self.kwargs.pop('step_kwargs')
-        nuts_kwargs = self.kwargs.pop('nuts_kwargs')
+        step_kwargs = self.kwargs.pop("step_kwargs")
+        if step_kwargs is not None:
+            # remove all individual default step kwargs if passed together using
+            # step_kwargs keywords
+            for key in self.default_step_kwargs:
+                self.kwargs.pop(key)
+        else:
+            # remove any None default step keywords and place others in step_kwargs
+            step_kwargs = {}
+            for key in self.default_step_kwargs:
+                if self.kwargs[key] is None:
+                    self.kwargs.pop(key)
+                else:
+                    step_kwargs[key] = self.kwargs.pop(key)
+
+        nuts_kwargs = self.kwargs.pop("nuts_kwargs")
+        if nuts_kwargs is not None:
+            # remove all individual default nuts kwargs if passed together using
+            # nuts_kwargs keywords
+            for key in self.default_nuts_kwargs:
+                self.kwargs.pop(key)
+        else:
+            # remove any None default nuts keywords and place others in nut_kwargs
+            nuts_kwargs = {}
+            for key in self.default_nuts_kwargs:
+                if self.kwargs[key] is None:
+                    self.kwargs.pop(key)
+                else:
+                    nuts_kwargs[key] = self.kwargs.pop(key)
         methodslist = []
 
         # set the step method
@@ -496,13 +536,19 @@ class Pymc3(MCMCSampler):
                         self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
                     else:
                         # re-add step_kwargs if no step methods are set
-                        self.kwargs['step_kwargs'] = step_kwargs
+                        if len(step_kwargs) > 0 and StrictVersion(pymc3.__version__) < StrictVersion("3.7"):
+                            self.kwargs['step_kwargs'] = step_kwargs
 
         # check whether only NUTS step method has been assigned
         if np.all([sm.lower() == 'nuts' for sm in methodslist]):
             # in this case we can let PyMC3 autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs
             self.kwargs['step'] = None
-            self.kwargs['nuts_kwargs'] = nuts_kwargs
+
+            if len(nuts_kwargs) > 0 and StrictVersion(pymc3.__version__) < StrictVersion("3.7"):
+                self.kwargs['nuts_kwargs'] = nuts_kwargs
+            elif len(nuts_kwargs) > 0:
+                # add NUTS kwargs to standard kwargs
+                self.kwargs.update(nuts_kwargs)
 
         with self.pymc3_model:
             # perform the sampling
@@ -561,6 +607,10 @@ class Pymc3(MCMCSampler):
             args = {}
         return args, nuts_kwargs
 
+    def _pymc3_version(self):
+        pymc3, _, _ = self._import_external_sampler()
+        return pymc3.__version__
+
     def set_prior(self):
         """
         Set the PyMC3 prior distributions.
diff --git a/test/sampler_test.py b/test/sampler_test.py
index c3c5e4d415e16a2d9f662b47be5b9d1c20bc8538..f03fc97d00265993dd1375825139e213add14ad7 100644
--- a/test/sampler_test.py
+++ b/test/sampler_test.py
@@ -695,14 +695,16 @@ class TestPyMC3(unittest.TestCase):
             chains=2,
             cores=1,
             tune=500,
-            nuts_kwargs=None,
-            step_kwargs=None,
             progressbar=True,
             model=None,
+            nuts_kwargs=None,
+            step_kwargs=None,
             random_seed=None,
             discard_tuned_samples=True,
             compute_convergence_checks=True,
         )
+        expected.update(self.sampler.default_nuts_kwargs)
+        expected.update(self.sampler.default_step_kwargs)
         self.assertDictEqual(expected, self.sampler.kwargs)
 
     def test_translate_kwargs(self):
@@ -717,14 +719,16 @@ class TestPyMC3(unittest.TestCase):
             chains=2,
             cores=1,
             tune=500,
-            nuts_kwargs=None,
-            step_kwargs=None,
             progressbar=True,
             model=None,
+            nuts_kwargs=None,
+            step_kwargs=None,
             random_seed=None,
             discard_tuned_samples=True,
             compute_convergence_checks=True,
         )
+        expected.update(self.sampler.default_nuts_kwargs)
+        expected.update(self.sampler.default_step_kwargs)
         self.sampler.kwargs["draws"] = 123
         for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
             new_kwargs = self.sampler.kwargs.copy()