diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index be34de2031e1f4ed0f9a37a257037b64dde89011..7e57c7f1ff27bd86ef8182d7ec7efb86476bf158 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -2434,50 +2434,20 @@ class FermiDirac(Prior):
             return lnp
 
 
-class MultivariateGaussianDist(object):
+class JointPriorDist(object):
 
-    def __init__(self, names, nmodes=1, mus=None, sigmas=None, corrcoefs=None,
-                 covs=None, weights=None, bounds=None):
+    def __init__(self, names, bounds=None):
         """
-        A class defining a multi-variate Gaussian, allowing multiple modes for
-        a Gaussian mixture model.
+        A class defining JointPriorDist that will be overwritten with child
+        classes defining the joint prior distribtuions between given parameters,
 
-        Note: if using a multivariate Gaussian prior, with bounds, this can
-        lead to biases in the marginal likelihood estimate and posterior
-        estimate for nested samplers routines that rely on sampling from a unit
-        hypercube and having a prior transform, e.g., nestle, dynesty and
-        MultiNest.
 
         Parameters
         ----------
         names: list
-            A list of the parameter names in the multivariate Gaussian. The
+            A list of the parameter names in the JointPriorDist. The
             listed parameters must have the same order that they appear in
-            the lists of means, standard deviations, and the correlation
-            coefficient, or covariance, matrices.
-        nmodes: int
-            The number of modes for the mixture model. This defaults to 1,
-            which will be checked against the shape of the other inputs.
-        mus: array_like
-            A list of lists of means of each mode in a multivariate Gaussian
-            mixture model. A single list can be given for a single mode. If
-            this is None then means at zero will be assumed.
-        sigmas: array_like
-            A list of lists of the standard deviations of each mode of the
-            multivariate Gaussian. If supplying a correlation coefficient
-            matrix rather than a covariance matrix these values must be given.
-            If this is None unit variances will be assumed.
-        corrcoefs: array
-            A list of square matrices containing the correlation coefficients
-            of the parameters for each mode. If this is None it will be assumed
-            that the parameters are uncorrelated.
-        covs: array
-            A list of square matrices containing the covariance matrix of the
-            multivariate Gaussian.
-        weights: list
-            A list of weights (relative probabilities) for each mode of the
-            multivariate Gaussian. This will default to equal weights for each
-            mode.
+            the lists of statistical parameters that may be passed in child class
         bounds: list
             A list of bounds on each parameter. The defaults are for bounds at
             +/- infinity.
@@ -2488,7 +2458,7 @@ class MultivariateGaussianDist(object):
         else:
             self.names = names
 
-        self.num_vars = len(self.names)  # the number of parameters
+        self.num_vars = len(self.names)
 
         # set the bounds for each parameter
         if isinstance(bounds, list):
@@ -2513,83 +2483,12 @@ class MultivariateGaussianDist(object):
                                "a prior transform.")
         else:
             bounds = [(-np.inf, np.inf) for _ in self.names]
-
-        # set bounds as dictionary
         self.bounds = {name: val for name, val in zip(self.names, bounds)}
 
-        self.mus = []
-        self.covs = []
-        self.corrcoefs = []
-        self.sigmas = []
-        self.weights = []
-        self.eigvalues = []
-        self.eigvectors = []
-        self.sqeigvalues = []  # square root of the eigenvalues
-        self.mvn = []  # list of multivariate normal distributions
-
         self._current_sample = {}  # initialise empty sample
         self._uncorrelated = None
         self._current_lnprob = None
 
-        # put values in lists if required
-        if nmodes == 1:
-            if mus is not None:
-                if len(np.shape(mus)) == 1:
-                    mus = [mus]
-                elif len(np.shape(mus)) == 0:
-                    raise ValueError("Must supply a list of means")
-            if sigmas is not None:
-                if len(np.shape(sigmas)) == 1:
-                    sigmas = [sigmas]
-                elif len(np.shape(sigmas)) == 0:
-                    raise ValueError("Must supply a list of standard "
-                                     "deviations")
-            if covs is not None:
-                if isinstance(covs, np.ndarray):
-                    covs = [covs]
-                elif isinstance(covs, list):
-                    if len(np.shape(covs)) == 2:
-                        covs = [np.array(covs)]
-                    elif len(np.shape(covs)) != 3:
-                        raise TypeError("List of covariances the wrong shape")
-                else:
-                    raise TypeError("Must pass a list of covariances")
-            if corrcoefs is not None:
-                if isinstance(corrcoefs, np.ndarray):
-                    corrcoefs = [corrcoefs]
-                elif isinstance(corrcoefs, list):
-                    if len(np.shape(corrcoefs)) == 2:
-                        corrcoefs = [np.array(corrcoefs)]
-                    elif len(np.shape(corrcoefs)) != 3:
-                        raise TypeError("List of correlation coefficients the wrong shape")
-                elif not isinstance(corrcoefs, list):
-                    raise TypeError("Must pass a list of correlation "
-                                    "coefficients")
-            if weights is not None:
-                if isinstance(weights, (int, float)):
-                    weights = [weights]
-                elif isinstance(weights, list):
-                    if len(weights) != 1:
-                        raise ValueError("Wrong number of weights given")
-
-        for val in [mus, sigmas, covs, corrcoefs, weights]:
-            if val is not None and not isinstance(val, list):
-                raise TypeError("Value must be a list")
-            else:
-                if val is not None and len(val) != nmodes:
-                    raise ValueError("Wrong number of modes given")
-
-        # add the modes
-        self.nmodes = 0
-        for i in range(nmodes):
-            mu = mus[i] if mus is not None else None
-            sigma = sigmas[i] if sigmas is not None else None
-            corrcoef = corrcoefs[i] if corrcoefs is not None else None
-            cov = covs[i] if covs is not None else None
-            weight = weights[i] if weights is not None else 1.
-
-            self.add_mode(mu, sigma, corrcoef, cov, weight)
-
         # a dictionary of the parameters as requested by the prior
         self.requested_parameters = OrderedDict()
         self.reset_request()
@@ -2626,16 +2525,310 @@ class MultivariateGaussianDist(object):
         Check is all the rescaled parameters have been filled.
         """
 
-        return not np.any([val is None for val in
-                           self.rescale_parameters.values()])
+        return not np.any([val is None for val in
+                           self.rescale_parameters.values()])
+
+    def reset_rescale(self):
+        """
+        Reset the rescaled parameters to None.
+        """
+
+        for name in self.names:
+            self.rescale_parameters[name] = None
+
+    def _get_instantiation_dict(self):
+        subclass_args = infer_args_from_method(self.__init__)
+        property_names = [p for p in dir(self.__class__)
+                          if isinstance(getattr(self.__class__, p), property)]
+        dict_with_properties = self.__dict__.copy()
+        for key in property_names:
+            dict_with_properties[key] = getattr(self, key)
+        instantiation_dict = OrderedDict()
+        for key in subclass_args:
+            if isinstance(dict_with_properties[key], list):
+                value = np.asarray(dict_with_properties[key]).tolist()
+            else:
+                value = dict_with_properties[key]
+            instantiation_dict[key] = value
+        return instantiation_dict
+
+    def __len__(self):
+        return len(self.names)
+
+    def __repr__(self):
+        """Overrides the special method __repr__.
+
+        Returns a representation of this instance that resembles how it is instantiated.
+        Works correctly for all child classes
+
+        Returns
+        -------
+        str: A string representation of this instance
+
+        """
+        dist_name = self.__class__.__name__
+        instantiation_dict = self._get_instantiation_dict()
+        args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
+                          for key in instantiation_dict])
+        return "{}({})".format(dist_name, args)
+
+    def prob(self, samp):
+        """
+        Get the probability of a sample. For bounded priors the
+        probability will not be properly normalised.
+        """
+
+        return np.exp(self.ln_prob(samp))
+
+    def _check_samp(self, value):
+        samp = np.asarray(value)
+        if len(samp.shape) == 1:
+            samp = samp.reshape(1, self.num_vars)
+
+        if len(samp.shape) != 2:
+            raise ValueError("Array is the wrong shape")
+        elif samp.shape[1] != self.num_vars:
+            raise ValueError("Array is the wrong shape")
+
+        # check sample(s) is within bounds
+        outbounds = np.ones(samp.shape[0], dtype=np.bool)
+        for s, bound in zip(samp.T, self.bounds.values()):
+            outbounds = (s < bound[0]) | (s > bound[1])
+            if np.any(outbounds):
+                break
+        return samp, outbounds
+
+    def ln_prob(self, value):
+        """
+        Get the log-probability of a sample. For bounded priors the
+        probability will not be properly normalised.
+
+        Parameters
+        ----------
+        value: array_like
+            A 1d vector of the sample, or 2d array of sample values with shape
+            NxM, where N is the number of samples and M is the number of
+            parameters.
+        """
+
+        samp, outbounds = self._check_samp(value)
+        lnprob = -np.inf * np.ones(samp.shape[0])
+        lnprob = self._ln_prob(samp, lnprob, outbounds)
+        if samp.shape[0] == 1:
+            return lnprob[0]
+        else:
+            return lnprob
+
+    def _ln_prob(self, samp, lnprob, outbounds):
+        '''
+        CHILD CLASS OVERWRITES THIS METHOD AND FILLS IN THIS PART OF ln_prob METHOD to samplethe lnprob for
+        the value of this sample
+        '''
+        return lnprob
+
+    def sample(self, size=1, **kwargs):
+        """
+        Draw, and set, a sample from the Dist
+
+        Parameters
+        ----------
+        size: int
+            number of samples to generate, defualts to 1
+        """
+
+        if size is None:
+            size = 1
+        # samples drawn from unit variance uncorrelated multivariate Gaussian
+        samps = self._draw_samp(size=size, **kwargs)
+        for i, name in enumerate(self.names):
+            if size == 1:
+                self.current_sample[name] = samps[:, i].flatten()[0]
+            else:
+                self.current_sample[name] = samps[:, i].flatten()
+
+    def _draw_samp(self, size, **kwargs):
+        """
+        Draw, and set, a sample from the joint dist (needs to be ovewritten by child class)
+
+        Parameters
+        ----------
+        size: int
+            number of samples to generate, defualts to 1
+        """
+        samps = np.zeros((size, len(self)))
+        """
+        Here is where the subclass where overwrite sampling method
+        """
+        return samps
+
+    def rescale(self, value, **kwargs):
+        """
+        Rescale from a unit hypercube to multivariate Gaussian. Note that no
+        bounds are applied in the rescale function.
+
+        Parameters
+        ----------
+        value: array
+            A 1d vector sample (one for each parameter) drawn from a uniform
+            distribution between 0 and 1, or a 2d NxM array of samples where
+            N is the number of samples and M is the number of parameters.
+        mode: int
+            Specify which mode to sample from. If not set then a mode is
+            chosen randomly based on its weight.
+
+        Returns
+        -------
+        array:
+            An vector sample drawn from the multivariate Gaussian
+            distribution.
+        """
+
+        # pick a mode (with a probability given by their weights)
+
+        samp = np.asarray(value)
+        if len(samp.shape) == 1:
+            samp = samp.reshape(1, self.num_vars)
+
+        if len(samp.shape) != 2:
+            raise ValueError("Array is the wrong shape")
+        elif samp.shape[1] != self.num_vars:
+            raise ValueError("Array is the wrong shape")
+
+        samp = self._rescale(samp, **kwargs)
+        return np.squeeze(samp)
+
+    def _rescale(self, samp, **kwargs):
+        '''
+        needs to be overwritten
+        :param samp:
+        :param kwargs:
+        :return:
+        '''
+        return samp
+
+
+class MultivariateGaussianDist(JointPriorDist):
+
+    def __init__(self, names, nmodes=1, mus=None, sigmas=None, corrcoefs=None,
+                 covs=None, weights=None, bounds=None):
+        """
+        A class defining a multi-variate Gaussian, allowing multiple modes for
+        a Gaussian mixture model.
+
+        Note: if using a multivariate Gaussian prior, with bounds, this can
+        lead to biases in the marginal likelihood estimate and posterior
+        estimate for nested samplers routines that rely on sampling from a unit
+        hypercube and having a prior transform, e.g., nestle, dynesty and
+        MultiNest.
+
+        Parameters
+        ----------
+        names: list
+            A list of the parameter names in the multivariate Gaussian. The
+            listed parameters must have the same order that they appear in
+            the lists of means, standard deviations, and the correlation
+            coefficient, or covariance, matrices.
+        nmodes: int
+            The number of modes for the mixture model. This defaults to 1,
+            which will be checked against the shape of the other inputs.
+        mus: array_like
+            A list of lists of means of each mode in a multivariate Gaussian
+            mixture model. A single list can be given for a single mode. If
+            this is None then means at zero will be assumed.
+        sigmas: array_like
+            A list of lists of the standard deviations of each mode of the
+            multivariate Gaussian. If supplying a correlation coefficient
+            matrix rather than a covariance matrix these values must be given.
+            If this is None unit variances will be assumed.
+        corrcoefs: array
+            A list of square matrices containing the correlation coefficients
+            of the parameters for each mode. If this is None it will be assumed
+            that the parameters are uncorrelated.
+        covs: array
+            A list of square matrices containing the covariance matrix of the
+            multivariate Gaussian.
+        weights: list
+            A list of weights (relative probabilities) for each mode of the
+            multivariate Gaussian. This will default to equal weights for each
+            mode.
+        bounds: list
+            A list of bounds on each parameter. The defaults are for bounds at
+            +/- infinity.
+        """
+        super(MultivariateGaussianDist, self).__init__(names=names, bounds=bounds)
+
+        self.mus = []
+        self.covs = []
+        self.corrcoefs = []
+        self.sigmas = []
+        self.weights = []
+        self.eigvalues = []
+        self.eigvectors = []
+        self.sqeigvalues = []  # square root of the eigenvalues
+        self.mvn = []  # list of multivariate normal distributions
+
+        self._current_sample = {}  # initialise empty sample
+        self._uncorrelated = None
+        self._current_lnprob = None
+
+        # put values in lists if required
+        if nmodes == 1:
+            if mus is not None:
+                if len(np.shape(mus)) == 1:
+                    mus = [mus]
+                elif len(np.shape(mus)) == 0:
+                    raise ValueError("Must supply a list of means")
+            if sigmas is not None:
+                if len(np.shape(sigmas)) == 1:
+                    sigmas = [sigmas]
+                elif len(np.shape(sigmas)) == 0:
+                    raise ValueError("Must supply a list of standard "
+                                     "deviations")
+            if covs is not None:
+                if isinstance(covs, np.ndarray):
+                    covs = [covs]
+                elif isinstance(covs, list):
+                    if len(np.shape(covs)) == 2:
+                        covs = [np.array(covs)]
+                    elif len(np.shape(covs)) != 3:
+                        raise TypeError("List of covariances the wrong shape")
+                else:
+                    raise TypeError("Must pass a list of covariances")
+            if corrcoefs is not None:
+                if isinstance(corrcoefs, np.ndarray):
+                    corrcoefs = [corrcoefs]
+                elif isinstance(corrcoefs, list):
+                    if len(np.shape(corrcoefs)) == 2:
+                        corrcoefs = [np.array(corrcoefs)]
+                    elif len(np.shape(corrcoefs)) != 3:
+                        raise TypeError("List of correlation coefficients the wrong shape")
+                elif not isinstance(corrcoefs, list):
+                    raise TypeError("Must pass a list of correlation "
+                                    "coefficients")
+            if weights is not None:
+                if isinstance(weights, (int, float)):
+                    weights = [weights]
+                elif isinstance(weights, list):
+                    if len(weights) != 1:
+                        raise ValueError("Wrong number of weights given")
+
+        for val in [mus, sigmas, covs, corrcoefs, weights]:
+            if val is not None and not isinstance(val, list):
+                raise TypeError("Value must be a list")
+            else:
+                if val is not None and len(val) != nmodes:
+                    raise ValueError("Wrong number of modes given")
 
-    def reset_rescale(self):
-        """
-        Reset the rescaled parameters to None.
-        """
+        # add the modes
+        self.nmodes = 0
+        for i in range(nmodes):
+            mu = mus[i] if mus is not None else None
+            sigma = sigmas[i] if sigmas is not None else None
+            corrcoef = corrcoefs[i] if corrcoefs is not None else None
+            cov = covs[i] if covs is not None else None
+            weight = weights[i] if weights is not None else 1.
 
-        for name in self.names:
-            self.rescale_parameters[name] = None
+            self.add_mode(mu, sigma, corrcoef, cov, weight)
 
     def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None,
                  weight=1.):
@@ -2744,69 +2937,37 @@ class MultivariateGaussianDist(object):
         self.mvn.append(scipy.stats.multivariate_normal(mean=self.mus[-1],
                                                         cov=self.covs[-1]))
 
-    def rescale(self, value, mode=None):
-        """
-        Rescale from a unit hypercube to multivariate Gaussian. Note that no
-        bounds are applied in the rescale function.
-
-        Parameters
-        ----------
-        value: array
-            A 1d vector sample (one for each parameter) drawn from a uniform
-            distribution between 0 and 1, or a 2d NxM array of samples where
-            N is the number of samples and M is the number of parameters.
-        mode: int
-            Specify which mode to sample from. If not set then a mode is
-            chosen randomly based on its weight.
-
-        Returns
-        -------
-        array:
-            An vector sample drawn from the multivariate Gaussian
-            distribution.
-        """
+    def _rescale(self, samp, **kwargs):
+        try:
+            mode = kwargs['mode']
+        except KeyError:
+            mode = None
 
-        # pick a mode (with a probability given by their weights)
         if mode is None:
             if self.nmodes == 1:
                 mode = 0
             else:
                 mode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0]
 
-        samp = np.asarray(value)
-        if len(samp.shape) == 1:
-            samp = samp.reshape(1, self.num_vars)
-
-        if len(samp.shape) != 2:
-            raise ValueError("Array is the wrong shape")
-        elif samp.shape[1] != self.num_vars:
-            raise ValueError("Array is the wrong shape")
-
-        # draw points from unit variance, uncorrelated Gaussian
         samp = erfinv(2. * samp - 1) * 2. ** 0.5
 
         # rotate and scale to the multivariate normal shape
         samp = self.mus[mode] + self.sigmas[mode] * np.einsum('ij,kj->ik',
                                                               samp * self.sqeigvalues[mode],
                                                               self.eigvectors[mode])
+        return samp
 
-        return np.squeeze(samp)
-
-    def sample(self, size=1, mode=None):
-        """
-        Draw, and set, a sample from the multivariate Gaussian.
-
-        Parameters
-        ----------
-        mode: int
-            Specify which mode to sample from. If not set then a mode is
-            chosen randomly based on its weight.
-        """
-
-        if size is None:
-            size = 1
+    def _draw_samp(self, size, **kwargs):
+        try:
+            mode = kwargs['mode']
+        except KeyError:
+            mode = None
 
-        # samples drawn from unit variance uncorrelated multivariate Gaussian
+        if mode is None:
+            if self.nmodes == 1:
+                mode = 0
+            else:
+                mode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0]
         samps = np.zeros((size, len(self)))
         for i in range(size):
             inbound = False
@@ -2827,42 +2988,9 @@ class MultivariateGaussianDist(object):
                 if not outbound:
                     inbound = True
 
-        for i, name in enumerate(self.names):
-            if size == 1:
-                self.current_sample[name] = samps[:, i].flatten()[0]
-            else:
-                self.current_sample[name] = samps[:, i].flatten()
-
-    def ln_prob(self, value):
-        """
-        Get the log-probability of a sample. For bounded priors the
-        probability will not be properly normalised.
-
-        Parameters
-        ----------
-        value: array_like
-            A 1d vector of the sample, or 2d array of sample values with shape
-            NxM, where N is the number of samples and M is the number of
-            parameters.
-        """
-
-        samp = np.asarray(value)
-        if len(samp.shape) == 1:
-            samp = samp.reshape(1, self.num_vars)
-
-        if len(samp.shape) != 2:
-            raise ValueError("Array is the wrong shape")
-        elif samp.shape[1] != self.num_vars:
-            raise ValueError("Array is the wrong shape")
-
-        # check sample(s) is within bounds
-        outbounds = np.ones(samp.shape[0], dtype=np.bool)
-        for s, bound in zip(samp.T, self.bounds.values()):
-            outbounds = (s < bound[0]) | (s > bound[1])
-            if np.any(outbounds):
-                break
+        return samps
 
