From a9566d17708059a1547cf9d2aedbb26f626335dd Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Tue, 10 May 2022 15:56:28 +0000 Subject: [PATCH] Fix integration tests --- bilby/gw/conversion.py | 6 +++++- test/gw/conversion_test.py | 24 ++++++++++++++++-------- test/integration/example_test.py | 12 +++++++++++- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index e811612f6..98fc76689 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -9,7 +9,7 @@ import multiprocessing import pickle import numpy as np -from pandas import DataFrame +from pandas import DataFrame, Series from ..core.likelihood import MarginalizedLikelihoodReconstructionError from ..core.utils import logger, solar_mass, command_line_args @@ -32,12 +32,16 @@ def redshift_to_comoving_distance(redshift, cosmology=None): def luminosity_distance_to_redshift(distance, cosmology=None): from astropy import units cosmology = get_cosmology(cosmology) + if isinstance(distance, Series): + distance = distance.values return z_at_value(cosmology.luminosity_distance, distance * units.Mpc) def comoving_distance_to_redshift(distance, cosmology=None): from astropy import units cosmology = get_cosmology(cosmology) + if isinstance(distance, Series): + distance = distance.values return z_at_value(cosmology.comoving_distance, distance * units.Mpc) diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py index 5df2a7499..54dd064f0 100644 --- a/test/gw/conversion_test.py +++ b/test/gw/conversion_test.py @@ -1,6 +1,7 @@ import unittest import numpy as np +import pandas as pd import bilby from bilby.gw import conversion @@ -442,20 +443,27 @@ class TestGenerateAllParameters(unittest.TestCase): "lambda_tilde", "delta_lambda_tilde", ] + self.data_frame = pd.DataFrame({ + key: [value] * 100 for key, value in self.parameters.items() + }) def test_generate_all_bbh_parameters(self): - new_parameters = bilby.gw.conversion.generate_all_bbh_parameters( - self.parameters + self._generate( + bilby.gw.conversion.generate_all_bbh_parameters, + self.expected_bbh_keys, ) - for key in self.expected_bbh_keys: - self.assertIn(key, new_parameters) def test_generate_all_bns_parameters(self): - new_parameters = bilby.gw.conversion.generate_all_bns_parameters( - self.parameters + self._generate( + bilby.gw.conversion.generate_all_bns_parameters, + self.expected_bbh_keys + self.expected_tidal_keys, ) - for key in self.expected_bbh_keys + self.expected_tidal_keys: - self.assertIn(key, new_parameters) + + def _generate(self, func, expected): + for values in [self.parameters, self.data_frame]: + new_parameters = func(values) + for key in expected: + self.assertIn(key, new_parameters) class TestDistanceTransformations(unittest.TestCase): diff --git a/test/integration/example_test.py b/test/integration/example_test.py index b9df79121..266001fc1 100644 --- a/test/integration/example_test.py +++ b/test/integration/example_test.py @@ -29,7 +29,6 @@ def _execute_file(name, fname): dname, fname = os.path.split(fname) old_directory = os.getcwd() os.chdir(dname) - print(f"Running {fname} from {dname}") spec = importlib.util.spec_from_file_location(name, fname) foo = importlib.util.module_from_spec(spec) spec.loader.exec_module(foo) @@ -41,6 +40,17 @@ class ExampleTest(unittest.TestCase): dir_path = os.path.dirname(os.path.realpath(__file__)) dir_path = os.path.abspath(os.path.join(dir_path, os.path.pardir)) + def setUp(self): + self.init_dir = os.getcwd() + + def tearDown(self): + if os.path.isdir(self.outdir): + try: + shutil.rmtree(self.outdir) + except OSError: + logging.warning("{} not removed after tests".format(self.outdir)) + os.chdir(self.init_dir) + @classmethod def setUpClass(cls): if os.path.isdir(cls.outdir): -- GitLab