Skip to content
Snippets Groups Projects
Commit 5de5ea82 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'adding-emcee' into 'master'

Adds the emcee sampler class

See merge request Monash/tupak!88
parents 6d3734d8 c2e3e3d1
No related branches found
No related tags found
1 merge request!88Adds the emcee sampler class
Pipeline #
......@@ -751,45 +751,59 @@ class Pymultinest(Sampler):
return self.result
class Ptemcee(Sampler):
class Emcee(Sampler):
""" https://github.com/dfm/emcee """
def _run_external_sampler(self):
ntemps = self.kwargs.pop('ntemps', 2)
nwalkers = self.kwargs.pop('nwalkers', 100)
nsteps = self.kwargs.pop('nsteps', 100)
nburn = self.kwargs.pop('nburn', 50)
ptemcee = self.external_sampler
self.nwalkers = self.kwargs.pop('nwalkers', 100)
self.nsteps = self.kwargs.pop('nsteps', 100)
self.nburn = self.kwargs.pop('nburn', 50)
emcee = self.external_sampler
tqdm = utils.get_progress_bar(self.kwargs.pop('tqdm', 'tqdm'))
sampler = ptemcee.Sampler(
ntemps=ntemps, nwalkers=nwalkers, dim=self.ndim,
logl=self.log_likelihood, logp=self.log_prior,
sampler = emcee.EnsembleSampler(
nwalkers=self.nwalkers, dim=self.ndim, lnpostfn=self.lnpostfn,
**self.kwargs)
pos0 = [[self.get_random_draw_from_prior()
for i in range(nwalkers)]
for j in range(ntemps)]
pos0 = [self.get_random_draw_from_prior() for i in range(self.nwalkers)]
for result in tqdm(
sampler.sample(pos0, iterations=nsteps, adapt=True), total=nsteps):
sampler.sample(pos0, iterations=self.nsteps), total=self.nsteps):
pass
self.result.sampler_output = np.nan
self.result.samples = sampler.chain[0, :, nburn:, :].reshape(
self.result.samples = sampler.chain[:, self.nburn:, :].reshape(
(-1, self.ndim))
self.result.walkers = sampler.chain[0, :, :, :]
self.result.walkers = sampler.chain[:, :, :]
self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan
self.plot_walkers()
logging.info("Max autocorr time = {}".format(np.max(sampler.get_autocorr_time())))
logging.info("Tswap frac = {}".format(sampler.tswap_acceptance_fraction))
if self.plot:
self.plot_walkers()
try:
logging.info("Max autocorr time = {}".format(
np.max(sampler.get_autocorr_time())))
except emcee.autocorr.AutocorrError as e:
logging.info("Unable to calculate autocorr time: {}".format(e))
return self.result
def lnpostfn(self, theta):
return self.log_likelihood(theta) + self.log_prior(theta)
def _get_walkers_to_plot(self):
return self.result.walkers[:, :, :]
def plot_walkers(self, save=True, **kwargs):
nwalkers, nsteps, ndim = self.result.walkers.shape
idxs = np.arange(nsteps)
fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3*self.ndim))
walkers = self._get_walkers_to_plot()
for i, ax in enumerate(axes):
ax.plot(idxs, self.result.walkers[:, :, i].T, lw=0.1, color='k')
ax.plot(idxs[:self.nburn+1], walkers[:, :self.nburn+1, i].T,
lw=0.1, color='r')
ax.set_ylabel(self.result.parameter_labels[i])
for i, ax in enumerate(axes):
ax.plot(idxs[self.nburn:], walkers[:, self.nburn:, i].T, lw=0.1,
color='k')
ax.set_ylabel(self.result.parameter_labels[i])
fig.tight_layout()
......@@ -798,6 +812,45 @@ class Ptemcee(Sampler):
fig.savefig(filename)
class Ptemcee(Emcee):
""" https://github.com/willvousden/ptemcee """
def _run_external_sampler(self):
self.ntemps = self.kwargs.pop('ntemps', 2)
self.nwalkers = self.kwargs.pop('nwalkers', 100)
self.nsteps = self.kwargs.pop('nsteps', 100)
self.nburn = self.kwargs.pop('nburn', 50)
ptemcee = self.external_sampler
tqdm = utils.get_progress_bar(self.kwargs.pop('tqdm', 'tqdm'))
sampler = ptemcee.Sampler(
ntemps=self.ntemps, nwalkers=self.nwalkers, dim=self.ndim,
logl=self.log_likelihood, logp=self.log_prior,
**self.kwargs)
pos0 = [[self.get_random_draw_from_prior()
for i in range(self.nwalkers)]
for j in range(self.ntemps)]
for result in tqdm(
sampler.sample(pos0, iterations=self.nsteps, adapt=True),
total=self.nsteps):
pass
self.result.sampler_output = np.nan
self.result.samples = sampler.chain[0, :, self.nburn:, :].reshape(
(-1, self.ndim))
self.result.walkers = sampler.chain[0, :, :, :]
self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan
if self.plot:
self.plot_walkers()
logging.info("Max autocorr time = {}"
.format(np.max(sampler.get_autocorr_time())))
logging.info("Tswap frac = {}"
.format(sampler.tswap_acceptance_fraction))
return self.result
def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
sampler='dynesty', use_ratio=None, injection_parameters=None,
conversion_function=None, plot=False, default_priors_file=None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment