From 3ca4cf5d1fab3e08af44693e8a5c76ea29d772f9 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Mon, 11 May 2020 09:45:04 +1000 Subject: [PATCH] Update to rejection sampling for the dumped posterior --- bilby/core/result.py | 7 +++++-- bilby/core/sampler/dynesty.py | 19 +++++++++---------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/bilby/core/result.py b/bilby/core/result.py index 7d7f7318b..11ec22a0e 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -156,7 +156,7 @@ def rejection_sample(posterior, weights): Parameters ---------- - posterior: pd.DataFrame + posterior: pd.DataFrame or np.ndarray of shape (nsamples, nparameters) The dataframe containing posterior samples weights: np.ndarray An array of weights @@ -168,7 +168,10 @@ def rejection_sample(posterior, weights): """ keep = weights > np.random.uniform(0, max(weights), weights.shape) - return posterior.iloc[keep] + if isinstance(posterior, np.ndarray): + return posterior[keep] + elif isinstance(posterior, pd.DataFrame): + return posterior.iloc[keep] def reweight(result, label=None, new_likelihood=None, new_prior=None, diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index bc8266c82..6c99cdde1 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -18,9 +18,9 @@ from ..utils import ( check_directory_exists_and_if_not_mkdir, reflect, safe_file_dump, - kish_log_effective_sample_size ) from .base_sampler import Sampler, NestedSampler +from ..result import rejection_sample from numpy import linalg from dynesty.utils import unitcheck @@ -581,22 +581,21 @@ class Dynesty(NestedSampler): self.dump_samples_to_dat() def dump_samples_to_dat(self): - from dynesty.utils import resample_equal sampler = self.sampler ln_weights = sampler.saved_logwt - sampler.saved_logz[-1] - neff = int(np.exp(kish_log_effective_sample_size(ln_weights))) + + weights = np.exp(ln_weights) + samples = rejection_sample(np.array(sampler.saved_v), weights) + nsamples = len(samples) # If we don't have enough samples, don't dump them - if neff < 100: + if nsamples < 100: return - weights = np.exp(ln_weights) - samples = resample_equal(np.array(sampler.saved_v), weights) - df = DataFrame(samples, columns=self.search_parameter_keys) - # Downsample to neff - df = df.sample(neff) filename = "{}/{}_samples.dat".format(self.outdir, self.label) - logger.info("Writing current samples to {} with neff={}".format(filename, neff)) + logger.info("Writing {} current samples to {}".format(nsamples, filename)) + + df = DataFrame(samples, columns=self.search_parameter_keys) df.to_csv(filename, index=False, header=True, sep=' ') def plot_current_state(self): -- GitLab