Skip to content
Snippets Groups Projects

Allow time domain approximants

Merged Colm Talbot requested to merge allow_time_domain_approximants into master
7 files
+ 109
51
Compare changes
  • Side-by-side
  • Inline
Files
7
@@ -51,6 +51,11 @@ class Emcee(MCMCSampler):
use_ratio=False, plot=False, skip_import_verification=False,
pos0=None, nburn=None, burn_in_fraction=0.25, resume=True,
burn_in_act=3, **kwargs):
import emcee
if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
self.prerelease = True
else:
self.prerelease = False
MCMCSampler.__init__(
self, likelihood=likelihood, priors=priors, outdir=outdir,
label=label, use_ratio=use_ratio, plot=plot,
@@ -82,7 +87,6 @@ class Emcee(MCMCSampler):
@property
def sampler_function_kwargs(self):
import emcee
keys = ['lnprob0', 'rstate0', 'blobs0', 'iterations', 'thin', 'storechain', 'mh_proposal']
# updated function keywords for emcee > v2.2.1
@@ -93,7 +97,7 @@ class Emcee(MCMCSampler):
function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
function_kwargs['p0'] = self.pos0
if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
if self.prerelease:
if function_kwargs['mh_proposal'] is not None:
logger.warning("The 'mh_proposal' option is no longer used "
"in emcee v{}, and will be ignored.".format(emcee.__version__))
@@ -109,8 +113,6 @@ class Emcee(MCMCSampler):
@property
def sampler_init_kwargs(self):
import emcee
init_kwargs = {key: value
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs}
@@ -122,7 +124,7 @@ class Emcee(MCMCSampler):
updatekeys = {'dim': 'ndim',
'lnpostfn': 'log_prob_fn'}
if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
if self.prerelease:
for key in updatekeys:
if key in init_kwargs:
init_kwargs[updatekeys[key]] = init_kwargs.pop(key)
@@ -193,8 +195,10 @@ class Emcee(MCMCSampler):
for sample in tqdm(sampler.sample(**self.sampler_function_kwargs),
total=self.nsteps):
points = np.hstack([sample[0], np.array(sample[3])])
# import IPython; IPython.embed()
if self.prerelease:
points = np.hstack([sample.coords, sample.blobs])
else:
points = np.hstack([sample[0], np.array(sample[3])])
with open(out_file, "a") as ff:
for ii, point in enumerate(points):
ff.write(template.format(ii, *point))
Loading