Skip to content
Snippets Groups Projects

Improve ptemcee

Merged Gregory Ashton requested to merge improve-ptemcee into master
1 file
+ 43
5
Compare changes
  • Side-by-side
  • Inline
@@ -13,7 +13,7 @@ import pandas as pd
import matplotlib.pyplot as plt
from ..utils import logger
from .base_sampler import MCMCSampler
from .base_sampler import SamplerError, MCMCSampler
class Ptemcee(MCMCSampler):
@@ -49,7 +49,7 @@ class Ptemcee(MCMCSampler):
autocorr_c=5, safety=1, frac_threshold=0.01,
autocorr_tol=50, autocorr_tau=5, min_tau=1, check_point_deltaT=600,
threads=1, exit_code=77, plot=False, store_walkers=False,
ignore_keys_for_tau="recalib", **kwargs):
ignore_keys_for_tau="recalib", pos0="prior", **kwargs):
super(Ptemcee, self).__init__(
likelihood=likelihood, priors=priors, outdir=outdir,
label=label, use_ratio=use_ratio, plot=plot,
@@ -74,6 +74,7 @@ class Ptemcee(MCMCSampler):
self.threads = threads
self.store_walkers = store_walkers
self.ignore_keys_for_tau = ignore_keys_for_tau
self.pos0 = pos0
self.check_point_plot = check_point_plot
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(self.outdir, self.label)
@@ -97,6 +98,37 @@ class Ptemcee(MCMCSampler):
for _ in range(self.sampler_init_kwargs["nwalkers"])]
for _ in range(self.kwargs['ntemps'])]
def get_pos0_from_minimize(self):
logger.info("Attempting to set pos0 from minimize")
from scipy.optimize import minimize
def neg_log_like(params):
try:
return -self.log_likelihood(params)
except RuntimeError:
return +np.inf
bounds = [(self.priors[key].minimum, self.priors[key].maximum)
for key in self.search_parameter_keys]
trials = 0
success = []
while True:
x0 = self.get_random_draw_from_prior()
res = minimize(
neg_log_like, x0, bounds=bounds, method='L-BFGS-B', tol=1e-15)
if res.success:
success.append(res.x)
if trials > 100:
raise SamplerError("Unable to set pos0 from minimize")
if len(success) >= 3:
break
pos0_min = np.min(success, axis=0)
pos0_max = np.max(success, axis=0)
pos0 = np.random.uniform(
pos0_min, pos0_max,
size=(self.kwargs["ntemps"], self.kwargs["nwalkers"], self.ndim))
return pos0
def setup_sampler(self):
import ptemcee
if os.path.isfile(self.resume_file) and self.resume is True:
@@ -132,10 +164,18 @@ class Ptemcee(MCMCSampler):
self.time_per_check = []
# Initialize the walker postitions
pos0 = self.get_pos0_from_prior()
pos0 = self.get_pos0()
return self.sampler, pos0
def get_pos0(self):
if self.pos0.lower() == "prior":
return self.get_pos0_from_prior()
elif self.pos0.lower() == "minimize":
return self.get_pos0_from_minimize()
else:
raise SamplerError("pos0={} not implemented".format(self.pos0))
def setup_pool(self):
if self.threads > 1:
import schwimmbad
@@ -394,8 +434,6 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
with open(resume_file, "wb") as file:
dill.dump(data, file, protocol=4)
del data, sampler_copy
logger.info("Finished writing checkpoint")
Loading