diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py
index dc717ba36128bec960698ee1dedb27a4c21d8f4b..05df00a7bad5a46021f6597617d833c6ee525816 100644
--- a/bilby/core/prior/base.py
+++ b/bilby/core/prior/base.py
@@ -63,21 +63,52 @@ class Prior(object):
         return self.sample()
 
     def __eq__(self, other):
+        """
+        Test equality of two prior objects.
+
+        Returns true iff:
+
+        - The class of the two priors are the same
+        - Both priors have the same keys in the __dict__ attribute
+        - The instantiation arguments match
+
+        We don't check that all entries the the __dict__ attribute
+        are equal as some attributes are variable for conditional
+        priors.
+
+        Parameters
+        ==========
+        other: Prior
+            The prior to compare with
+
+        Returns
+        =======
+        bool
+            Whether the priors are equivalent
+
+        Notes
+        =====
+        A special case is made for :code `scipy.stats.beta`: instances.
+        It may be possible to remove this as we now only check instantiation
+        arguments.
+
+        """
         if self.__class__ != other.__class__:
             return False
         if sorted(self.__dict__.keys()) != sorted(other.__dict__.keys()):
             return False
-        for key in self.__dict__:
+        this_dict = self.get_instantiation_dict()
+        other_dict = other.get_instantiation_dict()
+        for key in this_dict:
             if key == "least_recently_sampled":
-                # ignore sample drawn from prior in comparison
                 continue
-            if type(self.__dict__[key]) is np.ndarray:
-                if not np.array_equal(self.__dict__[key], other.__dict__[key]):
+            if isinstance(this_dict[key], np.ndarray):
+                if not np.array_equal(this_dict[key], other_dict[key]):
                     return False
-            elif isinstance(self.__dict__[key], type(scipy.stats.beta(1., 1.))):
+            elif isinstance(this_dict[key], type(scipy.stats.beta(1., 1.))):
                 continue
             else:
-                if not self.__dict__[key] == other.__dict__[key]:
+                if not this_dict[key] == other_dict[key]:
                     return False
         return True
 
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index af29c5997d89947d2c9394c17cedb44ea4e0aa47..7c8d7e8c34cdac1f11fea2bdd0e5770085b41f13 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -250,22 +250,6 @@ class Sampler(object):
 
         return result
 
-    def _check_if_priors_can_be_sampled(self):
-        """Check if all priors can be sampled properly.
-
-        Raises
-        ======
-        AttributeError
-            prior can't be sampled.
-        """
-        for key in self.priors:
-            if isinstance(self.priors[key], Constraint):
-                continue
-            try:
-                self.priors[key].sample()
-            except AttributeError as e:
-                logger.warning('Cannot sample from {}, {}'.format(key, e))
-
     def _verify_parameters(self):
         """ Evaluate a set of parameters drawn from the prior
 
@@ -322,7 +306,13 @@ class Sampler(object):
         Checks if use_ratio is set. Prints a warning if use_ratio is set but
         not properly implemented.
         """
-        self._check_if_priors_can_be_sampled()
+        try:
+            self.priors.sample_subset(self.search_parameter_keys)
+        except (KeyError, AttributeError):
+            logger.error("Cannot sample from priors with keys: {}.".format(
+                self.search_parameter_keys
+            ))
+            raise
         if self.use_ratio is False:
             logger.debug("use_ratio set to False")
             return
diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py
index 093e2eda8db4324985b10aed7d62460ec2ee85a5..caf68fd260036213e6a4c30a022a860b7fdc98c2 100644
--- a/bilby/gw/conversion.py
+++ b/bilby/gw/conversion.py
@@ -234,12 +234,26 @@ def convert_to_lal_binary_black_hole_parameters(parameters):
     for idx in ['1', '2']:
         key = 'chi_{}'.format(idx)
         if key in original_keys:
-            converted_parameters['a_{}'.format(idx)] = abs(
-                converted_parameters[key])
-            converted_parameters['cos_tilt_{}'.format(idx)] = \
-                np.sign(converted_parameters[key])
-            converted_parameters['phi_jl'] = 0.0
-            converted_parameters['phi_12'] = 0.0
+            if "chi_{}_in_plane".format(idx) in original_keys:
+                converted_parameters["a_{}".format(idx)] = (
+                    converted_parameters[f"chi_{idx}"] ** 2
+                    + converted_parameters[f"chi_{idx}_in_plane"] ** 2
+                ) ** 0.5
+                converted_parameters[f"cos_tilt_{idx}"] = (
+                    converted_parameters[f"chi_{idx}"]
+                    / converted_parameters[f"a_{idx}"]
+                )
+            elif "a_{}".format(idx) not in original_keys:
+                converted_parameters['a_{}'.format(idx)] = abs(
+                    converted_parameters[key])
+                converted_parameters['cos_tilt_{}'.format(idx)] = \
+                    np.sign(converted_parameters[key])
+                converted_parameters['phi_jl'] = 0.0
+                converted_parameters['phi_12'] = 0.0
+            else:
+                converted_parameters[f"cos_tilt_{idx}"] = (
+                    converted_parameters[key] / converted_parameters[f"a_{idx}"]
+                )
 
     for angle in ['tilt_1', 'tilt_2', 'theta_jn']:
         cos_angle = str('cos_' + angle)
diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py
index 594f6f71126692539a6a66b6e863e3a7886d230a..899e98060cf7023eb616ec4eeda6ee14089c6e30 100644
--- a/bilby/gw/prior.py
+++ b/bilby/gw/prior.py
@@ -7,10 +7,12 @@ from scipy.integrate import cumtrapz
 from scipy.special import hyp2f1
 from scipy.stats import norm
 
-from ..core.prior import (PriorDict, Uniform, Prior, DeltaFunction, Gaussian,
-                          Interped, Constraint, conditional_prior_factory,
-                          BaseJointPriorDist, JointPrior, JointPriorDistError,
-                          PowerLaw)
+from ..core.prior import (
+    PriorDict, Uniform, Prior, DeltaFunction, Gaussian, Interped, Constraint,
+    conditional_prior_factory, PowerLaw, ConditionalLogUniform,
+    ConditionalPriorDict, ConditionalBasePrior, BaseJointPriorDist, JointPrior,
+    JointPriorDistError,
+)
 from ..core.utils import infer_args_from_method, logger
 from .conversion import (
     convert_to_lal_binary_black_hole_parameters,
@@ -417,6 +419,141 @@ class AlignedSpin(Interped):
                                           maximum=maximum)
 
 
+class ConditionalChiUniformSpinMagnitude(ConditionalLogUniform):
+    r"""
+    This prior characterizes the conditional prior on the spin magnitude given
+    the aligned component of the spin  such that the marginal prior is uniform
+    if the distribution of spin orientations is isotropic.
+
+    .. math::
+        p(a) &= \frac{1}{a_{\max}}
+        p(\chi) &= - \frac{1}{2 a_{\max}} \ln(|\chi|)
+        p(a | \chi) &\propto \frac{1}{a}
+    """
+
+    def __init__(self, minimum, maximum, name, latex_label=None, unit=None, boundary=None):
+        super(ConditionalChiUniformSpinMagnitude, self).__init__(
+            minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit, boundary=boundary,
+            condition_func=self._condition_function)
+        self._required_variables = [name.replace("a", "chi")]
+        self.__class__.__name__ = "ConditionalChiUniformSpinMagnitude"
+        self.__class__.__qualname__ = "ConditionalChiUniformSpinMagnitude"
+
+    def _condition_function(self, reference_params, **kwargs):
+        return dict(minimum=np.abs(kwargs[self._required_variables[0]]), maximum=reference_params["maximum"])
+
+    def __repr__(self):
+        return Prior.__repr__(self)
+
+    def get_instantiation_dict(self):
+        instantiation_dict = Prior.get_instantiation_dict(self)
+        for key, value in self.reference_params.items():
+            if key in instantiation_dict:
+                instantiation_dict[key] = value
+        return instantiation_dict
+
+
+class ConditionalChiInPlane(ConditionalBasePrior):
+    r"""
+    This prior characterizes the conditional prior on the in-plane spin magnitude
+    given the aligned component of the spin  such that the marginal prior is uniform
+    if the distribution of spin orientations is isotropic.
+
+    .. math::
+        p(a) &= \frac{1}{a_{\max}}
+        p(\chi_\perp) = 2 N \chi_\perp / (\chi ** 2 + \chi_\perp ** 2)
+        N^{-1} &= 2 \ln(a_\max / |\chi|)
+    """
+
+    def __init__(self, minimum, maximum, name, latex_label=None, unit=None, boundary=None):
+        super(ConditionalChiInPlane, self).__init__(
+            minimum=minimum, maximum=maximum,
+            name=name, latex_label=latex_label,
+            unit=unit, boundary=boundary,
+            condition_func=self._condition_function
+        )
+        self._required_variables = [name[:5]]
+        self._reference_maximum = maximum
+        self.__class__.__name__ = "ConditionalChiInPlane"
+        self.__class__.__qualname__ = "ConditionalChiInPlane"
+
+    def prob(self, val, **required_variables):
+        self.update_conditions(**required_variables)
+        chi_aligned = abs(required_variables[self._required_variables[0]])
+        return (
+            (val >= self.minimum) * (val <= self.maximum)
+            * val
+            / (chi_aligned ** 2 + val ** 2)
+            / np.log(self._reference_maximum / chi_aligned)
+        )
+
+    def ln_prob(self, val, **required_variables):
+        return np.log(self.prob(val, **required_variables))
+
+    def cdf(self, val, **required_variables):
+        r"""
+        .. math::
+            \text{CDF}(\chi_\per) = N ln(1 + (\chi_\perp / \chi) ** 2)
+
+        Parameters
+        ----------
+        val: (float, array-like)
+            The value at which to evaluate the CDF
+        required_variables: dict
+            A dictionary containing the aligned component of the spin
+
+        Returns
+        -------
+        (float, array-like)
+            The value of the CDF
+
+        """
+        self.update_conditions(**required_variables)
+        chi_aligned = abs(required_variables[self._required_variables[0]])
+        return np.maximum(np.minimum(
+            (val >= self.minimum) * (val <= self.maximum)
+            * np.log(1 + (val / chi_aligned) ** 2)
+            / 2 / np.log(self._reference_maximum / chi_aligned)
+            , 1
+        ), 0)
+
+    def rescale(self, val, **required_variables):
+        r"""
+        .. math::
+            \text{PPF}(\chi_\perp) = ((a_\max / \chi) ** (2x) - 1) ** 0.5 * \chi
+
+        Parameters
+        ----------
+        val: (float, array-like)
+            The value to rescale
+        required_variables: dict
+            Dictionary containing the aligned spin component
+
+        Returns
+        -------
+        (float, array-like)
+            The in-plane component of the spin
+        """
+        self.update_conditions(**required_variables)
+        chi_aligned = abs(required_variables[self._required_variables[0]])
+        return chi_aligned * ((self._reference_maximum / chi_aligned) ** (2 * val) - 1) ** 0.5
+
+    def _condition_function(self, reference_params, **kwargs):
+        return dict(minimum=0, maximum=(
+            self._reference_maximum ** 2 - kwargs[self._required_variables[0]] ** 2
+        ) ** 0.5)
+
+    def __repr__(self):
+        return Prior.__repr__(self)
+
+    def get_instantiation_dict(self):
+        instantiation_dict = Prior.get_instantiation_dict(self)
+        for key, value in self.reference_params.items():
+            if key in instantiation_dict:
+                instantiation_dict[key] = value
+        return instantiation_dict
+
+
 class EOSCheck(Constraint):
     def __init__(self, minimum=-np.inf, maximum=np.inf):
         """
@@ -443,7 +580,7 @@ class EOSCheck(Constraint):
         return result
 
 
-class CBCPriorDict(PriorDict):
+class CBCPriorDict(ConditionalPriorDict):
     @property
     def minimum_chirp_mass(self):
         mass_1 = None
diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py
index 2bbbc2bf1e87535ed56a8d47bff35ec57c2c9b4b..6afdc9933ad782fd5cc5eab603f92bf4287a7d5f 100644
--- a/test/core/prior/prior_test.py
+++ b/test/core/prior/prior_test.py
@@ -667,17 +667,6 @@ class TestPriorClasses(unittest.TestCase):
         prior_2.other_key = 5
         self.assertNotEqual(prior_1, prior_2)
 
-    def test_np_array_eq(self):
-        prior_1 = bilby.core.prior.PowerLaw(
-            name="test", unit="unit", alpha=0, minimum=0, maximum=1
-        )
-        prior_2 = bilby.core.prior.PowerLaw(
-            name="test", unit="unit", alpha=0, minimum=0, maximum=1
-        )
-        prior_1.array_attribute = np.array([1, 2, 3])
-        prior_2.array_attribute = np.array([2, 2, 3])
-        self.assertNotEqual(prior_1, prior_2)
-
     def test_repr(self):
         for prior in self.priors:
             if isinstance(prior, bilby.core.prior.Interped):
diff --git a/test/gw/prior_test.py b/test/gw/prior_test.py
index f192b645fadd99d3dc965da533863c15b7573bff..832f2a9d5213ed9f967c8b1ddbd30d6a6c3f21ea 100644
--- a/test/gw/prior_test.py
+++ b/test/gw/prior_test.py
@@ -523,5 +523,20 @@ class TestAlignedSpin(unittest.TestCase):
         self.assertAlmostEqual(max_difference, 0, 2)
 
 
+class TestConditionalChiUniformSpinMagnitude(unittest.TestCase):
+
+    def setUp(self):
+        pass
+
+    def test_marginalized_prior_is_uniform(self):
+        priors = bilby.gw.prior.BBHPriorDict(aligned_spin=True)
+        priors["a_1"] = bilby.gw.prior.ConditionalChiUniformSpinMagnitude(
+            minimum=0.1, maximum=priors["chi_1"].maximum, name="a_1"
+        )
+        samples = priors.sample(100000)["a_1"]
+        ks = ks_2samp(samples, np.random.uniform(0, priors["chi_1"].maximum, 100000))
+        self.assertTrue(ks.pvalue > 0.001)
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/test/integration/sample_from_the_prior_test.py b/test/integration/sample_from_the_prior_test.py
index bb54aa98bcb8b6613e4464d1d0fe911fac164b5d..e48c3574dc79c44b6b73089f2b3a71d24d4ea6c1 100644
--- a/test/integration/sample_from_the_prior_test.py
+++ b/test/integration/sample_from_the_prior_test.py
@@ -40,7 +40,7 @@ class Test(unittest.TestCase):
         duration = 4.0
         sampling_frequency = 2048.0
         label = "full_15_parameters"
-        np.random.seed(8817020)
+        np.random.seed(8817021)
 
         waveform_arguments = dict(
             waveform_approximant="IMRPhenomPv2",