diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index db7d61e1243abb347207593dc00e053772b0602d..6d3bf6353f3a62d9659122bb2dcb1487a5ed48ea 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -157,6 +157,10 @@ class Ptemcee(MCMCSampler): **kwargs ) + self.nwalkers = self.sampler_init_kwargs["nwalkers"] + self.ntemps = self.sampler_init_kwargs["ntemps"] + self.max_steps = 500 + # Setup up signal handling signal.signal(signal.SIGTERM, self.write_current_state_and_exit) signal.signal(signal.SIGINT, self.write_current_state_and_exit) @@ -228,7 +232,7 @@ class Ptemcee(MCMCSampler): return [ [ self.get_random_draw_from_prior() - for _ in range(self.sampler_init_kwargs["nwalkers"]) + for _ in range(self.nwalkers) ] for _ in range(self.kwargs["ntemps"]) ] @@ -319,6 +323,10 @@ class Ptemcee(MCMCSampler): # Extract the check-point data self.sampler = data["sampler"] + self.iteration = data["iteration"] + self.chain_array = data["chain_array"] + self.log_likelihood_array = data["log_likelihood_array"] + self.pos0 = data["pos0"] self.tau_list = data["tau_list"] self.tau_list_n = data["tau_list_n"] self.time_per_check = data["time_per_check"] @@ -327,11 +335,8 @@ class Ptemcee(MCMCSampler): self.sampler.pool = self.pool self.sampler.threads = self.threads - # Set pos0 to None for resuming - pos0 = None - logger.info( - "Resuming from previous run with time={}".format(self.sampler.time) + "Resuming from previous run with time={}".format(self.iteration) ) else: @@ -357,15 +362,22 @@ class Ptemcee(MCMCSampler): self.search_parameter_keys, use_ratio=self.use_ratio ) - # Set up empty lists + # Initialize storing results + self.iteration = 0 + self.chain_array = self.get_zero_chain_array() + self.log_likelihood_array = self.get_zero_log_likelihood_array() self.tau_list = [] self.tau_list_n = [] self.time_per_check = [] + self.pos0 = self.get_pos0() + + return self.sampler - # Initialize the walker postitions - pos0 = self.get_pos0() + def get_zero_chain_array(self): + return np.zeros((self.nwalkers, self.max_steps, self.ndim)) - return self.sampler, pos0 + def get_zero_log_likelihood_array(self): + return np.zeros((self.ntemps, self.nwalkers, self.max_steps)) def get_pos0(self): """ Master logic for setting pos0 """ @@ -392,14 +404,25 @@ class Ptemcee(MCMCSampler): def run_sampler(self): self.setup_pool() - sampler, pos0 = self.setup_sampler() + sampler = self.setup_sampler() t0 = datetime.datetime.now() logger.info("Starting to sample") while True: - for (pos0, _, _) in sampler.sample(pos0, **self.sampler_function_kwargs): + for (pos0, log_posterior, log_likelihood) in sampler.sample( + self.pos0, storechain=False, **self.sampler_function_kwargs): pass + if self.iteration == self.chain_array.shape[1]: + self.chain_array = np.concatenate(( + self.chain_array, self.get_zero_chain_array()), axis=1) + self.log_likelihood_array = np.concatenate(( + self.log_likelihood_array, self.get_zero_log_likelihood_array()), + axis=2) + self.pos0 = pos0 + self.chain_array[:, self.iteration, :] = pos0[0, :, :] + self.log_likelihood_array[:, :, self.iteration] = log_likelihood + # Calculate time per iteration self.time_per_check.append((datetime.datetime.now() - t0).total_seconds()) t0 = datetime.datetime.now() @@ -411,6 +434,7 @@ class Ptemcee(MCMCSampler): self.tau_int, self.nsamples_effective, ) = check_iteration( + self.chain_array[:, :self.iteration + 1, :], sampler, self.convergence_inputs, self.search_parameter_keys, @@ -419,6 +443,8 @@ class Ptemcee(MCMCSampler): self.tau_list_n, ) + self.iteration += 1 + if stop: logger.info("Finished sampling") break @@ -436,12 +462,11 @@ class Ptemcee(MCMCSampler): self.write_current_state(plot=self.check_point_plot) # Get 0-likelihood samples and store in the result - samples = sampler.chain[0, :, :, :] # nwalkers, nsteps, ndim - self.result.samples = samples[ - :, self.nburn : sampler.time : self.thin, : + self.result.samples = self.chain_array[ + :, self.nburn : self.iteration : self.thin, : ].reshape((-1, self.ndim)) - loglikelihood = sampler.loglikelihood[ - 0, :, self.nburn : sampler.time : self.thin + loglikelihood = self.log_likelihood_array[ + 0, :, self.nburn : self.iteration : self.thin ] # nwalkers, nsteps self.result.log_likelihood_evaluations = loglikelihood.reshape((-1)) @@ -450,7 +475,8 @@ class Ptemcee(MCMCSampler): self.result.nburn = self.nburn log_evidence, log_evidence_err = compute_evidence( - sampler, self.outdir, self.label, self.nburn, self.thin + sampler, self.log_likelihood_array, self.outdir, self.label, self.nburn, + self.thin, self.iteration, ) self.result.log_evidence = log_evidence self.result.log_evidence_err = log_evidence_err @@ -475,6 +501,7 @@ class Ptemcee(MCMCSampler): def write_current_state(self, plot=True): checkpoint( + self.iteration, self.outdir, self.label, self.nsamples_effective, @@ -483,6 +510,9 @@ class Ptemcee(MCMCSampler): self.thin, self.search_parameter_keys, self.resume_file, + self.log_likelihood_array, + self.chain_array, + self.pos0, self.tau_list, self.tau_list_n, self.time_per_check, @@ -491,7 +521,7 @@ class Ptemcee(MCMCSampler): if plot and not np.isnan(self.nburn): # Generate the walkers plot diagnostic plot_walkers( - self.sampler.chain[0, :, : self.sampler.time, :], + self.chain_array[:, : self.iteration, :], self.nburn, self.thin, self.search_parameter_keys, @@ -512,6 +542,7 @@ class Ptemcee(MCMCSampler): def check_iteration( + samples, sampler, convergence_inputs, search_parameter_keys, @@ -523,9 +554,6 @@ def check_iteration( Parameters ---------- - sampler: ptemcee.Sampler - The initialized and running sampler object, this assumes it is run with - storechain=True in order to pull out the chain convergence_inputs: bilby.core.sampler.ptemcee.ConvergenceInputs A named tuple of the convergence checking inputs search_parameter_keys: list @@ -549,55 +577,51 @@ def check_iteration( import emcee ci = convergence_inputs + nwalkers, iteration, ndim = samples.shape # Compute ACT tau for 0-temperature chains - samples = sampler.chain[0, :, : sampler.time, :] - taus = [] - for ii in range(sampler.nwalkers): - tau_ii = [] + tau_array = np.zeros((nwalkers, ndim)) + for ii in range(nwalkers): for jj, key in enumerate(search_parameter_keys): if ci.ignore_keys_for_tau and ci.ignore_keys_for_tau in key: continue try: - tau_ii.append( - emcee.autocorr.integrated_time( - samples[ii, :, jj], c=ci.autocorr_c, tol=0 - )[0] - ) + tau_array[ii, jj] = emcee.autocorr.integrated_time( + samples[ii, :, jj], c=ci.autocorr_c, tol=0)[0] except emcee.autocorr.AutocorrError: - taus.append(np.inf) - taus.append(tau_ii) + tau_array[ii, jj] = np.inf - tau = np.max(np.mean(taus, axis=0)) + # Maximum over paramters, mean over walkers + tau = np.max(np.mean(tau_array, axis=0)) # Apply multiplicitive safety factor tau = ci.safety * tau # Store for convergence checking and plotting - tau_list.append(np.mean(taus, axis=0)) - tau_list_n.append(sampler.time) + tau_list.append(list(np.mean(tau_array, axis=0))) + tau_list_n.append(iteration) # Convert to an integer tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau if np.isnan(tau_int) or np.isinf(tau_int): print_progress( - sampler, time_per_check, ci.nsamples, np.nan, np.nan, np.nan, np.nan, False, + iteration, sampler, time_per_check, ci.nsamples, np.nan, np.nan, np.nan, np.nan, False, ) return False, np.nan, np.nan, np.nan, np.nan # Calculate the effective number of samples available nburn = int(ci.burn_in_nact * tau_int) thin = int(np.max([1, ci.thin_by_nact * tau_int])) - samples_per_check = sampler.nwalkers / thin - nsamples_effective = int(sampler.nwalkers * (sampler.time - nburn) / thin) + samples_per_check = nwalkers / thin + nsamples_effective = int(nwalkers * (iteration - nburn) / thin) # Calculate convergence boolean converged = ci.nsamples < nsamples_effective - # Calculate fractional change in tau from previous iterations + # Calculate fractional change in tau from previous iteration check_taus = np.array(tau_list[-tau_int * ci.autocorr_tau :]) - taus_per_parameter = np.array(taus)[-1, :] + taus_per_parameter = check_taus[-1, :] if not np.any(np.isnan(check_taus)): frac = (taus_per_parameter - check_taus) / taus_per_parameter max_frac = np.max(frac) @@ -606,11 +630,12 @@ def check_iteration( max_frac = np.nan tau_usable = False - if sampler.time < tau_int * ci.autocorr_tol or tau_int < ci.min_tau: + if iteration < tau_int * ci.autocorr_tol or tau_int < ci.min_tau: tau_usable = False # Print an update on the progress print_progress( + iteration, sampler, time_per_check, ci.nsamples, @@ -625,6 +650,7 @@ def check_iteration( def print_progress( + iteration, sampler, time_per_check, nsamples, @@ -664,13 +690,13 @@ def print_progress( evals_per_check = sampler.nwalkers * sampler.ntemps - ncalls = "{:1.1e}".format(sampler.time * sampler.nwalkers * sampler.ntemps) + ncalls = "{:1.1e}".format(iteration * sampler.nwalkers * sampler.ntemps) eval_timing = "{:1.1f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check) samp_timing = "{:1.1f}ms/sm".format(1e3 * ave_time_per_check / samples_per_check) print( "{}| {} | nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format( - sampler.time, + iteration, str(sampling_time).split(".")[0], ncalls, acceptance_str, @@ -686,6 +712,7 @@ def print_progress( def checkpoint( + iteration, outdir, label, nsamples_effective, @@ -694,6 +721,9 @@ def checkpoint( thin, search_parameter_keys, resume_file, + log_likelihood_array, + chain_array, + pos0, tau_list, tau_list_n, time_per_check, @@ -704,7 +734,7 @@ def checkpoint( # Store the samples if possible if nsamples_effective > 0: filename = "{}/{}_samples.txt".format(outdir, label) - samples = sampler.chain[0, :, nburn : sampler.time : thin, :].reshape( + samples = np.array(chain_array)[:, nburn : iteration : thin, :].reshape( (-1, ndim) ) df = pd.DataFrame(samples, columns=search_parameter_keys) @@ -713,16 +743,16 @@ def checkpoint( # Pickle the resume artefacts sampler_copy = copy.copy(sampler) del sampler_copy.pool - sampler_copy._chain = sampler._chain[:, :, : sampler.time, :] - sampler_copy._logposterior = sampler._logposterior[:, :, : sampler.time] - sampler_copy._loglikelihood = sampler._loglikelihood[:, :, : sampler.time] - sampler_copy._beta_history = sampler._beta_history[:, : sampler.time] data = dict( + iteration=iteration, sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n, time_per_check=time_per_check, + log_likelihood_array=log_likelihood_array, + chain_array=chain_array, + pos0=pos0, ) with open(resume_file, "wb") as file: @@ -778,11 +808,12 @@ def plot_tau( plt.close(fig) -def compute_evidence(sampler, outdir, label, nburn, thin, make_plots=True): +def compute_evidence(sampler, log_likelihood_array, outdir, label, nburn, thin, + iteration, make_plots=True): """ Computes the evidence using thermodynamic integration """ betas = sampler.betas # We compute the evidence without the burnin samples, but we do not thin - lnlike = sampler.loglikelihood[:, :, nburn : sampler.time] + lnlike = log_likelihood_array[:, :, nburn : iteration] mean_lnlikes = np.mean(np.mean(lnlike, axis=1), axis=1) mean_lnlikes = mean_lnlikes[::-1]