From 35e54f852b8965cc0484634950031dbbc63f684c Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Thu, 31 Aug 2023 15:12:38 +0000
Subject: [PATCH] FORMAT: enforce isinstance type checking

---
 bilby/bilby_mcmc/sampler.py        |  2 +-
 bilby/core/sampler/base_sampler.py |  4 ++--
 bilby/core/sampler/cpnest.py       |  2 +-
 bilby/core/utils/log.py            |  6 +++---
 bilby/gw/detector/networks.py      |  6 +++---
 bilby/gw/eos/eos.py                |  4 ++--
 test/core/likelihood_test.py       |  2 +-
 test/core/prior/joint_test.py      |  2 +-
 test/gw/likelihood_test.py         | 10 ++--------
 9 files changed, 16 insertions(+), 22 deletions(-)

diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py
index 1b528ec78..fcd4c9c5a 100644
--- a/bilby/bilby_mcmc/sampler.py
+++ b/bilby/bilby_mcmc/sampler.py
@@ -388,7 +388,7 @@ class Bilby_MCMC(MCMCSampler):
 
         with open(self.resume_file, "rb") as file:
             ptsampler = dill.load(file)
-            if type(ptsampler) != BilbyPTMCMCSampler:
+            if not isinstance(ptsampler, BilbyPTMCMCSampler):
                 logger.debug("Malformed resume file, ignoring")
                 return False
             self.ptsampler = ptsampler
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index 36c14de03..6b4f34619 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -682,11 +682,11 @@ class Sampler(object):
         if self.cached_result is None:
             kwargs_print = self.kwargs.copy()
             for k in kwargs_print:
-                if type(kwargs_print[k]) in (list, np.ndarray):
+                if isinstance(kwargs_print[k], (list, np.ndarray)):
                     array_repr = np.array(kwargs_print[k])
                     if array_repr.size > 10:
                         kwargs_print[k] = f"array_like, shape={array_repr.shape}"
-                elif type(kwargs_print[k]) == DataFrame:
+                elif isinstance(kwargs_print[k], DataFrame):
                     kwargs_print[k] = f"DataFrame, shape={kwargs_print[k].shape}"
             logger.info(
                 f"Using sampler {self.__class__.__name__} with kwargs {kwargs_print}"
diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py
index bc3b36465..e1f3ae19e 100644
--- a/bilby/core/sampler/cpnest.py
+++ b/bilby/core/sampler/cpnest.py
@@ -182,7 +182,7 @@ class Cpnest(NestedSampler):
         if "proposals" in self.kwargs:
             if self.kwargs["proposals"] is None:
                 return
-            if type(self.kwargs["proposals"]) == JumpProposalCycle:
+            if isinstance(self.kwargs["proposals"], JumpProposalCycle):
                 self.kwargs["proposals"] = dict(
                     mhs=self.kwargs["proposals"], hmc=self.kwargs["proposals"]
                 )
diff --git a/bilby/core/utils/log.py b/bilby/core/utils/log.py
index 4884eba9c..ca86b814b 100644
--- a/bilby/core/utils/log.py
+++ b/bilby/core/utils/log.py
@@ -22,7 +22,7 @@ def setup_logger(outdir='.', label=None, log_level='INFO', print_version=False):
         If true, print version information
     """
 
-    if type(log_level) is str:
+    if isinstance(log_level, str):
         try:
             level = getattr(logging, log_level.upper())
         except AttributeError:
@@ -34,14 +34,14 @@ def setup_logger(outdir='.', label=None, log_level='INFO', print_version=False):
     logger.propagate = False
     logger.setLevel(level)
 
-    if any([type(h) == logging.StreamHandler for h in logger.handlers]) is False:
+    if not any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
         stream_handler = logging.StreamHandler()
         stream_handler.setFormatter(logging.Formatter(
             '%(asctime)s %(name)s %(levelname)-8s: %(message)s', datefmt='%H:%M'))
         stream_handler.setLevel(level)
         logger.addHandler(stream_handler)
 
-    if any([type(h) == logging.FileHandler for h in logger.handlers]) is False:
+    if not any([isinstance(h, logging.FileHandler) for h in logger.handlers]):
         if label:
             Path(outdir).mkdir(parents=True, exist_ok=True)
             log_file = '{}/{}.log'.format(outdir, label)
diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py
index 20b1f54b5..7ff6f4417 100644
--- a/bilby/gw/detector/networks.py
+++ b/bilby/gw/detector/networks.py
@@ -25,12 +25,12 @@ class InterferometerList(list):
         """
 
         super(InterferometerList, self).__init__()
-        if type(interferometers) == str:
+        if isinstance(interferometers, str):
             raise TypeError("Input must not be a string")
         for ifo in interferometers:
-            if type(ifo) == str:
+            if isinstance(ifo, str):
                 ifo = get_empty_interferometer(ifo)
-            if type(ifo) not in [Interferometer, TriangularInterferometer]:
+            if not isinstance(ifo, (Interferometer, TriangularInterferometer)):
                 raise TypeError(
                     "Input list of interferometers are not all Interferometer objects"
                 )
diff --git a/bilby/gw/eos/eos.py b/bilby/gw/eos/eos.py
index 5693fb33e..ca7799ebb 100644
--- a/bilby/gw/eos/eos.py
+++ b/bilby/gw/eos/eos.py
@@ -60,12 +60,12 @@ class TabularEOS(object):
         self.sampling_flag = sampling_flag
         self.warning_flag = warning_flag
 
-        if type(eos) == str:
+        if isinstance(eos, str):
             if eos in valid_eos_dict.keys():
                 table = np.loadtxt(valid_eos_dict[eos])
             else:
                 table = np.loadtxt(eos)
-        elif type(eos) == np.ndarray:
+        elif isinstance(eos, np.ndarray):
             table = eos
         else:
             raise ValueError("eos provided is invalid type please supply a str name, str path to ASCII file, "
diff --git a/test/core/likelihood_test.py b/test/core/likelihood_test.py
index fb6ffa0c9..a7eb4a1c8 100644
--- a/test/core/likelihood_test.py
+++ b/test/core/likelihood_test.py
@@ -193,7 +193,7 @@ class TestGaussianLikelihood(unittest.TestCase):
         likelihood.parameters["m"] = 2
         likelihood.parameters["c"] = 0
         likelihood.log_likelihood()
-        self.assertTrue(type(likelihood.sigma) == type(sigma_array))
+        self.assertTrue(type(likelihood.sigma) == type(sigma_array))  # noqa: E721
         self.assertTrue(all(likelihood.sigma == sigma_array))
 
     def test_set_sigma_None(self):
diff --git a/test/core/prior/joint_test.py b/test/core/prior/joint_test.py
index ebadfcfae..c99373b00 100644
--- a/test/core/prior/joint_test.py
+++ b/test/core/prior/joint_test.py
@@ -32,7 +32,7 @@ MultivariateGaussianDist(
                 self.assertTrue(item == fromstr.__getattribute__(key))
             elif key == "mvn":
                 for d1, d2 in zip(fromstr.__getattribute__(key), item):
-                    self.assertTrue(type(d1) == type(d2))
+                    self.assertTrue(type(d1) == type(d2))  # noqa: E721
             elif isinstance(item, (list, tuple, np.ndarray)):
                 self.assertTrue(
                     np.all(np.array(item) == np.array(fromstr.__getattribute__(key)))
diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py
index 87a007072..286eadea2 100644
--- a/test/gw/likelihood_test.py
+++ b/test/gw/likelihood_test.py
@@ -189,10 +189,7 @@ class TestGWTransient(unittest.TestCase):
         self.assertListEqual(
             bilby.gw.detector.InterferometerList(ifos), self.likelihood.interferometers
         )
-        self.assertTrue(
-            type(self.likelihood.interferometers)
-            == bilby.gw.detector.InterferometerList
-        )
+        self.assertIsInstance(self.likelihood.interferometers, bilby.gw.detector.InterferometerList)
 
     def test_interferometers_setting_interferometer_list(self):
         ifos = bilby.gw.detector.InterferometerList(
@@ -205,10 +202,7 @@ class TestGWTransient(unittest.TestCase):
         self.assertListEqual(
             bilby.gw.detector.InterferometerList(ifos), self.likelihood.interferometers
         )
-        self.assertTrue(
-            type(self.likelihood.interferometers)
-            == bilby.gw.detector.InterferometerList
-        )
+        self.assertIsInstance(self.likelihood.interferometers, bilby.gw.detector.InterferometerList)
 
     def test_meta_data(self):
         expected = dict(
-- 
GitLab