sampler.py 17.6 KB
Newer Older
1 2
from __future__ import print_function, division, absolute_import

3
import inspect
4
import logging
5
import os
6
import sys
7
import numpy as np
8
import matplotlib.pyplot as plt
9

10
from .result import Result, read_in_result
Colm Talbot's avatar
Colm Talbot committed
11
from .prior import Prior, fill_priors
12
from . import utils
Colm Talbot's avatar
Colm Talbot committed
13
from . import prior
moritz's avatar
moritz committed
14
import tupak
15

moritz's avatar
moritz committed
16

17
class Sampler(object):
18 19 20 21
    """ A sampler object to aid in setting up an inference run

    Parameters
    ----------
22
    likelihood: likelihood.Likelihood
23 24 25 26
        A  object with a log_l method
    prior: dict
        The prior to be used in the search. Elements can either be floats
        (indicating a fixed value or delta function prior) or they can be
27
        of type parameter.Parameter with an associated prior
28 29 30 31 32 33 34 35 36 37
    sampler_string: str
        A string containing the module name of the sampler


    Returns
    -------
    results:
        A dictionary of the results

    """
38

39
    def __init__(self, likelihood, priors, external_sampler='nestle', outdir='outdir', label='label', use_ratio=False,
40
                 **kwargs):
41
        self.likelihood = likelihood
42
        self.priors = priors
43 44
        self.label = label
        self.outdir = outdir
45
        self.use_ratio = use_ratio
46
        self.external_sampler = external_sampler
47
        self.external_sampler_function = None
48

49
        self.__search_parameter_keys = []
50
        self.__fixed_parameter_keys = []
51
        self.initialise_parameters()
moritz's avatar
moritz committed
52
        self.verify_parameters()
Colm Talbot's avatar
Colm Talbot committed
53
        self.kwargs = kwargs
moritz's avatar
moritz committed
54

55
        self.check_cached_result()
56 57

        self.log_summary_for_sampler()
58 59 60 61

        if os.path.isdir(outdir) is False:
            os.makedirs(outdir)

62
        self.result = self.initialise_result()
63

64
    @property
65 66 67 68 69 70 71
    def search_parameter_keys(self):
        return self.__search_parameter_keys

    @property
    def fixed_parameter_keys(self):
        return self.__fixed_parameter_keys

72 73 74
    @property
    def ndim(self):
        return len(self.__search_parameter_keys)
75

76 77 78 79 80 81 82
    @property
    def kwargs(self):
        return self.__kwargs

    @kwargs.setter
    def kwargs(self, kwargs):
        self.__kwargs = kwargs
83

84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
    @property
    def external_sampler(self):
        return self.__external_sampler

    @external_sampler.setter
    def external_sampler(self, sampler):
        if type(sampler) is str:
            try:
                self.__external_sampler = __import__(sampler)
            except ImportError:
                raise ImportError(
                    "Sampler {} not installed on this system".format(sampler))
        elif isinstance(sampler, Sampler):
            self.__external_sampler = sampler
        else:
            raise TypeError('sampler must either be a string referring to built in sampler or a custom made class that '
                            'inherits from sampler')

102 103
    def verify_kwargs_against_external_sampler_function(self):
        args = inspect.getargspec(self.external_sampler_function).args
Gregory Ashton's avatar
Gregory Ashton committed
104
        bad_keys = []
105 106 107 108 109
        for user_input in self.kwargs.keys():
            if user_input not in args:
                logging.warning(
                    "Supplied argument '{}' not an argument of '{}', removing."
                    .format(user_input, self.external_sampler_function))
110 111 112
                bad_keys.append(user_input)
        for key in bad_keys:
            self.kwargs.pop(key)
113

114
    def initialise_parameters(self):
115

116
        for key in self.priors:
moritz's avatar
moritz committed
117
            if isinstance(self.priors[key], Prior) is True \
118 119
                    and self.priors[key].is_fixed is False:
                self.__search_parameter_keys.append(key)
moritz's avatar
moritz committed
120
            elif isinstance(self.priors[key], Prior) \
121
                    and self.priors[key].is_fixed is True:
122
                self.likelihood.parameters[key] = \
