From a9566d17708059a1547cf9d2aedbb26f626335dd Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Tue, 10 May 2022 15:56:28 +0000
Subject: [PATCH] Fix integration tests

---
 bilby/gw/conversion.py           |  6 +++++-
 test/gw/conversion_test.py       | 24 ++++++++++++++++--------
 test/integration/example_test.py | 12 +++++++++++-
 3 files changed, 32 insertions(+), 10 deletions(-)

diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py
index e811612f6..98fc76689 100644
--- a/bilby/gw/conversion.py
+++ b/bilby/gw/conversion.py
@@ -9,7 +9,7 @@ import multiprocessing
 import pickle
 
 import numpy as np
-from pandas import DataFrame
+from pandas import DataFrame, Series
 
 from ..core.likelihood import MarginalizedLikelihoodReconstructionError
 from ..core.utils import logger, solar_mass, command_line_args
@@ -32,12 +32,16 @@ def redshift_to_comoving_distance(redshift, cosmology=None):
 def luminosity_distance_to_redshift(distance, cosmology=None):
     from astropy import units
     cosmology = get_cosmology(cosmology)
+    if isinstance(distance, Series):
+        distance = distance.values
     return z_at_value(cosmology.luminosity_distance, distance * units.Mpc)
 
 
 def comoving_distance_to_redshift(distance, cosmology=None):
     from astropy import units
     cosmology = get_cosmology(cosmology)
+    if isinstance(distance, Series):
+        distance = distance.values
     return z_at_value(cosmology.comoving_distance, distance * units.Mpc)
 
 
diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py
index 5df2a7499..54dd064f0 100644
--- a/test/gw/conversion_test.py
+++ b/test/gw/conversion_test.py
@@ -1,6 +1,7 @@
 import unittest
 
 import numpy as np
+import pandas as pd
 
 import bilby
 from bilby.gw import conversion
@@ -442,20 +443,27 @@ class TestGenerateAllParameters(unittest.TestCase):
             "lambda_tilde",
             "delta_lambda_tilde",
         ]
+        self.data_frame = pd.DataFrame({
+            key: [value] * 100 for key, value in self.parameters.items()
+        })
 
     def test_generate_all_bbh_parameters(self):
-        new_parameters = bilby.gw.conversion.generate_all_bbh_parameters(
-            self.parameters
+        self._generate(
+            bilby.gw.conversion.generate_all_bbh_parameters,
+            self.expected_bbh_keys,
         )
-        for key in self.expected_bbh_keys:
-            self.assertIn(key, new_parameters)
 
     def test_generate_all_bns_parameters(self):
-        new_parameters = bilby.gw.conversion.generate_all_bns_parameters(
-            self.parameters
+        self._generate(
+            bilby.gw.conversion.generate_all_bns_parameters,
+            self.expected_bbh_keys + self.expected_tidal_keys,
         )
-        for key in self.expected_bbh_keys + self.expected_tidal_keys:
-            self.assertIn(key, new_parameters)
+
+    def _generate(self, func, expected):
+        for values in [self.parameters, self.data_frame]:
+            new_parameters = func(values)
+            for key in expected:
+                self.assertIn(key, new_parameters)
 
 
 class TestDistanceTransformations(unittest.TestCase):
diff --git a/test/integration/example_test.py b/test/integration/example_test.py
index b9df79121..266001fc1 100644
--- a/test/integration/example_test.py
+++ b/test/integration/example_test.py
@@ -29,7 +29,6 @@ def _execute_file(name, fname):
     dname, fname = os.path.split(fname)
     old_directory = os.getcwd()
     os.chdir(dname)
-    print(f"Running {fname} from {dname}")
     spec = importlib.util.spec_from_file_location(name, fname)
     foo = importlib.util.module_from_spec(spec)
     spec.loader.exec_module(foo)
@@ -41,6 +40,17 @@ class ExampleTest(unittest.TestCase):
     dir_path = os.path.dirname(os.path.realpath(__file__))
     dir_path = os.path.abspath(os.path.join(dir_path, os.path.pardir))
 
+    def setUp(self):
+        self.init_dir = os.getcwd()
+
+    def tearDown(self):
+        if os.path.isdir(self.outdir):
+            try:
+                shutil.rmtree(self.outdir)
+            except OSError:
+                logging.warning("{} not removed after tests".format(self.outdir))
+        os.chdir(self.init_dir)
+
     @classmethod
     def setUpClass(cls):
         if os.path.isdir(cls.outdir):
-- 
GitLab