diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index 9e338a70b9e1a1b658eb4f2ff5d9474dc68a1d8b..98f3db4635d129b7587d6232a877d3c6412dd172 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -95,20 +95,20 @@ class PriorDict(dict):
         check_directory_exists_and_if_not_mkdir(outdir)
         prior_file = os.path.join(outdir, "{}.prior".format(label))
         logger.debug("Writing priors to {}".format(prior_file))
-        mvgs = []
+        joint_dists = []
         with open(prior_file, "w") as outfile:
             for key in self.keys():
-                if isinstance(self[key], MultivariateGaussian):
-                    mvgname = '_'.join(self[key].mvg.names) + '_mvg'
-                    if mvgname not in mvgs:
-                        mvgs.append(mvgname)
+                if JointPrior in self[key].__class__.__mro__:
+                    distname = '_'.join(self[key].dist.names) + '_{}'.format(self[key].dist.distname)
+                    if distname not in joint_dists:
+                        joint_dists.append(distname)
                         outfile.write(
-                            "{} = {}\n".format(mvgname, self[key].mvg))
-                    mvgstr = repr(self[key].mvg)
+                            "{} = {}\n".format(distname, self[key].dist))
+                    diststr = repr(self[key].dist)
                     priorstr = repr(self[key])
                     outfile.write(
-                        "{} = {}\n".format(key, priorstr.replace(mvgstr,
-                                                                 mvgname)))
+                        "{} = {}\n".format(key, priorstr.replace(diststr,
+                                                                 distname)))
                 else:
                     outfile.write(
                         "{} = {}\n".format(key, self[key]))
@@ -2734,61 +2734,31 @@ class FermiDirac(Prior):
             return lnp
 
 
-class MultivariateGaussianDist(object):
+class BaseJointPriorDist(object):
 
-    def __init__(self, names, nmodes=1, mus=None, sigmas=None, corrcoefs=None,
-                 covs=None, weights=None, bounds=None):
+    def __init__(self, names, bounds=None):
         """
-        A class defining a multi-variate Gaussian, allowing multiple modes for
-        a Gaussian mixture model.
+        A class defining JointPriorDist that will be overwritten with child
+        classes defining the joint prior distribtuions between given parameters,
 
-        Note: if using a multivariate Gaussian prior, with bounds, this can
-        lead to biases in the marginal likelihood estimate and posterior
-        estimate for nested samplers routines that rely on sampling from a unit
-        hypercube and having a prior transform, e.g., nestle, dynesty and
-        MultiNest.
 
         Parameters
         ----------
-        names: list
-            A list of the parameter names in the multivariate Gaussian. The
+        names: list (required)
+            A list of the parameter names in the JointPriorDist. The
             listed parameters must have the same order that they appear in
-            the lists of means, standard deviations, and the correlation
-            coefficient, or covariance, matrices.
-        nmodes: int
-            The number of modes for the mixture model. This defaults to 1,
-            which will be checked against the shape of the other inputs.
-        mus: array_like
-            A list of lists of means of each mode in a multivariate Gaussian
-            mixture model. A single list can be given for a single mode. If
-            this is None then means at zero will be assumed.
-        sigmas: array_like
-            A list of lists of the standard deviations of each mode of the
-            multivariate Gaussian. If supplying a correlation coefficient
-            matrix rather than a covariance matrix these values must be given.
-            If this is None unit variances will be assumed.
-        corrcoefs: array
-            A list of square matrices containing the correlation coefficients
-            of the parameters for each mode. If this is None it will be assumed
-            that the parameters are uncorrelated.
-        covs: array
-            A list of square matrices containing the covariance matrix of the
-            multivariate Gaussian.
-        weights: list
-            A list of weights (relative probabilities) for each mode of the
-            multivariate Gaussian. This will default to equal weights for each
-            mode.
-        bounds: list
+            the lists of statistical parameters that may be passed in child class
+        bounds: list (optional)
             A list of bounds on each parameter. The defaults are for bounds at
             +/- infinity.
         """
-
+        self.distname = 'joint_dist'
         if not isinstance(names, list):
             self.names = [names]
         else:
             self.names = names
 
-        self.num_vars = len(self.names)  # the number of parameters
+        self.num_vars = len(self.names)
 
         # set the bounds for each parameter
         if isinstance(bounds, list):
@@ -2813,10 +2783,317 @@ class MultivariateGaussianDist(object):
                                "a prior transform.")
         else:
             bounds = [(-np.inf, np.inf) for _ in self.names]
-
-        # set bounds as dictionary
         self.bounds = {name: val for name, val in zip(self.names, bounds)}
 
+        self._current_sample = {}  # initialise empty sample
+        self._uncorrelated = None
+        self._current_lnprob = None
+
+        # a dictionary of the parameters as requested by the prior
+        self.requested_parameters = dict()
+        self.reset_request()
+
+        # a dictionary of the rescaled parameters
+        self.rescale_parameters = dict()
+        self.reset_rescale()
+
+        # a list of sampled parameters
+        self.reset_sampled()
+
+    def reset_sampled(self):
+        self.sampled_parameters = []
+        self.current_sample = {}
+
+    def filled_request(self):
+        """
+        Check if all requested parameters have been filled.
+        """
+
+        return not np.any([val is None for val in
+                           self.requested_parameters.values()])
+
+    def reset_request(self):
+        """
+        Reset the requested parameters to None.
+        """
+
+        for name in self.names:
+            self.requested_parameters[name] = None
+
+    def filled_rescale(self):
+        """
+        Check if all the rescaled parameters have been filled.
+        """
+
+        return not np.any([val is None for val in
+                           self.rescale_parameters.values()])
+
+    def reset_rescale(self):
+        """
+        Reset the rescaled parameters to None.
+        """
+
+        for name in self.names:
+            self.rescale_parameters[name] = None
+
+    def get_instantiation_dict(self):
+        subclass_args = infer_args_from_method(self.__init__)
+        property_names = [p for p in dir(self.__class__)
+                          if isinstance(getattr(self.__class__, p), property)]
+        dict_with_properties = self.__dict__.copy()
+        for key in property_names:
+            dict_with_properties[key] = getattr(self, key)
+        instantiation_dict = dict()
+        for key in subclass_args:
+            if isinstance(dict_with_properties[key], list):
+                value = np.asarray(dict_with_properties[key]).tolist()
+            else:
+                value = dict_with_properties[key]
+            instantiation_dict[key] = value
+        return instantiation_dict
+
+    def __len__(self):
+        return len(self.names)
+
+    def __repr__(self):
+        """Overrides the special method __repr__.
+
+        Returns a representation of this instance that resembles how it is instantiated.
+        Works correctly for all child classes
+
+        Returns
+        -------
+        str: A string representation of this instance
+
+        """
+        dist_name = self.__class__.__name__
+        instantiation_dict = self.get_instantiation_dict()
+        args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
+                          for key in instantiation_dict])
+        return "{}({})".format(dist_name, args)
+
+    def prob(self, samp):
+        """
+        Get the probability of a sample. For bounded priors the
+        probability will not be properly normalised.
+        """
+
+        return np.exp(self.ln_prob(samp))
+
+    def _check_samp(self, value):
+        """
+        Get the log-probability of a sample. For bounded priors the
+        probability will not be properly normalised.
+
+        Parameters
+        ----------
+        value: array_like
+            A 1d vector of the sample, or 2d array of sample values with shape
+            NxM, where N is the number of samples and M is the number of
+            parameters.
+
+        Returns
+        -------
+        samp: array_like
+            returns the input value as a sample array
+        outbounds: array_like
+            Boolean Array that selects samples in samp that are out of given bounds
+        """
+        samp = np.array(value)
+        if len(samp.shape) == 1:
+            samp = samp.reshape(1, self.num_vars)
+
+        if len(samp.shape) != 2:
+            raise ValueError("Array is the wrong shape")
+        elif samp.shape[1] != self.num_vars:
+            raise ValueError("Array is the wrong shape")
+
+        # check sample(s) is within bounds
+        outbounds = np.ones(samp.shape[0], dtype=np.bool)
+        for s, bound in zip(samp.T, self.bounds.values()):
+            outbounds = (s < bound[0]) | (s > bound[1])
+            if np.any(outbounds):
+                break
+        return samp, outbounds
+
+    def ln_prob(self, value):
+        """
+        Get the log-probability of a sample. For bounded priors the
+        probability will not be properly normalised.
+
+        Parameters
+        ----------
+        value: array_like
+            A 1d vector of the sample, or 2d array of sample values with shape
+            NxM, where N is the number of samples and M is the number of
+            parameters.
+        """
+
+        samp, outbounds = self._check_samp(value)
+        lnprob = -np.inf * np.ones(samp.shape[0])
+        lnprob = self._ln_prob(samp, lnprob, outbounds)
+        if samp.shape[0] == 1:
+            return lnprob[0]
+        else:
+            return lnprob
+
+    def _ln_prob(self, samp, lnprob, outbounds):
+        """
+        Get the log-probability of a sample. For bounded priors the
+        probability will not be properly normalised. **this method needs overwritten by child class**
+
+        Parameters
+        ----------
+        samp: vector
+            sample to evaluate the ln_prob at
+        lnprob: vector
+            of -inf pased in with the same shape as the number of samples
+        outbounds: array_like
+            boolean array showing which samples in lnprob vector are out of the given bounds
+
+        Returns
+        -------
+        lnprob: vector
+            array of lnprob values for each sample given
+        """
+        """
+        Here is where the subclass where overwrite ln_prob method
+        """
+        return lnprob
+
+    def sample(self, size=1, **kwargs):
+        """
+        Draw, and set, a sample from the Dist, accompanying method _sample needs to overwritten
+
+        Parameters
+        ----------
+        size: int
+            number of samples to generate, defualts to 1
+        """
+
+        if size is None:
+            size = 1
+        samps = self._sample(size=size, **kwargs)
+        for i, name in enumerate(self.names):
+            if size == 1:
+                self.current_sample[name] = samps[:, i].flatten()[0]
+            else:
+                self.current_sample[name] = samps[:, i].flatten()
+
+    def _sample(self, size, **kwargs):
+        """
+        Draw, and set, a sample from the joint dist (**needs to be ovewritten by child class**)
+
+        Parameters
+        ----------
+        size: int
+            number of samples to generate, defualts to 1
+        """
+        samps = np.zeros((size, len(self)))
+        """
+        Here is where the subclass where overwrite sampling method
+        """
+        return samps
+
+    def rescale(self, value, **kwargs):
+        """
+        Rescale from a unit hypercube to JointPriorDist. Note that no
+        bounds are applied in the rescale function. (child classes need to
+        overwrite accompanying method _rescale().
+
+        Parameters
+        ----------
+        value: array
+            A 1d vector sample (one for each parameter) drawn from a uniform
+            distribution between 0 and 1, or a 2d NxM array of samples where
+            N is the number of samples and M is the number of parameters.
+        kwargs: dict
+            All keyword args that need to be passed to _rescale method, these keyword
+            args are called in the JointPrior rescale methods for each parameter
+
+        Returns
+        -------
+        array:
+            An vector sample drawn from the multivariate Gaussian
+            distribution.
+        """
+        samp = np.asarray(value)
+        if len(samp.shape) == 1:
+            samp = samp.reshape(1, self.num_vars)
+
+        if len(samp.shape) != 2:
+            raise ValueError("Array is the wrong shape")
+        elif samp.shape[1] != self.num_vars:
+            raise ValueError("Array is the wrong shape")
+
+        samp = self._rescale(samp, **kwargs)
+        return np.squeeze(samp)
+
+    def _rescale(self, samp, **kwargs):
+        """
+        rescale a sample from a unit hypercybe to the joint dist (**needs to be ovewritten by child class**)
+
+        Parameters
+        ----------
+        samp: numpy array
+            this is a vector sample drawn from a uniform distribtuion to be rescaled to the distribution
+        """
+        """
+        Here is where the subclass where overwrite rescale method
+        """
+        return samp
+
+
+class MultivariateGaussianDist(BaseJointPriorDist):
+
+    def __init__(self, names, nmodes=1, mus=None, sigmas=None, corrcoefs=None,
+                 covs=None, weights=None, bounds=None):
+        """
+        A class defining a multi-variate Gaussian, allowing multiple modes for
+        a Gaussian mixture model.
+
+        Note: if using a multivariate Gaussian prior, with bounds, this can
+        lead to biases in the marginal likelihood estimate and posterior
+        estimate for nested samplers routines that rely on sampling from a unit
+        hypercube and having a prior transform, e.g., nestle, dynesty and
+        MultiNest.
+
+        Parameters
+        ----------
+        names: list
+            A list of the parameter names in the multivariate Gaussian. The
+            listed parameters must have the same order that they appear in
+            the lists of means, standard deviations, and the correlation
+            coefficient, or covariance, matrices.
+        nmodes: int
+            The number of modes for the mixture model. This defaults to 1,
+            which will be checked against the shape of the other inputs.
+        mus: array_like
+            A list of lists of means of each mode in a multivariate Gaussian
+            mixture model. A single list can be given for a single mode. If
+            this is None then means at zero will be assumed.
+        sigmas: array_like
+            A list of lists of the standard deviations of each mode of the
+            multivariate Gaussian. If supplying a correlation coefficient
+            matrix rather than a covariance matrix these values must be given.
+            If this is None unit variances will be assumed.
+        corrcoefs: array
+            A list of square matrices containing the correlation coefficients
+            of the parameters for each mode. If this is None it will be assumed
+            that the parameters are uncorrelated.
+        covs: array
+            A list of square matrices containing the covariance matrix of the
+            multivariate Gaussian.
+        weights: list
+            A list of weights (relative probabilities) for each mode of the
+            multivariate Gaussian. This will default to equal weights for each
+            mode.
+        bounds: list
+            A list of bounds on each parameter. The defaults are for bounds at
+            +/- infinity.
+        """
+        super(MultivariateGaussianDist, self).__init__(names=names, bounds=bounds)
+        self.distname = 'mvg'
         self.mus = []
         self.covs = []
         self.corrcoefs = []
@@ -2890,53 +3167,6 @@ class MultivariateGaussianDist(object):
 
             self.add_mode(mu, sigma, corrcoef, cov, weight)
 
-        # a dictionary of the parameters as requested by the prior
-        self.requested_parameters = dict()
-        self.reset_request()
-
-        # a dictionary of the rescaled parameters
-        self.rescale_parameters = dict()
-        self.reset_rescale()
-
-        # a list of sampled parameters
-        self.reset_sampled()
-
-    def reset_sampled(self):
-        self.sampled_parameters = []
-        self.current_sample = {}
-
-    def filled_request(self):
-        """
-        Check if all requested parameters have been filled.
-        """
-
-        return not np.any([val is None for val in
-                           self.requested_parameters.values()])
-
-    def reset_request(self):
-        """
-        Reset the requested parameters to None.
-        """
-
-        for name in self.names:
-            self.requested_parameters[name] = None
-
-    def filled_rescale(self):
-        """
-        Check is all the rescaled parameters have been filled.
-        """
-
-        return not np.any([val is None for val in
-                           self.rescale_parameters.values()])
-
-    def reset_rescale(self):
-        """
-        Reset the rescaled parameters to None.
-        """
-
-        for name in self.names:
-            self.rescale_parameters[name] = None
-
     def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None,
                  weight=1.):
         """
@@ -3044,69 +3274,37 @@ class MultivariateGaussianDist(object):
         self.mvn.append(scipy.stats.multivariate_normal(mean=self.mus[-1],
                                                         cov=self.covs[-1]))
 
-    def rescale(self, value, mode=None):
-        """
-        Rescale from a unit hypercube to multivariate Gaussian. Note that no
-        bounds are applied in the rescale function.
-
-        Parameters
-        ----------
-        value: array
-            A 1d vector sample (one for each parameter) drawn from a uniform
-            distribution between 0 and 1, or a 2d NxM array of samples where
-            N is the number of samples and M is the number of parameters.
-        mode: int
-            Specify which mode to sample from. If not set then a mode is
-            chosen randomly based on its weight.
-
-        Returns
-        -------
-        array:
-            An vector sample drawn from the multivariate Gaussian
-            distribution.
-        """
+    def _rescale(self, samp, **kwargs):
+        try:
+            mode = kwargs['mode']
+        except KeyError:
+            mode = None
 
-        # pick a mode (with a probability given by their weights)
         if mode is None:
             if self.nmodes == 1:
                 mode = 0
             else:
                 mode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0]
 
-        samp = np.asarray(value)
-        if len(samp.shape) == 1:
-            samp = samp.reshape(1, self.num_vars)
-
-        if len(samp.shape) != 2:
-            raise ValueError("Array is the wrong shape")
-        elif samp.shape[1] != self.num_vars:
-            raise ValueError("Array is the wrong shape")
-
-        # draw points from unit variance, uncorrelated Gaussian
         samp = erfinv(2. * samp - 1) * 2. ** 0.5
 
         # rotate and scale to the multivariate normal shape
         samp = self.mus[mode] + self.sigmas[mode] * np.einsum('ij,kj->ik',
                                                               samp * self.sqeigvalues[mode],
                                                               self.eigvectors[mode])
+        return samp
 
-        return np.squeeze(samp)
-
-    def sample(self, size=1, mode=None):
-        """
-        Draw, and set, a sample from the multivariate Gaussian.
-
-        Parameters
-        ----------
-        mode: int
-            Specify which mode to sample from. If not set then a mode is
-            chosen randomly based on its weight.
-        """
-
-        if size is None:
-            size = 1
+    def _sample(self, size, **kwargs):
+        try:
+            mode = kwargs['mode']
+        except KeyError:
+            mode = None
 
-        # samples drawn from unit variance uncorrelated multivariate Gaussian
+        if mode is None:
+            if self.nmodes == 1:
+                mode = 0
+            else:
+                mode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0]
         samps = np.zeros((size, len(self)))
         for i in range(size):
             inbound = False
@@ -3127,42 +3325,9 @@ class MultivariateGaussianDist(object):
                 if not outbound:
                     inbound = True
 
-        for i, name in enumerate(self.names):
-            if size == 1:
-                self.current_sample[name] = samps[:, i].flatten()[0]
-            else:
-                self.current_sample[name] = samps[:, i].flatten()
-
-    def ln_prob(self, value):
-        """
-        Get the log-probability of a sample. For bounded priors the
-        probability will not be properly normalised.
-
-        Parameters
-        ----------
-        value: array_like
-            A 1d vector of the sample, or 2d array of sample values with shape
-            NxM, where N is the number of samples and M is the number of
-            parameters.
-        """
-
-        samp = np.asarray(value)
-        if len(samp.shape) == 1:
-            samp = samp.reshape(1, self.num_vars)
-
-        if len(samp.shape) != 2:
-            raise ValueError("Array is the wrong shape")
-        elif samp.shape[1] != self.num_vars:
-            raise ValueError("Array is the wrong shape")
-
-        # check sample(s) is within bounds
-        outbounds = np.ones(samp.shape[0], dtype=np.bool)
-        for s, bound in zip(samp.T, self.bounds.values()):
-            outbounds = (s < bound[0]) | (s > bound[1])
-            if np.any(outbounds):
-                break
+        return samps
 
-        lnprob = -np.inf * np.ones(samp.shape[0])
+    def _ln_prob(self, samp, lnprob, outbounds):
         for j in range(samp.shape[0]):
             # loop over the modes and sum the probabilities
             for i in range(self.nmodes):
@@ -3170,55 +3335,7 @@ class MultivariateGaussianDist(object):
 
         # set out-of-bounds values to -inf
         lnprob[outbounds] = -np.inf
-
-        if samp.shape[0] == 1:
-            return lnprob[0]
-        else:
-            return lnprob
-
-    def prob(self, samp):
-        """
-        Get the probability of a sample. For bounded priors the
-        probability will not be properly normalised.
-        """
-
-        return np.exp(self.ln_prob(samp))
-
-    def get_instantiation_dict(self):
-        subclass_args = infer_args_from_method(self.__init__)
-        property_names = [p for p in dir(self.__class__)
-                          if isinstance(getattr(self.__class__, p), property)]
-        dict_with_properties = self.__dict__.copy()
-        for key in property_names:
-            dict_with_properties[key] = getattr(self, key)
-        instantiation_dict = dict()
-        for key in subclass_args:
-            if isinstance(dict_with_properties[key], list):
-                value = np.asarray(dict_with_properties[key]).tolist()
-            else:
-                value = dict_with_properties[key]
-            instantiation_dict[key] = value
-        return instantiation_dict
-
-    def __len__(self):
-        return len(self.names)
-
-    def __repr__(self):
-        """Overrides the special method __repr__.
-
-        Returns a representation of this instance that resembles how it is instantiated.
-        Works correctly for all child classes
-
-        Returns
-        -------
-        str: A string representation of this instance
-
-        """
-        dist_name = self.__class__.__name__
-        instantiation_dict = self.get_instantiation_dict()
-        args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
-                          for key in instantiation_dict])
-        return "{}({})".format(dist_name, args)
+        return lnprob
 
     def __eq__(self, other):
         if self.__class__ != other.__class__:
@@ -3256,132 +3373,136 @@ class MultivariateNormalDist(MultivariateGaussianDist):
     """ A synonym for the :class:`~bilby.core.prior.MultivariateGaussianDist` distribution."""
 
 
-class MultivariateGaussian(Prior):
+class JointPrior(Prior):
 
-    def __init__(self, mvg, name=None, latex_label=None, unit=None):
-        """
-        A prior class for a multivariate Gaussian (mixture model) prior.
+    def __init__(self, dist, name=None, latex_label=None, unit=None):
+        """This defines the single parameter Prior object for parameters that belong to a JointPriorDist
 
         Parameters
         ----------
-        mvg: MultivariateGaussianDist
-            A :class:`bilby.core.prior.MultivariateGaussianDist` object defining
-            the multivariate Gaussian distribution. This object is not copied,
-            as it needs to be shared across multiple priors, and as such its
-            contents will be altered by the prior.
+        dist: ChildClass of BaseJointPriorDist
+            The shared JointPriorDistribution that this parameter belongs to
         name: str
-            See superclass
+            Name of this parameter. Must be contained in dist.names
         latex_label: str
             See superclass
         unit: str
             See superclass
-
         """
+        if BaseJointPriorDist not in dist.__class__.__bases__:
+            raise TypeError("Must supply a JointPriorDist object instance to be shared by all joint params")
+
+        if name not in dist.names:
+            raise ValueError("'{}' is not a parameter in the JointPriorDist")
 
-        if not isinstance(mvg, MultivariateGaussianDist):
-            raise TypeError("Must supply a multivariate Gaussian object")
+        self.dist = dist
+        super(JointPrior, self).__init__(name=name, latex_label=latex_label, unit=unit, minimum=dist.bounds[name][0],
+                                         maximum=dist.bounds[name][1])
 
-        # check name is in the MultivariateGaussianDist class
-        if name not in mvg.names:
-            raise ValueError("'{}' is not a parameter in the multivariate "
-                             "Gaussian")
-        self.mvg = mvg
+    @property
+    def minimum(self):
+        return self._minimum
+
+    @minimum.setter
+    def minimum(self, minimum):
+        self._minimum = minimum
+        self.dist.bounds[self.name] = (minimum, self.dist.bounds[self.name][1])
+
+    @property
+    def maximum(self):
+        return self._maximum
 
-        super(MultivariateGaussian, self).__init__(name=name, latex_label=latex_label, unit=unit,
-                                                   minimum=mvg.bounds[name][0],
-                                                   maximum=mvg.bounds[name][1])
+    @maximum.setter
+    def maximum(self, maximum):
+        self._maximum = maximum
+        self.dist.bounds[self.name] = (self.dist.bounds[self.name][0], maximum)
 
-    def rescale(self, val, mode=None):
+    def rescale(self, val, **kwargs):
         """
         Scale a unit hypercube sample to the prior.
 
         Parameters
         ----------
-        mode: int
-            Specify which mode to sample from. If not set then a mode is
-            chosen randomly based on its weight.
+        val: array_like
+            value drawn from unit hypercube to be rescaled onto the prior
+        kwargs: dict
+            all kwargs passed to the dist.rescale method
+        Returns
+        -------
+        float:
+            A sample from the prior paramter.
         """
 
-        Prior.test_valid_for_rescaling(val)
-
-        # add parameter value to multivariate Gaussian
-        self.mvg.rescale_parameters[self.name] = val
+        self.test_valid_for_rescaling(val)
+        self.dist.rescale_parameters[self.name] = val
 
-        if self.mvg.filled_rescale():
-            values = np.array(list(self.mvg.rescale_parameters.values())).T
-            samples = self.mvg.rescale(values, mode=mode)
-            self.mvg.reset_rescale()
+        if self.dist.filled_rescale():
+            values = np.array(list(self.dist.rescale_parameters.values())).T
+            samples = self.dist.rescale(values, **kwargs)
+            self.dist.reset_rescale()
             return samples
         else:
             return []  # return empty list
 
-    def sample(self, size=1, mode=None):
+    def sample(self, size=1, **kwargs):
         """
         Draw a sample from the prior.
 
         Parameters
         ----------
-        mode: int
-            Specify which mode to sample from. If not set then a mode is
-            chosen randomly based on its weight.
-
+        size: int, float (defaults to 1)
+            number of samples to draw
+        kwargs: dict
+            kwargs passed to the dist.sample method
         Returns
         -------
         float:
             A sample from the prior paramter.
         """
 
-        if self.name in self.mvg.sampled_parameters:
+        if self.name in self.dist.sampled_parameters:
             logger.warning("You have already drawn a sample from parameter "
                            "'{}'. The same sample will be "
                            "returned".format(self.name))
 
-        if len(self.mvg.current_sample) == 0:
+        if len(self.dist.current_sample) == 0:
             # generate a sample
-            self.mvg.sample(size=size, mode=mode)
+            self.dist.sample(size=size, **kwargs)
 
-        sample = self.mvg.current_sample[self.name]
+        sample = self.dist.current_sample[self.name]
 
-        if self.name not in self.mvg.sampled_parameters:
-            self.mvg.sampled_parameters.append(self.name)
+        if self.name not in self.dist.sampled_parameters:
+            self.dist.sampled_parameters.append(self.name)
 
-        if len(self.mvg.sampled_parameters) == len(self.mvg):
+        if len(self.dist.sampled_parameters) == len(self.dist):
             # reset samples
-            self.mvg.reset_sampled()
+            self.dist.reset_sampled()
         self.least_recently_sampled = sample
         return sample
 
-    def prob(self, val):
-        """Return the prior probability of val
+    def ln_prob(self, val):
+        """
+        Return the natural logarithm of the prior probability. Note that this
+        will not be correctly normalised if there are bounds on the
+        distribution.
 
         Parameters
         ----------
-        val: float
-
+        val: array_like
+            value to evaluate the prior log-prob at
         Returns
         -------
         float:
-
-        """
-
-        return np.exp(self.ln_prob(val))
-
-    def ln_prob(self, val):
-        """
-        Return the natural logarithm of the prior probability. Note that this
-        will not be correctly normalised if there are bounds on the
-        distribution.
+            the logp value for the prior at given sample
         """
+        self.dist.requested_parameters[self.name] = val
 
-        # add parameter value to multivariate Gaussian
-        self.mvg.requested_parameters[self.name] = val
-
-        if self.mvg.filled_request():
+        if self.dist.filled_request():
             # all required parameters have been set
-            values = list(self.mvg.requested_parameters.values())
+            values = list(self.dist.requested_parameters.values())
 
             # check for the same number of values for each parameter
-            for i in range(len(self.mvg) - 1):
+            for i in range(len(self.dist) - 1):
                 if (isinstance(values[i], (list, np.ndarray)) or
                         isinstance(values[i + 1], (list, np.ndarray))):
                     if (isinstance(values[i], (list, np.ndarray)) and
@@ -3393,11 +3514,10 @@ class MultivariateGaussian(Prior):
                         raise ValueError("Each parameter must have the same "
                                          "number of requested values.")
 
-            lnp = self.mvg.ln_prob(np.asarray(values).T)
+            lnp = self.dist.ln_prob(np.asarray(values).T)
 
             # reset the requested parameters
-            self.mvg.reset_request()
-
+            self.dist.reset_request()
             return lnp
         else:
             # if not all parameters have been requested yet, just return 0
@@ -3415,32 +3535,33 @@ class MultivariateGaussian(Prior):
                 else:
                     return np.zeros_like(val)
 
-    @property
-    def minimum(self):
-        return self._minimum
+    def prob(self, val):
+        """Return the prior probability of val
 
-    @minimum.setter
-    def minimum(self, minimum):
-        self._minimum = minimum
+        Parameters
+        ----------
+        val: array_like
+            value to evaluate the prior prob at
 
-        # update the bounds in the MultivariateGaussianDist
-        self.mvg.bounds[self.name] = (minimum, self.mvg.bounds[self.name][1])
+        Returns
+        -------
+        float:
+            the p value for the prior at given sample
+        """
 
-    @property
-    def maximum(self):
-        return self._maximum
+        return np.exp(self.ln_prob(val))
 
-    @maximum.setter
-    def maximum(self, maximum):
-        self._maximum = maximum
 
-        # update the bounds in the MultivariateGaussianDist
-        self.mvg.bounds[self.name] = (self.mvg.bounds[self.name][0], maximum)
+class MultivariateGaussian(JointPrior):
+    def __init__(self, dist, name=None, latex_label=None, unit=None):
+        if not isinstance(dist, MultivariateGaussianDist):
+            raise JointPriorDistError("dist object must be instance of MultivariateGaussianDist")
+        super(MultivariateGaussian, self).__init__(dist=dist, name=name, latex_label=latex_label, unit=unit)
 
 
 class MultivariateNormal(MultivariateGaussian):
-    """A synonym for the :class:`bilby.core.prior.MultivariateGaussian`
-     prior distribution."""
+    """ A synonym for the :class:`bilby.core.prior.MultivariateGaussian`
+        prior distribution."""
 
 
 def conditional_prior_factory(prior_class):
@@ -3682,3 +3803,7 @@ class ConditionalPriorDictException(PriorDictException):
 
 class IllegalConditionsException(ConditionalPriorDictException):
     """ Exception class to handle prior dicts that contain unresolvable conditions. """
+
+
+class JointPriorDistError(PriorException):
+    """ Class for Error handling of JointPriorDists for JointPriors """
diff --git a/test/prior_test.py b/test/prior_test.py
index 77c939b8e02399cf34154c254e0f2e3901645826..8b2b9efe9c1f73e67c0ef8dab5256b777bfb35b2 100644
--- a/test/prior_test.py
+++ b/test/prior_test.py
@@ -5,7 +5,6 @@ from mock import Mock
 import mock
 import numpy as np
 import os
-from collections import OrderedDict
 import scipy.stats as ss
 
 
@@ -197,10 +196,10 @@ class TestPriorClasses(unittest.TestCase):
             bilby.core.prior.Gamma(name='test', unit='unit', k=1, theta=1),
             bilby.core.prior.ChiSquared(name='test', unit='unit', nu=2),
             bilby.gw.prior.AlignedSpin(name='test', unit='unit'),
-            bilby.core.prior.MultivariateGaussian(mvg=mvg, name='testa', unit='unit'),
-            bilby.core.prior.MultivariateGaussian(mvg=mvg, name='testb', unit='unit'),
-            bilby.core.prior.MultivariateNormal(mvg=mvn, name='testa', unit='unit'),
-            bilby.core.prior.MultivariateNormal(mvg=mvn, name='testb', unit='unit'),
+            bilby.core.prior.MultivariateGaussian(dist=mvg, name='testa', unit='unit'),
+            bilby.core.prior.MultivariateGaussian(dist=mvg, name='testb', unit='unit'),
+            bilby.core.prior.MultivariateNormal(dist=mvn, name='testa', unit='unit'),
+            bilby.core.prior.MultivariateNormal(dist=mvn, name='testb', unit='unit'),
             bilby.core.prior.ConditionalDeltaFunction(condition_func=condition_func, name='test', unit='unit', peak=1),
             bilby.core.prior.ConditionalGaussian(condition_func=condition_func, name='test', unit='unit', mu=0, sigma=1),
             bilby.core.prior.ConditionalPowerLaw(condition_func=condition_func, name='test', unit='unit', alpha=0, minimum=0, maximum=1),
@@ -230,9 +229,9 @@ class TestPriorClasses(unittest.TestCase):
     def test_minimum_rescaling(self):
         """Test the the rescaling works as expected."""
         for prior in self.priors:
-            if isinstance(prior, bilby.core.prior.MultivariateGaussian):
+            if bilby.core.prior.JointPrior in prior.__class__.__mro__:
                 minimum_sample = prior.rescale(0)
-                if prior.mvg.filled_rescale():
+                if prior.dist.filled_rescale():
                     self.assertAlmostEqual(minimum_sample[0], prior.minimum)
                     self.assertAlmostEqual(minimum_sample[1], prior.minimum)
             else:
@@ -242,9 +241,9 @@ class TestPriorClasses(unittest.TestCase):
     def test_maximum_rescaling(self):
         """Test the the rescaling works as expected."""
         for prior in self.priors:
-            if isinstance(prior, bilby.core.prior.MultivariateGaussian):
+            if bilby.core.prior.JointPrior in prior.__class__.__mro__:
                 maximum_sample = prior.rescale(0)
-                if prior.mvg.filled_rescale():
+                if prior.dist.filled_rescale():
                     self.assertAlmostEqual(maximum_sample[0], prior.maximum)
                     self.assertAlmostEqual(maximum_sample[1], prior.maximum)
             else:
@@ -255,8 +254,8 @@ class TestPriorClasses(unittest.TestCase):
         """Test the the rescaling works as expected."""
         for prior in self.priors:
             many_samples = prior.rescale(np.random.uniform(0, 1, 1000))
-            if isinstance(prior, bilby.core.prior.MultivariateGaussian):
-                if not prior.mvg.filled_rescale():
+            if bilby.core.prior.JointPrior in prior.__class__.__mro__:
+                if not prior.dist.filled_rescale():
                     continue
             self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)))
 
@@ -304,7 +303,7 @@ class TestPriorClasses(unittest.TestCase):
     def test_prob_and_ln_prob(self):
         for prior in self.priors:
             sample = prior.sample()
-            if not isinstance(prior, bilby.core.prior.MultivariateGaussian):
+            if not bilby.core.prior.JointPrior in prior.__class__.__mro__:
                 # due to the way that the Multivariate Gaussian prior must sequentially call
                 # the prob and ln_prob functions, it must be ignored in this test.
                 self.assertAlmostEqual(np.log(prior.prob(sample)), prior.ln_prob(sample), 12)
@@ -312,7 +311,7 @@ class TestPriorClasses(unittest.TestCase):
     def test_many_prob_and_many_ln_prob(self):
         for prior in self.priors:
             samples = prior.sample(10)
-            if not isinstance(prior, bilby.core.prior.MultivariateGaussian):
+            if not bilby.core.prior.JointPrior in prior.__class__.__mro__:
                 ln_probs = prior.ln_prob(samples)
                 probs = prior.prob(samples)
                 for sample, logp, p in zip(samples, ln_probs, probs):
@@ -323,9 +322,8 @@ class TestPriorClasses(unittest.TestCase):
         domain = np.linspace(0, 1, 100)
         threshold = 1e-9
         for prior in self.priors:
-            if isinstance(prior, (
-                    bilby.core.prior.DeltaFunction,
-                    bilby.core.prior.MultivariateGaussian)):
+            if isinstance(prior, bilby.core.prior.DeltaFunction) or \
+                    bilby.core.prior.JointPrior in prior.__class__.__mro__:
                 continue
             rescaled = prior.rescale(domain)
             max_difference = max(np.abs(domain - prior.cdf(rescaled)))
@@ -490,7 +488,7 @@ class TestPriorClasses(unittest.TestCase):
                 continue
             if isinstance(prior, bilby.core.prior.Cauchy):
                 continue
-            if isinstance(prior, bilby.core.prior.MultivariateGaussian):
+            if bilby.core.prior.JointPrior in prior.__class__.__mro__:
                 continue
             elif isinstance(prior, bilby.core.prior.Gaussian):
                 domain = np.linspace(-1e2, 1e2, 1000)
@@ -1248,10 +1246,10 @@ class TestJsonIO(unittest.TestCase):
             aa=bilby.core.prior.Gamma(name='test', unit='unit', k=1, theta=1),
             ab=bilby.core.prior.ChiSquared(name='test', unit='unit', nu=2),
             ac=bilby.gw.prior.AlignedSpin(name='test', unit='unit'),
-            ad=bilby.core.prior.MultivariateGaussian(mvg=mvg, name='testa', unit='unit'),
-            ae=bilby.core.prior.MultivariateGaussian(mvg=mvg, name='testb', unit='unit'),
-            af=bilby.core.prior.MultivariateNormal(mvg=mvn, name='testa', unit='unit'),
-            ag=bilby.core.prior.MultivariateNormal(mvg=mvn, name='testb', unit='unit')
+            ad=bilby.core.prior.MultivariateGaussian(dist=mvg, name='testa', unit='unit'),
+            ae=bilby.core.prior.MultivariateGaussian(dist=mvg, name='testb', unit='unit'),
+            af=bilby.core.prior.MultivariateNormal(dist=mvn, name='testa', unit='unit'),
+            ag=bilby.core.prior.MultivariateNormal(dist=mvn, name='testb', unit='unit')
         ))
 
     def test_read_write_to_json(self):