From 50b4c66f1f9751c6cb57e17224db28ce817af3a5 Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Fri, 21 May 2021 02:29:23 +0000 Subject: [PATCH] Allow ptemcee initialization with array --- bilby/core/sampler/ptemcee.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index 471439a1d..d7f79d80c 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -92,13 +92,15 @@ class Ptemcee(MCMCSampler): is not recommended for cases where tau is large. ignore_keys_for_tau: str A pattern used to ignore keys in estimating the autocorrelation time. - pos0: str, list ("prior") + pos0: str, list, np.ndarray If a string, one of "prior" or "minimize". For "prior", the initial positions of the sampler are drawn from the sampler. If "minimize", a scipy.optimize step is applied to all parameters a number of times. The walkers are then initialized from the range of values obtained. If a list, for the keys in the list the optimization step is applied, - otherwise the initial points are drawn from the prior. + otherwise the initial points are drawn from the prior. If a numpy array + the shape should be (ntemps, nwalkers, ndim). + niterations_per_check: int (5) The number of iteration steps to take before checking ACT. This effectively pre-thins the chains. Larger values reduce the per-eval @@ -363,6 +365,17 @@ class Ptemcee(MCMCSampler): ) return pos0 + def get_pos0_from_array(self): + if self.pos0.shape != (self.ntemps, self.nwalkers, self.ndim): + raise ValueError( + "Shape of starting array should be (ntemps, nwalkers, ndim). " + "In this case that is ({}, {}, {}), got {}".format( + self.ntemps, self.nwalkers, self.ndim, self.pos0.shape + ) + ) + else: + return self.pos0 + def setup_sampler(self): """ Either initialize the sampler or read in the resume file """ import ptemcee @@ -446,6 +459,8 @@ class Ptemcee(MCMCSampler): return self.get_pos0_from_minimize() elif isinstance(self.pos0, list): return self.get_pos0_from_minimize(minimize_list=self.pos0) + elif isinstance(self.pos0, np.ndarray): + return self.get_pos0_from_array() else: raise SamplerError("pos0={} not implemented".format(self.pos0)) -- GitLab