Commit 330f2461 authored by Gregory Ashton's avatar Gregory Ashton Committed by Colm Talbot

Generalise the consistency checks

There is good reason to create a results list and check the consistency
without it being for combining results (e.g., for a PP test). This makes
the error messages more applicable in these situations.
parent d94e4f62
......@@ -1311,10 +1311,10 @@ class ResultList(list):
if result.label is not None:
result.label += '_combined'
self._check_consistent_sampler()
self._check_consistent_data()
self._check_consistent_parameters()
self._check_consistent_priors()
self.check_consistent_sampler()
self.check_consistent_data()
self.check_consistent_parameters()
self.check_consistent_priors()
# check which kind of sampler was used: MCMC or Nested Sampling
if result.nested_samples is not None:
......@@ -1327,7 +1327,7 @@ class ResultList(list):
return result
def _combine_nested_sampled_runs(self, result):
self._check_nested_samples()
self.check_nested_samples()
log_evidences = np.array([res.log_evidence for res in self])
result.log_evidence = logsumexp(log_evidences, b=1. / len(self))
if result.use_ratio:
......@@ -1350,32 +1350,31 @@ class ResultList(list):
result.sampler_kwargs = None
return posteriors, result
def _check_nested_samples(self):
def check_nested_samples(self):
for res in self:
try:
res.nested_samples
except ValueError:
raise CombineResultError("Cannot combine results: No nested samples available "
"in all results")
raise ResultListError("Not all results contain nested samples")
def _check_consistent_priors(self):
def check_consistent_priors(self):
for res in self:
for p in self[0].priors.keys():
if not self[0].priors[p] == res.priors[p] or len(self[0].priors) != len(res.priors):
raise CombineResultError("Cannot combine results: inconsistent priors")
raise ResultListError("Inconsistent priors between results")
def _check_consistent_parameters(self):
def check_consistent_parameters(self):
if not np.all([set(self[0].search_parameter_keys) == set(res.search_parameter_keys) for res in self]):
raise CombineResultError("Cannot combine results: inconsistent parameters")
raise ResultListError("Inconsistent parameters between results")
def _check_consistent_data(self):
def check_consistent_data(self):
if not np.all([res.log_noise_evidence == self[0].log_noise_evidence for res in self])\
and not np.all([np.isnan(res.log_noise_evidence) for res in self]):
raise CombineResultError("Cannot combine results: inconsistent data")
raise ResultListError("Inconsistent data between results")
def _check_consistent_sampler(self):
def check_consistent_sampler(self):
if not np.all([res.sampler == self[0].sampler for res in self]):
raise CombineResultError("Cannot combine results: inconsistent samplers")
raise ResultListError("Inconsistent samplers between results")
def plot_multiple(results, filename=None, labels=None, colours=None,
......@@ -1568,7 +1567,7 @@ class ResultError(Exception):
""" Base exception for all Result related errors """
class CombineResultError(ResultError):
class ResultListError(ResultError):
""" For Errors occuring during combining results. """
......
......@@ -391,7 +391,7 @@ class TestResult(unittest.TestCase):
self.result.kde([[0, 0.1], [0.8, 0]])))
class TestResultList(unittest.TestCase):
class TestResultListError(unittest.TestCase):
def setUp(self):
np.random.seed(7)
......@@ -478,7 +478,7 @@ class TestResultList(unittest.TestCase):
def test_combine_inconsistent_samplers(self):
self.nested_results[0].sampler = 'dynesty'
with self.assertRaises(bilby.result.CombineResultError):
with self.assertRaises(bilby.result.ResultListError):
self.nested_results.combine()
def test_combine_inconsistent_priors_length(self):
......@@ -486,7 +486,7 @@ class TestResultList(unittest.TestCase):
x=bilby.prior.Uniform(0, 1, 'x', latex_label='$x$', unit='s'),
y=bilby.prior.Uniform(0, 1, 'y', latex_label='$y$', unit='m'),
c=1))
with self.assertRaises(bilby.result.CombineResultError):
with self.assertRaises(bilby.result.ResultListError):
self.nested_results.combine()
def test_combine_inconsistent_priors_types(self):
......@@ -495,17 +495,17 @@ class TestResultList(unittest.TestCase):
y=bilby.prior.Uniform(0, 1, 'y', latex_label='$y$', unit='m'),
c=1,
d=bilby.core.prior.Cosine()))
with self.assertRaises(bilby.result.CombineResultError):
with self.assertRaises(bilby.result.ResultListError):
self.nested_results.combine()
def test_combine_inconsistent_search_parameters(self):
self.nested_results[0].search_parameter_keys = ['y']
with self.assertRaises(bilby.result.CombineResultError):
with self.assertRaises(bilby.result.ResultListError):
self.nested_results.combine()
def test_combine_inconsistent_data(self):
self.nested_results[0].log_noise_evidence = -7
with self.assertRaises(bilby.result.CombineResultError):
with self.assertRaises(bilby.result.ResultListError):
self.nested_results.combine()
def test_combine_inconsistent_data_nan(self):
......@@ -531,7 +531,7 @@ class TestResultList(unittest.TestCase):
result.log_noise_evidence = 13
result._nested_samples = None
self.nested_results.append(result)
with self.assertRaises(bilby.result.CombineResultError):
with self.assertRaises(bilby.result.ResultListError):
self.nested_results.combine()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment