From 34cbf86ed85497331380296e6749a3d13a0c3973 Mon Sep 17 00:00:00 2001
From: Colm Talbot <talbotcolm@gmail.com>
Date: Sat, 10 Feb 2024 07:51:46 -0600
Subject: [PATCH] TST: update frame reading tests for latest version of gwpy

---
 bilby/gw/utils.py     |  4 ++--
 test/gw/utils_test.py | 30 +++++++++++++++++++++++-------
 2 files changed, 25 insertions(+), 9 deletions(-)

diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py
index c581a6c71..65f0bbf03 100644
--- a/bilby/gw/utils.py
+++ b/bilby/gw/utils.py
@@ -398,7 +398,7 @@ def read_frame_file(file_name, start_time, end_time, resample=None, channel=None
             strain = TimeSeries.read(source=file_name, channel=channel, start=start_time, end=end_time, **kwargs)
             loaded = True
             logger.info('Successfully loaded {}.'.format(channel))
-        except RuntimeError:
+        except (RuntimeError, ValueError):
             logger.warning('Channel {} not found. Trying preset channel names'.format(channel))
 
     if loaded is False:
@@ -418,7 +418,7 @@ def read_frame_file(file_name, start_time, end_time, resample=None, channel=None
                                              **kwargs)
                     loaded = True
                     logger.info('Successfully read strain data for channel {}.'.format(channel))
-                except RuntimeError:
+                except (RuntimeError, ValueError):
                     pass
 
     if loaded:
diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py
index f2aeb1c78..b67d72d5d 100644
--- a/test/gw/utils_test.py
+++ b/test/gw/utils_test.py
@@ -1,6 +1,7 @@
 import unittest
 import os
 from shutil import rmtree
+from importlib.metadata import version
 
 import numpy as np
 import lal
@@ -89,12 +90,28 @@ class TestGWUtils(unittest.TestCase):
         with self.assertRaises(ValueError):
             gwutils.get_event_time("GW010290")
 
+    @pytest.mark.skipif(version("gwpy") < "3.0.8", reason="GWpy version < 3.0.8")
     def test_read_frame_file(self):
+        """
+        Test that reading a frame file works as expected
+        for a few conditions.
+
+        1. Reading without time limits returns the full data
+        2. Reading with time limits returns the expected data
+           (inclusive of start time if present, exclusive of end time)
+        3. Reading without the channel name provided finds a standard name
+        4. Reading without the channel with a non-standard name returns None.
+
+        Notes
+        =====
+        There was a longstanding bug in gwpy that we previously tested for
+        here, but this has been fixed in gwpy 3.0.8.
+        """
         start_time = 0
         end_time = 10
         channel = "H1:GDS-CALIB_STRAIN"
         N = 100
-        times = np.linspace(start_time, end_time, N)
+        times = np.linspace(start_time, end_time, N, endpoint=False)
         data = np.random.normal(0, 1, N)
         ts = TimeSeries(data=data, times=times, t0=0)
         ts.channel = Channel(channel)
@@ -107,7 +124,7 @@ class TestGWUtils(unittest.TestCase):
             filename, start_time=None, end_time=None, channel=channel
         )
         self.assertEqual(strain.name, channel)
-        self.assertTrue(np.all(strain.value == data[:-1]))
+        self.assertTrue(np.all(strain.value == data))
 
         # Check reading with time limits
         start_cut = 2
@@ -115,19 +132,18 @@ class TestGWUtils(unittest.TestCase):
         strain = gwutils.read_frame_file(
             filename, start_time=start_cut, end_time=end_cut, channel=channel
         )
-        idxs = (times > start_cut) & (times < end_cut)
-        # Dropping the last element - for some reason gwpy drops the last element when reading in data
-        self.assertTrue(np.all(strain.value == data[idxs][:-1]))
+        idxs = (times >= start_cut) & (times < end_cut)
+        self.assertTrue(np.all(strain.value == data[idxs]))
 
         # Check reading with unknown channels
         strain = gwutils.read_frame_file(filename, start_time=None, end_time=None)
-        self.assertTrue(np.all(strain.value == data[:-1]))
+        self.assertTrue(np.all(strain.value == data))
 
         # Check reading with incorrect channel
         strain = gwutils.read_frame_file(
             filename, start_time=None, end_time=None, channel="WRONG"
         )
-        self.assertTrue(np.all(strain.value == data[:-1]))
+        self.assertTrue(np.all(strain.value == data))
 
         ts = TimeSeries(data=data, times=times, t0=0)
         ts.name = "NOT-A-KNOWN-CHANNEL"
-- 
GitLab