From 0eec235ea567c4f8d9bb0fd8aa914aff620e770a Mon Sep 17 00:00:00 2001
From: Moritz <email@moritz-huebner.de>
Date: Tue, 13 Nov 2018 17:32:58 +1100
Subject: [PATCH] Implemented to_hdf5 and from_hdf5 for Interferometer and
 InterferometerList

---
 bilby/gw/detector.py  | 22 +++++++++++++++++++---
 test/detector_test.py | 40 ++++++++++++++++++++++++++++------------
 2 files changed, 47 insertions(+), 15 deletions(-)

diff --git a/bilby/gw/detector.py b/bilby/gw/detector.py
index f92849e75..5d92a9da8 100644
--- a/bilby/gw/detector.py
+++ b/bilby/gw/detector.py
@@ -207,6 +207,19 @@ class InterferometerList(list):
         super(InterferometerList, self).insert(index, interferometer)
         self._check_interferometers()
 
+    def to_hdf5(self, outdir='outdir', label='ifo_list'):
+        utils.check_directory_exists_and_if_not_mkdir('outdir')
+        dd.io.save('./' + outdir + '/' + label + '.h5', self)
+
+    @classmethod
+    def from_hdf5(cls, outdir, label):
+        res = dd.io.load('./' + outdir + '/' + label + '.h5')
+        if res.__class__ == list:
+            res = cls(res)
+        if res.__class__ != cls:
+            raise TypeError('The loaded object is not a InterferometerList')
+        return res
+
 
 class InterferometerStrainData(object):
     """ Strain data for an interferometer """
@@ -1608,9 +1621,12 @@ class Interferometer(object):
         utils.check_directory_exists_and_if_not_mkdir('outdir')
         dd.io.save('./' + outdir + '/' + label + '.h5', self)
 
-    @staticmethod
-    def from_hdf5(outdir, label):
-        return dd.io.load('./' + outdir + '/' + label + '.h5')
+    @classmethod
+    def from_hdf5(cls, outdir, label):
+        res = dd.io.load('./' + outdir + '/' + label + '.h5')
+        if res.__class__ != cls:
+            raise TypeError('The loaded object is not a InterferometerList')
+        return res
 
 
 class TriangularInterferometer(InterferometerList):
diff --git a/test/detector_test.py b/test/detector_test.py
index 7f88b4df2..b9fa7803f 100644
--- a/test/detector_test.py
+++ b/test/detector_test.py
@@ -10,6 +10,7 @@ import scipy.signal.windows
 import os
 from shutil import rmtree
 import logging
+import deepdish as dd
 
 
 class TestInterferometer(unittest.TestCase):
@@ -37,6 +38,7 @@ class TestInterferometer(unittest.TestCase):
                                                     xarm_tilt=self.xarm_tilt, yarm_tilt=self.yarm_tilt)
         self.ifo.strain_data.set_from_frequency_domain_strain(
             np.linspace(0, 4096, 4097), sampling_frequency=4096, duration=2)
+        bilby.core.utils.check_directory_exists_and_if_not_mkdir('outdir')
 
     def tearDown(self):
         del self.name
@@ -52,6 +54,7 @@ class TestInterferometer(unittest.TestCase):
         del self.xarm_tilt
         del self.yarm_tilt
         del self.ifo
+        rmtree('outdir')
 
     def test_name_setting(self):
         self.assertEqual(self.ifo.name, self.name)
@@ -351,11 +354,16 @@ class TestInterferometer(unittest.TestCase):
                     float(self.yarm_tilt))
         self.assertEqual(expected, repr(self.ifo))
 
-    def test_to_and_from_hdf5(self):
+    def test_to_and_from_hdf5_loading(self):
         self.ifo.to_hdf5(outdir='outdir', label='test')
         recovered_ifo = bilby.gw.detector.Interferometer.from_hdf5(outdir='outdir', label='test')
         self.assertEqual(self.ifo, recovered_ifo)
-        rmtree('outdir')
+
+    def test_to_and_from_hdf5_wrong_class(self):
+        bilby.core.utils.check_directory_exists_and_if_not_mkdir('outdir')
+        dd.io.save('./outdir/psd.h5', self.power_spectral_density)
+        with self.assertRaises(TypeError):
+            bilby.gw.detector.Interferometer.from_hdf5(outdir='outdir', label='psd')
 
 
 class TestInterferometerEquals(unittest.TestCase):
@@ -838,12 +846,8 @@ class TestInterferometerList(unittest.TestCase):
         self.frequency_arrays = np.linspace(0, 4096, 4097)
         self.name1 = 'name1'
         self.name2 = 'name2'
-        self.power_spectral_density1 = MagicMock()
-        self.power_spectral_density1.get_noise_realisation = MagicMock(return_value=(self.frequency_arrays,
-                                                                                     self.frequency_arrays))
-        self.power_spectral_density2 = MagicMock()
-        self.power_spectral_density2.get_noise_realisation = MagicMock(return_value=(self.frequency_arrays,
-                                                                                     self.frequency_arrays))
+        self.power_spectral_density1 = bilby.gw.detector.PowerSpectralDensity.from_aligo()
+        self.power_spectral_density2 = bilby.gw.detector.PowerSpectralDensity.from_aligo()
         self.minimum_frequency1 = 10
         self.minimum_frequency2 = 10
         self.maximum_frequency1 = 20
@@ -886,6 +890,7 @@ class TestInterferometerList(unittest.TestCase):
         self.ifo2.strain_data.set_from_frequency_domain_strain(
             self.frequency_arrays, sampling_frequency=4096, duration=2)
         self.ifo_list = bilby.gw.detector.InterferometerList([self.ifo1, self.ifo2])
+        bilby.core.utils.check_directory_exists_and_if_not_mkdir('outdir')
 
     def tearDown(self):
         del self.frequency_arrays
@@ -916,20 +921,21 @@ class TestInterferometerList(unittest.TestCase):
         del self.ifo1
         del self.ifo2
         del self.ifo_list
+        rmtree('outdir')
 
     def test_init_with_string(self):
-        with self.assertRaises(ValueError):
+        with self.assertRaises(TypeError):
             bilby.gw.detector.InterferometerList("string")
 
     def test_init_with_string_list(self):
         """ Merely checks if this ends up in the right bracket """
         with mock.patch('bilby.gw.detector.get_empty_interferometer') as m:
-            m.side_effect = ValueError
-            with self.assertRaises(ValueError):
+            m.side_effect = TypeError
+            with self.assertRaises(TypeError):
                 bilby.gw.detector.InterferometerList(['string'])
 
     def test_init_with_other_object(self):
-        with self.assertRaises(ValueError):
+        with self.assertRaises(TypeError):
             bilby.gw.detector.InterferometerList([object()])
 
     def test_init_with_actual_ifos(self):
@@ -1034,6 +1040,16 @@ class TestInterferometerList(unittest.TestCase):
         names = [ifo.name for ifo in self.ifo_list]
         self.assertListEqual([self.ifo1.name, new_ifo.name, self.ifo2.name], names)
 
+    def test_to_and_from_hdf5_loading(self):
+        self.ifo_list.to_hdf5(outdir='outdir', label='test')
+        recovered_ifo = bilby.gw.detector.InterferometerList.from_hdf5(outdir='outdir', label='test')
+        self.assertListEqual(self.ifo_list, recovered_ifo)
+
+    def test_to_and_from_hdf5_wrong_class(self):
+        dd.io.save('./outdir/psd.h5', self.ifo_list[0].power_spectral_density)
+        with self.assertRaises(TypeError):
+            bilby.gw.detector.InterferometerList.from_hdf5(outdir='outdir', label='psd')
+
 
 class TestPowerSpectralDensityWithoutFiles(unittest.TestCase):
 
-- 
GitLab