-        lnprob = -np.inf * np.ones(samp.shape[0])
+    def _ln_prob(self, samp, lnprob, outbounds):
         for j in range(samp.shape[0]):
             # loop over the modes and sum the probabilities
             for i in range(self.nmodes):
@@ -2870,55 +2998,7 @@ class MultivariateGaussianDist(object):
 
         # set out-of-bounds values to -inf
         lnprob[outbounds] = -np.inf
-
-        if samp.shape[0] == 1:
-            return lnprob[0]
-        else:
-            return lnprob
-
-    def prob(self, samp):
-        """
-        Get the probability of a sample. For bounded priors the
-        probability will not be properly normalised.
-        """
-
-        return np.exp(self.ln_prob(samp))
-
-    def _get_instantiation_dict(self):
-        subclass_args = infer_args_from_method(self.__init__)
-        property_names = [p for p in dir(self.__class__)
-                          if isinstance(getattr(self.__class__, p), property)]
-        dict_with_properties = self.__dict__.copy()
-        for key in property_names:
-            dict_with_properties[key] = getattr(self, key)
-        instantiation_dict = OrderedDict()
-        for key in subclass_args:
-            if isinstance(dict_with_properties[key], list):
-                value = np.asarray(dict_with_properties[key]).tolist()
-            else:
-                value = dict_with_properties[key]
-            instantiation_dict[key] = value
-        return instantiation_dict
-
-    def __len__(self):
-        return len(self.names)
-
-    def __repr__(self):
-        """Overrides the special method __repr__.
-
-        Returns a representation of this instance that resembles how it is instantiated.
-        Works correctly for all child classes
-
-        Returns
-        -------
-        str: A string representation of this instance
-
-        """
-        dist_name = self.__class__.__name__
-        instantiation_dict = self._get_instantiation_dict()
-        args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
-                          for key in instantiation_dict])
-        return "{}({})".format(dist_name, args)
+        return lnprob
 
     def __eq__(self, other):
         if self.__class__ != other.__class__:
@@ -2956,74 +3036,62 @@ class MultivariateNormalDist(MultivariateGaussianDist):
     """ A synonym for the :class:`~bilby.core.prior.MultivariateGaussianDist` distribution."""
 
 
-class MultivariateGaussian(Prior):
+class JointPrior(Prior):
 
-    def __init__(self, mvg, name=None, latex_label=None, unit=None):
-        """
-        A prior class for a multivariate Gaussian (mixture model) prior.
+    def __init__(self, dist, name=None, latex_label=None, unit=None):
+        if dist.__class__.__bases__ [0] != JointPriorDist:
+            raise TypeError("Must supply a JointPriorDist object instance to be shared by all joint params")
 
-        Parameters
-        ----------
-        mvg: MultivariateGaussianDist
-            A :class:`bilby.core.prior.MultivariateGaussianDist` object defining
-            the multivariate Gaussian distribution. This object is not copied,
-            as it needs to be shared across multiple priors, and as such its
-            contents will be altered by the prior.
-        name: str
-            See superclass
-        latex_label: str
-            See superclass
-        unit: str
-            See superclass
+        if name not in dist.names:
+            raise ValueError("'{}' is not a parameter in the JointPriorDist")
 
-        """
+        self.dist = dist
+        super(JointPrior, self).__init__(name=name, latex_label=latex_label, unit=unit,
+                                                   minimum=dist.bounds[name][0],
+                                                   maximum=dist.bounds[name][1])
+
+    @property
+    def minimum(self):
+        return self._minimum
 
