Skip to content
Snippets Groups Projects
Commit be643f91 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Convert to iterations per check

parent b0bdcaac
No related branches found
No related tags found
1 merge request!750Improve ptemcee
Pipeline #114488 failed
......@@ -30,6 +30,7 @@ ConvergenceInputs = namedtuple(
"nsamples",
"ignore_keys_for_tau",
"min_tau",
"niterations_per_check",
],
)
......@@ -144,6 +145,7 @@ class Ptemcee(MCMCSampler):
store_walkers=False,
ignore_keys_for_tau=None,
pos0="prior",
niterations_per_check=10,
**kwargs
):
super(Ptemcee, self).__init__(
......@@ -187,6 +189,7 @@ class Ptemcee(MCMCSampler):
nsamples=nsamples,
ignore_keys_for_tau=ignore_keys_for_tau,
min_tau=min_tau,
niterations_per_check=niterations_per_check,
)
self.convergence_inputs = ConvergenceInputs(**convergence_inputs_dict)
......@@ -410,7 +413,9 @@ class Ptemcee(MCMCSampler):
logger.info("Starting to sample")
while True:
for (pos0, log_posterior, log_likelihood) in sampler.sample(
self.pos0, storechain=False, **self.sampler_function_kwargs):
self.pos0, storechain=False,
iterations=self.convergence_inputs.niterations_per_check,
**self.sampler_function_kwargs):
pass
if self.iteration == self.chain_array.shape[1]:
......@@ -419,6 +424,7 @@ class Ptemcee(MCMCSampler):
self.log_likelihood_array = np.concatenate((
self.log_likelihood_array, self.get_zero_log_likelihood_array()),
axis=2)
self.pos0 = pos0
self.chain_array[:, self.iteration, :] = pos0[0, :, :]
self.log_likelihood_array[:, :, self.iteration] = log_likelihood
......@@ -427,6 +433,8 @@ class Ptemcee(MCMCSampler):
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
t0 = datetime.datetime.now()
self.iteration += 1
(
stop,
self.nburn,
......@@ -443,7 +451,6 @@ class Ptemcee(MCMCSampler):
self.tau_list_n,
)
self.iteration += 1
if stop:
logger.info("Finished sampling")
......@@ -606,7 +613,8 @@ def check_iteration(
if np.isnan(tau_int) or np.isinf(tau_int):
print_progress(
iteration, sampler, time_per_check, ci.nsamples, np.nan, np.nan, np.nan, np.nan, False,
iteration, sampler, time_per_check, np.nan, np.nan,
np.nan, np.nan, False, convergence_inputs,
)
return False, np.nan, np.nan, np.nan, np.nan
......@@ -638,12 +646,12 @@ def check_iteration(
iteration,
sampler,
time_per_check,
ci.nsamples,
nsamples_effective,
samples_per_check,
tau_int,
max_frac,
tau_usable,
convergence_inputs,
)
stop = converged and tau_usable
return stop, nburn, thin, tau_int, nsamples_effective
......@@ -653,12 +661,12 @@ def print_progress(
iteration,
sampler,
time_per_check,
nsamples,
nsamples_effective,
samples_per_check,
tau_int,
max_frac,
tau_usable,
convergence_inputs,
):
# Setup acceptance string
acceptance = sampler.acceptance_fraction[0, :]
......@@ -671,7 +679,7 @@ def print_progress(
)
ave_time_per_check = np.mean(time_per_check[-3:])
time_left = (nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
time_left = (convergence_inputs.nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
if time_left > 0:
time_left = str(datetime.timedelta(seconds=int(time_left)))
else:
......@@ -688,21 +696,22 @@ def print_progress(
else:
tau_str = "!{}".format(tau_str)
evals_per_check = sampler.nwalkers * sampler.ntemps
evals_per_check = sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check
ncalls = "{:1.1e}".format(iteration * sampler.nwalkers * sampler.ntemps)
eval_timing = "{:1.1f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check)
ncalls = "{:1.1e}".format(
convergence_inputs.niterations_per_check * iteration * sampler.nwalkers * sampler.ntemps)
eval_timing = "{:1.2f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check)
samp_timing = "{:1.1f}ms/sm".format(1e3 * ave_time_per_check / samples_per_check)
print(
"{}| {} | nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
"{}| {}| nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
iteration,
str(sampling_time).split(".")[0],
ncalls,
acceptance_str,
tswap_acceptance_str,
nsamples_effective,
nsamples,
convergence_inputs.nsamples,
tau_str,
eval_timing,
samp_timing,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment