diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py index 1b528ec78cad6bd929feda0658b4b7f1d5c222d7..fcd4c9c5aa82d02334b1ffe02e5fbe501d4e14b8 100644 --- a/bilby/bilby_mcmc/sampler.py +++ b/bilby/bilby_mcmc/sampler.py @@ -388,7 +388,7 @@ class Bilby_MCMC(MCMCSampler): with open(self.resume_file, "rb") as file: ptsampler = dill.load(file) - if type(ptsampler) != BilbyPTMCMCSampler: + if not isinstance(ptsampler, BilbyPTMCMCSampler): logger.debug("Malformed resume file, ignoring") return False self.ptsampler = ptsampler diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 36c14de033449511d5f5b76970ff2afc4cd59cfd..6b4f3461932d2b19c4b4396ea21c9a365c9db5cc 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -682,11 +682,11 @@ class Sampler(object): if self.cached_result is None: kwargs_print = self.kwargs.copy() for k in kwargs_print: - if type(kwargs_print[k]) in (list, np.ndarray): + if isinstance(kwargs_print[k], (list, np.ndarray)): array_repr = np.array(kwargs_print[k]) if array_repr.size > 10: kwargs_print[k] = f"array_like, shape={array_repr.shape}" - elif type(kwargs_print[k]) == DataFrame: + elif isinstance(kwargs_print[k], DataFrame): kwargs_print[k] = f"DataFrame, shape={kwargs_print[k].shape}" logger.info( f"Using sampler {self.__class__.__name__} with kwargs {kwargs_print}" diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py index bc3b364656d26bcff0c14e3852bbbd394c5887cf..e1f3ae19e39f37854643aa86aaac9beece78465d 100644 --- a/bilby/core/sampler/cpnest.py +++ b/bilby/core/sampler/cpnest.py @@ -182,7 +182,7 @@ class Cpnest(NestedSampler): if "proposals" in self.kwargs: if self.kwargs["proposals"] is None: return - if type(self.kwargs["proposals"]) == JumpProposalCycle: + if isinstance(self.kwargs["proposals"], JumpProposalCycle): self.kwargs["proposals"] = dict( mhs=self.kwargs["proposals"], hmc=self.kwargs["proposals"] ) diff --git a/bilby/core/utils/log.py b/bilby/core/utils/log.py index 4884eba9cc76cf408fdcf8e5a794b6521b9ca5c2..ca86b814b3f583aea8aac2fdba3013644bd8cb00 100644 --- a/bilby/core/utils/log.py +++ b/bilby/core/utils/log.py @@ -22,7 +22,7 @@ def setup_logger(outdir='.', label=None, log_level='INFO', print_version=False): If true, print version information """ - if type(log_level) is str: + if isinstance(log_level, str): try: level = getattr(logging, log_level.upper()) except AttributeError: @@ -34,14 +34,14 @@ def setup_logger(outdir='.', label=None, log_level='INFO', print_version=False): logger.propagate = False logger.setLevel(level) - if any([type(h) == logging.StreamHandler for h in logger.handlers]) is False: + if not any([isinstance(h, logging.StreamHandler) for h in logger.handlers]): stream_handler = logging.StreamHandler() stream_handler.setFormatter(logging.Formatter( '%(asctime)s %(name)s %(levelname)-8s: %(message)s', datefmt='%H:%M')) stream_handler.setLevel(level) logger.addHandler(stream_handler) - if any([type(h) == logging.FileHandler for h in logger.handlers]) is False: + if not any([isinstance(h, logging.FileHandler) for h in logger.handlers]): if label: Path(outdir).mkdir(parents=True, exist_ok=True) log_file = '{}/{}.log'.format(outdir, label) diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py index 20b1f54b5fc94bb10d3db4ce051b70237234d4be..7ff6f44176c96b5254d98565aba82fed15a007a7 100644 --- a/bilby/gw/detector/networks.py +++ b/bilby/gw/detector/networks.py @@ -25,12 +25,12 @@ class InterferometerList(list): """ super(InterferometerList, self).__init__() - if type(interferometers) == str: + if isinstance(interferometers, str): raise TypeError("Input must not be a string") for ifo in interferometers: - if type(ifo) == str: + if isinstance(ifo, str): ifo = get_empty_interferometer(ifo) - if type(ifo) not in [Interferometer, TriangularInterferometer]: + if not isinstance(ifo, (Interferometer, TriangularInterferometer)): raise TypeError( "Input list of interferometers are not all Interferometer objects" ) diff --git a/bilby/gw/eos/eos.py b/bilby/gw/eos/eos.py index 5693fb33e927153ba3a577a592bc6fe557b5944e..ca7799ebb41728443433fc448b1bd6038c3db433 100644 --- a/bilby/gw/eos/eos.py +++ b/bilby/gw/eos/eos.py @@ -60,12 +60,12 @@ class TabularEOS(object): self.sampling_flag = sampling_flag self.warning_flag = warning_flag - if type(eos) == str: + if isinstance(eos, str): if eos in valid_eos_dict.keys(): table = np.loadtxt(valid_eos_dict[eos]) else: table = np.loadtxt(eos) - elif type(eos) == np.ndarray: + elif isinstance(eos, np.ndarray): table = eos else: raise ValueError("eos provided is invalid type please supply a str name, str path to ASCII file, " diff --git a/test/core/likelihood_test.py b/test/core/likelihood_test.py index fb6ffa0c9d29eba66e19b82b4c1daaf5a6da9fe3..a7eb4a1c80ea99eadb3148597b74b55bac854770 100644 --- a/test/core/likelihood_test.py +++ b/test/core/likelihood_test.py @@ -193,7 +193,7 @@ class TestGaussianLikelihood(unittest.TestCase): likelihood.parameters["m"] = 2 likelihood.parameters["c"] = 0 likelihood.log_likelihood() - self.assertTrue(type(likelihood.sigma) == type(sigma_array)) + self.assertTrue(type(likelihood.sigma) == type(sigma_array)) # noqa: E721 self.assertTrue(all(likelihood.sigma == sigma_array)) def test_set_sigma_None(self): diff --git a/test/core/prior/joint_test.py b/test/core/prior/joint_test.py index ebadfcfaeb96ed1cd017a212097c8ff65f5c477c..c99373b00a8494b5d0a3b67e44aa40c0860db7ca 100644 --- a/test/core/prior/joint_test.py +++ b/test/core/prior/joint_test.py @@ -32,7 +32,7 @@ MultivariateGaussianDist( self.assertTrue(item == fromstr.__getattribute__(key)) elif key == "mvn": for d1, d2 in zip(fromstr.__getattribute__(key), item): - self.assertTrue(type(d1) == type(d2)) + self.assertTrue(type(d1) == type(d2)) # noqa: E721 elif isinstance(item, (list, tuple, np.ndarray)): self.assertTrue( np.all(np.array(item) == np.array(fromstr.__getattribute__(key))) diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 87a007072f7dff2f0ee00976356eef43a3d5d15b..286eadea27966d1cebb72d77b4b4a90fea5a1c61 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -189,10 +189,7 @@ class TestGWTransient(unittest.TestCase): self.assertListEqual( bilby.gw.detector.InterferometerList(ifos), self.likelihood.interferometers ) - self.assertTrue( - type(self.likelihood.interferometers) - == bilby.gw.detector.InterferometerList - ) + self.assertIsInstance(self.likelihood.interferometers, bilby.gw.detector.InterferometerList) def test_interferometers_setting_interferometer_list(self): ifos = bilby.gw.detector.InterferometerList( @@ -205,10 +202,7 @@ class TestGWTransient(unittest.TestCase): self.assertListEqual( bilby.gw.detector.InterferometerList(ifos), self.likelihood.interferometers ) - self.assertTrue( - type(self.likelihood.interferometers) - == bilby.gw.detector.InterferometerList - ) + self.assertIsInstance(self.likelihood.interferometers, bilby.gw.detector.InterferometerList) def test_meta_data(self): expected = dict(