Skip to content
Snippets Groups Projects
Commit 3ca4cf5d authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Update to rejection sampling for the dumped posterior

parent a6646c08
No related branches found
No related tags found
1 merge request!778Add dynesty sample dump
......@@ -156,7 +156,7 @@ def rejection_sample(posterior, weights):
Parameters
----------
posterior: pd.DataFrame
posterior: pd.DataFrame or np.ndarray of shape (nsamples, nparameters)
The dataframe containing posterior samples
weights: np.ndarray
An array of weights
......@@ -168,7 +168,10 @@ def rejection_sample(posterior, weights):
"""
keep = weights > np.random.uniform(0, max(weights), weights.shape)
return posterior.iloc[keep]
if isinstance(posterior, np.ndarray):
return posterior[keep]
elif isinstance(posterior, pd.DataFrame):
return posterior.iloc[keep]
def reweight(result, label=None, new_likelihood=None, new_prior=None,
......
......@@ -18,9 +18,9 @@ from ..utils import (
check_directory_exists_and_if_not_mkdir,
reflect,
safe_file_dump,
kish_log_effective_sample_size
)
from .base_sampler import Sampler, NestedSampler
from ..result import rejection_sample
from numpy import linalg
from dynesty.utils import unitcheck
......@@ -581,22 +581,21 @@ class Dynesty(NestedSampler):
self.dump_samples_to_dat()
def dump_samples_to_dat(self):
from dynesty.utils import resample_equal
sampler = self.sampler
ln_weights = sampler.saved_logwt - sampler.saved_logz[-1]
neff = int(np.exp(kish_log_effective_sample_size(ln_weights)))
weights = np.exp(ln_weights)
samples = rejection_sample(np.array(sampler.saved_v), weights)
nsamples = len(samples)
# If we don't have enough samples, don't dump them
if neff < 100:
if nsamples < 100:
return
weights = np.exp(ln_weights)
samples = resample_equal(np.array(sampler.saved_v), weights)
df = DataFrame(samples, columns=self.search_parameter_keys)
# Downsample to neff
df = df.sample(neff)
filename = "{}/{}_samples.dat".format(self.outdir, self.label)
logger.info("Writing current samples to {} with neff={}".format(filename, neff))
logger.info("Writing {} current samples to {}".format(nsamples, filename))
df = DataFrame(samples, columns=self.search_parameter_keys)
df.to_csv(filename, index=False, header=True, sep=' ')
def plot_current_state(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment