diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py index caa03ede669f3161a39443b934328dceb98d13aa..0930ccb4115623fc86ec5130493288028275b24e 100644 --- a/bilby/core/sampler/pymc3.py +++ b/bilby/core/sampler/pymc3.py @@ -62,12 +62,12 @@ class Pymc3(MCMCSampler): live_plot_kwargs=None, compute_convergence_checks=True, use_mmap=False) def __init__(self, likelihood, priors, outdir='outdir', label='label', - use_ratio=False, plot=False, draws=1000, + use_ratio=False, plot=False, skip_import_verification=False, **kwargs): Sampler.__init__(self, likelihood, priors, outdir=outdir, label=label, use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification, **kwargs) - self.draws = draws + self.draws = self.__kwargs['draws'] self.chains = self.__kwargs['chains'] @staticmethod @@ -401,7 +401,7 @@ class Pymc3(MCMCSampler): sms = self.step_method[key] else: sms = [self.step_method[key]] - + print(sms) for sm in sms: if sm.lower() not in step_methods: raise ValueError("Using invalid step method '{}'".format(self.step_method[key])) @@ -454,10 +454,16 @@ class Pymc3(MCMCSampler): if curmethod == 'nuts': if nuts_kwargs is not None: args = nuts_kwargs - else: + elif step_kwargs is not None: args = step_kwargs.get('nuts', {}) + else : + args = {} else: - args = step_kwargs.get(curmethod, {}) + if step_kwargs is not None : + args = step_kwargs.get(curmethod, {}) + print(args) + else : + args = {} self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args)) else: with self.pymc3_model: