diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c7933b7681060dfd53efaf203647b2a52f195956..a19412a78c0974cbf6e095ef4935b39e0e74d3a2 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -45,11 +45,10 @@ basic-3.7: # test example on python 3.7 python-3.7: stage: test + needs: ["basic-3.7", "precommits-py3.7"] image: quay.io/bilbydev/v2-dockerfile-test-suite-python37 script: - python -m pip install . - # temporary fix for broken astropy version - - pip install pyerfa==1.7.1.1 # Run pyflakes - flake8 . @@ -66,13 +65,12 @@ python-3.7: docs: stage: docs + needs: ["basic-3.7"] image: quay.io/bilbydev/v2-dockerfile-test-suite-python37 script: # Make the documentation - apt-get -yqq install pandoc - python -m pip install . - # temporary fix for broken astropy version - - pip install pyerfa==1.7.1.1 - cd docs - pip install ipykernel ipython jupyter - cp ../examples/tutorials/*.ipynb ./ @@ -90,33 +88,30 @@ docs: # test example on python 3.8 python-3.8: stage: test + needs: ["basic-3.7", "precommits-py3.7"] image: quay.io/bilbydev/v2-dockerfile-test-suite-python38 script: - python -m pip install . - # temporary fix for broken astropy version - - pip install pyerfa==1.7.1.1 - pytest # test example on python 3.6 python-3.6: stage: test + needs: ["basic-3.7", "precommits-py3.7"] image: quay.io/bilbydev/v2-dockerfile-test-suite-python36 script: - python -m pip install . - # temporary fix for broken astropy version - - pip install pyerfa==1.7.1.1 - pytest # test samplers on python 3.7 python-3.7-samplers: stage: test + needs: ["basic-3.7", "precommits-py3.7"] image: quay.io/bilbydev/v2-dockerfile-test-suite-python37 script: - python -m pip install . - # temporary fix for broken astropy version - - pip install pyerfa==1.7.1.1 - pytest test/integration/sampler_run_test.py --durations 10 - pytest test/integration/sample_from_the_prior_test.py @@ -124,11 +119,10 @@ python-3.7-samplers: # test samplers on python 3.6 python-3.6-samplers: stage: test + needs: ["basic-3.7", "precommits-py3.7"] image: quay.io/bilbydev/v2-dockerfile-test-suite-python36 script: - python -m pip install . - # temporary fix for broken astropy version - - pip install pyerfa==1.7.1.1 - pytest test/integration/sampler_run_test.py @@ -151,8 +145,6 @@ scheduled-python-3.7: - schedules script: - python -m pip install . - # temporary fix for broken astropy version - - pip install pyerfa==1.7.1.1 # Run tests which are only done on schedule - pytest test/integration/example_test.py @@ -177,6 +169,7 @@ authors: pages: stage: deploy + needs: ["docs", "python-3.7"] dependencies: - docs - python-3.7 diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index f05f7d3d398328b90d97d5c7c87e4be53cafbca6..fc9918b04e37054a309b47390c588128c55648c8 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -155,7 +155,7 @@ class PriorDict(dict): @classmethod def _get_from_json_dict(cls, prior_dict): try: - cls == getattr( + cls = getattr( import_module(prior_dict["__module__"]), prior_dict["__name__"]) except ImportError: @@ -231,9 +231,19 @@ class PriorDict(dict): "= {}. Error message {}".format(key, val, e) ) elif isinstance(val, dict): - logger.warning( - 'Cannot convert {} into a prior object. ' - 'Leaving as dictionary.'.format(key)) + try: + _class = getattr( + import_module(val.get("__module__", "none")), + val.get("__name__", "none")) + dictionary[key] = _class(**val.get("kwargs", dict())) + except ImportError: + logger.debug("Cannot import prior module {}.{}".format( + val.get("__module__", "none"), val.get("__name__", "none") + )) + logger.warning( + 'Cannot convert {} into a prior object. ' + 'Leaving as dictionary.'.format(key)) + continue else: raise TypeError( "Unable to parse prior, bad entry: {} " diff --git a/bilby/core/result.py b/bilby/core/result.py index d4f544d5ed892a5a47ba233b5b668769dac4211c..7dcb4b04f426e461f6aebfb1ecf5b818b99f39f5 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -3,9 +3,11 @@ import os from collections import OrderedDict, namedtuple from copy import copy from distutils.version import LooseVersion +from importlib import import_module from itertools import product import corner +import h5py import json import matplotlib import matplotlib.pyplot as plt @@ -22,7 +24,9 @@ from .utils import ( latex_plot_format, safe_save_figure, BilbyJsonEncoder, load_json, move_old_file, get_version_information, - decode_bilby_json, + decode_bilby_json, docstring, + recursively_save_dict_contents_to_group, + recursively_load_dict_contents_from_group, ) from .prior import Prior, PriorDict, DeltaFunction @@ -88,6 +92,8 @@ def read_in_result(filename=None, outdir=None, label=None, extension='json', gzi result = Result.from_json(filename=filename) elif ('hdf5' in extension) or ('h5' in extension): result = Result.from_hdf5(filename=filename) + elif ("pkl" in extension) or ("pickle" in extension): + result = Result.from_pickle(filename=filename) elif extension is None: raise ValueError("No filetype extension provided") else: @@ -336,8 +342,8 @@ class Result(object): self._kde = None @classmethod - def from_hdf5(cls, filename=None, outdir=None, label=None): - """ Read in a saved .h5 data file + def _from_hdf5_old(cls, filename=None, outdir=None, label=None): + """ Read in a saved .h5 data file in the old format. Parameters ========== @@ -374,7 +380,10 @@ class Result(object): priordict = PriorDict() for key, value in dictionary["priors"].items(): if key not in ["__module__", "__name__", "__prior_dict__"]: - priordict[key] = decode_bilby_json(value) + try: + priordict[key] = decode_bilby_json(value) + except AttributeError: + continue dictionary["priors"] = priordict except Exception as e: raise IOError( @@ -391,27 +400,62 @@ class Result(object): else: raise IOError("No result '{}' found".format(filename)) - @classmethod - def from_json(cls, filename=None, outdir=None, label=None, gzip=False): - """ Read in a saved .json data file + _load_doctstring = """ Read in a saved .{format} data file - Parameters - ========== - filename: str - If given, try to load from this filename - outdir, label: str - If given, use the default naming convention for saved results file + Parameters + ========== + filename: str + If given, try to load from this filename + outdir, label: str + If given, use the default naming convention for saved results file - Returns - ======= - result: bilby.core.result.Result + Returns + ======= + result: bilby.core.result.Result - Raises - ======= - ValueError: If no filename is given and either outdir or label is None - If no bilby.core.result.Result is found in the path + Raises + ======= + ValueError: If no filename is given and either outdir or label is None + If no bilby.core.result.Result is found in the path - """ + """ + + @staticmethod + @docstring(_load_doctstring.format(format="pickle")) + def from_pickle(filename=None, outdir=None, label=None): + filename = _determine_file_name(filename, outdir, label, 'hdf5', False) + import dill + with open(filename, "rb") as ff: + return dill.load(ff) + + @classmethod + @docstring(_load_doctstring.format(format="hdf5")) + def from_hdf5(cls, filename=None, outdir=None, label=None): + filename = _determine_file_name(filename, outdir, label, 'hdf5', False) + with h5py.File(filename, "r") as ff: + data = recursively_load_dict_contents_from_group(ff, '/') + if list(data.keys()) == ["data"]: + return cls._from_hdf5_old(filename=filename) + data["posterior"] = pd.DataFrame(data["posterior"]) + data["priors"] = PriorDict._get_from_json_dict( + json.loads(data["priors"], object_hook=decode_bilby_json) + ) + try: + cls = getattr(import_module(data['__module__']), data['__name__']) + except ImportError: + logger.debug( + "Module {}.{} not found".format(data["__module__"], data["__name__"]) + ) + except KeyError: + logger.debug("No class specified, using base Result.") + for key in ["__module__", "__name__"]: + if key in data: + del data[key] + return cls(**data) + + @classmethod + @docstring(_load_doctstring.format(format="json")) + def from_json(cls, filename=None, outdir=None, label=None, gzip=False): filename = _determine_file_name(filename, outdir, label, 'json', gzip) if os.path.isfile(filename): @@ -592,7 +636,10 @@ class Result(object): def save_to_file(self, filename=None, overwrite=False, outdir=None, extension='json', gzip=False): """ - Writes the Result to a json or deepdish h5 file + + Writes the Result to a file. + + Supported formats are: `json`, `hdf5`, `arviz`, `pickle` Parameters ========== @@ -631,8 +678,8 @@ class Result(object): try: # convert priors to JSON dictionary for both JSON and hdf5 files - dictionary["priors"] = dictionary["priors"]._get_json_dict() if extension == 'json': + dictionary["priors"] = dictionary["priors"]._get_json_dict() if gzip: import gzip # encode to a string @@ -643,16 +690,25 @@ class Result(object): with open(filename, 'w') as file: json.dump(dictionary, file, indent=2, cls=BilbyJsonEncoder) elif extension == 'hdf5': - import deepdish - for key in dictionary: - if isinstance(dictionary[key], pd.DataFrame): - dictionary[key] = dictionary[key].to_dict() - deepdish.io.save(filename, dictionary) + dictionary["__module__"] = self.__module__ + dictionary["__name__"] = self.__class__.__name__ + with h5py.File(filename, 'w') as h5file: + recursively_save_dict_contents_to_group(h5file, '/', dictionary) + elif extension == 'pkl': + import dill + with open(filename, "wb") as ff: + dill.dump(self, ff) else: raise ValueError("Extension type {} not understood".format(extension)) except Exception as e: - logger.error("\n\n Saving the data has failed with the " - "following message:\n {} \n\n".format(e)) + import dill + filename = ".".join(filename.split(".")[:-1]) + ".pkl" + with open(filename, "wb") as ff: + dill.dump(self, ff) + logger.error( + "\n\nSaving the data has failed with the following message:\n" + "{}\nData has been dumped to {}.\n\n".format(e, filename) + ) def save_posterior_samples(self, filename=None, outdir=None, label=None): """ Saves posterior samples to a file diff --git a/bilby/core/utils.py b/bilby/core/utils.py index f0a8b2ff5c8fb2d115170c727e3af30e8f8797b9..7bfb1faa98f6112e02498eb17ed3c0e7eacd843b 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -1314,3 +1314,167 @@ class tcolors: VALUE = '\033[91m' HIGHLIGHT = '\033[95m' END = '\033[0m' + + +def decode_from_hdf5(item): + """ + Decode an item from HDF5 format to python type. + + This currently just converts __none__ to None and some arrays to lists + + .. versionadded:: 1.0.0 + + Parameters + ---------- + item: object + Item to be decoded + + Returns + ------- + output: object + Converted input item + """ + if isinstance(item, str) and item == "__none__": + output = None + elif isinstance(item, bytes) and item == b"__none__": + output = None + elif isinstance(item, (bytes, bytearray)): + output = item.decode() + elif isinstance(item, np.ndarray): + if "|S" in str(item.dtype) or isinstance(item[0], bytes): + output = [it.decode() for it in item] + else: + output = item + else: + output = item + return output + + +def encode_for_hdf5(item): + """ + Encode an item to a HDF5 savable format. + + .. versionadded:: 1.1.0 + + Parameters + ---------- + item: object + Object to be encoded, specific options are provided for Bilby types + + Returns + ------- + output: object + Input item converted into HDF5 savable format + """ + from .prior.dict import PriorDict + if isinstance(item, np.int_): + item = int(item) + elif isinstance(item, np.float_): + item = float(item) + elif isinstance(item, np.complex_): + item = complex(item) + if isinstance(item, (np.ndarray, int, float, complex, str, bytes)): + output = item + elif item is None: + output = "__none__" + elif isinstance(item, list): + if len(item) == 0: + output = item + elif isinstance(item[0], (str, bytes)) or item[0] is None: + output = list() + for value in item: + if isinstance(value, str): + output.append(value.encode("utf-8")) + elif isinstance(value, bytes): + output.append(value) + else: + output.append(b"__none__") + elif isinstance(item[0], (int, float, complex)): + output = np.array(item) + elif isinstance(item, PriorDict): + output = json.dumps(item._get_json_dict()) + elif isinstance(item, pd.DataFrame): + output = item.to_dict(orient="list") + elif isinstance(item, pd.Series): + output = item.to_dict() + elif inspect.isfunction(item) or inspect.isclass(item): + output = dict(__module__=item.__module__, __name__=item.__name__) + elif isinstance(item, dict): + output = item.copy() + else: + raise ValueError(f'Cannot save {type(item)} type') + return output + + +def recursively_load_dict_contents_from_group(h5file, path): + """ + Recursively load a HDF5 file into a dictionary + + .. versionadded:: 1.1.0 + + Parameters + ---------- + h5file: h5py.File + Open h5py file object + path: str + Path within the HDF5 file + + Returns + ------- + output: dict + The contents of the HDF5 file unpacked into the dictionary. + """ + import h5py + output = dict() + for key, item in h5file[path].items(): + if isinstance(item, h5py.Dataset): + output[key] = decode_from_hdf5(item[()]) + elif isinstance(item, h5py.Group): + output[key] = recursively_load_dict_contents_from_group(h5file, path + key + '/') + return output + + +def recursively_save_dict_contents_to_group(h5file, path, dic): + """ + Recursively save a dictionary to a HDF5 group + + .. versionadded:: 1.1.0 + + Parameters + ---------- + h5file: h5py.File + Open HDF5 file + path: str + Path inside the HDF5 file + dic: dict + The dictionary containing the data + """ + for key, item in dic.items(): + item = encode_for_hdf5(item) + if isinstance(item, dict): + recursively_save_dict_contents_to_group(h5file, path + key + '/', item) + else: + h5file[path + key] = item + + +def docstring(docstr, sep="\n"): + """ + Decorator: Append to a function's docstring. + + This is required for e.g., :code:`classmethods` as the :code:`__doc__` + can't be changed after. + + Parameters + ========== + docstr: str + The docstring + sep: str + Separation character for appending the existing docstring. + """ + def _decorator(func): + if func.__doc__ is None: + func.__doc__ = docstr + else: + func.__doc__ = sep.join([func.__doc__, docstr]) + return func + return _decorator diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 9a0c5b9a80fd1f4c54d64ae7f34467434268e625..72a00168e5a4e1e08f3bc055f9c787ed199ea9d8 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -5,7 +5,7 @@ import numpy as np from matplotlib import pyplot as plt from ...core import utils -from ...core.utils import logger +from ...core.utils import docstring, logger from .. import utils as gwutils from ..utils import PropertyAccessor from .calibration import Recalibrate @@ -704,19 +704,35 @@ class Interferometer(object): plt.close(fig) @staticmethod - def _hdf5_filename_from_outdir_label(outdir, label): - return os.path.join(outdir, label + '.h5') + def _filename_from_outdir_label_extension(outdir, label, extension="h5"): + return os.path.join(outdir, label + f'.{extension}') - def to_hdf5(self, outdir='outdir', label=None): - """ Save the object to a hdf5 file + _save_ifo_docstring = """ Save the object to a {format} file - Attributes - ========== - outdir: str, optional - Output directory name of the file, defaults to 'outdir'. - label: str, optional - Output file name, is self.name if not given otherwise. - """ + {extra} + + Attributes + ========== + outdir: str, optional + Output directory name of the file, defaults to 'outdir'. + label: str, optional + Output file name, is self.name if not given otherwise. + """ + + _load_docstring = """ Loads in an Interferometer object from a {format} file + + Parameters + ========== + filename: str + If given, try to load from this filename + + """ + + @docstring(_save_ifo_docstring.format( + format="hdf5", extra=""".. deprecated:: 1.1.0 + Use :func:`to_pickle` instead.""" + )) + def to_hdf5(self, outdir='outdir', label=None): import deepdish if sys.version_info[0] < 3: raise NotImplementedError('Pickling of Interferometer is not supported in Python 2.' @@ -724,19 +740,16 @@ class Interferometer(object): if label is None: label = self.name utils.check_directory_exists_and_if_not_mkdir('outdir') - filename = self._hdf5_filename_from_outdir_label(outdir, label) - deepdish.io.save(filename, self) + try: + filename = self._filename_from_outdir_label_extension(outdir, label, "h5") + deepdish.io.save(filename, self) + except AttributeError: + logger.warning("Saving to hdf5 using deepdish failed. Pickle dumping instead.") + self.to_pickle(outdir=outdir, label=label) @classmethod + @docstring(_load_docstring.format(format="hdf5")) def from_hdf5(cls, filename=None): - """ Loads in an Interferometer object from an hdf5 file - - Parameters - ========== - filename: str - If given, try to load from this filename - - """ import deepdish if sys.version_info[0] < 3: raise NotImplementedError('Pickling of Interferometer is not supported in Python 2.' @@ -746,3 +759,23 @@ class Interferometer(object): if res.__class__ != cls: raise TypeError('The loaded object is not an Interferometer') return res + + @docstring(_save_ifo_docstring.format( + format="pickle", extra=".. versionadded:: 1.1.0" + )) + def to_pickle(self, outdir="outdir", label=None): + import dill + utils.check_directory_exists_and_if_not_mkdir('outdir') + filename = self._filename_from_outdir_label_extension(outdir, label, extension="pkl") + with open(filename, "wb") as ff: + dill.dump(self, ff) + + @classmethod + @docstring(_load_docstring.format(format="pickle")) + def from_pickle(cls, filename=None): + import dill + with open(filename, "rb") as ff: + res = dill.load(ff) + if res.__class__ != cls: + raise TypeError('The loaded object is not an Interferometer') + return res diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py index 93dd11b350552b6e2b3e2fc7f830556528248201..8abd64c3f7ed1eac15f61a90925b8b323fb9efb4 100644 --- a/bilby/gw/detector/networks.py +++ b/bilby/gw/detector/networks.py @@ -217,38 +217,47 @@ class InterferometerList(list): for interferometer in self} @staticmethod - def _hdf5_filename_from_outdir_label(outdir, label): - return os.path.join(outdir, label + '.h5') + def _filename_from_outdir_label_extension(outdir, label, extension="h5"): + return os.path.join(outdir, label + f'.{extension}') - def to_hdf5(self, outdir='outdir', label='ifo_list'): - """ Saves the object to a hdf5 file + _save_docstring = """ Saves the object to a {format} file - Parameters - ========== - outdir: str, optional - Output directory name of the file - label: str, optional - Output file name, is 'ifo_list' if not given otherwise. A list of - the included interferometers will be appended. - """ + {extra} + + Parameters + ========== + outdir: str, optional + Output directory name of the file + label: str, optional + Output file name, is 'ifo_list' if not given otherwise. A list of + the included interferometers will be appended. + """ + + _load_docstring = """ Loads in an InterferometerList object from a {format} file + + Parameters + ========== + filename: str + If given, try to load from this filename + + """ + + def to_hdf5(self, outdir='outdir', label='ifo_list'): import deepdish if sys.version_info[0] < 3: raise NotImplementedError('Pickling of InterferometerList is not supported in Python 2.' 'Use Python 3 instead.') label = label + '_' + ''.join(ifo.name for ifo in self) utils.check_directory_exists_and_if_not_mkdir(outdir) - deepdish.io.save(self._hdf5_filename_from_outdir_label(outdir, label), self) + try: + filename = self._filename_from_outdir_label_extension(outdir, label, "h5") + deepdish.io.save(filename, self) + except AttributeError: + logger.warning("Saving to hdf5 using deepdish failed. Pickle dumping instead.") + self.to_pickle(outdir=outdir, label=label) @classmethod def from_hdf5(cls, filename=None): - """ Loads in an InterferometerList object from an hdf5 file - - Parameters - ========== - filename: str - If given, try to load from this filename - - """ import deepdish if sys.version_info[0] < 3: raise NotImplementedError('Pickling of InterferometerList is not supported in Python 2.' @@ -260,6 +269,33 @@ class InterferometerList(list): raise TypeError('The loaded object is not a InterferometerList') return res + def to_pickle(self, outdir="outdir", label="ifo_list"): + import dill + utils.check_directory_exists_and_if_not_mkdir('outdir') + label = label + '_' + ''.join(ifo.name for ifo in self) + filename = self._filename_from_outdir_label_extension(outdir, label, extension="pkl") + with open(filename, "wb") as ff: + dill.dump(self, ff) + + @classmethod + def from_pickle(cls, filename=None): + import dill + with open(filename, "rb") as ff: + res = dill.load(ff) + if res.__class__ != cls: + raise TypeError('The loaded object is not an InterferometerList') + return res + + to_hdf5.__doc__ = _save_docstring.format( + format="hdf5", extra=""".. deprecated:: 1.1.0 + Use :func:`to_pickle` instead.""" + ) + to_pickle.__doc__ = _save_docstring.format( + format="pickle", extra=".. versionadded:: 1.1.0" + ) + from_hdf5.__doc__ = _load_docstring.format(format="hdf5") + from_pickle.__doc__ = _load_docstring.format(format="pickle") + class TriangularInterferometer(InterferometerList): diff --git a/requirements.txt b/requirements.txt index 3ef1d047adb45453f7f066322d58d4505684f519..50a9de636a5c35d515d796d8d9da423232bbaf50 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,8 @@ corner numpy<1.20 matplotlib>=2.0 scipy>=0.16 -pandas<1.2 +pandas mock dill tqdm +h5py diff --git a/setup.py b/setup.py index f80072463619aa009986fe58fb371ff94382a166..9f49d130bd56acb7772cfaff7fd41bf59b8e686a 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,12 @@ def get_long_description(): return long_description +def get_requirements(): + with open("requirements.txt", "r") as ff: + requirements = ff.readlines() + return requirements + + # get version info from __init__.py def readfile(filename): with open(filename) as fp: @@ -86,16 +92,7 @@ setup(name='bilby', 'bilby.gw.eos': ['eos_tables/*.dat'], 'bilby': [version_file]}, python_requires='>=3.5', - install_requires=[ - 'dynesty>=1.0.0', - 'emcee', - 'corner', - 'dill', - 'numpy<1.20', - 'matplotlib>=2.0', - 'pandas<1.2', - 'scipy', - 'tqdm'], + install_requires=get_requirements(), entry_points={'console_scripts': ['bilby_plot=cli_bilby.plot_multiple_posteriors:main', 'bilby_result=cli_bilby.bilby_result:main'] diff --git a/test/gw/detector/interferometer_test.py b/test/gw/detector/interferometer_test.py index 0e5a51e032a6f3ce9ac11466b547eacbd596530a..d204d6cbc719ff87279b9828a57d199ae74232b8 100644 --- a/test/gw/detector/interferometer_test.py +++ b/test/gw/detector/interferometer_test.py @@ -1,11 +1,14 @@ import sys import unittest +import pytest +from packaging import version from shutil import rmtree import deepdish as dd import mock import numpy as np from mock import MagicMock, patch +import pandas import bilby @@ -364,29 +367,47 @@ class TestInterferometer(unittest.TestCase): ) self.assertEqual(expected, repr(self.ifo)) + pandas_version_test = version.parse(pandas.__version__) >= version.parse("1.2.0") + skip_reason = "Deepdish requires pandas < 1.2" + + @pytest.mark.skipif(pandas_version_test, reason=skip_reason) def test_to_and_from_hdf5_loading(self): if sys.version_info[0] < 3: with self.assertRaises(NotImplementedError): self.ifo.to_hdf5(outdir="outdir", label="test") else: self.ifo.to_hdf5(outdir="outdir", label="test") - filename = self.ifo._hdf5_filename_from_outdir_label( - outdir="outdir", label="test" + filename = self.ifo._filename_from_outdir_label_extension( + outdir="outdir", label="test", extension="h5" ) recovered_ifo = bilby.gw.detector.Interferometer.from_hdf5(filename) self.assertEqual(self.ifo, recovered_ifo) + @pytest.mark.skipif(pandas_version_test or sys.version_info[0] < 3, reason=skip_reason) def test_to_and_from_hdf5_wrong_class(self): - if sys.version_info[0] < 3: - pass - else: - bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir") - dd.io.save("./outdir/psd.h5", self.power_spectral_density) - filename = self.ifo._hdf5_filename_from_outdir_label( - outdir="outdir", label="psd" - ) - with self.assertRaises(TypeError): - bilby.gw.detector.Interferometer.from_hdf5(filename) + bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir") + dd.io.save("./outdir/psd.h5", self.power_spectral_density) + filename = self.ifo._filename_from_outdir_label_extension( + outdir="outdir", label="psd", extension="h5" + ) + with self.assertRaises(TypeError): + bilby.gw.detector.Interferometer.from_hdf5(filename) + + def test_to_and_from_pkl_loading(self): + self.ifo.to_pickle(outdir="outdir", label="test") + filename = "outdir/test.pkl" + recovered_ifo = bilby.gw.detector.Interferometer.from_pickle(filename) + self.assertEqual(self.ifo, recovered_ifo) + + def test_to_and_from_pkl_wrong_class(self): + import dill + with open("./outdir/psd.pkl", "wb") as ff: + dill.dump(self.ifo.power_spectral_density, ff) + filename = self.ifo._filename_from_outdir_label_extension( + outdir="outdir", label="psd", extension="pkl" + ) + with self.assertRaises(TypeError): + bilby.gw.detector.Interferometer.from_pickle(filename) class TestInterferometerEquals(unittest.TestCase): diff --git a/test/gw/detector/networks_test.py b/test/gw/detector/networks_test.py index 3627db3ba65349979d34c9147cfe916018cd6883..4484cfb12d526ecb11940e390e78b2e10c9c21c7 100644 --- a/test/gw/detector/networks_test.py +++ b/test/gw/detector/networks_test.py @@ -1,11 +1,14 @@ import sys import unittest +import pytest from shutil import rmtree +from packaging import version import deepdish as dd import mock import numpy as np from mock import patch, MagicMock +import pandas import bilby @@ -320,6 +323,10 @@ class TestInterferometerList(unittest.TestCase): names = [ifo.name for ifo in self.ifo_list] self.assertListEqual([self.ifo1.name, new_ifo.name, self.ifo2.name], names) + pandas_version_test = version.parse(pandas.__version__) >= version.parse("1.2.0") + skip_reason = "Deepdish requires pandas < 1.2" + + @pytest.mark.skipif(pandas_version_test, reason=skip_reason) def test_to_and_from_hdf5_loading(self): if sys.version_info[0] < 3: with self.assertRaises(NotImplementedError): @@ -330,16 +337,30 @@ class TestInterferometerList(unittest.TestCase): recovered_ifo = bilby.gw.detector.InterferometerList.from_hdf5(filename) self.assertListEqual(self.ifo_list, recovered_ifo) + @pytest.mark.skipif(pandas_version_test or sys.version_info[0] < 3, reason=skip_reason) def test_to_and_from_hdf5_wrong_class(self): - if sys.version_info[0] < 3: - pass - else: - dd.io.save("./outdir/psd.h5", self.ifo_list[0].power_spectral_density) - filename = self.ifo_list._hdf5_filename_from_outdir_label( - outdir="outdir", label="psd" - ) - with self.assertRaises(TypeError): - bilby.gw.detector.InterferometerList.from_hdf5(filename) + dd.io.save("./outdir/psd.h5", self.ifo_list[0].power_spectral_density) + filename = self.ifo_list._filename_from_outdir_label_extension( + outdir="outdir", label="psd", extension="h5" + ) + with self.assertRaises(TypeError): + bilby.gw.detector.InterferometerList.from_hdf5(filename) + + def test_to_and_from_pkl_loading(self): + self.ifo_list.to_pickle(outdir="outdir", label="test") + filename = "outdir/test_name1name2.pkl" + recovered_ifo = bilby.gw.detector.InterferometerList.from_pickle(filename) + self.assertListEqual(self.ifo_list, recovered_ifo) + + def test_to_and_from_pkl_wrong_class(self): + import dill + with open("./outdir/psd.pkl", "wb") as ff: + dill.dump(self.ifo_list[0].power_spectral_density, ff) + filename = self.ifo_list._filename_from_outdir_label_extension( + outdir="outdir", label="psd", extension="pkl" + ) + with self.assertRaises(TypeError): + bilby.gw.detector.InterferometerList.from_pickle(filename) def test_plot_data(self): ifos = bilby.gw.detector.InterferometerList(["H1", "L1"])