Skip to content
Snippets Groups Projects
Commit e90a156e authored by Rhys Green's avatar Rhys Green
Browse files

adding changes so you can provide step without step_kwargs every time

parent 80ee2d12
No related branches found
No related tags found
1 merge request!299Fixing some argument bugs in pymc3.py.
......@@ -62,12 +62,12 @@ class Pymc3(MCMCSampler):
live_plot_kwargs=None, compute_convergence_checks=True, use_mmap=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, draws=1000,
use_ratio=False, plot=False,
skip_import_verification=False, **kwargs):
Sampler.__init__(self, likelihood, priors, outdir=outdir, label=label,
use_ratio=use_ratio, plot=plot,
skip_import_verification=skip_import_verification, **kwargs)
self.draws = draws
self.draws = self.__kwargs['draws']
self.chains = self.__kwargs['chains']
@staticmethod
......@@ -401,7 +401,7 @@ class Pymc3(MCMCSampler):
sms = self.step_method[key]
else:
sms = [self.step_method[key]]
print(sms)
for sm in sms:
if sm.lower() not in step_methods:
raise ValueError("Using invalid step method '{}'".format(self.step_method[key]))
......@@ -454,10 +454,16 @@ class Pymc3(MCMCSampler):
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
else:
elif step_kwargs is not None:
args = step_kwargs.get('nuts', {})
else :
args = {}
else:
args = step_kwargs.get(curmethod, {})
if step_kwargs is not None :
args = step_kwargs.get(curmethod, {})
print(args)
else :
args = {}
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
else:
with self.pymc3_model:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment