diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py index 3d43d7352f04f5cce79ce8f452bd7c753a015da4..67bd076482fda6c26f32d48f95d47f8800265a36 100644 --- a/bilby/bilby_mcmc/sampler.py +++ b/bilby/bilby_mcmc/sampler.py @@ -6,6 +6,7 @@ from pathlib import Path import numpy as np import pandas as pd +from scipy.optimize import differential_evolution from ..core.result import rejection_sample from ..core.sampler.base_sampler import ( @@ -114,6 +115,13 @@ class Bilby_MCMC(MCMCSampler): evidence_method: str, [stepping_stone, thermodynamic] The evidence calculation method to use. Defaults to stepping_stone, but the results of all available methods are stored in the ln_z_dict. + initial_sample_method: str + Method to draw the initial sample. Either "prior" (a random draw + from the prior) or "maximize" (use an optimization approach to attempt + to find the maximum posterior estimate). + initial_sample_dict: dict + A dictionary of the initial sample value. If incomplete, will overwrite + the initial_sample drawn using initial_sample_method. verbose: bool Whether to print diagnostic output during the run. @@ -144,6 +152,8 @@ class Bilby_MCMC(MCMCSampler): fixed_tau=None, tau_window=None, evidence_method="stepping_stone", + initial_sample_method="prior", + initial_sample_dict=None, ) def __init__( @@ -188,6 +198,8 @@ class Bilby_MCMC(MCMCSampler): self.proposal_cycle = self.kwargs["proposal_cycle"] self.pt_rejection_sample = self.kwargs["pt_rejection_sample"] self.evidence_method = self.kwargs["evidence_method"] + self.initial_sample_method = self.kwargs["initial_sample_method"] + self.initial_sample_dict = self.kwargs["initial_sample_dict"] self.printdt = self.kwargs["printdt"] check_directory_exists_and_if_not_mkdir(self.outdir) @@ -287,6 +299,8 @@ class Bilby_MCMC(MCMCSampler): pool=self.pool, use_ratio=self.use_ratio, evidence_method=self.evidence_method, + initial_sample_method=self.initial_sample_method, + initial_sample_dict=self.initial_sample_dict, ) def get_setup_string(self): @@ -522,9 +536,13 @@ class BilbyPTMCMCSampler(object): pool, use_ratio, evidence_method, + initial_sample_method, + initial_sample_dict, ): self.set_pt_inputs(pt_inputs) self.use_ratio = use_ratio + self.initial_sample_method = initial_sample_method + self.initial_sample_dict = initial_sample_dict self.setup_sampler_dictionary(convergence_inputs, proposal_cycle) self.set_convergence_inputs(convergence_inputs) self.pt_rejection_sample = pt_rejection_sample @@ -572,10 +590,12 @@ class BilbyPTMCMCSampler(object): betas = self.get_initial_betas() logger.info( f"Initializing BilbyPTMCMCSampler with:" - f"ntemps={self.ntemps}," - f"nensemble={self.nensemble}," - f"pt_ensemble={self.pt_ensemble}," - f"initial_betas={betas}\n" + f"ntemps={self.ntemps}, " + f"nensemble={self.nensemble}, " + f"pt_ensemble={self.pt_ensemble}, " + f"initial_betas={betas}, " + f"initial_sample_method={self.initial_sample_method}, " + f"initial_sample_dict={self.initial_sample_dict}\n" ) self.sampler_dictionary = dict() for Tindex, beta in enumerate(betas): @@ -591,6 +611,8 @@ class BilbyPTMCMCSampler(object): convergence_inputs=convergence_inputs, proposal_cycle=proposal_cycle, use_ratio=self.use_ratio, + initial_sample_method=self.initial_sample_method, + initial_sample_dict=self.initial_sample_dict, ) for Eindex in range(n) ] @@ -1077,6 +1099,8 @@ class BilbyMCMCSampler(object): Tindex=0, Eindex=0, use_ratio=False, + initial_sample_method="prior", + initial_sample_dict=None, ): self.beta = beta self.Tindex = Tindex @@ -1086,12 +1110,24 @@ class BilbyMCMCSampler(object): self.parameters = _sampling_convenience_dump.priors.non_fixed_keys self.ndim = len(self.parameters) - full_sample_dict = _sampling_convenience_dump.priors.sample() - initial_sample = { - k: v - for k, v in full_sample_dict.items() - if k in _sampling_convenience_dump.priors.non_fixed_keys - } + if initial_sample_method.lower() == "prior": + full_sample_dict = _sampling_convenience_dump.priors.sample() + initial_sample = { + k: v + for k, v in full_sample_dict.items() + if k in _sampling_convenience_dump.priors.non_fixed_keys + } + elif initial_sample_method.lower() in ["maximize", "maximise", "maximum"]: + initial_sample = get_initial_maximimum_posterior_sample(self.beta) + else: + raise ValueError( + f"initial sample method {initial_sample_method} not understood" + ) + + if initial_sample_dict is not None: + initial_sample.update(initial_sample_dict) + + logger.info(f"Using initial sample {initial_sample}") initial_sample = Sample(initial_sample) initial_sample[LOGLKEY] = self.log_likelihood(initial_sample) initial_sample[LOGPKEY] = self.log_prior(initial_sample) @@ -1266,6 +1302,42 @@ class BilbyMCMCSampler(object): return samples +def get_initial_maximimum_posterior_sample(beta): + """A method to attempt optimization of the maximum likelihood + + This uses a simple scipy optimization approach, starting from a number + of draws from the prior to avoid problems with local optimization. + + """ + logger.info("Finding initial maximum posterior estimate") + likelihood = _sampling_convenience_dump.likelihood + priors = _sampling_convenience_dump.priors + search_parameter_keys = _sampling_convenience_dump.search_parameter_keys + + bounds = [] + for key in search_parameter_keys: + bounds.append((priors[key].minimum, priors[key].maximum)) + + def neg_log_post(x): + sample = {key: val for key, val in zip(search_parameter_keys, x)} + ln_prior = priors.ln_prob(sample) + + if np.isinf(ln_prior): + return -np.inf + + likelihood.parameters.update(sample) + + return -beta * likelihood.log_likelihood() - ln_prior + + res = differential_evolution(neg_log_post, bounds, popsize=100, init="sobol") + if res.success: + sample = {key: val for key, val in zip(search_parameter_keys, res.x)} + logger.info(f"Initial maximum posterior estimate {sample}") + return sample + else: + raise ValueError("Failed to find initial maximum posterior estimate") + + # Methods used to aid parallelisation: