diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index e811612f627da3855ae96be98d12e0e9f972d1a7..98fc76689660f756fcd9c0163840ff0e92e2933a 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -9,7 +9,7 @@ import multiprocessing import pickle import numpy as np -from pandas import DataFrame +from pandas import DataFrame, Series from ..core.likelihood import MarginalizedLikelihoodReconstructionError from ..core.utils import logger, solar_mass, command_line_args @@ -32,12 +32,16 @@ def redshift_to_comoving_distance(redshift, cosmology=None): def luminosity_distance_to_redshift(distance, cosmology=None): from astropy import units cosmology = get_cosmology(cosmology) + if isinstance(distance, Series): + distance = distance.values return z_at_value(cosmology.luminosity_distance, distance * units.Mpc) def comoving_distance_to_redshift(distance, cosmology=None): from astropy import units cosmology = get_cosmology(cosmology) + if isinstance(distance, Series): + distance = distance.values return z_at_value(cosmology.comoving_distance, distance * units.Mpc) diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py index 5df2a74997cab417ecf3cf5754d8de5e48793ad1..54dd064f08e956587debb4951dcf08270efaa612 100644 --- a/test/gw/conversion_test.py +++ b/test/gw/conversion_test.py @@ -1,6 +1,7 @@ import unittest import numpy as np +import pandas as pd import bilby from bilby.gw import conversion @@ -442,20 +443,27 @@ class TestGenerateAllParameters(unittest.TestCase): "lambda_tilde", "delta_lambda_tilde", ] + self.data_frame = pd.DataFrame({ + key: [value] * 100 for key, value in self.parameters.items() + }) def test_generate_all_bbh_parameters(self): - new_parameters = bilby.gw.conversion.generate_all_bbh_parameters( - self.parameters + self._generate( + bilby.gw.conversion.generate_all_bbh_parameters, + self.expected_bbh_keys, ) - for key in self.expected_bbh_keys: - self.assertIn(key, new_parameters) def test_generate_all_bns_parameters(self): - new_parameters = bilby.gw.conversion.generate_all_bns_parameters( - self.parameters + self._generate( + bilby.gw.conversion.generate_all_bns_parameters, + self.expected_bbh_keys + self.expected_tidal_keys, ) - for key in self.expected_bbh_keys + self.expected_tidal_keys: - self.assertIn(key, new_parameters) + + def _generate(self, func, expected): + for values in [self.parameters, self.data_frame]: + new_parameters = func(values) + for key in expected: + self.assertIn(key, new_parameters) class TestDistanceTransformations(unittest.TestCase): diff --git a/test/integration/example_test.py b/test/integration/example_test.py index b9df79121ed09e64da16d3f51ea1763d2341d0a7..266001fc1eb45f3e0b6f89b016a8058c49c125f7 100644 --- a/test/integration/example_test.py +++ b/test/integration/example_test.py @@ -29,7 +29,6 @@ def _execute_file(name, fname): dname, fname = os.path.split(fname) old_directory = os.getcwd() os.chdir(dname) - print(f"Running {fname} from {dname}") spec = importlib.util.spec_from_file_location(name, fname) foo = importlib.util.module_from_spec(spec) spec.loader.exec_module(foo) @@ -41,6 +40,17 @@ class ExampleTest(unittest.TestCase): dir_path = os.path.dirname(os.path.realpath(__file__)) dir_path = os.path.abspath(os.path.join(dir_path, os.path.pardir)) + def setUp(self): + self.init_dir = os.getcwd() + + def tearDown(self): + if os.path.isdir(self.outdir): + try: + shutil.rmtree(self.outdir) + except OSError: + logging.warning("{} not removed after tests".format(self.outdir)) + os.chdir(self.init_dir) + @classmethod def setUpClass(cls): if os.path.isdir(cls.outdir):