Skip to content
Snippets Groups Projects

Resolve "Update the bilby_result method to enable merging results with inconsistent parameters"

All threads resolved!
2 files
+ 45
6
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 35
5
@@ -1710,11 +1710,28 @@ class ResultList(list):
else:
raise TypeError("Could not append a non-Result type")
def combine(self, shuffle=False):
def combine(self, shuffle=False, consistency_level="error"):
"""
Return the combined results in a :class:bilby.core.result.Result`
object.
Parameters
----------
shuffle: bool
If true, shuffle the samples when combining, otherwise they are concatenated.
consistency_level: str, [warning, error]
If warning, print a warning if inconsistencies are discovered between the results before combining.
If error, raise an error if inconsistencies are discovered between the results before combining.
Returns
-------
result: bilby.core.result.Result
The combined result file
"""
self.consistency_level = consistency_level
if len(self) == 0:
return Result()
elif len(self) == 1:
@@ -1844,24 +1861,37 @@ class ResultList(list):
except ValueError:
raise ResultListError("Not all results contain nested samples")
def _error_or_warning_consistency(self, msg):
if self.consistency_level == "error":
raise ResultListError(msg)
elif self.consistency_level == "warning":
logger.warning(msg)
else:
raise ValueError(f"Input consistency_level {self.consistency_level} not understood")
def check_consistent_priors(self):
for res in self:
for p in self[0].priors.keys():
if not self[0].priors[p] == res.priors[p] or len(self[0].priors) != len(res.priors):
raise ResultListError("Inconsistent priors between results")
msg = "Inconsistent priors between results"
self._error_or_warning_consistency(msg)
def check_consistent_parameters(self):
if not np.all([set(self[0].search_parameter_keys) == set(res.search_parameter_keys) for res in self]):
raise ResultListError("Inconsistent parameters between results")
msg = "Inconsistent parameters between results"
self._error_or_warning_consistency(msg)
def check_consistent_data(self):
if not np.allclose([res.log_noise_evidence for res in self], self[0].log_noise_evidence, atol=1e-8, rtol=0.0)\
and not np.all([np.isnan(res.log_noise_evidence) for res in self]):
raise ResultListError("Inconsistent data between results")
msg = "Inconsistent data between results"
self._error_or_warning_consistency(msg)
def check_consistent_sampler(self):
if not np.all([res.sampler == self[0].sampler for res in self]):
raise ResultListError("Inconsistent samplers between results")
msg = "Inconsistent samplers between results"
self._error_or_warning_consistency(msg)
@latex_plot_format
Loading