From 01338356505abfaf2d5e9a99fe94658d90b932db Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Fri, 25 Feb 2022 17:02:03 +0000 Subject: [PATCH] Make sampling time saving work with hdf5 --- bilby/core/result.py | 2 +- bilby/core/utils/io.py | 3 +++ test/core/result_test.py | 2 ++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/bilby/core/result.py b/bilby/core/result.py index c8d9a517c..700c6469a 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -363,7 +363,7 @@ class Result(object): The number of times the likelihood function is called log_prior_evaluations: array_like The evaluations of the prior for each sample point - sampling_time: (datetime.timedelta, float) + sampling_time: datetime.timedelta, float The time taken to complete the sampling nburn: int The number of burn-in steps discarded for MCMC samplers diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py index 880abbfe6..190f55ea4 100644 --- a/bilby/core/utils/io.py +++ b/bilby/core/utils/io.py @@ -1,3 +1,4 @@ +import datetime import inspect import json import os @@ -301,6 +302,8 @@ def encode_for_hdf5(key, item): output = item.copy() elif isinstance(item, tuple): output = {str(ii): elem for ii, elem in enumerate(item)} + elif isinstance(item, datetime.timedelta): + output = item.total_seconds() else: raise ValueError(f'Cannot save {key}: {type(item)} type') return output diff --git a/test/core/result_test.py b/test/core/result_test.py index f49e4a3a2..fe83c8201 100644 --- a/test/core/result_test.py +++ b/test/core/result_test.py @@ -69,6 +69,7 @@ class TestResult(unittest.TestCase): sampler_kwargs=dict(test="test", func=lambda x: x), injection_parameters=dict(x=0.5, y=0.5), meta_data=dict(test="test"), + sampling_time=100.0, ) n = 100 @@ -254,6 +255,7 @@ class TestResult(unittest.TestCase): self.assertEqual(self.result.priors["y"], loaded_result.priors["y"]) self.assertEqual(self.result.priors["c"], loaded_result.priors["c"]) self.assertEqual(self.result.priors["d"], loaded_result.priors["d"]) + self.assertEqual(self.result.sampling_time, loaded_result.sampling_time) def test_save_and_dont_overwrite_json(self): self._save_and_dont_overwrite_test(extension='json') -- GitLab