Skip to content
Snippets Groups Projects
Commit 412c11de authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Add some checking to the samplers

For the pymultinest and nestle samplers only
- Adds equivalent parameters nlive, n_live_points, npoints etc all do
  the same thing
- Adds checking for any kwargs given to run_sampler that will not be
  used. Removes them and warns the user.
parent 3b78354c
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -100,6 +100,15 @@ class Sampler(object):
def kwargs(self, kwargs):
self.__kwargs = kwargs
def verify_kwargs_against_external_sampler_function(self):
args = inspect.getargspec(self.external_sampler_function).args
for user_input in self.kwargs.keys():
if user_input not in args:
logging.warning(
"Supplied argument '{}' not an argument of '{}', removing."
.format(user_input, self.external_sampler_function))
self.kwargs.pop(user_input)
def initialise_parameters(self):
for key in self.priors:
......@@ -177,13 +186,19 @@ class Nestle(Sampler):
@kwargs.setter
def kwargs(self, kwargs):
self.__kwargs = kwargs
if 'npoints' not in self.__kwargs:
for equiv in ['nlive', 'nlives', 'n_live_points']:
if equiv in self.__kwargs:
self.__kwargs['npoints'] = self.__kwargs.pop(equiv)
def run_sampler(self):
nestle = self.external_sampler
self.external_sampler_function = nestle.sample
if self.kwargs.get('verbose', True):
self.kwargs['callback'] = nestle.print_progress
self.verify_kwargs_against_external_sampler_function()
out = nestle.sample(
out = self.external_sampler_function(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
ndim=self.ndim, **self.kwargs)
......@@ -231,9 +246,18 @@ class Pymultinest(Sampler):
if self.__kwargs['outputfiles_basename'].endswith('/') is False:
self.__kwargs['outputfiles_basename'] = '{}/'.format(
self.__kwargs['outputfiles_basename'])
if 'n_live_points' not in self.__kwargs:
for equiv in ['nlive', 'nlives', 'npoints', 'npoint']:
if equiv in self.__kwargs:
self.__kwargs['n_live_points'] = self.__kwargs.pop(equiv)
def run_sampler(self):
pymultinest = self.external_sampler
self.external_sampler_function = pymultinest.run
self.verify_kwargs_against_external_sampler_function()
# Note: pymultinest.solve adds some extra steps, but underneath
# we are calling pymultinest.run - hence why it is used in checking
# the arguments.
out = pymultinest.solve(
LogLikelihood=self.log_likelihood, Prior=self.prior_transform,
n_dims=self.ndim, **self.kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment