diff --git a/bilby/gw/detector.py b/bilby/gw/detector.py index c7586cb86b3c3bc99039af52263cfdc755a0543f..f92849e757610294fcc05fdeb5b17e82381c56eb 100644 --- a/bilby/gw/detector.py +++ b/bilby/gw/detector.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import numpy as np from scipy.signal.windows import tukey from scipy.interpolate import interp1d +import deepdish as dd from . import utils as gwutils from ..core import utils @@ -1601,6 +1602,16 @@ class Interferometer(object): fig.savefig( '{}/{}_{}_time_domain_data.png'.format(outdir, self.name, label)) + def to_hdf5(self, outdir='outdir', label=None): + if label is None: + label = self.name + utils.check_directory_exists_and_if_not_mkdir('outdir') + dd.io.save('./' + outdir + '/' + label + '.h5', self) + + @staticmethod + def from_hdf5(outdir, label): + return dd.io.load('./' + outdir + '/' + label + '.h5') + class TriangularInterferometer(InterferometerList): diff --git a/test/detector_test.py b/test/detector_test.py index 9dd9d691eb765b97ad5272a731697d1859cf0dd8..7f88b4df216dc62996b4be12d0fa6b1ee611aa37 100644 --- a/test/detector_test.py +++ b/test/detector_test.py @@ -8,14 +8,15 @@ from mock import patch import numpy as np import scipy.signal.windows import os +from shutil import rmtree import logging -class TestDetector(unittest.TestCase): +class TestInterferometer(unittest.TestCase): def setUp(self): self.name = 'name' - self.power_spectral_density = MagicMock() + self.power_spectral_density = bilby.gw.detector.PowerSpectralDensity.from_aligo() self.minimum_frequency = 10 self.maximum_frequency = 20 self.length = 30 @@ -322,7 +323,10 @@ class TestDetector(unittest.TestCase): signal = 1 expected = [signal, signal, self.ifo.power_spectral_density_array, self.ifo.strain_data.duration] actual = self.ifo.optimal_snr_squared(signal=signal) - self.assertListEqual(expected, actual) + self.assertEqual(expected[0], actual[0]) + self.assertEqual(expected[1], actual[1]) + self.assertTrue(np.array_equal(expected[2], actual[2])) + self.assertEqual(expected[3], actual[3]) def test_matched_filter_snr_squared(self): """ Merely checks parameters are given in the right order """ @@ -333,7 +337,9 @@ class TestDetector(unittest.TestCase): self.ifo.strain_data.duration]] actual = self.ifo.matched_filter_snr_squared(signal=signal) self.assertTrue(np.array_equal(expected[0], actual[0])) # array-like element has to be evaluated separately - self.assertListEqual(expected[1], actual[1]) + self.assertEqual(expected[1][0], actual[1][0]) + self.assertTrue(np.array_equal(expected[1][1], actual[1][1])) + self.assertEqual(expected[1][2], actual[1][2]) def test_repr(self): expected = 'Interferometer(name=\'{}\', power_spectral_density={}, minimum_frequency={}, ' \ @@ -345,6 +351,12 @@ class TestDetector(unittest.TestCase): float(self.yarm_tilt)) self.assertEqual(expected, repr(self.ifo)) + def test_to_and_from_hdf5(self): + self.ifo.to_hdf5(outdir='outdir', label='test') + recovered_ifo = bilby.gw.detector.Interferometer.from_hdf5(outdir='outdir', label='test') + self.assertEqual(self.ifo, recovered_ifo) + rmtree('outdir') + class TestInterferometerEquals(unittest.TestCase):