From 01338356505abfaf2d5e9a99fe94658d90b932db Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Fri, 25 Feb 2022 17:02:03 +0000
Subject: [PATCH] Make sampling time saving work with hdf5

---
 bilby/core/result.py     | 2 +-
 bilby/core/utils/io.py   | 3 +++
 test/core/result_test.py | 2 ++
 3 files changed, 6 insertions(+), 1 deletion(-)

diff --git a/bilby/core/result.py b/bilby/core/result.py
index c8d9a517c..700c6469a 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -363,7 +363,7 @@ class Result(object):
             The number of times the likelihood function is called
         log_prior_evaluations: array_like
             The evaluations of the prior for each sample point
-        sampling_time: (datetime.timedelta, float)
+        sampling_time: datetime.timedelta, float
             The time taken to complete the sampling
         nburn: int
             The number of burn-in steps discarded for MCMC samplers
diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py
index 880abbfe6..190f55ea4 100644
--- a/bilby/core/utils/io.py
+++ b/bilby/core/utils/io.py
@@ -1,3 +1,4 @@
+import datetime
 import inspect
 import json
 import os
@@ -301,6 +302,8 @@ def encode_for_hdf5(key, item):
         output = item.copy()
     elif isinstance(item, tuple):
         output = {str(ii): elem for ii, elem in enumerate(item)}
+    elif isinstance(item, datetime.timedelta):
+        output = item.total_seconds()
     else:
         raise ValueError(f'Cannot save {key}: {type(item)} type')
     return output
diff --git a/test/core/result_test.py b/test/core/result_test.py
index f49e4a3a2..fe83c8201 100644
--- a/test/core/result_test.py
+++ b/test/core/result_test.py
@@ -69,6 +69,7 @@ class TestResult(unittest.TestCase):
             sampler_kwargs=dict(test="test", func=lambda x: x),
             injection_parameters=dict(x=0.5, y=0.5),
             meta_data=dict(test="test"),
+            sampling_time=100.0,
         )
 
         n = 100
@@ -254,6 +255,7 @@ class TestResult(unittest.TestCase):
         self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
         self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
         self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
+        self.assertEqual(self.result.sampling_time, loaded_result.sampling_time)
 
     def test_save_and_dont_overwrite_json(self):
         self._save_and_dont_overwrite_test(extension='json')
-- 
GitLab