-        if not isinstance(mvg, MultivariateGaussianDist):
-            raise TypeError("Must supply a multivariate Gaussian object")
+    @minimum.setter
+    def minimum(self, minimum):
+        self._minimum = minimum
+        self.dist.bounds[self.name] = (minimum, self.dist.bounds[self.name][1])
 
-        # check name is in the MultivariateGaussianDist class
-        if name not in mvg.names:
-            raise ValueError("'{}' is not a parameter in the multivariate "
-                             "Gaussian")
-        self.mvg = mvg
+    @property
+    def maximum(self):
+        return self._maximum
 
-        super(MultivariateGaussian, self).__init__(name=name, latex_label=latex_label, unit=unit,
-                                                   minimum=mvg.bounds[name][0],
-                                                   maximum=mvg.bounds[name][1])
+    @maximum.setter
+    def maximum(self, maximum):
+        self._maximum = maximum
+        self.dist.bounds[self.name] = (self.dist.bounds[self.name][0], maximum)
 
-    def rescale(self, val, mode=None):
+    def rescale(self, val, **kwargs):
         """
         Scale a unit hypercube sample to the prior.
-
-        Parameters
-        ----------
-        mode: int
-            Specify which mode to sample from. If not set then a mode is
-            chosen randomly based on its weight.
         """
 
-        Prior.test_valid_for_rescaling(val)
+        self.test_valid_for_rescaling(val)
 
         # add parameter value to multivariate Gaussian
-        self.mvg.rescale_parameters[self.name] = val
+        self.dist.rescale_parameters[self.name] = val
 
-        if self.mvg.filled_rescale():
-            values = np.array(list(self.mvg.rescale_parameters.values())).T
-            samples = self.mvg.rescale(values, mode=mode)
-            self.mvg.reset_rescale()
+        if self.dist.filled_rescale():
+            values = np.array(list(self.dist.rescale_parameters.values())).T
+            samples = self.dist.rescale(values, **kwargs)
+            self.dist.reset_rescale()
             return samples
         else:
             return []  # return empty list
 
-    def sample(self, size=1, mode=None):
+    def sample(self, size=1, **kwargs):
         """
         Draw a sample from the prior.
 
         Parameters
         ----------
-        mode: int
-            Specify which mode to sample from. If not set then a mode is
-            chosen randomly based on its weight.
 
         Returns
         -------
@@ -3031,41 +3099,25 @@ class MultivariateGaussian(Prior):
             A sample from the prior paramter.
         """
 
-        if self.name in self.mvg.sampled_parameters:
+        if self.name in self.dist.sampled_parameters:
             logger.warning("You have already drawn a sample from parameter "
                            "'{}'. The same sample will be "
                            "returned".format(self.name))
 
-        if len(self.mvg.current_sample) == 0:
+        if len(self.dist.current_sample) == 0:
             # generate a sample
-            self.mvg.sample(size=size, mode=mode)
+            self.dist.sample(size=size, **kwargs)
 
-        sample = self.mvg.current_sample[self.name]
+        sample = self.dist.current_sample[self.name]
 
-        if self.name not in self.mvg.sampled_parameters:
-            self.mvg.sampled_parameters.append(self.name)
+        if self.name not in self.dist.sampled_parameters:
+            self.dist.sampled_parameters.append(self.name)
 
-        if len(self.mvg.sampled_parameters) == len(self.mvg):
+        if len(self.dist.sampled_parameters) == len(self.dist):
             # reset samples
-            self.mvg.reset_sampled()
-
+            self.dist.reset_sampled()
         return sample
 
-    def prob(self, val):
-        """Return the prior probability of val
-
-        Parameters
-        ----------
-        val: float
-
-        Returns
-        -------
-        float:
-
-        """
-
-        return np.exp(self.ln_prob(val))
-
     def ln_prob(self, val):
         """
         Return the natural logarithm of the prior probability. Note that this
@@ -3074,14 +3126,14 @@ class MultivariateGaussian(Prior):
         """
 
         # add parameter value to multivariate Gaussian
-        self.mvg.requested_parameters[self.name] = val
+        self.dist.requested_parameters[self.name] = val
 
-        if self.mvg.filled_request():
+        if self.dist.filled_request():
             # all required parameters have been set
-            values = list(self.mvg.requested_parameters.values())
+            values = list(self.dist.requested_parameters.values())
 
             # check for the same number of values for each parameter
-            for i in range(len(self.mvg) - 1):
+            for i in range(len(self.dist) - 1):
                 if (isinstance(values[i], (list, np.ndarray)) or
                         isinstance(values[i + 1], (list, np.ndarray))):
                     if (isinstance(values[i], (list, np.ndarray)) and
@@ -3093,11 +3145,10 @@ class MultivariateGaussian(Prior):
                         raise ValueError("Each parameter must have the same "
                                          "number of requested values.")
 
-            lnp = self.mvg.ln_prob(np.asarray(values).T)
+            lnp = self.dist.ln_prob(np.asarray(values).T)
 
             # reset the requested parameters
-            self.mvg.reset_request()
-
+            self.dist.reset_request()
             return lnp
         else:
             # if not all parameters have been requested yet, just return 0
@@ -3115,27 +3166,27 @@ class MultivariateGaussian(Prior):
                 else:
                     return np.zeros_like(val)
 
-    @property
-    def minimum(self):
-        return self._minimum
+    def prob(self, val):
+        """Return the prior probability of val
 
-    @minimum.setter
-    def minimum(self, minimum):
-        self._minimum = minimum
+        Parameters
+        ----------
+        val: float
 
-        # update the bounds in the MultivariateGaussianDist
-        self.mvg.bounds[self.name] = (minimum, self.mvg.bounds[self.name][1])
+        Returns
+        -------
+        float:
 
-    @property
-    def maximum(self):
-        return self._maximum
+        """
+
+        return np.exp(self.ln_prob(val))
 
-    @maximum.setter
-    def maximum(self, maximum):
-        self._maximum = maximum
 
-        # update the bounds in the MultivariateGaussianDist
-        self.mvg.bounds[self.name] = (self.mvg.bounds[self.name][0], maximum)
+class MultivariateGaussian(JointPrior):
+    """
+    A synonmy class for MultiVariateNormal / deprecated, now use JointPrior with the dist object controlling
+    the type of joint prior
+    """
 
 
 class MultivariateNormal(MultivariateGaussian):
diff --git a/examples/core_examples/multivariate_gaussian_prior.py b/examples/core_examples/multivariate_gaussian_prior.py
index 53d8f94e47ac41833bd536d2f596f3a78cc76e8d..e362d88c8e31162ebd6a3c08120f9e51a71c36db 100644
--- a/examples/core_examples/multivariate_gaussian_prior.py
+++ b/examples/core_examples/multivariate_gaussian_prior.py
@@ -50,8 +50,8 @@ mvg = bilby.core.prior.MultivariateGaussianDist(names, nmodes=2, mus=mus,
                                                 corrcoefs=corrcoefs,
                                                 sigmas=sigmas, weights=weights)
 priors = dict()
-priors['m'] = bilby.core.prior.MultivariateGaussian(mvg, 'm')
-priors['c'] = bilby.core.prior.MultivariateGaussian(mvg, 'c')
+priors['m'] = bilby.core.prior.JointPrior(mvg, 'm')
+priors['c'] = bilby.core.prior.JointPrior(mvg, 'c')
 
 result = bilby.run_sampler(
     likelihood=likelihood, priors=priors, sampler='dynesty', nlive=4000,