From faf773c56e30208e4d8a5017bdf6850392084bb1 Mon Sep 17 00:00:00 2001
From: Moritz Huebner <moritz.huebner@ligo.org>
Date: Mon, 25 Feb 2019 17:23:28 -0600
Subject: [PATCH] Resolve "Check for duplicate parameters"

---
 CHANGELOG.md                                |   3 +-
 bilby/core/prior.py                         |  20 ++-
 bilby/core/sampler/base_sampler.py          |   8 +
 bilby/gw/prior.py                           |  95 ++++-------
 test/gw_prior_test.py                       | 168 +++++++++++++++++---
 test/prior_files/binary_neutron_stars.prior |  23 +++
 6 files changed, 237 insertions(+), 80 deletions(-)
 create mode 100644 test/prior_files/binary_neutron_stars.prior

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5d30ca96f..e895e5f50 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -14,6 +14,7 @@
 ## [0.4.0] 2019-02-15
 
 ### Changed
+- Changed the logic around redundancy tests in the `PriorDict` classes
 - Fixed an accidental addition of astropy as a first-class dependency and added a check for missing dependencies to the C.I.
 - Fixed a bug in the "create-your-own-time-domain-model" example
 - Added citation guide to the readme
@@ -33,7 +34,7 @@
 - Improve the load_data_from_cache_file method
 
 ### Removed
--
+- Removed deprecated `PriorSet` classes. Use `PriorDict` instead.
 
 ## [0.3.5] 2019-01-25
 
diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index 228148d6c..88ca37083 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -245,10 +245,28 @@ class PriorDict(OrderedDict):
         """
         return [self[key].rescale(sample) for key, sample in zip(keys, theta)]
 
-    def test_redundancy(self, key):
+    def test_redundancy(self, key, disable_logging=False):
         """Empty redundancy test, should be overwritten in subclasses"""
         return False
 
+    def test_has_redundant_keys(self):
+        """
+        Test whether there are redundant keys in self.
+
+        Return
+        ------
+        bool: Whether there are redundancies or not
+        """
+        redundant = False
+        for key in self:
+            temp = self.copy()
+            del temp[key]
+            if temp.test_redundancy(key, disable_logging=True):
+                logger.warning('{} is a redundant key in this {}.'
+                               .format(key, self.__class__.__name__))
+                redundant = True
+        return redundant
+
     def copy(self):
         """
         We have to overwrite the copy method as it fails due to the presence of
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index 20f436666..c07f208ae 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -241,6 +241,10 @@ class Sampler(object):
             Likelihood can't be evaluated.
 
         """
+
+        if self.priors.test_has_redundant_keys():
+            raise IllegalSamplingSetError("Your sampling set contains redundant parameters.")
+
         self._check_if_priors_can_be_sampled()
         try:
             t1 = datetime.datetime.now()
@@ -502,3 +506,7 @@ class SamplerError(Error):
 
 class SamplerNotInstalledError(SamplerError):
     """ Base class for Error raised by not installed samplers """
+
+
+class IllegalSamplingSetError(Error):
+    """ Class for illegal sets of sampling parameters """
diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py
index 68f8b3450..2021a758a 100644
--- a/bilby/gw/prior.py
+++ b/bilby/gw/prior.py
@@ -207,63 +207,48 @@ class BBHPriorDict(PriorDict):
                 filename = os.path.join(os.path.dirname(__file__), 'prior_files', filename)
         PriorDict.__init__(self, dictionary=dictionary, filename=filename)
 
-    def test_redundancy(self, key):
+    def test_redundancy(self, key, disable_logging=False):
         """
         Test whether adding the key would add be redundant.
+        Already existing keys return True.
 
         Parameters
         ----------
         key: str
             The key to test.
+        disable_logging: bool, optional
+            Disable logging in this function call. Default is False.
 
         Return
         ------
         redundant: bool
             Whether the key is redundant or not
         """
-        redundant = False
         if key in self:
             logger.debug('{} already in prior'.format(key))
-            return redundant
+            return True
+
         mass_parameters = {'mass_1', 'mass_2', 'chirp_mass', 'total_mass', 'mass_ratio', 'symmetric_mass_ratio'}
-        spin_magnitude_parameters = {'a_1', 'a_2'}
         spin_tilt_1_parameters = {'tilt_1', 'cos_tilt_1'}
         spin_tilt_2_parameters = {'tilt_2', 'cos_tilt_2'}
         spin_azimuth_parameters = {'phi_1', 'phi_2', 'phi_12', 'phi_jl'}
         inclination_parameters = {'iota', 'cos_iota'}
         distance_parameters = {'luminosity_distance', 'comoving_distance', 'redshift'}
 
-        for parameter_set in [mass_parameters, spin_magnitude_parameters, spin_azimuth_parameters]:
-            if key in parameter_set:
-                if len(parameter_set.intersection(self)) > 2:
-                    redundant = True
-                    logger.warning('{} in prior. This may lead to unexpected behaviour.'.format(
-                        parameter_set.intersection(self)))
-                    break
-            elif len(parameter_set.intersection(self)) == 2:
-                redundant = True
-                break
-        for parameter_set in [inclination_parameters, distance_parameters, spin_tilt_1_parameters,
-                              spin_tilt_2_parameters]:
+        for independent_parameters, parameter_set in \
+                zip([2, 2, 1, 1, 1, 1],
+                    [mass_parameters, spin_azimuth_parameters,
+                     spin_tilt_1_parameters, spin_tilt_2_parameters,
+                     inclination_parameters, distance_parameters]):
             if key in parameter_set:
-                if len(parameter_set.intersection(self)) > 1:
-                    redundant = True
-                    logger.warning('{} in prior. This may lead to unexpected behaviour.'.format(
-                        parameter_set.intersection(self)))
-                    break
-                elif len(parameter_set.intersection(self)) == 1:
-                    redundant = True
-                    break
-
-        return redundant
-
-
-class BBHPriorSet(BBHPriorDict):
-
-    def __init__(self, dictionary=None, filename=None):
-        """ DEPRECATED: USE BBHPriorDict INSTEAD"""
-        logger.warning("The name 'BBHPriorSet' is deprecated use 'BBHPriorDict' instead")
-        super(BBHPriorSet, self).__init__(dictionary, filename)
+                if len(parameter_set.intersection(self)) >= independent_parameters:
+                    logger.disabled = disable_logging
+                    logger.warning('{} already in prior. '
+                                   'This may lead to unexpected behaviour.'
+                                   .format(parameter_set.intersection(self)))
+                    logger.disabled = False
+                    return True
+        return False
 
 
 class BNSPriorDict(PriorDict):
@@ -286,34 +271,32 @@ class BNSPriorDict(PriorDict):
                 filename = os.path.join(os.path.dirname(__file__), 'prior_files', filename)
         PriorDict.__init__(self, dictionary=dictionary, filename=filename)
 
-    def test_redundancy(self, key):
-        logger.info("Performing redundancy check using BBHPriorDict().test_redundancy")
-        bbh_redundancy = BBHPriorDict().test_redundancy(key)
+    def test_redundancy(self, key, disable_logging=False):
+        logger.disabled = disable_logging
+        logger.info("Performing redundancy check using BBHPriorDict(self).test_redundancy")
+        logger.disabled = False
+        bbh_redundancy = BBHPriorDict(self).test_redundancy(key)
+
         if bbh_redundancy:
             return True
         redundant = False
 
-        tidal_parameters =\
+        tidal_parameters = \
             {'lambda_1', 'lambda_2', 'lambda_tilde', 'delta_lambda'}
 
         if key in tidal_parameters:
             if len(tidal_parameters.intersection(self)) > 2:
                 redundant = True
-                logger.warning('{} in prior. This may lead to unexpected behaviour.'.format(
-                    tidal_parameters.intersection(self)))
+                logger.disabled = disable_logging
+                logger.warning('{} already in prior. '
+                               'This may lead to unexpected behaviour.'
+                               .format(tidal_parameters.intersection(self)))
+                logger.disabled = False
             elif len(tidal_parameters.intersection(self)) == 2:
                 redundant = True
         return redundant
 
 
-class BNSPriorSet(BNSPriorDict):
-
-    def __init__(self, dictionary=None, filename=None):
-        """ DEPRECATED: USE BNSPriorDict INSTEAD"""
-        super(BNSPriorSet, self).__init__(dictionary, filename)
-        logger.warning("The name 'BNSPriorSet' is deprecated use 'BNSPriorDict' instead")
-
-
 Prior._default_latex_labels = {
     'mass_1': '$m_1$',
     'mass_2': '$m_2$',
@@ -418,13 +401,13 @@ class CalibrationPriorDict(PriorDict):
         nodes = np.logspace(np.log10(minimum_frequency),
                             np.log10(maximum_frequency), n_nodes)
 
-        amplitude_mean_nodes =\
+        amplitude_mean_nodes = \
             UnivariateSpline(frequency_array, amplitude_median)(nodes)
-        amplitude_sigma_nodes =\
+        amplitude_sigma_nodes = \
             UnivariateSpline(frequency_array, amplitude_sigma)(nodes)
-        phase_mean_nodes =\
+        phase_mean_nodes = \
             UnivariateSpline(frequency_array, phase_median)(nodes)
-        phase_sigma_nodes =\
+        phase_sigma_nodes = \
             UnivariateSpline(frequency_array, phase_sigma)(nodes)
 
         prior = CalibrationPriorDict()
@@ -506,11 +489,3 @@ class CalibrationPriorDict(PriorDict):
                                         latex_label=latex_label)
 
         return prior
-
-
-class CalibrationPriorSet(CalibrationPriorDict):
-
-    def __init__(self, dictionary=None, filename=None):
-        """ DEPRECATED: USE BNSPriorDict INSTEAD"""
-        super(CalibrationPriorSet, self).__init__(dictionary, filename)
-        logger.warning("The name 'CalibrationPriorSet' is deprecated use 'CalibrationPriorDict' instead")
diff --git a/test/gw_prior_test.py b/test/gw_prior_test.py
index f7c7454e9..5679d8000 100644
--- a/test/gw_prior_test.py
+++ b/test/gw_prior_test.py
@@ -16,42 +16,174 @@ class TestBBHPriorDict(unittest.TestCase):
         self.base_directory =\
             '/'.join(os.path.dirname(
                 os.path.abspath(sys.argv[0])).split('/')[:-1])
-        self.filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prior_files/binary_black_holes.prior')
-        self.default_prior = bilby.gw.prior.BBHPriorDict(
-            filename=self.filename)
+        self.filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
+                                     'prior_files/binary_black_holes.prior')
+        self.bbh_prior_dict = bilby.gw.prior.BBHPriorDict(filename=self.filename)
+        for key, value in self.bbh_prior_dict.items():
+            self.prior_dict[key] = value
 
     def tearDown(self):
         del self.prior_dict
         del self.filename
+        del self.bbh_prior_dict
+        del self.base_directory
 
     def test_create_default_prior(self):
         default = bilby.gw.prior.BBHPriorDict()
-        minima = all([self.default_prior[key].minimum == default[key].minimum
+        minima = all([self.bbh_prior_dict[key].minimum == default[key].minimum
                       for key in default.keys()])
-        maxima = all([self.default_prior[key].maximum == default[key].maximum
+        maxima = all([self.bbh_prior_dict[key].maximum == default[key].maximum
                       for key in default.keys()])
-        names = all([self.default_prior[key].name == default[key].name
+        names = all([self.bbh_prior_dict[key].name == default[key].name
                      for key in default.keys()])
 
         self.assertTrue(all([minima, maxima, names]))
 
     def test_create_from_dict(self):
-        bilby.gw.prior.BBHPriorDict(dictionary=self.prior_dict)
+        new_dict = bilby.gw.prior.BBHPriorDict(dictionary=self.prior_dict)
+        for key in self.bbh_prior_dict:
+            self.assertEqual(self.bbh_prior_dict[key], new_dict[key])
+
+    def test_redundant_priors_not_in_dict_before(self):
+        for prior in ['chirp_mass', 'total_mass', 'mass_ratio', 'symmetric_mass_ratio',
+                      'cos_tilt_1', 'cos_tilt_2', 'phi_1', 'phi_2', 'cos_iota',
+                      'comoving_distance', 'redshift']:
+            self.assertTrue(self.bbh_prior_dict.test_redundancy(prior))
+
+    def test_redundant_priors_already_in_dict(self):
+        for prior in ['mass_1', 'mass_2', 'tilt_1', 'tilt_2',
+                      'phi_1', 'phi_2', 'iota', 'luminosity_distance']:
+            self.assertTrue(self.bbh_prior_dict.test_redundancy(prior))
+
+    def test_correct_not_redundant_priors_masses(self):
+        del self.bbh_prior_dict['mass_2']
+        for prior in ['mass_2', 'chirp_mass', 'total_mass', 'mass_ratio',  'symmetric_mass_ratio']:
+            self.assertFalse(self.bbh_prior_dict.test_redundancy(prior))
+
+    def test_correct_not_redundant_priors_spin_magnitudes(self):
+        del self.bbh_prior_dict['a_2']
+        self.assertFalse(self.bbh_prior_dict.test_redundancy('a_2'))
+
+    def test_correct_not_redundant_priors_spin_tilt_1(self):
+        del self.bbh_prior_dict['tilt_1']
+        for prior in ['tilt_1', 'cos_tilt_1']:
+            self.assertFalse(self.bbh_prior_dict.test_redundancy(prior))
+
+    def test_correct_not_redundant_priors_spin_tilt_2(self):
+        del self.bbh_prior_dict['tilt_2']
+        for prior in ['tilt_2', 'cos_tilt_2']:
+            self.assertFalse(self.bbh_prior_dict.test_redundancy(prior))
+
+    def test_correct_not_redundant_priors_spin_azimuth(self):
+        del self.bbh_prior_dict['phi_12']
+        for prior in ['phi_1', 'phi_2', 'phi_12']:
+            self.assertFalse(self.bbh_prior_dict.test_redundancy(prior))
+
+    def test_correct_not_redundant_priors_inclination(self):
+        del self.bbh_prior_dict['iota']
+        for prior in ['iota', 'cos_iota']:
+            self.assertFalse(self.bbh_prior_dict.test_redundancy(prior))
+
+    def test_correct_not_redundant_priors_distance(self):
+        del self.bbh_prior_dict['luminosity_distance']
+        for prior in ['luminosity_distance', 'comoving_distance',
+                      'redshift']:
+            self.assertFalse(self.bbh_prior_dict.test_redundancy(prior))
+
+    def test_add_unrelated_prior(self):
+        self.assertFalse(self.bbh_prior_dict.test_redundancy('abc'))
+
+    def test_test_has_redundant_priors(self):
+        self.assertFalse(self.bbh_prior_dict.test_has_redundant_keys())
+        for prior in ['chirp_mass', 'total_mass', 'mass_ratio', 'symmetric_mass_ratio',
+                      'cos_tilt_1', 'cos_tilt_2', 'phi_1', 'phi_2', 'cos_iota',
+                      'comoving_distance', 'redshift']:
+            self.bbh_prior_dict[prior] = 0
+            self.assertTrue(self.bbh_prior_dict.test_has_redundant_keys())
+            del self.bbh_prior_dict[prior]
+
+
+class TestBNSPriorDict(unittest.TestCase):
 
-    def test_create_from_filename(self):
-        bilby.gw.prior.BBHPriorDict(filename=self.filename)
+    def setUp(self):
+        self.prior_dict = dict()
+        self.base_directory =\
+            '/'.join(os.path.dirname(
+                os.path.abspath(sys.argv[0])).split('/')[:-1])
+        self.filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
+                                     'prior_files/binary_neutron_stars.prior')
+        self.bns_prior_dict = bilby.gw.prior.BNSPriorDict(filename=self.filename)
+        for key, value in self.bns_prior_dict.items():
+            self.prior_dict[key] = value
 
-    def test_key_in_prior_not_redundant(self):
-        test = self.default_prior.test_redundancy('mass_1')
-        self.assertFalse(test)
+    def tearDown(self):
+        del self.prior_dict
+        del self.filename
+        del self.bns_prior_dict
+        del self.base_directory
 
-    def test_chirp_mass_redundant(self):
-        test = self.default_prior.test_redundancy('chirp_mass')
-        self.assertTrue(test)
+    def test_create_default_prior(self):
+        default = bilby.gw.prior.BNSPriorDict()
+        minima = all([self.bns_prior_dict[key].minimum == default[key].minimum
+                      for key in default.keys()])
+        maxima = all([self.bns_prior_dict[key].maximum == default[key].maximum
+                      for key in default.keys()])
+        names = all([self.bns_prior_dict[key].name == default[key].name
+                     for key in default.keys()])
 
-    def test_comoving_distance_redundant(self):
-        test = self.default_prior.test_redundancy('comoving_distance')
-        self.assertTrue(test)
+        self.assertTrue(all([minima, maxima, names]))
+
+    def test_create_from_dict(self):
+        new_dict = bilby.gw.prior.BNSPriorDict(dictionary=self.prior_dict)
+        self.assertDictEqual(self.bns_prior_dict, new_dict)
+
+    def test_redundant_priors_not_in_dict_before(self):
+        for prior in ['chirp_mass', 'total_mass', 'mass_ratio',
+                      'symmetric_mass_ratio', 'cos_iota', 'comoving_distance',
+                      'redshift', 'lambda_tilde', 'delta_lambda']:
+            self.assertTrue(self.bns_prior_dict.test_redundancy(prior))
+
+    def test_redundant_priors_already_in_dict(self):
+        for prior in ['mass_1', 'mass_2', 'chi_1', 'chi_2',
+                      'iota', 'luminosity_distance',
+                      'lambda_1', 'lambda_2']:
+            self.assertTrue(self.bns_prior_dict.test_redundancy(prior))
+
+    def test_correct_not_redundant_priors_masses(self):
+        del self.bns_prior_dict['mass_2']
+        for prior in ['mass_2', 'chirp_mass', 'total_mass', 'mass_ratio',  'symmetric_mass_ratio']:
+            self.assertFalse(self.bns_prior_dict.test_redundancy(prior))
+
+    def test_correct_not_redundant_priors_spin_magnitudes(self):
+        del self.bns_prior_dict['chi_2']
+        self.assertFalse(self.bns_prior_dict.test_redundancy('chi_2'))
+
+    def test_correct_not_redundant_priors_inclination(self):
+        del self.bns_prior_dict['iota']
+        for prior in ['iota', 'cos_iota']:
+            self.assertFalse(self.bns_prior_dict.test_redundancy(prior))
+
+    def test_correct_not_redundant_priors_distance(self):
+        del self.bns_prior_dict['luminosity_distance']
+        for prior in ['luminosity_distance', 'comoving_distance',
+                      'redshift']:
+            self.assertFalse(self.bns_prior_dict.test_redundancy(prior))
+
+    def test_correct_not_redundant_priors_tidal(self):
+        del self.bns_prior_dict['lambda_1']
+        for prior in['lambda_1', 'lambda_tilde', 'delta_lambda']:
+            self.assertFalse(self.bns_prior_dict.test_redundancy(prior))
+
+    def test_add_unrelated_prior(self):
+        self.assertFalse(self.bns_prior_dict.test_redundancy('abc'))
+
+    def test_test_has_redundant_priors(self):
+        self.assertFalse(self.bns_prior_dict.test_has_redundant_keys())
+        for prior in ['chirp_mass', 'total_mass', 'mass_ratio', 'symmetric_mass_ratio',
+                      'cos_iota', 'comoving_distance', 'redshift']:
+            self.bns_prior_dict[prior] = 0
+            self.assertTrue(self.bns_prior_dict.test_has_redundant_keys())
+            del self.bns_prior_dict[prior]
 
 
 class TestCalibrationPrior(unittest.TestCase):
diff --git a/test/prior_files/binary_neutron_stars.prior b/test/prior_files/binary_neutron_stars.prior
new file mode 100644
index 000000000..1bc4d485c
--- /dev/null
+++ b/test/prior_files/binary_neutron_stars.prior
@@ -0,0 +1,23 @@
+# These are the default priors we use for BNS systems.
+# Note that you may wish to use more specific mass and distance parameters.
+# These commands are all known to bilby.gw.prior.
+# Lines beginning "#" are ignored.
+mass_1 = Uniform(name='mass_1', minimum=1, maximum=2, unit='$M_{\\odot}$')
+mass_2 = Uniform(name='mass_2', minimum=1, maximum=2, unit='$M_{\\odot}$')
+# chirp_mass = Uniform(name='chirp_mass', minimum=0.87, maximum=1.74, unit='$M_{\\odot}$')
+# total_mass =  Uniform(name='total_mass', minimum=2, maximum=4, unit='$M_{\\odot}$')
+# mass_ratio =  Uniform(name='mass_ratio', minimum=0.5, maximum=1)
+# symmetric_mass_ratio =  Uniform(name='symmetric_mass_ratio', minimum=0.22, maximum=0.25)
+chi_1 =  bilby.gw.prior.AlignedSpin(a_prior=Uniform(0, 0.05), z_prior=Uniform(-1, 1), name='chi_1', latex_label='$\\chi_1$')
+chi_2 =  bilby.gw.prior.AlignedSpin(a_prior=Uniform(0, 0.05), z_prior=Uniform(-1, 1), name='chi_2', latex_label='$\\chi_2$')
+luminosity_distance =  bilby.gw.prior.UniformComovingVolume(name='luminosity_distance', minimum=10, maximum=500, unit='Mpc')
+dec =  Cosine(name='dec')
+ra =  Uniform(name='ra', minimum=0, maximum=2 * np.pi)
+iota =  Sine(name='iota')
+# cos_iota =  Uniform(name='cos_iota', minimum=-1, maximum=1)
+psi =  Uniform(name='psi', minimum=0, maximum=np.pi)
+phase =  Uniform(name='phase', minimum=0, maximum=2 * np.pi)
+lambda_1 = Uniform(name='lambda_1', minimum=0, maximum=3000 )
+lambda_2 = Uniform(name='lambda_2', minimum=0, maximum=3000 )
+# lambda_tilde = Uniform(name='lambda_tilde', minimum=0, maximum=5000)
+# delta_lambda = Uniform(name='delta_lambda', minimum=-5000, maximum=5000)
-- 
GitLab