diff --git a/bilby/core/result.py b/bilby/core/result.py index 7d7f7318b976e82aab4ff907beda872fcd14720c..11ec22a0ebb28f409dd41d038f9b33251ed07325 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -156,7 +156,7 @@ def rejection_sample(posterior, weights): Parameters ---------- - posterior: pd.DataFrame + posterior: pd.DataFrame or np.ndarray of shape (nsamples, nparameters) The dataframe containing posterior samples weights: np.ndarray An array of weights @@ -168,7 +168,10 @@ def rejection_sample(posterior, weights): """ keep = weights > np.random.uniform(0, max(weights), weights.shape) - return posterior.iloc[keep] + if isinstance(posterior, np.ndarray): + return posterior[keep] + elif isinstance(posterior, pd.DataFrame): + return posterior.iloc[keep] def reweight(result, label=None, new_likelihood=None, new_prior=None, diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index bc8266c828ecb627d716225becd53973afdb20f1..6c99cdde1672fe7994a6d249278ef8851453b33f 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -18,9 +18,9 @@ from ..utils import ( check_directory_exists_and_if_not_mkdir, reflect, safe_file_dump, - kish_log_effective_sample_size ) from .base_sampler import Sampler, NestedSampler +from ..result import rejection_sample from numpy import linalg from dynesty.utils import unitcheck @@ -581,22 +581,21 @@ class Dynesty(NestedSampler): self.dump_samples_to_dat() def dump_samples_to_dat(self): - from dynesty.utils import resample_equal sampler = self.sampler ln_weights = sampler.saved_logwt - sampler.saved_logz[-1] - neff = int(np.exp(kish_log_effective_sample_size(ln_weights))) + + weights = np.exp(ln_weights) + samples = rejection_sample(np.array(sampler.saved_v), weights) + nsamples = len(samples) # If we don't have enough samples, don't dump them - if neff < 100: + if nsamples < 100: return - weights = np.exp(ln_weights) - samples = resample_equal(np.array(sampler.saved_v), weights) - df = DataFrame(samples, columns=self.search_parameter_keys) - # Downsample to neff - df = df.sample(neff) filename = "{}/{}_samples.dat".format(self.outdir, self.label) - logger.info("Writing current samples to {} with neff={}".format(filename, neff)) + logger.info("Writing {} current samples to {}".format(nsamples, filename)) + + df = DataFrame(samples, columns=self.search_parameter_keys) df.to_csv(filename, index=False, header=True, sep=' ') def plot_current_state(self):