Gitlab will migrate to a new storage backend starting 0300 UTC on 2020-04-04. We do not anticipate a maintenance window for this migration. Performance may be impacted over the weekend. Thanks for your patience.

Commit 330f2461 authored by Gregory Ashton's avatar Gregory Ashton Committed by Colm Talbot

Generalise the consistency checks

There is good reason to create a results list and check the consistency
without it being for combining results (e.g., for a PP test). This makes
the error messages more applicable in these situations.
parent d94e4f62
......@@ -1311,10 +1311,10 @@ class ResultList(list):
if result.label is not None:
result.label += '_combined'
self._check_consistent_sampler()
self._check_consistent_data()
self._check_consistent_parameters()
self._check_consistent_priors()
self.check_consistent_sampler()
self.check_consistent_data()
self.check_consistent_parameters()
self.check_consistent_priors()
# check which kind of sampler was used: MCMC or Nested Sampling
if result.nested_samples is not None:
......@@ -1327,7 +1327,7 @@ class ResultList(list):
return result
def _combine_nested_sampled_runs(self, result):
self._check_nested_samples()
self.check_nested_samples()
log_evidences = np.array([res.log_evidence for res in self])
result.log_evidence = logsumexp(log_evidences, b=1. / len(self))
if result.use_ratio:
......@@ -1350,32 +1350,31 @@ class ResultList(list):
result.sampler_kwargs = None
return posteriors, result
def _check_nested_samples(self):
def check_nested_samples(self):
for res in self:
try:
res.nested_samples
except ValueError:
raise CombineResultError("Cannot combine results: No nested samples available "
"in all results")
raise ResultListError("Not all results contain nested samples")
def _check_consistent_priors(self):
def check_consistent_priors(self):
for res in self:
for p in self[0].priors.keys():
if not self[0].priors[p] == res.priors[p] or len(self[0].priors) != len(res.priors):
raise CombineResultError("Cannot combine results: inconsistent priors")
raise ResultListError("Inconsistent priors between results")
def _check_consistent_parameters(self):
def check_consistent_parameters(self):
if not np.all([set(self[0].search_parameter_keys) == set(res.search_parameter_keys) for res in self]):
raise CombineResultError("Cannot combine results: inconsistent parameters")
raise ResultListError("Inconsistent parameters between results")
def _check_consistent_data(self):
def check_consistent_data(self):
if not np.all([res.log_noise_evidence == self[0].log_noise_evidence for res in self])\
and not np.all([np.isnan(res.log_noise_evidence) for res in self]):
raise CombineResultError("Cannot combine results: inconsistent data")
raise ResultListError("Inconsistent data between results")
def _check_consistent_sampler(self):
def check_consistent_sampler(self):
if not np.all([res.sampler == self[0].sampler for res in self]):
raise CombineResultError("Cannot combine results: inconsistent samplers")
raise ResultListError("Inconsistent samplers between results")
def plot_multiple(results, filename=None, labels=None, colours=None,
......@@ -1568,7 +1567,7 @@ class ResultError(Exception):
""" Base exception for all Result related errors """
class CombineResultError(ResultError):
class ResultListError(ResultError):
""" For Errors occuring during combining results. """
......
......@@ -391,7 +391,7 @@ class TestResult(unittest.TestCase):
self.result.kde([[0, 0.1], [0.8, 0]])))
class TestResultList(unittest.TestCase):
class TestResultListError(unittest.TestCase):
def setUp(self):
np.random.seed(7)
......@@ -478,7 +478,7 @@ class TestResultList(unittest.TestCase):
def test_combine_inconsistent_samplers(self):
self.nested_results[0].sampler = 'dynesty'
with self.assertRaises(bilby.result.CombineResultError):
with self.assertRaises(bilby.result.ResultListError):
self.nested_results.combine()
def test_combine_inconsistent_priors_length(self):
......@@ -486,7 +486,7 @@ class TestResultList(unittest.TestCase):
x=bilby.prior.Uniform(0, 1, 'x', latex_label='$x$', unit='s'),
y=bilby.prior.Uniform(0, 1, 'y', latex_label='$y$', unit='m'),
c=1))
with self.assertRaises(bilby.result.CombineResultError):
with self.assertRaises(bilby.result.ResultListError):
self.nested_results.combine()
def test_combine_inconsistent_priors_types(self):
......@@ -495,17 +495,17 @@ class TestResultList(unittest.TestCase):
y=bilby.prior.Uniform(0, 1, 'y', latex_label='$y$', unit='m'),
c=1,
d=bilby.core.prior.Cosine()))
with self.assertRaises(bilby.result.CombineResultError):
with self.assertRaises(bilby.result.ResultListError):
self.nested_results.combine()
def test_combine_inconsistent_search_parameters(self):
self.nested_results[0].search_parameter_keys = ['y']
with self.assertRaises(bilby.result.CombineResultError):
with self.assertRaises(bilby.result.ResultListError):
self.nested_results.combine()
def test_combine_inconsistent_data(self):
self.nested_results[0].log_noise_evidence = -7
with self.assertRaises(bilby.result.CombineResultError):
with self.assertRaises(bilby.result.ResultListError):
self.nested_results.combine()
def test_combine_inconsistent_data_nan(self):
......@@ -531,7 +531,7 @@ class TestResultList(unittest.TestCase):
result.log_noise_evidence = 13
result._nested_samples = None
self.nested_results.append(result)
with self.assertRaises(bilby.result.CombineResultError):
with self.assertRaises(bilby.result.ResultListError):
self.nested_results.combine()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment