diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py index 0193e0f9a5566044c53e602ab6a7aa7860a8f7d5..302fbf4f582c97c20dd4f703058a20cfdd3f97b6 100644 --- a/tupak/core/sampler.py +++ b/tupak/core/sampler.py @@ -751,45 +751,59 @@ class Pymultinest(Sampler): return self.result -class Ptemcee(Sampler): +class Emcee(Sampler): + """ https://github.com/dfm/emcee """ def _run_external_sampler(self): - ntemps = self.kwargs.pop('ntemps', 2) - nwalkers = self.kwargs.pop('nwalkers', 100) - nsteps = self.kwargs.pop('nsteps', 100) - nburn = self.kwargs.pop('nburn', 50) - ptemcee = self.external_sampler + self.nwalkers = self.kwargs.pop('nwalkers', 100) + self.nsteps = self.kwargs.pop('nsteps', 100) + self.nburn = self.kwargs.pop('nburn', 50) + emcee = self.external_sampler tqdm = utils.get_progress_bar(self.kwargs.pop('tqdm', 'tqdm')) - sampler = ptemcee.Sampler( - ntemps=ntemps, nwalkers=nwalkers, dim=self.ndim, - logl=self.log_likelihood, logp=self.log_prior, + sampler = emcee.EnsembleSampler( + nwalkers=self.nwalkers, dim=self.ndim, lnpostfn=self.lnpostfn, **self.kwargs) - pos0 = [[self.get_random_draw_from_prior() - for i in range(nwalkers)] - for j in range(ntemps)] + pos0 = [self.get_random_draw_from_prior() for i in range(self.nwalkers)] for result in tqdm( - sampler.sample(pos0, iterations=nsteps, adapt=True), total=nsteps): + sampler.sample(pos0, iterations=self.nsteps), total=self.nsteps): pass self.result.sampler_output = np.nan - self.result.samples = sampler.chain[0, :, nburn:, :].reshape( + self.result.samples = sampler.chain[:, self.nburn:, :].reshape( (-1, self.ndim)) - self.result.walkers = sampler.chain[0, :, :, :] + self.result.walkers = sampler.chain[:, :, :] self.result.log_evidence = np.nan self.result.log_evidence_err = np.nan - self.plot_walkers() - logging.info("Max autocorr time = {}".format(np.max(sampler.get_autocorr_time()))) - logging.info("Tswap frac = {}".format(sampler.tswap_acceptance_fraction)) + if self.plot: + self.plot_walkers() + try: + logging.info("Max autocorr time = {}".format( + np.max(sampler.get_autocorr_time()))) + except emcee.autocorr.AutocorrError as e: + logging.info("Unable to calculate autocorr time: {}".format(e)) return self.result + def lnpostfn(self, theta): + return self.log_likelihood(theta) + self.log_prior(theta) + + def _get_walkers_to_plot(self): + return self.result.walkers[:, :, :] + def plot_walkers(self, save=True, **kwargs): nwalkers, nsteps, ndim = self.result.walkers.shape idxs = np.arange(nsteps) fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3*self.ndim)) + walkers = self._get_walkers_to_plot() for i, ax in enumerate(axes): - ax.plot(idxs, self.result.walkers[:, :, i].T, lw=0.1, color='k') + ax.plot(idxs[:self.nburn+1], walkers[:, :self.nburn+1, i].T, + lw=0.1, color='r') + ax.set_ylabel(self.result.parameter_labels[i]) + + for i, ax in enumerate(axes): + ax.plot(idxs[self.nburn:], walkers[:, self.nburn:, i].T, lw=0.1, + color='k') ax.set_ylabel(self.result.parameter_labels[i]) fig.tight_layout() @@ -798,6 +812,45 @@ class Ptemcee(Sampler): fig.savefig(filename) +class Ptemcee(Emcee): + """ https://github.com/willvousden/ptemcee """ + + def _run_external_sampler(self): + self.ntemps = self.kwargs.pop('ntemps', 2) + self.nwalkers = self.kwargs.pop('nwalkers', 100) + self.nsteps = self.kwargs.pop('nsteps', 100) + self.nburn = self.kwargs.pop('nburn', 50) + ptemcee = self.external_sampler + tqdm = utils.get_progress_bar(self.kwargs.pop('tqdm', 'tqdm')) + + sampler = ptemcee.Sampler( + ntemps=self.ntemps, nwalkers=self.nwalkers, dim=self.ndim, + logl=self.log_likelihood, logp=self.log_prior, + **self.kwargs) + pos0 = [[self.get_random_draw_from_prior() + for i in range(self.nwalkers)] + for j in range(self.ntemps)] + + for result in tqdm( + sampler.sample(pos0, iterations=self.nsteps, adapt=True), + total=self.nsteps): + pass + + self.result.sampler_output = np.nan + self.result.samples = sampler.chain[0, :, self.nburn:, :].reshape( + (-1, self.ndim)) + self.result.walkers = sampler.chain[0, :, :, :] + self.result.log_evidence = np.nan + self.result.log_evidence_err = np.nan + if self.plot: + self.plot_walkers() + logging.info("Max autocorr time = {}" + .format(np.max(sampler.get_autocorr_time()))) + logging.info("Tswap frac = {}" + .format(sampler.tswap_acceptance_fraction)) + return self.result + + def run_sampler(likelihood, priors=None, label='label', outdir='outdir', sampler='dynesty', use_ratio=None, injection_parameters=None, conversion_function=None, plot=False, default_priors_file=None,