Skip to content
Snippets Groups Projects
Commit 145d3989 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adds in the ability to pass through pos0 of the walkers

parent 82a23f1c
No related branches found
No related tags found
No related merge requests found
......@@ -348,8 +348,15 @@ class Sampler(object):
def _log_summary_for_sampler(self):
"""Print a summary of the sampler used and its kwargs"""
if self.cached_result is None:
kwargs_print = self.kwargs.copy()
for k in kwargs_print:
if type(kwargs_print[k]) in (list, np.ndarray):
array_repr = np.array(kwargs_print[k])
if array_repr.shape > 10:
kwargs_print[k] = ('array_like, shape={}'
.format(array_repr.shape))
logging.info("Using sampler {} with kwargs {}".format(
self.__class__.__name__, self.kwargs))
self.__class__.__name__, kwargs_print))
class Nestle(Sampler):
......@@ -764,7 +771,17 @@ class Emcee(Sampler):
sampler = emcee.EnsembleSampler(
nwalkers=self.nwalkers, dim=self.ndim, lnpostfn=self.lnpostfn,
a=a)
pos0 = [self.get_random_draw_from_prior() for i in range(self.nwalkers)]
if 'pos0' in self.kwargs:
logging.debug("Using given initial positions for walkers")
pos0 = np.squeeze(self.kwargs['pos0'])
if pos0.shape != (self.ndim, self.nwalkers):
raise ValueError(
'Input pos0 should be of shape ndim, nwalkers')
else:
logging.debug("Generating initial walker positions from prior")
pos0 = [self.get_random_draw_from_prior()
for i in range(self.nwalkers)]
for result in tqdm(
sampler.sample(pos0, iterations=self.nsteps), total=self.nsteps):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment