diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index 3a53fe994448499caa40cf14deb8696f2d24baee..6d2e9b418c0e9c415440af6bfad0d2f3f1353abd 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -16,7 +16,8 @@ from .utils import logger, infer_args_from_method, check_directory_exists_and_if
 
 
 class PriorDict(OrderedDict):
-    def __init__(self, dictionary=None, filename=None):
+    def __init__(self, dictionary=None, filename=None,
+                 conversion_function=None):
         """ A set of priors
 
         Parameters
@@ -25,6 +26,9 @@ class PriorDict(OrderedDict):
             If given, a dictionary to generate the prior set.
         filename: str, None
             If given, a file containing the prior to generate the prior set.
+        conversion_function: func
+            Function to convert between sampled parameters and constraints.
+            Default is no conversion.
         """
         OrderedDict.__init__(self)
         if isinstance(dictionary, dict):
@@ -40,6 +44,35 @@ class PriorDict(OrderedDict):
 
         self.convert_floats_to_delta_functions()
 
+        if conversion_function is not None:
+            self.conversion_function = conversion_function
+        else:
+            self.conversion_function = self.default_conversion_function
+
+    def evaluate_constraints(self, sample):
+        out_sample = self.conversion_function(sample)
+        prob = 1
+        for key in self:
+            if isinstance(self[key], Constraint) and key in out_sample:
+                prob *= self[key].prob(out_sample[key])
+        return prob
+
+    def default_conversion_function(self, sample):
+        """
+        Placeholder parameter conversion function.
+
+        Parameters
+        ----------
+        sample: dict
+            Dictionary to convert
+
+        Returns
+        -------
+        sample: dict
+            Same as input
+        """
+        return sample
+
     def to_file(self, outdir, label):
         """ Write the prior distribution to file.
 
@@ -168,7 +201,7 @@ class PriorDict(OrderedDict):
         -------
         dict: Dictionary of the samples
         """
-        return self.sample_subset(keys=self.keys(), size=size)
+        return self.sample_subset_constrained(keys=list(self.keys()), size=size)
 
     def sample_subset(self, keys=iter([]), size=None):
         """Draw samples from the prior set for parameters which are not a DeltaFunction
@@ -188,11 +221,35 @@ class PriorDict(OrderedDict):
         samples = dict()
         for key in keys:
             if isinstance(self[key], Prior):
-                samples[key] = self[key].sample(size=size)
+                if isinstance(self[key], Constraint):
+                    continue
+                else:
+                    samples[key] = self[key].sample(size=size)
             else:
                 logger.debug('{} not a known prior.'.format(key))
         return samples
 
+    def sample_subset_constrained(self, keys=iter([]), size=None):
+        if size is None or size == 1:
+            while True:
+                sample = self.sample_subset(keys=keys, size=size)
+                if self.evaluate_constraints(sample):
+                    return sample
+        else:
+            needed = np.prod(size)
+            all_samples = {key: np.array([]) for key in keys}
+            _first_key = list(all_samples.keys())[0]
+            while len(all_samples[_first_key]) <= needed:
+                samples = self.sample_subset(keys=keys, size=needed)
+                keep = np.array(self.evaluate_constraints(samples), dtype=bool)
+                for key in samples:
+                    all_samples[key] = np.hstack(
+                        [all_samples[key], samples[key][keep].flatten()])
+            all_samples = {key: np.reshape(all_samples[key][:needed], size)
+                           for key in all_samples
+                           if not isinstance(self[key], Constraint)}
+            return all_samples
+
     def prob(self, sample, **kwargs):
         """
 
@@ -208,7 +265,14 @@ class PriorDict(OrderedDict):
         float: Joint probability of all individual sample probabilities
 
         """
-        return np.product([self[key].prob(sample[key]) for key in sample], **kwargs)
+        prob = np.product([self[key].prob(sample[key])
+                           for key in sample], **kwargs)
+        if prob == 0:
+            return 0
+        elif self.evaluate_constraints(sample):
+            return prob
+        else:
+            return 0
 
     def ln_prob(self, sample, axis=None):
         """
@@ -226,8 +290,14 @@ class PriorDict(OrderedDict):
             Joint log probability of all the individual sample probabilities
 
         """
-        return np.sum([self[key].ln_prob(sample[key]) for key in sample],
-                      axis=axis)
+        ln_prob = np.sum([self[key].ln_prob(sample[key])
+                          for key in sample], axis=axis)
+        if np.isinf(ln_prob):
+            return ln_prob
+        elif self.evaluate_constraints(sample):
+            return ln_prob
+        else:
+            return -np.inf
 
     def rescale(self, keys, theta):
         """Rescale samples from unit cube to prior
@@ -259,6 +329,8 @@ class PriorDict(OrderedDict):
         """
         redundant = False
         for key in self:
+            if isinstance(self[key], Constraint):
+                continue
             temp = self.copy()
             del temp[key]
             if temp.test_redundancy(key, disable_logging=True):
@@ -490,7 +562,7 @@ class Prior(object):
         bool: Whether it's fixed or not!
 
         """
-        return isinstance(self, DeltaFunction)
+        return isinstance(self, (Constraint, DeltaFunction))
 
     @property
     def latex_label(self):
@@ -553,6 +625,20 @@ class Prior(object):
         return label
 
 
+class Constraint(Prior):
+
+    def __init__(self, minimum, maximum, name=None, latex_label=None,
+                 unit=None):
+        Prior.__init__(self, minimum=minimum, maximum=maximum, name=name,
+                       latex_label=latex_label, unit=unit)
+
+    def prob(self, val):
+        return (val > self.minimum) & (val < self.maximum)
+
+    def ln_prob(self, val):
+        return np.log((val > self.minimum) & (val < self.maximum))
+
+
 class DeltaFunction(Prior):
 
     def __init__(self, peak, name=None, latex_label=None, unit=None):
diff --git a/bilby/core/result.py b/bilby/core/result.py
index 43c36491e4f09645be08b32858c0047e3b3fdd5d..82fb07792c4ee844a2e6ec5cb67ddbc6c271e4be 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -91,7 +91,8 @@ def read_in_result(filename=None, outdir=None, label=None, extension='json', gzi
 class Result(object):
     def __init__(self, label='no_label', outdir='.', sampler=None,
                  search_parameter_keys=None, fixed_parameter_keys=None,
-                 priors=None, sampler_kwargs=None, injection_parameters=None,
+                 constraint_parameter_keys=None, priors=None,
+                 sampler_kwargs=None, injection_parameters=None,
                  meta_data=None, posterior=None, samples=None,
                  nested_samples=None, log_evidence=np.nan,
                  log_evidence_err=np.nan, log_noise_evidence=np.nan,
@@ -106,9 +107,10 @@ class Result(object):
         ----------
         label, outdir, sampler: str
             The label, output directory, and sampler used
-        search_parameter_keys, fixed_parameter_keys: list
-            Lists of the search and fixed parameter keys. Elemenents of the
-            list should be of type `str` and matchs the keys of the `prior`
+        search_parameter_keys, fixed_parameter_keys, constraint_parameter_keys: list
+            Lists of the search, constraint, and fixed parameter keys.
+            Elements of the list should be of type `str` and match the keys
+            of the `prior`
         priors: dict, bilby.core.prior.PriorDict
             A dictionary of the priors used in the run
         sampler_kwargs: dict
@@ -155,6 +157,7 @@ class Result(object):
         self.sampler = sampler
         self.search_parameter_keys = search_parameter_keys
         self.fixed_parameter_keys = fixed_parameter_keys
+        self.constraint_parameter_keys = constraint_parameter_keys
         self.parameter_labels = parameter_labels
         self.parameter_labels_with_unit = parameter_labels_with_unit
         self.priors = priors
@@ -384,7 +387,8 @@ class Result(object):
             'label', 'outdir', 'sampler', 'log_evidence', 'log_evidence_err',
             'log_noise_evidence', 'log_bayes_factor', 'priors', 'posterior',
             'injection_parameters', 'meta_data', 'search_parameter_keys',
-            'fixed_parameter_keys', 'sampling_time', 'sampler_kwargs',
+            'fixed_parameter_keys', 'constraint_parameter_keys',
+            'sampling_time', 'sampler_kwargs',
             'log_likelihood_evaluations', 'log_prior_evaluations', 'samples',
             'nested_samples', 'walkers', 'nburn', 'parameter_labels',
             'parameter_labels_with_unit', 'version']
@@ -1004,8 +1008,12 @@ class Result(object):
             data_frame['log_likelihood'] = getattr(
                 self, 'log_likelihood_evaluations', np.nan)
             if self.log_prior_evaluations is None:
-                data_frame['log_prior'] = self.priors.ln_prob(
-                    data_frame[self.search_parameter_keys], axis=0)
+                ln_prior = list()
+                for ii in range(len(data_frame)):
+                    ln_prior.append(
+                        self.priors.ln_prob(dict(
+                            data_frame[self.search_parameter_keys].iloc[ii])))
+                data_frame['log_prior'] = np.array(ln_prior)
             else:
                 data_frame['log_prior'] = self.log_prior_evaluations
         if conversion_function is not None:
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index c07f208aec77fa1d27c41501d3c3833776f3d556..daa204293290a26dd2017ac32b52621a80a63a26 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -5,7 +5,7 @@ import numpy as np
 from pandas import DataFrame
 
 from ..utils import logger, command_line_args
-from ..prior import Prior, PriorDict
+from ..prior import Prior, PriorDict, DeltaFunction, Constraint
 from ..result import Result, read_in_result
 
 
@@ -102,8 +102,9 @@ class Sampler(object):
         self.external_sampler_function = None
         self.plot = plot
 
-        self.__search_parameter_keys = []
-        self.__fixed_parameter_keys = []
+        self._search_parameter_keys = list()
+        self._fixed_parameter_keys = list()
+        self._constraint_keys = list()
         self._initialise_parameters()
         self._verify_parameters()
         self._verify_use_ratio()
@@ -118,28 +119,33 @@ class Sampler(object):
     @property
     def search_parameter_keys(self):
         """list: List of parameter keys that are being sampled"""
-        return self.__search_parameter_keys
+        return self._search_parameter_keys
 
     @property
     def fixed_parameter_keys(self):
         """list: List of parameter keys that are not being sampled"""
-        return self.__fixed_parameter_keys
+        return self._fixed_parameter_keys
+
+    @property
+    def constraint_parameter_keys(self):
+        """list: List of parameters providing prior constraints"""
+        return self._constraint_parameter_keys
 
     @property
     def ndim(self):
         """int: Number of dimensions of the search parameter space"""
-        return len(self.__search_parameter_keys)
+        return len(self._search_parameter_keys)
 
     @property
     def kwargs(self):
         """dict: Container for the kwargs. Has more sophisticated logic in subclasses """
-        return self.__kwargs
+        return self._kwargs
 
     @kwargs.setter
     def kwargs(self, kwargs):
-        self.__kwargs = self.default_kwargs.copy()
+        self._kwargs = self.default_kwargs.copy()
         self._translate_kwargs(kwargs)
-        self.__kwargs.update(kwargs)
+        self._kwargs.update(kwargs)
         self._verify_kwargs_against_default_kwargs()
 
     def _translate_kwargs(self, kwargs):
@@ -179,17 +185,17 @@ class Sampler(object):
         for key in self.priors:
             if isinstance(self.priors[key], Prior) \
                     and self.priors[key].is_fixed is False:
-                self.__search_parameter_keys.append(key)
-            elif isinstance(self.priors[key], Prior) \
-                    and self.priors[key].is_fixed is True:
-                self.likelihood.parameters[key] = \
-                    self.priors[key].sample()
-                self.__fixed_parameter_keys.append(key)
+                self._search_parameter_keys.append(key)
+            elif isinstance(self.priors[key], Constraint):
+                self._constraint_keys.append(key)
+            elif isinstance(self.priors[key], DeltaFunction):
+                self.likelihood.parameters[key] = self.priors[key].sample()
+                self._fixed_parameter_keys.append(key)
 
         logger.info("Search parameters:")
-        for key in self.__search_parameter_keys:
+        for key in self._search_parameter_keys + self._constraint_keys:
             logger.info('  {} = {}'.format(key, self.priors[key]))
-        for key in self.__fixed_parameter_keys:
+        for key in self._fixed_parameter_keys:
             logger.info('  {} = {}'.format(key, self.priors[key].peak))
 
     def _initialise_result(self, result_class):
@@ -202,8 +208,9 @@ class Sampler(object):
         result_kwargs = dict(
             label=self.label, outdir=self.outdir,
             sampler=self.__class__.__name__.lower(),
-            search_parameter_keys=self.__search_parameter_keys,
-            fixed_parameter_keys=self.__fixed_parameter_keys,
+            search_parameter_keys=self._search_parameter_keys,
+            fixed_parameter_keys=self._fixed_parameter_keys,
+            constraint_parameter_keys=self._constraint_keys,
             priors=self.priors, meta_data=self.meta_data,
             injection_parameters=self.injection_parameters,
             sampler_kwargs=self.kwargs)
@@ -227,6 +234,8 @@ class Sampler(object):
             prior can't be sampled.
         """
         for key in self.priors:
+            if isinstance(self.priors[key], Constraint):
+                continue
             try:
                 self.likelihood.parameters[key] = self.priors[key].sample()
             except AttributeError as e:
@@ -248,7 +257,9 @@ class Sampler(object):
         self._check_if_priors_can_be_sampled()
         try:
             t1 = datetime.datetime.now()
-            self.likelihood.log_likelihood()
+            theta = [self.priors[key].sample()
+                     for key in self._search_parameter_keys]
+            self.log_likelihood(theta)
             self._log_likelihood_eval_time = (
                 datetime.datetime.now() - t1).total_seconds()
             if self._log_likelihood_eval_time == 0:
@@ -296,7 +307,7 @@ class Sampler(object):
         -------
         list: Properly rescaled sampled values
         """
-        return self.priors.rescale(self.__search_parameter_keys, theta)
+        return self.priors.rescale(self._search_parameter_keys, theta)
 
     def log_prior(self, theta):
         """
@@ -308,11 +319,12 @@ class Sampler(object):
 
         Returns
         -------
-        float: TODO: Fill in proper explanation of what this is.
+        float: Joint ln prior probability of theta
 
         """
-        return self.priors.ln_prob({
-            key: t for key, t in zip(self.__search_parameter_keys, theta)})
+        params = {
+            key: t for key, t in zip(self._search_parameter_keys, theta)}
+        return self.priors.ln_prob(params)
 
     def log_likelihood(self, theta):
         """
@@ -328,8 +340,9 @@ class Sampler(object):
             likelihood.parameter values
 
         """
-        for i, k in enumerate(self.__search_parameter_keys):
-            self.likelihood.parameters[k] = theta[i]
+        params = {
+            key: t for key, t in zip(self._search_parameter_keys, theta)}
+        self.likelihood.parameters.update(params)
         if self.use_ratio:
             return self.likelihood.log_likelihood_ratio()
         else:
@@ -347,7 +360,7 @@ class Sampler(object):
         """
         new_sample = self.priors.sample()
         draw = np.array(list(new_sample[key]
-                             for key in self.__search_parameter_keys))
+                             for key in self._search_parameter_keys))
         self.check_draw(draw)
         return draw
 
@@ -459,6 +472,26 @@ class NestedSampler(Sampler):
             idxs.append(idx[0])
         return unsorted_loglikelihoods[idxs]
 
+    def log_likelihood(self, theta):
+        """
+        Since some nested samplers don't call the log_prior method, evaluate
+        the prior constraint here.
+
+        Parameters
+        theta: array-like
+            Parameter values at which to evaluate likelihood
+
+        Returns
+        -------
+        float: log_likelihood
+        """
+        if self.priors.evaluate_constraints({
+                key: theta[ii] for ii, key in
+                enumerate(self.search_parameter_keys)}):
+            return Sampler.log_likelihood(self, theta)
+        else:
+            return np.nan_to_num(-np.inf)
+
 
 class MCMCSampler(Sampler):
     nwalkers_equiv_kwargs = ['nwalker', 'nwalkers', 'draws']
diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py
index 2c5dc5cab870d7bb60ce3d219bd2cf475f7aa43f..359e445f5fe36368a444c68e24e3cf0d9f7ee800 100644
--- a/bilby/core/sampler/pymc3.py
+++ b/bilby/core/sampler/pymc3.py
@@ -4,9 +4,8 @@ from collections import OrderedDict
 
 import numpy as np
 
-from ..utils import derivatives, logger, infer_args_from_method
-from ..prior import Prior, DeltaFunction, Sine, Cosine, PowerLaw
-from ..result import Result
+from ..utils import derivatives, infer_args_from_method
+from ..prior import DeltaFunction, Sine, Cosine, PowerLaw
 from .base_sampler import Sampler, MCMCSampler
 from ..likelihood import GaussianLikelihood, PoissonLikelihood, ExponentialLikelihood, \
     StudentTLikelihood
@@ -67,8 +66,8 @@ class Pymc3(MCMCSampler):
         Sampler.__init__(self, likelihood, priors, outdir=outdir, label=label,
                          use_ratio=use_ratio, plot=plot,
                          skip_import_verification=skip_import_verification, **kwargs)
-        self.draws = self.__kwargs['draws']
-        self.chains = self.__kwargs['chains']
+        self.draws = self._kwargs['draws']
+        self.chains = self._kwargs['chains']
 
     @staticmethod
     def _import_external_sampler():
@@ -97,71 +96,6 @@ class Pymc3(MCMCSampler):
         """
         pass
 
-    def _initialise_parameters(self):
-        """
-        Change `_initialise_parameters()`, so that it does call the `sample`
-        method in the Prior class.
-
-        """
-
-        self.__search_parameter_keys = []
-        self.__fixed_parameter_keys = []
-
-        for key in self.priors:
-            if isinstance(self.priors[key], Prior) \
-                    and self.priors[key].is_fixed is False:
-                self.__search_parameter_keys.append(key)
-            elif isinstance(self.priors[key], Prior) \
-                    and self.priors[key].is_fixed is True:
-                self.__fixed_parameter_keys.append(key)
-
-        logger.info("Search parameters:")
-        for key in self.__search_parameter_keys:
-            logger.info('  {} = {}'.format(key, self.priors[key]))
-        for key in self.__fixed_parameter_keys:
-            logger.info('  {} = {}'.format(key, self.priors[key].peak))
-
-    def _initialise_result(self, result_class):
-        """
-        Initialise results within Pymc3 subclass.
-        """
-
-        result_kwargs = dict(
-            label=self.label, outdir=self.outdir,
-            sampler=self.__class__.__name__.lower(),
-            search_parameter_keys=self.__search_parameter_keys,
-            fixed_parameter_keys=self.__fixed_parameter_keys,
-            priors=self.priors, meta_data=self.meta_data,
-            injection_parameters=self.injection_parameters,
-            sampler_kwargs=self.kwargs)
-
-        if result_class is None:
-            result = Result(**result_kwargs)
-        elif issubclass(result_class, Result):
-            result = result_class(**result_kwargs)
-        else:
-            raise ValueError(
-                "Input result_class={} not understood".format(result_class))
-
-        return result
-
-    @property
-    def kwargs(self):
-        """ Ensures that proper keyword arguments are used for the Pymc3 sampler.
-
-        Returns
-        -------
-        dict: Keyword arguments used for the Nestle Sampler
-
-        """
-        return self.__kwargs
-
-    @kwargs.setter
-    def kwargs(self, kwargs):
-        self.__kwargs = self.default_kwargs.copy()
-        self.__kwargs.update(kwargs)
-        self._verify_kwargs_against_default_kwargs()
-
     def setup_prior_mapping(self):
         """
         Set the mapping between predefined bilby priors and the equivalent
@@ -393,8 +327,8 @@ class Pymc3(MCMCSampler):
         # set the step method
         pymc3, STEP_METHODS, floatX = self._import_external_sampler()
         step_methods = {m.__name__.lower(): m.__name__ for m in STEP_METHODS}
-        if 'step' in self.__kwargs:
-            self.step_method = self.__kwargs.pop('step')
+        if 'step' in self._kwargs:
+            self.step_method = self._kwargs.pop('step')
 
             # 'step' could be a dictionary of methods for different parameters,
             # so check for this
@@ -402,7 +336,7 @@ class Pymc3(MCMCSampler):
                 pass
             elif isinstance(self.step_method, (dict, OrderedDict)):
                 for key in self.step_method:
-                    if key not in self.__search_parameter_keys:
+                    if key not in self._search_parameter_keys:
                         raise ValueError("Setting a step method for an unknown parameter '{}'".format(key))
                     else:
                         # check if using a compound step (a list of step
@@ -780,11 +714,11 @@ class Pymc3(MCMCSampler):
                 pymc3.StudentT('likelihood', nu=self.likelihood.nu, mu=model, sd=self.likelihood.sigma,
                                observed=self.likelihood.y)
             elif isinstance(self.likelihood, (GravitationalWaveTransient, BasicGravitationalWaveTransient)):
-                # set theano Op - pass __search_parameter_keys, which only contains non-fixed variables
-                logl = LogLike(self.__search_parameter_keys, self.likelihood, self.pymc3_priors)
+                # set theano Op - pass _search_parameter_keys, which only contains non-fixed variables
+                logl = LogLike(self._search_parameter_keys, self.likelihood, self.pymc3_priors)
 
                 parameters = OrderedDict()
-                for key in self.__search_parameter_keys:
+                for key in self._search_parameter_keys:
                     try:
                         parameters[key] = self.pymc3_priors[key]
                     except KeyError:
diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py
index ec39959c956b10e28a7a2ec3fb76e0cf35848ab8..7814e1f45657c860f2acce9b1647d594468a6c5c 100644
--- a/bilby/gw/conversion.py
+++ b/bilby/gw/conversion.py
@@ -841,8 +841,8 @@ def generate_component_spins(sample):
     """
     output_sample = sample.copy()
     spin_conversion_parameters =\
-        ['theta_jn', 'phi_jl', 'tilt_1', 'tilt_2', 'phi_12', 'a_1', 'a_2', 'mass_1',
-         'mass_2', 'reference_frequency', 'phase']
+        ['theta_jn', 'phi_jl', 'tilt_1', 'tilt_2', 'phi_12', 'a_1', 'a_2',
+         'mass_1', 'mass_2', 'reference_frequency', 'phase']
     if all(key in output_sample.keys() for key in spin_conversion_parameters):
         output_sample['iota'], output_sample['spin_1x'],\
             output_sample['spin_1y'], output_sample['spin_1z'], \
diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py
index 2ff4128eebd60cae69c742352ba6a27627aad065..58790e41cff09904f4accfe656f3a01f7ead8e70 100644
--- a/bilby/gw/prior.py
+++ b/bilby/gw/prior.py
@@ -4,8 +4,12 @@ import numpy as np
 from scipy.interpolate import UnivariateSpline
 
 from ..core.prior import (PriorDict, Uniform, Prior, DeltaFunction, Gaussian,
-                          Interped)
+                          Interped, Constraint)
 from ..core.utils import infer_args_from_method, logger