Colm Talbot's avatar
Colm Talbot committed
123
                    self.priors[key].sample()
124
                self.__fixed_parameter_keys.append(key)
moritz's avatar
moritz committed
125

126
        logging.info("Search parameters:")
127
        for key in self.__search_parameter_keys:
128
            logging.info('  {} ~ {}'.format(key, self.priors[key]))
129 130
        for key in self.__fixed_parameter_keys:
            logging.info('  {} = {}'.format(key, self.priors[key].peak))
131

132 133 134 135 136 137 138 139 140 141 142 143
    def initialise_result(self):
        result = Result()
        result.search_parameter_keys = self.__search_parameter_keys
        result.fixed_parameter_keys = self.__fixed_parameter_keys
        result.parameter_labels = [
            self.priors[k].latex_label for k in
            self.__search_parameter_keys]
        result.label = self.label
        result.outdir = self.outdir
        result.kwargs = self.kwargs
        return result

moritz's avatar
moritz committed
144
    def verify_parameters(self):
145
        for key in self.priors:
146
            try:
Colm Talbot's avatar
Colm Talbot committed
147
                self.likelihood.parameters[key] = self.priors[key].sample()
148 149
            except AttributeError as e:
                logging.warning('Cannot sample from {}, {}'.format(key, e))
150
        try:
Colm Talbot's avatar
Colm Talbot committed
151
            self.likelihood.log_likelihood_ratio()
152
        except TypeError:
Colm Talbot's avatar
Colm Talbot committed
153 154
            raise TypeError('Likelihood evaluation failed. Have you definitely specified all the parameters?\n{}'.format(
                self.likelihood.parameters))
155

156
    def prior_transform(self, theta):
157
        return [self.priors[key].rescale(t) for key, t in zip(self.__search_parameter_keys, theta)]
158

159 160 161 162 163
    def log_prior(self, theta):
        return np.sum(
            [np.log(self.priors[key].prob(t)) for key, t in
                zip(self.__search_parameter_keys, theta)])

164
    def log_likelihood(self, theta):
165
        for i, k in enumerate(self.__search_parameter_keys):
166
            self.likelihood.parameters[k] = theta[i]
167 168 169 170
        if self.use_ratio:
            return self.likelihood.log_likelihood_ratio()
        else:
            return self.likelihood.log_likelihood()
171

172 173 174 175 176 177 178 179 180 181
    def get_random_draw_from_prior(self):
        """ Get a random draw from the prior distribution

        Returns
        draw: array_like
            An ndim-length array of values drawn from the prior. Parameters
            with delta-function (or fixed) priors are not returned

        """

182 183 184 185 186 187 188
        draw = np.array([self.priors[key].sample()
                        for key in self.__search_parameter_keys])
        if np.isinf(self.log_likelihood(draw)):
            logging.info('Prior draw {} has inf likelihood'.format(draw))
        if np.isinf(self.log_prior(draw)):
            logging.info('Prior draw {} has inf prior'.format(draw))
        return draw
189

190 191 192
    def run_sampler(self):
        pass

193
    def check_cached_result(self):
194
        """ Check if the cached data file exists and can be used """
195 196 197 198 199

        if utils.command_line_args.clean:
            logging.debug("Command line argument clean given, forcing rerun")
            self.cached_result = None
            return
Gregory Ashton's avatar
Fix bug  
Gregory Ashton committed
200 201 202 203 204 205

        try:
            self.cached_result = read_in_result(self.outdir, self.label)
        except ValueError:
            self.cached_result = None

206 207 208 209 210
        if utils.command_line_args.use_cached:
            logging.debug("Command line argument cached given, no cache check performed")
            return

        logging.debug("Checking cached data")
211 212 213 214 215
        if self.cached_result:
            check_keys = ['search_parameter_keys', 'fixed_parameter_keys',
                          'kwargs']
            use_cache = True
            for key in check_keys:
216
                if self.cached_result.check_attribute_match_to_other_object(
217 218 219 220 221
                        key, self) is False:
                    logging.debug("Cached value {} is unmatched".format(key))
                    use_cache = False
            if use_cache is False:
                self.cached_result = None
222

223
    def log_summary_for_sampler(self):
224
        if self.cached_result is None:
225 226
            logging.info("Using sampler {} with kwargs {}".format(
                self.__class__.__name__, self.kwargs))
227

228 229

class Nestle(Sampler):
230

moritz's avatar
moritz committed
231 232 233 234 235 236
    @property
    def kwargs(self):
        return self.__kwargs

    @kwargs.setter
    def kwargs(self, kwargs):
Gregory Ashton's avatar
Gregory Ashton committed
237 238 239
        self.__kwargs = dict(verbose=True, method='multi')
        self.__kwargs.update(kwargs)

240 241 242 243
        if 'npoints' not in self.__kwargs:
            for equiv in ['nlive', 'nlives', 'n_live_points']:
                if equiv in self.__kwargs:
                    self.__kwargs['npoints'] = self.__kwargs.pop(equiv)
244

245
    def run_sampler(self):
246
        nestle = self.external_sampler
247
        self.external_sampler_function = nestle.sample
248 249
        if self.kwargs.get('verbose', True):
            self.kwargs['callback'] = nestle.print_progress
Gregory Ashton's avatar
Gregory Ashton committed
250
            self.kwargs.pop('verbose')
251
        self.verify_kwargs_against_external_sampler_function()
252

253
        out = self.external_sampler_function(
254
            loglikelihood=self.log_likelihood,
255
            prior_transform=self.prior_transform,
moritz's avatar
moritz committed
256
            ndim=self.ndim, **self.kwargs)
257
        print("")
258 259

        self.result.sampler_output = out
Gregory Ashton's avatar
Gregory Ashton committed
260 261 262 263
        self.result.samples = nestle.resample_equal(out.samples, out.weights)
        self.result.logz = out.logz
        self.result.logzerr = out.logzerr
        return self.result
264 265 266


class Dynesty(Sampler):
267 268 269 270 271 272 273

    @property
    def kwargs(self):
        return self.__kwargs

    @kwargs.setter
    def kwargs(self, kwargs):
274 275
        self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk',
                             walks=self.ndim * 5, verbose=True)
276
        self.__kwargs.update(kwargs)
277 278
        if 'nlive' not in self.__kwargs:
            for equiv in ['nlives', 'n_live_points', 'npoint', 'npoints']:
279
                if equiv in self.__kwargs:
280 281 282
                    self.__kwargs['nlive'] = self.__kwargs.pop(equiv)
        if 'nlive' not in self.__kwargs:
            self.__kwargs['nlive'] = 250
Colm Talbot's avatar
Colm Talbot committed
283
        if 'update_interval' not in self.__kwargs:
284
            self.__kwargs['update_interval'] = int(0.6 * self.__kwargs['nlive'])
285

286
    def run_sampler(self):
287
        dynesty = self.external_sampler
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302

        if self.kwargs.get('dynamic', False) is False:
            nested_sampler = dynesty.NestedSampler(
                loglikelihood=self.log_likelihood,
                prior_transform=self.prior_transform,
                ndim=self.ndim, **self.kwargs)
            nested_sampler.run_nested(
                dlogz=self.kwargs['dlogz'],
                print_progress=self.kwargs['verbose'])
        else:
            nested_sampler = dynesty.DynamicNestedSampler(
                loglikelihood=self.log_likelihood,
                prior_transform=self.prior_transform,
                ndim=self.ndim, **self.kwargs)
            nested_sampler.run_nested(print_progress=self.kwargs['verbose'])
303
        print("")
Gregory Ashton's avatar
Gregory Ashton committed
304
        out = nested_sampler.results
305

Gregory Ashton's avatar
Gregory Ashton committed
306
        # self.result.sampler_output = out
Gregory Ashton's avatar
Gregory Ashton committed
307 308 309
        weights = np.exp(out['logwt'] - out['logz'][-1])
        self.result.samples = dynesty.utils.resample_equal(
            out.samples, weights)
Colm Talbot's avatar
Colm Talbot committed
310 311
        self.result.logz = out.logz[-1]
        self.result.logzerr = out.logzerr[-1]
Gregory Ashton's avatar
Gregory Ashton committed
312
        return self.result
313 314


315
class Pymultinest(Sampler):
316

moritz's avatar
moritz committed
317 318 319 320 321 322
    @property
    def kwargs(self):
        return self.__kwargs

    @kwargs.setter
    def kwargs(self, kwargs):
Gregory Ashton's avatar
Gregory Ashton committed
323
        outputfiles_basename = self.outdir + '/pymultinest_{}/'.format(self.label)
324 325 326 327
        utils.check_directory_exists_and_if_not_mkdir(outputfiles_basename)
        self.__kwargs = dict(importance_nested_sampling=False, resume=True,
                             verbose=True, sampling_efficiency='parameter',
                             outputfiles_basename=outputfiles_basename)
moritz's avatar
moritz committed
328 329 330 331
        self.__kwargs.update(kwargs)
        if self.__kwargs['outputfiles_basename'].endswith('/') is False:
            self.__kwargs['outputfiles_basename'] = '{}/'.format(
                self.__kwargs['outputfiles_basename'])
332 333 334 335
        if 'n_live_points' not in self.__kwargs:
            for equiv in ['nlive', 'nlives', 'npoints', 'npoint']:
                if equiv in self.__kwargs:
                    self.__kwargs['n_live_points'] = self.__kwargs.pop(equiv)
336

337
    def run_sampler(self):
338
        pymultinest = self.external_sampler
339 340 341 342 343
        self.external_sampler_function = pymultinest.run
        self.verify_kwargs_against_external_sampler_function()
        # Note: pymultinest.solve adds some extra steps, but underneath
        # we are calling pymultinest.run - hence why it is used in checking
        # the arguments.
344
        out = pymultinest.solve(
345
            LogLikelihood=self.log_likelihood, Prior=self.prior_transform,
moritz's avatar
moritz committed
346
            n_dims=self.ndim, **self.kwargs)
347

348
        self.result.sampler_output = out
349 350 351
        self.result.samples = out['samples']
        self.result.logz = out['logZ']
        self.result.logzerr = out['logZerr']
352
        self.result.outputfiles_basename = self.kwargs['outputfiles_basename']
353 354 355
        return self.result


356 357
class Ptemcee(Sampler):

358 359 360 361 362
    def run_sampler(self):
        ntemps = self.kwargs.pop('ntemps', 2)
        nwalkers = self.kwargs.pop('nwalkers', 100)
        nsteps = self.kwargs.pop('nsteps', 100)
        nburn = self.kwargs.pop('nburn', 50)
363
        ptemcee = self.external_sampler
364
        tqdm = utils.get_progress_bar(self.kwargs.pop('tqdm', 'tqdm'))
365 366 367

        sampler = ptemcee.Sampler(
            ntemps=ntemps, nwalkers=nwalkers, dim=self.ndim,
368 369
            logl=self.log_likelihood, logp=self.log_prior,
            **self.kwargs)
370 371 372
        pos0 = [[self.get_random_draw_from_prior()
                 for i in range(nwalkers)]
                for j in range(ntemps)]
373

374
        for result in tqdm(
375
                sampler.sample(pos0, iterations=nsteps, adapt=True), total=nsteps):
376 377 378 379 380 381 382 383
            pass

        self.result.sampler_output = np.nan
        self.result.samples = sampler.chain[0, :, nburn:, :].reshape(
            (-1, self.ndim))
        self.result.walkers = sampler.chain[0, :, :, :]
        self.result.logz = np.nan
        self.result.logzerr = np.nan
384
        self.plot_walkers()
385
        logging.info("Max autocorr time = {}".format(np.max(sampler.get_autocorr_time())))
386
        logging.info("Tswap frac = {}".format(sampler.tswap_acceptance_fraction))
387 388
        return self.result

389 390 391
    def plot_walkers(self, save=True, **kwargs):
        nwalkers, nsteps, ndim = self.result.walkers.shape
        idxs = np.arange(nsteps)
392
        fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3*self.ndim))
393 394
        for i, ax in enumerate(axes):
            ax.plot(idxs, self.result.walkers[:, :, i].T, lw=0.1, color='k')
395
            ax.set_ylabel(self.result.parameter_labels[i])
396 397 398 399 400 401

        fig.tight_layout()
        filename = '{}/{}_walkers.png'.format(self.outdir, self.label)
        logging.info('Saving walkers plot to {}'.format('filename'))
        fig.savefig(filename)

402

403
def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
404
                sampler='nestle', use_ratio=True, injection_parameters=None,
Colm Talbot's avatar
Colm Talbot committed
405
                conversion_function=None, **kwargs):
406 407 408 409 410
    """
    The primary interface to easy parameter estimation

    Parameters
    ----------
411
    likelihood: `tupak.likelihood.Likelihood`
412 413 414
        A `Likelihood` instance
    priors: dict
        A dictionary of the priors for each parameter - missing parameters will
415
        use default priors, if None, all priors will be default
Colm Talbot's avatar
Colm Talbot committed
416 417 418
    label: str
        Name for the run, used in output files
    outdir: str
419 420 421
        A string used in defining output files
    sampler: str
        The name of the sampler to use - see
422
        `tupak.sampler.get_implemented_samplers()` for a list of available
423
        samplers
424 425
    use_ratio: bool (False)
        If True, use the likelihood's loglikelihood_ratio, rather than just
Colm Talbot's avatar
Colm Talbot committed
426
        the log likelhood.
427 428 429
    injection_parameters: dict
        A dictionary of injection parameters used in creating the data (if
        using simulated data). Appended to the result object and saved.
Colm Talbot's avatar
Colm Talbot committed
430 431 432

    conversion_function: function, optional
        Function to apply to posterior to generate additional parameters.
433
    **kwargs:
Colm Talbot's avatar
Colm Talbot committed
434
        All kwargs are passed directly to the samplers `run` function
435 436 437

    Returns
    ------
Gregory Ashton's avatar
Gregory Ashton committed
438 439
    result
        An object containing the results
440
    """
441 442

    utils.check_directory_exists_and_if_not_mkdir(outdir)
443 444
    implemented_samplers = get_implemented_samplers()

445 446
    if priors is None:
        priors = dict()
447
    priors = fill_priors(priors, likelihood, parameters=likelihood.non_standard_sampling_parameter_keys)
moritz's avatar
moritz committed
448
    tupak.prior.write_priors_to_file(priors, outdir)
Colm Talbot's avatar
Colm Talbot committed
449

450 451
    if implemented_samplers.__contains__(sampler.title()):
        sampler_class = globals()[sampler.title()]
452
        sampler = sampler_class(likelihood, priors, sampler, outdir=outdir,
453
                                label=label, use_ratio=use_ratio,
454
                                **kwargs)
455

456 457 458 459
        if sampler.cached_result:
            logging.info("Using cached result")
            return sampler.cached_result

460
        result = sampler.run_sampler()
461
        result.noise_logz = likelihood.noise_log_likelihood()
462 463 464 465 466
        if use_ratio:
            result.log_bayes_factor = result.logz
            result.logz = result.log_bayes_factor + result.noise_logz
        else:
            result.log_bayes_factor = result.logz - result.noise_logz
467 468 469
        if injection_parameters is not None:
            result.injection_parameters = injection_parameters
            tupak.conversion.generate_all_bbh_parameters(result.injection_parameters)
Colm Talbot's avatar
Colm Talbot committed
470
        result.fixed_parameter_keys = [key for key in priors if isinstance(key, prior.DeltaFunction)]
Gregory Ashton's avatar
Fix #49  
Gregory Ashton committed
471
        # result.prior = prior  # Removed as this breaks the saving of the data
Colm Talbot's avatar
Colm Talbot committed
472
        result.samples_to_data_frame(likelihood=likelihood, priors=priors, conversion_function=conversion_function)
473
        result.kwargs = sampler.kwargs
474
        result.save_to_file(outdir=outdir, label=label)
Gregory Ashton's avatar
Gregory Ashton committed
475
        return result
476 477 478
    else:
        raise ValueError(
            "Sampler {} not yet implemented".format(sampler))
479 480 481 482 483 484 485 486


def get_implemented_samplers():
    implemented_samplers = []
    for name, obj in inspect.getmembers(sys.modules[__name__]):
        if inspect.isclass(obj):
            implemented_samplers.append(obj.__name__)
    return implemented_samplers