diff --git a/examples/other_examples/linear_regression_pymc3_custom_likelihood.py b/examples/other_examples/linear_regression_pymc3_custom_likelihood.py
index 9535d58d92b71a3152fd39421c55d49938e94157..3b3ab8af138ead0dc1428e1737eb46ed1ea255ca 100644
--- a/examples/other_examples/linear_regression_pymc3_custom_likelihood.py
+++ b/examples/other_examples/linear_regression_pymc3_custom_likelihood.py
@@ -109,11 +109,37 @@ class GaussianLikelihoodPyMC3(tupak.Likelihood):
 # the time, data and signal model
 likelihood = GaussianLikelihoodPyMC3(time, data, sigma, model)
 
+
+# Define a custom prior for one of the parameter for use with PyMC3
+class PriorPyMC3(tupak.core.prior.Prior):
+    def __init__(self, minimum, maximum, name=None, latex_label=None):
+        """
+        Uniform prior with bounds (should be equivalent to tupak.prior.Uniform)
+        """
+
+        tupak.core.prior.Prior.__init__(self, name, latex_label,
+                                        minimum=minimum,
+                                        maximum=maximum)
+
+    def ln_prob(self, sampler=None):
+        """
+        Change ln_prob method to take in a Sampler and return a PyMC3
+        distribution.
+        """
+
+        from tupak.core.sampler import Pymc3
+
+        if not isinstance(sampler, Pymc3):
+            raise ValueError("Sampler is not a tupak Pymc3 sampler object")
+
+        return pm.Uniform(self.name, lower=self.minimum,
+                          upper=self.maximum)
+
 # From hereon, the syntax is exactly equivalent to other tupak examples
 # We make a prior
 priors = {}
 priors['m'] = tupak.core.prior.Uniform(0, 5, 'm')
-priors['c'] = tupak.core.prior.Uniform(-2, 2, 'c')
+priors['c'] = PriorPyMC3(-2, 2, 'c')
 
 # And run sampler
 result = tupak.run_sampler(
diff --git a/tupak/core/prior.py b/tupak/core/prior.py
index 31f04c93ffca06ef0c4051daf4998ac7d44ecc3f..0a415160120948616eddafe3a98621e9685e30db 100644
--- a/tupak/core/prior.py
+++ b/tupak/core/prior.py
@@ -467,30 +467,6 @@ class Prior(object):
             label = self.name
         return label
 
-    def pymc3_prior(self, sampler):
-        """
-        'pymc3_prior' A user defined PyMC3 prior.
-
-        This should be overwritten by each subclass if needed.
-
-        Parameters
-        ----------
-        sampler: `tupak.core.sampler.Sampler`
-            A Sampler class
-
-        Returns
-        -------
-        None
-
-        """
-
-        from tupak.core.sampler import Sampler
-
-        if not isinstance(sampler, Sampler):
-            raise ValueError("'sampler' is not a Sampler class")
-
-        return None
-
 
 class DeltaFunction(Prior):
 
diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py
index 045d1f66776928d6cec5a9e318085031300d22a8..eca15e6c87f9e1470fc657ea59e7127c9a59497a 100644
--- a/tupak/core/sampler.py
+++ b/tupak/core/sampler.py
@@ -957,6 +957,46 @@ class Pymc3(Sampler):
         """
         pass
 
+    def _initialise_parameters(self):
+        """
+        Change `_initialise_parameters()`, so that it does call the `sample`
+        method in the Prior class.
+
+        """
+
+        self.__search_parameter_keys = []
+        self.__fixed_parameter_keys = []
+
+        for key in self.priors:
+            if isinstance(self.priors[key], Prior) \
+                    and self.priors[key].is_fixed is False:
+                self.__search_parameter_keys.append(key)
+            elif isinstance(self.priors[key], Prior) \
+                    and self.priors[key].is_fixed is True:
+                self.__fixed_parameter_keys.append(key)
+
+        logger.info("Search parameters:")
+        for key in self.__search_parameter_keys:
+            logger.info('  {} = {}'.format(key, self.priors[key]))
+        for key in self.__fixed_parameter_keys:
+            logger.info('  {} = {}'.format(key, self.priors[key].peak))
+
+    def _initialise_result(self):
+        """
+        Initialise results within Pymc3 subclass.
+        """
+        result = Result()
+        result.sampler = self.__class__.__name__.lower()
+        result.search_parameter_keys = self.__search_parameter_keys
+        result.fixed_parameter_keys = self.__fixed_parameter_keys
+        result.parameter_labels = [
+            self.priors[k].latex_label for k in
+            self.__search_parameter_keys]
+        result.label = self.label
+        result.outdir = self.outdir
+        result.kwargs = self.kwargs
+        return result
+
     @property
     def kwargs(self):
         """ Ensures that proper keyword arguments are used for the Pymc3 sampler.
@@ -1253,10 +1293,15 @@ class Pymc3(Sampler):
         # set the parameter prior distributions (in the model context manager)
         with self.pymc3_model:
             for key in self.priors:
-                # if the prior contains a pymc3_prior method use that otherwise try
-                # and find the PyMC3 distribution
-                if self.priors[key].pymc3_prior(self) is not None:
-                    self.pymc3_priors[key] = self.priors[key].pymc3_prior(self)
+                # if the prior contains ln_prob method that takes a 'sampler' argument
+                # then try using that
+                lnprobargs = inspect.getargspec(self.priors[key].ln_prob).args
+                if 'sampler' in lnprobargs:
+                    try:
+                        self.pymc3_priors[key] = self.priors[key].ln_prob(sampler=self)
+                    except RuntimeError:
+                        raise RuntimeError(("Problem setting PyMC3 prior for ",
+                            "'{}'".format(key)))
                 else:
                     # use Prior distribution name
                     distname = self.priors[key].__class__.__name__