+from .conversion import (
+    convert_to_lal_binary_black_hole_parameters,
+    convert_to_lal_binary_neutron_star_parameters, generate_mass_parameters,
+    generate_tidal_parameters, fill_from_fixed_priors)
 from .cosmology import get_cosmology
 
 try:
@@ -193,7 +197,8 @@ class AlignedSpin(Interped):
 
 
 class BBHPriorDict(PriorDict):
-    def __init__(self, dictionary=None, filename=None, aligned_spin=False):
+    def __init__(self, dictionary=None, filename=None, aligned_spin=False,
+                 conversion_function=None):
         """ Initialises a Prior set for Binary Black holes
 
         Parameters
@@ -202,6 +207,10 @@ class BBHPriorDict(PriorDict):
             See superclass
         filename: str, optional
             See superclass
+        conversion_function: func
+            Function to convert between sampled parameters and constraints.
+            By default this generates many additional parameters, see
+            BBHPriorDict.default_conversion_function
         """
         basedir = os.path.join(os.path.dirname(__file__), 'prior_files')
         if dictionary is None and filename is None:
@@ -214,7 +223,36 @@ class BBHPriorDict(PriorDict):
         elif filename is not None:
             if not os.path.isfile(filename):
                 filename = os.path.join(os.path.dirname(__file__), 'prior_files', filename)
-        PriorDict.__init__(self, dictionary=dictionary, filename=filename)
+        PriorDict.__init__(self, dictionary=dictionary, filename=filename,
+                           conversion_function=conversion_function)
+
+    def default_conversion_function(self, sample):
+        """
+        Default parameter conversion function for BBH signals.
+
+        This generates:
+        - the parameters passed to source.lal_binary_black_hole
+        - all mass parameters
+
+        It does not generate:
+        - component spins
+        - source-frame parameters
+
+        Parameters
+        ----------
+        sample: dict
+            Dictionary to convert
+
+        Returns
+        -------
+        sample: dict
+            Same as input
+        """
+        out_sample = fill_from_fixed_priors(sample, self)
+        out_sample, _ = convert_to_lal_binary_black_hole_parameters(out_sample)
+        out_sample = generate_mass_parameters(out_sample)
+
+        return out_sample
 
     def test_redundancy(self, key, disable_logging=False):
         """
@@ -237,6 +275,9 @@ class BBHPriorDict(PriorDict):
             logger.debug('{} already in prior'.format(key))
             return True
 
+        sampling_parameters = {key for key in self if not isinstance(
+            self[key], (DeltaFunction, Constraint))}
+
         mass_parameters = {'mass_1', 'mass_2', 'chirp_mass', 'total_mass', 'mass_ratio', 'symmetric_mass_ratio'}
         spin_tilt_1_parameters = {'tilt_1', 'cos_tilt_1'}
         spin_tilt_2_parameters = {'tilt_2', 'cos_tilt_2'}
@@ -250,7 +291,8 @@ class BBHPriorDict(PriorDict):
                      spin_tilt_1_parameters, spin_tilt_2_parameters,
                      inclination_parameters, distance_parameters]):
             if key in parameter_set:
-                if len(parameter_set.intersection(self)) >= independent_parameters:
+                if len(parameter_set.intersection(
+                        sampling_parameters)) >= independent_parameters:
                     logger.disabled = disable_logging
                     logger.warning('{} already in prior. '
                                    'This may lead to unexpected behaviour.'
@@ -262,7 +304,8 @@ class BBHPriorDict(PriorDict):
 
 class BNSPriorDict(PriorDict):
 
-    def __init__(self, dictionary=None, filename=None, aligned_spin=True):
+    def __init__(self, dictionary=None, filename=None, aligned_spin=True,
+                 conversion_function=None):
         """ Initialises a Prior set for Binary Neutron Stars
 
         Parameters
@@ -271,6 +314,10 @@ class BNSPriorDict(PriorDict):
             See superclass
         filename: str, optional
             See superclass
+        conversion_function: func
+            Function to convert between sampled parameters and constraints.
+            By default this generates many additional parameters, see
+            BNSPriorDict.default_conversion_function
         """
         if not aligned_spin:
             logger.warning('Non-aligned spins not yet supported for BNS.')
@@ -280,7 +327,37 @@ class BNSPriorDict(PriorDict):
         elif filename is not None:
             if not os.path.isfile(filename):
                 filename = os.path.join(os.path.dirname(__file__), 'prior_files', filename)
-        PriorDict.__init__(self, dictionary=dictionary, filename=filename)
+        PriorDict.__init__(self, dictionary=dictionary, filename=filename,
+                           conversion_function=conversion_function)
+
+    def default_conversion_function(self, sample):
+        """
+        Default parameter conversion function for BNS signals.
+
+        This generates:
+        - the parameters passed to source.lal_binary_neutron_star
+        - all mass parameters
+        - all tidal parameters
+
+        It does not generate:
+        - component spins
+        - source-frame parameters
+
+        Parameters
+        ----------
+        sample: dict
+            Dictionary to convert
+
+        Returns
+        -------
+        sample: dict
+            Same as input
+        """
+        out_sample = fill_from_fixed_priors(sample, self)
+        out_sample, _ = convert_to_lal_binary_neutron_star_parameters(out_sample)
+        out_sample = generate_mass_parameters(out_sample)
+        out_sample = generate_tidal_parameters(out_sample)
+        return out_sample
 
     def test_redundancy(self, key, disable_logging=False):
         logger.disabled = disable_logging
@@ -292,18 +369,21 @@ class BNSPriorDict(PriorDict):
             return True
         redundant = False
 
+        sampling_parameters = {key for key in self if not isinstance(
+            self[key], (DeltaFunction, Constraint))}
+
         tidal_parameters = \
             {'lambda_1', 'lambda_2', 'lambda_tilde', 'delta_lambda'}
 
         if key in tidal_parameters:
-            if len(tidal_parameters.intersection(self)) > 2:
+            if len(tidal_parameters.intersection(sampling_parameters)) > 2:
                 redundant = True
                 logger.disabled = disable_logging
                 logger.warning('{} already in prior. '
                                'This may lead to unexpected behaviour.'
                                .format(tidal_parameters.intersection(self)))
                 logger.disabled = False
-            elif len(tidal_parameters.intersection(self)) == 2:
+            elif len(tidal_parameters.intersection(sampling_parameters)) == 2:
                 redundant = True
         return redundant
 
diff --git a/bilby/gw/prior_files/GW150914.prior b/bilby/gw/prior_files/GW150914.prior
index 0eef13a9734f8f0e37dc6c578e3d3a0721e28dad..410fc927682662918f3d229640b564b672e4a86e 100644
--- a/bilby/gw/prior_files/GW150914.prior
+++ b/bilby/gw/prior_files/GW150914.prior
@@ -1,6 +1,7 @@
 # These are the default priors for analysing GW150914.
 mass_1 = Uniform(name='mass_1', minimum=30, maximum=50, unit='$M_{\\odot}$')
 mass_2 = Uniform(name='mass_2', minimum=20, maximum=40, unit='$M_{\\odot}$')
+mass_ratio =  Constraint(name='mass_ratio', minimum=0.125, maximum=1)
 a_1 =  Uniform(name='a_1', minimum=0, maximum=0.8)
 a_2 =  Uniform(name='a_2', minimum=0, maximum=0.8)
 tilt_1 =  Sine(name='tilt_1')
diff --git a/bilby/gw/prior_files/aligned_spin_binary_black_holes.prior b/bilby/gw/prior_files/aligned_spin_binary_black_holes.prior
index 52649990cf3ba75125caa21572efa01fcf9ddf84..ca49a506b37a1fd7d8e855029932e8c437335a4c 100644
--- a/bilby/gw/prior_files/aligned_spin_binary_black_holes.prior
+++ b/bilby/gw/prior_files/aligned_spin_binary_black_holes.prior
@@ -4,6 +4,7 @@
 # Lines beginning "#" are ignored.
 mass_1 = Uniform(name='mass_1', minimum=5, maximum=100, unit='$M_{\\odot}$')
 mass_2 = Uniform(name='mass_2', minimum=5, maximum=100, unit='$M_{\\odot}$')
+mass_ratio =  Constraint(name='mass_ratio', minimum=0.125, maximum=1)
 # chirp_mass = Uniform(name='chirp_mass', minimum=25, maximum=100, unit='$M_{\\odot}$')
 # total_mass =  Uniform(name='total_mass', minimum=10, maximum=200, unit='$M_{\\odot}$')
 # mass_ratio =  Uniform(name='mass_ratio', minimum=0.125, maximum=1)
diff --git a/bilby/gw/prior_files/binary_black_holes.prior b/bilby/gw/prior_files/binary_black_holes.prior
index e41cc73e2f765fbddd821693604ac15e302586f2..e79bd5baa754eb09aeb2a9de41b47adf82b9452c 100644
--- a/bilby/gw/prior_files/binary_black_holes.prior
+++ b/bilby/gw/prior_files/binary_black_holes.prior
@@ -4,6 +4,7 @@
 # Lines beginning "#" are ignored.
 mass_1 = Uniform(name='mass_1', minimum=5, maximum=100, unit='$M_{\\odot}$')
 mass_2 = Uniform(name='mass_2', minimum=5, maximum=100, unit='$M_{\\odot}$')
+mass_ratio =  Constraint(name='mass_ratio', minimum=0.125, maximum=1)
 # chirp_mass = Uniform(name='chirp_mass', minimum=25, maximum=100, unit='$M_{\\odot}$')
 # total_mass =  Uniform(name='total_mass', minimum=10, maximum=200, unit='$M_{\\odot}$')
 # mass_ratio =  Uniform(name='mass_ratio', minimum=0.125, maximum=1)
diff --git a/bilby/gw/prior_files/binary_neutron_stars.prior b/bilby/gw/prior_files/binary_neutron_stars.prior
index 08ec88f2ffb8bd6536563cb20723aabb589c7c13..eff1a2f2a4923b520cf3e242ee1bc1545cb6c64f 100644
--- a/bilby/gw/prior_files/binary_neutron_stars.prior
+++ b/bilby/gw/prior_files/binary_neutron_stars.prior
@@ -4,6 +4,7 @@
 # Lines beginning "#" are ignored.
 mass_1 = Uniform(name='mass_1', minimum=1, maximum=2, unit='$M_{\\odot}$')
 mass_2 = Uniform(name='mass_2', minimum=1, maximum=2, unit='$M_{\\odot}$')
+mass_ratio =  Constraint(name='mass_ratio', minimum=0.125, maximum=1)
 # chirp_mass = Uniform(name='chirp_mass', minimum=0.87, maximum=1.74, unit='$M_{\\odot}$')
 # total_mass =  Uniform(name='total_mass', minimum=2, maximum=4, unit='$M_{\\odot}$')
 # mass_ratio =  Uniform(name='mass_ratio', minimum=0.5, maximum=1)
diff --git a/bilby/gw/source.py b/bilby/gw/source.py
index 29f033ba8e02db1cdb1723979376d021468fe8c5..ecbefa1511478723fa374cc1f207b256d77fc03b 100644
--- a/bilby/gw/source.py
+++ b/bilby/gw/source.py
@@ -221,9 +221,6 @@ def _base_lal_cbc_fd_waveform(
     frequency_bounds = ((frequency_array >= minimum_frequency) *
                         (frequency_array <= maximum_frequency))
 
-    if mass_2 > mass_1:
-        return None
-
     luminosity_distance = luminosity_distance * 1e6 * utils.parsec
     mass_1 = mass_1 * utils.solar_mass
     mass_2 = mass_2 * utils.solar_mass
@@ -378,9 +375,6 @@ def roq(frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1,
         quadratic frequency nodes.
 
     """
-    if mass_2 > mass_1:
-        return None
-
     frequency_nodes_linear = waveform_arguments['frequency_nodes_linear']
     frequency_nodes_quadratic = waveform_arguments['frequency_nodes_quadratic']
     reference_frequency = getattr(waveform_arguments,
diff --git a/examples/injection_examples/roq_example.py b/examples/injection_examples/roq_example.py
index 319251dccd89e975c76d503d88e8c753a1b51300..6d54aaa99459347cc5d791f8798d0f0450903ed7 100644
--- a/examples/injection_examples/roq_example.py
+++ b/examples/injection_examples/roq_example.py
@@ -32,7 +32,7 @@ duration = 4
 sampling_frequency = 2048
 
 injection_parameters = dict(
-    chirp_mass=36., mass_ratio=0.9, a_1=0.4, a_2=0.3, tilt_1=0.0, tilt_2=0.0,
+    mass_1=36.0, mass_2=29.0, a_1=0.4, a_2=0.3, tilt_1=0.0, tilt_2=0.0,
     phi_12=1.7, phi_jl=0.3, luminosity_distance=1000., theta_jn=0.4, psi=0.659,
     phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108)
 
@@ -62,15 +62,14 @@ search_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
                             approximant='IMRPhenomPv2'),
     parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters)
 
+# Here we add constraints on chirp mass and mass ratio to the prior
 priors = bilby.gw.prior.BBHPriorDict()
 for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'theta_jn', 'phase', 'psi', 'ra',
             'dec', 'phi_12', 'phi_jl', 'luminosity_distance']:
     priors[key] = injection_parameters[key]
-priors.pop('mass_1')
-priors.pop('mass_2')
-priors['chirp_mass'] = bilby.core.prior.Uniform(
-    15, 40, latex_label='$\\mathcal{M}$')
-priors['mass_ratio'] = bilby.core.prior.Uniform(0.5, 1, latex_label='$q$')
+priors['chirp_mass'] = bilby.core.prior.Constraint(
+    name='chirp_mass', minimum=12.5, maximum=45)
+priors['mass_ratio'] = bilby.core.prior.Constraint(0.125, 1, name='mass_ratio')
 priors['geocent_time'] = bilby.core.prior.Uniform(
     injection_parameters['geocent_time'] - 0.1,
     injection_parameters['geocent_time'] + 0.1, latex_label='$t_c$', unit='s')
diff --git a/test/gw_likelihood_test.py b/test/gw_likelihood_test.py
index 0811070fc161af2b8619d7c19ea9682860ccce00..1c29e03dcb0e0c4ddfc644d82e0a269275035fce 100644
--- a/test/gw_likelihood_test.py
+++ b/test/gw_likelihood_test.py
@@ -54,10 +54,10 @@ class TestBasicGWTransient(unittest.TestCase):
     def test_likelihood_zero_when_waveform_is_none(self):
         """Test log likelihood returns np.nan_to_num(-np.inf) when the
         waveform is None"""
-        self.likelihood.parameters['mass_2'] = 32
+        self.likelihood.waveform_generator.frequency_domain_strain = \
+            lambda x: None
         self.assertEqual(self.likelihood.log_likelihood_ratio(),
                          np.nan_to_num(-np.inf))
-        self.likelihood.parameters['mass_2'] = 29
 
     def test_repr(self):
         expected = 'BasicGravitationalWaveTransient(interferometers={},\n\twaveform_generator={})'.format(
@@ -123,10 +123,10 @@ class TestGWTransient(unittest.TestCase):
     def test_likelihood_zero_when_waveform_is_none(self):
         """Test log likelihood returns np.nan_to_num(-np.inf) when the
         waveform is None"""
-        self.likelihood.parameters['mass_2'] = 32
+        self.likelihood.waveform_generator.frequency_domain_strain =\
+            lambda x: None
         self.assertEqual(self.likelihood.log_likelihood_ratio(),
                          np.nan_to_num(-np.inf))
-        self.likelihood.parameters['mass_2'] = 29
 
     def test_repr(self):
         expected = 'GravitationalWaveTransient(interferometers={},\n\twaveform_generator={},\n\t' \
diff --git a/test/gw_prior_test.py b/test/gw_prior_test.py
index 1eac2e76ca9311323edc6033668611cee5780a31..893d4f3de0ed3391df17b17603e9588d71d2f80e 100644
--- a/test/gw_prior_test.py
+++ b/test/gw_prior_test.py
@@ -57,7 +57,7 @@ class TestBBHPriorDict(unittest.TestCase):
 
     def test_correct_not_redundant_priors_masses(self):
         del self.bbh_prior_dict['mass_2']
-        for prior in ['mass_2', 'chirp_mass', 'total_mass', 'mass_ratio',  'symmetric_mass_ratio']:
+        for prior in ['mass_2', 'chirp_mass', 'total_mass',  'symmetric_mass_ratio']:
             self.assertFalse(self.bbh_prior_dict.test_redundancy(prior))
 
     def test_correct_not_redundant_priors_spin_magnitudes(self):
@@ -102,6 +102,11 @@ class TestBBHPriorDict(unittest.TestCase):
             self.assertTrue(self.bbh_prior_dict.test_has_redundant_keys())
             del self.bbh_prior_dict[prior]
 
+    def test_add_constraint_prior_not_redundant(self):
+        self.bbh_prior_dict['chirp_mass'] = bilby.prior.Constraint(
+            minimum=20, maximum=40, name='chirp_mass')
+        self.assertFalse(self.bbh_prior_dict.test_has_redundant_keys())
+
 
 class TestBNSPriorDict(unittest.TestCase):
 
@@ -151,7 +156,7 @@ class TestBNSPriorDict(unittest.TestCase):
 
     def test_correct_not_redundant_priors_masses(self):
         del self.bns_prior_dict['mass_2']
-        for prior in ['mass_2', 'chirp_mass', 'total_mass', 'mass_ratio',  'symmetric_mass_ratio']:
+        for prior in ['mass_2', 'chirp_mass', 'total_mass',  'symmetric_mass_ratio']:
             self.assertFalse(self.bns_prior_dict.test_redundancy(prior))
 
     def test_correct_not_redundant_priors_spin_magnitudes(self):
@@ -185,6 +190,11 @@ class TestBNSPriorDict(unittest.TestCase):
             self.assertTrue(self.bns_prior_dict.test_has_redundant_keys())
             del self.bns_prior_dict[prior]
 
+    def test_add_constraint_prior_not_redundant(self):
+        self.bns_prior_dict['chirp_mass'] = bilby.prior.Constraint(
+            minimum=1, maximum=2, name='chirp_mass')
+        self.assertFalse(self.bns_prior_dict.test_has_redundant_keys())
+
 
 class TestCalibrationPrior(unittest.TestCase):
 
diff --git a/test/gw_source_test.py b/test/gw_source_test.py
index 9c4562ffb4f3b568fa17f3e669228a7e8d9a030a..82316e8253089b20e97bbb617829af73f61f58d1 100644
--- a/test/gw_source_test.py
+++ b/test/gw_source_test.py
@@ -29,13 +29,6 @@ class TestLalBBH(unittest.TestCase):
             bilby.gw.source.lal_binary_black_hole(
                 self.frequency_array, **self.parameters), dict)
 
-    def test_mass_ratio_greater_one_returns_none(self):
-        self.parameters['mass_2'] = 1000.0
-        self.parameters.update(self.waveform_kwargs)
-        self.assertIsNone(
-            bilby.gw.source.lal_binary_black_hole(
-                self.frequency_array, **self.parameters), dict)
-
     def test_lal_bbh_works_without_waveform_parameters(self):
         self.assertIsInstance(
             bilby.gw.source.lal_binary_black_hole(
@@ -72,13 +65,6 @@ class TestLalBNS(unittest.TestCase):
             bilby.gw.source.lal_binary_neutron_star(
                 self.frequency_array, **self.parameters), dict)
 
-    def test_mass_ratio_greater_one_returns_none(self):
-        self.parameters['mass_2'] = 1000.0
-        self.parameters.update(self.waveform_kwargs)
-        self.assertIsNone(
-            bilby.gw.source.lal_binary_neutron_star(
-                self.frequency_array, **self.parameters), dict)
-
     def test_lal_bns_works_without_waveform_parameters(self):
         self.assertIsInstance(
             bilby.gw.source.lal_binary_neutron_star(
@@ -123,13 +109,6 @@ class TestEccentricLalBBH(unittest.TestCase):
             bilby.gw.source.lal_eccentric_binary_black_hole_no_spins(
                 self.frequency_array, **self.parameters), dict)
 
-    def test_mass_ratio_greater_one_returns_none(self):
-        self.parameters['mass_2'] = 1000.0
-        self.parameters.update(self.waveform_kwargs)
-        self.assertIsNone(
-            bilby.gw.source.lal_eccentric_binary_black_hole_no_spins(
-                self.frequency_array, **self.parameters), dict)
-
     def test_lal_ebbh_works_without_waveform_parameters(self):
         self.assertIsInstance(
             bilby.gw.source.lal_eccentric_binary_black_hole_no_spins(
@@ -155,8 +134,8 @@ class TestROQBBH(unittest.TestCase):
 
         self.parameters = dict(
             mass_1=30.0, mass_2=30.0, luminosity_distance=400.0, a_1=0.0,
-            tilt_1=0.0, phi_12=0.0, a_2=0.0, tilt_2=0.0, phi_jl=0.0, theta_jn=0.0,
-            phase=0.0)
+            tilt_1=0.0, phi_12=0.0, a_2=0.0, tilt_2=0.0, phi_jl=0.0,
+            theta_jn=0.0, phase=0.0)
         self.waveform_kwargs = dict(
             frequency_nodes_linear=fnodes_linear,
             frequency_nodes_quadratic=fnodes_quadratic,
@@ -174,12 +153,6 @@ class TestROQBBH(unittest.TestCase):
         self.assertIsInstance(
             bilby.gw.source.roq(self.frequency_array, **self.parameters), dict)
 
-    def test_mass_ratio_greater_one_returns_none(self):
-        self.parameters['mass_2'] = 1000.0
-        self.parameters.update(self.waveform_kwargs)
-        self.assertIsNone(
-            bilby.gw.source.roq(self.frequency_array, **self.parameters), dict)
-
     def test_roq_fails_without_frequency_nodes(self):
         self.parameters.update(self.waveform_kwargs)
         del self.parameters['frequency_nodes_linear']
diff --git a/test/prior_files/binary_black_holes.prior b/test/prior_files/binary_black_holes.prior
index e41cc73e2f765fbddd821693604ac15e302586f2..e79bd5baa754eb09aeb2a9de41b47adf82b9452c 100644
--- a/test/prior_files/binary_black_holes.prior
+++ b/test/prior_files/binary_black_holes.prior
@@ -4,6 +4,7 @@
 # Lines beginning "#" are ignored.
 mass_1 = Uniform(name='mass_1', minimum=5, maximum=100, unit='$M_{\\odot}$')
 mass_2 = Uniform(name='mass_2', minimum=5, maximum=100, unit='$M_{\\odot}$')
+mass_ratio =  Constraint(name='mass_ratio', minimum=0.125, maximum=1)
 # chirp_mass = Uniform(name='chirp_mass', minimum=25, maximum=100, unit='$M_{\\odot}$')
 # total_mass =  Uniform(name='total_mass', minimum=10, maximum=200, unit='$M_{\\odot}$')
 # mass_ratio =  Uniform(name='mass_ratio', minimum=0.125, maximum=1)
diff --git a/test/prior_files/binary_neutron_stars.prior b/test/prior_files/binary_neutron_stars.prior
index 08ec88f2ffb8bd6536563cb20723aabb589c7c13..eff1a2f2a4923b520cf3e242ee1bc1545cb6c64f 100644
--- a/test/prior_files/binary_neutron_stars.prior
+++ b/test/prior_files/binary_neutron_stars.prior
@@ -4,6 +4,7 @@
 # Lines beginning "#" are ignored.
 mass_1 = Uniform(name='mass_1', minimum=1, maximum=2, unit='$M_{\\odot}$')
 mass_2 = Uniform(name='mass_2', minimum=1, maximum=2, unit='$M_{\\odot}$')
+mass_ratio =  Constraint(name='mass_ratio', minimum=0.125, maximum=1)
 # chirp_mass = Uniform(name='chirp_mass', minimum=0.87, maximum=1.74, unit='$M_{\\odot}$')
 # total_mass =  Uniform(name='total_mass', minimum=2, maximum=4, unit='$M_{\\odot}$')
 # mass_ratio =  Uniform(name='mass_ratio', minimum=0.5, maximum=1)
diff --git a/test/prior_test.py b/test/prior_test.py
index c6fe849b9e6aaf7c8f4d66336910769ebe6acd91..6a2afc750db4a03beee8a99d9c6167a99dd6cb71 100644
--- a/test/prior_test.py
+++ b/test/prior_test.py
@@ -401,6 +401,8 @@ class TestPriorDict(unittest.TestCase):
                 name='mass_1', minimum=5, maximum=100, unit='$M_{\\odot}$'),
             mass_2=bilby.core.prior.Uniform(
                 name='mass_2', minimum=5, maximum=100, unit='$M_{\\odot}$'),
+            mass_ratio=bilby.core.prior.Constraint(
+                name='mass_ratio', minimum=0.125, maximum=1.0),
             a_1=bilby.core.prior.Uniform(name='a_1', minimum=0, maximum=0.8),
             a_2=bilby.core.prior.Uniform(name='a_2', minimum=0, maximum=0.8),
             tilt_1=bilby.core.prior.Sine(name='tilt_1'),
@@ -457,6 +459,8 @@ class TestPriorDict(unittest.TestCase):
                 name='mass_1', minimum=5, maximum=100, unit='$M_{\\odot}$'),
             mass_2=bilby.core.prior.Uniform(
                 name='mass_2', minimum=5, maximum=100, unit='$M_{\\odot}$'),
+            mass_ratio=bilby.core.prior.Constraint(
+                name='mass_ratio', minimum=0.125, maximum=1.0),
             a_1=bilby.core.prior.Uniform(name='a_1', minimum=0, maximum=0.8),
             a_2=bilby.core.prior.Uniform(name='a_2', minimum=0, maximum=0.8),
             tilt_1=bilby.core.prior.Sine(name='tilt_1'),
@@ -506,7 +510,7 @@ class TestPriorDict(unittest.TestCase):
         samples1 = self.prior_set_from_dict.sample_subset(keys=self.prior_set_from_dict.keys(), size=size)
         np.random.seed(42)
         samples2 = self.prior_set_from_dict.sample(size=size)
-        self.assertEqual(samples1.keys(), samples2.keys())
+        self.assertEqual(set(samples1.keys()), set(samples2.keys()))
         for key in samples1:
             self.assertTrue(np.array_equal(samples1[key], samples2[key]))