diff --git a/.dictionary.txt b/.dictionary.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b8adc596cdc293dfd4d158aebeb908adf705139 --- /dev/null +++ b/.dictionary.txt @@ -0,0 +1,2 @@ +hist +livetime diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 584f8a6395f249eb1ea039ff8fff466e968b99ef..0d7ff4d82f43c8bdb696ab211490d216a91aa682 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,19 +4,27 @@ repos: hooks: - id: check-merge-conflict # prevent committing files with merge conflicts - id: flake8 # checks for flake8 errors -#- repo: https://github.com/codespell-project/codespell -# rev: v1.16.0 -# hooks: -# - id: codespell # Spellchecker -# args: [-L, nd, --skip, "*.ipynb,*.html", --ignore-words=.dictionary.txt] -# exclude: ^examples/tutorials/ -#- repo: https://github.com/asottile/seed-isort-config -# rev: v1.3.0 -# hooks: -# - id: seed-isort-config -# args: [--application-directories, 'bilby/'] -#- repo: https://github.com/pre-commit/mirrors-isort -# rev: v4.3.21 -# hooks: -# - id: isort # sort imports alphabetically and separates import into sections -# args: [-w=88, -m=3, -tc, -sp=setup.cfg ] +- repo: https://github.com/psf/black + rev: 20.8b1 + hooks: + - id: black + language_version: python3 + files: ^bilby/bilby_mcmc/ +- repo: https://github.com/codespell-project/codespell + rev: v1.16.0 + hooks: + - id: codespell + args: [--ignore-words=.dictionary.txt] + files: ^bilby/bilby_mcmc/ +- repo: https://github.com/asottile/seed-isort-config + rev: v1.3.0 + hooks: + - id: seed-isort-config + args: [--application-directories, 'bilby/'] + files: ^bilby/bilby_mcmc/ +- repo: https://github.com/pre-commit/mirrors-isort + rev: v4.3.21 + hooks: + - id: isort # sort imports alphabetically and separates import into sections + args: [-w=88, -m=3, -tc, -sp=setup.cfg ] + files: ^bilby/bilby_mcmc/ diff --git a/bilby/bilby_mcmc/__init__.py b/bilby/bilby_mcmc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbf49adb19d17adfc5553b0e5c06e46e7123c554 --- /dev/null +++ b/bilby/bilby_mcmc/__init__.py @@ -0,0 +1 @@ +from .sampler import Bilby_MCMC diff --git a/bilby/bilby_mcmc/chain.py b/bilby/bilby_mcmc/chain.py new file mode 100644 index 0000000000000000000000000000000000000000..43c47ef2ea0616ca0ecbe54525e0c5e36134a4cd --- /dev/null +++ b/bilby/bilby_mcmc/chain.py @@ -0,0 +1,526 @@ +from distutils.version import LooseVersion + +import numpy as np +import pandas as pd + +from ..core.sampler.base_sampler import SamplerError +from ..core.utils import logger +from .utils import LOGLKEY, LOGLLATEXKEY, LOGPKEY, LOGPLATEXKEY + + +class Chain(object): + def __init__( + self, + initial_sample, + burn_in_nact=1, + thin_by_nact=1, + fixed_discard=0, + autocorr_c=5, + min_tau=1, + fixed_tau=None, + tau_window=None, + block_length=100000, + ): + """Object to store a single mcmc chain + + Parameters + ---------- + initial_sample: bilby.bilby_mcmc.chain.Sample + The starting point of the chain + burn_in_nact, thin_by_nact : int (1, 1) + The number of autocorrelation times (tau) to discard for burn-in + and the multiplicative factor to thin by (thin_by_nact < 1). I.e + burn_in_nact=10 and thin_by_nact=1 will discard 10*tau samples from + the start of the chain, then thin the final chain by a factor + of 1*tau (resulting in independent samples). + fixed_discard: int (0) + A fixed minimum number of samples to discard (can be used to + override the burn_in_nact if it is too small). + autocorr_c: float (5) + The step size of the window search used by emcee.autocorr when + estimating the autocorrelation time. + min_tau: int (1) + A minimum value for the autocorrelation time. + fixed_tau: int (None) + A fixed value for the autocorrelation (overrides the automated + autocorrelation time estimation). Used in testing. + tau_window: int (None) + Only calculate the autocorrelation time in a trailing window. If + None (default) this method is not used. + block_length: int + The incremental size to extend the array by when it runs out of + space. + """ + self.autocorr_c = autocorr_c + self.min_tau = min_tau + self.burn_in_nact = burn_in_nact + self.thin_by_nact = thin_by_nact + self.block_length = block_length + self.fixed_discard = int(fixed_discard) + self.fixed_tau = fixed_tau + self.tau_window = tau_window + + self.ndim = initial_sample.ndim + self.current_sample = initial_sample + self.keys = self.current_sample.keys + self.parameter_keys = self.current_sample.parameter_keys + + # Initialize chain + self._chain_array = self._get_zero_chain_array() + self._chain_array_length = block_length + self.position = -1 + self.max_log_likelihood = -np.inf + self.max_tau_dict = {} + self.converged = False + self.cached_tau_count = 0 + self._minimum_index_proposal = 0 + self._minimum_index_adapt = 0 + self._last_minimum_index = (0, 0, "I") + self.last_full_tau_dict = {key: np.inf for key in self.parameter_keys} + + # Append the initial sample + self.append(self.current_sample) + + def _get_zero_chain_array(self): + return np.zeros((self.block_length, self.ndim + 2), dtype=np.float64) + + def _extend_chain_array(self): + self._chain_array = np.concatenate( + (self._chain_array, self._get_zero_chain_array()), axis=0 + ) + self._chain_array_length = len(self._chain_array) + + @property + def current_sample(self): + return self._current_sample.copy() + + @current_sample.setter + def current_sample(self, current_sample): + self._current_sample = current_sample + + def append(self, sample): + self.position += 1 + + # Extend the array if needed + if self.position >= self._chain_array_length: + self._extend_chain_array() + + # Store the current sample and append to the array + self.current_sample = sample + self._chain_array[self.position] = sample.list + + # Update the maximum log_likelihood + if sample[LOGLKEY] > self.max_log_likelihood: + self.max_log_likelihood = sample[LOGLKEY] + + def __getitem__(self, index): + if index < 0: + index = index + self.position + 1 + + if index <= self.position: + values = self._chain_array[index] + return Sample({k: v for k, v in zip(self.keys, values)}) + else: + raise SamplerError(f"Requested index {index} out of bounds") + + def __setitem__(self, index, sample): + if index < 0: + index = index + self.position + 1 + + self._chain_array[index] = sample.list + + def key_to_idx(self, key): + return self.keys.index(key) + + def get_1d_array(self, key): + return self._chain_array[: 1 + self.position, self.key_to_idx(key)] + + @property + def _random_idx(self): + mindex = self._last_minimum_index[1] + # Check if mindex exceeds current position by 10 ACT: if so use a random sample + # otherwise we draw only from the chain past the minimum_index + if np.isinf(self.tau_last) or self.position - mindex < 10 * self.tau_last: + mindex = 0 + return np.random.randint(mindex, self.position + 1) + + @property + def random_sample(self): + return self[self._random_idx] + + @property + def fixed_discard(self): + return self._fixed_discard + + @fixed_discard.setter + def fixed_discard(self, fixed_discard): + self._fixed_discard = int(fixed_discard) + + @property + def minimum_index(self): + """This calculated a minimum index from which to discard samples + + A number of methods are provided for the calculation. A subset are + switched off (by `if False` statements) for future development + + """ + position = self.position + + # Return cached minimum index + last_minimum_index = self._last_minimum_index + if position == last_minimum_index[0]: + return int(last_minimum_index[1]) + + # If fixed discard is not yet reached, just return that + if position < self.fixed_discard: + self.minimum_index_method = "FD" + return self.fixed_discard + + # Initialize list of minimum index methods with the fixed discard (FD) + minimum_index_list = [self.fixed_discard] + minimum_index_method_list = ["FD"] + + # Calculate minimum index from tau + if self.tau_last < np.inf: + tau = self.tau_last + elif len(self.max_tau_dict) == 0: + # Bootstrap calculating tau when minimum index has not yet been calculated + tau = self._tau_for_full_chain + else: + tau = np.inf + + if tau < np.inf: + minimum_index_list.append(self.burn_in_nact * tau) + minimum_index_method_list.append(f"{self.burn_in_nact}tau") + + # Calculate points when log-posterior is within z std of the mean + if True: + zfactor = 1 + N = 100 + delta_lnP = zfactor * self.ndim / 2 + logl = self.get_1d_array(LOGLKEY) + log_prior = self.get_1d_array(LOGPKEY) + log_posterior = logl + log_prior + max_posterior = np.max(log_posterior) + + ave = pd.Series(log_posterior).rolling(window=N).mean().iloc[N - 1 :] + delta = max_posterior - ave + passes = ave[delta < delta_lnP] + if len(passes) > 0: + minimum_index_list.append(passes.index[0] + 1) + minimum_index_method_list.append(f"z{zfactor}") + + # Add last minimum_index_method + if False: + minimum_index_list.append(last_minimum_index[1]) + minimum_index_method_list.append(last_minimum_index[2]) + + # Minimum index set by proposals + minimum_index_list.append(self.minimum_index_proposal) + minimum_index_method_list.append("PR") + + # Minimum index set by temperature adaptation + minimum_index_list.append(self.minimum_index_adapt) + minimum_index_method_list.append("AD") + + # Calculate the maximum minimum index and associated method (reporting) + minimum_index = int(np.max(minimum_index_list)) + minimum_index_method = minimum_index_method_list[np.argmax(minimum_index_list)] + + # Cache the method + self._last_minimum_index = (position, minimum_index, minimum_index_method) + self.minimum_index_method = minimum_index_method + + return minimum_index + + @property + def minimum_index_proposal(self): + return self._minimum_index_proposal + + @minimum_index_proposal.setter + def minimum_index_proposal(self, minimum_index_proposal): + if minimum_index_proposal > self._minimum_index_proposal: + self._minimum_index_proposal = minimum_index_proposal + + @property + def minimum_index_adapt(self): + return self._minimum_index_adapt + + @minimum_index_adapt.setter + def minimum_index_adapt(self, minimum_index_adapt): + if minimum_index_adapt > self._minimum_index_adapt: + self._minimum_index_adapt = minimum_index_adapt + + @property + def tau(self): + """ The maximum ACT over all parameters """ + + if self.position in self.max_tau_dict: + # If we have the ACT at the current position, return it + return self.max_tau_dict[self.position] + elif ( + self.tau_last < np.inf + and self.cached_tau_count < 50 + and self.nsamples_last > 50 + ): + # If we have a recent ACT return it + self.cached_tau_count += 1 + return self.tau_last + else: + # Calculate the ACT + return self.tau_nocache + + @property + def tau_nocache(self): + """ Calculate tau forcing a recalculation (no cached tau) """ + tau = max(self.tau_dict.values()) + self.max_tau_dict[self.position] = tau + self.cached_tau_count = 0 + return tau + + @property + def tau_last(self): + """ Return the last-calculated tau if it exists, else inf """ + if len(self.max_tau_dict) > 0: + return list(self.max_tau_dict.values())[-1] + else: + return np.inf + + @property + def _tau_for_full_chain(self): + """ The maximum ACT over all parameters """ + return max(self._tau_dict_for_full_chain.values()) + + @property + def _tau_dict_for_full_chain(self): + return self._calculate_tau_dict(minimum_index=0) + + @property + def tau_dict(self): + """ Calculate a dictionary of tau (ACT) for every parameter """ + return self._calculate_tau_dict(self.minimum_index) + + def _calculate_tau_dict(self, minimum_index): + """ Calculate a dictionary of tau (ACT) for every parameter """ + logger.debug(f"Calculating tau_dict {self}") + + # If there are too few samples to calculate tau + if (self.position - minimum_index) < 2 * self.autocorr_c: + return {key: np.inf for key in self.parameter_keys} + + # Choose minimimum index for the ACT calculation + last_tau = self.tau_last + if self.tau_window is not None and last_tau < np.inf: + minimum_index_for_act = max( + minimum_index, int(self.position - self.tau_window * last_tau) + ) + else: + minimum_index_for_act = minimum_index + + # Calculate a dictionary of tau's for each parameter + taus = {} + for key in self.parameter_keys: + if self.fixed_tau is None: + x = self.get_1d_array(key)[minimum_index_for_act:] + tau = calculate_tau(x, self.autocorr_c) + taux = round(tau, 1) + else: + taux = self.fixed_tau + taus[key] = max(taux, self.min_tau) + + # Cache the last tau dictionary for future use + self.last_full_tau_dict = taus + + return taus + + @property + def thin(self): + if np.isfinite(self.tau): + return np.max([1, int(self.thin_by_nact * self.tau)]) + else: + return 1 + + @property + def nsamples(self): + nuseable_steps = self.position - self.minimum_index + return int(nuseable_steps / (self.thin_by_nact * self.tau)) + + @property + def nsamples_last(self): + nuseable_steps = self.position - self.minimum_index + return int(nuseable_steps / (self.thin_by_nact * self.tau_last)) + + @property + def samples(self): + samples = self._chain_array[self.minimum_index : self.position : self.thin] + return pd.DataFrame(samples, columns=self.keys) + + def plot(self, outdir=".", label="label", priors=None, all_samples=None): + import matplotlib.pyplot as plt + + fig, axes = plt.subplots( + nrows=self.ndim + 3, ncols=2, figsize=(8, 9 + 3 * (self.ndim)) + ) + scatter_kwargs = dict( + lw=0, + marker="o", + ) + K = 1000 + + nburn = self.minimum_index + plot_setups = zip( + [0, nburn, nburn], + [nburn, self.position, self.position], + [1, 1, self.thin], # Thin-by factor + ["tab:red", "tab:grey", "tab:blue"], # Color + [0.5, 0.05, 0.5], # Alpha + [1, 1, 1], # Marker size + ) + + position_indexes = np.arange(self.position + 1) + + # Plot the traceplots + for (start, stop, thin, color, alpha, ms) in plot_setups: + for ax, key in zip(axes[:, 0], self.keys): + xx = position_indexes[start:stop:thin] / K + yy = self.get_1d_array(key)[start:stop:thin] + + # Downsample plots to max_pts: avoid memory issues + max_pts = 10000 + while len(xx) > max_pts: + xx = xx[::2] + yy = yy[::2] + + ax.plot( + xx, + yy, + color=color, + alpha=alpha, + ms=ms, + **scatter_kwargs, + ) + ax.set_ylabel(self._get_plot_label_by_key(key, priors)) + if key not in [LOGLKEY, LOGPKEY]: + msg = r"$\tau=$" + f"{self.last_full_tau_dict[key]:0.1f}" + ax.set_title(msg) + + # Plot the histograms + for ax, key in zip(axes[:, 1], self.keys): + yy_all = all_samples[key] + ax.hist(yy_all, bins=50, alpha=0.6, density=True, color="k") + + yy = self.get_1d_array(key)[nburn : self.position : self.thin] + ax.hist(yy, bins=50, alpha=0.8, density=True) + + ax.set_xlabel(self._get_plot_label_by_key(key, priors)) + + # Add x-axes labels to the traceplots + axes[-1, 0].set_xlabel(r"Iteration $[\times 10^{3}]$") + + # Plot the calculated ACT + ax = axes[-1, 0] + tausit = np.array(list(self.max_tau_dict.keys()) + [self.position]) / K + taus = list(self.max_tau_dict.values()) + [self.tau_last] + ax.plot(tausit, taus, color="C3") + ax.set(ylabel=r"Maximum $\tau$") + + axes[-1, 1].set_axis_off() + + filename = "{}/{}_checkpoint_trace.png".format(outdir, label) + msg = [ + r"Maximum $\tau$" + f"={self.tau:0.1f} ", + r"$n_{\rm samples}=$" + f"{self.nsamples} ", + ] + if self.thin_by_nact != 1: + msg += [ + r"$n_{\rm samples}^{\rm eff}=$" + + f"{int(self.nsamples * self.thin_by_nact)} " + ] + fig.suptitle( + "| ".join(msg), + y=1, + ) + fig.tight_layout() + fig.savefig(filename, dpi=200) + plt.close(fig) + + @staticmethod + def _get_plot_label_by_key(key, priors=None): + if priors is not None and key in priors: + return priors[key].latex_label + elif key == LOGLKEY: + return LOGLLATEXKEY + elif key == LOGPKEY: + return LOGPLATEXKEY + else: + return key + + +class Sample(object): + def __init__(self, sample_dict): + """A single sample + + Parameters + ---------- + sample_dict: dict + A dictionary of the sample + """ + + self.sample_dict = sample_dict + self.keys = list(sample_dict.keys()) + self.parameter_keys = [k for k in self.keys if k not in [LOGPKEY, LOGLKEY]] + self.ndim = len(self.parameter_keys) + + def __getitem__(self, key): + return self.sample_dict[key] + + def __setitem__(self, key, value): + self.sample_dict[key] = value + if key not in self.keys: + self.keys = list(self.sample_dict.keys()) + + @property + def list(self): + return list(self.sample_dict.values()) + + def __repr__(self): + return str(self.sample_dict) + + @property + def parameter_only_dict(self): + return {key: self.sample_dict[key] for key in self.parameter_keys} + + @property + def dict(self): + return {key: self.sample_dict[key] for key in self.keys} + + def as_dict(self, keys=None): + sdict = self.dict + if keys is None: + return sdict + else: + return {key: sdict[key] for key in keys} + + def __eq__(self, other_sample): + return self.list == other_sample.list + + def copy(self): + return Sample(self.sample_dict.copy()) + + +def calculate_tau(x, autocorr_c=5): + import emcee + + if LooseVersion(emcee.__version__) < LooseVersion("3"): + raise SamplerError("bilby-mcmc requires emcee > 3.0 for autocorr analysis") + + if np.all(np.diff(x) == 0): + return np.inf + try: + # Hard code tol=1: we perform this check internally + tau = emcee.autocorr.integrated_time(x, c=autocorr_c, tol=1)[0] + if np.isnan(tau): + tau = np.inf + return tau + except emcee.autocorr.AutocorrError: + return np.inf diff --git a/bilby/bilby_mcmc/flows.py b/bilby/bilby_mcmc/flows.py new file mode 100644 index 0000000000000000000000000000000000000000..5fbaf196b38859660307554e8bc371ce0b03edf8 --- /dev/null +++ b/bilby/bilby_mcmc/flows.py @@ -0,0 +1,100 @@ +import torch +from nflows.distributions.normal import StandardNormal +from nflows.flows.base import Flow +from nflows.nn import nets as nets +from nflows.transforms import ( + CompositeTransform, + MaskedAffineAutoregressiveTransform, + RandomPermutation, +) +from nflows.transforms.coupling import ( + AdditiveCouplingTransform, + AffineCouplingTransform, +) +from nflows.transforms.normalization import BatchNorm +from torch.nn import functional as F + +# Turn off parallelism +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +class NVPFlow(Flow): + """A simplified version of Real NVP for 1-dim inputs. + + This implementation uses 1-dim checkerboard masking but doesn't use + multi-scaling. + Reference: + > L. Dinh et al., Density estimation using Real NVP, ICLR 2017. + + This class has been modified from the example found at: + https://github.com/bayesiains/nflows/blob/master/nflows/flows/realnvp.py + """ + + def __init__( + self, + features, + hidden_features, + num_layers, + num_blocks_per_layer, + use_volume_preserving=False, + activation=F.relu, + dropout_probability=0.0, + batch_norm_within_layers=False, + batch_norm_between_layers=False, + random_permutation=True, + ): + + if use_volume_preserving: + coupling_constructor = AdditiveCouplingTransform + else: + coupling_constructor = AffineCouplingTransform + + mask = torch.ones(features) + mask[::2] = -1 + + def create_resnet(in_features, out_features): + return nets.ResidualNet( + in_features, + out_features, + hidden_features=hidden_features, + num_blocks=num_blocks_per_layer, + activation=activation, + dropout_probability=dropout_probability, + use_batch_norm=batch_norm_within_layers, + ) + + layers = [] + for _ in range(num_layers): + transform = coupling_constructor( + mask=mask, transform_net_create_fn=create_resnet + ) + layers.append(transform) + mask *= -1 + if batch_norm_between_layers: + layers.append(BatchNorm(features=features)) + + if random_permutation: + layers.append(RandomPermutation(features=features)) + + super().__init__( + transform=CompositeTransform(layers), + distribution=StandardNormal([features]), + ) + + +class BasicFlow(Flow): + def __init__(self, features): + transform = CompositeTransform( + [ + MaskedAffineAutoregressiveTransform( + features=features, hidden_features=2 * features + ), + RandomPermutation(features=features), + ] + ) + distribution = StandardNormal(shape=[features]) + super().__init__( + transform=transform, + distribution=distribution, + ) diff --git a/bilby/bilby_mcmc/proposals.py b/bilby/bilby_mcmc/proposals.py new file mode 100644 index 0000000000000000000000000000000000000000..4d1d33bdf1b2e6d92802fe34be4c86e2976dd8a5 --- /dev/null +++ b/bilby/bilby_mcmc/proposals.py @@ -0,0 +1,1110 @@ +import importlib +import time +from abc import ABCMeta, abstractmethod + +import numpy as np +from scipy.spatial.distance import jensenshannon +from scipy.stats import gaussian_kde + +from ..core.prior import PriorDict +from ..core.sampler.base_sampler import SamplerError +from ..core.utils import logger, reflect +from ..gw.source import PARAMETER_SETS + + +class ProposalCycle(object): + def __init__(self, proposal_list): + self.proposal_list = proposal_list + self.weights = [prop.weight for prop in self.proposal_list] + self.normalized_weights = [w / sum(self.weights) for w in self.weights] + self.weighted_proposal_list = [ + np.random.choice(self.proposal_list, p=self.normalized_weights) + for _ in range(10 * int(1 / min(self.normalized_weights))) + ] + self.nproposals = len(self.weighted_proposal_list) + self._position = 0 + + @property + def position(self): + return self._position + + @position.setter + def position(self, position): + self._position = np.mod(position, self.nproposals) + + def get_proposal(self): + prop = self.weighted_proposal_list[self._position] + self.position += 1 + return prop + + def __str__(self): + string = "ProposalCycle:\n" + for prop in self.proposal_list: + string += f" {prop}\n" + return string + + +class BaseProposal(object): + _accepted = 0 + _rejected = 0 + __metaclass__ = ABCMeta + + def __init__(self, priors, weight=1, subset=None): + self._str_attrs = ["acceptance_ratio", "n"] + + self.parameters = priors.non_fixed_keys + self.weight = weight + self.subset = subset + + # Restrict to a subset + if self.subset is not None: + self.parameters = [p for p in self.parameters if p in subset] + self._str_attrs.append("parameters") + + self.ndim = len(self.parameters) + + self.prior_boundary_dict = {key: priors[key].boundary for key in priors} + self.prior_minimum_dict = {key: np.max(priors[key].minimum) for key in priors} + self.prior_maximum_dict = {key: np.min(priors[key].maximum) for key in priors} + self.prior_width_dict = {key: np.max(priors[key].width) for key in priors} + + @property + def accepted(self): + return self._accepted + + @accepted.setter + def accepted(self, accepted): + self._accepted = accepted + + @property + def rejected(self): + return self._rejected + + @rejected.setter + def rejected(self, rejected): + self._rejected = rejected + + @property + def acceptance_ratio(self): + if self.n == 0: + return np.nan + else: + return self.accepted / self.n + + @property + def n(self): + return self.accepted + self.rejected + + def __str__(self): + msg = [f"{type(self).__name__}("] + for attr in self._str_attrs: + val = getattr(self, attr, "N/A") + if isinstance(val, (float, int)): + val = f"{val:1.2g}" + msg.append(f"{attr}:{val},") + return "".join(msg) + ")" + + def apply_boundaries(self, point): + for key in self.parameters: + boundary = self.prior_boundary_dict[key] + if boundary is None: + continue + elif boundary == "periodic": + point[key] = self.apply_periodic_boundary(key, point[key]) + elif boundary == "reflective": + point[key] = self.apply_reflective_boundary(key, point[key]) + else: + raise SamplerError(f"Boundary {boundary} not implemented") + return point + + def apply_periodic_boundary(self, key, val): + minimum = self.prior_minimum_dict[key] + width = self.prior_width_dict[key] + return minimum + np.mod(val - minimum, width) + + def apply_reflective_boundary(self, key, val): + minimum = self.prior_minimum_dict[key] + width = self.prior_width_dict[key] + val_normalised = (val - minimum) / width + val_normalised_reflected = reflect(np.array(val_normalised)) + return minimum + width * val_normalised_reflected + + def __call__(self, chain): + sample, log_factor = self.propose(chain) + sample = self.apply_boundaries(sample) + return sample, log_factor + + @abstractmethod + def propose(self, chain): + """Propose a new point + + This method must be overwritten by implemented proposals. The propose + method is called by __call__, then boundaries applied, before returning + the proposed point. + + Parameters + ---------- + chain: bilby.core.sampler.bilby_mcmc.chain.Chain + The chain to use for the proposal + + Returns + ------- + proposal: bilby.core.sampler.bilby_mcmc.Sample + The proposed point + log_factor: float + The natural-log of the additional factor entering the acceptance + probability to ensure detailed balance. For symmetric proposals, + a value of 0 should be returned. + """ + pass + + @staticmethod + def check_dependencies(warn=True): + """Check the dependencies required to use the proposal + + Parameters + ---------- + warn: bool + If true, print a warning + + Returns + ------- + check: bool + If true, dependencies exist + """ + return True + + +class FixedGaussianProposal(BaseProposal): + """A proposal using a fixed non-correlated Gaussian distribution + + Parameters + ---------- + priors: bilby.core.prior.PriorDict + The set of priors + weight: float + Weighting factor + subset: list + A list of keys for which to restrict the proposal to (other parameters + will be kept fixed) + sigma: float + The scaling factor for proposals + """ + + def __init__(self, priors, weight=1, subset=None, sigma=0.01): + super(FixedGaussianProposal, self).__init__(priors, weight, subset) + self.sigmas = {} + for key in self.parameters: + if np.isinf(self.prior_width_dict[key]): + self.prior_width_dict[key] = 1 + if isinstance(sigma, float): + self.sigmas[key] = sigma + elif isinstance(sigma, dict): + self.sigmas[key] = sigma[key] + else: + raise SamplerError("FixedGaussianProposal sigma not understood") + + def propose(self, chain): + sample = chain.current_sample + for key in self.parameters: + sigma = self.prior_width_dict[key] * self.sigmas[key] + sample[key] += sigma * np.random.randn() + log_factor = 0 + return sample, log_factor + + +class AdaptiveGaussianProposal(BaseProposal): + def __init__( + self, + priors, + weight=1, + subset=None, + sigma=1, + scale_init=1e0, + stop=1e5, + target_facc=0.234, + ): + super(AdaptiveGaussianProposal, self).__init__(priors, weight, subset) + self.sigmas = {} + for key in self.parameters: + if np.isinf(self.prior_width_dict[key]): + self.prior_width_dict[key] = 1 + if isinstance(sigma, (float, int)): + self.sigmas[key] = sigma + elif isinstance(sigma, dict): + self.sigmas[key] = sigma[key] + else: + raise SamplerError("AdaptiveGaussianProposal sigma not understood") + + self.target_facc = target_facc + self.scale = scale_init + self.stop = stop + self._str_attrs.append("scale") + self._last_accepted = 0 + + def propose(self, chain): + sample = chain.current_sample + self.update_scale(chain) + if np.random.random() < 1e-3: + factor = 1e1 + elif np.random.random() < 1e-4: + factor = 1e2 + else: + factor = 1 + for key in self.parameters: + sigma = factor * self.scale * self.prior_width_dict[key] * self.sigmas[key] + sample[key] += sigma * np.random.randn() + log_factor = 0 + return sample, log_factor + + def update_scale(self, chain): + """ + The adaptation of the scale follows (35)/(36) of https://arxiv.org/abs/1409.7215 + """ + if 0 < self.n < self.stop: + s_gamma = (self.stop / self.n) ** 0.2 - 1 + if self.accepted > self._last_accepted: + self.scale += s_gamma * (1 - self.target_facc) / 100 + else: + self.scale -= s_gamma * self.target_facc / 100 + self._last_accepted = self.accepted + self.scale = max(self.scale, 1 / self.stop) + + +class DifferentialEvolutionProposal(BaseProposal): + """A proposal using Differential Evolution + + Parameters + ---------- + priors: bilby.core.prior.PriorDict + The set of priors + weight: float + Weighting factor + subset: list + A list of keys for which to restrict the proposal to (other parameters + will be kept fixed) + mode_hopping_frac: float + The fraction of proposals which use 'mode hopping' + """ + + def __init__(self, priors, weight=1, subset=None, mode_hopping_frac=0.5): + super(DifferentialEvolutionProposal, self).__init__(priors, weight, subset) + self.mode_hopping_frac = mode_hopping_frac + + def propose(self, chain): + theta = chain.current_sample + theta1 = chain.random_sample + theta2 = chain.random_sample + if np.random.rand() > self.mode_hopping_frac: + gamma = 1 + else: + # Base jump size + gamma = np.random.normal(0, 2.38 / np.sqrt(2 * self.ndim)) + # Scale uniformly in log between 0.1 and 10 times + gamma *= np.exp(np.log(0.1) + np.log(100.0) * np.random.rand()) + + for key in self.parameters: + theta[key] += gamma * (theta2[key] - theta1[key]) + + log_factor = 0 + return theta, log_factor + + +class UniformProposal(BaseProposal): + """A proposal using uniform draws from the prior support + + Parameters + ---------- + priors: bilby.core.prior.PriorDict + The set of priors + weight: float + Weighting factor + subset: list + A list of keys for which to restrict the proposal to (other parameters + will be kept fixed) + """ + + def __init__(self, priors, weight=1, subset=None): + super(UniformProposal, self).__init__(priors, weight, subset) + + def propose(self, chain): + sample = chain.current_sample + for key in self.parameters: + sample[key] = np.random.uniform( + self.prior_minimum_dict[key], self.prior_maximum_dict[key] + ) + log_factor = 0 + return sample, log_factor + + +class PriorProposal(BaseProposal): + """A proposal using draws from the prior distribution + + Note: for priors which use interpolation, this proposal can be problematic + as the proposal gets pickled in multiprocessing. Either, use serial + processing (npool=1) or fall back to a UniformProposal. + + Parameters + ---------- + priors: bilby.core.prior.PriorDict + The set of priors + weight: float + Weighting factor + subset: list + A list of keys for which to restrict the proposal to (other parameters + will be kept fixed) + """ + + def __init__(self, priors, weight=1, subset=None): + super(PriorProposal, self).__init__(priors, weight, subset) + self.priors = PriorDict({key: priors[key] for key in self.parameters}) + + def propose(self, chain): + sample = chain.current_sample + lnp_theta = self.priors.ln_prob(sample.as_dict(self.parameters)) + prior_sample = self.priors.sample() + for key in self.parameters: + sample[key] = prior_sample[key] + lnp_thetaprime = self.priors.ln_prob(sample.as_dict(self.parameters)) + log_factor = lnp_theta - lnp_thetaprime + return sample, log_factor + + +_density_estimate_doc = """ A proposal using draws from a {estimator} fit to the chain + +Parameters +---------- +priors: bilby.core.prior.PriorDict + The set of priors +weight: float + Weighting factor +subset: list + A list of keys for which to restrict the proposal to (other parameters + will be kept fixed) +first_fit: int + The number of steps to take before first fitting the KDE +fit_multiplier: int + The multiplier for the next fit +nsamples_for_density: int + The number of samples to use when fitting the KDE +fallback: bilby.core.sampler.bilby_mcmc.proposal.BaseProposal + A proposal to use before first training +scale_fits: int + A scaling factor for both the initial and subsequent updates +""" + + +class DensityEstimateProposal(BaseProposal): + def __init__( + self, + priors, + weight=1, + subset=None, + first_fit=1000, + fit_multiplier=10, + nsamples_for_density=1000, + fallback=AdaptiveGaussianProposal, + scale_fits=1, + ): + super(DensityEstimateProposal, self).__init__(priors, weight, subset) + self.nsamples_for_density = nsamples_for_density + self.fallback = fallback(priors, weight, subset) + self.fit_multiplier = fit_multiplier * scale_fits + + # Counters + self.steps_since_refit = 0 + self.next_refit_time = first_fit * scale_fits + self.density = None + self.trained = False + self._str_attrs.append("trained") + + density_name = None + __doc__ = _density_estimate_doc.format(estimator=density_name) + + def _fit(self, dataset): + raise NotImplementedError + + def _evaluate(self, point): + raise NotImplementedError + + def _sample(self, nsamples=None): + raise NotImplementedError + + def refit(self, chain): + current_density = self.density + start = time.time() + + # Draw two (possibly overlapping) data sets for training and verification + dataset = [] + verification_dataset = [] + nsamples_for_density = min(chain.position, self.nsamples_for_density) + for _ in range(nsamples_for_density): + s = chain.random_sample + dataset.append([s[key] for key in self.parameters]) + s = chain.random_sample + verification_dataset.append([s[key] for key in self.parameters]) + + # Fit the density + self.density = self._fit(np.array(dataset).T) + + # Print a log message + took = time.time() - start + logger.info( + f"{self.density_name} construction at {self.steps_since_refit} finished" + f" for length {chain.position} chain, took {took:0.2f}s." + f" Current accept-ratio={self.acceptance_ratio:0.2f}" + ) + + # Reset counters for next training + self.steps_since_refit = 0 + self.next_refit_time *= self.fit_multiplier + + # Verify training hasn't overconstrained + new_draws = np.atleast_2d(self._sample(1000)) + verification_dataset = np.array(verification_dataset) + fail_parameters = [] + for ii, key in enumerate(self.parameters): + std_draws = np.std(new_draws[:, ii]) + std_verification = np.std(verification_dataset[:, ii]) + if std_draws < 0.1 * std_verification: + fail_parameters.append(key) + + if len(fail_parameters) > 0: + logger.info( + f"{self.density_name} construction failed verification and is discarded" + ) + self.density = current_density + else: + self.trained = True + + def propose(self, chain): + self.steps_since_refit += 1 + + # Check if we refit + testA = self.steps_since_refit >= self.next_refit_time + if testA: + self.refit(chain) + + # If KDE is yet to be fitted, use the fallback + if self.trained is False: + return self.fallback.propose(chain) + + # Grab the current sample and it's probability under the KDE + theta = chain.current_sample + ln_p_theta = self._evaluate(list(theta.as_dict(self.parameters).values())) + + # Sample and update theta + new_sample = self._sample(1) + for key, val in zip(self.parameters, new_sample): + theta[key] = val + + # Calculate the probability of the new sample and the KDE + ln_p_thetaprime = self._evaluate(list(theta.as_dict(self.parameters).values())) + + # Calculate Q(theta|theta') / Q(theta'|theta) + log_factor = ln_p_theta - ln_p_thetaprime + + return theta, log_factor + + +class KDEProposal(DensityEstimateProposal): + density_name = "Gaussian KDE" + __doc__ = _density_estimate_doc.format(estimator=density_name) + + def _fit(self, dataset): + return gaussian_kde(dataset) + + def _evaluate(self, point): + return self.density.logpdf(point)[0] + + def _sample(self, nsamples=None): + return np.atleast_1d(np.squeeze(self.density.resample(nsamples))) + + +class GMMProposal(DensityEstimateProposal): + density_name = "Gaussian Mixture Model" + __doc__ = _density_estimate_doc.format(estimator=density_name) + + def _fit(self, dataset): + from sklearn.mixture import GaussianMixture + + density = GaussianMixture(n_components=10) + density.fit(dataset.T) + return density + + def _evaluate(self, point): + return np.squeeze(self.density.score_samples(np.atleast_2d(point))) + + def _sample(self, nsamples=None): + return np.squeeze(self.density.sample(n_samples=nsamples)[0]) + + def check_dependencies(warn=True): + if importlib.util.find_spec("sklearn") is None: + if warn: + logger.warning( + "Unable to utilise GMMProposal as sklearn is not installed" + ) + return False + else: + return True + + +class NormalizingFlowProposal(DensityEstimateProposal): + density_name = "Normalizing Flow" + __doc__ = _density_estimate_doc.format(estimator=density_name) + ( + """ + js_factor: float + The factor to use in determining the max-JS factor to terminate + training. + max_training_epochs: int + The maximum bumber of traning steps to take + """ + ) + + def __init__( + self, + priors, + weight=1, + subset=None, + first_fit=1000, + fit_multiplier=10, + max_training_epochs=1000, + scale_fits=1, + nsamples_for_density=1000, + js_factor=10, + fallback=AdaptiveGaussianProposal, + ): + super(NormalizingFlowProposal, self).__init__( + priors=priors, + weight=weight, + subset=subset, + first_fit=first_fit, + fit_multiplier=fit_multiplier, + nsamples_for_density=nsamples_for_density, + fallback=fallback, + scale_fits=scale_fits, + ) + self.setup_flow() + self.setup_optimizer() + + self.max_training_epochs = max_training_epochs + self.js_factor = js_factor + + def setup_flow(self): + if self.ndim < 3: + self.setup_basic_flow() + else: + self.setup_NVP_flow() + + def setup_NVP_flow(self): + from .flows import NVPFlow + + self.flow = NVPFlow( + features=self.ndim, + hidden_features=self.ndim * 2, + num_layers=2, + num_blocks_per_layer=2, + batch_norm_between_layers=True, + batch_norm_within_layers=True, + ) + + def setup_basic_flow(self): + from .flows import BasicFlow + + self.flow = BasicFlow(features=self.ndim) + + def setup_optimizer(self): + from torch import optim + + self.optimizer = optim.Adam(self.flow.parameters()) + + def get_training_data(self, chain): + training_data = [] + nsamples_for_density = min(chain.position, self.nsamples_for_density) + for _ in range(nsamples_for_density): + s = chain.random_sample + training_data.append([s[key] for key in self.parameters]) + return training_data + + def _calculate_js(self, validation_samples, training_samples_draw): + # Calculate the maximum JS between the validation and draw + max_js = 0 + for i in range(self.ndim): + A = validation_samples[:, i] + B = training_samples_draw[:, i] + xmin = np.min([np.min(A), np.min(B)]) + xmax = np.min([np.max(A), np.max(B)]) + xval = np.linspace(xmin, xmax, 100) + Apdf = gaussian_kde(A)(xval) + Bpdf = gaussian_kde(B)(xval) + js = jensenshannon(Apdf, Bpdf) + max_js = max(max_js, js) + return np.power(max_js, 2) + + def train(self, chain): + logger.info("Starting NF training") + + import torch + + start = time.time() + + training_samples = np.array(self.get_training_data(chain)) + validation_samples = np.array(self.get_training_data(chain)) + + training_tensor = torch.tensor(training_samples, dtype=torch.float32) + + max_js_threshold = self.js_factor / self.nsamples_for_density + + for epoch in range(1, self.max_training_epochs + 1): + self.optimizer.zero_grad() + loss = -self.flow.log_prob(inputs=training_tensor).mean() + loss.backward() + self.optimizer.step() + + # Draw from the current flow + self.flow.eval() + training_samples_draw = ( + self.flow.sample(self.nsamples_for_density).detach().numpy() + ) + self.flow.train() + + if np.mod(epoch, 10) == 0: + max_js_bits = self._calculate_js( + validation_samples, training_samples_draw + ) + if max_js_bits < max_js_threshold: + logger.info( + f"Training complete after {epoch} steps, " + f"max_js_bits={max_js_bits:0.5f}<{max_js_threshold}" + ) + break + + took = time.time() - start + logger.info( + f"Flow training step ({self.steps_since_refit}) finished" + f" for length {chain.position} chain, took {took:0.2f}s." + f" Current accept-ratio={self.acceptance_ratio:0.2f}" + ) + self.steps_since_refit = 0 + self.next_refit_time *= self.fit_multiplier + self.trained = True + + def propose(self, chain): + import torch + + self.steps_since_refit += 1 + theta = chain.current_sample + + # Check if we retrain the NF + testA = self.steps_since_refit >= self.next_refit_time + if testA: + self.train(chain) + + if self.trained is False: + return self.fallback.propose(chain) + + self.flow.eval() + theta_prime_T = self.flow.sample(1) + + logp_theta_prime = self.flow.log_prob(theta_prime_T).detach().numpy()[0] + theta_T = torch.tensor( + np.atleast_2d([theta[key] for key in self.parameters]), dtype=torch.float32 + ) + logp_theta = self.flow.log_prob(theta_T).detach().numpy()[0] + log_factor = logp_theta - logp_theta_prime + + flow_sample_values = np.atleast_1d(np.squeeze(theta_prime_T.detach().numpy())) + for key, val in zip(self.parameters, flow_sample_values): + theta[key] = val + + return theta, float(log_factor) + + def check_dependencies(warn=True): + if importlib.util.find_spec("nflows") is None: + if warn: + logger.warning( + "Unable to utilise NormalizingFlowProposal as nflows is not installed" + ) + return False + else: + return True + + +class FixedJumpProposal(BaseProposal): + def __init__(self, priors, jumps=1, subset=None, weight=1, scale=1e-4): + super(FixedJumpProposal, self).__init__(priors, weight, subset) + self.scale = scale + if isinstance(jumps, (int, float)): + self.jumps = {key: jumps for key in self.parameters} + elif isinstance(jumps, dict): + self.jumps = jumps + else: + raise SamplerError("jumps not understood") + + def propose(self, chain): + sample = chain.current_sample + for key, jump in self.jumps.items(): + sign = np.random.randint(2) * 2 - 1 + sample[key] += sign * jump + self.epsilon * self.prior_width_dict[key] + log_factor = 0 + return sample, log_factor + + @property + def epsilon(self): + return self.scale * np.random.normal() + + +class BaseGravitationalWaveTransientProposal(BaseProposal): + def __init__(self, priors, weight=1): + super(BaseGravitationalWaveTransientProposal, self).__init__( + priors, weight=weight + ) + if "phase" in priors: + self.phase_key = "phase" + elif "delta_phase" in priors: + self.phase_key = "delta_phase" + else: + self.phase_key = None + + def get_cos_theta_jn(self, sample): + if "cos_theta_jn" in sample.parameter_keys: + cos_theta_jn = sample["cos_theta_jn"] + elif "theta_jn" in sample.parameter_keys: + cos_theta_jn = np.cos(sample["theta_jn"]) + else: + raise SamplerError() + return cos_theta_jn + + def get_phase(self, sample): + if "phase" in sample.parameter_keys: + return sample["phase"] + elif "delta_phase" in sample.parameter_keys: + cos_theta_jn = self.get_cos_theta_jn(sample) + delta_phase = sample["delta_phase"] + psi = sample["psi"] + phase = np.mod(delta_phase - np.sign(cos_theta_jn) * psi, 2 * np.pi) + else: + raise SamplerError() + return phase + + def get_delta_phase(self, phase, sample): + cos_theta_jn = self.get_cos_theta_jn(sample) + psi = sample["psi"] + delta_phase = phase + np.sign(cos_theta_jn) * psi + return delta_phase + + +class CorrelatedPolarisationPhaseJump(BaseGravitationalWaveTransientProposal): + def __init__(self, priors, weight=1): + super(CorrelatedPolarisationPhaseJump, self).__init__(priors, weight=weight) + + def propose(self, chain): + sample = chain.current_sample + phase = self.get_phase(sample) + + alpha = sample["psi"] + phase + beta = sample["psi"] - phase + + draw = np.random.random() + if draw < 0.5: + alpha = 3.0 * np.pi * np.random.random() + else: + beta = 3.0 * np.pi * np.random.random() - 2 * np.pi + + # Update + sample["psi"] = (alpha + beta) * 0.5 + phase = (alpha - beta) * 0.5 + + if self.phase_key == "delta_phase": + sample["delta_phase"] = self.get_delta_phase(phase, sample) + else: + sample["phase"] = phase + + log_factor = 0 + return sample, log_factor + + +class PhaseReversalProposal(BaseGravitationalWaveTransientProposal): + def __init__(self, priors, weight=1, fuzz=True, fuzz_sigma=1e-1): + super(PhaseReversalProposal, self).__init__(priors, weight) + self.fuzz = fuzz + self.fuzz_sigma = fuzz_sigma + if self.phase_key is None: + raise SamplerError( + f"{type(self).__name__} initialised without a phase prior" + ) + + def propose(self, chain): + sample = chain.current_sample + phase = sample[self.phase_key] + sample[self.phase_key] = np.mod(phase + np.pi + self.epsilon, 2 * np.pi) + log_factor = 0 + return sample, log_factor + + @property + def epsilon(self): + if self.fuzz: + return np.random.normal(0, self.fuzz_sigma) + else: + return 0 + + +class PolarisationReversalProposal(PhaseReversalProposal): + def __init__(self, priors, weight=1, fuzz=True, fuzz_sigma=1e-3): + super(PolarisationReversalProposal, self).__init__( + priors, weight, fuzz, fuzz_sigma + ) + self.fuzz = fuzz + + def propose(self, chain): + sample = chain.current_sample + psi = sample["psi"] + sample["psi"] = np.mod(psi + np.pi / 2 + self.epsilon, np.pi) + log_factor = 0 + return sample, log_factor + + +class PhasePolarisationReversalProposal(PhaseReversalProposal): + def __init__(self, priors, weight=1, fuzz=True, fuzz_sigma=1e-1): + super(PhasePolarisationReversalProposal, self).__init__( + priors, weight, fuzz, fuzz_sigma + ) + self.fuzz = fuzz + + def propose(self, chain): + sample = chain.current_sample + sample[self.phase_key] = np.mod( + sample[self.phase_key] + np.pi + self.epsilon, 2 * np.pi + ) + sample["psi"] = np.mod(sample["psi"] + np.pi / 2 + self.epsilon, np.pi) + log_factor = 0 + return sample, log_factor + + +class StretchProposal(BaseProposal): + """The Goodman & Weare (2010) Stretch proposal for an MCMC chain + + Implementation of the Stretch proposal using a sample drawn from the chain. + We assume the form of g(z) from Equation (9) of [1]. + + References + ---------- + [1] Goodman & Weare (2010) + https://ui.adsabs.harvard.edu/abs/2010CAMCS...5...65G/abstract + + """ + + def __init__(self, priors, weight=1, subset=None, scale=2): + super(StretchProposal, self).__init__(priors, weight, subset) + self.scale = scale + + def propose(self, chain): + sample = chain.current_sample + + # Draw a random sample + rand = chain.random_sample + + return _stretch_move(sample, rand, self.scale, self.ndim, self.parameters) + + +def _stretch_move(sample, complement, scale, ndim, parameters): + # Draw z + u = np.random.rand() + z = (u * (scale - 1) + 1) ** 2 / scale + + log_factor = (ndim - 1) * np.log(z) + + for key in parameters: + sample[key] = complement[key] + (sample[key] - complement[key]) * z + + return sample, log_factor + + +class EnsembleProposal(BaseProposal): + """ Base EnsembleProposal class for ensemble-based swap proposals """ + + def __init__(self, priors, weight=1): + super(EnsembleProposal, self).__init__(priors, weight) + + def __call__(self, chain, chain_complement): + sample, log_factor = self.propose(chain, chain_complement) + sample = self.apply_boundaries(sample) + return sample, log_factor + + +class EnsembleStretch(EnsembleProposal): + """The Goodman & Weare (2010) Stretch proposal for an Ensemble + + Implementation of the Stretch proposal using a sample drawn from complement. + We assume the form of g(z) from Equation (9) of [1]. + + References + ---------- + [1] Goodman & Weare (2010) + https://ui.adsabs.harvard.edu/abs/2010CAMCS...5...65G/abstract + + """ + + def __init__(self, priors, weight=1, scale=2): + super(EnsembleStretch, self).__init__(priors, weight) + self.scale = scale + + def propose(self, chain, chain_complement): + sample = chain.current_sample + completement = chain_complement[ + np.random.randint(len(chain_complement)) + ].current_sample + return _stretch_move( + sample, completement, self.scale, self.ndim, self.parameters + ) + + +def get_default_ensemble_proposal_cycle(priors): + return ProposalCycle([EnsembleStretch(priors)]) + + +def get_proposal_cycle(string, priors, L1steps=1, warn=True): + big_weight = 10 + small_weight = 5 + tiny_weight = 0.1 + + if "gwA" in string: + # Parameters for learning proposals + learning_kwargs = dict( + first_fit=1000, nsamples_for_density=10000, fit_multiplier=2 + ) + + plist = [ + AdaptiveGaussianProposal(priors, weight=small_weight), + DifferentialEvolutionProposal(priors, weight=small_weight), + ] + + if GMMProposal.check_dependencies(warn=warn) is False: + raise SamplerError( + "the gwA proposal_cycle required the GMMProposal dependencies" + ) + + if priors.intrinsic: + intrinsic = PARAMETER_SETS["intrinsic"] + plist += [ + AdaptiveGaussianProposal(priors, weight=big_weight, subset=intrinsic), + DifferentialEvolutionProposal( + priors, weight=big_weight, subset=intrinsic + ), + KDEProposal( + priors, weight=big_weight, subset=intrinsic, **learning_kwargs + ), + GMMProposal( + priors, weight=big_weight, subset=intrinsic, **learning_kwargs + ), + ] + + if priors.extrinsic: + extrinsic = PARAMETER_SETS["extrinsic"] + plist += [ + AdaptiveGaussianProposal(priors, weight=small_weight, subset=extrinsic), + DifferentialEvolutionProposal( + priors, weight=big_weight, subset=extrinsic + ), + KDEProposal( + priors, weight=big_weight, subset=extrinsic, **learning_kwargs + ), + GMMProposal( + priors, weight=big_weight, subset=extrinsic, **learning_kwargs + ), + ] + + if priors.mass: + mass = PARAMETER_SETS["mass"] + plist += [ + DifferentialEvolutionProposal(priors, weight=small_weight, subset=mass), + GMMProposal( + priors, weight=small_weight, subset=mass, **learning_kwargs + ), + ] + + if priors.spin: + spin = PARAMETER_SETS["spin"] + plist += [ + DifferentialEvolutionProposal(priors, weight=small_weight, subset=spin), + GMMProposal( + priors, weight=small_weight, subset=spin, **learning_kwargs + ), + ] + if priors.precession: + measured_spin = ["chi_1", "chi_2", "a_1", "a_2", "chi_1_in_plane"] + plist += [ + AdaptiveGaussianProposal( + priors, weight=small_weight, subset=measured_spin + ), + ] + + if priors.mass and priors.spin: + primary_spin_and_q = PARAMETER_SETS["primary_spin_and_q"] + plist += [ + DifferentialEvolutionProposal( + priors, weight=small_weight, subset=primary_spin_and_q + ), + ] + + if getattr(priors, "tidal", False): + tidal = PARAMETER_SETS["tidal"] + plist += [ + DifferentialEvolutionProposal( + priors, weight=small_weight, subset=tidal + ), + PriorProposal(priors, weight=small_weight, subset=tidal), + ] + if priors.phase: + plist += [ + PhaseReversalProposal(priors, weight=tiny_weight), + ] + if priors.phase and "psi" in priors.non_fixed_keys: + plist += [ + CorrelatedPolarisationPhaseJump(priors, weight=tiny_weight), + PhasePolarisationReversalProposal(priors, weight=tiny_weight), + ] + for key in ["time_jitter", "psi", "phi_12", "tilt_2", "lambda_1", "lambda_2"]: + if key in priors.non_fixed_keys: + plist.append(PriorProposal(priors, subset=[key], weight=tiny_weight)) + if "chi_1_in_plane" in priors and "chi_2_in_plane" in priors: + in_plane = ["chi_1_in_plane", "chi_2_in_plane", "phi_12"] + plist.append(UniformProposal(priors, subset=in_plane, weight=tiny_weight)) + if any("recalib_" in key for key in priors): + calibration = [key for key in priors if "recalib_" in key] + plist.append(PriorProposal(priors, subset=calibration, weight=small_weight)) + else: + plist = [ + AdaptiveGaussianProposal(priors, weight=big_weight), + DifferentialEvolutionProposal(priors, weight=big_weight), + UniformProposal(priors, weight=tiny_weight), + KDEProposal(priors, weight=big_weight, scale_fits=L1steps), + ] + if GMMProposal.check_dependencies(warn=warn): + plist.append(GMMProposal(priors, weight=big_weight, scale_fits=L1steps)) + if NormalizingFlowProposal.check_dependencies(warn=warn): + plist.append( + NormalizingFlowProposal(priors, weight=big_weight, scale_fits=L1steps) + ) + + plist = remove_proposals_using_string(plist, string) + return ProposalCycle(plist) + + +def remove_proposals_using_string(plist, string): + mapping = dict( + DE=DifferentialEvolutionProposal, + AG=AdaptiveGaussianProposal, + ST=StretchProposal, + FG=FixedGaussianProposal, + NF=NormalizingFlowProposal, + KD=KDEProposal, + GM=GMMProposal, + PR=PriorProposal, + UN=UniformProposal, + ) + + for element in string.split("no")[1:]: + if element in mapping: + plist = [p for p in plist if isinstance(p, mapping[element]) is False] + return plist diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..6cad693ae9c1879388141d99d49fc282a3e0487a --- /dev/null +++ b/bilby/bilby_mcmc/sampler.py @@ -0,0 +1,1348 @@ +import datetime +import os +import signal +import time +from collections import Counter + +import numpy as np +import pandas as pd + +from ..core.result import rejection_sample +from ..core.sampler.base_sampler import MCMCSampler, ResumeError, SamplerError +from ..core.utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump +from . import proposals +from .chain import Chain, Sample +from .utils import LOGLKEY, LOGPKEY, ConvergenceInputs, ParallelTemperingInputs + + +class Bilby_MCMC(MCMCSampler): + """The built-in Bilby MCMC sampler + + Parameters + ---------- + likelihood: likelihood.Likelihood + A object with a log_l method + priors: bilby.core.prior.PriorDict, dict + Priors to be used in the search. + This has attributes for each parameter to be sampled. + outdir: str, optional + Name of the output directory + label: str, optional + Naming scheme of the output files + use_ratio: bool, optional + Switch to set whether or not you want to use the log-likelihood ratio + or just the log-likelihood + skip_import_verification: bool + Skips the check if the sampler is installed if true. This is + only advisable for testing environments + check_point_plot: bool + If true, create plots at the check point + check_point_delta_t: float + The time in seconds afterwhich to checkpoint (defaults to 30 minutes) + diagnostic: bool + If true, create deep-diagnostic plots used for checking convergence + problems. + resume: bool + If true, resume from any existing check point files + exit_code: int + The code on which to raise if exiting + + Sampling Parameters + ------------------- + nsamples: int (1000) + The number of samples to draw + nensemble: int (1) + The number of ensemble-chains to run (with periodic communication) + pt_ensemble: bool (False) + If true, each run a parallel-tempered set of chains for each + ensemble-chain (in which case the total number of chains is + nensemble * ntemps). Else, only the zero-ensemble chain is run with a + parallel-tempering (in which case the total number of chains is + nensemble + ntemps - 1). + ntemps: int (1) + The number of parallel-tempered chains to run + Tmax: float, (None) + If given, the maximum temperature to set the initial temperate-ladder + Tmax_from_SNR: float (20) + (Alternative to Tmax): The SNR to estimate an appropriate Tmax from. + initial_betas: list (None) + (Alternative to Tmax and Tmax_from_SNR): If given, an initial choice of + the inverse temperature ladder. + pt_rejection_sample: bool (False) + If true, use rejection sampling to draw samples from the pt-chains. + adapt, adapt_t0, adapt_nu: bool, float, float (True, 100, 10) + Whether to use adaptation and the adaptation parameters. + See arXiv:1501.05823 for a description of adapt_t0 and adapt_nu. + burn_in_nact, thin_by_nact, fixed_discard: float, float, float (10, 1, 0) + The number of auto-correlation times to discard for burn-in and to + thin by. The fixed_discard is the number of steps discarded before + automatic autocorrelation time analysis begins. + autocorr_c: float (5) + The step-size for the window search. See emcee.autocorr.integrated_time + for additional details. + L1steps: int + The number of internal steps to take. Improves the scaling performance + of multiprocessing. Note, all ACTs are calculated based on the saved + steps. So, the total ACT (or number of steps) is L1steps * tau + (or L1steps * position). + L2steps: int + The number of steps to take before swapping between parallel-tempered + and ensemble chains. + npool: int + The number of multiprocessing cores to use. For efficiency, this must be + matched to an integer number of the total number of chains. + printdt: float + Print an update on the progress every printdt s. Note, each print + requires an evaluation of the ACT so short print times are unwise. + min_tau: 1 + The minimum allowed ACT. Can be used to force a larger ACT. + proposal_cycle: str, bilby.core.sampler.bilby_mcmc.proposals.ProposalCycle + Either a string pointing to one of the built-in proposal cycles or, + a proposal cycle. + stop_after_convergence: + If running with parallel-tempered chains. Stop updating the chains once + they have congerged. After this time, random samples will be drawn at + swap time. + fixed_tau: int + A fixed value for the ACT: used for testing purposes. + tau_window: int, None + Using tau', a previous estimates of tau, calculate the new tau using + the last tau_window * tau' steps. If None, the entire chain is used. + evidence_method: str, [stepping_stone, thermodynamic] + The evidence calculation method to use. Defaults to stepping_stone, but + the results of all available methods are stored in the ln_z_dict. + + """ + + default_kwargs = dict( + nsamples=1000, + nensemble=1, + pt_ensemble=False, + ntemps=1, + Tmax=None, + Tmax_from_SNR=20, + initial_betas=None, + adapt=True, + adapt_t0=100, + adapt_nu=10, + pt_rejection_sample=False, + burn_in_nact=10, + thin_by_nact=1, + fixed_discard=0, + autocorr_c=5, + L1steps=100, + L2steps=3, + npool=1, + printdt=60, + min_tau=1, + proposal_cycle="default", + stop_after_convergence=False, + fixed_tau=None, + tau_window=None, + evidence_method="stepping_stone", + ) + + def __init__( + self, + likelihood, + priors, + outdir="outdir", + label="label", + use_ratio=False, + skip_import_verification=True, + check_point_plot=True, + check_point_delta_t=1800, + diagnostic=False, + resume=True, + exit_code=130, + **kwargs, + ): + + super(Bilby_MCMC, self).__init__( + likelihood=likelihood, + priors=priors, + outdir=outdir, + label=label, + use_ratio=use_ratio, + skip_import_verification=skip_import_verification, + exit_code=exit_code, + **kwargs, + ) + + self.check_point_plot = check_point_plot + self.diagnostic = diagnostic + self.kwargs["target_nsamples"] = self.kwargs["nsamples"] + self.npool = self.kwargs["npool"] + self.L1steps = self.kwargs["L1steps"] + self.L2steps = self.kwargs["L2steps"] + self.pt_inputs = ParallelTemperingInputs( + **{key: self.kwargs[key] for key in ParallelTemperingInputs._fields} + ) + self.convergence_inputs = ConvergenceInputs( + **{key: self.kwargs[key] for key in ConvergenceInputs._fields} + ) + self.proposal_cycle = self.kwargs["proposal_cycle"] + self.pt_rejection_sample = self.kwargs["pt_rejection_sample"] + self.evidence_method = self.kwargs["evidence_method"] + + self.printdt = self.kwargs["printdt"] + check_directory_exists_and_if_not_mkdir(self.outdir) + self.resume = resume + self.check_point_delta_t = check_point_delta_t + self.resume_file = "{}/{}_resume.pickle".format(self.outdir, self.label) + + self.verify_configuration() + + try: + signal.signal(signal.SIGTERM, self.write_current_state_and_exit) + signal.signal(signal.SIGINT, self.write_current_state_and_exit) + signal.signal(signal.SIGALRM, self.write_current_state_and_exit) + except AttributeError: + logger.debug( + "Setting signal attributes unavailable on this system. " + "This is likely the case if you are running on a Windows machine" + " and is no further concern." + ) + + def verify_configuration(self): + if self.convergence_inputs.burn_in_nact / self.kwargs["target_nsamples"] > 0.1: + logger.warning("Burn-in inefficiency fraction greater than 10%") + + def _translate_kwargs(self, kwargs): + if "printdt" not in kwargs: + for equiv in ["print_dt", "print_update"]: + if equiv in kwargs: + kwargs["printdt"] = kwargs.pop(equiv) + if "npool" not in kwargs: + for equiv in self.npool_equiv_kwargs: + if equiv in kwargs: + kwargs["npool"] = kwargs.pop(equiv) + + @property + def target_nsamples(self): + return self.kwargs["target_nsamples"] + + def run_sampler(self): + self._setup_pool() + self.setup_chain_set() + self.start_time = datetime.datetime.now() + self.draw() + self._close_pool() + self.check_point(ignore_time=True) + + self.result = self.add_data_to_result( + result=self.result, + ptsampler=self.ptsampler, + outdir=self.outdir, + label=self.label, + make_plots=self.check_point_plot, + ) + + return self.result + + @staticmethod + def add_data_to_result(result, ptsampler, outdir, label, make_plots): + result.samples = ptsampler.samples + result.log_likelihood_evaluations = result.samples[LOGLKEY] + result.log_prior_evaluations = result.samples[LOGPKEY] + ptsampler.compute_evidence( + outdir=outdir, + label=label, + make_plots=make_plots, + ) + result.log_evidence = ptsampler.ln_z + result.log_evidence_err = ptsampler.ln_z_err + result.sampling_time = datetime.timedelta(seconds=ptsampler.sampling_time) + result.meta_data["bilby_mcmc"] = dict( + tau=ptsampler.tau, + convergence_inputs=ptsampler.convergence_inputs._asdict(), + pt_inputs=ptsampler.pt_inputs._asdict(), + total_steps=ptsampler.position, + nsamples=ptsampler.nsamples, + ) + if ptsampler.pool is not None: + npool = ptsampler.pool._processes + else: + npool = 1 + result.meta_data["run_statistics"] = dict( + nlikelihood=ptsampler.position * ptsampler.L1steps * ptsampler._nsamplers, + neffsamples=ptsampler.nsamples * ptsampler.convergence_inputs.thin_by_nact, + sampling_time_s=result.sampling_time.seconds, + ncores=npool, + ) + + return result + + def setup_chain_set(self): + if os.path.isfile(self.resume_file) and self.resume is True: + self.read_current_state() + self.ptsampler.pool = self.pool + else: + self.init_ptsampler() + + def init_ptsampler(self): + + logger.info(f"Initializing BilbyPTMCMCSampler with:\n{self.get_setup_string()}") + self.ptsampler = BilbyPTMCMCSampler( + convergence_inputs=self.convergence_inputs, + pt_inputs=self.pt_inputs, + proposal_cycle=self.proposal_cycle, + pt_rejection_sample=self.pt_rejection_sample, + pool=self.pool, + use_ratio=self.use_ratio, + evidence_method=self.evidence_method, + ) + + def get_setup_string(self): + string = ( + f" Convergence settings: {self.convergence_inputs}\n" + f" Parallel-tempering settings: {self.pt_inputs}\n" + f" proposal_cycle: {self.proposal_cycle}\n" + f" pt_rejection_sample: {self.pt_rejection_sample}" + ) + return string + + def draw(self): + self._steps_since_last_print = 0 + self._time_since_last_print = 0 + logger.info(f"Drawing {self.target_nsamples} samples") + logger.info(f"Checkpoint every {self.check_point_delta_t}s") + logger.info(f"Print update every {self.printdt}s") + + while True: + t0 = datetime.datetime.now() + self.ptsampler.step_all_chains() + dt = (datetime.datetime.now() - t0).total_seconds() + self.ptsampler.sampling_time += dt + self._time_since_last_print += dt + self._steps_since_last_print += self.ptsampler.L1steps + + if self._time_since_last_print > self.printdt: + tp0 = datetime.datetime.now() + self.print_progress() + tp = datetime.datetime.now() + ppt_frac = (tp - tp0).total_seconds() / self._time_since_last_print + if ppt_frac > 0.01: + logger.warning( + f"Non-negligible print progress time (ppt_frac={ppt_frac:0.2f})" + ) + self._steps_since_last_print = 0 + self._time_since_last_print = 0 + + self.check_point() + + if self.ptsampler.nsamples_last >= self.target_nsamples: + # Perform a second check without cached values + if self.ptsampler.nsamples_nocache >= self.target_nsamples: + logger.info("Reached convergence: exiting sampling") + break + + def check_point(self, ignore_time=False): + tS = (datetime.datetime.now() - self.start_time).total_seconds() + if os.path.isfile(self.resume_file): + tR = time.time() - os.path.getmtime(self.resume_file) + else: + tR = np.inf + + if ignore_time or np.min([tS, tR]) > self.check_point_delta_t: + logger.info("Checkpoint start") + self.write_current_state() + self.print_long_progress() + logger.info("Checkpoint finished") + + def _remove_checkpoint(self): + """Remove checkpointed state""" + if os.path.isfile(self.resume_file): + os.remove(self.resume_file) + + def read_current_state(self): + import dill + + with open(self.resume_file, "rb") as file: + self.ptsampler = dill.load(file) + if self.ptsampler.pt_inputs != self.pt_inputs: + msg = ( + f"pt_inputs has changed: {self.ptsampler.pt_inputs} " + f"-> {self.pt_inputs}" + ) + raise ResumeError(msg) + self.ptsampler.set_convergence_inputs(self.convergence_inputs) + self.ptsampler.proposal_cycle = self.proposal_cycle + self.ptsampler.pt_rejection_sample = self.pt_rejection_sample + + logger.info( + f"Loaded resume file {self.resume_file} " + f"with {self.ptsampler.position} steps " + f"setup:\n{self.get_setup_string()}" + ) + + def write_current_state_and_exit(self, signum=None, frame=None): + """ + Make sure that if a pool of jobs is running only the parent tries to + checkpoint and exit. Only the parent has a 'pool' attribute. + """ + if self.npool == 1 or getattr(self, "pool", None) is not None: + if signum == 14: + logger.info( + "Run interrupted by alarm signal {}: checkpoint and exit on {}".format( + signum, self.exit_code + ) + ) + else: + logger.info( + "Run interrupted by signal {}: checkpoint and exit on {}".format( + signum, self.exit_code + ) + ) + self.write_current_state() + self._close_pool() + os._exit(self.exit_code) + + def write_current_state(self): + import dill + + logger.debug("Check point") + check_directory_exists_and_if_not_mkdir(self.outdir) + + _pool = self.ptsampler.pool + self.ptsampler.pool = None + if dill.pickles(self.ptsampler): + safe_file_dump(self.ptsampler, self.resume_file, dill) + logger.info("Written checkpoint file {}".format(self.resume_file)) + else: + logger.warning( + "Cannot write pickle resume file! " + "Job will not resume if interrupted." + ) + self.ptsampler.pool = _pool + + def print_long_progress(self): + self.print_per_proposal() + self.print_tau_dict() + if self.ptsampler.ntemps > 1: + self.print_pt_acceptance() + if self.ptsampler.nensemble > 1: + self.print_ensemble_acceptance() + if self.check_point_plot: + self.plot_progress( + self.ptsampler, self.label, self.outdir, self.priors, self.diagnostic + ) + self.ptsampler.compute_evidence( + outdir=self.outdir, label=self.label, make_plots=True + ) + + def print_ensemble_acceptance(self): + logger.info(f"Ensemble swaps = {self.ptsampler.swap_counter['ensemble']}") + logger.info(self.ptsampler.ensemble_proposal_cycle) + + def print_progress(self): + position = self.ptsampler.position + + # Total sampling time + sampling_time = datetime.timedelta(seconds=self.ptsampler.sampling_time) + time = str(sampling_time).split(".")[0] + + # Time for last evaluation set + time_per_eval_ms = ( + 1000 * self._time_since_last_print / self._steps_since_last_print + ) + + # Pull out progress summary + tau = self.ptsampler.tau + nsamples = self.ptsampler.nsamples + minimum_index = self.ptsampler.primary_sampler.chain.minimum_index + method = self.ptsampler.primary_sampler.chain.minimum_index_method + mindex_str = f"{minimum_index:0.2e}({method})" + alpha = self.ptsampler.primary_sampler.acceptance_ratio + maxl = self.ptsampler.primary_sampler.chain.max_log_likelihood + + nlikelihood = position * self.L1steps * self.ptsampler._nsamplers + eff = 100 * nsamples / nlikelihood + + # Estimated time til finish (ETF) + if tau < np.inf: + remaining_samples = self.target_nsamples - nsamples + remaining_evals = ( + remaining_samples + * self.convergence_inputs.thin_by_nact + * tau + * self.L1steps + ) + remaining_time_s = time_per_eval_ms * 1e-3 * remaining_evals + remaining_time_dt = datetime.timedelta(seconds=remaining_time_s) + if remaining_samples > 0: + remaining_time = str(remaining_time_dt).split(".")[0] + else: + remaining_time = "0" + else: + remaining_time = "-" + + msg = ( + f"{position:0.2e}|{time}|{mindex_str}|t={tau:0.0f}|" + f"n={nsamples:0.0f}|a={alpha:0.2f}|e={eff:0.1e}%|" + f"{time_per_eval_ms:0.2f}ms/ev|maxl={maxl:0.2f}|" + f"ETF={remaining_time}" + ) + + if self.pt_rejection_sample: + count = self.ptsampler.rejection_sampling_count + rse = 100 * count / nsamples + msg += f"|rse={rse:0.2f}%" + + print(msg, flush=True) + + def print_per_proposal(self): + logger.info("Zero-temperature proposals:") + for prop in self.ptsampler[0].proposal_cycle.proposal_list: + logger.info(prop) + + def print_pt_acceptance(self): + logger.info(f"Temperature swaps = {self.ptsampler.swap_counter['temperature']}") + for column in self.ptsampler.sampler_list_of_tempered_lists: + for ii, sampler in enumerate(column): + total = sampler.pt_accepted + sampler.pt_rejected + beta = sampler.beta + if total > 0: + ratio = f"{sampler.pt_accepted / total:0.2f}" + else: + ratio = "-" + logger.info( + f"Temp:{ii}<->{ii+1}|" + f"beta={beta:0.4g}|" + f"hot-samp={sampler.nsamples}|" + f"swap={ratio}|" + f"conv={sampler.chain.converged}|" + ) + + def print_tau_dict(self): + msg = f"Current taus={self.ptsampler.primary_sampler.chain.tau_dict}" + logger.info(msg) + + @staticmethod + def plot_progress(ptsampler, label, outdir, priors, diagnostic=False): + logger.info("Creating diagnostic plots") + for ii, row in ptsampler.sampler_dictionary.items(): + for jj, sampler in enumerate(row): + plot_label = f"{label}_E{sampler.Eindex}_T{sampler.Tindex}" + if diagnostic is True or sampler.beta == 1: + sampler.chain.plot( + outdir=outdir, + label=plot_label, + priors=priors, + all_samples=ptsampler.samples, + ) + + def _setup_pool(self): + if self.npool > 1: + logger.info(f"Setting up multiproccesing pool with {self.npool} processes") + import multiprocessing + + self.pool = multiprocessing.Pool( + processes=self.npool, + initializer=_initialize_global_variables, + initargs=( + self.likelihood, + self.priors, + self._search_parameter_keys, + self.use_ratio, + ), + ) + else: + self.pool = None + + _initialize_global_variables( + likelihood=self.likelihood, + priors=self.priors, + search_parameter_keys=self._search_parameter_keys, + use_ratio=self.use_ratio, + ) + + def _close_pool(self): + if getattr(self, "pool", None) is not None: + logger.info("Starting to close worker pool.") + self.pool.close() + self.pool.join() + self.pool = None + logger.info("Finished closing worker pool.") + + +class BilbyPTMCMCSampler(object): + def __init__( + self, + convergence_inputs, + pt_inputs, + proposal_cycle, + pt_rejection_sample, + pool, + use_ratio, + evidence_method, + ): + + self.set_pt_inputs(pt_inputs) + self.use_ratio = use_ratio + self.setup_sampler_dictionary(convergence_inputs, proposal_cycle) + self.set_convergence_inputs(convergence_inputs) + self.pt_rejection_sample = pt_rejection_sample + self.pool = pool + self.evidence_method = evidence_method + + # Initialize counters + self.swap_counter = Counter() + self.swap_counter["temperature"] = 0 + self.swap_counter["L2-temperature"] = 0 + self.swap_counter["ensemble"] = 0 + self.swap_counter["L2-ensemble"] = int(self.L2steps / 2) + 1 + + self._nsamples_dict = {} + self.ensemble_proposal_cycle = proposals.get_default_ensemble_proposal_cycle( + _priors + ) + self.sampling_time = 0 + self.ln_z_dict = dict() + self.ln_z_err_dict = dict() + + def get_initial_betas(self): + pt_inputs = self.pt_inputs + if self.ntemps == 1: + betas = np.array([1]) + elif pt_inputs.initial_betas is not None: + betas = np.array(pt_inputs.initial_betas) + elif pt_inputs.Tmax is not None: + betas = np.logspace(0, -np.log10(pt_inputs.Tmax), pt_inputs.ntemps) + elif pt_inputs.Tmax_from_SNR is not None: + ndim = len(_priors.non_fixed_keys) + target_hot_likelihood = ndim / 2 + Tmax = pt_inputs.Tmax_from_SNR ** 2 / (2 * target_hot_likelihood) + betas = np.logspace(0, -np.log10(Tmax), pt_inputs.ntemps) + else: + raise SamplerError("Unable to set temperature ladder from inputs") + + if len(betas) != self.ntemps: + raise SamplerError("Temperatures do not match ntemps") + + return betas + + def setup_sampler_dictionary(self, convergence_inputs, proposal_cycle): + + betas = self.get_initial_betas() + logger.info( + f"Initializing BilbyPTMCMCSampler with:" + f"ntemps={self.ntemps}," + f"nensemble={self.nensemble}," + f"pt_ensemble={self.pt_ensemble}," + f"initial_betas={betas}\n" + ) + self.sampler_dictionary = dict() + for Tindex, beta in enumerate(betas): + if beta == 1 or self.pt_ensemble: + n = self.nensemble + else: + n = 1 + temp_sampler_list = [ + BilbyMCMCSampler( + beta=beta, + Tindex=Tindex, + Eindex=Eindex, + convergence_inputs=convergence_inputs, + proposal_cycle=proposal_cycle, + use_ratio=self.use_ratio, + ) + for Eindex in range(n) + ] + self.sampler_dictionary[Tindex] = temp_sampler_list + + # Store data + self._nsamplers = len(self.sampler_list) + + @property + def sampler_list(self): + """ A list of all individual samplers """ + return [s for item in self.sampler_dictionary.values() for s in item] + + @sampler_list.setter + def sampler_list(self, sampler_list): + for sampler in sampler_list: + self.sampler_dictionary[sampler.Tindex][sampler.Eindex] = sampler + + def sampler_list_by_column(self, column): + return [row[column] for row in self.sampler_dictionary.values()] + + @property + def sampler_list_of_tempered_lists(self): + if self.pt_ensemble: + return [self.sampler_list_by_column(ii) for ii in range(self.nensemble)] + else: + return [self.sampler_list_by_column(0)] + + @property + def tempered_sampler_list(self): + return [s for s in self.sampler_list if s.beta < 1] + + @property + def zerotemp_sampler_list(self): + return [s for s in self.sampler_list if s.beta == 1] + + @property + def primary_sampler(self): + return self.sampler_dictionary[0][0] + + def set_pt_inputs(self, pt_inputs): + logger.info(f"Setting parallel tempering inputs={pt_inputs}") + self.pt_inputs = pt_inputs + + # Pull out only what is needed + self.ntemps = pt_inputs.ntemps + self.nensemble = pt_inputs.nensemble + self.pt_ensemble = pt_inputs.pt_ensemble + self.adapt = pt_inputs.adapt + self.adapt_t0 = pt_inputs.adapt_t0 + self.adapt_nu = pt_inputs.adapt_nu + + def set_convergence_inputs(self, convergence_inputs): + logger.info(f"Setting convergence_inputs={convergence_inputs}") + self.convergence_inputs = convergence_inputs + self.L1steps = convergence_inputs.L1steps + self.L2steps = convergence_inputs.L2steps + for sampler in self.sampler_list: + sampler.set_convergence_inputs(convergence_inputs) + + @property + def tau(self): + return self.primary_sampler.chain.tau + + @property + def minimum_index(self): + return self.primary_sampler.chain.minimum_index + + @property + def nsamples(self): + pos = self.primary_sampler.chain.position + if hasattr(self, "_nsamples_dict") is False: + self._nsamples_dict = {} + if pos in self._nsamples_dict: + return self._nsamples_dict[pos] + logger.debug(f"Calculating nsamples at {pos}") + self._nsamples_dict[pos] = self._calculate_nsamples() + return self._nsamples_dict[pos] + + @property + def nsamples_last(self): + if len(self._nsamples_dict) > 0: + return list(self._nsamples_dict.values())[-1] + else: + return 0 + + @property + def nsamples_nocache(self): + for sampler in self.sampler_list: + sampler.chain.tau_nocache + pos = self.primary_sampler.chain.position + self._nsamples_dict[pos] = self._calculate_nsamples() + return self._nsamples_dict[pos] + + def _calculate_nsamples(self): + nsamples_list = [] + for sampler in self.zerotemp_sampler_list: + nsamples_list.append(sampler.nsamples) + if self.pt_rejection_sample: + for samp in self.sampler_list[1:]: + nsamples_list.append( + len(samp.rejection_sample_zero_temperature_samples()) + ) + return sum(nsamples_list) + + @property + def samples(self): + sample_list = [] + for sampler in self.zerotemp_sampler_list: + sample_list.append(sampler.samples) + if self.pt_rejection_sample: + for sampler in self.tempered_sampler_list: + sample_list.append(sampler.samples) + return pd.concat(sample_list) + + @property + def position(self): + return self.primary_sampler.chain.position + + @property + def evaluations(self): + return int(self.position * len(self.sampler_list)) + + def __getitem__(self, index): + return self.sampler_list[index] + + def step_all_chains(self): + if self.pool: + self.sampler_list = self.pool.map(call_step, self.sampler_list) + else: + for ii, sampler in enumerate(self.sampler_list): + self.sampler_list[ii] = sampler.step() + + if self.nensemble > 1 and self.swap_counter["L2-ensemble"] >= self.L2steps: + self.swap_counter["ensemble"] += 1 + self.swap_counter["L2-ensemble"] = 0 + self.ensemble_step() + + if self.ntemps > 1 and self.swap_counter["L2-temperature"] >= self.L2steps: + self.swap_counter["temperature"] += 1 + self.swap_counter["L2-temperature"] = 0 + self.swap_tempered_chains() + if self.position < self.adapt_t0 * 10: + if self.adapt: + self.adapt_temperatures() + elif self.adapt: + logger.info( + f"Adaptation of temperature chains finished at step {self.position}" + ) + self.adapt = False + + self.swap_counter["L2-ensemble"] += 1 + self.swap_counter["L2-temperature"] += 1 + + @staticmethod + def _get_sample_to_swap(sampler): + if sampler.chain.converged is False: + v = sampler.chain[-1] + else: + v = sampler.chain.random_sample + logl = v[LOGLKEY] + return v, logl + + def swap_tempered_chains(self): + if self.pt_ensemble: + Eindexs = range(self.nensemble) + else: + Eindexs = [0] + for Eindex in Eindexs: + for Tindex in range(self.ntemps - 1): + sampleri = self.sampler_dictionary[Tindex][Eindex] + vi, logli = self._get_sample_to_swap(sampleri) + betai = sampleri.beta + + samplerj = self.sampler_dictionary[Tindex + 1][Eindex] + vj, loglj = self._get_sample_to_swap(samplerj) + betaj = samplerj.beta + + dbeta = betaj - betai + with np.errstate(over="ignore"): + alpha_swap = np.exp(dbeta * (logli - loglj)) + + if np.random.uniform(0, 1) <= alpha_swap: + sampleri.chain[-1] = vj + samplerj.chain[-1] = vi + self.sampler_dictionary[Tindex][Eindex] = sampleri + self.sampler_dictionary[Tindex + 1][Eindex] = samplerj + sampleri.pt_accepted += 1 + else: + sampleri.pt_rejected += 1 + + def ensemble_step(self): + for Tindex, sampler_list in self.sampler_dictionary.items(): + if len(sampler_list) > 1: + for Eindex, sampler in enumerate(sampler_list): + curr = sampler.chain.current_sample + proposal = self.ensemble_proposal_cycle.get_proposal() + complement = [s.chain for s in sampler_list if s != sampler] + prop, log_factor = proposal(sampler.chain, complement) + logp = sampler.log_prior(prop) + + if logp == -np.inf: + sampler.reject_proposal(curr, proposal) + self.sampler_dictionary[Tindex][Eindex] = sampler + continue + + prop[LOGPKEY] = logp + prop[LOGLKEY] = sampler.log_likelihood(prop) + alpha = np.exp( + log_factor + + sampler.beta * prop[LOGLKEY] + + prop[LOGPKEY] + - sampler.beta * curr[LOGLKEY] + - curr[LOGPKEY] + ) + + if np.random.uniform(0, 1) <= alpha: + sampler.accept_proposal(prop, proposal) + else: + sampler.reject_proposal(curr, proposal) + self.sampler_dictionary[Tindex][Eindex] = sampler + + def adapt_temperatures(self): + """Adapt the temperature of the chains + + Using the dynamic temperature selection described in arXiv:1501.05823, + adapt the chains to target a constant swap ratio. This method is based + on github.com/willvousden/ptemcee/tree/master/ptemcee + """ + + self.primary_sampler.chain.minimum_index_adapt = self.position + tt = self.swap_counter["temperature"] + for sampler_list in self.sampler_list_of_tempered_lists: + betas = np.array([s.beta for s in sampler_list]) + ratios = np.array([s.acceptance_ratio for s in sampler_list[:-1]]) + + # Modulate temperature adjustments with a hyperbolic decay. + decay = self.adapt_t0 / (tt + self.adapt_t0) + kappa = decay / self.adapt_nu + + # Construct temperature adjustments. + dSs = kappa * (ratios[:-1] - ratios[1:]) + + # Compute new ladder (hottest and coldest chains don't move). + deltaTs = np.diff(1 / betas[:-1]) + deltaTs *= np.exp(dSs) + betas[1:-1] = 1 / (np.cumsum(deltaTs) + 1 / betas[0]) + for sampler, beta in zip(sampler_list, betas): + sampler.beta = beta + + @property + def ln_z(self): + return self.ln_z_dict.get(self.evidence_method, np.nan) + + @property + def ln_z_err(self): + return self.ln_z_err_dict.get(self.evidence_method, np.nan) + + def compute_evidence(self, outdir, label, make_plots=True): + if self.ntemps == 1: + return + kwargs = dict(outdir=outdir, label=label, make_plots=make_plots) + methods = dict( + thermodynamic=self.thermodynamic_integration_evidence, + stepping_stone=self.stepping_stone_evidence, + ) + for key, method in methods.items(): + ln_z, ln_z_err = self.compute_evidence_per_ensemble(method, kwargs) + self.ln_z_dict[key] = ln_z + self.ln_z_err_dict[key] = ln_z_err + logger.info( + f"Log-evidence of {ln_z:0.2f}+/-{ln_z_err:0.2f} calculated using {key} method" + ) + + def compute_evidence_per_ensemble(self, method, kwargs): + from scipy.special import logsumexp + + if self.ntemps == 1: + return np.nan, np.nan + + lnZ_list = [] + lnZerr_list = [] + for index, ptchain in enumerate(self.sampler_list_of_tempered_lists): + lnZ, lnZerr = method(ptchain, **kwargs) + lnZ_list.append(lnZ) + lnZerr_list.append(lnZerr) + + N = len(lnZ_list) + + # Average lnZ + lnZ = logsumexp(lnZ_list, b=1.0 / N) + + # Propagate uncertainty in combined evidence + lnZerr = 0.5 * logsumexp(2 * np.array(lnZerr_list), b=1.0 / N) + + return lnZ, lnZerr + + def thermodynamic_integration_evidence( + self, ptchain, outdir, label, make_plots=True + ): + """Computes the evidence using thermodynamic integration + + We compute the evidence without the burnin samples, no thinning + """ + from scipy.stats import sem + + betas = [] + mean_lnlikes = [] + sem_lnlikes = [] + for sampler in ptchain: + lnlikes = sampler.chain.get_1d_array(LOGLKEY) + mindex = sampler.chain.minimum_index + lnlikes = lnlikes[mindex:] + mean_lnlikes.append(np.mean(lnlikes)) + sem_lnlikes.append(sem(lnlikes)) + betas.append(sampler.beta) + + # Convert to array and re-order + betas = np.array(betas)[::-1] + mean_lnlikes = np.array(mean_lnlikes)[::-1] + sem_lnlikes = np.array(sem_lnlikes)[::-1] + + lnZ, lnZerr = self._compute_evidence_from_mean_lnlikes(betas, mean_lnlikes) + + if make_plots: + plot_label = f"{label}_E{ptchain[0].Eindex}" + self._create_lnZ_plots( + betas=betas, + mean_lnlikes=mean_lnlikes, + outdir=outdir, + label=plot_label, + sem_lnlikes=sem_lnlikes, + ) + + return lnZ, lnZerr + + def stepping_stone_evidence(self, ptchain, outdir, label, make_plots=True): + """ + Compute the evidence using the stepping stone approximation. + + See https://arxiv.org/abs/1810.04488 and + https://pubmed.ncbi.nlm.nih.gov/21187451/ for details. + + The uncertainty calculation is hopefully combining the evidence in each + of the steps. + + Returns + ------- + ln_z: float + Estimate of the natural log evidence + ln_z_err: float + Estimate of the uncertainty in the evidence + """ + # Order in increasing beta + ptchain.reverse() + + # Get maximum usable set of samples across the ptchain + min_index = max([samp.chain.minimum_index for samp in ptchain]) + max_index = min([len(samp.chain.get_1d_array(LOGLKEY)) for samp in ptchain]) + tau = self.tau + + if max_index - min_index <= 1 or np.isinf(tau): + return np.nan, np.nan + + # Read in log likelihoods + ln_likes = np.array( + [samp.chain.get_1d_array(LOGLKEY)[min_index:max_index] for samp in ptchain] + )[:-1].T + + # Thin to only independent samples + ln_likes = ln_likes[:: int(self.tau), :] + steps = ln_likes.shape[0] + + # Calculate delta betas + betas = np.array([samp.beta for samp in ptchain]) + + ln_z, ln_ratio = self._calculate_stepping_stone(betas, ln_likes) + + # Implementation of the bootstrap method described in Maturana-Russel + # et. al. (2019) to estimate the evidence uncertainty. + ll = 50 # Block length + repeats = 100 # Repeats + ln_z_realisations = [] + try: + for _ in range(repeats): + idxs = [np.random.randint(i, i + ll) for i in range(steps - ll)] + ln_z_realisations.append( + self._calculate_stepping_stone(betas, ln_likes[idxs, :])[0] + ) + ln_z_err = np.std(ln_z_realisations) + except ValueError: + logger.info("Failed to estimate stepping stone uncertainty") + ln_z_err = np.nan + + if make_plots: + plot_label = f"{label}_E{ptchain[0].Eindex}" + self._create_stepping_stone_plot( + means=ln_ratio, + outdir=outdir, + label=plot_label, + ) + + return ln_z, ln_z_err + + @staticmethod + def _calculate_stepping_stone(betas, ln_likes): + from scipy.special import logsumexp + + n_samples = ln_likes.shape[0] + d_betas = betas[1:] - betas[:-1] + ln_ratio = logsumexp(d_betas * ln_likes, axis=0) - np.log(n_samples) + return sum(ln_ratio), ln_ratio + + @staticmethod + def _compute_evidence_from_mean_lnlikes(betas, mean_lnlikes): + lnZ = np.trapz(mean_lnlikes, betas) + z2 = np.trapz(mean_lnlikes[::-1][::2][::-1], betas[::-1][::2][::-1]) + lnZerr = np.abs(lnZ - z2) + return lnZ, lnZerr + + def _create_lnZ_plots(self, betas, mean_lnlikes, outdir, label, sem_lnlikes=None): + import matplotlib.pyplot as plt + + logger.debug("Creating thermodynamic evidence diagnostic plot") + + fig, ax1 = plt.subplots() + if betas[-1] == 0: + x, y = betas[:-1], mean_lnlikes[:-1] + else: + x, y = betas, mean_lnlikes + if sem_lnlikes is not None: + ax1.errorbar(x, y, sem_lnlikes, fmt="-") + else: + ax1.plot(x, y, "-o") + ax1.set_xscale("log") + ax1.set_xlabel(r"$\beta$") + ax1.set_ylabel(r"$\langle \log(\mathcal{L}) \rangle$") + + plt.tight_layout() + fig.savefig("{}/{}_beta_lnl.png".format(outdir, label)) + plt.close() + + def _create_stepping_stone_plot(self, means, outdir, label): + import matplotlib.pyplot as plt + + logger.debug("Creating stepping stone evidence diagnostic plot") + + n_steps = len(means) + + fig, axes = plt.subplots(nrows=2, figsize=(8, 10)) + + ax = axes[0] + ax.plot(np.arange(1, n_steps + 1), means) + ax.set_xlabel("$k$") + ax.set_ylabel("$r_{k}$") + + ax = axes[1] + ax.plot(np.arange(1, n_steps + 1), np.cumsum(means[::1])[::1]) + ax.set_xlabel("$k$") + ax.set_ylabel("Cumulative $\\ln Z$") + + plt.tight_layout() + fig.savefig("{}/{}_stepping_stone.png".format(outdir, label)) + plt.close() + + @property + def rejection_sampling_count(self): + if self.pt_rejection_sample: + counts = 0 + for column in self.sampler_list_of_tempered_lists: + for sampler in column: + counts += sampler.rejection_sampling_count + return counts + else: + return None + + +class BilbyMCMCSampler(object): + def __init__( + self, + convergence_inputs, + proposal_cycle=None, + beta=1, + Tindex=0, + Eindex=0, + use_ratio=False, + ): + self.beta = beta + self.Tindex = Tindex + self.Eindex = Eindex + self.use_ratio = use_ratio + + self.parameters = _priors.non_fixed_keys + self.ndim = len(self.parameters) + + full_sample_dict = _priors.sample() + initial_sample = { + k: v for k, v in full_sample_dict.items() if k in _priors.non_fixed_keys + } + initial_sample = Sample(initial_sample) + initial_sample[LOGLKEY] = self.log_likelihood(initial_sample) + initial_sample[LOGPKEY] = self.log_prior(initial_sample) + + self.chain = Chain(initial_sample=initial_sample) + self.set_convergence_inputs(convergence_inputs) + + self.accepted = 0 + self.rejected = 0 + self.pt_accepted = 0 + self.pt_rejected = 0 + self.rejection_sampling_count = 0 + + if isinstance(proposal_cycle, str): + # Only print warnings for the primary sampler + if Tindex == 0 and Eindex == 0: + warn = True + else: + warn = False + + self.proposal_cycle = proposals.get_proposal_cycle( + proposal_cycle, _priors, L1steps=self.chain.L1steps, warn=warn + ) + elif isinstance(proposal_cycle, proposals.ProposalCycle): + self.proposal_cycle = proposal_cycle + else: + raise SamplerError("Proposal cycle not understood") + + if self.Tindex == 0 and self.Eindex == 0: + logger.info(f"Using {self.proposal_cycle}") + + def set_convergence_inputs(self, convergence_inputs): + for key, val in convergence_inputs._asdict().items(): + setattr(self.chain, key, val) + self.target_nsamples = convergence_inputs.target_nsamples + self.stop_after_convergence = convergence_inputs.stop_after_convergence + + def log_likelihood(self, sample): + _likelihood.parameters.update(sample.sample_dict) + + if self.use_ratio: + logl = _likelihood.log_likelihood_ratio() + else: + logl = _likelihood.log_likelihood() + + return logl + + def log_prior(self, sample): + return _priors.ln_prob(sample.parameter_only_dict) + + def accept_proposal(self, prop, proposal): + self.chain.append(prop) + self.accepted += 1 + proposal.accepted += 1 + + def reject_proposal(self, curr, proposal): + self.chain.append(curr) + self.rejected += 1 + proposal.rejected += 1 + + def step(self): + if self.stop_after_convergence and self.chain.converged: + return self + + internal_steps = 0 + internal_accepted = 0 + internal_rejected = 0 + curr = self.chain.current_sample.copy() + while internal_steps < self.chain.L1steps: + internal_steps += 1 + proposal = self.proposal_cycle.get_proposal() + prop, log_factor = proposal(self.chain) + logp = self.log_prior(prop) + + if np.isinf(logp) or np.isnan(logp): + internal_rejected += 1 + proposal.rejected += 1 + continue + + prop[LOGPKEY] = logp + prop[LOGLKEY] = self.log_likelihood(prop) + + if np.isinf(prop[LOGLKEY]) or np.isnan(prop[LOGLKEY]): + internal_rejected += 1 + proposal.rejected += 1 + continue + + with np.errstate(over="ignore"): + alpha = np.exp( + log_factor + + self.beta * prop[LOGLKEY] + + prop[LOGPKEY] + - self.beta * curr[LOGLKEY] + - curr[LOGPKEY] + ) + + if np.random.uniform(0, 1) <= alpha: + internal_accepted += 1 + proposal.accepted += 1 + curr = prop + self.chain.current_sample = curr + else: + internal_rejected += 1 + proposal.rejected += 1 + + self.chain.append(curr) + self.rejected += internal_rejected + self.accepted += internal_accepted + return self + + @property + def nsamples(self): + nsamples = self.chain.nsamples + if nsamples > self.target_nsamples and self.chain.converged is False: + logger.debug(f"Temperature {self.Tindex} chain reached convergence") + self.chain.converged = True + return nsamples + + @property + def acceptance_ratio(self): + return self.accepted / (self.accepted + self.rejected) + + @property + def samples(self): + if self.beta == 1: + return self.chain.samples + else: + return self.rejection_sample_zero_temperature_samples(print_message=True) + + def rejection_sample_zero_temperature_samples(self, print_message=False): + beta = self.beta + chain = self.chain + hot_samples = pd.DataFrame( + chain._chain_array[chain.minimum_index : chain.position], columns=chain.keys + ) + if len(hot_samples) == 0: + logger.debug( + f"Rejection sampling for Temp {self.Tindex} failed: " + "no usable hot samples" + ) + return hot_samples + + # Pull out log likelihood + zerotemp_logl = hot_samples[LOGLKEY] + + # Revert to true likelihood if needed + if _use_ratio: + zerotemp_logl += _likelihood.noise_log_likelihood() + + # Calculate normalised weights + log_weights = (1 - beta) * zerotemp_logl + max_weight = np.max(log_weights) + unnormalised_weights = np.exp(log_weights - max_weight) + weights = unnormalised_weights / np.sum(unnormalised_weights) + + # Rejection sample + samples = rejection_sample(hot_samples, weights) + + # Logging + self.rejection_sampling_count = len(samples) + + if print_message: + logger.info( + f"Rejection sampling Temp {self.Tindex}, beta={beta:0.2f} " + f"yielded {len(samples)} samples" + ) + return samples + + +# Methods used to aid parallelisation: + + +def call_step(sampler): + sampler = sampler.step() + return sampler + + +_likelihood = None +_priors = None +_search_parameter_keys = None +_use_ratio = False + + +def _initialize_global_variables( + likelihood, + priors, + search_parameter_keys, + use_ratio, +): + """ + Store a global copy of the likelihood, priors, and search keys for + multiprocessing. + """ + global _likelihood + global _priors + global _search_parameter_keys + global _use_ratio + _likelihood = likelihood + _priors = priors + _search_parameter_keys = search_parameter_keys + _use_ratio = use_ratio diff --git a/bilby/bilby_mcmc/utils.py b/bilby/bilby_mcmc/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..25c8ba2327d7bacc9927c6796aeeb48b09647155 --- /dev/null +++ b/bilby/bilby_mcmc/utils.py @@ -0,0 +1,38 @@ +from collections import namedtuple + +LOGLKEY = "logl" +LOGLLATEXKEY = r"$\log\mathcal{L}$" +LOGPKEY = "logp" +LOGPLATEXKEY = r"$\log\pi$" + +ConvergenceInputs = namedtuple( + "ConvergenceInputs", + [ + "autocorr_c", + "burn_in_nact", + "thin_by_nact", + "fixed_discard", + "target_nsamples", + "stop_after_convergence", + "L1steps", + "L2steps", + "min_tau", + "fixed_tau", + "tau_window", + ], +) + +ParallelTemperingInputs = namedtuple( + "ParallelTemperingInputs", + [ + "ntemps", + "nensemble", + "Tmax", + "Tmax_from_SNR", + "initial_betas", + "adapt", + "adapt_t0", + "adapt_nu", + "pt_ensemble", + ], +) diff --git a/bilby/core/result.py b/bilby/core/result.py index 00b66869e14586eb330da5d7a4fd6447f9e941bc..0c234df022ac513507f3a4547c07a29a93a55bbf 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -1776,7 +1776,7 @@ class ResultList(list): # check which kind of sampler was used: MCMC or Nested Sampling if result._nested_samples is not None: posteriors, result = self._combine_nested_sampled_runs(result) - elif result.sampler in ["bilbymcmc"]: + elif result.sampler in ["bilby_mcmc", "bilbymcmc"]: posteriors, result = self._combine_mcmc_sampled_runs(result) else: posteriors = [res.posterior for res in self] diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 93202387f35767e4e58dad4298ca309c4b441506..e586f138510cf5c8637b3faa27b4d3753a26b131 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -22,10 +22,11 @@ from .pymultinest import Pymultinest from .ultranest import Ultranest from .fake_sampler import FakeSampler from .dnest4 import DNest4 +from bilby.bilby_mcmc import Bilby_MCMC from . import proposal IMPLEMENTED_SAMPLERS = { - 'cpnest': Cpnest, 'dnest4': DNest4, 'dynamic_dynesty': DynamicDynesty, + 'bilby_mcmc': Bilby_MCMC, 'cpnest': Cpnest, 'dnest4': DNest4, 'dynamic_dynesty': DynamicDynesty, 'dynesty': Dynesty, 'emcee': Emcee,'kombine': Kombine, 'nessai': Nessai, 'nestle': Nestle, 'ptemcee': Ptemcee, 'ptmcmcsampler': PTMCMCSampler, 'pymc3': Pymc3, 'pymultinest': Pymultinest, 'pypolychord': PyPolyChord, diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 548e1a9b35c6b1e5a5e86264ad422d00844d28cf..5e25a7e54247fa0dcc7f5e9841e308bbfd071173 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -668,6 +668,10 @@ class SamplerError(Error): """ Base class for Error related to samplers in this module """ +class ResumeError(Error): + """ Class for errors arising from resuming runs """ + + class SamplerNotInstalledError(SamplerError): """ Base class for Error raised by not installed samplers """ diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py index 080e591f5c181aa7460356ce9ed34f093fa74521..49f77b7739baf8bf3490185bad3bdf67a30b699c 100644 --- a/bilby/core/utils/io.py +++ b/bilby/core/utils/io.py @@ -29,6 +29,7 @@ class BilbyJsonEncoder(json.JSONEncoder): def default(self, obj): from ..prior import MultivariateGaussianDist, Prior, PriorDict from ...gw.prior import HealPixMapPriorDist + from ...bilby_mcmc.proposals import ProposalCycle if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): @@ -39,6 +40,8 @@ class BilbyJsonEncoder(json.JSONEncoder): return {'__prior__': True, '__module__': obj.__module__, '__name__': obj.__class__.__name__, 'kwargs': dict(obj.get_instantiation_dict())} + if isinstance(obj, ProposalCycle): + return str(obj) try: from astropy import cosmology as cosmo, units if isinstance(obj, cosmo.FLRW): diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 291cf412da04f2626de4d24dc4a6216c89438106..23356d48970b911d0d1d85923010e7124a005dba 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -957,7 +957,8 @@ class CalibrationPriorDict(PriorDict): @staticmethod def from_envelope_file(envelope_file, minimum_frequency, - maximum_frequency, n_nodes, label): + maximum_frequency, n_nodes, label, + boundary="reflective"): """ Load in the calibration envelope. @@ -980,6 +981,8 @@ class CalibrationPriorDict(PriorDict): Number of nodes for the spline. label: str Label for the names of the parameters, e.g., `recalib_H1_` + bounadry: None, 'reflective', 'periodic' + The type of prior boundary to assign Returns ======= @@ -1013,14 +1016,14 @@ class CalibrationPriorDict(PriorDict): prior[name] = Gaussian(mu=amplitude_mean_nodes[ii], sigma=amplitude_sigma_nodes[ii], name=name, latex_label=latex_label, - boundary='reflective') + boundary=boundary) for ii in range(n_nodes): name = "recalib_{}_phase_{}".format(label, ii) latex_label = "$\\phi^{}_{}$".format(label, ii) prior[name] = Gaussian(mu=phase_mean_nodes[ii], sigma=phase_sigma_nodes[ii], name=name, latex_label=latex_label, - boundary='reflective') + boundary=boundary) for ii in range(n_nodes): name = "recalib_{}_frequency_{}".format(label, ii) latex_label = "$f^{}_{}$".format(label, ii) diff --git a/optional_requirements.txt b/optional_requirements.txt index aad48e7d087f2717a6f42483fc95805d3e763fd3..86acd396e5bba1759be86b07ade32ff1b3dfad60 100644 --- a/optional_requirements.txt +++ b/optional_requirements.txt @@ -5,3 +5,5 @@ theano plotly tables pyfftw +scikit-learn +nflows diff --git a/setup.py b/setup.py index 575b97ce124f0ba2e93f4fbd29268013588a333f..49747e8bf0a8e619840aff5f55e0a623d1de102e 100644 --- a/setup.py +++ b/setup.py @@ -85,7 +85,8 @@ setup(name='bilby', version=VERSION, packages=['bilby', 'bilby.core', 'bilby.core.prior', 'bilby.core.sampler', 'bilby.core.utils', 'bilby.gw', 'bilby.gw.detector', - 'bilby.gw.sampler', 'bilby.hyper', 'bilby.gw.eos', 'cli_bilby'], + 'bilby.gw.sampler', 'bilby.hyper', 'bilby.gw.eos', 'bilby.bilby_mcmc', + 'cli_bilby'], package_dir={'bilby': 'bilby', 'cli_bilby': 'cli_bilby'}, package_data={'bilby.gw': ['prior_files/*'], 'bilby.gw.detector': ['noise_curves/*.txt', 'detectors/*'], diff --git a/test/bilby_mcmc/chain.py b/test/bilby_mcmc/chain.py new file mode 100644 index 0000000000000000000000000000000000000000..a8017dc814369fedeea6f4cc4f3121971fb7b69a --- /dev/null +++ b/test/bilby_mcmc/chain.py @@ -0,0 +1,227 @@ +import os +import shutil +import unittest + +import bilby +from bilby.bilby_mcmc.chain import Chain, Sample, calculate_tau +from bilby.bilby_mcmc.utils import LOGLKEY, LOGPKEY +from bilby.core.sampler.base_sampler import SamplerError +import numpy as np +import pandas as pd + + +class TestChain(unittest.TestCase): + def setUp(self): + self.initial_sample = self.create_random_sample() + self.outdir = "chain_test" + if os.path.isdir(self.outdir) is False: + os.mkdir(self.outdir) + + def tearDown(self): + if os.path.isdir(self.outdir): + shutil.rmtree(self.outdir) + + def create_random_sample(self): + return Sample({ + "a": np.random.normal(0, 1), + "b": np.random.normal(0, 1), + LOGLKEY: np.random.normal(0, 1), + LOGPKEY: -1 + }) + + def create_chain(self, n=1000): + chain = Chain(initial_sample=self.initial_sample) + for i in range(n): + chain.append(self.create_random_sample()) + return chain + + def test_initialize(self): + chain = Chain(initial_sample=self.initial_sample) + self.assertEqual(chain.position, 0) + + def test_append(self): + chain = Chain(initial_sample=self.initial_sample) + chain.append(self.create_random_sample()) + self.assertEqual(chain.position, 1) + self.assertEqual(len(chain.get_1d_array('a')), 2) + + def test_append_within_init_space(self): + chain = Chain(initial_sample=self.initial_sample) + N = chain.block_length - 1 + for i in range(N): + chain.append(self.create_random_sample()) + + self.assertEqual(chain.position, N) + + # N samples + 1 initial position + self.assertEqual(len(chain.get_1d_array('a')), N + 1) + + def test_append_with_extending(self): + block_length = 100 + chain = Chain(initial_sample=self.initial_sample, block_length=block_length) + + # Check the array is the block length + self.assertEqual(len(chain._chain_array), block_length) + for i in range(3 * block_length): + chain.append(self.create_random_sample()) + + # Check the array is now longer than the block length (succesfully extended) + self.assertEqual(len(chain._chain_array), 4 * block_length) + + def test_get_item(self): + chain = self.create_chain() + tenth_sample = chain[10] + self.assertTrue(isinstance(tenth_sample, Sample)) + + last_sample = chain[-1] + self.assertEqual(last_sample, chain.current_sample) + + with self.assertRaises(SamplerError): + chain[chain.position + 10] + + def test_set_item(self): + chain = self.create_chain() + s = self.create_random_sample() + + chain[10] = s + self.assertEqual(s, chain[10]) + + chain[-1] = s + self.assertEqual(s, chain[-1]) + + def test_random_sample(self): + chain = self.create_chain() + c1 = chain.random_sample + c2 = chain.random_sample + self.assertNotEqual(c1, c2) + + def test_fixed_discard(self): + chain = self.create_chain() + self.assertEqual(chain.fixed_discard, 0) + chain.fixed_discard = 10 + self.assertEqual(chain.fixed_discard, 10) + + def test_minimum_index(self): + chain = self.create_chain() + self.assertEqual(chain.minimum_index, 0) + + chain._last_minimum_index = (chain.position, 10, "I") + self.assertEqual(chain.minimum_index, 10) + chain._last_minimum_index = (0, 0, "I") + + chain.fixed_discard = 200 + self.assertEqual(chain.minimum_index, 200) + chain._last_minimum_index = (0, 0, "I") + + chain.fixed_discard = 100000 + self.assertEqual(chain.minimum_index, 100000) + + def test_tau(self): + chain = self.create_chain(n=1000) + self.assertGreaterEqual(chain.tau, chain.min_tau) + self.assertLess(chain.tau, np.inf) + self.assertEqual(chain.tau, chain.max_tau_dict[chain.position]) + self.assertEqual(chain.tau, chain.tau_last) + + # Check the cached tau calc works + for i in range(5): + chain.append(self.create_random_sample()) + chain.tau + self.assertEqual(chain.cached_tau_count, 1) + + def test_nsamples(self): + chain = self.create_chain(n=1000) + self.assertGreaterEqual(chain.nsamples, 1) + self.assertLessEqual(chain.nsamples, chain.position) + + def test_thin(self): + chain = self.create_chain(n=1000) + self.assertEqual(chain.thin, int(chain.thin_by_nact * chain.tau)) + + def test_samples(self): + chain = self.create_chain(n=1000) + samples = chain.samples + self.assertTrue(isinstance(samples, pd.DataFrame)) + self.assertTrue("a" in samples) + self.assertTrue("b" in samples) + self.assertTrue(LOGLKEY in samples) + self.assertTrue(LOGPKEY in samples) + + def test_plot(self): + chain = self.create_chain(n=1000) + chain.plot(outdir=self.outdir, label="test") + self.assertTrue(os.path.exists(f"{self.outdir}/test_checkpoint_trace.png")) + priors = dict( + a=bilby.core.prior.Uniform(-10, 10, latex_label='a'), + b=bilby.core.prior.Uniform(-10, 10), + ) + chain.thin_by_nact = 0.5 + chain.plot(outdir=self.outdir, label="test", priors=priors) + self.assertTrue(os.path.exists(f"{self.outdir}/test_checkpoint_trace.png")) + + +class TestSample(unittest.TestCase): + def setUp(self): + self.sample_dict = dict(a=1, b=2) + + def tearDown(self): + del self.sample_dict + + def test_init(self): + s = Sample(self.sample_dict) + self.assertEqual(s.keys, list(self.sample_dict.keys())) + + def test_dict_access(self): + s = Sample(self.sample_dict) + for key in s.keys: + self.assertEqual(s[key], self.sample_dict[key]) + + def test_list_access(self): + s = Sample(self.sample_dict) + slist = s.list + self.assertEqual(slist, [self.sample_dict['a'], self.sample_dict['b']]) + + def test_setitem(self): + s = Sample(self.sample_dict) + + # Set existing parameter + s['a'] = 100 + self.assertEqual(s['a'], 100) + + # Add parameter + s['c'] = 100 + self.assertEqual(s['c'], 100) + + def test_parameter_only_dict(self): + s = Sample(self.sample_dict) + self.assertEqual(s.parameter_only_dict, dict(a=1, b=2)) + + def test_update(self): + sample_dict = dict(a=1, b=2) + curr = Sample(sample_dict) + prop = curr.copy() + prop['a'] = 200 + self.assertEqual(prop['a'], 200) + self.assertEqual(curr['a'], 1) + + +class TestACT(unittest.TestCase): + def test_act_normal(self): + x = np.random.normal(0, 1, 1000) + tau = calculate_tau(x) + self.assertLess(tau, 10) + + def test_act_identical(self): + x = np.array([0] * 1000) + tau = calculate_tau(x) + self.assertEqual(tau, np.inf) + + def test_act_long(self): + t = np.linspace(0, 1, 1000) + x = np.sin(2 * np.pi * t) + tau = calculate_tau(x) + self.assertGreater(tau, 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/bilby_mcmc/proposals.py b/test/bilby_mcmc/proposals.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf71f213e01ddaac2ec90b989e7a6f666d86ae7 --- /dev/null +++ b/test/bilby_mcmc/proposals.py @@ -0,0 +1,187 @@ +import os +import copy +import shutil +import unittest +import inspect +import sys +import time +import bilby +from bilby.bilby_mcmc.chain import Chain, Sample +from bilby.bilby_mcmc import proposals +from bilby.bilby_mcmc.utils import LOGLKEY, LOGPKEY +import numpy as np + + +class GivenProposal(proposals.BaseProposal): + """ A simple proposal class used for testing """ + def __init__(self, priors, weight=1, subset=None, sigma=0.01): + super(GivenProposal, self).__init__(priors, weight, subset) + + def propose(self, chain): + log_factor = 0 + return self.given_sample, log_factor + + +class TestBaseProposals(unittest.TestCase): + def create_priors(self, ndim=2, boundary=None): + priors = bilby.core.prior.PriorDict({ + f'x{i}': bilby.core.prior.Uniform(-10, 10, name=f'x{i}', boundary=boundary) + for i in range(ndim) + }) + priors["fixedA"] = bilby.core.prior.DeltaFunction(1) + return priors + + def create_random_sample(self, ndim=2): + p = {f"x{i}": np.random.normal(0, 1) for i in range(ndim)} + p[LOGLKEY] = np.random.normal(0, 1) + p[LOGPKEY] = -1 + p["fixedA"] = 1 + return Sample(p) + + def create_chain(self, n=1000, ndim=2): + initial_sample = self.create_random_sample(ndim) + chain = Chain(initial_sample=initial_sample) + for i in range(n): + chain.append(self.create_random_sample(ndim)) + return chain + + def test_GivenProposal(self): + priors = self.create_priors() + chain = self.create_chain() + proposal = GivenProposal(priors) + proposal.given_sample = self.create_random_sample() + prop, _ = proposal(chain) + self.assertEqual(prop, proposal.given_sample) + + def test_noboundary(self): + priors = self.create_priors() + chain = self.create_chain() + proposal = GivenProposal(priors) + + sample = self.create_random_sample() + sample["x0"] = priors["x0"].maximum + 0.5 + proposal.given_sample = sample + + prop, _ = proposal(chain) + self.assertEqual(prop, proposal.given_sample) + self.assertEqual(prop["x0"], priors["x0"].maximum + 0.5) + + def test_periodic_boundary_above(self): + priors = self.create_priors(boundary="periodic") + chain = self.create_chain() + proposal = GivenProposal(priors) + + sample = self.create_random_sample() + sample["x0"] = priors["x0"].maximum + 0.5 + proposal.given_sample = copy.deepcopy(sample) + + prop, _ = proposal(chain) + self.assertFalse(prop["x0"] == priors["x0"].maximum + 0.5) + self.assertEqual(prop["x0"], priors["x0"].minimum + 0.5) + + def test_periodic_boundary_below(self): + priors = self.create_priors(boundary="periodic") + chain = self.create_chain() + proposal = GivenProposal(priors) + + sample = self.create_random_sample() + sample["x0"] = priors["x0"].minimum - 0.5 + proposal.given_sample = copy.deepcopy(sample) + + prop, _ = proposal(chain) + self.assertFalse(prop["x0"] == priors["x0"].minimum - 0.5) + self.assertEqual(prop["x0"], priors["x0"].maximum - 0.5) + + +class TestProposals(TestBaseProposals): + def setUp(self): + self.outdir = "chain_test" + if os.path.isdir(self.outdir) is False: + os.mkdir(self.outdir) + + def tearDown(self): + if os.path.isdir(self.outdir): + shutil.rmtree(self.outdir) + + def get_simple_proposals(self): + clsmembers = inspect.getmembers( + sys.modules[proposals.__name__], inspect.isclass + ) + clsmembers_clean = [] + for name, cls in clsmembers: + a = "Proposal" in name + b = "Base" not in name + c = "Ensemble" not in name + d = "Phase" not in name + e = "Polarisation" not in name + f = "Cycle" not in name + g = "KDE" not in name + h = "NormalizingFlow" not in name + if a * b * c * d * e * f * g * h: + clsmembers_clean.append((name, cls)) + + return clsmembers_clean + + def proposal_check(self, prop, ndim=2, N=100): + chain = self.create_chain(ndim=ndim) + + print(f"Testing {prop.__class__.__name__}") + # Timing and return type + start = time.time() + for _ in range(N): + p, w = prop(chain) + chain.append(p) + dt = 1e3 * (time.time() - start) / N + print(f"Testing {prop.__class__.__name__}: dt~{dt:0.2g} [ms]") + + self.assertTrue(isinstance(p, Sample)) + self.assertTrue(isinstance(w, (int, float))) + + def test_proposal_return_type(self): + priors = self.create_priors() + for name, cls in self.get_simple_proposals(): + prop = cls(priors) + self.proposal_check(prop) + + def test_KDE_proposal(self): + priors = self.create_priors() + prop = proposals.KDEProposal(priors) + self.proposal_check(prop, N=20000) + self.assertTrue(prop.trained) + + def test_GMM_proposal(self): + priors = self.create_priors() + prop = proposals.GMMProposal(priors) + self.proposal_check(prop, N=20000) + self.assertTrue(prop.trained) + + def test_NF_proposal(self): + priors = self.create_priors() + chain = self.create_chain(10000) + prop = proposals.NormalizingFlowProposal(priors, first_fit=10000) + prop.steps_since_refit = 9999 + start = time.time() + p, w = prop(chain) + dt = time.time() - start + print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]") + self.assertTrue(prop.trained) + + self.proposal_check(prop) + + def test_NF_proposal_15D(self): + ndim = 15 + priors = self.create_priors(ndim) + chain = self.create_chain(10000, ndim=ndim) + prop = proposals.NormalizingFlowProposal(priors, first_fit=10000) + prop.steps_since_refit = 9999 + start = time.time() + p, w = prop(chain) + dt = time.time() - start + print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]") + self.assertTrue(prop.trained) + + self.proposal_check(prop, ndim=ndim) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/bilby_mcmc/sampler.py b/test/bilby_mcmc/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b5ff98b7f40445a07d4d3aeae682cf662bc95d --- /dev/null +++ b/test/bilby_mcmc/sampler.py @@ -0,0 +1,86 @@ +import os +import shutil +import unittest + +import bilby +from bilby.bilby_mcmc.sampler import BilbyMCMC, BilbyMCMCSampler, _initialize_global_variables +from bilby.bilby_mcmc.utils import ConvergenceInputs +from bilby.core.sampler.base_sampler import SamplerError +import numpy as np +import pandas as pd + + +class TestBilbyMCMCSampler(unittest.TestCase): + def setUp(self): + default_kwargs = BilbyMCMC.default_kwargs + default_kwargs["target_nsamples"] = 100 + default_kwargs["L1steps"] = 1 + self.convergence_inputs = ConvergenceInputs( + **{key: default_kwargs[key] for key in ConvergenceInputs._fields} + ) + + self.outdir = "bilby_mcmc_sampler_test" + if os.path.isdir(self.outdir) is False: + os.mkdir(self.outdir) + + def model(time, m, c): + return time * m + c + injection_parameters = dict(m=0.5, c=0.2) + sampling_frequency = 10 + time_duration = 10 + time = np.arange(0, time_duration, 1 / sampling_frequency) + N = len(time) + sigma = np.random.normal(1, 0.01, N) + data = model(time, **injection_parameters) + np.random.normal(0, sigma, N) + likelihood = bilby.likelihood.GaussianLikelihood(time, data, model, sigma) + + # From hereon, the syntax is exactly equivalent to other bilby examples + # We make a prior + priors = dict() + priors['m'] = bilby.core.prior.Uniform(0, 5, 'm') + priors['c'] = bilby.core.prior.Uniform(-2, 2, 'c') + priors = bilby.core.prior.PriorDict(priors) + + search_parameter_keys = ['m', 'c'] + use_ratio = False + + _initialize_global_variables(likelihood, priors, search_parameter_keys, use_ratio) + + def tearDown(self): + if os.path.isdir(self.outdir): + shutil.rmtree(self.outdir) + + def test_None_proposal_cycle(self): + with self.assertRaises(SamplerError): + BilbyMCMCSampler( + convergence_inputs=self.convergence_inputs, + proposal_cycle=None, + beta=1, + Tindex=0, + Eindex=0, + use_ratio=False + ) + + def test_default_proposal_cycle(self): + sampler = BilbyMCMCSampler( + convergence_inputs=self.convergence_inputs, + proposal_cycle="default_noNFnoGMnoKD", + beta=1, + Tindex=0, + Eindex=0, + use_ratio=False + ) + + nsteps = 0 + while sampler.nsamples < 500: + sampler.step() + nsteps += 1 + self.assertEqual(sampler.chain.position, nsteps) + self.assertEqual(sampler.accepted + sampler.rejected, nsteps) + self.assertTrue(isinstance(sampler.samples, pd.DataFrame)) + for prop in sampler.proposal_cycle.proposal_list: + self.assertGreater(prop.n, 50) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/integration/sampler_run_test.py b/test/integration/sampler_run_test.py index 46f09bbb6988e839d8c4390f278b470b1d882fab..2b225c01bdf318ff0eadfd85597aeed95d2a9f4c 100644 --- a/test/integration/sampler_run_test.py +++ b/test/integration/sampler_run_test.py @@ -23,7 +23,7 @@ class TestRunningSamplers(unittest.TestCase): ) self.priors = bilby.core.prior.PriorDict() - self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="reflective") + self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic") self.priors["c"] = bilby.core.prior.Uniform(-2, 2, boundary="reflective") bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir") @@ -189,6 +189,13 @@ class TestRunningSamplers(unittest.TestCase): sampler='ultranest', save=False, ) + def test_run_bilby_mcmc(self): + _ = bilby.run_sampler( + likelihood=self.likelihood, priors=self.priors, + sampler="bilby_mcmc", nsamples=200, save=False, + printdt=1, + ) + if __name__ == "__main__": unittest.main()