diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index 63f38004cfd5d3ec174b88008c7d97c54b30368a..c67d0c1cae95260edb3b129d6d079f8bdfd5e647 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -16,7 +16,7 @@ from scipy.special import erf, erfinv, xlogy, log1p,\
 from matplotlib.cbook import flatten
 
 # Keep import bilby statement, it is necessary for some eval() statements
-from .utils import BilbyJsonEncoder, decode_bilby_json
+from .utils import BilbyJsonEncoder, decode_bilby_json, infer_parameters_from_function
 from .utils import (
     check_directory_exists_and_if_not_mkdir,
     infer_args_from_method, logger
@@ -173,7 +173,7 @@ class PriorDict(dict):
                 else:
                     module = __name__
                 cls = getattr(import_module(module), cls, cls)
-                if key.lower() == "conversion_function":
+                if key.lower() in ["conversion_function", "condition_func"]:
                     setattr(self, key, cls)
                 elif (cls.__name__ in ['MultivariateGaussianDist',
                                        'MultivariateNormalDist']):
@@ -331,11 +331,10 @@ class PriorDict(dict):
         self.convert_floats_to_delta_functions()
         samples = dict()
         for key in keys:
-            if isinstance(self[key], Prior):
-                if isinstance(self[key], Constraint):
-                    continue
-                else:
-                    samples[key] = self[key].sample(size=size)
+            if isinstance(self[key], Constraint):
+                continue
+            elif isinstance(self[key], Prior):
+                samples[key] = self[key].sample(size=size)
             else:
                 logger.debug('{} not a known prior.'.format(key))
         return samples
@@ -488,6 +487,201 @@ class PriorSet(PriorDict):
         super(PriorSet, self).__init__(dictionary, filename)
 
 
+class ConditionalPriorDict(PriorDict):
+
+    def __init__(self, dictionary=None, filename=None, conversion_function=None):
+        """
+
+        Parameters
+        ----------
+        dictionary: dict
+            See parent class
+        filename: str
+            See parent class
+        """
+        self._conditional_keys = []
+        self._unconditional_keys = []
+        self._rescale_keys = []
+        self._rescale_indexes = []
+        self._least_recently_rescaled_keys = []
+        super(ConditionalPriorDict, self).__init__(
+            dictionary=dictionary, filename=filename,
+            conversion_function=conversion_function
+        )
+        self._resolved = False
+        self._resolve_conditions()
+
+    def _resolve_conditions(self):
+        """
+        Resolves how priors depend on each other and automatically
+        sorts them into the right order.
+        1. All unconditional priors are put in front in arbitrary order
+        2. We loop through all the unsorted conditional priors to find
+        which one can go next
+        3. We repeat step 2 len(self) number of times to make sure that
+        all conditional priors will be sorted in order
+        4. We set the `self._resolved` flag to True if all conditional
+        priors were added in the right order
+        """
+        self._unconditional_keys = [key for key in self.keys() if not hasattr(self[key], 'condition_func')]
+        conditional_keys_unsorted = [key for key in self.keys() if hasattr(self[key], 'condition_func')]
+        self._conditional_keys = []
+        for _ in range(len(self)):
+            for key in conditional_keys_unsorted[:]:
+                if self._check_conditions_resolved(key, self.sorted_keys):
+                    self._conditional_keys.append(key)
+                    conditional_keys_unsorted.remove(key)
+
+        self._resolved = True
+        if len(conditional_keys_unsorted) != 0:
+            self._resolved = False
+
+    def _check_conditions_resolved(self, key, sampled_keys):
+        """ Checks if all required variables have already been sampled so we can sample this key """
+        conditions_resolved = True
+        for k in self[key].required_variables:
+            if k not in sampled_keys:
+                conditions_resolved = False
+        return conditions_resolved
+
+    def sample_subset(self, keys=iter([]), size=None):
+        self.convert_floats_to_delta_functions()
+        subset_dict = ConditionalPriorDict({key: self[key] for key in keys})
+        if not subset_dict._resolved:
+            raise IllegalConditionsException("The current set of priors contains unresolvable conditions.")
+        samples = dict()
+        for key in subset_dict.sorted_keys:
+            if isinstance(self[key], Constraint):
+                continue
+            elif isinstance(self[key], Prior):
+                try:
+                    samples[key] = subset_dict[key].sample(size=size, **subset_dict.get_required_variables(key))
+                except ValueError:
+                    # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw)
+                    # If that is the case, we sample each sample individually.
+                    required_variables = subset_dict.get_required_variables(key)
+                    samples[key] = np.zeros(size)
+                    for i in range(size):
+                        rvars = {key: value[i] for key, value in required_variables.items()}
+                        samples[key][i] = subset_dict[key].sample(**rvars)
+            else:
+                logger.debug('{} not a known prior.'.format(key))
+        return samples
+
+    def get_required_variables(self, key):
+        """ Returns the required variables to sample a given conditional key.
+
+        Parameters
+        ----------
+        key : str
+            Name of the key that we want to know the required variables for
+
+        Returns
+        ----------
+        dict: key/value pairs of the required variables
+        """
+        return {k: self[k].least_recently_sampled for k in getattr(self[key], 'required_variables', [])}
+
+    def prob(self, sample, **kwargs):
+        """
+
+        Parameters
+        ----------
+        sample: dict
+            Dictionary of the samples of which we want to have the probability of
+        kwargs:
+            The keyword arguments are passed directly to `np.product`
+
+        Returns
+        -------
+        float: Joint probability of all individual sample probabilities
+
+        """
+        self._check_resolved()
+        for key, value in sample.items():
+            self[key].least_recently_sampled = value
+        res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample]
+        return np.product(res, **kwargs)
+
+    def ln_prob(self, sample, axis=None):
+        """
+
+        Parameters
+        ----------
+        sample: dict
+            Dictionary of the samples of which we want to have the log probability of
+        axis: Union[None, int]
+            Axis along which the summation is performed
+
+        Returns
+        -------
+        float: Joint log probability of all the individual sample probabilities
+
+        """
+        self._check_resolved()
+        for key, value in sample.items():
+            self[key].least_recently_sampled = value
+        res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample]
+        return np.sum(res, axis=axis)
+
+    def rescale(self, keys, theta):
+        """Rescale samples from unit cube to prior
+
+        Parameters
+        ----------
+        keys: list
+            List of prior keys to be rescaled
+        theta: list
+            List of randomly drawn values on a unit cube associated with the prior keys
+
+        Returns
+        -------
+        list: List of floats containing the rescaled sample
+        """
+        self._check_resolved()
+        self._update_rescale_keys(keys)
+        result = dict()
+        for key, index in zip(self._rescale_keys, self._rescale_indexes):
+            required_variables = {k: result[k] for k in getattr(self[key], 'required_variables', [])}
+            result[key] = self[key].rescale(theta[index], **required_variables)
+        return [result[key] for key in keys]
+
+    def _update_rescale_keys(self, keys):
+        if not keys == self._least_recently_rescaled_keys:
+            self._set_rescale_keys_and_indexes(keys)
+            self._least_recently_rescaled_keys = keys
+
+    def _set_rescale_keys_and_indexes(self, keys):
+        unconditional_keys, unconditional_idxs, _ = np.intersect1d(keys, self.unconditional_keys, return_indices=True)
+        conditional_keys, conditional_idxs, _ = np.intersect1d(keys, self.conditional_keys, return_indices=True)
+        self._rescale_keys = np.append(unconditional_keys, conditional_keys)
+        self._rescale_indexes = np.append(unconditional_idxs, conditional_idxs)
+
+    def _check_resolved(self):
+        if not self._resolved:
+            raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
+
+    @property
+    def conditional_keys(self):
+        return self._conditional_keys
+
+    @property
+    def unconditional_keys(self):
+        return self._unconditional_keys
+
+    @property
+    def sorted_keys(self):
+        return self.unconditional_keys + self.conditional_keys
+
+    def __setitem__(self, key, value):
+        super(ConditionalPriorDict, self).__setitem__(key, value)
+        self._resolve_conditions()
+
+    def __delitem__(self, key):
+        super(ConditionalPriorDict, self).__delitem__(key)
+        self._resolve_conditions()
+
+
 def create_default_prior(name, default_priors_file=None):
     """Make a default prior for a parameter with a known name.
 
@@ -548,6 +742,7 @@ class Prior(object):
         self.unit = unit
         self.minimum = minimum
         self.maximum = maximum
+        self.least_recently_sampled = None
         self.boundary = boundary
 
     def __call__(self):
@@ -588,7 +783,8 @@ class Prior(object):
         float: A random number between 0 and 1, rescaled to match the distribution of this Prior
 
         """
-        return self.rescale(np.random.uniform(0, 1, size))
+        self.least_recently_sampled = self.rescale(np.random.uniform(0, 1, size))
+        return self.least_recently_sampled
 
     def rescale(self, val):
         """
@@ -692,7 +888,7 @@ class Prior(object):
 
         """
         prior_name = self.__class__.__name__
-        instantiation_dict = self._get_instantiation_dict()
+        instantiation_dict = self.get_instantiation_dict()
         args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
                           for key in instantiation_dict])
         return "{}({})".format(prior_name, args)
@@ -775,7 +971,7 @@ class Prior(object):
     def maximum(self, maximum):
         self._maximum = maximum
 
-    def _get_instantiation_dict(self):
+    def get_instantiation_dict(self):
         subclass_args = infer_args_from_method(self.__init__)
         property_names = [p for p in dir(self.__class__)
                           if isinstance(getattr(self.__class__, p), property)]
@@ -825,11 +1021,19 @@ class Prior(object):
         kwargs = cls._split_repr(string)
         for key in kwargs:
             val = kwargs[key]
-            if key not in subclass_args:
+            if key not in subclass_args and not hasattr(cls, "reference_params"):
                 raise AttributeError('Unknown argument {} for class {}'.format(
                     key, cls.__name__))
             else:
                 kwargs[key] = cls._parse_argument_string(val)
+            if key in ["condition_func", "conversion_function"] and isinstance(kwargs[key], str):
+                if "." in kwargs[key]:
+                    module = '.'.join(kwargs[key].split('.')[:-1])
+                    name = kwargs[key].split('.')[-1]
+                else:
+                    module = __name__
+                    name = kwargs[key]
+                kwargs[key] = getattr(import_module(module), name)
         return cls(**kwargs)
 
     @classmethod
@@ -2982,7 +3186,7 @@ class MultivariateGaussianDist(object):
 
         return np.exp(self.ln_prob(samp))
 
-    def _get_instantiation_dict(self):
+    def get_instantiation_dict(self):
         subclass_args = infer_args_from_method(self.__init__)
         property_names = [p for p in dir(self.__class__)
                           if isinstance(getattr(self.__class__, p), property)]
@@ -3013,7 +3217,7 @@ class MultivariateGaussianDist(object):
 
         """
         dist_name = self.__class__.__name__
-        instantiation_dict = self._get_instantiation_dict()
+        instantiation_dict = self.get_instantiation_dict()
         args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
                           for key in instantiation_dict])
         return "{}({})".format(dist_name, args)
@@ -3146,7 +3350,7 @@ class MultivariateGaussian(Prior):
         if len(self.mvg.sampled_parameters) == len(self.mvg):
             # reset samples
             self.mvg.reset_sampled()
-
+        self.least_recently_sampled = sample
         return sample
 
     def prob(self, val):
@@ -3237,5 +3441,245 @@ class MultivariateGaussian(Prior):
 
 
 class MultivariateNormal(MultivariateGaussian):
-    """ A synonym for the :class:`bilby.core.prior.MultivariateGaussian`
-        prior distribution."""
+    """A synonym for the :class:`bilby.core.prior.MultivariateGaussian`
+     prior distribution."""
+
+
+def conditional_prior_factory(prior_class):
+    class ConditionalPrior(prior_class):
+        def __init__(self, condition_func, name=None, latex_label=None, unit=None,
+                     boundary=None, **reference_params):
+            """
+
+            Parameters
+            ----------
+            condition_func: func
+                Functional form of the condition for this prior. The first function argument
+                has to be a dictionary for the `reference_params` (see below). The following
+                arguments are the required variables that are required before we can draw this
+                prior.
+                It needs to return a dictionary with the modified values for the
+                `reference_params` that are being used in the next draw.
+                For example if we have a Uniform prior for `x` depending on a different variable `y`
+                `p(x|y)` with the boundaries linearly depending on y, then this
+                could have the following form:
+
+                ```
+                def condition_func(reference_params, y):
+                    return dict(minimum=reference_params['minimum'] + y, maximum=reference_params['maximum'] + y)
+                ```
+            name: str, optional
+               See superclass
+            latex_label: str, optional
+                See superclass
+            unit: str, optional
+                See superclass
+            boundary: str, optional
+                See superclass
+            reference_params:
+                Initial values for attributes such as `minimum`, `maximum`.
+                This differs on the `prior_class`, for example for the Gaussian
+                prior this is `mu` and `sigma`.
+            """
+            if 'boundary' in infer_args_from_method(super(ConditionalPrior, self).__init__):
+                super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label,
+                                                       unit=unit, boundary=boundary, **reference_params)
+            else:
+                super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label,
+                                                       unit=unit, **reference_params)
+
+            self._required_variables = None
+            self.condition_func = condition_func
+            self._reference_params = reference_params
+            self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__)
+
+        def sample(self, size=None, **required_variables):
+            """Draw a sample from the prior
+
+            Parameters
+            ----------
+            size: int or tuple of ints, optional
+                See superclass
+            required_variables:
+                Any required variables that this prior depends on
+
+            Returns
+            -------
+            float: See superclass
+
+            """
+            self.least_recently_sampled = self.rescale(np.random.uniform(0, 1, size), **required_variables)
+            return self.least_recently_sampled
+
+        def rescale(self, val, **required_variables):
+            """
+            'Rescale' a sample from the unit line element to the prior.
+
+            Parameters
+            ----------
+            val: Union[float, int, array_like]
+                See superclass
+            required_variables:
+                Any required variables that this prior depends on
+
+
+            """
+            self.update_conditions(**required_variables)
+            return super(ConditionalPrior, self).rescale(val)
+
+        def prob(self, val, **required_variables):
+            """Return the prior probability of val.
+
+            Parameters
+            ----------
+            val: Union[float, int, array_like]
+                See superclass
+            required_variables:
+                Any required variables that this prior depends on
+
+
+            Returns
+            -------
+            float: Prior probability of val
+            """
+            self.update_conditions(**required_variables)
+            return super(ConditionalPrior, self).prob(val)
+
+        def ln_prob(self, val, **required_variables):
+            self.update_conditions(**required_variables)
+            return super(ConditionalPrior, self).ln_prob(val)
+
+        def update_conditions(self, **required_variables):
+            """
+            This method updates the conditional parameters (depending on the parent class
+            this could be e.g. `minimum`, `maximum`, `mu`, `sigma`, etc.) of this prior
+            class depending on the required variables it depends on.
+
+            If no variables are given, the most recently used conditional parameters are kept
+
+            Parameters
+            ----------
+            required_variables:
+                Any required variables that this prior depends on. If none are given,
+                self.reference_params will be used.
+
+            """
+            if sorted(list(required_variables)) == sorted(self.required_variables):
+                parameters = self.condition_func(self.reference_params, **required_variables)
+                for key, value in parameters.items():
+                    setattr(self, key, value)
+            elif len(required_variables) == 0:
+                return
+            else:
+                raise IllegalRequiredVariablesException("Expected kwargs for {}. Got kwargs for {} instead."
+                                                        .format(self.required_variables,
+                                                                list(required_variables.keys())))
+
+        @property
+        def reference_params(self):
+            """
+            Initial values for attributes such as `minimum`, `maximum`.
+            This depends on the `prior_class`, for example for the Gaussian
+            prior this is `mu` and `sigma`. This is read-only.
+            """
+            return self._reference_params
+
+        @property
+        def condition_func(self):
+            return self._condition_func
+
+        @condition_func.setter
+        def condition_func(self, condition_func):
+            if condition_func is None:
+                self._condition_func = lambda reference_params: reference_params
+            else:
+                self._condition_func = condition_func
+            self._required_variables = infer_parameters_from_function(self.condition_func)
+
+        @property
+        def required_variables(self):
+            """ The required variables to pass into the condition function. """
+            return self._required_variables
+
+        def get_instantiation_dict(self):
+            instantiation_dict = super(ConditionalPrior, self).get_instantiation_dict()
+            for key, value in self.reference_params.items():
+                instantiation_dict[key] = value
+            return instantiation_dict
+
+        def reset_to_reference_parameters(self):
+            """
+            Reset the object attributes to match the original reference parameters
+            """
+            for key, value in self.reference_params.items():
+                setattr(self, key, value)
+
+        def __repr__(self):
+            """Overrides the special method __repr__.
+
+            Returns a representation of this instance that resembles how it is instantiated.
+            Works correctly for all child classes
+
+            Returns
+            -------
+            str: A string representation of this instance
+
+            """
+            prior_name = self.__class__.__name__
+            instantiation_dict = self.get_instantiation_dict()
+            instantiation_dict["condition_func"] = ".".join([
+                instantiation_dict["condition_func"].__module__,
+                instantiation_dict["condition_func"].__name__
+            ])
+            args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
+                              for key in instantiation_dict])
+            return "{}({})".format(prior_name, args)
+
+    return ConditionalPrior
+
+
+ConditionalBasePrior = conditional_prior_factory(Prior)  # Only for testing purposes
+ConditionalUniform = conditional_prior_factory(Uniform)
+ConditionalDeltaFunction = conditional_prior_factory(DeltaFunction)
+ConditionalPowerLaw = conditional_prior_factory(PowerLaw)
+ConditionalGaussian = conditional_prior_factory(Gaussian)
+ConditionalLogUniform = conditional_prior_factory(LogUniform)
+ConditionalSymmetricLogUniform = conditional_prior_factory(SymmetricLogUniform)
+ConditionalCosine = conditional_prior_factory(Cosine)
+ConditionalSine = conditional_prior_factory(Sine)
+ConditionalTruncatedGaussian = conditional_prior_factory(TruncatedGaussian)
+ConditionalHalfGaussian = conditional_prior_factory(HalfGaussian)
+ConditionalLogNormal = conditional_prior_factory(LogNormal)
+ConditionalExponential = conditional_prior_factory(Exponential)
+ConditionalStudentT = conditional_prior_factory(StudentT)
+ConditionalBeta = conditional_prior_factory(Beta)
+ConditionalLogistic = conditional_prior_factory(Logistic)
+ConditionalCauchy = conditional_prior_factory(Cauchy)
+ConditionalGamma = conditional_prior_factory(Gamma)
+ConditionalChiSquared = conditional_prior_factory(ChiSquared)
+ConditionalFermiDirac = conditional_prior_factory(FermiDirac)
+ConditionalInterped = conditional_prior_factory(Interped)
+
+
+class PriorException(Exception):
+    """ General base class for all prior exceptions """
+
+
+class ConditionalPriorException(PriorException):
+    """ General base class for all conditional prior exceptions """
+
+
+class IllegalRequiredVariablesException(ConditionalPriorException):
+    """ Exception class for exceptions relating to handling the required variables. """
+
+
+class PriorDictException(Exception):
+    """ General base class for all prior dict exceptions """
+
+
+class ConditionalPriorDictException(PriorDictException):
+    """ General base class for all conditional prior dict exceptions """
+
+
+class IllegalConditionsException(ConditionalPriorDictException):
+    """ Exception class to handle prior dicts that contain unresolvable conditions. """
diff --git a/bilby/core/result.py b/bilby/core/result.py
index 7101cf99676ab0c4ecebfbfb3649cbb03f7212a3..630830dfc3830de9f6e338e88f469523bfa3f988 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -20,7 +20,7 @@ from . import utils
 from .utils import (logger, infer_parameters_from_function,
                     check_directory_exists_and_if_not_mkdir,)
 from .utils import BilbyJsonEncoder, decode_bilby_json
-from .prior import Prior, PriorDict, DeltaFunction
+from .prior import Prior, PriorDict, DeltaFunction, ConditionalPriorDict
 
 
 def result_file_name(outdir, label, extension='json', gzip=False):
@@ -299,7 +299,10 @@ class Result(object):
     @priors.setter
     def priors(self, priors):
         if isinstance(priors, dict):
-            self._priors = PriorDict(priors)
+            if isinstance(priors, ConditionalPriorDict):
+                self._priors = priors
+            else:
+                self._priors = PriorDict(priors)
             if self.parameter_labels is None:
                 self.parameter_labels = [self.priors[k].latex_label for k in
                                          self.search_parameter_keys]
@@ -307,7 +310,6 @@ class Result(object):
                 self.parameter_labels_with_unit = [
                     self.priors[k].latex_label_with_unit for k in
                     self.search_parameter_keys]
-
         elif priors is None:
             self._priors = priors
             self.parameter_labels = self.search_parameter_keys
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index ab6d62bd4206caf12db8fbeda6b344d4ada02e0f..4b5cc048f484f81ce90cb00c62ac23b1cadef917 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -5,7 +5,7 @@ import numpy as np
 from pandas import DataFrame
 
 from ..utils import logger, command_line_args, Counter
-from ..prior import Prior, PriorDict, DeltaFunction, Constraint
+from ..prior import Prior, PriorDict, ConditionalPriorDict, DeltaFunction, Constraint
 from ..result import Result, read_in_result
 
 
@@ -251,13 +251,19 @@ class Sampler(object):
         AttributeError
             prior can't be sampled.
         """
-        for key in self.priors:
-            if isinstance(self.priors[key], Constraint):
-                continue
+        if isinstance(self.priors, ConditionalPriorDict):
             try:
-                self.likelihood.parameters[key] = self.priors[key].sample()
+                self.likelihood.parameters = self.priors.sample()
             except AttributeError as e:
-                logger.warning('Cannot sample from {}, {}'.format(key, e))
+                logger.warning('Cannot sample from prior, {}'.format(e))
+        else:
+            for key in self.priors:
+                if isinstance(self.priors[key], Constraint):
+                    continue
+                try:
+                    self.likelihood.parameters[key] = self.priors[key].sample()
+                except AttributeError as e:
+                    logger.warning('Cannot sample from {}, {}'.format(key, e))
 
     def _verify_parameters(self):
         """ Evaluate a set of parameters drawn from the prior
@@ -276,9 +282,13 @@ class Sampler(object):
                 "Your sampling set contains redundant parameters.")
 
         self._check_if_priors_can_be_sampled()
-        try:
+        if isinstance(self.priors, ConditionalPriorDict):
+            theta = self.priors.sample()
+            theta = [theta[key] for key in self._search_parameter_keys]
+        else:
             theta = [self.priors[key].sample()
                      for key in self._search_parameter_keys]
+        try:
             self.log_likelihood(theta)
         except TypeError as e:
             raise TypeError(
@@ -298,8 +308,12 @@ class Sampler(object):
 
         t1 = datetime.datetime.now()
         for _ in range(n_evaluations):
-            theta = [self.priors[key].sample()
-                     for key in self._search_parameter_keys]
+            if isinstance(self.priors, ConditionalPriorDict):
+                theta = self.priors.sample()
+                theta = [theta[key] for key in self._search_parameter_keys]
+            else:
+                theta = [self.priors[key].sample()
+                         for key in self._search_parameter_keys]
             self.log_likelihood(theta)
         total_time = (datetime.datetime.now() - t1).total_seconds()
         self._log_likelihood_eval_time = total_time / n_evaluations
diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py
index be46742558613695d8e2a6ce0fd8a6315b44f989..1ed64e89c23b1588c88083043e3b74d64989ad0f 100644
--- a/bilby/core/sampler/cpnest.py
+++ b/bilby/core/sampler/cpnest.py
@@ -61,9 +61,7 @@ class Cpnest(NestedSampler):
             def __init__(self, names, priors):
                 self.names = names
                 self.priors = priors
-                self.bounds = [
-                    [self.priors[key].minimum, self.priors[key].maximum]
-                    for key in self.names]
+                self._update_bounds()
 
             @staticmethod
             def log_likelihood(x, **kwargs):
@@ -75,10 +73,17 @@ class Cpnest(NestedSampler):
                 theta = [x[n] for n in self.search_parameter_keys]
                 return self.log_prior(theta)
 
+            def _update_bounds(self):
+                self.bounds = [
+                    [self.priors[key].minimum, self.priors[key].maximum]
+                    for key in self.names]
+
             def new_point(self):
                 """Draw a point from the prior"""
+                prior_samples = self.priors.sample()
+                self._update_bounds()
                 point = LivePoint(
-                    self.names, [self.priors[name].sample()
+                    self.names, [prior_samples[name]
                                  for name in self.names])
                 return point
 
diff --git a/bilby/core/utils.py b/bilby/core/utils.py
index 66dc063684170387a856c43c3b179b4b2c14b680..c5f9a9709b8894b3896b9526d0f453ab243de987 100644
--- a/bilby/core/utils.py
+++ b/bilby/core/utils.py
@@ -979,7 +979,7 @@ class BilbyJsonEncoder(json.JSONEncoder):
         if isinstance(obj, (MultivariateGaussianDist, Prior)):
             return {'__prior__': True, '__module__': obj.__module__,
                     '__name__': obj.__class__.__name__,
-                    'kwargs': dict(obj._get_instantiation_dict())}
+                    'kwargs': dict(obj.get_instantiation_dict())}
         try:
             from astropy import cosmology as cosmo, units
             if isinstance(obj, cosmo.FLRW):
diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py
index 4a3a902aade3af8b33397ccac84a1fababd7b6c2..5837d623674c48178d3fe74406d86664b89eefff 100644
--- a/bilby/gw/prior.py
+++ b/bilby/gw/prior.py
@@ -4,8 +4,8 @@ import copy
 import numpy as np
 from scipy.interpolate import InterpolatedUnivariateSpline
 
-from ..core.prior import (PriorDict, Uniform, Prior, DeltaFunction, Gaussian,
-                          Interped, Constraint)
+from ..core.prior import (ConditionalPriorDict, PriorDict, Uniform, Prior, DeltaFunction, Gaussian,
+                          Interped, Constraint, conditional_prior_factory)
 from ..core.utils import infer_args_from_method, logger
 from .conversion import (
     convert_to_lal_binary_black_hole_parameters,
@@ -307,7 +307,7 @@ class AlignedSpin(Interped):
                                           boundary=boundary)
 
 
-class CBCPriorDict(PriorDict):
+class CBCPriorDict(ConditionalPriorDict):
     @property
     def minimum_chirp_mass(self):
         mass_1 = None
@@ -753,3 +753,12 @@ class CalibrationPriorDict(PriorDict):
                                         latex_label=latex_label)
 
         return prior
+
+
+def secondary_mass_condition_function(reference_params, mass_1):
+    return dict(minimum=reference_params['minimum'], maximum=mass_1)
+
+
+ConditionalCosmological = conditional_prior_factory(Cosmological)
+ConditionalUniformComovingVolume = conditional_prior_factory(UniformComovingVolume)
+ConditionalUniformSourceFrame = conditional_prior_factory(UniformSourceFrame)
diff --git a/examples/core_examples/conditional_prior.py b/examples/core_examples/conditional_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..2475e479d160f1fe0088f9142b6fdec31919a9e1
--- /dev/null
+++ b/examples/core_examples/conditional_prior.py
@@ -0,0 +1,44 @@
+import bilby
+import numpy as np
+
+# This tutorial demonstrates how we can sample a prior in the shape of a ball
+# Note that this will not end up sampling uniformly in that space, constraint priors are more suitable for that.
+# This implementation will draw a value for the x-coordinate from p(x), and given that draw a value for the
+# y-coordinate from p(y|x), and given that draw a value for the z-coordinate from p(z|x,y).
+# Only the x-coordinate will end up being uniform for this
+
+
+class ZeroLikelihood(bilby.core.likelihood.Likelihood):
+    """ Flat likelihood. This always returns 0.
+    This way our posterior distribution is exactly the prior distribution."""
+    def log_likelihood(self):
+        return 0
+
+
+def condition_func_y(reference_params, x):
+    """ Condition function for our p(y|x) prior."""
+    radius = 0.5 * (reference_params['maximum'] - reference_params['minimum'])
+    y_max = np.sqrt(radius**2 - x**2)
+    return dict(minimum=-y_max, maximum=y_max)
+
+
+def condition_func_z(reference_params, x, y):
+    """ Condition function for our p(z|x, y) prior."""
+    radius = 0.5 * (reference_params['maximum'] - reference_params['minimum'])
+    z_max = np.sqrt(radius**2 - x**2 - y**2)
+    return dict(minimum=-z_max, maximum=z_max)
+
+
+# Set up the conditional priors and the flat likelihood
+priors = bilby.core.prior.ConditionalPriorDict()
+priors['x'] = bilby.core.prior.Uniform(minimum=-1, maximum=1, latex_label="$x$")
+priors['y'] = bilby.core.prior.ConditionalUniform(condition_func=condition_func_y, minimum=-1,
+                                                  maximum=1, latex_label="$y$")
+priors['z'] = bilby.core.prior.ConditionalUniform(condition_func=condition_func_z, minimum=-1,
+                                                  maximum=1, latex_label="$z$")
+likelihood = ZeroLikelihood(parameters=dict(x=0, y=0, z=0))
+
+# Sample the prior distribution
+res = bilby.run_sampler(likelihood=likelihood, priors=priors, sampler='dynesty', npoints=5000, walks=100,
+                        label='conditional_prior', outdir='outdir', resume=False, clean=True)
+res.plot_corner()
diff --git a/test/prior_test.py b/test/prior_test.py
index 8f58d41e57fcea98a6d19ec9f5cc386e756fdd98..b6f554d44bbd2e91fd4d47ba686ecf3e73dfff95 100644
--- a/test/prior_test.py
+++ b/test/prior_test.py
@@ -2,6 +2,7 @@ from __future__ import absolute_import, division
 import bilby
 import unittest
 from mock import Mock
+import mock
 import numpy as np
 import os
 from collections import OrderedDict
@@ -162,6 +163,9 @@ class TestPriorClasses(unittest.TestCase):
                                                         covs=np.array([[2., 0.5], [0.5, 2.]]),
                                                         weights=1.)
 
+        def condition_func(reference_params, test_param):
+            return reference_params.copy()
+
         self.priors = [
             bilby.core.prior.DeltaFunction(name='test', unit='unit', peak=1),
             bilby.core.prior.Gaussian(name='test', unit='unit', mu=0, sigma=1),
@@ -196,7 +200,28 @@ class TestPriorClasses(unittest.TestCase):
             bilby.core.prior.MultivariateGaussian(mvg=mvg, name='testa', unit='unit'),
             bilby.core.prior.MultivariateGaussian(mvg=mvg, name='testb', unit='unit'),
             bilby.core.prior.MultivariateNormal(mvg=mvn, name='testa', unit='unit'),
-            bilby.core.prior.MultivariateNormal(mvg=mvn, name='testb', unit='unit')
+            bilby.core.prior.MultivariateNormal(mvg=mvn, name='testb', unit='unit'),
+            bilby.core.prior.ConditionalDeltaFunction(condition_func=condition_func, name='test', unit='unit', peak=1),
+            bilby.core.prior.ConditionalGaussian(condition_func=condition_func, name='test', unit='unit', mu=0, sigma=1),
+            bilby.core.prior.ConditionalPowerLaw(condition_func=condition_func, name='test', unit='unit', alpha=0, minimum=0, maximum=1),
+            bilby.core.prior.ConditionalPowerLaw(condition_func=condition_func, name='test', unit='unit', alpha=-1, minimum=0.5, maximum=1),
+            bilby.core.prior.ConditionalPowerLaw(condition_func=condition_func, name='test', unit='unit', alpha=2, minimum=1, maximum=1e2),
+            bilby.core.prior.ConditionalUniform(condition_func=condition_func, name='test', unit='unit', minimum=0, maximum=1),
+            bilby.core.prior.ConditionalLogUniform(condition_func=condition_func, name='test', unit='unit', minimum=5e0, maximum=1e2),
+            bilby.gw.prior.ConditionalUniformComovingVolume(condition_func=condition_func, name='redshift', minimum=0.1, maximum=1.0),
+            bilby.gw.prior.ConditionalUniformSourceFrame(condition_func=condition_func, name='redshift', minimum=0.1, maximum=1.0),
+            bilby.core.prior.ConditionalSine(condition_func=condition_func, name='test', unit='unit'),
+            bilby.core.prior.ConditionalCosine(condition_func=condition_func, name='test', unit='unit'),
+            bilby.core.prior.ConditionalTruncatedGaussian(condition_func=condition_func, name='test', unit='unit', mu=1, sigma=0.4, minimum=-1, maximum=1),
+            bilby.core.prior.ConditionalHalfGaussian(condition_func=condition_func, name='test', unit='unit', sigma=1),
+            bilby.core.prior.ConditionalLogNormal(condition_func=condition_func, name='test', unit='unit', mu=0, sigma=1),
+            bilby.core.prior.ConditionalExponential(condition_func=condition_func, name='test', unit='unit', mu=1),
+            bilby.core.prior.ConditionalStudentT(condition_func=condition_func, name='test', unit='unit', df=3, mu=0, scale=1),
+            bilby.core.prior.ConditionalBeta(condition_func=condition_func, name='test', unit='unit', alpha=2.0, beta=2.0),
+            bilby.core.prior.ConditionalLogistic(condition_func=condition_func, name='test', unit='unit', mu=0, scale=1),
+            bilby.core.prior.ConditionalCauchy(condition_func=condition_func, name='test', unit='unit', alpha=0, beta=1),
+            bilby.core.prior.ConditionalGamma(condition_func=condition_func, name='test', unit='unit', k=1, theta=1),
+            bilby.core.prior.ConditionalChiSquared(condition_func=condition_func, name='test', unit='unit', nu=2)
         ]
 
     def tearDown(self):
@@ -240,6 +265,11 @@ class TestPriorClasses(unittest.TestCase):
         for prior in self.priors:
             self.assertRaises(ValueError, lambda: prior.rescale(-1))
 
+    def test_least_recently_sampled(self):
+        for prior in self.priors:
+            least_recently_sampled_expected = prior.sample()
+            self.assertEqual(least_recently_sampled_expected, prior.least_recently_sampled)
+
     def test_sampling_single(self):
         """Test that sampling from the prior always returns values within its domain."""
         for prior in self.priors:
@@ -266,6 +296,11 @@ class TestPriorClasses(unittest.TestCase):
                 outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000)
                 self.assertTrue(all(prior.prob(outside_domain) == 0))
 
+    def test_least_recently_sampled(self):
+        for prior in self.priors:
+            lrs = prior.sample()
+            self.assertEqual(lrs, prior.least_recently_sampled)
+
     def test_prob_and_ln_prob(self):
         for prior in self.priors:
             sample = prior.sample()
@@ -603,6 +638,8 @@ class TestPriorClasses(unittest.TestCase):
                 )
             elif isinstance(prior, bilby.gw.prior.UniformComovingVolume):
                 repr_prior_string = 'bilby.gw.prior.' + repr(prior)
+            elif 'Conditional' in prior.__class__.__name__:
+                continue # This feature does not exist because we cannot recreate the condition function
             else:
                 repr_prior_string = 'bilby.core.prior.' + repr(prior)
             repr_prior = eval(repr_prior_string, None, dict(inf=np.inf))
@@ -906,6 +943,246 @@ class TestCreateDefaultPrior(unittest.TestCase):
         self.assertIsNone(bilby.core.prior.create_default_prior(name='name', default_priors_file=prior_file))
 
 
+class TestConditionalPrior(unittest.TestCase):
+
+    def setUp(self):
+        self.condition_func_call_counter = 0
+
+        def condition_func(reference_parameters, test_variable_1, test_variable_2):
+            self.condition_func_call_counter += 1
+            return {key: value + 1 for key, value in reference_parameters.items()}
+        self.condition_func = condition_func
+        self.minimum = 0
+        self.maximum = 5
+        self.test_variable_1 = 0
+        self.test_variable_2 = 1
+        self.prior = bilby.core.prior.ConditionalBasePrior(condition_func=condition_func,
+                                                           minimum=self.minimum,
+                                                           maximum=self.maximum)
+
+    def tearDown(self):
+        del self.condition_func
+        del self.condition_func_call_counter
+        del self.minimum
+        del self.maximum
+        del self.test_variable_1
+        del self.test_variable_2
+        del self.prior
+
+    def test_reference_params(self):
+        self.assertDictEqual(dict(minimum=self.minimum, maximum=self.maximum), self.prior.reference_params)
+
+    def test_required_variables(self):
+        self.assertListEqual(['test_variable_1', 'test_variable_2'], sorted(self.prior.required_variables))
+
+    def test_required_variables_no_condition_func(self):
+        self.prior = bilby.core.prior.ConditionalBasePrior(condition_func=None,
+                                                           minimum=self.minimum,
+                                                           maximum=self.maximum)
+        self.assertListEqual([], self.prior.required_variables)
+
+    def test_get_instantiation_dict(self):
+        expected = dict(minimum=0, maximum=5, name=None, latex_label=None, unit=None,
+                        boundary=None, condition_func=self.condition_func)
+        actual = self.prior.get_instantiation_dict()
+        for key, value in expected.items():
+            if key == 'condition_func':
+                continue
+            self.assertEqual(value, actual[key])
+
+    def test_update_conditions_correct_variables(self):
+        self.prior.update_conditions(test_variable_1=self.test_variable_1, test_variable_2=self.test_variable_2)
+        self.assertEqual(1, self.condition_func_call_counter)
+        self.assertEqual(self.minimum + 1, self.prior.minimum)
+        self.assertEqual(self.maximum + 1, self.prior.maximum)
+
+    def test_update_conditions_no_variables(self):
+        self.prior.update_conditions(test_variable_1=self.test_variable_1, test_variable_2=self.test_variable_2)
+        self.prior.update_conditions()
+        self.assertEqual(1, self.condition_func_call_counter)
+        self.assertEqual(self.minimum + 1, self.prior.minimum)
+        self.assertEqual(self.maximum + 1, self.prior.maximum)
+
+    def test_update_conditions_illegal_variables(self):
+        with self.assertRaises(bilby.core.prior.IllegalRequiredVariablesException):
+            self.prior.update_conditions(test_parameter_1=self.test_variable_1)
+
+    def test_sample_calls_update_conditions(self):
+        with mock.patch.object(self.prior, 'update_conditions') as m:
+            self.prior.sample(1,
+                              test_parameter_1=self.test_variable_1,
+                              test_parameter_2=self.test_variable_2)
+            m.assert_called_with(test_parameter_1=self.test_variable_1, test_parameter_2=self.test_variable_2)
+
+    def test_rescale_calls_update_conditions(self):
+        with mock.patch.object(self.prior, 'update_conditions') as m:
+            self.prior.rescale(1, test_parameter_1=self.test_variable_1,
+                               test_parameter_2=self.test_variable_2)
+            m.assert_called_with(test_parameter_1=self.test_variable_1,
+                                 test_parameter_2=self.test_variable_2)
+
+    def test_rescale_prob_update_conditions(self):
+        with mock.patch.object(self.prior, 'update_conditions') as m:
+            self.prior.prob(1, test_parameter_1=self.test_variable_1,
+                            test_parameter_2=self.test_variable_2)
+            m.assert_called_with(test_parameter_1=self.test_variable_1,
+                                 test_parameter_2=self.test_variable_2)
+
+    def test_rescale_ln_prob_update_conditions(self):
+        with mock.patch.object(self.prior, 'update_conditions') as m:
+            self.prior.ln_prob(1, test_parameter_1=self.test_variable_1,
+                               test_parameter_2=self.test_variable_2)
+            calls = [mock.call(test_parameter_1=self.test_variable_1,
+                               test_parameter_2=self.test_variable_2),
+                     mock.call()]
+            m.assert_has_calls(calls)
+
+    def test_reset_to_reference_parameters(self):
+        self.prior.minimum = 10
+        self.prior.maximum = 20
+        self.prior.reset_to_reference_parameters()
+        self.assertEqual(self.prior.reference_params['minimum'], self.prior.minimum)
+        self.assertEqual(self.prior.reference_params['maximum'], self.prior.maximum)
+
+    def test_cond_prior_instantiation_no_boundary_prior(self):
+        prior = bilby.core.prior.ConditionalFermiDirac(condition_func=None, sigma=1, mu=1)
+        self.assertIsNone(prior.boundary)
+
+
+class TestConditionalPriorDict(unittest.TestCase):
+
+    def setUp(self):
+        def condition_func_1(reference_parameters, var_0):
+            return reference_parameters
+
+        def condition_func_2(reference_parameters, var_0, var_1):
+            return reference_parameters
+
+        def condition_func_3(reference_parameters, var_1, var_2):
+            return reference_parameters
+
+        self.minimum = 0
+        self.maximum = 1
+        self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=self.maximum)
+        self.prior_1 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_1,
+                                                           minimum=self.minimum, maximum=self.maximum)
+        self.prior_2 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_2,
+                                                           minimum=self.minimum, maximum=self.maximum)
+        self.prior_3 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_3,
+                                                           minimum=self.minimum, maximum=self.maximum)
+        self.conditional_priors = bilby.core.prior.ConditionalPriorDict(dict(var_3=self.prior_3, var_2=self.prior_2,
+                                                                             var_0=self.prior_0, var_1=self.prior_1))
+        self.conditional_priors_manually_set_items = bilby.core.prior.ConditionalPriorDict()
+        self.test_sample = dict(var_0=0.3, var_1=0.4, var_2=0.5, var_3=0.4)
+        for key, value in dict(var_0=self.prior_0, var_1=self.prior_1, var_2=self.prior_2, var_3=self.prior_3).items():
+            self.conditional_priors_manually_set_items[key] = value
+
+    def tearDown(self):
+        del self.minimum
+        del self.maximum
+        del self.prior_0
+        del self.prior_1
+        del self.prior_2
+        del self.prior_3
+        del self.conditional_priors
+        del self.conditional_priors_manually_set_items
+        del self.test_sample
+
+    def test_conditions_resolved_upon_instantiation(self):
+        self.assertListEqual(['var_0', 'var_1', 'var_2', 'var_3'], self.conditional_priors.sorted_keys)
+
+    def test_conditions_resolved_setting_items(self):
+        self.assertListEqual(['var_0', 'var_1', 'var_2', 'var_3'],
+                             self.conditional_priors_manually_set_items.sorted_keys)
+
+    def test_unconditional_keys_upon_instantiation(self):
+        self.assertListEqual(['var_0'], self.conditional_priors.unconditional_keys)
+
+    def test_unconditional_keys_setting_items(self):
+        self.assertListEqual(['var_0'], self.conditional_priors_manually_set_items.unconditional_keys)
+
+    def test_conditional_keys_upon_instantiation(self):
+        self.assertListEqual(['var_1', 'var_2', 'var_3'], self.conditional_priors.conditional_keys)
+
+    def test_conditional_keys_setting_items(self):
+        self.assertListEqual(['var_1', 'var_2', 'var_3'], self.conditional_priors_manually_set_items.conditional_keys)
+
+    def test_prob(self):
+        self.assertEqual(1, self.conditional_priors.prob(sample=self.test_sample))
+
+    def test_prob_illegal_conditions(self):
+        del self.conditional_priors['var_0']
+        with self.assertRaises(bilby.core.prior.IllegalConditionsException):
+            self.conditional_priors.prob(sample=self.test_sample)
+
+    def test_ln_prob(self):
+        self.assertEqual(0, self.conditional_priors.ln_prob(sample=self.test_sample))
+
+    def test_ln_prob_illegal_conditions(self):
+        del self.conditional_priors['var_0']
+        with self.assertRaises(bilby.core.prior.IllegalConditionsException):
+            self.conditional_priors.ln_prob(sample=self.test_sample)
+
+    def test_sample_subset_all_keys(self):
+        with mock.patch("numpy.random.uniform") as m:
+            m.return_value = 0.5
+            self.assertDictEqual(dict(var_0=0.5, var_1=0.5, var_2=0.5, var_3=0.5),
+                                 self.conditional_priors.sample_subset(keys=['var_0', 'var_1', 'var_2', 'var_3']))
+
+    def test_sample_illegal_subset(self):
+        with mock.patch("numpy.random.uniform") as m:
+            m.return_value = 0.5
+            with self.assertRaises(bilby.core.prior.IllegalConditionsException):
+                self.conditional_priors.sample_subset(keys=['var_1'])
+
+    def test_sample_multiple(self):
+        def condition_func(reference_params, a):
+            return dict(minimum=reference_params['minimum'],
+                        maximum=reference_params['maximum'],
+                        alpha=reference_params['alpha'] * a)
+        priors = bilby.core.prior.ConditionalPriorDict()
+        priors['a'] = bilby.core.prior.Uniform(minimum=0, maximum=1)
+        priors['b'] = bilby.core.prior.ConditionalPowerLaw(condition_func=condition_func, minimum=1, maximum=2,
+                                                           alpha=-2)
+        print(priors.sample(2))
+
+    def test_rescale(self):
+
+        def condition_func_1_rescale(reference_parameters, var_0):
+            if var_0 == 0.5:
+                return dict(minimum=reference_parameters['minimum'], maximum=1)
+            return reference_parameters
+
+        def condition_func_2_rescale(reference_parameters, var_0, var_1):
+            if var_0 == 0.5 and var_1 == 0.5:
+                return dict(minimum=reference_parameters['minimum'], maximum=1)
+            return reference_parameters
+
+        def condition_func_3_rescale(reference_parameters, var_1, var_2):
+            if var_1 == 0.5 and var_2 == 0.5:
+                return dict(minimum=reference_parameters['minimum'], maximum=1)
+            return reference_parameters
+
+        self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=1)
+        self.prior_1 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_1_rescale,
+                                                           minimum=self.minimum, maximum=2)
+        self.prior_2 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_2_rescale,
+                                                           minimum=self.minimum, maximum=2)
+        self.prior_3 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_3_rescale,
+                                                           minimum=self.minimum, maximum=2)
+        self.conditional_priors = bilby.core.prior.ConditionalPriorDict(dict(var_3=self.prior_3, var_2=self.prior_2,
+                                                                             var_0=self.prior_0, var_1=self.prior_1))
+        ref_variables = [0.5, 0.5, 0.5, 0.5]
+        res = self.conditional_priors.rescale(keys=list(self.test_sample.keys()),
+                                              theta=ref_variables)
+        self.assertListEqual(ref_variables, res)
+
+    def test_rescale_illegal_conditions(self):
+        del self.conditional_priors['var_0']
+        with self.assertRaises(bilby.core.prior.IllegalConditionsException):
+            self.conditional_priors.rescale(keys=list(self.test_sample.keys()), theta=list(self.test_sample.values()))
+
+
 class TestJsonIO(unittest.TestCase):
 
     def setUp(self):