diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py index bdfecffa86ade96880f3265b23813bfca9cf8ce3..4f45928ce584d85dd031e4f21130cd4a58832f3c 100644 --- a/bilby/core/sampler/pymc3.py +++ b/bilby/core/sampler/pymc3.py @@ -36,11 +36,13 @@ class Pymc3(MCMCSampler): chains. step: str, dict Provide a step method name, or dictionary of step method names keyed to - particular variable names (these are case insensitive). If no method is - provided for any particular variable then PyMC3 will automatically - decide upon a default, with the first option being the NUTS sampler. - The currently allowed methods are 'NUTS', 'HamiltonianMC', - 'Metropolis', 'BinaryMetropolis', 'BinaryGibbsMetropolis', 'Slice', and + particular variable names (these are case insensitive). If passing a + dictionary of methods, the values keyed on particular variables can be + lists of methods to form compound steps. If no method is provided for + any particular variable then PyMC3 will automatically decide upon a + default, with the first option being the NUTS sampler. The currently + allowed methods are 'NUTS', 'HamiltonianMC', 'Metropolis', + 'BinaryMetropolis', 'BinaryGibbsMetropolis', 'Slice', and 'CategoricalGibbsMetropolis'. Note: you cannot provide a PyMC3 step method function itself here as it is outside of the model context manager. @@ -393,13 +395,27 @@ class Pymc3(MCMCSampler): if key not in self.__search_parameter_keys: raise ValueError("Setting a step method for an unknown parameter '{}'".format(key)) else: - if self.step_method[key].lower() not in step_methods: - raise ValueError("Using invalid step method '{}'".format(self.step_method[key])) + # check if using a compound step (a list of step + # methods for a particular parameter) + if isinstance(self.step_method[key], list): + sms = self.step_method[key] + else: + sms = [self.step_method[key]] + + for sm in sms: + if sm.lower() not in step_methods: + raise ValueError("Using invalid step method '{}'".format(self.step_method[key])) else: - self.step_method = self.step_method.lower() + # check if using a compound step (a list of step + # methods for a particular parameter) + if isinstance(self.step_method, list): + sms = self.step_method + else: + sms = [self.step_method] - if self.step_method not in step_methods: - raise ValueError("Using invalid step method '{}'".format(self.step_method)) + for i in range(len(sms)): + if sms[i].lower() not in step_methods: + raise ValueError("Using invalid step method '{}'".format(sms[i])) else: self.step_method = None @@ -415,11 +431,23 @@ class Pymc3(MCMCSampler): self.kwargs['step'] = [] with self.pymc3_model: for key in self.step_method: - curmethod = self.step_method[key].lower() - self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]])) + # check for a compound step list + if isinstance(self.step_method[key], list): + for sms in self.step_method[key]: + curmethod = sms.lower() + self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]])) + else: + curmethod = self.step_method[key].lower() + self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]])) else: with self.pymc3_model: - self.kwargs['step'] = None if self.step_method is None else pymc3.__dict__[step_methods[self.step_method]]() + # check for a compound step list + if isinstance(self.step_method, list): + compound = [] + for sms in self.step_method: + compound.append(pymc3.__dict__[step_methods[sms.lower()]]()) + else: + self.kwargs['step'] = None if self.step_method is None else pymc3.__dict__[step_methods[self.step_method.lower()]]() # if a custom log_likelihood function requires a `sampler` argument # then use that log_likelihood function, with the assumption that it