diff --git a/bilby/core/result.py b/bilby/core/result.py index f1a18a96389209e7d07f736606b4d9c57040d46f..3bf68e8d61be24b2aecfb0a612a756cc8ea0b6f0 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -1188,31 +1188,127 @@ class Result(object): "keyword argument, e.g. " + caller_func.__name__ + "(outdir='.')") return outdir - def __add__(self, other): + +class ResultsList(object): + + def __init__(self, results=None): + """ A class to store a list of :class:`bilby.core.result.Result` objects + from equivalent runs on the same data. This provides methods for + outputing combined results. + + Parameters + ---------- + results: list + A :class:`bilby.core.result.Result` object, or list of + results. + """ + + self.results = results + self.__currentidx = 0 # index for iterator + + @property + def results(self): + return self._results + + @results.setter + def results(self, results): + self._results = [] # empty list + if results is None: + return + + self.append(results) + + def append(self, results): + """ + Append a :class:`bilby.core.result.Result`, or set of results, to the + list. + + Parameters + ---------- + results: :class:`bilby.core.result.Result`, or list of results + A :class:`bilby.core.result.Result`, list of + :class:`bilby.core.result.Result` object, or list of filenames + pointing to results objects, to append to the list. + """ + + if not isinstance(results, (list, ResultsList)): + results = [results] # make into list for iteration + + # loop over list + for result in results: + if isinstance(result, Result): + # append new result + self._results.append(result) + elif isinstance(result, str): + # try reading from file + try: + self._results.append(read_in_result(result)) + except Exception as e: + raise IOError("Could not read in results file: " + "{}".format(e)) + else: + raise TypeError("Could not append a non-Result type") + + def __len__(self): + return len(self.results) + + def __iter__(self): + self.__currentidx = 0 # reset iterator index + return self + + def __next__(self): + if self.__currentidx >= len(self): + raise StopIteration + else: + self.__currentidx += 1 + return self.results[self.__currentidx - 1] + + next = __next__ # Python 2 next + + def combine(self): """ - Method to add two Results objects. + Return the combined results in a :class:bilby.core.result.Result` + object. """ - if not isinstance(other, Result): - raise TypeError("Trying to add a non-Result object to a Result " - "object") + from copy import deepcopy + + if len(self) == 0: + return Result() + elif len(self) == 1: + return deepcopy(self.results[0]) + else: + # get first result + result = deepcopy(self.results[0]) + + # check all results are equivalent + sampler = result.sampler + if not np.all([res.sampler == sampler for res in self]): + raise ValueError("Cannot combine results: inconsistent samplers") + + log_noise_evidence = result.log_noise_evidence + if not np.all([res.log_noise_evidence == log_noise_evidence for res in self]): + raise ValueError("Cannot combine results: inconsistent data") - # check that the results have some common features + parameters = result.search_parameter_keys + if not np.all([set(parameters) == set(res.search_parameter_keys) for res in self]): + raise ValueError("Cannot combine results: inconsistent parameters") - # check parameters are the same - if not set(self.search_parameter_keys) == set(other.search_parameter_keys): - raise ValueError("Results being added contain inconsistent parameters") + priors = result.priors + for result in self: + for p in parameters: + if not priors[p] == result.priors[p]: + raise ValueError("Cannot combine results: inconsistent priors") - if self.log_noise_evidence != other.log_noise_evidence: - raise ValueError("Results being added do not have consistent " - "noise evidences") + # check which kind of sampler was used: MCMC or Nested Sampling + if result.nested_samples is not None: + if not np.all([res.nested_samples is not None for res in self]): + raise ValueError("Cannot combine results: nested samples available") - # check priors are the same - if self.priors is not None and other.priors is not None: - for p in self.search_parameter_keys: - if not self.priors[p] == other.priors[p]: - raise ValueError("Results being added used inconsistent " - "priors") + # combine all nested samples + + else: + # combine MCMC samples def plot_multiple(results, filename=None, labels=None, colours=None,