from __future__ import print_function, division, absolute_import import inspect import logging import os import sys import numpy as np import matplotlib.pyplot as plt from .result import Result, read_in_result from .prior import Prior, fill_priors from . import utils from . import prior import tupak class Sampler(object): """ A sampler object to aid in setting up an inference run Parameters ---------- likelihood: likelihood.GravitationalWaveTransient A object with a log_l method prior: dict The prior to be used in the search. Elements can either be floats (indicating a fixed value or delta function prior) or they can be of type parameter.Parameter with an associated prior sampler_string: str A string containing the module name of the sampler Returns ------- results: A dictionary of the results """ def __init__(self, likelihood, priors, external_sampler='nestle', outdir='outdir', label='label', use_ratio=False, **kwargs): self.likelihood = likelihood self.priors = priors self.label = label self.outdir = outdir self.use_ratio = use_ratio self.external_sampler = external_sampler self.external_sampler_function = None self.__search_parameter_keys = [] self.__fixed_parameter_keys = [] self.initialise_parameters() self.verify_parameters() self.kwargs = kwargs self.check_cached_result() self.log_summary_for_sampler() if os.path.isdir(outdir) is False: os.makedirs(outdir) self.result = self.initialise_result() @property def search_parameter_keys(self): return self.__search_parameter_keys @property def fixed_parameter_keys(self): return self.__fixed_parameter_keys @property def ndim(self): return len(self.__search_parameter_keys) @property def kwargs(self): return self.__kwargs @kwargs.setter def kwargs(self, kwargs): self.__kwargs = kwargs @property def external_sampler(self): return self.__external_sampler @external_sampler.setter def external_sampler(self, sampler): if type(sampler) is str: try: self.__external_sampler = __import__(sampler) except ImportError: raise ImportError( "Sampler {} not installed on this system".format(sampler)) elif isinstance(sampler, Sampler): self.__external_sampler = sampler else: raise TypeError('sampler must either be a string referring to built in sampler or a custom made class that ' 'inherits from sampler') def verify_kwargs_against_external_sampler_function(self): args = inspect.getargspec(self.external_sampler_function).args bad_keys = [] for user_input in self.kwargs.keys(): if user_input not in args: logging.warning( "Supplied argument '{}' not an argument of '{}', removing." .format(user_input, self.external_sampler_function)) bad_keys.append(user_input) for key in bad_keys: self.kwargs.pop(key) def initialise_parameters(self): for key in self.priors: if isinstance(self.priors[key], Prior) is True \ and self.priors[key].is_fixed is False: self.__search_parameter_keys.append(key) elif isinstance(self.priors[key], Prior) \ and self.priors[key].is_fixed is True: self.likelihood.parameters[key] = \ self.priors[key].sample() self.__fixed_parameter_keys.append(key) logging.info("Search parameters:") for key in self.__search_parameter_keys: logging.info(' {} ~ {}'.format(key, self.priors[key])) for key in self.__fixed_parameter_keys: logging.info(' {} = {}'.format(key, self.priors[key].peak)) def initialise_result(self): result = Result() result.search_parameter_keys = self.__search_parameter_keys result.fixed_parameter_keys = self.__fixed_parameter_keys result.parameter_labels = [ self.priors[k].latex_label for k in self.__search_parameter_keys] result.label = self.label result.outdir = self.outdir result.kwargs = self.kwargs return result def verify_parameters(self): for key in self.priors: try: self.likelihood.parameters[key] = self.priors[key].sample() except AttributeError as e: logging.warning('Cannot sample from {}, {}'.format(key, e)) try: self.likelihood.log_likelihood_ratio() except TypeError: raise TypeError('GravitationalWaveTransient evaluation failed. Have you definitely specified all the parameters?\n{}'.format( self.likelihood.parameters)) def prior_transform(self, theta): return [self.priors[key].rescale(t) for key, t in zip(self.__search_parameter_keys, theta)] def log_prior(self, theta): return np.sum( [np.log(self.priors[key].prob(t)) for key, t in zip(self.__search_parameter_keys, theta)]) def log_likelihood(self, theta): for i, k in enumerate(self.__search_parameter_keys): self.likelihood.parameters[k] = theta[i] if self.use_ratio: return self.likelihood.log_likelihood_ratio() else: return self.likelihood.log_likelihood() def get_random_draw_from_prior(self): """ Get a random draw from the prior distribution Returns draw: array_like An ndim-length array of values drawn from the prior. Parameters with delta-function (or fixed) priors are not returned """ draw = np.array([self.priors[key].sample() for key in self.__search_parameter_keys]) if np.isinf(self.log_likelihood(draw)): logging.info('Prior draw {} has inf likelihood'.format(draw)) if np.isinf(self.log_prior(draw)): logging.info('Prior draw {} has inf prior'.format(draw)) return draw def run_sampler(self): pass def check_cached_result(self): """ Check if the cached data file exists and can be used """ if utils.command_line_args.clean: logging.debug("Command line argument clean given, forcing rerun") self.cached_result = None return try: self.cached_result = read_in_result(self.outdir, self.label) except ValueError: self.cached_result = None if utils.command_line_args.use_cached: logging.debug("Command line argument cached given, no cache check performed") return logging.debug("Checking cached data") if self.cached_result: check_keys = ['search_parameter_keys', 'fixed_parameter_keys', 'kwargs'] use_cache = True for key in check_keys: if self.cached_result.check_attribute_match_to_other_object( key, self) is False: logging.debug("Cached value {} is unmatched".format(key)) use_cache = False if use_cache is False: self.cached_result = None def log_summary_for_sampler(self): if self.cached_result is None: logging.info("Using sampler {} with kwargs {}".format( self.__class__.__name__, self.kwargs)) class Nestle(Sampler): @property def kwargs(self): return self.__kwargs @kwargs.setter def kwargs(self, kwargs): self.__kwargs = dict(verbose=True, method='multi') self.__kwargs.update(kwargs) if 'npoints' not in self.__kwargs: for equiv in ['nlive', 'nlives', 'n_live_points']: if equiv in self.__kwargs: self.__kwargs['npoints'] = self.__kwargs.pop(equiv) def run_sampler(self): nestle = self.external_sampler self.external_sampler_function = nestle.sample if self.kwargs.get('verbose', True): self.kwargs['callback'] = nestle.print_progress self.kwargs.pop('verbose') self.verify_kwargs_against_external_sampler_function() out = self.external_sampler_function( loglikelihood=self.log_likelihood, prior_transform=self.prior_transform, ndim=self.ndim, **self.kwargs) print("") self.result.sampler_output = out self.result.samples = nestle.resample_equal(out.samples, out.weights) self.result.logz = out.logz self.result.logzerr = out.logzerr return self.result class Dynesty(Sampler): @property def kwargs(self): return self.__kwargs @kwargs.setter def kwargs(self, kwargs): self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk', walks=self.ndim * 5, verbose=True) self.__kwargs.update(kwargs) if 'nlive' not in self.__kwargs: for equiv in ['nlives', 'n_live_points', 'npoint', 'npoints']: if equiv in self.__kwargs: self.__kwargs['nlive'] = self.__kwargs.pop(equiv) if 'nlive' not in self.__kwargs: self.__kwargs['nlive'] = 250 if 'update_interval' not in self.__kwargs: self.__kwargs['update_interval'] = int(0.6 * self.__kwargs['nlive']) def run_sampler(self): dynesty = self.external_sampler if self.kwargs.get('dynamic', False) is False: nested_sampler = dynesty.NestedSampler( loglikelihood=self.log_likelihood, prior_transform=self.prior_transform, ndim=self.ndim, **self.kwargs) nested_sampler.run_nested( dlogz=self.kwargs['dlogz'], print_progress=self.kwargs['verbose']) else: nested_sampler = dynesty.DynamicNestedSampler( loglikelihood=self.log_likelihood, prior_transform=self.prior_transform, ndim=self.ndim, **self.kwargs) nested_sampler.run_nested(print_progress=self.kwargs['verbose']) print("") out = nested_sampler.results # self.result.sampler_output = out weights = np.exp(out['logwt'] - out['logz'][-1]) self.result.samples = dynesty.utils.resample_equal( out.samples, weights) self.result.logz = out.logz[-1] self.result.logzerr = out.logzerr[-1] return self.result class Pymultinest(Sampler): @property def kwargs(self): return self.__kwargs @kwargs.setter def kwargs(self, kwargs): outputfiles_basename = self.outdir + '/pymultinest_{}/'.format(self.label) utils.check_directory_exists_and_if_not_mkdir(outputfiles_basename) self.__kwargs = dict(importance_nested_sampling=False, resume=True, verbose=True, sampling_efficiency='parameter', outputfiles_basename=outputfiles_basename) self.__kwargs.update(kwargs) if self.__kwargs['outputfiles_basename'].endswith('/') is False: self.__kwargs['outputfiles_basename'] = '{}/'.format( self.__kwargs['outputfiles_basename']) if 'n_live_points' not in self.__kwargs: for equiv in ['nlive', 'nlives', 'npoints', 'npoint']: if equiv in self.__kwargs: self.__kwargs['n_live_points'] = self.__kwargs.pop(equiv) def run_sampler(self): pymultinest = self.external_sampler self.external_sampler_function = pymultinest.run self.verify_kwargs_against_external_sampler_function() # Note: pymultinest.solve adds some extra steps, but underneath # we are calling pymultinest.run - hence why it is used in checking # the arguments. out = pymultinest.solve( LogLikelihood=self.log_likelihood, Prior=self.prior_transform, n_dims=self.ndim, **self.kwargs) self.result.sampler_output = out self.result.samples = out['samples'] self.result.logz = out['logZ'] self.result.logzerr = out['logZerr'] self.result.outputfiles_basename = self.kwargs['outputfiles_basename'] return self.result class Ptemcee(Sampler): def run_sampler(self): ntemps = self.kwargs.pop('ntemps', 2) nwalkers = self.kwargs.pop('nwalkers', 100) nsteps = self.kwargs.pop('nsteps', 100) nburn = self.kwargs.pop('nburn', 50) ptemcee = self.external_sampler tqdm = utils.get_progress_bar(self.kwargs.pop('tqdm', 'tqdm')) sampler = ptemcee.Sampler( ntemps=ntemps, nwalkers=nwalkers, dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior, **self.kwargs) pos0 = [[self.get_random_draw_from_prior() for i in range(nwalkers)] for j in range(ntemps)] for result in tqdm( sampler.sample(pos0, iterations=nsteps, adapt=True), total=nsteps): pass self.result.sampler_output = np.nan self.result.samples = sampler.chain[0, :, nburn:, :].reshape( (-1, self.ndim)) self.result.walkers = sampler.chain[0, :, :, :] self.result.logz = np.nan self.result.logzerr = np.nan self.plot_walkers() logging.info("Max autocorr time = {}".format(np.max(sampler.get_autocorr_time()))) logging.info("Tswap frac = {}".format(sampler.tswap_acceptance_fraction)) return self.result def plot_walkers(self, save=True, **kwargs): nwalkers, nsteps, ndim = self.result.walkers.shape idxs = np.arange(nsteps) fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3*self.ndim)) for i, ax in enumerate(axes): ax.plot(idxs, self.result.walkers[:, :, i].T, lw=0.1, color='k') ax.set_ylabel(self.result.parameter_labels[i]) fig.tight_layout() filename = '{}/{}_walkers.png'.format(self.outdir, self.label) logging.info('Saving walkers plot to {}'.format('filename')) fig.savefig(filename) def run_sampler(likelihood, priors=None, label='label', outdir='outdir', sampler='nestle', use_ratio=True, injection_parameters=None, conversion_function=None, **kwargs): """ The primary interface to easy parameter estimation Parameters ---------- likelihood: `tupak.likelihood.GravitationalWaveTransient` A `GravitationalWaveTransient` instance priors: dict A dictionary of the priors for each parameter - missing parameters will use default priors, if None, all priors will be default label: str Name for the run, used in output files outdir: str A string used in defining output files sampler: str The name of the sampler to use - see `tupak.sampler.get_implemented_samplers()` for a list of available samplers use_ratio: bool (False) If True, use the likelihood's loglikelihood_ratio, rather than just the log likelhood. injection_parameters: dict A dictionary of injection parameters used in creating the data (if using simulated data). Appended to the result object and saved. conversion_function: function, optional Function to apply to posterior to generate additional parameters. **kwargs: All kwargs are passed directly to the samplers `run` function Returns ------ result An object containing the results """ utils.check_directory_exists_and_if_not_mkdir(outdir) implemented_samplers = get_implemented_samplers() if priors is None: priors = dict() priors = fill_priors(priors, likelihood) tupak.prior.write_priors_to_file(priors, outdir) if implemented_samplers.__contains__(sampler.title()): sampler_class = globals()[sampler.title()] sampler = sampler_class(likelihood, priors, sampler, outdir=outdir, label=label, use_ratio=use_ratio, **kwargs) if sampler.cached_result: logging.info("Using cached result") return sampler.cached_result result = sampler.run_sampler() result.noise_logz = likelihood.noise_log_likelihood() if use_ratio: result.log_bayes_factor = result.logz result.logz = result.log_bayes_factor + result.noise_logz else: result.log_bayes_factor = result.logz - result.noise_logz if injection_parameters is not None: result.injection_parameters = injection_parameters if conversion_function is not None: conversion_function(result.injection_parameters) result.fixed_parameter_keys = [key for key in priors if isinstance(key, prior.DeltaFunction)] # result.prior = prior # Removed as this breaks the saving of the data result.samples_to_data_frame(likelihood=likelihood, priors=priors, conversion_function=conversion_function) result.kwargs = sampler.kwargs result.save_to_file(outdir=outdir, label=label) return result else: raise ValueError( "Sampler {} not yet implemented".format(sampler)) def get_implemented_samplers(): implemented_samplers = [] for name, obj in inspect.getmembers(sys.modules[__name__]): if inspect.isclass(obj): implemented_samplers.append(obj.__name__) return implemented_samplers