From 145d3989afd1d2933773262434d7337927fb1af4 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Wed, 4 Jul 2018 16:37:24 +1000 Subject: [PATCH] Adds in the ability to pass through pos0 of the walkers --- tupak/core/sampler.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py index 22e0797e..89d542df 100644 --- a/tupak/core/sampler.py +++ b/tupak/core/sampler.py @@ -348,8 +348,15 @@ class Sampler(object): def _log_summary_for_sampler(self): """Print a summary of the sampler used and its kwargs""" if self.cached_result is None: + kwargs_print = self.kwargs.copy() + for k in kwargs_print: + if type(kwargs_print[k]) in (list, np.ndarray): + array_repr = np.array(kwargs_print[k]) + if array_repr.shape > 10: + kwargs_print[k] = ('array_like, shape={}' + .format(array_repr.shape)) logging.info("Using sampler {} with kwargs {}".format( - self.__class__.__name__, self.kwargs)) + self.__class__.__name__, kwargs_print)) class Nestle(Sampler): @@ -764,7 +771,17 @@ class Emcee(Sampler): sampler = emcee.EnsembleSampler( nwalkers=self.nwalkers, dim=self.ndim, lnpostfn=self.lnpostfn, a=a) - pos0 = [self.get_random_draw_from_prior() for i in range(self.nwalkers)] + + if 'pos0' in self.kwargs: + logging.debug("Using given initial positions for walkers") + pos0 = np.squeeze(self.kwargs['pos0']) + if pos0.shape != (self.ndim, self.nwalkers): + raise ValueError( + 'Input pos0 should be of shape ndim, nwalkers') + else: + logging.debug("Generating initial walker positions from prior") + pos0 = [self.get_random_draw_from_prior() + for i in range(self.nwalkers)] for result in tqdm( sampler.sample(pos0, iterations=self.nsteps), total=self.nsteps): -- GitLab