Skip to content
Snippets Groups Projects
Commit 72bf21a9 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Add option to initialize from minimize

parent e0e010e9
No related branches found
No related tags found
1 merge request!750Improve ptemcee
Pipeline #113456 failed
......@@ -13,7 +13,7 @@ import pandas as pd
import matplotlib.pyplot as plt
from ..utils import logger
from .base_sampler import MCMCSampler
from .base_sampler import SamplerError, MCMCSampler
class Ptemcee(MCMCSampler):
......@@ -49,7 +49,7 @@ class Ptemcee(MCMCSampler):
autocorr_c=5, safety=1, frac_threshold=0.01,
autocorr_tol=50, autocorr_tau=5, min_tau=1, check_point_deltaT=600,
threads=1, exit_code=77, plot=False, store_walkers=False,
ignore_keys_for_tau="recalib", **kwargs):
ignore_keys_for_tau="recalib", pos0="prior", **kwargs):
super(Ptemcee, self).__init__(
likelihood=likelihood, priors=priors, outdir=outdir,
label=label, use_ratio=use_ratio, plot=plot,
......@@ -74,6 +74,7 @@ class Ptemcee(MCMCSampler):
self.threads = threads
self.store_walkers = store_walkers
self.ignore_keys_for_tau = ignore_keys_for_tau
self.pos0 = pos0
self.check_point_plot = check_point_plot
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(self.outdir, self.label)
......@@ -97,6 +98,37 @@ class Ptemcee(MCMCSampler):
for _ in range(self.sampler_init_kwargs["nwalkers"])]
for _ in range(self.kwargs['ntemps'])]
def get_pos0_from_minimize(self):
logger.info("Attempting to set pos0 from minimize")
from scipy.optimize import minimize
def neg_log_like(params):
try:
return -self.log_likelihood(params)
except RuntimeError:
return +np.inf
bounds = [(self.priors[key].minimum, self.priors[key].maximum)
for key in self.search_parameter_keys]
trials = 0
success = []
while True:
x0 = self.get_random_draw_from_prior()
res = minimize(
neg_log_like, x0, bounds=bounds, method='L-BFGS-B', tol=1e-15)
if res.success:
success.append(res.x)
if trials > 100:
raise SamplerError("Unable to set pos0 from minimize")
if len(success) >= 3:
break
pos0_min = np.min(success, axis=0)
pos0_max = np.max(success, axis=0)
pos0 = np.random.uniform(
pos0_min, pos0_max,
size=(self.kwargs["ntemps"], self.kwargs["nwalkers"], self.ndim))
return pos0
def setup_sampler(self):
import ptemcee
if os.path.isfile(self.resume_file) and self.resume is True:
......@@ -132,10 +164,18 @@ class Ptemcee(MCMCSampler):
self.time_per_check = []
# Initialize the walker postitions
pos0 = self.get_pos0_from_prior()
pos0 = self.get_pos0()
return self.sampler, pos0
def get_pos0(self):
if self.pos0.lower() == "prior":
return self.get_pos0_from_prior()
elif self.pos0.lower() == "minimize":
return self.get_pos0_from_minimize()
else:
raise SamplerError("pos0={} not implemented".format(self.pos0))
def setup_pool(self):
if self.threads > 1:
import schwimmbad
......@@ -394,8 +434,6 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
with open(resume_file, "wb") as file:
dill.dump(data, file, protocol=4)
del data, sampler_copy
logger.info("Finished writing checkpoint")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment