Maintenance will be performed on git.ligo.org, chat.ligo.org, and docs.ligo.org, starting at approximately 10am CDT Tuesday 20 August 2019. The maintenance is expected to take around an hour and here will be two short periods of downtime, one at the beginning of the maintenance and another at the end.

Commit 1f724ae8 authored by Gregory Ashton's avatar Gregory Ashton

Merge branch 'master' into 'master'

Fixing some argument bugs in pymc3.py.

See merge request lscsoft/bilby!299
parents 5566d9d0 1be08c59
Pipeline #42158 passed with stage
in 10 minutes and 54 seconds
......@@ -62,12 +62,12 @@ class Pymc3(MCMCSampler):
live_plot_kwargs=None, compute_convergence_checks=True, use_mmap=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, draws=1000,
use_ratio=False, plot=False,
skip_import_verification=False, **kwargs):
Sampler.__init__(self, likelihood, priors, outdir=outdir, label=label,
use_ratio=use_ratio, plot=plot,
skip_import_verification=skip_import_verification, **kwargs)
self.draws = draws
self.draws = self.__kwargs['draws']
self.chains = self.__kwargs['chains']
@staticmethod
......@@ -401,7 +401,6 @@ class Pymc3(MCMCSampler):
sms = self.step_method[key]
else:
sms = [self.step_method[key]]
for sm in sms:
if sm.lower() not in step_methods:
raise ValueError("Using invalid step method '{}'".format(self.step_method[key]))
......@@ -425,9 +424,21 @@ class Pymc3(MCMCSampler):
# set the prior
self.set_prior()
# if a custom log_likelihood function requires a `sampler` argument
# then use that log_likelihood function, with the assumption that it
# takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
# the likelihood within that context manager
likeargs = infer_args_from_method(self.likelihood.log_likelihood)
if 'sampler' in likeargs:
self.likelihood.log_likelihood(sampler=self)
else:
# set the likelihood function from predefined functions
self.set_likelihood()
# get the step method keyword arguments
step_kwargs = self.kwargs.pop('step_kwargs')
nuts_kwargs = self.kwargs.pop('nuts_kwargs')
methodslist = []
# set the step method
if isinstance(self.step_method, (dict, OrderedDict)):
......@@ -439,25 +450,41 @@ class Pymc3(MCMCSampler):
if isinstance(self.step_method[key], list):
for sms in self.step_method[key]:
curmethod = sms.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = step_kwargs.get('nuts', {})
args = {}
else:
args = step_kwargs.get(curmethod, {})
if step_kwargs is not None:
args = step_kwargs.get(curmethod, {})
else:
args = {}
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
else:
curmethod = self.step_method[key].lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = step_kwargs.get('nuts', {})
args = {}
else:
args = step_kwargs.get(curmethod, {})
if step_kwargs is not None:
args = step_kwargs.get(curmethod, {})
else:
args = {}
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
else:
with self.pymc3_model:
......@@ -466,43 +493,48 @@ class Pymc3(MCMCSampler):
compound = []
for sms in self.step_method:
curmethod = sms.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = step_kwargs.get('nuts', {})
args = {}
else:
args = step_kwargs.get(curmethod, {})
compound.append(pymc3.__dict__[step_methods[curmethod]](**args))
self.kwargs['step'] = compound
else:
self.kwargs['step'] = None
if self.step_method is not None:
curmethod = self.step_method.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = step_kwargs.get('nuts', {})
args = {}
else:
args = step_kwargs.get(curmethod, {})
self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
else:
# re-add step_kwargs and nuts_kwargs if no step methods are set
self.kwargs['nuts_kwargs'] = nuts_kwargs
# re-add step_kwargs if no step methods are set
self.kwargs['step_kwargs'] = step_kwargs
# if a custom log_likelihood function requires a `sampler` argument
# then use that log_likelihood function, with the assumption that it
# takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
# the likelihood within that context manager
likeargs = infer_args_from_method(self.likelihood.log_likelihood)
if 'sampler' in likeargs:
self.likelihood.log_likelihood(sampler=self)
else:
# set the likelihood function from predefined functions
self.set_likelihood()
# check whether only NUTS step method has been assigned
if np.all([sm.lower() == 'nuts' for sm in methodslist]):
# in this case we can let PyMC3 autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs
self.kwargs['step'] = None
self.kwargs['nuts_kwargs'] = nuts_kwargs
with self.pymc3_model:
# perform the sampling
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment