diff --git a/AUTHORS.md b/AUTHORS.md index 2ae71ee830649c2ab51a422f39bb46fe9a0983e1..1e38fa961c0dcacf8476adf8c8c23e1005044b30 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -89,3 +89,5 @@ Vivien Raymond Ka-Lok Lo Isaac Legred Marc Penuliar +Andrew Fowlie +Martin White diff --git a/bilby/core/result.py b/bilby/core/result.py index 4f2b610c663dc4f11b635f10e50ba89be99aafbc..f27bdd7f3bbafc4b2b0fbc26f36f24d84401c581 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -76,7 +76,7 @@ def _determine_file_name(filename, outdir, label, extension, gzip): return result_file_name(outdir, label, extension, gzip) -def read_in_result(filename=None, outdir=None, label=None, extension='json', gzip=False): +def read_in_result(filename=None, outdir=None, label=None, extension='json', gzip=False, result_class=None): """ Reads in a stored bilby result object Parameters @@ -86,21 +86,29 @@ def read_in_result(filename=None, outdir=None, label=None, extension='json', gzi outdir, label, extension: str Name of the output directory, label and extension used for the default naming scheme. - + result_class: bilby.core.result.Result, or child of + The result class to use. By default, `bilby.core.result.Result` is used, + but objects which inherit from this class can be given providing + additional methods. """ filename = _determine_file_name(filename, outdir, label, extension, gzip) + if result_class is None: + result_class = Result + elif not issubclass(result_class, Result): + raise ValueError(f"Input result_class={result_class} not understood") + # Get the actual extension (may differ from the default extension if the filename is given) extension = os.path.splitext(filename)[1].lstrip('.') if extension == 'gz': # gzipped file extension = os.path.splitext(os.path.splitext(filename)[0])[1].lstrip('.') if 'json' in extension: - result = Result.from_json(filename=filename) + result = result_class.from_json(filename=filename) elif ('hdf5' in extension) or ('h5' in extension): - result = Result.from_hdf5(filename=filename) + result = result_class.from_hdf5(filename=filename) elif ("pkl" in extension) or ("pickle" in extension): - result = Result.from_pickle(filename=filename) + result = result_class.from_pickle(filename=filename) elif extension is None: raise ValueError("No filetype extension provided") else: diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 4568b165432418141f34c3f6b8e04b2d6f9ef41d..e7ad043b36cc1cf6d65b88a0f63548679d99700b 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -253,7 +253,7 @@ class Sampler(object): self.kwargs = kwargs - self._check_cached_result() + self._check_cached_result(result_class) self._log_summary_for_sampler() @@ -631,7 +631,7 @@ class Sampler(object): """ raise ValueError("Method not yet implemented") - def _check_cached_result(self): + def _check_cached_result(self, result_class=None): """Check if the cached data file exists and can be used""" if command_line_args.clean: @@ -640,7 +640,9 @@ class Sampler(object): return try: - self.cached_result = read_in_result(outdir=self.outdir, label=self.label) + self.cached_result = read_in_result( + outdir=self.outdir, label=self.label, result_class=result_class + ) except IOError: self.cached_result = None diff --git a/test/core/result_test.py b/test/core/result_test.py index 80300ea360acd3586a8a8fb718dd5c6500844d3b..36e50aa365d9547236506f67dbfa38844215ee38 100644 --- a/test/core/result_test.py +++ b/test/core/result_test.py @@ -59,9 +59,10 @@ class TestResult(unittest.TestCase): d=2, ) ) + self.outdir = "test_outdir" result = bilby.core.result.Result( label="label", - outdir="outdir", + outdir=self.outdir, sampler="nestle", search_parameter_keys=["x", "y"], fixed_parameter_keys=["c", "d"], @@ -87,7 +88,7 @@ class TestResult(unittest.TestCase): def tearDown(self): bilby.utils.command_line_args.bilby_test_mode = True try: - shutil.rmtree(self.result.outdir) + shutil.rmtree(self.outdir) except OSError: pass del self.result @@ -491,6 +492,40 @@ class TestResult(unittest.TestCase): self.assertTrue(np.array_equal(az.log_likelihood["log_likelihood"].values.squeeze(), log_likelihood)) + def test_result_caching(self): + + class SimpleLikelihood(bilby.Likelihood): + def __init__(self): + super().__init__(parameters={"x": None}) + + def log_likelihood(self): + return -self.parameters["x"]**2 + + likelihood = SimpleLikelihood() + priors = dict(x=bilby.core.prior.Uniform(-5, 5, "x")) + + # Trivial subclass of Result + + class NotAResult(bilby.core.result.Result): + pass + + result = bilby.run_sampler( + likelihood, priors, sampler='bilby_mcmc', nsamples=10, L1steps=1, + proposal_cycle="default_noGMnoKD", printdt=1, + check_point_plot=False, + result_class=NotAResult) + # result should be specified result_class + assert isinstance(result, NotAResult) + + cached_result = bilby.run_sampler( + likelihood, priors, sampler='bilby_mcmc', nsamples=10, L1steps=1, + proposal_cycle="default_noGMnoKD", printdt=1, + check_point_plot=False, + result_class=NotAResult) + + # so should a result loaded from cache + assert isinstance(cached_result, NotAResult) + class TestResultListError(unittest.TestCase): def setUp(self):