Skip to content
Snippets Groups Projects

Resolve "Update the bilby_result method to enable merging results with inconsistent parameters"

All threads resolved!
2 files
+ 28
21
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 18
20
@@ -1710,7 +1710,7 @@ class ResultList(list):
else:
raise TypeError("Could not append a non-Result type")
def combine(self, shuffle=False, consistency_warning=False, consistency_error=True):
def combine(self, shuffle=False, consistency_level="error"):
"""
Return the combined results in a :class:bilby.core.result.Result`
object.
@@ -1719,10 +1719,9 @@ class ResultList(list):
----------
shuffle: bool
If true, shuffle the samples when combining, otherwise they are concatenated.
consistency_warning: bool
If true, print a warning if inconsistencies are discovered between the results before combining.
consistency_error: bool
If true, raise an error if inconsistencies are discovered between the results before combining.
consistency_level: str, [warning, error]
If warning, print a warning if inconsistencies are discovered between the results before combining.
If error, raise an error if inconsistencies are discovered between the results before combining.
Returns
-------
@@ -1731,8 +1730,7 @@ class ResultList(list):
"""
self.consistency_error = consistency_error
self.consistency_warning = consistency_warning
self.consistency_level = consistency_level
if len(self) == 0:
return Result()
@@ -1863,37 +1861,37 @@ class ResultList(list):
except ValueError:
raise ResultListError("Not all results contain nested samples")
def _error_or_warning_consistency(self, msg):
if self.consistency_level == "error":
raise ResultListError(msg)
elif self.consistency_level == "warning":
logger.warning(msg)
else:
raise ValueError(f"Input consistency_level {self.consistency_level} not understood")
def check_consistent_priors(self):
for res in self:
for p in self[0].priors.keys():
if not self[0].priors[p] == res.priors[p] or len(self[0].priors) != len(res.priors):
raise ResultListError("Inconsistent priors between results")
msg = "Inconsistent priors between results"
self._error_or_warning_consistency(msg)
def check_consistent_parameters(self):
if not np.all([set(self[0].search_parameter_keys) == set(res.search_parameter_keys) for res in self]):
msg = "Inconsistent parameters between results"
if self.consistency_error:
raise ResultListError(msg)
elif self.consistency_warning:
logger.warning(msg)
self._error_or_warning_consistency(msg)
def check_consistent_data(self):
if not np.allclose([res.log_noise_evidence for res in self], self[0].log_noise_evidence, atol=1e-8, rtol=0.0)\
and not np.all([np.isnan(res.log_noise_evidence) for res in self]):
msg = "Inconsistent data between results"
if self.consistency_error:
raise ResultListError(msg)
elif self.consistency_warning:
logger.warning(msg)
self._error_or_warning_consistency(msg)
def check_consistent_sampler(self):
if not np.all([res.sampler == self[0].sampler for res in self]):
msg = "Inconsistent samplers between results"
if self.consistency_error:
raise ResultListError(msg)
elif self.consistency_warning:
logger.warning(msg)
self._error_or_warning_consistency(msg)
@latex_plot_format
Loading