Skip to content
Snippets Groups Projects

Add dynesty sample dump

Merged Gregory Ashton requested to merge add-dynesty-sample-dump into master
All threads resolved!
Files
2
@@ -18,9 +18,9 @@ from ..utils import (
@@ -18,9 +18,9 @@ from ..utils import (
check_directory_exists_and_if_not_mkdir,
check_directory_exists_and_if_not_mkdir,
reflect,
reflect,
safe_file_dump,
safe_file_dump,
kish_log_effective_sample_size
)
)
from .base_sampler import Sampler, NestedSampler
from .base_sampler import Sampler, NestedSampler
 
from ..result import rejection_sample
from numpy import linalg
from numpy import linalg
from dynesty.utils import unitcheck
from dynesty.utils import unitcheck
@@ -581,22 +581,21 @@ class Dynesty(NestedSampler):
@@ -581,22 +581,21 @@ class Dynesty(NestedSampler):
self.dump_samples_to_dat()
self.dump_samples_to_dat()
def dump_samples_to_dat(self):
def dump_samples_to_dat(self):
from dynesty.utils import resample_equal
sampler = self.sampler
sampler = self.sampler
ln_weights = sampler.saved_logwt - sampler.saved_logz[-1]
ln_weights = sampler.saved_logwt - sampler.saved_logz[-1]
neff = int(np.exp(kish_log_effective_sample_size(ln_weights)))
 
weights = np.exp(ln_weights)
 
samples = rejection_sample(np.array(sampler.saved_v), weights)
 
nsamples = len(samples)
# If we don't have enough samples, don't dump them
# If we don't have enough samples, don't dump them
if neff < 100:
if nsamples < 100:
return
return
weights = np.exp(ln_weights)
samples = resample_equal(np.array(sampler.saved_v), weights)
df = DataFrame(samples, columns=self.search_parameter_keys)
# Downsample to neff
df = df.sample(neff)
filename = "{}/{}_samples.dat".format(self.outdir, self.label)
filename = "{}/{}_samples.dat".format(self.outdir, self.label)
logger.info("Writing current samples to {} with neff={}".format(filename, neff))
logger.info("Writing {} current samples to {}".format(nsamples, filename))
 
 
df = DataFrame(samples, columns=self.search_parameter_keys)
df.to_csv(filename, index=False, header=True, sep=' ')
df.to_csv(filename, index=False, header=True, sep=' ')
def plot_current_state(self):
def plot_current_state(self):
Loading