From 35e54f852b8965cc0484634950031dbbc63f684c Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Thu, 31 Aug 2023 15:12:38 +0000 Subject: [PATCH] FORMAT: enforce isinstance type checking --- bilby/bilby_mcmc/sampler.py | 2 +- bilby/core/sampler/base_sampler.py | 4 ++-- bilby/core/sampler/cpnest.py | 2 +- bilby/core/utils/log.py | 6 +++--- bilby/gw/detector/networks.py | 6 +++--- bilby/gw/eos/eos.py | 4 ++-- test/core/likelihood_test.py | 2 +- test/core/prior/joint_test.py | 2 +- test/gw/likelihood_test.py | 10 ++-------- 9 files changed, 16 insertions(+), 22 deletions(-) diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py index 1b528ec78..fcd4c9c5a 100644 --- a/bilby/bilby_mcmc/sampler.py +++ b/bilby/bilby_mcmc/sampler.py @@ -388,7 +388,7 @@ class Bilby_MCMC(MCMCSampler): with open(self.resume_file, "rb") as file: ptsampler = dill.load(file) - if type(ptsampler) != BilbyPTMCMCSampler: + if not isinstance(ptsampler, BilbyPTMCMCSampler): logger.debug("Malformed resume file, ignoring") return False self.ptsampler = ptsampler diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 36c14de03..6b4f34619 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -682,11 +682,11 @@ class Sampler(object): if self.cached_result is None: kwargs_print = self.kwargs.copy() for k in kwargs_print: - if type(kwargs_print[k]) in (list, np.ndarray): + if isinstance(kwargs_print[k], (list, np.ndarray)): array_repr = np.array(kwargs_print[k]) if array_repr.size > 10: kwargs_print[k] = f"array_like, shape={array_repr.shape}" - elif type(kwargs_print[k]) == DataFrame: + elif isinstance(kwargs_print[k], DataFrame): kwargs_print[k] = f"DataFrame, shape={kwargs_print[k].shape}" logger.info( f"Using sampler {self.__class__.__name__} with kwargs {kwargs_print}" diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py index bc3b36465..e1f3ae19e 100644 --- a/bilby/core/sampler/cpnest.py +++ b/bilby/core/sampler/cpnest.py @@ -182,7 +182,7 @@ class Cpnest(NestedSampler): if "proposals" in self.kwargs: if self.kwargs["proposals"] is None: return - if type(self.kwargs["proposals"]) == JumpProposalCycle: + if isinstance(self.kwargs["proposals"], JumpProposalCycle): self.kwargs["proposals"] = dict( mhs=self.kwargs["proposals"], hmc=self.kwargs["proposals"] ) diff --git a/bilby/core/utils/log.py b/bilby/core/utils/log.py index 4884eba9c..ca86b814b 100644 --- a/bilby/core/utils/log.py +++ b/bilby/core/utils/log.py @@ -22,7 +22,7 @@ def setup_logger(outdir='.', label=None, log_level='INFO', print_version=False): If true, print version information """ - if type(log_level) is str: + if isinstance(log_level, str): try: level = getattr(logging, log_level.upper()) except AttributeError: @@ -34,14 +34,14 @@ def setup_logger(outdir='.', label=None, log_level='INFO', print_version=False): logger.propagate = False logger.setLevel(level) - if any([type(h) == logging.StreamHandler for h in logger.handlers]) is False: + if not any([isinstance(h, logging.StreamHandler) for h in logger.handlers]): stream_handler = logging.StreamHandler() stream_handler.setFormatter(logging.Formatter( '%(asctime)s %(name)s %(levelname)-8s: %(message)s', datefmt='%H:%M')) stream_handler.setLevel(level) logger.addHandler(stream_handler) - if any([type(h) == logging.FileHandler for h in logger.handlers]) is False: + if not any([isinstance(h, logging.FileHandler) for h in logger.handlers]): if label: Path(outdir).mkdir(parents=True, exist_ok=True) log_file = '{}/{}.log'.format(outdir, label) diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py index 20b1f54b5..7ff6f4417 100644 --- a/bilby/gw/detector/networks.py +++ b/bilby/gw/detector/networks.py @@ -25,12 +25,12 @@ class InterferometerList(list): """ super(InterferometerList, self).__init__() - if type(interferometers) == str: + if isinstance(interferometers, str): raise TypeError("Input must not be a string") for ifo in interferometers: - if type(ifo) == str: + if isinstance(ifo, str): ifo = get_empty_interferometer(ifo) - if type(ifo) not in [Interferometer, TriangularInterferometer]: + if not isinstance(ifo, (Interferometer, TriangularInterferometer)): raise TypeError( "Input list of interferometers are not all Interferometer objects" ) diff --git a/bilby/gw/eos/eos.py b/bilby/gw/eos/eos.py index 5693fb33e..ca7799ebb 100644 --- a/bilby/gw/eos/eos.py +++ b/bilby/gw/eos/eos.py @@ -60,12 +60,12 @@ class TabularEOS(object): self.sampling_flag = sampling_flag self.warning_flag = warning_flag - if type(eos) == str: + if isinstance(eos, str): if eos in valid_eos_dict.keys(): table = np.loadtxt(valid_eos_dict[eos]) else: table = np.loadtxt(eos) - elif type(eos) == np.ndarray: + elif isinstance(eos, np.ndarray): table = eos else: raise ValueError("eos provided is invalid type please supply a str name, str path to ASCII file, " diff --git a/test/core/likelihood_test.py b/test/core/likelihood_test.py index fb6ffa0c9..a7eb4a1c8 100644 --- a/test/core/likelihood_test.py +++ b/test/core/likelihood_test.py @@ -193,7 +193,7 @@ class TestGaussianLikelihood(unittest.TestCase): likelihood.parameters["m"] = 2 likelihood.parameters["c"] = 0 likelihood.log_likelihood() - self.assertTrue(type(likelihood.sigma) == type(sigma_array)) + self.assertTrue(type(likelihood.sigma) == type(sigma_array)) # noqa: E721 self.assertTrue(all(likelihood.sigma == sigma_array)) def test_set_sigma_None(self): diff --git a/test/core/prior/joint_test.py b/test/core/prior/joint_test.py index ebadfcfae..c99373b00 100644 --- a/test/core/prior/joint_test.py +++ b/test/core/prior/joint_test.py @@ -32,7 +32,7 @@ MultivariateGaussianDist( self.assertTrue(item == fromstr.__getattribute__(key)) elif key == "mvn": for d1, d2 in zip(fromstr.__getattribute__(key), item): - self.assertTrue(type(d1) == type(d2)) + self.assertTrue(type(d1) == type(d2)) # noqa: E721 elif isinstance(item, (list, tuple, np.ndarray)): self.assertTrue( np.all(np.array(item) == np.array(fromstr.__getattribute__(key))) diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 87a007072..286eadea2 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -189,10 +189,7 @@ class TestGWTransient(unittest.TestCase): self.assertListEqual( bilby.gw.detector.InterferometerList(ifos), self.likelihood.interferometers ) - self.assertTrue( - type(self.likelihood.interferometers) - == bilby.gw.detector.InterferometerList - ) + self.assertIsInstance(self.likelihood.interferometers, bilby.gw.detector.InterferometerList) def test_interferometers_setting_interferometer_list(self): ifos = bilby.gw.detector.InterferometerList( @@ -205,10 +202,7 @@ class TestGWTransient(unittest.TestCase): self.assertListEqual( bilby.gw.detector.InterferometerList(ifos), self.likelihood.interferometers ) - self.assertTrue( - type(self.likelihood.interferometers) - == bilby.gw.detector.InterferometerList - ) + self.assertIsInstance(self.likelihood.interferometers, bilby.gw.detector.InterferometerList) def test_meta_data(self): expected = dict( -- GitLab