Skip to content
Snippets Groups Projects
Commit 1f724ae8 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'master' into 'master'

Fixing some argument bugs in pymc3.py.

See merge request lscsoft/bilby!299
parents 5566d9d0 1be08c59
No related branches found
No related tags found
No related merge requests found
......@@ -62,12 +62,12 @@ class Pymc3(MCMCSampler):
live_plot_kwargs=None, compute_convergence_checks=True, use_mmap=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, draws=1000,
use_ratio=False, plot=False,
skip_import_verification=False, **kwargs):
Sampler.__init__(self, likelihood, priors, outdir=outdir, label=label,
use_ratio=use_ratio, plot=plot,
skip_import_verification=skip_import_verification, **kwargs)
self.draws = draws
self.draws = self.__kwargs['draws']
self.chains = self.__kwargs['chains']
@staticmethod
......@@ -401,7 +401,6 @@ class Pymc3(MCMCSampler):
sms = self.step_method[key]
else:
sms = [self.step_method[key]]
for sm in sms:
if sm.lower() not in step_methods:
raise ValueError("Using invalid step method '{}'".format(self.step_method[key]))
......@@ -425,9 +424,21 @@ class Pymc3(MCMCSampler):
# set the prior
self.set_prior()
# if a custom log_likelihood function requires a `sampler` argument
# then use that log_likelihood function, with the assumption that it
# takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
# the likelihood within that context manager
likeargs = infer_args_from_method(self.likelihood.log_likelihood)
if 'sampler' in likeargs:
self.likelihood.log_likelihood(sampler=self)
else:
# set the likelihood function from predefined functions
self.set_likelihood()
# get the step method keyword arguments
step_kwargs = self.kwargs.pop('step_kwargs')
nuts_kwargs = self.kwargs.pop('nuts_kwargs')
methodslist = []
# set the step method
if isinstance(self.step_method, (dict, OrderedDict)):
......@@ -439,25 +450,41 @@ class Pymc3(MCMCSampler):
if isinstance(self.step_method[key], list):
for sms in self.step_method[key]:
curmethod = sms.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = step_kwargs.get('nuts', {})
args = {}
else:
args = step_kwargs.get(curmethod, {})
if step_kwargs is not None:
args = step_kwargs.get(curmethod, {})
else:
args = {}
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
else:
curmethod = self.step_method[key].lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = step_kwargs.get('nuts', {})
args = {}
else:
args = step_kwargs.get(curmethod, {})
if step_kwargs is not None:
args = step_kwargs.get(curmethod, {})
else:
args = {}
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
else:
with self.pymc3_model:
......@@ -466,43 +493,48 @@ class Pymc3(MCMCSampler):
compound = []
for sms in self.step_method:
curmethod = sms.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = step_kwargs.get('nuts', {})
args = {}
else:
args = step_kwargs.get(curmethod, {})
compound.append(pymc3.__dict__[step_methods[curmethod]](**args))
self.kwargs['step'] = compound
else:
self.kwargs['step'] = None
if self.step_method is not None:
curmethod = self.step_method.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = step_kwargs.get('nuts', {})
args = {}
else:
args = step_kwargs.get(curmethod, {})
self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
else:
# re-add step_kwargs and nuts_kwargs if no step methods are set
self.kwargs['nuts_kwargs'] = nuts_kwargs
# re-add step_kwargs if no step methods are set
self.kwargs['step_kwargs'] = step_kwargs
# if a custom log_likelihood function requires a `sampler` argument
# then use that log_likelihood function, with the assumption that it
# takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
# the likelihood within that context manager
likeargs = infer_args_from_method(self.likelihood.log_likelihood)
if 'sampler' in likeargs:
self.likelihood.log_likelihood(sampler=self)
else:
# set the likelihood function from predefined functions
self.set_likelihood()
# check whether only NUTS step method has been assigned
if np.all([sm.lower() == 'nuts' for sm in methodslist]):
# in this case we can let PyMC3 autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs
self.kwargs['step'] = None
self.kwargs['nuts_kwargs'] = nuts_kwargs
with self.pymc3_model:
# perform the sampling
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment