diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 12a355c322abc32ddc26e258b6529387d2bd57a3..eaf6150f9741f99d0bdda8d326ca3c877bac648e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -156,9 +156,10 @@ python-3.9: stage: test script: - python -m pip install . + - python -m pip install schwimmbad - python -m pip list installed - - pytest test/integration/sampler_run_test.py --durations 10 + - pytest test/integration/sampler_run_test.py --durations 10 -v python-3.8-samplers: <<: *test-sampler diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c2c49cea108d3c1ae687cabc8bbc479cd9b6c596..a78a604a57217f0d579213ef06be28307664da50 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: hooks: - id: black language_version: python3 - files: '(^bilby/bilby_mcmc/|^examples/)' + files: '^(bilby/bilby_mcmc/|bilby/core/sampler/|examples/)' - repo: https://github.com/codespell-project/codespell rev: v2.1.0 hooks: @@ -20,7 +20,7 @@ repos: hooks: - id: isort # sort imports alphabetically and separates import into sections args: [-w=88, -m=3, -tc, -sp=setup.cfg ] - files: '(^bilby/bilby_mcmc/|^examples/)' + files: '^(bilby/bilby_mcmc/|bilby/core/sampler/|examples/)' - repo: https://github.com/datarootsio/databooks rev: 0.1.14 hooks: diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py index 66e6e4392a9456161c123ff84a1d5d73c4ed216f..ec6c2abaaf424f8904149cebdde60095262fe1a0 100644 --- a/bilby/bilby_mcmc/sampler.py +++ b/bilby/bilby_mcmc/sampler.py @@ -1,6 +1,5 @@ import datetime import os -import signal import time from collections import Counter @@ -8,7 +7,13 @@ import numpy as np import pandas as pd from ..core.result import rejection_sample -from ..core.sampler.base_sampler import MCMCSampler, ResumeError, SamplerError +from ..core.sampler.base_sampler import ( + MCMCSampler, + ResumeError, + SamplerError, + _sampling_convenience_dump, + signal_wrapper, +) from ..core.utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump from . import proposals from .chain import Chain, Sample @@ -131,7 +136,6 @@ class Bilby_MCMC(MCMCSampler): autocorr_c=5, L1steps=100, L2steps=3, - npool=1, printdt=60, min_tau=1, proposal_cycle="default", @@ -172,7 +176,6 @@ class Bilby_MCMC(MCMCSampler): self.check_point_plot = check_point_plot self.diagnostic = diagnostic self.kwargs["target_nsamples"] = self.kwargs["nsamples"] - self.npool = self.kwargs["npool"] self.L1steps = self.kwargs["L1steps"] self.L2steps = self.kwargs["L2steps"] self.pt_inputs = ParallelTemperingInputs( @@ -194,17 +197,6 @@ class Bilby_MCMC(MCMCSampler): self.verify_configuration() self.verbose = verbose - try: - signal.signal(signal.SIGTERM, self.write_current_state_and_exit) - signal.signal(signal.SIGINT, self.write_current_state_and_exit) - signal.signal(signal.SIGALRM, self.write_current_state_and_exit) - except AttributeError: - logger.debug( - "Setting signal attributes unavailable on this system. " - "This is likely the case if you are running on a Windows machine" - " and is no further concern." - ) - def verify_configuration(self): if self.convergence_inputs.burn_in_nact / self.kwargs["target_nsamples"] > 0.1: logger.warning("Burn-in inefficiency fraction greater than 10%") @@ -223,6 +215,7 @@ class Bilby_MCMC(MCMCSampler): def target_nsamples(self): return self.kwargs["target_nsamples"] + @signal_wrapper def run_sampler(self): self._setup_pool() self.setup_chain_set() @@ -377,31 +370,12 @@ class Bilby_MCMC(MCMCSampler): f"setup:\n{self.get_setup_string()}" ) - def write_current_state_and_exit(self, signum=None, frame=None): - """ - Make sure that if a pool of jobs is running only the parent tries to - checkpoint and exit. Only the parent has a 'pool' attribute. - """ - if self.npool == 1 or getattr(self, "pool", None) is not None: - if signum == 14: - logger.info( - "Run interrupted by alarm signal {}: checkpoint and exit on {}".format( - signum, self.exit_code - ) - ) - else: - logger.info( - "Run interrupted by signal {}: checkpoint and exit on {}".format( - signum, self.exit_code - ) - ) - self.write_current_state() - self._close_pool() - os._exit(self.exit_code) - def write_current_state(self): import dill + if not hasattr(self, "ptsampler"): + logger.debug("Attempted checkpoint before initialization") + return logger.debug("Check point") check_directory_exists_and_if_not_mkdir(self.outdir) @@ -534,39 +508,6 @@ class Bilby_MCMC(MCMCSampler): all_samples=ptsampler.samples, ) - def _setup_pool(self): - if self.npool > 1: - logger.info(f"Setting up multiproccesing pool with {self.npool} processes") - import multiprocessing - - self.pool = multiprocessing.Pool( - processes=self.npool, - initializer=_initialize_global_variables, - initargs=( - self.likelihood, - self.priors, - self._search_parameter_keys, - self.use_ratio, - ), - ) - else: - self.pool = None - - _initialize_global_variables( - likelihood=self.likelihood, - priors=self.priors, - search_parameter_keys=self._search_parameter_keys, - use_ratio=self.use_ratio, - ) - - def _close_pool(self): - if getattr(self, "pool", None) is not None: - logger.info("Starting to close worker pool.") - self.pool.close() - self.pool.join() - self.pool = None - logger.info("Finished closing worker pool.") - class BilbyPTMCMCSampler(object): def __init__( @@ -579,7 +520,6 @@ class BilbyPTMCMCSampler(object): use_ratio, evidence_method, ): - self.set_pt_inputs(pt_inputs) self.use_ratio = use_ratio self.setup_sampler_dictionary(convergence_inputs, proposal_cycle) @@ -597,7 +537,7 @@ class BilbyPTMCMCSampler(object): self._nsamples_dict = {} self.ensemble_proposal_cycle = proposals.get_default_ensemble_proposal_cycle( - _priors + _sampling_convenience_dump.priors ) self.sampling_time = 0 self.ln_z_dict = dict() @@ -612,7 +552,7 @@ class BilbyPTMCMCSampler(object): elif pt_inputs.Tmax is not None: betas = np.logspace(0, -np.log10(pt_inputs.Tmax), pt_inputs.ntemps) elif pt_inputs.Tmax_from_SNR is not None: - ndim = len(_priors.non_fixed_keys) + ndim = len(_sampling_convenience_dump.priors.non_fixed_keys) target_hot_likelihood = ndim / 2 Tmax = pt_inputs.Tmax_from_SNR**2 / (2 * target_hot_likelihood) betas = np.logspace(0, -np.log10(Tmax), pt_inputs.ntemps) @@ -1140,12 +1080,14 @@ class BilbyMCMCSampler(object): self.Eindex = Eindex self.use_ratio = use_ratio - self.parameters = _priors.non_fixed_keys + self.parameters = _sampling_convenience_dump.priors.non_fixed_keys self.ndim = len(self.parameters) - full_sample_dict = _priors.sample() + full_sample_dict = _sampling_convenience_dump.priors.sample() initial_sample = { - k: v for k, v in full_sample_dict.items() if k in _priors.non_fixed_keys + k: v + for k, v in full_sample_dict.items() + if k in _sampling_convenience_dump.priors.non_fixed_keys } initial_sample = Sample(initial_sample) initial_sample[LOGLKEY] = self.log_likelihood(initial_sample) @@ -1168,7 +1110,10 @@ class BilbyMCMCSampler(object): warn = False self.proposal_cycle = proposals.get_proposal_cycle( - proposal_cycle, _priors, L1steps=self.chain.L1steps, warn=warn + proposal_cycle, + _sampling_convenience_dump.priors, + L1steps=self.chain.L1steps, + warn=warn, ) elif isinstance(proposal_cycle, proposals.ProposalCycle): self.proposal_cycle = proposal_cycle @@ -1185,17 +1130,17 @@ class BilbyMCMCSampler(object): self.stop_after_convergence = convergence_inputs.stop_after_convergence def log_likelihood(self, sample): - _likelihood.parameters.update(sample.sample_dict) + _sampling_convenience_dump.likelihood.parameters.update(sample.sample_dict) if self.use_ratio: - logl = _likelihood.log_likelihood_ratio() + logl = _sampling_convenience_dump.likelihood.log_likelihood_ratio() else: - logl = _likelihood.log_likelihood() + logl = _sampling_convenience_dump.likelihood.log_likelihood() return logl def log_prior(self, sample): - return _priors.ln_prob(sample.parameter_only_dict) + return _sampling_convenience_dump.priors.ln_prob(sample.parameter_only_dict) def accept_proposal(self, prop, proposal): self.chain.append(prop) @@ -1293,8 +1238,10 @@ class BilbyMCMCSampler(object): zerotemp_logl = hot_samples[LOGLKEY] # Revert to true likelihood if needed - if _use_ratio: - zerotemp_logl += _likelihood.noise_log_likelihood() + if _sampling_convenience_dump.use_ratio: + zerotemp_logl += ( + _sampling_convenience_dump.likelihood.noise_log_likelihood() + ) # Calculate normalised weights log_weights = (1 - beta) * zerotemp_logl @@ -1322,29 +1269,3 @@ class BilbyMCMCSampler(object): def call_step(sampler): sampler = sampler.step() return sampler - - -_likelihood = None -_priors = None -_search_parameter_keys = None -_use_ratio = False - - -def _initialize_global_variables( - likelihood, - priors, - search_parameter_keys, - use_ratio, -): - """ - Store a global copy of the likelihood, priors, and search keys for - multiprocessing. - """ - global _likelihood - global _priors - global _search_parameter_keys - global _use_ratio - _likelihood = likelihood - _priors = priors - _search_parameter_keys = search_parameter_keys - _use_ratio = use_ratio diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index bd259f97a74c4784ecf4cf8b5fe2c742e7bc79b0..56c32ed3d88e1c95062954c81f15b78c12b4d792 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -1,15 +1,20 @@ +import datetime import inspect import sys -import datetime import bilby -from ..utils import command_line_args, logger, loaded_modules_dict -from ..prior import PriorDict, DeltaFunction +from bilby.bilby_mcmc import Bilby_MCMC + +from ..prior import DeltaFunction, PriorDict +from ..utils import command_line_args, loaded_modules_dict, logger +from . import proposal from .base_sampler import Sampler, SamplingMarginalisedParameterError from .cpnest import Cpnest +from .dnest4 import DNest4 from .dynamic_dynesty import DynamicDynesty from .dynesty import Dynesty from .emcee import Emcee +from .fake_sampler import FakeSampler from .kombine import Kombine from .nessai import Nessai from .nestle import Nestle @@ -19,11 +24,7 @@ from .ptmcmc import PTMCMCSampler from .pymc3 import Pymc3 from .pymultinest import Pymultinest from .ultranest import Ultranest -from .fake_sampler import FakeSampler -from .dnest4 import DNest4 from .zeus import Zeus -from bilby.bilby_mcmc import Bilby_MCMC -from . import proposal IMPLEMENTED_SAMPLERS = { "bilby_mcmc": Bilby_MCMC, @@ -49,7 +50,7 @@ if command_line_args.sampler_help: sampler = command_line_args.sampler_help if sampler in IMPLEMENTED_SAMPLERS: sampler_class = IMPLEMENTED_SAMPLERS[sampler] - print('Help for sampler "{}":'.format(sampler)) + print(f'Help for sampler "{sampler}":') print(sampler_class.__doc__) else: if sampler == "None": @@ -58,8 +59,8 @@ if command_line_args.sampler_help: "the name of the sampler" ) else: - print("Requested sampler {} not implemented".format(sampler)) - print("Available samplers = {}".format(IMPLEMENTED_SAMPLERS)) + print(f"Requested sampler {sampler} not implemented") + print(f"Available samplers = {IMPLEMENTED_SAMPLERS}") sys.exit() @@ -81,7 +82,7 @@ def run_sampler( gzip=False, result_class=None, npool=1, - **kwargs + **kwargs, ): """ The primary interface to easy parameter estimation @@ -144,9 +145,7 @@ def run_sampler( An object containing the results """ - logger.info( - "Running for label '{}', output will be saved to '{}'".format(label, outdir) - ) + logger.info(f"Running for label '{label}', output will be saved to '{outdir}'") if clean: command_line_args.clean = clean @@ -174,7 +173,7 @@ def run_sampler( meta_data = dict() likelihood.label = label likelihood.outdir = outdir - meta_data['likelihood'] = likelihood.meta_data + meta_data["likelihood"] = likelihood.meta_data meta_data["loaded_modules"] = loaded_modules_dict() if command_line_args.bilby_zero_likelihood_mode: @@ -198,11 +197,11 @@ def run_sampler( plot=plot, result_class=result_class, npool=npool, - **kwargs + **kwargs, ) else: print(IMPLEMENTED_SAMPLERS) - raise ValueError("Sampler {} not yet implemented".format(sampler)) + raise ValueError(f"Sampler {sampler} not yet implemented") elif inspect.isclass(sampler): sampler = sampler.__init__( likelihood, @@ -214,12 +213,12 @@ def run_sampler( injection_parameters=injection_parameters, meta_data=meta_data, npool=npool, - **kwargs + **kwargs, ) else: raise ValueError( "Provided sampler should be a Sampler object or name of a known " - "sampler: {}.".format(", ".join(IMPLEMENTED_SAMPLERS.keys())) + f"sampler: {', '.join(IMPLEMENTED_SAMPLERS.keys())}." ) if sampler.cached_result: @@ -240,23 +239,22 @@ def run_sampler( elif isinstance(result.sampling_time, (float, int)): result.sampling_time = datetime.timedelta(result.sampling_time) - logger.info('Sampling time: {}'.format(result.sampling_time)) + logger.info(f"Sampling time: {result.sampling_time}") # Convert sampling time into seconds result.sampling_time = result.sampling_time.total_seconds() if sampler.use_ratio: result.log_noise_evidence = likelihood.noise_log_likelihood() result.log_bayes_factor = result.log_evidence - result.log_evidence = \ - result.log_bayes_factor + result.log_noise_evidence + result.log_evidence = result.log_bayes_factor + result.log_noise_evidence else: result.log_noise_evidence = likelihood.noise_log_likelihood() - result.log_bayes_factor = \ - result.log_evidence - result.log_noise_evidence + result.log_bayes_factor = result.log_evidence - result.log_noise_evidence if None not in [result.injection_parameters, conversion_function]: result.injection_parameters = conversion_function( - result.injection_parameters) + result.injection_parameters + ) # Initial save of the sampler in case of failure in samples_to_posterior if save: @@ -267,9 +265,12 @@ def run_sampler( # Check if the posterior has already been created if getattr(result, "_posterior", None) is None: - result.samples_to_posterior(likelihood=likelihood, priors=result.priors, - conversion_function=conversion_function, - npool=npool) + result.samples_to_posterior( + likelihood=likelihood, + priors=result.priors, + conversion_function=conversion_function, + npool=npool, + ) if save: # The overwrite here ensures we overwrite the initially stored data @@ -277,7 +278,7 @@ def run_sampler( if plot: result.plot_corner() - logger.info("Summary of results:\n{}".format(result)) + logger.info(f"Summary of results:\n{result}") return result @@ -286,7 +287,5 @@ def _check_marginalized_parameters_not_sampled(likelihood, priors): if key in priors: if not isinstance(priors[key], (float, DeltaFunction)): raise SamplingMarginalisedParameterError( - "Likelihood is {} marginalized but you are trying to sample in {}. ".format( - key, key - ) + f"Likelihood is {key} marginalized but you are trying to sample in {key}. " ) diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 215104a98087bb91abc3964684edf1a0a5d0d458..c30f76045d285d0362e6d972330694202e4db381 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -1,18 +1,111 @@ import datetime import distutils.dir_util -import numpy as np import os +import shutil +import signal +import sys import tempfile +import time +import attr +import numpy as np from pandas import DataFrame -from ..utils import logger, check_directory_exists_and_if_not_mkdir, command_line_args, Counter -from ..prior import Prior, PriorDict, DeltaFunction, Constraint +from ..prior import Constraint, DeltaFunction, Prior, PriorDict from ..result import Result, read_in_result +from ..utils import ( + Counter, + check_directory_exists_and_if_not_mkdir, + command_line_args, + logger, +) + + +@attr.s +class _SamplingContainer: + """ + A container class for objects that are stored independently in each thread + for some samplers. + + A single instance of this will appear in this module that can be access + by the individual samplers. + + This includes the: + + - likelihood (bilby.core.likelihood.Likelihood) + - priors (bilby.core.prior.PriorDict) + - search_parameter_keys (list) + - use_ratio (bool) + """ + + likelihood = attr.ib(default=None) + priors = attr.ib(default=None) + search_parameter_keys = attr.ib(default=None) + use_ratio = attr.ib(default=False) + + +_sampling_convenience_dump = _SamplingContainer() + + +def _initialize_global_variables( + likelihood, + priors, + search_parameter_keys, + use_ratio, +): + """ + Store a global copy of the likelihood, priors, and search keys for + multiprocessing. + """ + global _sampling_convenience_dump + _sampling_convenience_dump.likelihood = likelihood + _sampling_convenience_dump.priors = priors + _sampling_convenience_dump.search_parameter_keys = search_parameter_keys + _sampling_convenience_dump.use_ratio = use_ratio + + +def signal_wrapper(method): + """ + Decorator to wrap a method of a class to set system signals before running + and reset them after. + + Parameters + ========== + method: callable + The method to call, this assumes the first argument is `self` + and that `self` has a `write_current_state_and_exit` method. + + Returns + ======= + output: callable + The wrapped method. + """ + + def wrapped(self, *args, **kwargs): + try: + old_term = signal.signal(signal.SIGTERM, self.write_current_state_and_exit) + old_int = signal.signal(signal.SIGINT, self.write_current_state_and_exit) + old_alarm = signal.signal(signal.SIGALRM, self.write_current_state_and_exit) + _set = True + except (AttributeError, ValueError): + _set = False + logger.debug( + "Setting signal attributes unavailable on this system. " + "This is likely the case if you are running on a Windows machine " + "and can be safely ignored." + ) + output = method(self, *args, **kwargs) + if _set: + signal.signal(signal.SIGTERM, old_term) + signal.signal(signal.SIGINT, old_int) + signal.signal(signal.SIGALRM, old_alarm) + return output + + return wrapped class Sampler(object): - """ A sampler object to aid in setting up an inference run + """A sampler object to aid in setting up an inference run Parameters ========== @@ -76,6 +169,10 @@ class Sampler(object): System exit code to return on interrupt kwargs: dict Dictionary of keyword arguments that can be used in the external sampler + hard_exit: bool + Whether the implemented sampler exits hard (:code:`os._exit` rather + than :code:`sys.exit`). The latter can be escaped as :code:`SystemExit`. + The former cannot. Raises ====== @@ -89,15 +186,36 @@ class Sampler(object): If some of the priors can't be sampled """ + default_kwargs = dict() - npool_equiv_kwargs = ['queue_size', 'threads', 'nthreads', 'npool'] + npool_equiv_kwargs = [ + "npool", + "queue_size", + "threads", + "nthreads", + "cores", + "n_pool", + ] + hard_exit = False def __init__( - self, likelihood, priors, outdir='outdir', label='label', - use_ratio=False, plot=False, skip_import_verification=False, - injection_parameters=None, meta_data=None, result_class=None, - likelihood_benchmark=False, soft_init=False, exit_code=130, - **kwargs): + self, + likelihood, + priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + skip_import_verification=False, + injection_parameters=None, + meta_data=None, + result_class=None, + likelihood_benchmark=False, + soft_init=False, + exit_code=130, + npool=1, + **kwargs, + ): self.likelihood = likelihood if isinstance(priors, PriorDict): self.priors = priors @@ -108,6 +226,7 @@ class Sampler(object): self.injection_parameters = injection_parameters self.meta_data = meta_data self.use_ratio = use_ratio + self._npool = npool if not skip_import_verification: self._verify_external_sampler() self.external_sampler_function = None @@ -159,7 +278,7 @@ class Sampler(object): @property def kwargs(self): - """dict: Container for the kwargs. Has more sophisticated logic in subclasses """ + """dict: Container for the kwargs. Has more sophisticated logic in subclasses""" return self._kwargs @kwargs.setter @@ -170,7 +289,7 @@ class Sampler(object): self._verify_kwargs_against_default_kwargs() def _translate_kwargs(self, kwargs): - """ Template for child classes """ + """Template for child classes""" pass @property @@ -180,10 +299,11 @@ class Sampler(object): def _verify_external_sampler(self): external_sampler_name = self.external_sampler_name try: - self.external_sampler = __import__(external_sampler_name) + __import__(external_sampler_name) except (ImportError, SystemExit): raise SamplerNotInstalledError( - "Sampler {} is not installed on this system".format(external_sampler_name)) + f"Sampler {external_sampler_name} is not installed on this system" + ) def _verify_kwargs_against_default_kwargs(self): """ @@ -195,8 +315,8 @@ class Sampler(object): for user_input in self.kwargs.keys(): if user_input not in args: logger.warning( - "Supplied argument '{}' not an argument of '{}', removing." - .format(user_input, self.__class__.__name__)) + f"Supplied argument '{user_input}' not an argument of '{self.__class__.__name__}', removing." + ) bad_keys.append(user_input) for key in bad_keys: self.kwargs.pop(key) @@ -208,8 +328,10 @@ class Sampler(object): the respective parameter is fixed. """ for key in self.priors: - if isinstance(self.priors[key], Prior) \ - and self.priors[key].is_fixed is False: + if ( + isinstance(self.priors[key], Prior) + and self.priors[key].is_fixed is False + ): self._search_parameter_keys.append(key) elif isinstance(self.priors[key], Constraint): self._constraint_parameter_keys.append(key) @@ -219,9 +341,9 @@ class Sampler(object): logger.info("Search parameters:") for key in self._search_parameter_keys + self._constraint_parameter_keys: - logger.info(' {} = {}'.format(key, self.priors[key])) + logger.info(f" {key} = {self.priors[key]}") for key in self._fixed_parameter_keys: - logger.info(' {} = {}'.format(key, self.priors[key].peak)) + logger.info(f" {key} = {self.priors[key].peak}") def _initialise_result(self, result_class): """ @@ -231,27 +353,30 @@ class Sampler(object): """ result_kwargs = dict( - label=self.label, outdir=self.outdir, + label=self.label, + outdir=self.outdir, sampler=self.__class__.__name__.lower(), search_parameter_keys=self._search_parameter_keys, fixed_parameter_keys=self._fixed_parameter_keys, constraint_parameter_keys=self._constraint_parameter_keys, - priors=self.priors, meta_data=self.meta_data, + priors=self.priors, + meta_data=self.meta_data, injection_parameters=self.injection_parameters, - sampler_kwargs=self.kwargs, use_ratio=self.use_ratio) + sampler_kwargs=self.kwargs, + use_ratio=self.use_ratio, + ) if result_class is None: result = Result(**result_kwargs) elif issubclass(result_class, Result): result = result_class(**result_kwargs) else: - raise ValueError( - "Input result_class={} not understood".format(result_class)) + raise ValueError(f"Input result_class={result_class} not understood") return result def _verify_parameters(self): - """ Evaluate a set of parameters drawn from the prior + """Evaluate a set of parameters drawn from the prior Tests if the likelihood evaluation passes @@ -264,20 +389,22 @@ class Sampler(object): if self.priors.test_has_redundant_keys(): raise IllegalSamplingSetError( - "Your sampling set contains redundant parameters.") + "Your sampling set contains redundant parameters." + ) theta = self.priors.sample_subset_constrained_as_array( - self.search_parameter_keys, size=1)[:, 0] + self.search_parameter_keys, size=1 + )[:, 0] try: self.log_likelihood(theta) except TypeError as e: raise TypeError( - "Likelihood evaluation failed with message: \n'{}'\n" - "Have you specified all the parameters:\n{}" - .format(e, self.likelihood.parameters)) + f"Likelihood evaluation failed with message: \n'{e}'\n" + f"Have you specified all the parameters:\n{self.likelihood.parameters}" + ) def _time_likelihood(self, n_evaluations=100): - """ Times the likelihood evaluation and print an info message + """Times the likelihood evaluation and print an info message Parameters ========== @@ -289,7 +416,8 @@ class Sampler(object): t1 = datetime.datetime.now() for _ in range(n_evaluations): theta = self.priors.sample_subset_constrained_as_array( - self._search_parameter_keys, size=1)[:, 0] + self._search_parameter_keys, size=1 + )[:, 0] self.log_likelihood(theta) total_time = (datetime.datetime.now() - t1).total_seconds() self._log_likelihood_eval_time = total_time / n_evaluations @@ -298,8 +426,9 @@ class Sampler(object): self._log_likelihood_eval_time = np.nan logger.info("Unable to measure single likelihood time") else: - logger.info("Single likelihood evaluation took {:.3e} s" - .format(self._log_likelihood_eval_time)) + logger.info( + f"Single likelihood evaluation took {self._log_likelihood_eval_time:.3e} s" + ) def _verify_use_ratio(self): """ @@ -309,9 +438,9 @@ class Sampler(object): try: self.priors.sample_subset(self.search_parameter_keys) except (KeyError, AttributeError): - logger.error("Cannot sample from priors with keys: {}.".format( - self.search_parameter_keys - )) + logger.error( + f"Cannot sample from priors with keys: {self.search_parameter_keys}." + ) raise if self.use_ratio is False: logger.debug("use_ratio set to False") @@ -322,14 +451,14 @@ class Sampler(object): if self.use_ratio is True and ratio_is_nan: logger.warning( "You have requested to use the loglikelihood_ratio, but it " - " returns a NaN") + " returns a NaN" + ) elif self.use_ratio is None and not ratio_is_nan: - logger.debug( - "use_ratio not spec. but gives valid answer, setting True") + logger.debug("use_ratio not spec. but gives valid answer, setting True") self.use_ratio = True def prior_transform(self, theta): - """ Prior transform method that is passed into the external sampler. + """Prior transform method that is passed into the external sampler. Parameters ========== @@ -355,8 +484,7 @@ class Sampler(object): float: Joint ln prior probability of theta """ - params = { - key: t for key, t in zip(self._search_parameter_keys, theta)} + params = {key: t for key, t in zip(self._search_parameter_keys, theta)} return self.priors.ln_prob(params) def log_likelihood(self, theta): @@ -378,8 +506,7 @@ class Sampler(object): self.likelihood_count.increment() except AttributeError: pass - params = { - key: t for key, t in zip(self._search_parameter_keys, theta)} + params = {key: t for key, t in zip(self._search_parameter_keys, theta)} self.likelihood.parameters.update(params) if self.use_ratio: return self.likelihood.log_likelihood_ratio() @@ -387,7 +514,7 @@ class Sampler(object): return self.likelihood.log_likelihood() def get_random_draw_from_prior(self): - """ Get a random draw from the prior distribution + """Get a random draw from the prior distribution Returns ======= @@ -397,13 +524,12 @@ class Sampler(object): """ new_sample = self.priors.sample() - draw = np.array(list(new_sample[key] - for key in self._search_parameter_keys)) + draw = np.array(list(new_sample[key] for key in self._search_parameter_keys)) self.check_draw(draw) return draw def get_initial_points_from_prior(self, npoints=1): - """ Method to draw a set of live points from the prior + """Method to draw a set of live points from the prior This iterates over draws from the prior until all the samples have a finite prior and likelihood (relevant for constrained priors). @@ -457,9 +583,11 @@ class Sampler(object): """ log_p = self.log_prior(theta) log_l = self.log_likelihood(theta) - return \ - self._check_bad_value(val=log_p, warning=warning, theta=theta, label='prior') and \ - self._check_bad_value(val=log_l, warning=warning, theta=theta, label='likelihood') + return self._check_bad_value( + val=log_p, warning=warning, theta=theta, label="prior" + ) and self._check_bad_value( + val=log_l, warning=warning, theta=theta, label="likelihood" + ) @staticmethod def _check_bad_value(val, warning, theta, label): @@ -467,7 +595,7 @@ class Sampler(object): bad_values = [np.inf, np.nan_to_num(np.inf)] if val in bad_values or np.isnan(val): if warning: - logger.warning(f'Prior draw {theta} has inf {label}') + logger.warning(f"Prior draw {theta} has inf {label}") return False return True @@ -485,7 +613,7 @@ class Sampler(object): raise ValueError("Method not yet implemented") def _check_cached_result(self): - """ Check if the cached data file exists and can be used """ + """Check if the cached data file exists and can be used""" if command_line_args.clean: logger.debug("Command line argument clean given, forcing rerun") @@ -493,30 +621,30 @@ class Sampler(object): return try: - self.cached_result = read_in_result( - outdir=self.outdir, label=self.label) + self.cached_result = read_in_result(outdir=self.outdir, label=self.label) except IOError: self.cached_result = None if command_line_args.use_cached: - logger.debug( - "Command line argument cached given, no cache check performed") + logger.debug("Command line argument cached given, no cache check performed") return logger.debug("Checking cached data") if self.cached_result: - check_keys = ['search_parameter_keys', 'fixed_parameter_keys'] + check_keys = ["search_parameter_keys", "fixed_parameter_keys"] use_cache = True for key in check_keys: - if self.cached_result._check_attribute_match_to_other_object( - key, self) is False: - logger.debug("Cached value {} is unmatched".format(key)) + if ( + self.cached_result._check_attribute_match_to_other_object(key, self) + is False + ): + logger.debug(f"Cached value {key} is unmatched") use_cache = False try: # Recursive check the dictionaries allowing for numpy arrays np.testing.assert_equal( self.meta_data["likelihood"], - self.cached_result.meta_data["likelihood"] + self.cached_result.meta_data["likelihood"], ) except AssertionError: use_cache = False @@ -531,13 +659,12 @@ class Sampler(object): if type(kwargs_print[k]) in (list, np.ndarray): array_repr = np.array(kwargs_print[k]) if array_repr.size > 10: - kwargs_print[k] = ('array_like, shape={}' - .format(array_repr.shape)) + kwargs_print[k] = f"array_like, shape={array_repr.shape}" elif type(kwargs_print[k]) == DataFrame: - kwargs_print[k] = ('DataFrame, shape={}' - .format(kwargs_print[k].shape)) - logger.info("Using sampler {} with kwargs {}".format( - self.__class__.__name__, kwargs_print)) + kwargs_print[k] = f"DataFrame, shape={kwargs_print[k].shape}" + logger.info( + f"Using sampler {self.__class__.__name__} with kwargs {kwargs_print}" + ) def calc_likelihood_count(self): if self.likelihood_benchmark: @@ -545,15 +672,100 @@ class Sampler(object): else: return None + @property + def npool(self): + for key in self.npool_equiv_kwargs: + if key in self.kwargs: + return self.kwargs[key] + return self._npool + + def _log_interruption(self, signum=None): + if signum == 14: + logger.info( + f"Run interrupted by alarm signal {signum}: checkpoint and exit on {self.exit_code}" + ) + else: + logger.info( + f"Run interrupted by signal {signum}: checkpoint and exit on {self.exit_code}" + ) + + def write_current_state_and_exit(self, signum=None, frame=None): + """ + Make sure that if a pool of jobs is running only the parent tries to + checkpoint and exit. Only the parent has a 'pool' attribute. + + For samplers that must hard exit (typically due to non-Python process) + use :code:`os._exit` that cannot be excepted. Other samplers exiting + can be caught as a :code:`SystemExit`. + """ + if self.npool in (1, None) or getattr(self, "pool", None) is not None: + self._log_interruption(signum=signum) + self.write_current_state() + self._close_pool() + if self.hard_exit: + os._exit(self.exit_code) + else: + sys.exit(self.exit_code) + + def _close_pool(self): + if getattr(self, "pool", None) is not None: + logger.info("Starting to close worker pool.") + self.pool.close() + self.pool.join() + self.pool = None + self.kwargs["pool"] = self.pool + logger.info("Finished closing worker pool.") + + def _setup_pool(self): + if self.kwargs.get("pool", None) is not None: + logger.info("Using user defined pool.") + self.pool = self.kwargs["pool"] + elif self.npool is not None and self.npool > 1: + logger.info(f"Setting up multiproccesing pool with {self.npool} processes") + import multiprocessing + + self.pool = multiprocessing.Pool( + processes=self.npool, + initializer=_initialize_global_variables, + initargs=( + self.likelihood, + self.priors, + self._search_parameter_keys, + self.use_ratio, + ), + ) + else: + self.pool = None + _initialize_global_variables( + likelihood=self.likelihood, + priors=self.priors, + search_parameter_keys=self._search_parameter_keys, + use_ratio=self.use_ratio, + ) + self.kwargs["pool"] = self.pool + + def write_current_state(self): + raise NotImplementedError() + class NestedSampler(Sampler): - npoints_equiv_kwargs = ['nlive', 'nlives', 'n_live_points', 'npoints', - 'npoint', 'Nlive', 'num_live_points', 'num_particles'] - walks_equiv_kwargs = ['walks', 'steps', 'nmcmc'] + npoints_equiv_kwargs = [ + "nlive", + "nlives", + "n_live_points", + "npoints", + "npoint", + "Nlive", + "num_live_points", + "num_particles", + ] + walks_equiv_kwargs = ["walks", "steps", "nmcmc"] - def reorder_loglikelihoods(self, unsorted_loglikelihoods, unsorted_samples, - sorted_samples): - """ Reorders the stored log-likelihood after they have been reweighted + @staticmethod + def reorder_loglikelihoods( + unsorted_loglikelihoods, unsorted_samples, sorted_samples + ): + """Reorders the stored log-likelihood after they have been reweighted This creates a sorting index by matching the reweights `result.samples` against the raw samples, then uses this index to sort the @@ -578,12 +790,12 @@ class NestedSampler(Sampler): idxs = [] for ii in range(len(unsorted_loglikelihoods)): - idx = np.where(np.all(sorted_samples[ii] == unsorted_samples, - axis=1))[0] + idx = np.where(np.all(sorted_samples[ii] == unsorted_samples, axis=1))[0] if len(idx) > 1: logger.warning( "Multiple likelihood matches found between sorted and " - "unsorted samples. Taking the first match.") + "unsorted samples. Taking the first match." + ) idxs.append(idx[0]) return unsorted_loglikelihoods[idxs] @@ -601,52 +813,34 @@ class NestedSampler(Sampler): ======= float: log_likelihood """ - if self.priors.evaluate_constraints({ - key: theta[ii] for ii, key in - enumerate(self.search_parameter_keys)}): + if self.priors.evaluate_constraints( + {key: theta[ii] for ii, key in enumerate(self.search_parameter_keys)} + ): return Sampler.log_likelihood(self, theta) else: return np.nan_to_num(-np.inf) - def _setup_run_directory(self): - """ - If using a temporary directory, the output directory is moved to the - temporary directory. - Used for Dnest4, Pymultinest, and Ultranest. - """ - if self.use_temporary_directory: - temporary_outputfiles_basename = tempfile.TemporaryDirectory().name - self.temporary_outputfiles_basename = temporary_outputfiles_basename - - if os.path.exists(self.outputfiles_basename): - distutils.dir_util.copy_tree(self.outputfiles_basename, self.temporary_outputfiles_basename) - check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename) - - self.kwargs["outputfiles_basename"] = self.temporary_outputfiles_basename - logger.info("Using temporary file {}".format(temporary_outputfiles_basename)) - else: - check_directory_exists_and_if_not_mkdir(self.outputfiles_basename) - self.kwargs["outputfiles_basename"] = self.outputfiles_basename - logger.info("Using output file {}".format(self.outputfiles_basename)) - class MCMCSampler(Sampler): - nwalkers_equiv_kwargs = ['nwalker', 'nwalkers', 'draws', 'Niter'] - nburn_equiv_kwargs = ['burn', 'nburn'] + nwalkers_equiv_kwargs = ["nwalker", "nwalkers", "draws", "Niter"] + nburn_equiv_kwargs = ["burn", "nburn"] def print_nburn_logging_info(self): - """ Prints logging info as to how nburn was calculated """ + """Prints logging info as to how nburn was calculated""" if type(self.nburn) in [float, int]: - logger.info("Discarding {} steps for burn-in".format(self.nburn)) + logger.info(f"Discarding {self.nburn} steps for burn-in") elif self.result.max_autocorrelation_time is None: - logger.info("Autocorrelation time not calculated, discarding {} " - " steps for burn-in".format(self.nburn)) + logger.info( + f"Autocorrelation time not calculated, discarding " + f"{self.nburn} steps for burn-in" + ) else: - logger.info("Discarding {} steps for burn-in, estimated from " - "autocorr".format(self.nburn)) + logger.info( + f"Discarding {self.nburn} steps for burn-in, estimated from autocorr" + ) def calculate_autocorrelation(self, samples, c=3): - """ Uses the `emcee.autocorr` module to estimate the autocorrelation + """Uses the `emcee.autocorr` module to estimate the autocorrelation Parameters ========== @@ -657,35 +851,155 @@ class MCMCSampler(Sampler): estimate (default: `3`). See `emcee.autocorr.integrated_time`. """ import emcee + try: - self.result.max_autocorrelation_time = int(np.max( - emcee.autocorr.integrated_time(samples, c=c))) - logger.info("Max autocorr time = {}".format( - self.result.max_autocorrelation_time)) + self.result.max_autocorrelation_time = int( + np.max(emcee.autocorr.integrated_time(samples, c=c)) + ) + logger.info(f"Max autocorr time = {self.result.max_autocorrelation_time}") except emcee.autocorr.AutocorrError as e: self.result.max_autocorrelation_time = None - logger.info("Unable to calculate autocorr time: {}".format(e)) + logger.info(f"Unable to calculate autocorr time: {e}") + + +class _TemporaryFileSamplerMixin: + """ + A mixin class to handle storing sampler intermediate products in a temporary + location. See, e.g., `this SO <https://stackoverflow.com/a/547714>` for a + basic background on mixins. + + This class makes sure that any subclasses can seamlessly use the temporary + file functionality. + """ + + short_name = "" + + def __init__(self, temporary_directory, **kwargs): + super(_TemporaryFileSamplerMixin, self).__init__(**kwargs) + self.use_temporary_directory = temporary_directory + self._outputfiles_basename = None + self._temporary_outputfiles_basename = None + + def _check_and_load_sampling_time_file(self): + if os.path.exists(self.time_file_path): + with open(self.time_file_path, "r") as time_file: + self.total_sampling_time = float(time_file.readline()) + else: + self.total_sampling_time = 0 + + def _calculate_and_save_sampling_time(self): + current_time = time.time() + new_sampling_time = current_time - self.start_time + self.total_sampling_time += new_sampling_time + + with open(self.time_file_path, "w") as time_file: + time_file.write(str(self.total_sampling_time)) + + self.start_time = current_time + + def _clean_up_run_directory(self): + if self.use_temporary_directory: + self._move_temporary_directory_to_proper_path() + self.kwargs["outputfiles_basename"] = self.outputfiles_basename + + @property + def outputfiles_basename(self): + return self._outputfiles_basename + + @outputfiles_basename.setter + def outputfiles_basename(self, outputfiles_basename): + if outputfiles_basename is None: + outputfiles_basename = f"{self.outdir}/{self.short_name}_{self.label}/" + if not outputfiles_basename.endswith("/"): + outputfiles_basename += "/" + check_directory_exists_and_if_not_mkdir(self.outdir) + self._outputfiles_basename = outputfiles_basename + + @property + def temporary_outputfiles_basename(self): + return self._temporary_outputfiles_basename + + @temporary_outputfiles_basename.setter + def temporary_outputfiles_basename(self, temporary_outputfiles_basename): + if not temporary_outputfiles_basename.endswith("/"): + temporary_outputfiles_basename += "/" + self._temporary_outputfiles_basename = temporary_outputfiles_basename + if os.path.exists(self.outputfiles_basename): + shutil.copytree( + self.outputfiles_basename, self.temporary_outputfiles_basename + ) + + def write_current_state(self): + self._calculate_and_save_sampling_time() + if self.use_temporary_directory: + self._move_temporary_directory_to_proper_path() + + def _move_temporary_directory_to_proper_path(self): + """ + Move the temporary back to the proper path + + Anything in the proper path at this point is removed including links + """ + self._copy_temporary_directory_contents_to_proper_path() + shutil.rmtree(self.temporary_outputfiles_basename) + + def _copy_temporary_directory_contents_to_proper_path(self): + """ + Copy the temporary back to the proper path. + Do not delete the temporary directory. + """ + logger.info( + f"Overwriting {self.outputfiles_basename} with {self.temporary_outputfiles_basename}" + ) + outputfiles_basename_stripped = self.outputfiles_basename.rstrip("/") + distutils.dir_util.copy_tree( + self.temporary_outputfiles_basename, outputfiles_basename_stripped + ) + + def _setup_run_directory(self): + """ + If using a temporary directory, the output directory is moved to the + temporary directory. + Used for Dnest4, Pymultinest, and Ultranest. + """ + check_directory_exists_and_if_not_mkdir(self.outputfiles_basename) + if self.use_temporary_directory: + temporary_outputfiles_basename = tempfile.TemporaryDirectory().name + self.temporary_outputfiles_basename = temporary_outputfiles_basename + + if os.path.exists(self.outputfiles_basename): + distutils.dir_util.copy_tree( + self.outputfiles_basename, self.temporary_outputfiles_basename + ) + check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename) + + self.kwargs["outputfiles_basename"] = self.temporary_outputfiles_basename + logger.info(f"Using temporary file {temporary_outputfiles_basename}") + else: + self.kwargs["outputfiles_basename"] = self.outputfiles_basename + logger.info(f"Using output file {self.outputfiles_basename}") + self.time_file_path = self.kwargs["outputfiles_basename"] + "/sampling_time.dat" class Error(Exception): - """ Base class for all exceptions raised by this module """ + """Base class for all exceptions raised by this module""" class SamplerError(Error): - """ Base class for Error related to samplers in this module """ + """Base class for Error related to samplers in this module""" class ResumeError(Error): - """ Class for errors arising from resuming runs """ + """Class for errors arising from resuming runs""" class SamplerNotInstalledError(SamplerError): - """ Base class for Error raised by not installed samplers """ + """Base class for Error raised by not installed samplers""" class IllegalSamplingSetError(Error): - """ Class for illegal sets of sampling parameters """ + """Class for illegal sets of sampling parameters""" class SamplingMarginalisedParameterError(IllegalSamplingSetError): - """ Class for errors that occur when sampling over marginalized parameters """ + """Class for errors that occur when sampling over marginalized parameters""" diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py index e64365f2e423c9ac4968e8c9964f4d26cd55bc00..b124643759fc742465d612918523161b585bb17c 100644 --- a/bilby/core/sampler/cpnest.py +++ b/bilby/core/sampler/cpnest.py @@ -1,18 +1,18 @@ - import array import copy +import sys import numpy as np from numpy.lib.recfunctions import structured_to_unstructured from pandas import DataFrame -from .base_sampler import NestedSampler -from .proposal import Sample, JumpProposalCycle -from ..utils import logger, check_directory_exists_and_if_not_mkdir +from ..utils import check_directory_exists_and_if_not_mkdir, logger +from .base_sampler import NestedSampler, signal_wrapper +from .proposal import JumpProposalCycle, Sample class Cpnest(NestedSampler): - """ bilby wrapper of cpnest (https://github.com/johnveitch/cpnest) + """bilby wrapper of cpnest (https://github.com/johnveitch/cpnest) All positional and keyword arguments (i.e., the args and kwargs) passed to `run_sampler` will be propagated to `cpnest.CPNest`, see documentation @@ -39,30 +39,44 @@ class Cpnest(NestedSampler): {self.outdir}/cpnest_{self.label}/ """ - default_kwargs = dict(verbose=3, nthreads=1, nlive=500, maxmcmc=1000, - seed=None, poolsize=100, nhamiltonian=0, resume=True, - output=None, proposals=None, n_periodic_checkpoint=8000) + + default_kwargs = dict( + verbose=3, + nthreads=1, + nlive=500, + maxmcmc=1000, + seed=None, + poolsize=100, + nhamiltonian=0, + resume=True, + output=None, + proposals=None, + n_periodic_checkpoint=8000, + ) + hard_exit = True def _translate_kwargs(self, kwargs): - if 'nlive' not in kwargs: + if "nlive" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: - kwargs['nlive'] = kwargs.pop(equiv) - if 'nthreads' not in kwargs: + kwargs["nlive"] = kwargs.pop(equiv) + if "nthreads" not in kwargs: for equiv in self.npool_equiv_kwargs: if equiv in kwargs: - kwargs['nthreads'] = kwargs.pop(equiv) + kwargs["nthreads"] = kwargs.pop(equiv) - if 'seed' not in kwargs: - logger.warning('No seed provided, cpnest will use 1234.') + if "seed" not in kwargs: + logger.warning("No seed provided, cpnest will use 1234.") + @signal_wrapper def run_sampler(self): - from cpnest import model as cpmodel, CPNest - from cpnest.parameter import LivePoint + from cpnest import CPNest + from cpnest import model as cpmodel from cpnest.nest2pos import compute_weights + from cpnest.parameter import LivePoint class Model(cpmodel.Model): - """ A wrapper class to pass our log_likelihood into cpnest """ + """A wrapper class to pass our log_likelihood into cpnest""" def __init__(self, names, priors): self.names = names @@ -82,14 +96,16 @@ class Cpnest(NestedSampler): def _update_bounds(self): self.bounds = [ [self.priors[key].minimum, self.priors[key].maximum] - for key in self.names] + for key in self.names + ] def new_point(self): """Draw a point from the prior""" prior_samples = self.priors.sample() self._update_bounds() point = LivePoint( - self.names, array.array('d', [prior_samples[name] for name in self.names]) + self.names, + array.array("d", [prior_samples[name] for name in self.names]), ) return point @@ -105,18 +121,14 @@ class Cpnest(NestedSampler): kwarg = remove_kwargs.pop(0) else: raise TypeError("Unable to initialise cpnest sampler") - logger.info( - "CPNest init. failed with error {}, please update" - .format(e)) - logger.info( - "Attempting to rerun with kwarg {} removed".format(kwarg)) + logger.info(f"CPNest init. failed with error {e}, please update") + logger.info(f"Attempting to rerun with kwarg {kwarg} removed") self.kwargs.pop(kwarg) try: out.run() - except SystemExit as e: - import sys - logger.info("Caught exit code {}, exiting with signal {}".format(e.args[0], self.exit_code)) - sys.exit(self.exit_code) + except SystemExit: + out.checkpoint() + self.write_current_state_and_exit() if self.plot: out.plot() @@ -125,42 +137,58 @@ class Cpnest(NestedSampler): self.result.samples = structured_to_unstructured( out.posterior_samples[self.search_parameter_keys] ) - self.result.log_likelihood_evaluations = out.posterior_samples['logL'] - self.result.nested_samples = DataFrame(out.get_nested_samples(filename='')) - self.result.nested_samples.rename(columns=dict(logL='log_likelihood'), inplace=True) - _, log_weights = compute_weights(np.array(self.result.nested_samples.log_likelihood), - np.array(out.NS.state.nlive)) - self.result.nested_samples['weights'] = np.exp(log_weights) + self.result.log_likelihood_evaluations = out.posterior_samples["logL"] + self.result.nested_samples = DataFrame(out.get_nested_samples(filename="")) + self.result.nested_samples.rename( + columns=dict(logL="log_likelihood"), inplace=True + ) + _, log_weights = compute_weights( + np.array(self.result.nested_samples.log_likelihood), + np.array(out.NS.state.nlive), + ) + self.result.nested_samples["weights"] = np.exp(log_weights) self.result.log_evidence = out.NS.state.logZ self.result.log_evidence_err = np.sqrt(out.NS.state.info / out.NS.state.nlive) self.result.information_gain = out.NS.state.info return self.result + def write_current_state_and_exit(self, signum=None, frame=None): + """ + Overwrites the base class to make sure that :code:`CPNest` terminates + properly as :code:`CPNest` handles all the multiprocessing internally. + """ + self._log_interruption(signum=signum) + sys.exit(self.exit_code) + def _verify_kwargs_against_default_kwargs(self): """ Set the directory where the output will be written and check resume and checkpoint status. """ - if not self.kwargs['output']: - self.kwargs['output'] = \ - '{}/cpnest_{}/'.format(self.outdir, self.label) - if self.kwargs['output'].endswith('/') is False: - self.kwargs['output'] = '{}/'.format(self.kwargs['output']) - check_directory_exists_and_if_not_mkdir(self.kwargs['output']) - if self.kwargs['n_periodic_checkpoint'] and not self.kwargs['resume']: - self.kwargs['n_periodic_checkpoint'] = None + if not self.kwargs["output"]: + self.kwargs["output"] = f"{self.outdir}/cpnest_{self.label}/" + if self.kwargs["output"].endswith("/") is False: + self.kwargs["output"] = f"{self.kwargs['output']}/" + check_directory_exists_and_if_not_mkdir(self.kwargs["output"]) + if self.kwargs["n_periodic_checkpoint"] and not self.kwargs["resume"]: + self.kwargs["n_periodic_checkpoint"] = None NestedSampler._verify_kwargs_against_default_kwargs(self) def _resolve_proposal_functions(self): from cpnest.proposal import ProposalCycle - if 'proposals' in self.kwargs: - if self.kwargs['proposals'] is None: + + if "proposals" in self.kwargs: + if self.kwargs["proposals"] is None: return - if type(self.kwargs['proposals']) == JumpProposalCycle: - self.kwargs['proposals'] = dict(mhs=self.kwargs['proposals'], hmc=self.kwargs['proposals']) - for key, proposal in self.kwargs['proposals'].items(): + if type(self.kwargs["proposals"]) == JumpProposalCycle: + self.kwargs["proposals"] = dict( + mhs=self.kwargs["proposals"], hmc=self.kwargs["proposals"] + ) + for key, proposal in self.kwargs["proposals"].items(): if isinstance(proposal, JumpProposalCycle): - self.kwargs['proposals'][key] = cpnest_proposal_cycle_factory(proposal) + self.kwargs["proposals"][key] = cpnest_proposal_cycle_factory( + proposal + ) elif isinstance(proposal, ProposalCycle): pass else: @@ -171,7 +199,6 @@ def cpnest_proposal_factory(jump_proposal): import cpnest.proposal class CPNestEnsembleProposal(cpnest.proposal.EnsembleProposal): - def __init__(self, jp): self.jump_proposal = jp self.ensemble = None @@ -181,8 +208,8 @@ def cpnest_proposal_factory(jump_proposal): def get_sample(self, cpnest_sample, **kwargs): sample = Sample.from_cpnest_live_point(cpnest_sample) - self.ensemble = kwargs.get('coordinates', self.ensemble) - sample = self.jump_proposal(sample=sample, sampler_name='cpnest', **kwargs) + self.ensemble = kwargs.get("coordinates", self.ensemble) + sample = self.jump_proposal(sample=sample, sampler_name="cpnest", **kwargs) self.log_J = self.jump_proposal.log_j return self._update_cpnest_sample(cpnest_sample, sample) @@ -203,11 +230,15 @@ def cpnest_proposal_cycle_factory(jump_proposals): def __init__(self): self.jump_proposals = copy.deepcopy(jump_proposals) for i, prop in enumerate(self.jump_proposals.proposal_functions): - self.jump_proposals.proposal_functions[i] = cpnest_proposal_factory(prop) + self.jump_proposals.proposal_functions[i] = cpnest_proposal_factory( + prop + ) self.jump_proposals.update_cycle() - super(CPNestProposalCycle, self).__init__(proposals=self.jump_proposals.proposal_functions, - weights=self.jump_proposals.weights, - cyclelength=self.jump_proposals.cycle_length) + super(CPNestProposalCycle, self).__init__( + proposals=self.jump_proposals.proposal_functions, + weights=self.jump_proposals.weights, + cyclelength=self.jump_proposals.cycle_length, + ) def get_sample(self, old, **kwargs): return self.jump_proposals(sample=old, coordinates=self.ensemble, **kwargs) diff --git a/bilby/core/sampler/dnest4.py b/bilby/core/sampler/dnest4.py index ef80c13e933e4dfd6fcaa5e4c3ea8f113b15928e..7d5b97092919e0951de2329de8633e4a6bc7fc2b 100644 --- a/bilby/core/sampler/dnest4.py +++ b/bilby/core/sampler/dnest4.py @@ -1,21 +1,17 @@ -import os -import shutil -import distutils.dir_util -import signal -import time import datetime -import sys +import time import numpy as np import pandas as pd -from ..utils import check_directory_exists_and_if_not_mkdir, logger -from .base_sampler import NestedSampler +from ..utils import logger +from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper class _DNest4Model(object): - - def __init__(self, log_likelihood_func, from_prior_func, widths, centers, highs, lows): + def __init__( + self, log_likelihood_func, from_prior_func, widths, centers, highs, lows + ): """Initialize the DNest4 model. Args: log_likelihood_func: function @@ -48,7 +44,7 @@ class _DNest4Model(object): """The perturb function to perform Monte Carlo trial moves.""" idx = np.random.randint(self._n_dim) - coords[idx] += (self._widths[idx] * (np.random.uniform(size=1) - 0.5)) + coords[idx] += self._widths[idx] * (np.random.uniform(size=1) - 0.5) cw = self._widths[idx] cc = self._centers[idx] @@ -59,11 +55,13 @@ class _DNest4Model(object): @staticmethod def wrap(x, minimum, maximum): if maximum <= minimum: - raise ValueError("maximum {} <= minimum {}, when trying to wrap coordinates".format(maximum, minimum)) + raise ValueError( + f"maximum {maximum} <= minimum {minimum}, when trying to wrap coordinates" + ) return (x - minimum) % (maximum - minimum) + minimum -class DNest4(NestedSampler): +class DNest4(_TemporaryFileSamplerMixin, NestedSampler): """ Bilby wrapper of DNest4 @@ -100,35 +98,58 @@ class DNest4(NestedSampler): If True, prints information during run """ - default_kwargs = dict(max_num_levels=20, num_steps=500, - new_level_interval=10000, num_per_step=10000, - thread_steps=1, num_particles=1000, lam=10.0, - beta=100, seed=None, verbose=True, outputfiles_basename=None, - backend='memory') - - def __init__(self, likelihood, priors, outdir="outdir", label="label", use_ratio=False, plot=False, - exit_code=77, skip_import_verification=False, temporary_directory=True, **kwargs): + default_kwargs = dict( + max_num_levels=20, + num_steps=500, + new_level_interval=10000, + num_per_step=10000, + thread_steps=1, + num_particles=1000, + lam=10.0, + beta=100, + seed=None, + verbose=True, + outputfiles_basename=None, + backend="memory", + ) + short_name = "dn4" + hard_exit = True + + def __init__( + self, + likelihood, + priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + exit_code=77, + skip_import_verification=False, + temporary_directory=True, + **kwargs, + ): super(DNest4, self).__init__( - likelihood=likelihood, priors=priors, outdir=outdir, label=label, - use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification, - exit_code=exit_code, **kwargs) + likelihood=likelihood, + priors=priors, + outdir=outdir, + label=label, + use_ratio=use_ratio, + plot=plot, + skip_import_verification=skip_import_verification, + temporary_directory=temporary_directory, + exit_code=exit_code, + **kwargs, + ) self.num_particles = self.kwargs["num_particles"] self.max_num_levels = self.kwargs["max_num_levels"] self._verbose = self.kwargs["verbose"] self._backend = self.kwargs["backend"] - self.use_temporary_directory = temporary_directory self.start_time = np.nan self.sampler = None self._information = np.nan self._last_live_sample_info = np.nan - self._outputfiles_basename = None - self._temporary_outputfiles_basename = None - - signal.signal(signal.SIGTERM, self.write_current_state_and_exit) - signal.signal(signal.SIGINT, self.write_current_state_and_exit) - signal.signal(signal.SIGALRM, self.write_current_state_and_exit) # Get the estimates of the prior distributions' widths and centers. widths = [] @@ -155,13 +176,22 @@ class DNest4(NestedSampler): self._highs = np.array(highs) self._lows = np.array(lows) - self._dnest4_model = _DNest4Model(self.log_likelihood, self.get_random_draw_from_prior, self._widths, - self._centers, self._highs, self._lows) + self._dnest4_model = _DNest4Model( + self.log_likelihood, + self.get_random_draw_from_prior, + self._widths, + self._centers, + self._highs, + self._lows, + ) def _set_backend(self): import dnest4 - if self._backend == 'csv': - return dnest4.backends.CSVBackend("{}/dnest4{}/".format(self.outdir, self.label), sep=" ") + + if self._backend == "csv": + return dnest4.backends.CSVBackend( + f"{self.outdir}/dnest4{self.label}/", sep=" " + ) else: return dnest4.backends.MemoryBackend() @@ -169,6 +199,7 @@ class DNest4(NestedSampler): dnest4_keys = ["num_steps", "new_level_interval", "lam", "beta", "seed"] self.dnest4_kwargs = {key: self.kwargs[key] for key in dnest4_keys} + @signal_wrapper def run_sampler(self): import dnest4 @@ -181,31 +212,37 @@ class DNest4(NestedSampler): self.start_time = time.time() self.sampler = dnest4.DNest4Sampler(self._dnest4_model, backend=backend) - out = self.sampler.sample(self.max_num_levels, - num_particles=self.num_particles, - **self.dnest4_kwargs) + out = self.sampler.sample( + self.max_num_levels, num_particles=self.num_particles, **self.dnest4_kwargs + ) for i, sample in enumerate(out): if self._verbose and ((i + 1) % 100 == 0): stats = self.sampler.postprocess() - logger.info("Iteration: {0} log(Z): {1}".format(i + 1, stats['log_Z'])) + logger.info(f"Iteration: {i + 1} log(Z): {stats['log_Z']}") self._calculate_and_save_sampling_time() self._clean_up_run_directory() stats = self.sampler.postprocess(resample=1) - self.result.log_evidence = stats['log_Z'] - self._information = stats['H'] + self.result.log_evidence = stats["log_Z"] + self._information = stats["H"] self.result.log_evidence_err = np.sqrt(self._information / self.num_particles) - if self._backend == 'memory': - self._last_live_sample_info = pd.DataFrame(self.sampler.backend.sample_info[-1]) - self.result.log_likelihood_evaluations = self._last_live_sample_info['log_likelihood'] + if self._backend == "memory": + self._last_live_sample_info = pd.DataFrame( + self.sampler.backend.sample_info[-1] + ) + self.result.log_likelihood_evaluations = self._last_live_sample_info[ + "log_likelihood" + ] self.result.samples = np.array(self.sampler.backend.posterior_samples) else: - sample_info_path = './' + self.kwargs["outputfiles_basename"] + '/sample_info.txt' - sample_info = np.genfromtxt(sample_info_path, comments='#', names=True) - self.result.log_likelihood_evaluations = sample_info['log_likelihood'] + sample_info_path = ( + "./" + self.kwargs["outputfiles_basename"] + "/sample_info.txt" + ) + sample_info = np.genfromtxt(sample_info_path, comments="#", names=True) + self.result.log_likelihood_evaluations = sample_info["log_likelihood"] self.result.samples = np.array(self.sampler.backend.posterior_samples) self.result.sampler_output = out @@ -217,100 +254,11 @@ class DNest4(NestedSampler): return self.result def _translate_kwargs(self, kwargs): - if 'num_steps' not in kwargs: + if "num_steps" not in kwargs: for equiv in self.walks_equiv_kwargs: if equiv in kwargs: - kwargs['num_steps'] = kwargs.pop(equiv) + kwargs["num_steps"] = kwargs.pop(equiv) def _verify_kwargs_against_default_kwargs(self): self.outputfiles_basename = self.kwargs.pop("outputfiles_basename", None) super(DNest4, self)._verify_kwargs_against_default_kwargs() - - def _check_and_load_sampling_time_file(self): - self.time_file_path = self.kwargs["outputfiles_basename"] + '/sampling_time.dat' - if os.path.exists(self.time_file_path): - with open(self.time_file_path, 'r') as time_file: - self.total_sampling_time = float(time_file.readline()) - else: - self.total_sampling_time = 0 - - def _calculate_and_save_sampling_time(self): - current_time = time.time() - new_sampling_time = current_time - self.start_time - self.total_sampling_time += new_sampling_time - - with open(self.time_file_path, 'w') as time_file: - time_file.write(str(self.total_sampling_time)) - - self.start_time = current_time - - def _clean_up_run_directory(self): - if self.use_temporary_directory: - self._move_temporary_directory_to_proper_path() - self.kwargs["outputfiles_basename"] = self.outputfiles_basename - - @property - def outputfiles_basename(self): - return self._outputfiles_basename - - @outputfiles_basename.setter - def outputfiles_basename(self, outputfiles_basename): - if outputfiles_basename is None: - outputfiles_basename = "{}/dnest4{}/".format(self.outdir, self.label) - if not outputfiles_basename.endswith("/"): - outputfiles_basename += "/" - check_directory_exists_and_if_not_mkdir(self.outdir) - self._outputfiles_basename = outputfiles_basename - - @property - def temporary_outputfiles_basename(self): - return self._temporary_outputfiles_basename - - @temporary_outputfiles_basename.setter - def temporary_outputfiles_basename(self, temporary_outputfiles_basename): - if not temporary_outputfiles_basename.endswith("/"): - temporary_outputfiles_basename = "{}/".format( - temporary_outputfiles_basename - ) - self._temporary_outputfiles_basename = temporary_outputfiles_basename - if os.path.exists(self.outputfiles_basename): - shutil.copytree( - self.outputfiles_basename, self.temporary_outputfiles_basename - ) - - def write_current_state_and_exit(self, signum=None, frame=None): - """ Write current state and exit on exit_code """ - logger.info( - "Run interrupted by signal {}: checkpoint and exit on {}".format( - signum, self.exit_code - ) - ) - self._calculate_and_save_sampling_time() - if self.use_temporary_directory: - self._move_temporary_directory_to_proper_path() - sys.exit(self.exit_code) - - def _move_temporary_directory_to_proper_path(self): - """ - Move the temporary back to the proper path - - Anything in the proper path at this point is removed including links - """ - self._copy_temporary_directory_contents_to_proper_path() - shutil.rmtree(self.temporary_outputfiles_basename) - - def _copy_temporary_directory_contents_to_proper_path(self): - """ - Copy the temporary back to the proper path. - Do not delete the temporary directory. - """ - logger.info( - "Overwriting {} with {}".format( - self.outputfiles_basename, self.temporary_outputfiles_basename - ) - ) - if self.outputfiles_basename.endswith('/'): - outputfiles_basename_stripped = self.outputfiles_basename[:-1] - else: - outputfiles_basename_stripped = self.outputfiles_basename - distutils.dir_util.copy_tree(self.temporary_outputfiles_basename, outputfiles_basename_stripped) diff --git a/bilby/core/sampler/dynamic_dynesty.py b/bilby/core/sampler/dynamic_dynesty.py index 8bb6d647aad013c5be957cfde6311f10f9feda82..ef28f22ddbb2ce099d481c3adebaae1ef1d3b0cd 100644 --- a/bilby/core/sampler/dynamic_dynesty.py +++ b/bilby/core/sampler/dynamic_dynesty.py @@ -1,12 +1,10 @@ - -import os -import signal +import datetime import numpy as np -from ..utils import logger, check_directory_exists_and_if_not_mkdir -from .base_sampler import Sampler -from .dynesty import Dynesty +from ..utils import logger +from .base_sampler import Sampler, signal_wrapper +from .dynesty import Dynesty, _log_likelihood_wrapper, _prior_transform_wrapper class DynamicDynesty(Dynesty): @@ -62,33 +60,77 @@ class DynamicDynesty(Dynesty): resume: bool If true, resume run from checkpoint (if available) """ - default_kwargs = dict(bound='multi', sample='rwalk', - verbose=True, - check_point_delta_t=600, - first_update=None, - npdim=None, rstate=None, queue_size=None, pool=None, - use_pool=None, - logl_args=None, logl_kwargs=None, - ptform_args=None, ptform_kwargs=None, - enlarge=None, bootstrap=None, vol_dec=0.5, vol_check=2.0, - facc=0.5, slices=5, - walks=None, update_interval=0.6, - nlive_init=500, maxiter_init=None, maxcall_init=None, - dlogz_init=0.01, logl_max_init=np.inf, nlive_batch=500, - wt_function=None, wt_kwargs=None, maxiter_batch=None, - maxcall_batch=None, maxiter=None, maxcall=None, - maxbatch=None, stop_function=None, stop_kwargs=None, - use_stop=True, save_bounds=True, - print_progress=True, print_func=None, live_points=None, - ) - - def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False, - skip_import_verification=False, check_point=True, n_check_point=None, check_point_delta_t=600, - resume=True, **kwargs): - super(DynamicDynesty, self).__init__(likelihood=likelihood, priors=priors, - outdir=outdir, label=label, use_ratio=use_ratio, - plot=plot, skip_import_verification=skip_import_verification, - **kwargs) + + default_kwargs = dict( + bound="multi", + sample="rwalk", + verbose=True, + check_point_delta_t=600, + first_update=None, + npdim=None, + rstate=None, + queue_size=None, + pool=None, + use_pool=None, + logl_args=None, + logl_kwargs=None, + ptform_args=None, + ptform_kwargs=None, + enlarge=None, + bootstrap=None, + vol_dec=0.5, + vol_check=2.0, + facc=0.5, + slices=5, + walks=None, + update_interval=0.6, + nlive_init=500, + maxiter_init=None, + maxcall_init=None, + dlogz_init=0.01, + logl_max_init=np.inf, + nlive_batch=500, + wt_function=None, + wt_kwargs=None, + maxiter_batch=None, + maxcall_batch=None, + maxiter=None, + maxcall=None, + maxbatch=None, + stop_function=None, + stop_kwargs=None, + use_stop=True, + save_bounds=True, + print_progress=True, + print_func=None, + live_points=None, + ) + + def __init__( + self, + likelihood, + priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + skip_import_verification=False, + check_point=True, + n_check_point=None, + check_point_delta_t=600, + resume=True, + **kwargs, + ): + super(DynamicDynesty, self).__init__( + likelihood=likelihood, + priors=priors, + outdir=outdir, + label=label, + use_ratio=use_ratio, + plot=plot, + skip_import_verification=skip_import_verification, + **kwargs, + ) self.n_check_point = n_check_point self.check_point = check_point self.resume = resume @@ -97,39 +139,59 @@ class DynamicDynesty(Dynesty): # check_point is set to False. if np.isnan(self._log_likelihood_eval_time): self.check_point = False - n_check_point_raw = (check_point_delta_t / self._log_likelihood_eval_time) - n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw))) + n_check_point_raw = check_point_delta_t / self._log_likelihood_eval_time + n_check_point_rnd = int(float(f"{n_check_point_raw:1.0g}")) self.n_check_point = n_check_point_rnd - self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label) - - signal.signal(signal.SIGTERM, self.write_current_state_and_exit) - signal.signal(signal.SIGINT, self.write_current_state_and_exit) + self.resume_file = f"{self.outdir}/{self.label}_resume.pickle" @property def external_sampler_name(self): - return 'dynesty' + return "dynesty" @property def sampler_function_kwargs(self): - keys = ['nlive_init', 'maxiter_init', 'maxcall_init', 'dlogz_init', - 'logl_max_init', 'nlive_batch', 'wt_function', 'wt_kwargs', - 'maxiter_batch', 'maxcall_batch', 'maxiter', 'maxcall', - 'maxbatch', 'stop_function', 'stop_kwargs', 'use_stop', - 'save_bounds', 'print_progress', 'print_func', 'live_points'] + keys = [ + "nlive_init", + "maxiter_init", + "maxcall_init", + "dlogz_init", + "logl_max_init", + "nlive_batch", + "wt_function", + "wt_kwargs", + "maxiter_batch", + "maxcall_batch", + "maxiter", + "maxcall", + "maxbatch", + "stop_function", + "stop_kwargs", + "use_stop", + "save_bounds", + "print_progress", + "print_func", + "live_points", + ] return {key: self.kwargs[key] for key in keys} + @signal_wrapper def run_sampler(self): import dynesty + + self._setup_pool() self.sampler = dynesty.DynamicNestedSampler( - loglikelihood=self.log_likelihood, - prior_transform=self.prior_transform, - ndim=self.ndim, **self.sampler_init_kwargs) + loglikelihood=_log_likelihood_wrapper, + prior_transform=_prior_transform_wrapper, + ndim=self.ndim, + **self.sampler_init_kwargs, + ) if self.check_point: out = self._run_external_sampler_with_checkpointing() else: out = self._run_external_sampler_without_checkpointing() + self._close_pool() # Flushes the output to force a line break if self.kwargs["verbose"]: @@ -147,13 +209,14 @@ class DynamicDynesty(Dynesty): if self.resume: resume = self.read_saved_state(continuing=True) if resume: - logger.info('Resuming from previous run.') + logger.info("Resuming from previous run.") old_ncall = self.sampler.ncall sampler_kwargs = self.sampler_function_kwargs.copy() - sampler_kwargs['maxcall'] = self.n_check_point + sampler_kwargs["maxcall"] = self.n_check_point + self.start_time = datetime.datetime.now() while True: - sampler_kwargs['maxcall'] += self.n_check_point + sampler_kwargs["maxcall"] += self.n_check_point self.sampler.run_nested(**sampler_kwargs) if self.sampler.ncall == old_ncall: break @@ -164,27 +227,8 @@ class DynamicDynesty(Dynesty): self._remove_checkpoint() return self.sampler.results - def write_current_state(self): - """ - """ - import dill - check_directory_exists_and_if_not_mkdir(self.outdir) - with open(self.resume_file, 'wb') as file: - dill.dump(self, file) - - def read_saved_state(self, continuing=False): - """ - """ - import dill - - logger.debug("Reading resume file {}".format(self.resume_file)) - if os.path.isfile(self.resume_file): - with open(self.resume_file, 'rb') as file: - self = dill.load(file) - else: - logger.debug( - "Failed to read resume file {}".format(self.resume_file)) - return False + def write_current_state_and_exit(self, signum=None, frame=None): + Sampler.write_current_state_and_exit(self=self, signum=signum, frame=frame) def _verify_kwargs_against_default_kwargs(self): Sampler._verify_kwargs_against_default_kwargs(self) diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index d9f1ae94bc4995c0b1f0bd88649ff3850e571085..ab0af61be6dc55ffac996931e4ed48a00cc348f1 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -1,62 +1,51 @@ import datetime import os import sys -import signal import time import warnings import numpy as np from pandas import DataFrame +from ..result import rejection_sample from ..utils import ( - logger, check_directory_exists_and_if_not_mkdir, + latex_plot_format, + logger, reflect, safe_file_dump, - latex_plot_format, ) -from .base_sampler import Sampler, NestedSampler -from ..result import rejection_sample - -_likelihood = None -_priors = None -_search_parameter_keys = None -_use_ratio = False - - -def _initialize_global_variables( - likelihood, priors, search_parameter_keys, use_ratio -): - """ - Store a global copy of the likelihood, priors, and search keys for - multiprocessing. - """ - global _likelihood - global _priors - global _search_parameter_keys - global _use_ratio - _likelihood = likelihood - _priors = priors - _search_parameter_keys = search_parameter_keys - _use_ratio = use_ratio +from .base_sampler import NestedSampler, Sampler, signal_wrapper def _prior_transform_wrapper(theta): """Wrapper to the prior transformation. Needed for multiprocessing.""" - return _priors.rescale(_search_parameter_keys, theta) + from .base_sampler import _sampling_convenience_dump + + return _sampling_convenience_dump.priors.rescale( + _sampling_convenience_dump.search_parameter_keys, theta + ) def _log_likelihood_wrapper(theta): """Wrapper to the log likelihood. Needed for multiprocessing.""" - if _priors.evaluate_constraints({ - key: theta[ii] for ii, key in enumerate(_search_parameter_keys) - }): - params = {key: t for key, t in zip(_search_parameter_keys, theta)} - _likelihood.parameters.update(params) - if _use_ratio: - return _likelihood.log_likelihood_ratio() + from .base_sampler import _sampling_convenience_dump + + if _sampling_convenience_dump.priors.evaluate_constraints( + { + key: theta[ii] + for ii, key in enumerate(_sampling_convenience_dump.search_parameter_keys) + } + ): + params = { + key: t + for key, t in zip(_sampling_convenience_dump.search_parameter_keys, theta) + } + _sampling_convenience_dump.likelihood.parameters.update(params) + if _sampling_convenience_dump.use_ratio: + return _sampling_convenience_dump.likelihood.log_likelihood_ratio() else: - return _likelihood.log_likelihood() + return _sampling_convenience_dump.likelihood.log_likelihood() else: return np.nan_to_num(-np.inf) @@ -130,32 +119,77 @@ class Dynesty(NestedSampler): e.g., 'interval-10' prints every ten seconds, this does not print every iteration - else: print to `stdout` at every iteration """ - default_kwargs = dict(bound='multi', sample='rwalk', - periodic=None, reflective=None, - check_point_delta_t=1800, nlive=1000, - first_update=None, walks=100, - npdim=None, rstate=None, queue_size=1, pool=None, - use_pool=None, live_points=None, - logl_args=None, logl_kwargs=None, - ptform_args=None, ptform_kwargs=None, - enlarge=1.5, bootstrap=None, vol_dec=0.5, vol_check=8.0, - facc=0.2, slices=5, - update_interval=None, print_func=None, - dlogz=0.1, maxiter=None, maxcall=None, - logl_max=np.inf, add_live=True, print_progress=True, - save_bounds=False, n_effective=None, - maxmcmc=5000, nact=5, print_method="tqdm") - - def __init__(self, likelihood, priors, outdir='outdir', label='label', - use_ratio=False, plot=False, skip_import_verification=False, - check_point=True, check_point_plot=True, n_check_point=None, - check_point_delta_t=600, resume=True, nestcheck=False, exit_code=130, **kwargs): - - super(Dynesty, self).__init__(likelihood=likelihood, priors=priors, - outdir=outdir, label=label, use_ratio=use_ratio, - plot=plot, skip_import_verification=skip_import_verification, - exit_code=exit_code, - **kwargs) + + default_kwargs = dict( + bound="multi", + sample="rwalk", + print_progress=True, + periodic=None, + reflective=None, + check_point_delta_t=1800, + nlive=1000, + first_update=None, + walks=100, + npdim=None, + rstate=None, + queue_size=1, + pool=None, + use_pool=None, + live_points=None, + logl_args=None, + logl_kwargs=None, + ptform_args=None, + ptform_kwargs=None, + enlarge=1.5, + bootstrap=None, + vol_dec=0.5, + vol_check=8.0, + facc=0.2, + slices=5, + update_interval=None, + print_func=None, + dlogz=0.1, + maxiter=None, + maxcall=None, + logl_max=np.inf, + add_live=True, + save_bounds=False, + n_effective=None, + maxmcmc=5000, + nact=5, + print_method="tqdm", + ) + + def __init__( + self, + likelihood, + priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + skip_import_verification=False, + check_point=True, + check_point_plot=True, + n_check_point=None, + check_point_delta_t=600, + resume=True, + nestcheck=False, + exit_code=130, + **kwargs, + ): + self._translate_kwargs(kwargs) + super(Dynesty, self).__init__( + likelihood=likelihood, + priors=priors, + outdir=outdir, + label=label, + use_ratio=use_ratio, + plot=plot, + skip_import_verification=skip_import_verification, + exit_code=exit_code, + **kwargs, + ) self.n_check_point = n_check_point self.check_point = check_point self.check_point_plot = check_point_plot @@ -169,77 +203,79 @@ class Dynesty(NestedSampler): if self.n_check_point is None: self.n_check_point = 1000 self.check_point_delta_t = check_point_delta_t - logger.info("Checkpoint every check_point_delta_t = {}s" - .format(check_point_delta_t)) + logger.info(f"Checkpoint every check_point_delta_t = {check_point_delta_t}s") - self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label) + self.resume_file = f"{self.outdir}/{self.label}_resume.pickle" self.sampling_time = datetime.timedelta() - try: - signal.signal(signal.SIGTERM, self.write_current_state_and_exit) - signal.signal(signal.SIGINT, self.write_current_state_and_exit) - signal.signal(signal.SIGALRM, self.write_current_state_and_exit) - except AttributeError: - logger.debug( - "Setting signal attributes unavailable on this system. " - "This is likely the case if you are running on a Windows machine" - " and is no further concern.") - def __getstate__(self): - """ For pickle: remove external_sampler, which can be an unpicklable "module" """ + """For pickle: remove external_sampler, which can be an unpicklable "module" """ state = self.__dict__.copy() if "external_sampler" in state: - del state['external_sampler'] + del state["external_sampler"] return state @property def sampler_function_kwargs(self): - keys = ['dlogz', 'print_progress', 'print_func', 'maxiter', - 'maxcall', 'logl_max', 'add_live', 'save_bounds', - 'n_effective'] + keys = [ + "dlogz", + "print_progress", + "print_func", + "maxiter", + "maxcall", + "logl_max", + "add_live", + "save_bounds", + "n_effective", + ] return {key: self.kwargs[key] for key in keys} @property def sampler_init_kwargs(self): - return {key: value - for key, value in self.kwargs.items() - if key not in self.sampler_function_kwargs} + return { + key: value + for key, value in self.kwargs.items() + if key not in self.sampler_function_kwargs + } def _translate_kwargs(self, kwargs): - if 'nlive' not in kwargs: + if "nlive" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: - kwargs['nlive'] = kwargs.pop(equiv) - if 'print_progress' not in kwargs: - if 'verbose' in kwargs: - kwargs['print_progress'] = kwargs.pop('verbose') - if 'walks' not in kwargs: + kwargs["nlive"] = kwargs.pop(equiv) + if "print_progress" not in kwargs: + if "verbose" in kwargs: + kwargs["print_progress"] = kwargs.pop("verbose") + if "walks" not in kwargs: for equiv in self.walks_equiv_kwargs: if equiv in kwargs: - kwargs['walks'] = kwargs.pop(equiv) + kwargs["walks"] = kwargs.pop(equiv) if "queue_size" not in kwargs: for equiv in self.npool_equiv_kwargs: if equiv in kwargs: - kwargs['queue_size'] = kwargs.pop(equiv) + kwargs["queue_size"] = kwargs.pop(equiv) def _verify_kwargs_against_default_kwargs(self): from tqdm.auto import tqdm - if not self.kwargs['walks']: - self.kwargs['walks'] = 100 - if not self.kwargs['update_interval']: - self.kwargs['update_interval'] = int(0.6 * self.kwargs['nlive']) - if self.kwargs['print_func'] is None: - self.kwargs['print_func'] = self._print_func + + if not self.kwargs["walks"]: + self.kwargs["walks"] = 100 + if not self.kwargs["update_interval"]: + self.kwargs["update_interval"] = int(0.6 * self.kwargs["nlive"]) + if self.kwargs["print_func"] is None: + self.kwargs["print_func"] = self._print_func print_method = self.kwargs["print_method"] if print_method == "tqdm" and self.kwargs["print_progress"]: self.pbar = tqdm(file=sys.stdout) elif "interval" in print_method: self._last_print_time = datetime.datetime.now() - self._print_interval = datetime.timedelta(seconds=float(print_method.split("-")[1])) + self._print_interval = datetime.timedelta( + seconds=float(print_method.split("-")[1]) + ) Sampler._verify_kwargs_against_default_kwargs(self) def _print_func(self, results, niter, ncall=None, dlogz=None, *args, **kwargs): - """ Replacing status update for dynesty.result.print_func """ + """Replacing status update for dynesty.result.print_func""" if "interval" in self.kwargs["print_method"]: _time = datetime.datetime.now() if _time - self._last_print_time < self._print_interval: @@ -251,17 +287,31 @@ class Dynesty(NestedSampler): total_time = self.sampling_time + _time - self.start_time # Remove fractional seconds - total_time_str = str(total_time).split('.')[0] + total_time_str = str(total_time).split(".")[0] # Extract results at the current iteration. - (worst, ustar, vstar, loglstar, logvol, logwt, - logz, logzvar, h, nc, worst_it, boundidx, bounditer, - eff, delta_logz) = results + ( + worst, + ustar, + vstar, + loglstar, + logvol, + logwt, + logz, + logzvar, + h, + nc, + worst_it, + boundidx, + bounditer, + eff, + delta_logz, + ) = results # Adjusting outputs for printing. if delta_logz > 1e6: delta_logz = np.inf - if 0. <= logzvar <= 1e6: + if 0.0 <= logzvar <= 1e6: logzerr = np.sqrt(logzvar) else: logzerr = np.nan @@ -271,38 +321,38 @@ class Dynesty(NestedSampler): loglstar = -np.inf if self.use_ratio: - key = 'logz-ratio' + key = "logz-ratio" else: - key = 'logz' + key = "logz" # Constructing output. string = list() - string.append("bound:{:d}".format(bounditer)) - string.append("nc:{:3d}".format(nc)) - string.append("ncall:{:.1e}".format(ncall)) - string.append("eff:{:0.1f}%".format(eff)) - string.append("{}={:0.2f}+/-{:0.2f}".format(key, logz, logzerr)) - string.append("dlogz:{:0.3f}>{:0.2g}".format(delta_logz, dlogz)) + string.append(f"bound:{bounditer:d}") + string.append(f"nc:{nc:3d}") + string.append(f"ncall:{ncall:.1e}") + string.append(f"eff:{eff:0.1f}%") + string.append(f"{key}={logz:0.2f}+/-{logzerr:0.2f}") + string.append(f"dlogz:{delta_logz:0.3f}>{dlogz:0.2g}") if self.kwargs["print_method"] == "tqdm": self.pbar.set_postfix_str(" ".join(string), refresh=False) self.pbar.update(niter - self.pbar.n) elif "interval" in self.kwargs["print_method"]: formatted = " ".join([total_time_str] + string) - print("{}it [{}]".format(niter, formatted), file=sys.stdout, flush=True) + print(f"{niter}it [{formatted}]", file=sys.stdout, flush=True) else: formatted = " ".join([total_time_str] + string) - print("{}it [{}]".format(niter, formatted), file=sys.stdout, flush=True) + print(f"{niter}it [{formatted}]", file=sys.stdout, flush=True) def _apply_dynesty_boundaries(self): self._periodic = list() self._reflective = list() for ii, key in enumerate(self.search_parameter_keys): - if self.priors[key].boundary == 'periodic': - logger.debug("Setting periodic boundary for {}".format(key)) + if self.priors[key].boundary == "periodic": + logger.debug(f"Setting periodic boundary for {key}") self._periodic.append(ii) - elif self.priors[key].boundary == 'reflective': - logger.debug("Setting reflective boundary for {}".format(key)) + elif self.priors[key].boundary == "reflective": + logger.debug(f"Setting reflective boundary for {key}") self._reflective.append(ii) # The periodic kwargs passed into dynesty allows the parameters to @@ -312,61 +362,26 @@ class Dynesty(NestedSampler): self.kwargs["reflective"] = self._reflective def nestcheck_data(self, out_file): - import nestcheck.data_processing import pickle + + import nestcheck.data_processing + ns_run = nestcheck.data_processing.process_dynesty_run(out_file) - nestcheck_result = "{}/{}_nestcheck.pickle".format(self.outdir, self.label) - with open(nestcheck_result, 'wb') as file_nest: + nestcheck_result = f"{self.outdir}/{self.label}_nestcheck.pickle" + with open(nestcheck_result, "wb") as file_nest: pickle.dump(ns_run, file_nest) - def _setup_pool(self): - if self.kwargs["pool"] is not None: - logger.info("Using user defined pool.") - self.pool = self.kwargs["pool"] - elif self.kwargs["queue_size"] > 1: - logger.info( - "Setting up multiproccesing pool with {} processes.".format( - self.kwargs["queue_size"] - ) - ) - import multiprocessing - self.pool = multiprocessing.Pool( - processes=self.kwargs["queue_size"], - initializer=_initialize_global_variables, - initargs=( - self.likelihood, - self.priors, - self._search_parameter_keys, - self.use_ratio - ) - ) - else: - _initialize_global_variables( - likelihood=self.likelihood, - priors=self.priors, - search_parameter_keys=self._search_parameter_keys, - use_ratio=self.use_ratio - ) - self.pool = None - self.kwargs["pool"] = self.pool - - def _close_pool(self): - if getattr(self, "pool", None) is not None: - logger.info("Starting to close worker pool.") - self.pool.close() - self.pool.join() - self.pool = None - self.kwargs["pool"] = self.pool - logger.info("Finished closing worker pool.") - + @signal_wrapper def run_sampler(self): - import dynesty import dill - logger.info("Using dynesty version {}".format(dynesty.__version__)) + import dynesty + + logger.info(f"Using dynesty version {dynesty.__version__}") if self.kwargs.get("sample", "rwalk") == "rwalk": logger.info( - "Using the bilby-implemented rwalk sample method with ACT estimated walks") + "Using the bilby-implemented rwalk sample method with ACT estimated walks" + ) dynesty.dynesty._SAMPLING["rwalk"] = sample_rwalk_bilby dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_bilby if self.kwargs.get("walks") > self.kwargs.get("maxmcmc"): @@ -375,12 +390,10 @@ class Dynesty(NestedSampler): raise DynestySetupError("Unable to run with nact < 1") elif self.kwargs.get("sample") == "rwalk_dynesty": self._kwargs["sample"] = "rwalk" - logger.info( - "Using the dynesty-implemented rwalk sample method") + logger.info("Using the dynesty-implemented rwalk sample method") elif self.kwargs.get("sample") == "rstagger_dynesty": self._kwargs["sample"] = "rstagger" - logger.info( - "Using the dynesty-implemented rstagger sample method") + logger.info("Using the dynesty-implemented rstagger sample method") self._setup_pool() @@ -388,22 +401,25 @@ class Dynesty(NestedSampler): self.resume = self.read_saved_state(continuing=True) if self.resume: - logger.info('Resume file successfully loaded.') + logger.info("Resume file successfully loaded.") else: - if self.kwargs['live_points'] is None: - self.kwargs['live_points'] = ( - self.get_initial_points_from_prior(self.kwargs['nlive']) + if self.kwargs["live_points"] is None: + self.kwargs["live_points"] = self.get_initial_points_from_prior( + self.kwargs["nlive"] ) self.sampler = dynesty.NestedSampler( loglikelihood=_log_likelihood_wrapper, prior_transform=_prior_transform_wrapper, - ndim=self.ndim, **self.sampler_init_kwargs + ndim=self.ndim, + **self.sampler_init_kwargs, ) + self.start_time = datetime.datetime.now() if self.check_point: out = self._run_external_sampler_with_checkpointing() else: out = self._run_external_sampler_without_checkpointing() + self._update_sampling_time() self._close_pool() @@ -417,8 +433,8 @@ class Dynesty(NestedSampler): if self.nestcheck: self.nestcheck_data(out) - dynesty_result = "{}/{}_dynesty.pickle".format(self.outdir, self.label) - with open(dynesty_result, 'wb') as file: + dynesty_result = f"{self.outdir}/{self.label}_dynesty.pickle" + with open(dynesty_result, "wb") as file: dill.dump(out, file) self._generate_result(out) @@ -432,21 +448,23 @@ class Dynesty(NestedSampler): def _generate_result(self, out): import dynesty from scipy.special import logsumexp + logwts = out["logwt"] - weights = np.exp(logwts - out['logz'][-1]) - nested_samples = DataFrame( - out.samples, columns=self.search_parameter_keys) - nested_samples['weights'] = weights - nested_samples['log_likelihood'] = out.logl + weights = np.exp(logwts - out["logz"][-1]) + nested_samples = DataFrame(out.samples, columns=self.search_parameter_keys) + nested_samples["weights"] = weights + nested_samples["log_likelihood"] = out.logl self.result.samples = dynesty.utils.resample_equal(out.samples, weights) self.result.nested_samples = nested_samples self.result.log_likelihood_evaluations = self.reorder_loglikelihoods( - unsorted_loglikelihoods=out.logl, unsorted_samples=out.samples, - sorted_samples=self.result.samples) + unsorted_loglikelihoods=out.logl, + unsorted_samples=out.samples, + sorted_samples=self.result.samples, + ) self.result.log_evidence = out.logz[-1] self.result.log_evidence_err = out.logzerr[-1] self.result.information_gain = out.information[-1] - self.result.num_likelihood_evaluations = getattr(self.sampler, 'ncall', 0) + self.result.num_likelihood_evaluations = getattr(self.sampler, "ncall", 0) logneff = logsumexp(logwts) * 2 - logsumexp(logwts * 2) neffsamples = int(np.exp(logneff)) @@ -454,11 +472,16 @@ class Dynesty(NestedSampler): nlikelihood=self.result.num_likelihood_evaluations, neffsamples=neffsamples, sampling_time_s=self.sampling_time.seconds, - ncores=self.kwargs.get("queue_size", 1) + ncores=self.kwargs.get("queue_size", 1), ) + def _update_sampling_time(self): + end_time = datetime.datetime.now() + self.sampling_time += end_time - self.start_time + self.start_time = end_time + def _run_nested_wrapper(self, kwargs): - """ Wrapper function to run_nested + """Wrapper function to run_nested This wrapper catches exceptions related to different versions of dynesty accepting different arguments. @@ -469,8 +492,7 @@ class Dynesty(NestedSampler): The dictionary of kwargs to pass to run_nested """ - logger.debug("Calling run_nested with sampler_function_kwargs {}" - .format(kwargs)) + logger.debug(f"Calling run_nested with sampler_function_kwargs {kwargs}") try: self.sampler.run_nested(**kwargs) except TypeError: @@ -487,9 +509,8 @@ class Dynesty(NestedSampler): old_ncall = self.sampler.ncall sampler_kwargs = self.sampler_function_kwargs.copy() - sampler_kwargs['maxcall'] = self.n_check_point - sampler_kwargs['add_live'] = True - self.start_time = datetime.datetime.now() + sampler_kwargs["maxcall"] = self.n_check_point + sampler_kwargs["add_live"] = True while True: self._run_nested_wrapper(sampler_kwargs) if self.sampler.ncall == old_ncall: @@ -499,14 +520,16 @@ class Dynesty(NestedSampler): if os.path.isfile(self.resume_file): last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file) else: - last_checkpoint_s = (datetime.datetime.now() - self.start_time).total_seconds() + last_checkpoint_s = ( + datetime.datetime.now() - self.start_time + ).total_seconds() if last_checkpoint_s > self.check_point_delta_t: self.write_current_state() self.plot_current_state() if self.sampler.added_live: self.sampler._remove_live_points() - sampler_kwargs['add_live'] = True + sampler_kwargs["add_live"] = True self._run_nested_wrapper(sampler_kwargs) self.write_current_state() self.plot_current_state() @@ -534,39 +557,41 @@ class Dynesty(NestedSampler): Whether the run is continuing or terminating, if True, the loaded state is mostly written back to disk. """ - from ... import __version__ as bilby_version - from dynesty import __version__ as dynesty_version import dill + from dynesty import __version__ as dynesty_version + + from ... import __version__ as bilby_version + versions = dict(bilby=bilby_version, dynesty=dynesty_version) if os.path.isfile(self.resume_file): - logger.info("Reading resume file {}".format(self.resume_file)) - with open(self.resume_file, 'rb') as file: + logger.info(f"Reading resume file {self.resume_file}") + with open(self.resume_file, "rb") as file: sampler = dill.load(file) if not hasattr(sampler, "versions"): logger.warning( - "The resume file {} is corrupted or the version of " - "bilby has changed between runs. This resume file will " - "be ignored." - .format(self.resume_file) + f"The resume file {self.resume_file} is corrupted or " + "the version of bilby has changed between runs. This " + "resume file will be ignored." ) return False version_warning = ( "The {code} version has changed between runs. " "This may cause unpredictable behaviour and/or failure. " "Old version = {old}, new version = {new}." - ) for code in versions: if not versions[code] == sampler.versions.get(code, None): - logger.warning(version_warning.format( - code=code, - old=sampler.versions.get(code, "None"), - new=versions[code] - )) + logger.warning( + version_warning.format( + code=code, + old=sampler.versions.get(code, "None"), + new=versions[code], + ) + ) del sampler.versions self.sampler = sampler - if self.sampler.added_live and continuing: + if getattr(self.sampler, "added_live", False) and continuing: self.sampler._remove_live_points() self.sampler.nqueue = -1 self.sampler.rstate = np.random @@ -579,27 +604,13 @@ class Dynesty(NestedSampler): self.sampler.M = map return True else: - logger.info( - "Resume file {} does not exist.".format(self.resume_file)) + logger.info(f"Resume file {self.resume_file} does not exist.") return False def write_current_state_and_exit(self, signum=None, frame=None): - """ - Make sure that if a pool of jobs is running only the parent tries to - checkpoint and exit. Only the parent has a 'pool' attribute. - """ - if self.kwargs["queue_size"] == 1 or getattr(self, "pool", None) is not None: - if signum == 14: - logger.info( - "Run interrupted by alarm signal {}: checkpoint and exit on {}" - .format(signum, self.exit_code)) - else: - logger.info( - "Run interrupted by signal {}: checkpoint and exit on {}" - .format(signum, self.exit_code)) - self.write_current_state() - self._close_pool() - os._exit(self.exit_code) + if self.kwargs["print_method"] == "tqdm": + self.pbar.close() + super(Dynesty, self).write_current_state_and_exit(signum=signum, frame=frame) def write_current_state(self): """ @@ -613,29 +624,26 @@ class Dynesty(NestedSampler): normal running. """ - from ... import __version__ as bilby_version - from dynesty import __version__ as dynesty_version import dill + from dynesty import __version__ as dynesty_version + + from ... import __version__ as bilby_version if getattr(self, "sampler", None) is None: # Sampler not initialized, not able to write current state return check_directory_exists_and_if_not_mkdir(self.outdir) - end_time = datetime.datetime.now() - if hasattr(self, 'start_time'): - self.sampling_time += end_time - self.start_time - self.start_time = end_time + if hasattr(self, "start_time"): + self._update_sampling_time() self.sampler.kwargs["sampling_time"] = self.sampling_time self.sampler.kwargs["start_time"] = self.start_time - self.sampler.versions = dict( - bilby=bilby_version, dynesty=dynesty_version - ) + self.sampler.versions = dict(bilby=bilby_version, dynesty=dynesty_version) self.sampler.pool = None self.sampler.M = map if dill.pickles(self.sampler): safe_file_dump(self.sampler, self.resume_file, dill) - logger.info("Written checkpoint file {}".format(self.resume_file)) + logger.info(f"Written checkpoint file {self.resume_file}") else: logger.warning( "Cannot write pickle resume file! " @@ -657,86 +665,108 @@ class Dynesty(NestedSampler): if nsamples < 100: return - filename = "{}/{}_samples.dat".format(self.outdir, self.label) - logger.info("Writing {} current samples to {}".format(nsamples, filename)) + filename = f"{self.outdir}/{self.label}_samples.dat" + logger.info(f"Writing {nsamples} current samples to {filename}") df = DataFrame(samples, columns=self.search_parameter_keys) - df.to_csv(filename, index=False, header=True, sep=' ') + df.to_csv(filename, index=False, header=True, sep=" ") def plot_current_state(self): import matplotlib.pyplot as plt + if self.check_point_plot: import dynesty.plotting as dyplot - labels = [label.replace('_', ' ') for label in self.search_parameter_keys] + + labels = [label.replace("_", " ") for label in self.search_parameter_keys] try: - filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label) + filename = f"{self.outdir}/{self.label}_checkpoint_trace.png" fig = dyplot.traceplot(self.sampler.results, labels=labels)[0] fig.tight_layout() fig.savefig(filename) - except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError, OverflowError, Exception) as e: + except ( + RuntimeError, + np.linalg.linalg.LinAlgError, + ValueError, + OverflowError, + Exception, + ) as e: logger.warning(e) - logger.warning('Failed to create dynesty state plot at checkpoint') + logger.warning("Failed to create dynesty state plot at checkpoint") finally: plt.close("all") try: - filename = "{}/{}_checkpoint_trace_unit.png".format(self.outdir, self.label) + filename = f"{self.outdir}/{self.label}_checkpoint_trace_unit.png" from copy import deepcopy + temp = deepcopy(self.sampler.results) temp["samples"] = temp["samples_u"] fig = dyplot.traceplot(temp, labels=labels)[0] fig.tight_layout() fig.savefig(filename) - except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError, OverflowError, Exception) as e: + except ( + RuntimeError, + np.linalg.linalg.LinAlgError, + ValueError, + OverflowError, + Exception, + ) as e: logger.warning(e) - logger.warning('Failed to create dynesty unit state plot at checkpoint') + logger.warning("Failed to create dynesty unit state plot at checkpoint") finally: plt.close("all") try: - filename = "{}/{}_checkpoint_run.png".format(self.outdir, self.label) + filename = f"{self.outdir}/{self.label}_checkpoint_run.png" fig, axs = dyplot.runplot( - self.sampler.results, logplot=False, use_math_text=False) + self.sampler.results, logplot=False, use_math_text=False + ) fig.tight_layout() plt.savefig(filename) except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError) as e: logger.warning(e) - logger.warning('Failed to create dynesty run plot at checkpoint') + logger.warning("Failed to create dynesty run plot at checkpoint") finally: - plt.close('all') + plt.close("all") try: - filename = "{}/{}_checkpoint_stats.png".format(self.outdir, self.label) + filename = f"{self.outdir}/{self.label}_checkpoint_stats.png" fig, axs = dynesty_stats_plot(self.sampler) fig.tight_layout() plt.savefig(filename) except (RuntimeError, ValueError) as e: logger.warning(e) - logger.warning('Failed to create dynesty stats plot at checkpoint') + logger.warning("Failed to create dynesty stats plot at checkpoint") finally: - plt.close('all') + plt.close("all") def generate_trace_plots(self, dynesty_results): check_directory_exists_and_if_not_mkdir(self.outdir) - filename = '{}/{}_trace.png'.format(self.outdir, self.label) - logger.debug("Writing trace plot to {}".format(filename)) + filename = f"{self.outdir}/{self.label}_trace.png" + logger.debug(f"Writing trace plot to {filename}") from dynesty import plotting as dyplot - fig, axes = dyplot.traceplot(dynesty_results, - labels=self.result.parameter_labels) + + fig, axes = dyplot.traceplot( + dynesty_results, labels=self.result.parameter_labels + ) fig.tight_layout() fig.savefig(filename) def _run_test(self): import dynesty import pandas as pd + self.sampler = dynesty.NestedSampler( loglikelihood=self.log_likelihood, prior_transform=self.prior_transform, - ndim=self.ndim, **self.sampler_init_kwargs) + ndim=self.ndim, + **self.sampler_init_kwargs, + ) sampler_kwargs = self.sampler_function_kwargs.copy() - sampler_kwargs['maxiter'] = 2 + sampler_kwargs["maxiter"] = 2 self.sampler.run_nested(**sampler_kwargs) N = 100 - self.result.samples = pd.DataFrame( - self.priors.sample(N))[self.search_parameter_keys].values + self.result.samples = pd.DataFrame(self.priors.sample(N))[ + self.search_parameter_keys + ].values self.result.nested_samples = self.result.samples self.result.log_likelihood_evaluations = np.ones(N) self.result.log_evidence = 1 @@ -745,7 +775,7 @@ class Dynesty(NestedSampler): return self.result def prior_transform(self, theta): - """ Prior transform method that is passed into the external sampler. + """Prior transform method that is passed into the external sampler. cube we map this back to [0, 1]. Parameters @@ -762,25 +792,24 @@ class Dynesty(NestedSampler): def sample_rwalk_bilby(args): - """ Modified bilby-implemented version of dynesty.sampling.sample_rwalk """ + """Modified bilby-implemented version of dynesty.sampling.sample_rwalk""" from dynesty.utils import unitcheck # Unzipping. - (u, loglstar, axes, scale, - prior_transform, loglikelihood, kwargs) = args + (u, loglstar, axes, scale, prior_transform, loglikelihood, kwargs) = args rstate = np.random # Bounds - nonbounded = kwargs.get('nonbounded', None) - periodic = kwargs.get('periodic', None) - reflective = kwargs.get('reflective', None) + nonbounded = kwargs.get("nonbounded", None) + periodic = kwargs.get("periodic", None) + reflective = kwargs.get("reflective", None) # Setup. n = len(u) - walks = kwargs.get('walks', 100) # minimum number of steps - maxmcmc = kwargs.get('maxmcmc', 5000) # Maximum number of steps - nact = kwargs.get('nact', 5) # Number of ACT - old_act = kwargs.get('old_act', walks) + walks = kwargs.get("walks", 100) # minimum number of steps + maxmcmc = kwargs.get("maxmcmc", 5000) # Maximum number of steps + nact = kwargs.get("nact", 5) # Number of ACT + old_act = kwargs.get("old_act", walks) # Initialize internal variables accept = 0 @@ -848,19 +877,21 @@ def sample_rwalk_bilby(args): if accept + reject > walks: act = estimate_nmcmc( accept_ratio=accept / (accept + reject + nfail), - old_act=old_act, maxmcmc=maxmcmc) + old_act=old_act, + maxmcmc=maxmcmc, + ) # If we've taken too many likelihood evaluations then break if accept + reject > maxmcmc: warnings.warn( - "Hit maximum number of walks {} with accept={}, reject={}, " - "and nfail={} try increasing maxmcmc" - .format(maxmcmc, accept, reject, nfail)) + f"Hit maximum number of walks {maxmcmc} with accept={accept}," + f" reject={reject}, and nfail={nfail} try increasing maxmcmc" + ) break # If the act is finite, pick randomly from within the chain - if np.isfinite(act) and int(.5 * nact * act) < len(u_list): - idx = np.random.randint(int(.5 * nact * act), len(u_list)) + if np.isfinite(act) and int(0.5 * nact * act) < len(u_list): + idx = np.random.randint(int(0.5 * nact * act), len(u_list)) u = u_list[idx] v = v_list[idx] logl = logl_list[idx] @@ -870,7 +901,7 @@ def sample_rwalk_bilby(args): v = prior_transform(u) logl = loglikelihood(v) - blob = {'accept': accept, 'reject': reject, 'fail': nfail, 'scale': scale} + blob = {"accept": accept, "reject": reject, "fail": nfail, "scale": scale} kwargs["old_act"] = act ncall = accept + reject @@ -878,7 +909,7 @@ def sample_rwalk_bilby(args): def estimate_nmcmc(accept_ratio, old_act, maxmcmc, safety=5, tau=None): - """ Estimate autocorrelation length of chain using acceptance fraction + """Estimate autocorrelation length of chain using acceptance fraction Using ACL = (2/acc) - 1 multiplied by a safety margin. Code adapted from CPNest: @@ -905,9 +936,8 @@ def estimate_nmcmc(accept_ratio, old_act, maxmcmc, safety=5, tau=None): if accept_ratio == 0.0: Nmcmc_exact = (1 + 1 / tau) * old_act else: - Nmcmc_exact = ( - (1. - 1. / tau) * old_act + - (safety / tau) * (2. / accept_ratio - 1.) + Nmcmc_exact = (1.0 - 1.0 / tau) * old_act + (safety / tau) * ( + 2.0 / accept_ratio - 1.0 ) Nmcmc_exact = float(min(Nmcmc_exact, maxmcmc)) return max(safety, int(Nmcmc_exact)) @@ -943,7 +973,7 @@ def dynesty_stats_plot(sampler): fig, axs = plt.subplots(nrows=4, figsize=(8, 8)) for ax, name in zip(axs, ["nc", "scale"]): - ax.plot(getattr(sampler, "saved_{}".format(name)), color="blue") + ax.plot(getattr(sampler, f"saved_{name}"), color="blue") ax.set_ylabel(name.title()) lifetimes = np.arange(len(sampler.saved_it)) - sampler.saved_it axs[-2].set_ylabel("Lifetime") @@ -951,9 +981,17 @@ def dynesty_stats_plot(sampler): burn = int(geom(p=1 / nlive).isf(1 / 2 / nlive)) if len(sampler.saved_it) > burn + sampler.nlive: axs[-2].plot(np.arange(0, burn), lifetimes[:burn], color="grey") - axs[-2].plot(np.arange(burn, len(lifetimes) - nlive), lifetimes[burn: -nlive], color="blue") - axs[-2].plot(np.arange(len(lifetimes) - nlive, len(lifetimes)), lifetimes[-nlive:], color="red") - lifetimes = lifetimes[burn: -nlive] + axs[-2].plot( + np.arange(burn, len(lifetimes) - nlive), + lifetimes[burn:-nlive], + color="blue", + ) + axs[-2].plot( + np.arange(len(lifetimes) - nlive, len(lifetimes)), + lifetimes[-nlive:], + color="red", + ) + lifetimes = lifetimes[burn:-nlive] ks_result = ks_1samp(lifetimes, geom(p=1 / nlive).cdf) axs[-1].hist( lifetimes, @@ -961,19 +999,25 @@ def dynesty_stats_plot(sampler): histtype="step", density=True, color="blue", - label=f"p value = {ks_result.pvalue:.3f}" + label=f"p value = {ks_result.pvalue:.3f}", ) axs[-1].plot( np.arange(1, 6 * nlive), geom(p=1 / nlive).pmf(np.arange(1, 6 * nlive)), - color="red" + color="red", ) axs[-1].set_xlim(0, 6 * nlive) axs[-1].legend() axs[-1].set_yscale("log") else: - axs[-2].plot(np.arange(0, len(lifetimes) - nlive), lifetimes[:-nlive], color="grey") - axs[-2].plot(np.arange(len(lifetimes) - nlive, len(lifetimes)), lifetimes[-nlive:], color="red") + axs[-2].plot( + np.arange(0, len(lifetimes) - nlive), lifetimes[:-nlive], color="grey" + ) + axs[-2].plot( + np.arange(len(lifetimes) - nlive, len(lifetimes)), + lifetimes[-nlive:], + color="red", + ) axs[-2].set_yscale("log") axs[-2].set_xlabel("Iteration") axs[-1].set_xlabel("Lifetime") diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index d976d391ed34f0871b3dc3121f6924ee395cee41..5afa169b8959faa4c8e1e2df3a3d7397819ecd2f 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -1,7 +1,5 @@ import os -import signal import shutil -import sys from collections import namedtuple from distutils.version import LooseVersion from shutil import copyfile @@ -9,8 +7,11 @@ from shutil import copyfile import numpy as np from pandas import DataFrame -from ..utils import logger, check_directory_exists_and_if_not_mkdir -from .base_sampler import MCMCSampler, SamplerError +from ..utils import check_directory_exists_and_if_not_mkdir, logger +from .base_sampler import MCMCSampler, SamplerError, signal_wrapper +from .ptemcee import LikePriorEvaluator + +_evaluator = LikePriorEvaluator() class Emcee(MCMCSampler): @@ -45,81 +46,110 @@ class Emcee(MCMCSampler): """ default_kwargs = dict( - nwalkers=500, a=2, args=[], kwargs={}, postargs=None, pool=None, - live_dangerously=False, runtime_sortingfn=None, lnprob0=None, - rstate0=None, blobs0=None, iterations=100, thin=1, storechain=True, - mh_proposal=None) - - def __init__(self, likelihood, priors, outdir='outdir', label='label', - use_ratio=False, plot=False, skip_import_verification=False, - pos0=None, nburn=None, burn_in_fraction=0.25, resume=True, - burn_in_act=3, verbose=True, **kwargs): + nwalkers=500, + a=2, + args=[], + kwargs={}, + postargs=None, + pool=None, + live_dangerously=False, + runtime_sortingfn=None, + lnprob0=None, + rstate0=None, + blobs0=None, + iterations=100, + thin=1, + storechain=True, + mh_proposal=None, + ) + + def __init__( + self, + likelihood, + priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + skip_import_verification=False, + pos0=None, + nburn=None, + burn_in_fraction=0.25, + resume=True, + burn_in_act=3, + **kwargs, + ): import emcee - self.emcee = emcee - if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'): + if LooseVersion(emcee.__version__) > LooseVersion("2.2.1"): self.prerelease = True else: self.prerelease = False super(Emcee, self).__init__( - likelihood=likelihood, priors=priors, outdir=outdir, - label=label, use_ratio=use_ratio, plot=plot, - skip_import_verification=skip_import_verification, **kwargs) - self.emcee = self._check_version() + likelihood=likelihood, + priors=priors, + outdir=outdir, + label=label, + use_ratio=use_ratio, + plot=plot, + skip_import_verification=skip_import_verification, + **kwargs, + ) + self._check_version() self.resume = resume self.pos0 = pos0 self.nburn = nburn self.burn_in_fraction = burn_in_fraction self.burn_in_act = burn_in_act - self.verbose = verbose - - signal.signal(signal.SIGTERM, self.checkpoint_and_exit) - signal.signal(signal.SIGINT, self.checkpoint_and_exit) + self.verbose = kwargs.get("verbose", True) def _check_version(self): import emcee - if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'): + + if LooseVersion(emcee.__version__) > LooseVersion("2.2.1"): self.prerelease = True else: self.prerelease = False return emcee def _translate_kwargs(self, kwargs): - if 'nwalkers' not in kwargs: + if "nwalkers" not in kwargs: for equiv in self.nwalkers_equiv_kwargs: if equiv in kwargs: - kwargs['nwalkers'] = kwargs.pop(equiv) - if 'iterations' not in kwargs: - if 'nsteps' in kwargs: - kwargs['iterations'] = kwargs.pop('nsteps') - if 'threads' in kwargs: - if kwargs['threads'] != 1: - logger.warning("The 'threads' argument cannot be used for " - "parallelisation. This run will proceed " - "without parallelisation, but consider the use " - "of an appropriate Pool object passed to the " - "'pool' keyword.") - kwargs['threads'] = 1 + kwargs["nwalkers"] = kwargs.pop(equiv) + if "iterations" not in kwargs: + if "nsteps" in kwargs: + kwargs["iterations"] = kwargs.pop("nsteps") @property def sampler_function_kwargs(self): - keys = ['lnprob0', 'rstate0', 'blobs0', 'iterations', 'thin', - 'storechain', 'mh_proposal'] + keys = [ + "lnprob0", + "rstate0", + "blobs0", + "iterations", + "thin", + "storechain", + "mh_proposal", + ] # updated function keywords for emcee > v2.2.1 - updatekeys = {'p0': 'initial_state', - 'lnprob0': 'log_prob0', - 'storechain': 'store'} + updatekeys = { + "p0": "initial_state", + "lnprob0": "log_prob0", + "storechain": "store", + } function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} - function_kwargs['p0'] = self.pos0 + function_kwargs["p0"] = self.pos0 if self.prerelease: - if function_kwargs['mh_proposal'] is not None: - logger.warning("The 'mh_proposal' option is no longer used " - "in emcee v{}, and will be ignored.".format( - self.emcee.__version__)) - del function_kwargs['mh_proposal'] + if function_kwargs["mh_proposal"] is not None: + logger.warning( + "The 'mh_proposal' option is no longer used " + "in emcee > 2.2.1, and will be ignored." + ) + del function_kwargs["mh_proposal"] for key in updatekeys: if updatekeys[key] not in function_kwargs: @@ -131,37 +161,30 @@ class Emcee(MCMCSampler): @property def sampler_init_kwargs(self): - init_kwargs = {key: value - for key, value in self.kwargs.items() - if key not in self.sampler_function_kwargs} + init_kwargs = { + key: value + for key, value in self.kwargs.items() + if key not in self.sampler_function_kwargs + } - init_kwargs['lnpostfn'] = self.lnpostfn - init_kwargs['dim'] = self.ndim + init_kwargs["lnpostfn"] = _evaluator.call_emcee + init_kwargs["dim"] = self.ndim # updated init keywords for emcee > v2.2.1 - updatekeys = {'dim': 'ndim', - 'lnpostfn': 'log_prob_fn'} + updatekeys = {"dim": "ndim", "lnpostfn": "log_prob_fn"} if self.prerelease: for key in updatekeys: if key in init_kwargs: init_kwargs[updatekeys[key]] = init_kwargs.pop(key) - oldfunckeys = ['p0', 'lnprob0', 'storechain', 'mh_proposal'] + oldfunckeys = ["p0", "lnprob0", "storechain", "mh_proposal"] for key in oldfunckeys: if key in init_kwargs: del init_kwargs[key] return init_kwargs - def lnpostfn(self, theta): - log_prior = self.log_prior(theta) - if np.isinf(log_prior): - return -np.inf, [np.nan, np.nan] - else: - log_likelihood = self.log_likelihood(theta) - return log_likelihood + log_prior, [log_likelihood, log_prior] - @property def nburn(self): if type(self.__nburn) in [float, int]: @@ -174,52 +197,54 @@ class Emcee(MCMCSampler): @nburn.setter def nburn(self, nburn): if isinstance(nburn, (float, int)): - if nburn > self.kwargs['iterations'] - 1: - raise ValueError('Number of burn-in samples must be smaller ' - 'than the total number of iterations') + if nburn > self.kwargs["iterations"] - 1: + raise ValueError( + "Number of burn-in samples must be smaller " + "than the total number of iterations" + ) self.__nburn = nburn @property def nwalkers(self): - return self.kwargs['nwalkers'] + return self.kwargs["nwalkers"] @property def nsteps(self): - return self.kwargs['iterations'] + return self.kwargs["iterations"] @nsteps.setter def nsteps(self, nsteps): - self.kwargs['iterations'] = nsteps + self.kwargs["iterations"] = nsteps @property def stored_chain(self): - """ Read the stored zero-temperature chain data in from disk """ + """Read the stored zero-temperature chain data in from disk""" return np.genfromtxt(self.checkpoint_info.chain_file, names=True) @property def stored_samples(self): - """ Returns the samples stored on disk """ + """Returns the samples stored on disk""" return self.stored_chain[self.search_parameter_keys] @property def stored_loglike(self): - """ Returns the log-likelihood stored on disk """ - return self.stored_chain['log_l'] + """Returns the log-likelihood stored on disk""" + return self.stored_chain["log_l"] @property def stored_logprior(self): - """ Returns the log-prior stored on disk """ - return self.stored_chain['log_p'] + """Returns the log-prior stored on disk""" + return self.stored_chain["log_p"] def _init_chain_file(self): with open(self.checkpoint_info.chain_file, "w+") as ff: - ff.write('walker\t{}\tlog_l\tlog_p\n'.format( - '\t'.join(self.search_parameter_keys))) + search_keys_str = "\t".join(self.search_parameter_keys) + ff.write(f"walker\t{search_keys_str}\tlog_l\tlog_p\n") @property def checkpoint_info(self): - """ Defines various things related to checkpointing and storing data + """Defines various things related to checkpointing and storing data Returns ======= @@ -231,21 +256,25 @@ class Emcee(MCMCSampler): """ out_dir = os.path.join( - self.outdir, '{}_{}'.format(self.__class__.__name__.lower(), - self.label)) + self.outdir, f"{self.__class__.__name__.lower()}_{self.label}" + ) check_directory_exists_and_if_not_mkdir(out_dir) - chain_file = os.path.join(out_dir, 'chain.dat') - sampler_file = os.path.join(out_dir, 'sampler.pickle') - chain_template =\ - '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n' + chain_file = os.path.join(out_dir, "chain.dat") + sampler_file = os.path.join(out_dir, "sampler.pickle") + chain_template = ( + "{:d}" + "\t{:.9e}" * (len(self.search_parameter_keys) + 2) + "\n" + ) CheckpointInfo = namedtuple( - 'CheckpointInfo', ['sampler_file', 'chain_file', 'chain_template']) + "CheckpointInfo", ["sampler_file", "chain_file", "chain_template"] + ) checkpoint_info = CheckpointInfo( - sampler_file=sampler_file, chain_file=chain_file, - chain_template=chain_template) + sampler_file=sampler_file, + chain_file=chain_file, + chain_template=chain_template, + ) return checkpoint_info @@ -254,43 +283,48 @@ class Emcee(MCMCSampler): nsteps = self._previous_iterations return self.sampler.chain[:, :nsteps, :] - def checkpoint(self): - """ Writes a pickle file of the sampler to disk using dill """ + def write_current_state(self): + """Writes a pickle file of the sampler to disk using dill""" import dill - logger.info("Checkpointing sampler to file {}" - .format(self.checkpoint_info.sampler_file)) - with open(self.checkpoint_info.sampler_file, 'wb') as f: + + logger.info( + f"Checkpointing sampler to file {self.checkpoint_info.sampler_file}" + ) + with open(self.checkpoint_info.sampler_file, "wb") as f: # Overwrites the stored sampler chain with one that is truncated # to only the completed steps self.sampler._chain = self.sampler_chain + _pool = self.sampler.pool + self.sampler.pool = None dill.dump(self._sampler, f) - - def checkpoint_and_exit(self, signum, frame): - logger.info("Received signal {}".format(signum)) - self.checkpoint() - sys.exit() + self.sampler.pool = _pool def _initialise_sampler(self): - self._sampler = self.emcee.EnsembleSampler(**self.sampler_init_kwargs) + from emcee import EnsembleSampler + + self._sampler = EnsembleSampler(**self.sampler_init_kwargs) self._init_chain_file() @property def sampler(self): - """ Returns the emcee sampler object + """Returns the emcee sampler object If, already initialized, returns the stored _sampler value. Otherwise, first checks if there is a pickle file from which to load. If there is not, then initialize the sampler and set the initial random draw """ - if hasattr(self, '_sampler'): + if hasattr(self, "_sampler"): pass elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file): import dill - logger.info("Resuming run from checkpoint file {}" - .format(self.checkpoint_info.sampler_file)) - with open(self.checkpoint_info.sampler_file, 'rb') as f: + + logger.info( + f"Resuming run from checkpoint file {self.checkpoint_info.sampler_file}" + ) + with open(self.checkpoint_info.sampler_file, "rb") as f: self._sampler = dill.load(f) + self._sampler.pool = self.pool self._set_pos0_for_resume() else: self._initialise_sampler() @@ -299,7 +333,7 @@ class Emcee(MCMCSampler): def write_chains_to_file(self, sample): chain_file = self.checkpoint_info.chain_file - temp_chain_file = chain_file + '.temp' + temp_chain_file = chain_file + ".temp" if os.path.isfile(chain_file): copyfile(chain_file, temp_chain_file) if self.prerelease: @@ -313,7 +347,7 @@ class Emcee(MCMCSampler): @property def _previous_iterations(self): - """ Returns the number of iterations that the sampler has saved + """Returns the number of iterations that the sampler has saved This is used when loading in a sampler from a pickle file to figure out how much of the run has already been completed @@ -325,7 +359,8 @@ class Emcee(MCMCSampler): def _draw_pos0_from_prior(self): return np.array( - [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]) + [self.get_random_draw_from_prior() for _ in range(self.nwalkers)] + ) @property def _pos0_shape(self): @@ -340,8 +375,7 @@ class Emcee(MCMCSampler): self.pos0 = np.squeeze(self.pos0) if self.pos0.shape != self._pos0_shape: - raise ValueError( - 'Input pos0 should be of shape ndim, nwalkers') + raise ValueError("Input pos0 should be of shape ndim, nwalkers") logger.debug("Checking input pos0") for draw in self.pos0: self.check_draw(draw) @@ -352,38 +386,39 @@ class Emcee(MCMCSampler): def _set_pos0_for_resume(self): self.pos0 = self.sampler.chain[:, -1, :] + @signal_wrapper def run_sampler(self): + self._setup_pool() from tqdm.auto import tqdm + sampler_function_kwargs = self.sampler_function_kwargs - iterations = sampler_function_kwargs.pop('iterations') + iterations = sampler_function_kwargs.pop("iterations") iterations -= self._previous_iterations if self.prerelease: - sampler_function_kwargs['initial_state'] = self.pos0 + sampler_function_kwargs["initial_state"] = self.pos0 else: - sampler_function_kwargs['p0'] = self.pos0 + sampler_function_kwargs["p0"] = self.pos0 # main iteration loop - iterator = self.sampler.sample( - iterations=iterations, **sampler_function_kwargs - ) + iterator = self.sampler.sample(iterations=iterations, **sampler_function_kwargs) if self.verbose: iterator = tqdm(iterator, total=iterations) for sample in iterator: self.write_chains_to_file(sample) if self.verbose: iterator.close() - self.checkpoint() + self.write_current_state() self.result.sampler_output = np.nan - self.calculate_autocorrelation( - self.sampler.chain.reshape((-1, self.ndim))) + self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim))) self.print_nburn_logging_info() self._generate_result() - self.result.samples = self.sampler.chain[:, self.nburn:, :].reshape( - (-1, self.ndim)) + self.result.samples = self.sampler.chain[:, self.nburn :, :].reshape( + (-1, self.ndim) + ) self.result.walkers = self.sampler.chain return self.result @@ -393,10 +428,11 @@ class Emcee(MCMCSampler): if self.result.nburn > self.nsteps: raise SamplerError( "The run has finished, but the chain is not burned in: " - "`nburn < nsteps` ({} < {}). Try increasing the " - "number of steps.".format(self.result.nburn, self.nsteps)) + f"`nburn < nsteps` ({self.result.nburn} < {self.nsteps})." + " Try increasing the number of steps." + ) blobs = np.array(self.sampler.blobs) - blobs_trimmed = blobs[self.nburn:, :, :].reshape((-1, 2)) + blobs_trimmed = blobs[self.nburn :, :, :].reshape((-1, 2)) log_likelihoods, log_priors = blobs_trimmed.T self.result.log_likelihood_evaluations = log_likelihoods self.result.log_prior_evaluations = log_priors diff --git a/bilby/core/sampler/fake_sampler.py b/bilby/core/sampler/fake_sampler.py index 8d218472d13f7769c4b88cdd0155cb8f7c04dd0d..5f375fdbad8055e6a5bdaf7dd7e99caabe330f33 100644 --- a/bilby/core/sampler/fake_sampler.py +++ b/bilby/core/sampler/fake_sampler.py @@ -1,8 +1,7 @@ - import numpy as np -from .base_sampler import Sampler from ..result import read_in_result +from .base_sampler import Sampler class FakeSampler(Sampler): @@ -17,17 +16,38 @@ class FakeSampler(Sampler): sample_file: str A string pointing to the posterior data file to be loaded. """ - default_kwargs = dict(verbose=True, logl_args=None, logl_kwargs=None, - print_progress=True) - - def __init__(self, likelihood, priors, sample_file, outdir='outdir', - label='label', use_ratio=False, plot=False, - injection_parameters=None, meta_data=None, result_class=None, - **kwargs): - super(FakeSampler, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label, - use_ratio=False, plot=False, skip_import_verification=True, - injection_parameters=None, meta_data=None, result_class=None, - **kwargs) + + default_kwargs = dict( + verbose=True, logl_args=None, logl_kwargs=None, print_progress=True + ) + + def __init__( + self, + likelihood, + priors, + sample_file, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + injection_parameters=None, + meta_data=None, + result_class=None, + **kwargs + ): + super(FakeSampler, self).__init__( + likelihood=likelihood, + priors=priors, + outdir=outdir, + label=label, + use_ratio=False, + plot=False, + skip_import_verification=True, + injection_parameters=None, + meta_data=None, + result_class=None, + **kwargs + ) self._read_parameter_list_from_file(sample_file) self.result.outdir = outdir self.result.label = label @@ -41,7 +61,7 @@ class FakeSampler(Sampler): def run_sampler(self): """Compute the likelihood for the list of parameter space points.""" - self.sampler = 'fake_sampler' + self.sampler = "fake_sampler" # Flushes the output to force a line break if self.kwargs["verbose"]: @@ -59,8 +79,12 @@ class FakeSampler(Sampler): likelihood_ratios.append(logl) if self.kwargs["verbose"]: - print(self.likelihood.parameters['log_likelihood'], likelihood_ratios[-1], - self.likelihood.parameters['log_likelihood'] - likelihood_ratios[-1]) + print( + self.likelihood.parameters["log_likelihood"], + likelihood_ratios[-1], + self.likelihood.parameters["log_likelihood"] + - likelihood_ratios[-1], + ) self.result.log_likelihood_evaluations = np.array(likelihood_ratios) diff --git a/bilby/core/sampler/kombine.py b/bilby/core/sampler/kombine.py index 83947fc88378c5401508eac458192141cd9f221e..1f09387cc33520a7c8408db7cd7af2a924fa85cf 100644 --- a/bilby/core/sampler/kombine.py +++ b/bilby/core/sampler/kombine.py @@ -2,8 +2,12 @@ import os import numpy as np -from .emcee import Emcee from ..utils import logger +from .base_sampler import signal_wrapper +from .emcee import Emcee +from .ptemcee import LikePriorEvaluator + +_evaluator = LikePriorEvaluator() class Kombine(Emcee): @@ -35,21 +39,61 @@ class Kombine(Emcee): """ - default_kwargs = dict(nwalkers=500, args=[], pool=None, transd=False, - lnpost0=None, blob0=None, iterations=500, storechain=True, processes=1, update_interval=None, - kde=None, kde_size=None, spaces=None, freeze_transd=False, test_steps=16, critical_pval=0.05, - max_steps=None, burnin_verbose=False) - - def __init__(self, likelihood, priors, outdir='outdir', label='label', - use_ratio=False, plot=False, skip_import_verification=False, - pos0=None, nburn=None, burn_in_fraction=0.25, resume=True, - burn_in_act=3, autoburnin=False, **kwargs): - super(Kombine, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label, - use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification, - pos0=pos0, nburn=nburn, burn_in_fraction=burn_in_fraction, - burn_in_act=burn_in_act, resume=resume, **kwargs) - - if self.kwargs['nwalkers'] > self.kwargs['iterations']: + default_kwargs = dict( + nwalkers=500, + args=[], + pool=None, + transd=False, + lnpost0=None, + blob0=None, + iterations=500, + storechain=True, + processes=1, + update_interval=None, + kde=None, + kde_size=None, + spaces=None, + freeze_transd=False, + test_steps=16, + critical_pval=0.05, + max_steps=None, + burnin_verbose=False, + ) + + def __init__( + self, + likelihood, + priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + skip_import_verification=False, + pos0=None, + nburn=None, + burn_in_fraction=0.25, + resume=True, + burn_in_act=3, + autoburnin=False, + **kwargs, + ): + super(Kombine, self).__init__( + likelihood=likelihood, + priors=priors, + outdir=outdir, + label=label, + use_ratio=use_ratio, + plot=plot, + skip_import_verification=skip_import_verification, + pos0=pos0, + nburn=nburn, + burn_in_fraction=burn_in_fraction, + burn_in_act=burn_in_act, + resume=resume, + **kwargs, + ) + + if self.kwargs["nwalkers"] > self.kwargs["iterations"]: raise ValueError("Kombine Sampler requires Iterations be > nWalkers") self.autoburnin = autoburnin @@ -57,42 +101,34 @@ class Kombine(Emcee): # set prerelease to False to prevent checks for newer emcee versions in parent class self.prerelease = False - def _translate_kwargs(self, kwargs): - if 'nwalkers' not in kwargs: - for equiv in self.nwalkers_equiv_kwargs: - if equiv in kwargs: - kwargs['nwalkers'] = kwargs.pop(equiv) - if 'iterations' not in kwargs: - if 'nsteps' in kwargs: - kwargs['iterations'] = kwargs.pop('nsteps') - # make sure processes kwarg is 1 - if 'processes' in kwargs: - if kwargs['processes'] != 1: - logger.warning("The 'processes' argument cannot be used for " - "parallelisation. This run will proceed " - "without parallelisation, but consider the use " - "of an appropriate Pool object passed to the " - "'pool' keyword.") - kwargs['processes'] = 1 - @property def sampler_function_kwargs(self): - keys = ['lnpost0', 'blob0', 'iterations', 'storechain', 'lnprop0', 'update_interval', 'kde', - 'kde_size', 'spaces', 'freeze_transd'] + keys = [ + "lnpost0", + "blob0", + "iterations", + "storechain", + "lnprop0", + "update_interval", + "kde", + "kde_size", + "spaces", + "freeze_transd", + ] function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} - function_kwargs['p0'] = self.pos0 + function_kwargs["p0"] = self.pos0 return function_kwargs @property def sampler_burnin_kwargs(self): - extra_keys = ['test_steps', 'critical_pval', 'max_steps', 'burnin_verbose'] - removal_keys = ['iterations', 'spaces', 'freeze_transd'] + extra_keys = ["test_steps", "critical_pval", "max_steps", "burnin_verbose"] + removal_keys = ["iterations", "spaces", "freeze_transd"] burnin_kwargs = self.sampler_function_kwargs.copy() for key in extra_keys: if key in self.kwargs: burnin_kwargs[key] = self.kwargs[key] - if 'burnin_verbose' in burnin_kwargs.keys(): - burnin_kwargs['verbose'] = burnin_kwargs.pop('burnin_verbose') + if "burnin_verbose" in burnin_kwargs.keys(): + burnin_kwargs["verbose"] = burnin_kwargs.pop("burnin_verbose") for key in removal_keys: if key in burnin_kwargs.keys(): burnin_kwargs.pop(key) @@ -100,19 +136,21 @@ class Kombine(Emcee): @property def sampler_init_kwargs(self): - init_kwargs = {key: value - for key, value in self.kwargs.items() - if key not in self.sampler_function_kwargs and key not in self.sampler_burnin_kwargs} + init_kwargs = { + key: value + for key, value in self.kwargs.items() + if key not in self.sampler_function_kwargs + and key not in self.sampler_burnin_kwargs + } init_kwargs.pop("burnin_verbose") - init_kwargs['lnpostfn'] = self.lnpostfn - init_kwargs['ndim'] = self.ndim + init_kwargs["lnpostfn"] = _evaluator.call_emcee + init_kwargs["ndim"] = self.ndim - # have to make sure pool is None so sampler will be pickleable - init_kwargs['pool'] = None return init_kwargs def _initialise_sampler(self): import kombine + self._sampler = kombine.Sampler(**self.sampler_init_kwargs) self._init_chain_file() @@ -129,7 +167,9 @@ class Kombine(Emcee): def check_resume(self): return self.resume and os.path.isfile(self.checkpoint_info.sampler_file) + @signal_wrapper def run_sampler(self): + self._setup_pool() if self.autoburnin: if self.check_resume(): logger.info("Resuming with autoburnin=True skips burnin process:") @@ -138,29 +178,50 @@ class Kombine(Emcee): self.sampler.burnin(**self.sampler_burnin_kwargs) self.kwargs["iterations"] += self._previous_iterations self.nburn = self._previous_iterations - logger.info("Kombine auto-burnin complete. Removing {} samples from chains".format(self.nburn)) + logger.info( + f"Kombine auto-burnin complete. Removing {self.nburn} samples from chains" + ) self._set_pos0_for_resume() from tqdm.auto import tqdm + sampler_function_kwargs = self.sampler_function_kwargs - iterations = sampler_function_kwargs.pop('iterations') + iterations = sampler_function_kwargs.pop("iterations") iterations -= self._previous_iterations - sampler_function_kwargs['p0'] = self.pos0 + sampler_function_kwargs["p0"] = self.pos0 for sample in tqdm( - self.sampler.sample(iterations=iterations, **sampler_function_kwargs), - total=iterations): + self.sampler.sample(iterations=iterations, **sampler_function_kwargs), + total=iterations, + ): self.write_chains_to_file(sample) - self.checkpoint() + self.write_current_state() self.result.sampler_output = np.nan if not self.autoburnin: tmp_chain = self.sampler.chain.copy() self.calculate_autocorrelation(tmp_chain.reshape((-1, self.ndim))) self.print_nburn_logging_info() + self._close_pool() self._generate_result() self.result.log_evidence_err = np.nan - tmp_chain = self.sampler.chain[self.nburn:, :, :].copy() + tmp_chain = self.sampler.chain[self.nburn :, :, :].copy() self.result.samples = tmp_chain.reshape((-1, self.ndim)) - self.result.walkers = self.sampler.chain.reshape((self.nwalkers, self.nsteps, self.ndim)) + self.result.walkers = self.sampler.chain.reshape( + (self.nwalkers, self.nsteps, self.ndim) + ) return self.result + + def _setup_pool(self): + from kombine import SerialPool + + super(Kombine, self)._setup_pool() + if self.pool is None: + self.pool = SerialPool() + + def _close_pool(self): + from kombine import SerialPool + + if isinstance(self.pool, SerialPool): + self.pool = None + super(Kombine, self)._close_pool() diff --git a/bilby/core/sampler/nessai.py b/bilby/core/sampler/nessai.py index d8bb578a4d681c8aca425cd0b4825c651dfa73e0..fdee87b058647a47a5671aed9a9ab14d23ad0573 100644 --- a/bilby/core/sampler/nessai.py +++ b/bilby/core/sampler/nessai.py @@ -1,9 +1,10 @@ -import numpy as np import os + +import numpy as np from pandas import DataFrame +from ..utils import check_directory_exists_and_if_not_mkdir, load_json, logger from .base_sampler import NestedSampler -from ..utils import logger, check_directory_exists_and_if_not_mkdir, load_json class Nessai(NestedSampler): @@ -16,8 +17,9 @@ class Nessai(NestedSampler): Documentation: https://nessai.readthedocs.io/ """ + _default_kwargs = None - seed_equiv_kwargs = ['sampling_seed'] + seed_equiv_kwargs = ["sampling_seed"] @property def default_kwargs(self): @@ -29,6 +31,7 @@ class Nessai(NestedSampler): """ if not self._default_kwargs: from inspect import signature + from nessai.flowsampler import FlowSampler from nessai.nestedsampler import NestedSampler from nessai.proposal import AugmentedFlowProposal, FlowProposal @@ -42,12 +45,14 @@ class Nessai(NestedSampler): ] for c in classes: kwargs.update( - {k: v.default for k, v in signature(c).parameters.items() if v.default is not v.empty} + { + k: v.default + for k, v in signature(c).parameters.items() + if v.default is not v.empty + } ) # Defaults for bilby that will override nessai defaults - bilby_defaults = dict( - output=None, - ) + bilby_defaults = dict(output=None, exit_code=self.exit_code) kwargs.update(bilby_defaults) self._default_kwargs = kwargs return self._default_kwargs @@ -69,8 +74,8 @@ class Nessai(NestedSampler): def run_sampler(self): from nessai.flowsampler import FlowSampler - from nessai.model import Model as BaseModel from nessai.livepoint import dict_to_live_points, live_points_to_array + from nessai.model import Model as BaseModel from nessai.posterior import compute_weights from nessai.utils import setup_logger @@ -85,6 +90,7 @@ class Nessai(NestedSampler): Priors to use for sampling. Needed for the bounds and the `sample` method. """ + def __init__(self, names, priors): self.names = names self.priors = priors @@ -103,8 +109,10 @@ class Nessai(NestedSampler): return self.log_prior(theta) def _update_bounds(self): - self.bounds = {key: [self.priors[key].minimum, self.priors[key].maximum] - for key in self.names} + self.bounds = { + key: [self.priors[key].minimum, self.priors[key].maximum] + for key in self.names + } def new_point(self, N=1): """Draw a point from the prior""" @@ -117,20 +125,22 @@ class Nessai(NestedSampler): return self.log_prior(x) # Setup the logger for nessai using the same settings as the bilby logger - setup_logger(self.outdir, label=self.label, - log_level=logger.getEffectiveLevel()) + setup_logger( + self.outdir, label=self.label, log_level=logger.getEffectiveLevel() + ) model = Model(self.search_parameter_keys, self.priors) - out = None - while out is None: - try: - out = FlowSampler(model, **self.kwargs) - except TypeError as e: - raise TypeError("Unable to initialise nessai sampler with error: {}".format(e)) try: + out = FlowSampler(model, **self.kwargs) out.run(save=True, plot=self.plot) - except SystemExit as e: + except TypeError as e: + raise TypeError(f"Unable to initialise nessai sampler with error: {e}") + except (SystemExit, KeyboardInterrupt) as e: import sys - logger.info("Caught exit code {}, exiting with signal {}".format(e.args[0], self.exit_code)) + + logger.info( + f"Caught {type(e).__name__} with args {e.args}, " + f"exiting with signal {self.exit_code}" + ) sys.exit(self.exit_code) # Manually set likelihood evaluations because parallelisation breaks the counter @@ -139,53 +149,61 @@ class Nessai(NestedSampler): self.result.samples = live_points_to_array( out.posterior_samples, self.search_parameter_keys ) - self.result.log_likelihood_evaluations = out.posterior_samples['logL'] + self.result.log_likelihood_evaluations = out.posterior_samples["logL"] self.result.nested_samples = DataFrame(out.nested_samples) self.result.nested_samples.rename( - columns=dict(logL='log_likelihood', logP='log_prior'), inplace=True) - _, log_weights = compute_weights(np.array(self.result.nested_samples.log_likelihood), - np.array(out.ns.state.nlive)) - self.result.nested_samples['weights'] = np.exp(log_weights) + columns=dict(logL="log_likelihood", logP="log_prior"), inplace=True + ) + _, log_weights = compute_weights( + np.array(self.result.nested_samples.log_likelihood), + np.array(out.ns.state.nlive), + ) + self.result.nested_samples["weights"] = np.exp(log_weights) self.result.log_evidence = out.ns.log_evidence self.result.log_evidence_err = np.sqrt(out.ns.information / out.ns.nlive) return self.result def _translate_kwargs(self, kwargs): - if 'nlive' not in kwargs: + if "nlive" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: - kwargs['nlive'] = kwargs.pop(equiv) - if 'n_pool' not in kwargs: + kwargs["nlive"] = kwargs.pop(equiv) + if "n_pool" not in kwargs: for equiv in self.npool_equiv_kwargs: if equiv in kwargs: - kwargs['n_pool'] = kwargs.pop(equiv) - if 'seed' not in kwargs: + kwargs["n_pool"] = kwargs.pop(equiv) + if "n_pool" not in kwargs: + kwargs["n_pool"] = self._npool + if "seed" not in kwargs: for equiv in self.seed_equiv_kwargs: if equiv in kwargs: - kwargs['seed'] = kwargs.pop(equiv) + kwargs["seed"] = kwargs.pop(equiv) def _verify_kwargs_against_default_kwargs(self): """ Set the directory where the output will be written and check resume and checkpoint status. """ - if 'config_file' in self.kwargs: - d = load_json(self.kwargs['config_file'], None) + if "config_file" in self.kwargs: + d = load_json(self.kwargs["config_file"], None) self.kwargs.update(d) - self.kwargs.pop('config_file') + self.kwargs.pop("config_file") - if not self.kwargs['plot']: - self.kwargs['plot'] = self.plot + if not self.kwargs["plot"]: + self.kwargs["plot"] = self.plot - if self.kwargs['n_pool'] == 1 and self.kwargs['max_threads'] == 1: - logger.warning('Setting pool to None (n_pool=1 & max_threads=1)') - self.kwargs['n_pool'] = None + if self.kwargs["n_pool"] == 1 and self.kwargs["max_threads"] == 1: + logger.warning("Setting pool to None (n_pool=1 & max_threads=1)") + self.kwargs["n_pool"] = None - if not self.kwargs['output']: - self.kwargs['output'] = os.path.join( - self.outdir, '{}_nessai'.format(self.label), '' + if not self.kwargs["output"]: + self.kwargs["output"] = os.path.join( + self.outdir, f"{self.label}_nessai", "" ) - check_directory_exists_and_if_not_mkdir(self.kwargs['output']) + check_directory_exists_and_if_not_mkdir(self.kwargs["output"]) NestedSampler._verify_kwargs_against_default_kwargs(self) + + def _setup_pool(self): + pass diff --git a/bilby/core/sampler/nestle.py b/bilby/core/sampler/nestle.py index f598d8b1751b217ae9515019c5963001eb0da840..2ea8787a63a85b62101e6f0d193bae85765f3884 100644 --- a/bilby/core/sampler/nestle.py +++ b/bilby/core/sampler/nestle.py @@ -1,8 +1,7 @@ - import numpy as np from pandas import DataFrame -from .base_sampler import NestedSampler +from .base_sampler import NestedSampler, signal_wrapper class Nestle(NestedSampler): @@ -25,30 +24,44 @@ class Nestle(NestedSampler): sampling """ - default_kwargs = dict(verbose=True, method='multi', npoints=500, - update_interval=None, npdim=None, maxiter=None, - maxcall=None, dlogz=None, decline_factor=None, - rstate=None, callback=None, steps=20, enlarge=1.2) + + default_kwargs = dict( + verbose=True, + method="multi", + npoints=500, + update_interval=None, + npdim=None, + maxiter=None, + maxcall=None, + dlogz=None, + decline_factor=None, + rstate=None, + callback=None, + steps=20, + enlarge=1.2, + ) def _translate_kwargs(self, kwargs): - if 'npoints' not in kwargs: + if "npoints" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: - kwargs['npoints'] = kwargs.pop(equiv) - if 'steps' not in kwargs: + kwargs["npoints"] = kwargs.pop(equiv) + if "steps" not in kwargs: for equiv in self.walks_equiv_kwargs: if equiv in kwargs: - kwargs['steps'] = kwargs.pop(equiv) + kwargs["steps"] = kwargs.pop(equiv) def _verify_kwargs_against_default_kwargs(self): - if self.kwargs['verbose']: + if self.kwargs["verbose"]: import nestle - self.kwargs['callback'] = nestle.print_progress - self.kwargs.pop('verbose') + + self.kwargs["callback"] = nestle.print_progress + self.kwargs.pop("verbose") NestedSampler._verify_kwargs_against_default_kwargs(self) + @signal_wrapper def run_sampler(self): - """ Runs Nestle sampler with given kwargs and returns the result + """Runs Nestle sampler with given kwargs and returns the result Returns ======= @@ -56,21 +69,27 @@ class Nestle(NestedSampler): """ import nestle + out = nestle.sample( loglikelihood=self.log_likelihood, prior_transform=self.prior_transform, - ndim=self.ndim, **self.kwargs) + ndim=self.ndim, + **self.kwargs + ) print("") self.result.sampler_output = out self.result.samples = nestle.resample_equal(out.samples, out.weights) self.result.nested_samples = DataFrame( - out.samples, columns=self.search_parameter_keys) - self.result.nested_samples['weights'] = out.weights - self.result.nested_samples['log_likelihood'] = out.logl + out.samples, columns=self.search_parameter_keys + ) + self.result.nested_samples["weights"] = out.weights + self.result.nested_samples["log_likelihood"] = out.logl self.result.log_likelihood_evaluations = self.reorder_loglikelihoods( - unsorted_loglikelihoods=out.logl, unsorted_samples=out.samples, - sorted_samples=self.result.samples) + unsorted_loglikelihoods=out.logl, + unsorted_samples=out.samples, + sorted_samples=self.result.samples, + ) self.result.log_evidence = out.logz self.result.log_evidence_err = out.logzerr self.result.information_gain = out.h @@ -88,14 +107,24 @@ class Nestle(NestedSampler): """ import nestle + kwargs = self.kwargs.copy() - kwargs['maxiter'] = 2 + kwargs["maxiter"] = 2 nestle.sample( loglikelihood=self.log_likelihood, prior_transform=self.prior_transform, - ndim=self.ndim, **kwargs) + ndim=self.ndim, + **kwargs + ) self.result.samples = np.random.uniform(0, 1, (100, self.ndim)) self.result.log_evidence = np.nan self.result.log_evidence_err = np.nan self.calc_likelihood_count() return self.result + + def write_current_state(self): + """ + Nestle doesn't support checkpointing so no current state will be + written on interrupt. + """ + pass diff --git a/bilby/core/sampler/polychord.py b/bilby/core/sampler/polychord.py index 943a5c413abe7e45ff54eb4dde2c9aa8d35b7d91..617d6c7d17b22569f1b9bc23e37e58e292033625 100644 --- a/bilby/core/sampler/polychord.py +++ b/bilby/core/sampler/polychord.py @@ -1,7 +1,6 @@ - import numpy as np -from .base_sampler import NestedSampler +from .base_sampler import NestedSampler, signal_wrapper class PyPolyChord(NestedSampler): @@ -21,32 +20,66 @@ class PyPolyChord(NestedSampler): To see what the keyword arguments are for, see the docstring of PyPolyChordSettings """ - default_kwargs = dict(use_polychord_defaults=False, nlive=None, num_repeats=None, - nprior=-1, do_clustering=True, feedback=1, precision_criterion=0.001, - logzero=-1e30, max_ndead=-1, boost_posterior=0.0, posteriors=True, - equals=True, cluster_posteriors=True, write_resume=True, - write_paramnames=False, read_resume=True, write_stats=True, - write_live=True, write_dead=True, write_prior=True, - compression_factor=np.exp(-1), base_dir='outdir', - file_root='polychord', seed=-1, grade_dims=None, grade_frac=None, nlives={}) - + default_kwargs = dict( + use_polychord_defaults=False, + nlive=None, + num_repeats=None, + nprior=-1, + do_clustering=True, + feedback=1, + precision_criterion=0.001, + logzero=-1e30, + max_ndead=-1, + boost_posterior=0.0, + posteriors=True, + equals=True, + cluster_posteriors=True, + write_resume=True, + write_paramnames=False, + read_resume=True, + write_stats=True, + write_live=True, + write_dead=True, + write_prior=True, + compression_factor=np.exp(-1), + base_dir="outdir", + file_root="polychord", + seed=-1, + grade_dims=None, + grade_frac=None, + nlives={}, + ) + hard_exit = True + + @signal_wrapper def run_sampler(self): import pypolychord from pypolychord.settings import PolyChordSettings - if self.kwargs['use_polychord_defaults']: - settings = PolyChordSettings(nDims=self.ndim, nDerived=self.ndim, - base_dir=self._sample_file_directory, - file_root=self.label) + + if self.kwargs["use_polychord_defaults"]: + settings = PolyChordSettings( + nDims=self.ndim, + nDerived=self.ndim, + base_dir=self._sample_file_directory, + file_root=self.label, + ) else: self._setup_dynamic_defaults() pc_kwargs = self.kwargs.copy() - pc_kwargs['base_dir'] = self._sample_file_directory - pc_kwargs['file_root'] = self.label - pc_kwargs.pop('use_polychord_defaults') - settings = PolyChordSettings(nDims=self.ndim, nDerived=self.ndim, **pc_kwargs) + pc_kwargs["base_dir"] = self._sample_file_directory + pc_kwargs["file_root"] = self.label + pc_kwargs.pop("use_polychord_defaults") + settings = PolyChordSettings( + nDims=self.ndim, nDerived=self.ndim, **pc_kwargs + ) self._verify_kwargs_against_default_kwargs() - out = pypolychord.run_polychord(loglikelihood=self.log_likelihood, nDims=self.ndim, - nDerived=self.ndim, settings=settings, prior=self.prior_transform) + out = pypolychord.run_polychord( + loglikelihood=self.log_likelihood, + nDims=self.ndim, + nDerived=self.ndim, + settings=settings, + prior=self.prior_transform, + ) self.result.log_evidence = out.logZ self.result.log_evidence_err = out.logZerr log_likelihoods, physical_parameters = self._read_sample_file() @@ -56,24 +89,24 @@ class PyPolyChord(NestedSampler): return self.result def _setup_dynamic_defaults(self): - """ Sets up some interdependent default argument if none are given by the user """ - if not self.kwargs['grade_dims']: - self.kwargs['grade_dims'] = [self.ndim] - if not self.kwargs['grade_frac']: - self.kwargs['grade_frac'] = [1.0] * len(self.kwargs['grade_dims']) - if not self.kwargs['nlive']: - self.kwargs['nlive'] = self.ndim * 25 - if not self.kwargs['num_repeats']: - self.kwargs['num_repeats'] = self.ndim * 5 + """Sets up some interdependent default argument if none are given by the user""" + if not self.kwargs["grade_dims"]: + self.kwargs["grade_dims"] = [self.ndim] + if not self.kwargs["grade_frac"]: + self.kwargs["grade_frac"] = [1.0] * len(self.kwargs["grade_dims"]) + if not self.kwargs["nlive"]: + self.kwargs["nlive"] = self.ndim * 25 + if not self.kwargs["num_repeats"]: + self.kwargs["num_repeats"] = self.ndim * 5 def _translate_kwargs(self, kwargs): - if 'nlive' not in kwargs: + if "nlive" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: - kwargs['nlive'] = kwargs.pop(equiv) + kwargs["nlive"] = kwargs.pop(equiv) def log_likelihood(self, theta): - """ Overrides the log_likelihood so that PolyChord understands it """ + """Overrides the log_likelihood so that PolyChord understands it""" return super(PyPolyChord, self).log_likelihood(theta), theta def _read_sample_file(self): @@ -87,12 +120,14 @@ class PyPolyChord(NestedSampler): array_like, array_like: The log_likelihoods and the associated parameters """ - sample_file = self._sample_file_directory + '/' + self.label + '_equal_weights.txt' + sample_file = ( + self._sample_file_directory + "/" + self.label + "_equal_weights.txt" + ) samples = np.loadtxt(sample_file) log_likelihoods = -0.5 * samples[:, 1] - physical_parameters = samples[:, -self.ndim:] + physical_parameters = samples[:, -self.ndim :] return log_likelihoods, physical_parameters @property def _sample_file_directory(self): - return self.outdir + '/chains' + return self.outdir + "/chains" diff --git a/bilby/core/sampler/proposal.py b/bilby/core/sampler/proposal.py index 2d52616588328c8b5bdb02247d20ff5d48b71e8b..023caac5744de968d338cea30125574248b91f95 100644 --- a/bilby/core/sampler/proposal.py +++ b/bilby/core/sampler/proposal.py @@ -1,13 +1,12 @@ +import random from inspect import isclass import numpy as np -import random from ..prior import Uniform class Sample(dict): - def __init__(self, dictionary=None): if dictionary is None: dictionary = dict() @@ -31,15 +30,14 @@ class Sample(dict): @classmethod def from_external_type(cls, external_sample, sampler_name): - if sampler_name == 'cpnest': + if sampler_name == "cpnest": return cls.from_cpnest_live_point(external_sample) return external_sample class JumpProposal(object): - def __init__(self, priors=None): - """ A generic class for jump proposals + """A generic class for jump proposals Parameters ========== @@ -56,7 +54,7 @@ class JumpProposal(object): self.log_j = 0.0 def __call__(self, sample, **kwargs): - """ A generic wrapper for the jump proposal function + """A generic wrapper for the jump proposal function Parameters ========== @@ -71,26 +69,35 @@ class JumpProposal(object): return self._apply_boundaries(sample) def _move_reflecting_keys(self, sample): - keys = [key for key in sample.keys() - if self.priors[key].boundary == 'reflective'] + keys = [ + key for key in sample.keys() if self.priors[key].boundary == "reflective" + ] for key in keys: - if sample[key] > self.priors[key].maximum or sample[key] < self.priors[key].minimum: + if ( + sample[key] > self.priors[key].maximum + or sample[key] < self.priors[key].minimum + ): r = self.priors[key].maximum - self.priors[key].minimum delta = (sample[key] - self.priors[key].minimum) % (2 * r) if delta > r: - sample[key] = 2 * self.priors[key].maximum - self.priors[key].minimum - delta + sample[key] = ( + 2 * self.priors[key].maximum - self.priors[key].minimum - delta + ) elif delta < r: sample[key] = self.priors[key].minimum + delta return sample def _move_periodic_keys(self, sample): - keys = [key for key in sample.keys() - if self.priors[key].boundary == 'periodic'] + keys = [key for key in sample.keys() if self.priors[key].boundary == "periodic"] for key in keys: - if sample[key] > self.priors[key].maximum or sample[key] < self.priors[key].minimum: - sample[key] = (self.priors[key].minimum + - ((sample[key] - self.priors[key].minimum) % - (self.priors[key].maximum - self.priors[key].minimum))) + if ( + sample[key] > self.priors[key].maximum + or sample[key] < self.priors[key].minimum + ): + sample[key] = self.priors[key].minimum + ( + (sample[key] - self.priors[key].minimum) + % (self.priors[key].maximum - self.priors[key].minimum) + ) return sample def _apply_boundaries(self, sample): @@ -100,9 +107,8 @@ class JumpProposal(object): class JumpProposalCycle(object): - def __init__(self, proposal_functions, weights, cycle_length=100): - """ A generic wrapper class for proposal cycles + """A generic wrapper class for proposal cycles Parameters ========== @@ -129,8 +135,12 @@ class JumpProposalCycle(object): return len(self.proposal_functions) def update_cycle(self): - self._cycle = np.random.choice(self.proposal_functions, size=self.cycle_length, - p=self.weights, replace=True) + self._cycle = np.random.choice( + self.proposal_functions, + size=self.cycle_length, + p=self.weights, + replace=True, + ) @property def proposal_functions(self): @@ -190,9 +200,13 @@ class NormJump(JumpProposal): class EnsembleWalk(JumpProposal): - - def __init__(self, random_number_generator=random.random, n_points=3, priors=None, - **random_number_generator_args): + def __init__( + self, + random_number_generator=random.random, + n_points=3, + priors=None, + **random_number_generator_args + ): """ An ensemble walk @@ -213,12 +227,16 @@ class EnsembleWalk(JumpProposal): self.random_number_generator_args = random_number_generator_args def __call__(self, sample, **kwargs): - subset = random.sample(kwargs['coordinates'], self.n_points) + subset = random.sample(kwargs["coordinates"], self.n_points) for i in range(len(subset)): - subset[i] = Sample.from_external_type(subset[i], kwargs.get('sampler_name', None)) + subset[i] = Sample.from_external_type( + subset[i], kwargs.get("sampler_name", None) + ) center_of_mass = self.get_center_of_mass(subset) for x in subset: - sample += (x - center_of_mass) * self.random_number_generator(**self.random_number_generator_args) + sample += (x - center_of_mass) * self.random_number_generator( + **self.random_number_generator_args + ) return super(EnsembleWalk, self).__call__(sample) @staticmethod @@ -227,7 +245,6 @@ class EnsembleWalk(JumpProposal): class EnsembleStretch(JumpProposal): - def __init__(self, scale=2.0, priors=None): """ Stretch move. Calculates the log Jacobian which can be used in cpnest to bias future moves. @@ -241,8 +258,10 @@ class EnsembleStretch(JumpProposal): self.scale = scale def __call__(self, sample, **kwargs): - second_sample = random.choice(kwargs['coordinates']) - second_sample = Sample.from_external_type(second_sample, kwargs.get('sampler_name', None)) + second_sample = random.choice(kwargs["coordinates"]) + second_sample = Sample.from_external_type( + second_sample, kwargs.get("sampler_name", None) + ) step = random.uniform(-1, 1) * np.log(self.scale) sample = second_sample + (sample - second_sample) * np.exp(step) self.log_j = len(sample) * step @@ -250,7 +269,6 @@ class EnsembleStretch(JumpProposal): class DifferentialEvolution(JumpProposal): - def __init__(self, sigma=1e-4, mu=1.0, priors=None): """ Differential evolution step. Takes two elements from the existing coordinates and differentially evolves the @@ -268,13 +286,12 @@ class DifferentialEvolution(JumpProposal): self.mu = mu def __call__(self, sample, **kwargs): - a, b = random.sample(kwargs['coordinates'], 2) + a, b = random.sample(kwargs["coordinates"], 2) sample = sample + (b - a) * random.gauss(self.mu, self.sigma) return super(DifferentialEvolution, self).__call__(sample) class EnsembleEigenVector(JumpProposal): - def __init__(self, priors=None): """ Ensemble step based on the ensemble eigenvectors. @@ -316,7 +333,7 @@ class EnsembleEigenVector(JumpProposal): self.eigen_values, self.eigen_vectors = np.linalg.eigh(self.covariance) def __call__(self, sample, **kwargs): - self.update_eigenvectors(kwargs['coordinates']) + self.update_eigenvectors(kwargs["coordinates"]) i = random.randrange(len(sample)) jump_size = np.sqrt(np.fabs(self.eigen_values[i])) * random.gauss(0, 1) for j, key in enumerate(sample.keys()): diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index 6191af0ec51ccbb984ef034d317f26b544cb8e9b..063e2af7e472812e0a9a8f88cc2ea7c1fe384cc0 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -2,17 +2,19 @@ import copy import datetime import logging import os -import signal -import sys import time from collections import namedtuple import numpy as np import pandas as pd -from ..utils import logger, check_directory_exists_and_if_not_mkdir -from .base_sampler import SamplerError, MCMCSampler - +from ..utils import check_directory_exists_and_if_not_mkdir, logger +from .base_sampler import ( + MCMCSampler, + SamplerError, + _sampling_convenience_dump, + signal_wrapper, +) ConvergenceInputs = namedtuple( "ConvergenceInputs", @@ -81,7 +83,7 @@ class Ptemcee(MCMCSampler): the Gelman-Rubin statistic). min_tau: int, (1) A minimum tau (autocorrelation time) to accept. - check_point_deltaT: float, (600) + check_point_delta_t: float, (600) The period with which to checkpoint (in seconds). threads: int, (1) If threads > 1, a MultiPool object is setup and used. @@ -163,7 +165,7 @@ class Ptemcee(MCMCSampler): gradient_mean_log_posterior=0.1, Q_tol=1.02, min_tau=1, - check_point_deltaT=600, + check_point_delta_t=600, threads=1, exit_code=77, plot=False, @@ -173,7 +175,7 @@ class Ptemcee(MCMCSampler): niterations_per_check=5, log10beta_min=None, verbose=True, - **kwargs + **kwargs, ): super(Ptemcee, self).__init__( likelihood=likelihood, @@ -184,25 +186,18 @@ class Ptemcee(MCMCSampler): plot=plot, skip_import_verification=skip_import_verification, exit_code=exit_code, - **kwargs + **kwargs, ) self.nwalkers = self.sampler_init_kwargs["nwalkers"] self.ntemps = self.sampler_init_kwargs["ntemps"] self.max_steps = 500 - # Setup up signal handling - signal.signal(signal.SIGTERM, self.write_current_state_and_exit) - signal.signal(signal.SIGINT, self.write_current_state_and_exit) - signal.signal(signal.SIGALRM, self.write_current_state_and_exit) - # Checkpointing inputs self.resume = resume - self.check_point_deltaT = check_point_deltaT + self.check_point_delta_t = check_point_delta_t self.check_point_plot = check_point_plot - self.resume_file = "{}/{}_checkpoint_resume.pickle".format( - self.outdir, self.label - ) + self.resume_file = f"{self.outdir}/{self.label}_checkpoint_resume.pickle" # Store convergence checking inputs in a named tuple convergence_inputs_dict = dict( @@ -223,7 +218,7 @@ class Ptemcee(MCMCSampler): niterations_per_check=niterations_per_check, ) self.convergence_inputs = ConvergenceInputs(**convergence_inputs_dict) - logger.info("Using convergence inputs: {}".format(self.convergence_inputs)) + logger.info(f"Using convergence inputs: {self.convergence_inputs}") # Check if threads was given as an equivalent arg if threads == 1: @@ -239,32 +234,50 @@ class Ptemcee(MCMCSampler): self.pos0 = pos0 self._periodic = [ - self.priors[key].boundary == "periodic" for key in self.search_parameter_keys + self.priors[key].boundary == "periodic" + for key in self.search_parameter_keys ] self.priors.sample() - self._minima = np.array([ - self.priors[key].minimum for key in self.search_parameter_keys - ]) - self._range = np.array([ - self.priors[key].maximum for key in self.search_parameter_keys - ]) - self._minima + self._minima = np.array( + [self.priors[key].minimum for key in self.search_parameter_keys] + ) + self._range = ( + np.array([self.priors[key].maximum for key in self.search_parameter_keys]) + - self._minima + ) self.log10beta_min = log10beta_min if self.log10beta_min is not None: betas = np.logspace(0, self.log10beta_min, self.ntemps) - logger.warning("Using betas {}".format(betas)) + logger.warning(f"Using betas {betas}") self.kwargs["betas"] = betas self.verbose = verbose + self.iteration = 0 + self.chain_array = self.get_zero_chain_array() + self.log_likelihood_array = self.get_zero_array() + self.log_posterior_array = self.get_zero_array() + self.beta_list = list() + self.tau_list = list() + self.tau_list_n = list() + self.Q_list = list() + self.time_per_check = list() + + self.nburn = np.nan + self.thin = np.nan + self.tau_int = np.nan + self.nsamples_effective = 0 + self.discard = 0 + @property def sampler_function_kwargs(self): - """ Kwargs passed to samper.sampler() """ + """Kwargs passed to samper.sampler()""" keys = ["adapt", "swap_ratios"] return {key: self.kwargs[key] for key in keys} @property def sampler_init_kwargs(self): - """ Kwargs passed to initialize ptemcee.Sampler() """ + """Kwargs passed to initialize ptemcee.Sampler()""" return { key: value for key, value in self.kwargs.items() @@ -272,14 +285,14 @@ class Ptemcee(MCMCSampler): } def _translate_kwargs(self, kwargs): - """ Translate kwargs """ + """Translate kwargs""" if "nwalkers" not in kwargs: for equiv in self.nwalkers_equiv_kwargs: if equiv in kwargs: kwargs["nwalkers"] = kwargs.pop(equiv) def get_pos0_from_prior(self): - """ Draw the initial positions from the prior + """Draw the initial positions from the prior Returns ======= @@ -288,16 +301,15 @@ class Ptemcee(MCMCSampler): """ logger.info("Generating pos0 samples") - return np.array([ + return np.array( [ - self.get_random_draw_from_prior() - for _ in range(self.nwalkers) + [self.get_random_draw_from_prior() for _ in range(self.nwalkers)] + for _ in range(self.kwargs["ntemps"]) ] - for _ in range(self.kwargs["ntemps"]) - ]) + ) def get_pos0_from_minimize(self, minimize_list=None): - """ Draw the initial positions using an initial minimization step + """Draw the initial positions using an initial minimization step See pos0 in the class initialization for details. @@ -318,12 +330,12 @@ class Ptemcee(MCMCSampler): else: pos0 = np.array(self.get_pos0_from_prior()) - logger.info("Attempting to set pos0 for {} from minimize".format(minimize_list)) + logger.info(f"Attempting to set pos0 for {minimize_list} from minimize") likelihood_copy = copy.copy(self.likelihood) def neg_log_like(params): - """ Internal function to minimize """ + """Internal function to minimize""" likelihood_copy.parameters.update( {key: val for key, val in zip(minimize_list, params)} ) @@ -360,9 +372,7 @@ class Ptemcee(MCMCSampler): for i, key in enumerate(minimize_list): pos0_min = np.min(success[:, i]) pos0_max = np.max(success[:, i]) - logger.info( - "Initialize {} walkers from {}->{}".format(key, pos0_min, pos0_max) - ) + logger.info(f"Initialize {key} walkers from {pos0_min}->{pos0_max}") j = self.search_parameter_keys.index(key) pos0[:, :, j] = np.random.uniform( pos0_min, @@ -375,9 +385,8 @@ class Ptemcee(MCMCSampler): if self.pos0.shape != (self.ntemps, self.nwalkers, self.ndim): raise ValueError( "Shape of starting array should be (ntemps, nwalkers, ndim). " - "In this case that is ({}, {}, {}), got {}".format( - self.ntemps, self.nwalkers, self.ndim, self.pos0.shape - ) + f"In this case that is ({self.ntemps}, {self.nwalkers}, " + f"{self.ndim}), got {self.pos0.shape}" ) else: return self.pos0 @@ -395,12 +404,13 @@ class Ptemcee(MCMCSampler): return self.get_pos0_from_array() def setup_sampler(self): - """ Either initialize the sampler or read in the resume file """ + """Either initialize the sampler or read in the resume file""" import ptemcee if os.path.isfile(self.resume_file) and self.resume is True: import dill - logger.info("Resume data {} found".format(self.resume_file)) + + logger.info(f"Resume data {self.resume_file} found") with open(self.resume_file, "rb") as file: data = dill.load(file) @@ -422,9 +432,7 @@ class Ptemcee(MCMCSampler): self.sampler.pool = self.pool self.sampler.threads = self.threads - logger.info( - "Resuming from previous run with time={}".format(self.iteration) - ) + logger.info(f"Resuming from previous run with time={self.iteration}") else: # Initialize the PTSampler @@ -433,32 +441,29 @@ class Ptemcee(MCMCSampler): dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior, - **self.sampler_init_kwargs + **self.sampler_init_kwargs, ) else: self.sampler = ptemcee.Sampler( dim=self.ndim, logl=do_nothing_function, logp=do_nothing_function, - pool=self.pool, threads=self.threads, - **self.sampler_init_kwargs + **self.sampler_init_kwargs, ) - self.sampler._likeprior = LikePriorEvaluator( - self.search_parameter_keys, use_ratio=self.use_ratio - ) + self.sampler._likeprior = LikePriorEvaluator() # Initialize storing results self.iteration = 0 self.chain_array = self.get_zero_chain_array() self.log_likelihood_array = self.get_zero_array() self.log_posterior_array = self.get_zero_array() - self.beta_list = [] - self.tau_list = [] - self.tau_list_n = [] - self.Q_list = [] - self.time_per_check = [] + self.beta_list = list() + self.tau_list = list() + self.tau_list_n = list() + self.Q_list = list() + self.time_per_check = list() self.pos0 = self.get_pos0() return self.sampler @@ -470,7 +475,7 @@ class Ptemcee(MCMCSampler): return np.zeros((self.ntemps, self.nwalkers, self.max_steps)) def get_pos0(self): - """ Master logic for setting pos0 """ + """Master logic for setting pos0""" if isinstance(self.pos0, str) and self.pos0.lower() == "prior": return self.get_pos0_from_prior() elif isinstance(self.pos0, str) and self.pos0.lower() == "minimize": @@ -482,52 +487,55 @@ class Ptemcee(MCMCSampler): elif isinstance(self.pos0, dict): return self.get_pos0_from_dict() else: - raise SamplerError("pos0={} not implemented".format(self.pos0)) + raise SamplerError(f"pos0={self.pos0} not implemented") - def setup_pool(self): - """ If threads > 1, setup a MultiPool, else run in serial mode """ - if self.threads > 1: - import schwimmbad - - logger.info("Creating MultiPool with {} processes".format(self.threads)) - self.pool = schwimmbad.MultiPool( - self.threads, initializer=init, initargs=(self.likelihood, self.priors) - ) - else: - self.pool = None + def _close_pool(self): + if getattr(self.sampler, "pool", None) is not None: + self.sampler.pool = None + if "pool" in self.result.sampler_kwargs: + del self.result.sampler_kwargs["pool"] + super(Ptemcee, self)._close_pool() + @signal_wrapper def run_sampler(self): - self.setup_pool() + self._setup_pool() sampler = self.setup_sampler() t0 = datetime.datetime.now() logger.info("Starting to sample") while True: for (pos0, log_posterior, log_likelihood) in sampler.sample( - self.pos0, storechain=False, - iterations=self.convergence_inputs.niterations_per_check, - **self.sampler_function_kwargs): - pos0[:, :, self._periodic] = np.mod( - pos0[:, :, self._periodic] - self._minima[self._periodic], - self._range[self._periodic] - ) + self._minima[self._periodic] + self.pos0, + storechain=False, + iterations=self.convergence_inputs.niterations_per_check, + **self.sampler_function_kwargs, + ): + pos0[:, :, self._periodic] = ( + np.mod( + pos0[:, :, self._periodic] - self._minima[self._periodic], + self._range[self._periodic], + ) + + self._minima[self._periodic] + ) if self.iteration == self.chain_array.shape[1]: - self.chain_array = np.concatenate(( - self.chain_array, self.get_zero_chain_array()), axis=1) - self.log_likelihood_array = np.concatenate(( - self.log_likelihood_array, self.get_zero_array()), - axis=2) - self.log_posterior_array = np.concatenate(( - self.log_posterior_array, self.get_zero_array()), - axis=2) + self.chain_array = np.concatenate( + (self.chain_array, self.get_zero_chain_array()), axis=1 + ) + self.log_likelihood_array = np.concatenate( + (self.log_likelihood_array, self.get_zero_array()), axis=2 + ) + self.log_posterior_array = np.concatenate( + (self.log_posterior_array, self.get_zero_array()), axis=2 + ) self.pos0 = pos0 self.chain_array[:, self.iteration, :] = pos0[0, :, :] self.log_likelihood_array[:, :, self.iteration] = log_likelihood self.log_posterior_array[:, :, self.iteration] = log_posterior self.mean_log_posterior = np.mean( - self.log_posterior_array[:, :, :self. iteration], axis=1) + self.log_posterior_array[:, :, : self.iteration], axis=1 + ) # Calculate time per iteration self.time_per_check.append((datetime.datetime.now() - t0).total_seconds()) @@ -537,15 +545,13 @@ class Ptemcee(MCMCSampler): # Calculate minimum iteration step to discard minimum_iteration = get_minimum_stable_itertion( - self.mean_log_posterior, - frac=self.convergence_inputs.mean_logl_frac + self.mean_log_posterior, frac=self.convergence_inputs.mean_logl_frac ) - logger.debug("Minimum iteration = {}".format(minimum_iteration)) + logger.debug(f"Minimum iteration = {minimum_iteration}") # Calculate the maximum discard number discard_max = np.max( - [self.convergence_inputs.burn_in_fixed_discard, - minimum_iteration] + [self.convergence_inputs.burn_in_fixed_discard, minimum_iteration] ) if self.iteration > discard_max + self.nwalkers: @@ -565,7 +571,7 @@ class Ptemcee(MCMCSampler): self.nsamples_effective, ) = check_iteration( self.iteration, - self.chain_array[:, self.discard:self.iteration, :], + self.chain_array[:, self.discard : self.iteration, :], sampler, self.convergence_inputs, self.search_parameter_keys, @@ -588,7 +594,7 @@ class Ptemcee(MCMCSampler): else: last_checkpoint_s = np.sum(self.time_per_check) - if last_checkpoint_s > self.check_point_deltaT: + if last_checkpoint_s > self.check_point_delta_t: self.write_current_state(plot=self.check_point_plot) # Run a final checkpoint to update the plots and samples @@ -609,9 +615,14 @@ class Ptemcee(MCMCSampler): self.result.discard = self.discard log_evidence, log_evidence_err = compute_evidence( - sampler, self.log_likelihood_array, self.outdir, - self.label, self.discard, self.nburn, - self.thin, self.iteration, + sampler, + self.log_likelihood_array, + self.outdir, + self.label, + self.discard, + self.nburn, + self.thin, + self.iteration, ) self.result.log_evidence = log_evidence self.result.log_evidence_err = log_evidence_err @@ -620,21 +631,10 @@ class Ptemcee(MCMCSampler): seconds=np.sum(self.time_per_check) ) - if self.pool: - self.pool.close() + self._close_pool() return self.result - def write_current_state_and_exit(self, signum=None, frame=None): - logger.warning("Run terminated with signal {}".format(signum)) - if getattr(self, "pool", None) or self.threads == 1: - self.write_current_state(plot=False) - if getattr(self, "pool", None): - logger.info("Closing pool") - self.pool.close() - logger.info("Exit on signal {}".format(self.exit_code)) - sys.exit(self.exit_code) - def write_current_state(self, plot=True): check_directory_exists_and_if_not_mkdir(self.outdir) checkpoint( @@ -672,7 +672,7 @@ class Ptemcee(MCMCSampler): self.discard, ) except Exception as e: - logger.info("Walkers plot failed with exception {}".format(e)) + logger.info(f"Walkers plot failed with exception {e}") try: # Generate the tau plot diagnostic if DEBUG @@ -687,7 +687,7 @@ class Ptemcee(MCMCSampler): self.convergence_inputs.autocorr_tau, ) except Exception as e: - logger.info("tau plot failed with exception {}".format(e)) + logger.info(f"tau plot failed with exception {e}") try: plot_mean_log_posterior( @@ -696,7 +696,7 @@ class Ptemcee(MCMCSampler): self.label, ) except Exception as e: - logger.info("mean_logl plot failed with exception {}".format(e)) + logger.info(f"mean_logl plot failed with exception {e}") def get_minimum_stable_itertion(mean_array, frac, nsteps_min=10): @@ -728,7 +728,7 @@ def check_iteration( mean_log_posterior, verbose=True, ): - """ Per-iteration logic to calculate the convergence check + """Per-iteration logic to calculate the convergence check Parameters ========== @@ -780,8 +780,17 @@ def check_iteration( if np.isnan(tau) or np.isinf(tau): if verbose: print_progress( - iteration, sampler, time_per_check, np.nan, np.nan, - np.nan, np.nan, np.nan, False, convergence_inputs, Q, + iteration, + sampler, + time_per_check, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + False, + convergence_inputs, + Q, ) return False, np.nan, np.nan, np.nan, np.nan @@ -796,45 +805,47 @@ def check_iteration( # Calculate convergence boolean converged = Q < ci.Q_tol and ci.nsamples < nsamples_effective - logger.debug("Convergence: Q<Q_tol={}, nsamples<nsamples_effective={}" - .format(Q < ci.Q_tol, ci.nsamples < nsamples_effective)) + logger.debug( + f"Convergence: Q<Q_tol={Q < ci.Q_tol}, " + f"nsamples<nsamples_effective={ci.nsamples < nsamples_effective}" + ) GRAD_WINDOW_LENGTH = nwalkers + 1 nsteps_to_check = ci.autocorr_tau * np.max([2 * GRAD_WINDOW_LENGTH, tau_int]) lower_tau_index = np.max([0, len(tau_list) - nsteps_to_check]) - check_taus = np.array(tau_list[lower_tau_index :]) + check_taus = np.array(tau_list[lower_tau_index:]) if not np.any(np.isnan(check_taus)) and check_taus.shape[0] > GRAD_WINDOW_LENGTH: - gradient_tau = get_max_gradient( - check_taus, axis=0, window_length=11) + gradient_tau = get_max_gradient(check_taus, axis=0, window_length=11) if gradient_tau < ci.gradient_tau: logger.debug( - "tau usable as {} < gradient_tau={}" - .format(gradient_tau, ci.gradient_tau) + f"tau usable as {gradient_tau} < gradient_tau={ci.gradient_tau}" ) tau_usable = True else: logger.debug( - "tau not usable as {} > gradient_tau={}" - .format(gradient_tau, ci.gradient_tau) + f"tau not usable as {gradient_tau} > gradient_tau={ci.gradient_tau}" ) tau_usable = False check_mean_log_posterior = mean_log_posterior[:, -nsteps_to_check:] gradient_mean_log_posterior = get_max_gradient( - check_mean_log_posterior, axis=1, window_length=GRAD_WINDOW_LENGTH, - smooth=True) + check_mean_log_posterior, + axis=1, + window_length=GRAD_WINDOW_LENGTH, + smooth=True, + ) if gradient_mean_log_posterior < ci.gradient_mean_log_posterior: logger.debug( - "tau usable as {} < gradient_mean_log_posterior={}" - .format(gradient_mean_log_posterior, ci.gradient_mean_log_posterior) + f"tau usable as {gradient_mean_log_posterior} < " + f"gradient_mean_log_posterior={ci.gradient_mean_log_posterior}" ) tau_usable *= True else: logger.debug( - "tau not usable as {} > gradient_mean_log_posterior={}" - .format(gradient_mean_log_posterior, ci.gradient_mean_log_posterior) + f"tau not usable as {gradient_mean_log_posterior} > " + f"gradient_mean_log_posterior={ci.gradient_mean_log_posterior}" ) tau_usable = False @@ -864,7 +875,7 @@ def check_iteration( gradient_mean_log_posterior, tau_usable, convergence_inputs, - Q + Q, ) stop = converged and tau_usable return stop, nburn, thin, tau_int, nsamples_effective @@ -872,13 +883,14 @@ def check_iteration( def get_max_gradient(x, axis=0, window_length=11, polyorder=2, smooth=False): from scipy.signal import savgol_filter + if smooth: - x = savgol_filter( - x, axis=axis, window_length=window_length, polyorder=3 + x = savgol_filter(x, axis=axis, window_length=window_length, polyorder=3) + return np.max( + savgol_filter( + x, axis=axis, window_length=window_length, polyorder=polyorder, deriv=1 ) - return np.max(savgol_filter( - x, axis=axis, window_length=window_length, polyorder=polyorder, - deriv=1)) + ) def get_Q_convergence(samples): @@ -887,7 +899,7 @@ def get_Q_convergence(samples): W = np.mean(np.var(samples, axis=1), axis=0) per_walker_mean = np.mean(samples, axis=1) mean = np.mean(per_walker_mean, axis=0) - B = nsteps / (nwalkers - 1.) * np.sum((per_walker_mean - mean)**2, axis=0) + B = nsteps / (nwalkers - 1.0) * np.sum((per_walker_mean - mean) ** 2, axis=0) Vhat = (nsteps - 1) / nsteps * W + (nwalkers + 1) / (nwalkers * nsteps) * B Q_per_dim = np.sqrt(Vhat / W) return np.max(Q_per_dim) @@ -910,16 +922,18 @@ def print_progress( ): # Setup acceptance string acceptance = sampler.acceptance_fraction[0, :] - acceptance_str = "{:1.2f}-{:1.2f}".format(np.min(acceptance), np.max(acceptance)) + acceptance_str = f"{np.min(acceptance):1.2f}-{np.max(acceptance):1.2f}" # Setup tswap acceptance string tswap_acceptance_fraction = sampler.tswap_acceptance_fraction - tswap_acceptance_str = "{:1.2f}-{:1.2f}".format( - np.min(tswap_acceptance_fraction), np.max(tswap_acceptance_fraction) - ) + tswap_acceptance_str = f"{np.min(tswap_acceptance_fraction):1.2f}-{np.max(tswap_acceptance_fraction):1.2f}" ave_time_per_check = np.mean(time_per_check[-3:]) - time_left = (convergence_inputs.nsamples - nsamples_effective) * ave_time_per_check / samples_per_check + time_left = ( + (convergence_inputs.nsamples - nsamples_effective) + * ave_time_per_check + / samples_per_check + ) if time_left > 0: time_left = str(datetime.timedelta(seconds=int(time_left))) else: @@ -927,46 +941,44 @@ def print_progress( sampling_time = datetime.timedelta(seconds=np.sum(time_per_check)) - tau_str = "{}(+{:0.2f},+{:0.2f})".format( - tau_int, gradient_tau, gradient_mean_log_posterior - ) + tau_str = f"{tau_int}(+{gradient_tau:0.2f},+{gradient_mean_log_posterior:0.2f})" if tau_usable: - tau_str = "={}".format(tau_str) + tau_str = f"={tau_str}" else: - tau_str = "!{}".format(tau_str) + tau_str = f"!{tau_str}" - Q_str = "{:0.2f}".format(Q) + Q_str = f"{Q:0.2f}" - evals_per_check = sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check + evals_per_check = ( + sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check + ) - ncalls = "{:1.1e}".format( - convergence_inputs.niterations_per_check * iteration * sampler.nwalkers * sampler.ntemps) - eval_timing = "{:1.2f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check) + approximate_ncalls = ( + convergence_inputs.niterations_per_check + * iteration + * sampler.nwalkers + * sampler.ntemps + ) + ncalls = f"{approximate_ncalls:1.1e}" + eval_timing = f"{1000.0 * ave_time_per_check / evals_per_check:1.2f}ms/ev" try: print( - "{}|{}|nc:{}|a0:{}|swp:{}|n:{}<{}|t{}|q:{}|{}".format( - iteration, - str(sampling_time).split(".")[0], - ncalls, - acceptance_str, - tswap_acceptance_str, - nsamples_effective, - convergence_inputs.nsamples, - tau_str, - Q_str, - eval_timing, - ), + f"{iteration}|{str(sampling_time).split('.')[0]}|nc:{ncalls}|" + f"a0:{acceptance_str}|swp:{tswap_acceptance_str}|" + f"n:{nsamples_effective}<{convergence_inputs.nsamples}|t{tau_str}|" + f"q:{Q_str}|{eval_timing}", flush=True, ) except OSError as e: - logger.debug("Failed to print iteration due to :{}".format(e)) + logger.debug(f"Failed to print iteration due to :{e}") def calculate_tau_array(samples, search_parameter_keys, ci): - """ Compute ACT tau for 0-temperature chains """ + """Compute ACT tau for 0-temperature chains""" import emcee + nwalkers, nsteps, ndim = samples.shape tau_array = np.zeros((nwalkers, ndim)) + np.inf if nsteps > 1: @@ -976,7 +988,8 @@ def calculate_tau_array(samples, search_parameter_keys, ci): continue try: tau_array[ii, jj] = emcee.autocorr.integrated_time( - samples[ii, :, jj], c=ci.autocorr_c, tol=0)[0] + samples[ii, :, jj], c=ci.autocorr_c, tol=0 + )[0] except emcee.autocorr.AutocorrError: tau_array[ii, jj] = np.inf return tau_array @@ -1004,21 +1017,24 @@ def checkpoint( time_per_check, ): import dill + logger.info("Writing checkpoint and diagnostics") ndim = sampler.dim # Store the samples if possible if nsamples_effective > 0: - filename = "{}/{}_samples.txt".format(outdir, label) - samples = np.array(chain_array)[:, discard + nburn : iteration : thin, :].reshape( - (-1, ndim) - ) + filename = f"{outdir}/{label}_samples.txt" + samples = np.array(chain_array)[ + :, discard + nburn : iteration : thin, : + ].reshape((-1, ndim)) df = pd.DataFrame(samples, columns=search_parameter_keys) df.to_csv(filename, index=False, header=True, sep=" ") # Pickle the resume artefacts - sampler_copy = copy.copy(sampler) - del sampler_copy.pool + pool = sampler.pool + sampler.pool = None + sampler_copy = copy.deepcopy(sampler) + sampler.pool = pool data = dict( iteration=iteration, @@ -1040,10 +1056,10 @@ def checkpoint( logger.info("Finished writing checkpoint") -def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label, - discard=0): - """ Method to plot the trace of the walkers in an ensemble MCMC plot """ +def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label, discard=0): + """Method to plot the trace of the walkers in an ensemble MCMC plot""" import matplotlib.pyplot as plt + nwalkers, nsteps, ndim = walkers.shape if np.isnan(nburn): nburn = nsteps @@ -1051,51 +1067,65 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label, thin = 1 idxs = np.arange(nsteps) fig, axes = plt.subplots(nrows=ndim, ncols=2, figsize=(8, 3 * ndim)) - scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.1,) + scatter_kwargs = dict( + lw=0, + marker="o", + markersize=1, + alpha=0.1, + ) # Plot the fixed burn-in if discard > 0: for i, (ax, axh) in enumerate(axes): ax.plot( - idxs[: discard], - walkers[:, : discard, i].T, + idxs[:discard], + walkers[:, :discard, i].T, color="gray", - **scatter_kwargs + **scatter_kwargs, ) # Plot the burn-in for i, (ax, axh) in enumerate(axes): ax.plot( - idxs[discard: discard + nburn + 1], - walkers[:, discard: discard + nburn + 1, i].T, + idxs[discard : discard + nburn + 1], + walkers[:, discard : discard + nburn + 1, i].T, color="C1", - **scatter_kwargs + **scatter_kwargs, ) # Plot the thinned posterior samples for i, (ax, axh) in enumerate(axes): ax.plot( - idxs[discard + nburn::thin], - walkers[:, discard + nburn::thin, i].T, + idxs[discard + nburn :: thin], + walkers[:, discard + nburn :: thin, i].T, color="C0", - **scatter_kwargs + **scatter_kwargs, + ) + axh.hist( + walkers[:, discard + nburn :: thin, i].reshape((-1)), bins=50, alpha=0.8 ) - axh.hist(walkers[:, discard + nburn::thin, i].reshape((-1)), bins=50, alpha=0.8) for i, (ax, axh) in enumerate(axes): axh.set_xlabel(parameter_labels[i]) ax.set_ylabel(parameter_labels[i]) fig.tight_layout() - filename = "{}/{}_checkpoint_trace.png".format(outdir, label) + filename = f"{outdir}/{label}_checkpoint_trace.png" fig.savefig(filename) plt.close(fig) def plot_tau( - tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau, + tau_list_n, + tau_list, + search_parameter_keys, + outdir, + label, + tau, + autocorr_tau, ): import matplotlib.pyplot as plt + fig, ax = plt.subplots() for i, key in enumerate(search_parameter_keys): ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key) @@ -1103,7 +1133,7 @@ def plot_tau( ax.set_ylabel(r"$\langle \tau \rangle$") ax.legend() fig.tight_layout() - fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label)) + fig.savefig(f"{outdir}/{label}_checkpoint_tau.png") plt.close(fig) @@ -1119,17 +1149,30 @@ def plot_mean_log_posterior(mean_log_posterior, outdir, label): fig, ax = plt.subplots() idxs = np.arange(nsteps) ax.plot(idxs, mean_log_posterior.T) - ax.set(xlabel="Iteration", ylabel=r"$\langle\mathrm{log-posterior}\rangle$", - ylim=(ymin, ymax)) + ax.set( + xlabel="Iteration", + ylabel=r"$\langle\mathrm{log-posterior}\rangle$", + ylim=(ymin, ymax), + ) fig.tight_layout() - fig.savefig("{}/{}_checkpoint_meanlogposterior.png".format(outdir, label)) + fig.savefig(f"{outdir}/{label}_checkpoint_meanlogposterior.png") plt.close(fig) -def compute_evidence(sampler, log_likelihood_array, outdir, label, discard, nburn, thin, - iteration, make_plots=True): - """ Computes the evidence using thermodynamic integration """ +def compute_evidence( + sampler, + log_likelihood_array, + outdir, + label, + discard, + nburn, + thin, + iteration, + make_plots=True, +): + """Computes the evidence using thermodynamic integration""" import matplotlib.pyplot as plt + betas = sampler.betas # We compute the evidence without the burnin samples, but we do not thin lnlike = log_likelihood_array[:, :, discard + nburn : iteration] @@ -1141,7 +1184,7 @@ def compute_evidence(sampler, log_likelihood_array, outdir, label, discard, nbur if any(np.isinf(mean_lnlikes)): logger.warning( "mean_lnlikes contains inf: recalculating without" - " the {} infs".format(len(betas[np.isinf(mean_lnlikes)])) + f" the {len(betas[np.isinf(mean_lnlikes)])} infs" ) idxs = np.isinf(mean_lnlikes) mean_lnlikes = mean_lnlikes[~idxs] @@ -1165,33 +1208,23 @@ def compute_evidence(sampler, log_likelihood_array, outdir, label, discard, nbur ax2.semilogx(min_betas, evidence, "-o") ax2.set_ylabel( - r"$\int_{\beta_{min}}^{\beta=1}" + r"\langle \log(\mathcal{L})\rangle d\beta$", + r"$\int_{\beta_{min}}^{\beta=1}" + + r"\langle \log(\mathcal{L})\rangle d\beta$", size=16, ) ax2.set_xlabel(r"$\beta_{min}$") plt.tight_layout() - fig.savefig("{}/{}_beta_lnl.png".format(outdir, label)) + fig.savefig(f"{outdir}/{label}_beta_lnl.png") plt.close(fig) return lnZ, lnZerr def do_nothing_function(): - """ This is a do-nothing function, we overwrite the likelihood and prior elsewhere """ + """This is a do-nothing function, we overwrite the likelihood and prior elsewhere""" pass -likelihood = None -priors = None - - -def init(likelihood_in, priors_in): - global likelihood - global priors - likelihood = likelihood_in - priors = priors_in - - class LikePriorEvaluator(object): """ This class is copied and modified from ptemcee.LikePriorEvaluator, see @@ -1203,38 +1236,43 @@ class LikePriorEvaluator(object): """ - def __init__(self, search_parameter_keys, use_ratio=False): - self.search_parameter_keys = search_parameter_keys - self.use_ratio = use_ratio + def __init__(self): self.periodic_set = False def _setup_periodic(self): + priors = _sampling_convenience_dump.priors + search_parameter_keys = _sampling_convenience_dump.search_parameter_keys self._periodic = [ - priors[key].boundary == "periodic" for key in self.search_parameter_keys + priors[key].boundary == "periodic" for key in search_parameter_keys ] priors.sample() - self._minima = np.array([ - priors[key].minimum for key in self.search_parameter_keys - ]) - self._range = np.array([ - priors[key].maximum for key in self.search_parameter_keys - ]) - self._minima + self._minima = np.array([priors[key].minimum for key in search_parameter_keys]) + self._range = ( + np.array([priors[key].maximum for key in search_parameter_keys]) + - self._minima + ) self.periodic_set = True def _wrap_periodic(self, array): if not self.periodic_set: self._setup_periodic() - array[self._periodic] = np.mod( - array[self._periodic] - self._minima[self._periodic], - self._range[self._periodic] - ) + self._minima[self._periodic] + array[self._periodic] = ( + np.mod( + array[self._periodic] - self._minima[self._periodic], + self._range[self._periodic], + ) + + self._minima[self._periodic] + ) return array def logl(self, v_array): - parameters = {key: v for key, v in zip(self.search_parameter_keys, v_array)} + priors = _sampling_convenience_dump.priors + likelihood = _sampling_convenience_dump.likelihood + search_parameter_keys = _sampling_convenience_dump.search_parameter_keys + parameters = {key: v for key, v in zip(search_parameter_keys, v_array)} if priors.evaluate_constraints(parameters) > 0: likelihood.parameters.update(parameters) - if self.use_ratio: + if _sampling_convenience_dump.use_ratio: return likelihood.log_likelihood() - likelihood.noise_log_likelihood() else: return likelihood.log_likelihood() @@ -1242,9 +1280,15 @@ class LikePriorEvaluator(object): return np.nan_to_num(-np.inf) def logp(self, v_array): - params = {key: t for key, t in zip(self.search_parameter_keys, v_array)} + priors = _sampling_convenience_dump.priors + search_parameter_keys = _sampling_convenience_dump.search_parameter_keys + params = {key: t for key, t in zip(search_parameter_keys, v_array)} return priors.ln_prob(params) + def call_emcee(self, theta): + ll, lp = self.__call__(theta) + return ll + lp, [ll, lp] + def __call__(self, x): lp = self.logp(x) if np.isnan(lp): diff --git a/bilby/core/sampler/ptmcmc.py b/bilby/core/sampler/ptmcmc.py index 49b86d7392ff4d21af698fe7b7034b25eaa7ac22..6b9c3c96eb83a81486df1e38dd5948b668cfb358 100644 --- a/bilby/core/sampler/ptmcmc.py +++ b/bilby/core/sampler/ptmcmc.py @@ -1,11 +1,10 @@ - import glob import shutil import numpy as np -from .base_sampler import MCMCSampler, SamplerNotInstalledError from ..utils import logger +from .base_sampler import MCMCSampler, SamplerNotInstalledError, signal_wrapper class PTMCMCSampler(MCMCSampler): @@ -42,29 +41,66 @@ class PTMCMCSampler(MCMCSampler): """ - default_kwargs = {'p0': None, 'Niter': 2 * 10 ** 4 + 1, 'neff': 10 ** 4, - 'burn': 5 * 10 ** 3, 'verbose': True, - 'ladder': None, 'Tmin': 1, 'Tmax': None, 'Tskip': 100, - 'isave': 1000, 'thin': 1, 'covUpdate': 1000, - 'SCAMweight': 1, 'AMweight': 1, 'DEweight': 1, - 'HMCweight': 0, 'MALAweight': 0, 'NUTSweight': 0, - 'HMCstepsize': 0.1, 'HMCsteps': 300, - 'groups': None, 'custom_proposals': None, - 'loglargs': {}, 'loglkwargs': {}, 'logpargs': {}, - 'logpkwargs': {}, 'logl_grad': None, 'logp_grad': None, - 'outDir': None} - - def __init__(self, likelihood, priors, outdir='outdir', label='label', - use_ratio=False, plot=False, skip_import_verification=False, - pos0=None, burn_in_fraction=0.25, **kwargs): - - super(PTMCMCSampler, self).__init__(likelihood=likelihood, priors=priors, - outdir=outdir, label=label, use_ratio=use_ratio, - plot=plot, - skip_import_verification=skip_import_verification, - **kwargs) - - self.p0 = self.get_random_draw_from_prior() + default_kwargs = { + "p0": None, + "Niter": 2 * 10**4 + 1, + "neff": 10**4, + "burn": 5 * 10**3, + "verbose": True, + "ladder": None, + "Tmin": 1, + "Tmax": None, + "Tskip": 100, + "isave": 1000, + "thin": 1, + "covUpdate": 1000, + "SCAMweight": 1, + "AMweight": 1, + "DEweight": 1, + "HMCweight": 0, + "MALAweight": 0, + "NUTSweight": 0, + "HMCstepsize": 0.1, + "HMCsteps": 300, + "groups": None, + "custom_proposals": None, + "loglargs": {}, + "loglkwargs": {}, + "logpargs": {}, + "logpkwargs": {}, + "logl_grad": None, + "logp_grad": None, + "outDir": None, + } + hard_exit = True + + def __init__( + self, + likelihood, + priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + skip_import_verification=False, + **kwargs, + ): + + super(PTMCMCSampler, self).__init__( + likelihood=likelihood, + priors=priors, + outdir=outdir, + label=label, + use_ratio=use_ratio, + plot=plot, + skip_import_verification=skip_import_verification, + **kwargs, + ) + + if self.kwargs["p0"] is None: + self.p0 = self.get_random_draw_from_prior() + else: + self.p0 = self.kwargs["p0"] self.likelihood = likelihood self.priors = priors @@ -73,88 +109,102 @@ class PTMCMCSampler(MCMCSampler): # which forces `__name__.lower() external_sampler_name = self.__class__.__name__ try: - self.external_sampler = __import__(external_sampler_name) + __import__(external_sampler_name) except (ImportError, SystemExit): raise SamplerNotInstalledError( - "Sampler {} is not installed on this system".format(external_sampler_name)) + f"Sampler {external_sampler_name} is not installed on this system" + ) def _translate_kwargs(self, kwargs): - if 'Niter' not in kwargs: + if "Niter" not in kwargs: for equiv in self.nwalkers_equiv_kwargs: if equiv in kwargs: - kwargs['Niter'] = kwargs.pop(equiv) - if 'burn' not in kwargs: + kwargs["Niter"] = kwargs.pop(equiv) + if "burn" not in kwargs: for equiv in self.nburn_equiv_kwargs: if equiv in kwargs: - kwargs['burn'] = kwargs.pop(equiv) + kwargs["burn"] = kwargs.pop(equiv) @property def custom_proposals(self): - return self.kwargs['custom_proposals'] + return self.kwargs["custom_proposals"] @property def sampler_init_kwargs(self): - keys = ['groups', - 'loglargs', - 'logp_grad', - 'logpkwargs', - 'loglkwargs', - 'logl_grad', - 'logpargs', - 'outDir', - 'verbose'] + keys = [ + "groups", + "loglargs", + "logp_grad", + "logpkwargs", + "loglkwargs", + "logl_grad", + "logpargs", + "outDir", + "verbose", + ] init_kwargs = {key: self.kwargs[key] for key in keys} - if init_kwargs['outDir'] is None: - init_kwargs['outDir'] = '{}/ptmcmc_temp_{}/'.format(self.outdir, self.label) + if init_kwargs["outDir"] is None: + init_kwargs["outDir"] = f"{self.outdir}/ptmcmc_temp_{self.label}/" return init_kwargs @property def sampler_function_kwargs(self): - keys = ['Niter', - 'neff', - 'Tmin', - 'HMCweight', - 'covUpdate', - 'SCAMweight', - 'ladder', - 'burn', - 'NUTSweight', - 'AMweight', - 'MALAweight', - 'thin', - 'HMCstepsize', - 'isave', - 'Tskip', - 'HMCsteps', - 'Tmax', - 'DEweight'] + keys = [ + "Niter", + "neff", + "Tmin", + "HMCweight", + "covUpdate", + "SCAMweight", + "ladder", + "burn", + "NUTSweight", + "AMweight", + "MALAweight", + "thin", + "HMCstepsize", + "isave", + "Tskip", + "HMCsteps", + "Tmax", + "DEweight", + ] sampler_kwargs = {key: self.kwargs[key] for key in keys} return sampler_kwargs @staticmethod def _import_external_sampler(): from PTMCMCSampler import PTMCMCSampler + return PTMCMCSampler + @signal_wrapper def run_sampler(self): PTMCMCSampler = self._import_external_sampler() - sampler = PTMCMCSampler.PTSampler(ndim=self.ndim, logp=self.log_prior, - logl=self.log_likelihood, cov=np.eye(self.ndim), - **self.sampler_init_kwargs) + sampler = PTMCMCSampler.PTSampler( + ndim=self.ndim, + logp=self.log_prior, + logl=self.log_likelihood, + cov=np.eye(self.ndim), + **self.sampler_init_kwargs, + ) if self.custom_proposals is not None: for proposal in self.custom_proposals: - logger.info('Adding {} to proposals with weight {}'.format( - proposal, self.custom_proposals[proposal][1])) - sampler.addProposalToCycle(self.custom_proposals[proposal][0], - self.custom_proposals[proposal][1]) + logger.info( + f"Adding {proposal} to proposals with weight {self.custom_proposals[proposal][1]}" + ) + sampler.addProposalToCycle( + self.custom_proposals[proposal][0], + self.custom_proposals[proposal][1], + ) sampler.sample(p0=self.p0, **self.sampler_function_kwargs) samples, meta, loglike = self.__read_in_data() self.calc_likelihood_count() - self.result.nburn = self.sampler_function_kwargs['burn'] - self.result.samples = samples[self.result.nburn:] - self.meta_data['sampler_meta'] = meta - self.result.log_likelihood_evaluations = loglike[self.result.nburn:] + self.result.nburn = self.sampler_function_kwargs["burn"] + self.result.samples = samples[self.result.nburn :] + self.meta_data["sampler_meta"] = meta + self.result.log_likelihood_evaluations = loglike[self.result.nburn :] self.result.sampler_output = np.nan self.result.walkers = np.nan self.result.log_evidence = np.nan @@ -162,30 +212,34 @@ class PTMCMCSampler(MCMCSampler): return self.result def __read_in_data(self): - """ Read the data stored by PTMCMC to disk """ - temp_outDir = self.sampler_init_kwargs['outDir'] + """Read the data stored by PTMCMC to disk""" + temp_outDir = self.sampler_init_kwargs["outDir"] try: - data = np.loadtxt('{}chain_1.txt'.format(temp_outDir)) + data = np.loadtxt(f"{temp_outDir}chain_1.txt") except OSError: - data = np.loadtxt('{}chain_1.0.txt'.format(temp_outDir)) - jumpfiles = glob.glob('{}/*jump.txt'.format(temp_outDir)) + data = np.loadtxt(f"{temp_outDir}chain_1.0.txt") + jumpfiles = glob.glob(f"{temp_outDir}/*jump.txt") jumps = map(np.loadtxt, jumpfiles) samples = data[:, :-4] loglike = data[:, -3] jump_accept = {} for ct, j in enumerate(jumps): - label = jumpfiles[ct].split('/')[-1].split('_jump.txt')[0] + label = jumpfiles[ct].split("/")[-1].split("_jump.txt")[0] jump_accept[label] = j - PT_swap = {'swap_accept': data[:, -1]} - tot_accept = {'tot_accept': data[:, -2]} - log_post = {'log_post': data[:, -4]} + PT_swap = {"swap_accept": data[:, -1]} + tot_accept = {"tot_accept": data[:, -2]} + log_post = {"log_post": data[:, -4]} meta = {} - meta['tot_accept'] = tot_accept - meta['PT_swap'] = PT_swap - meta['proposals'] = jump_accept - meta['log_post'] = log_post + meta["tot_accept"] = tot_accept + meta["PT_swap"] = PT_swap + meta["proposals"] = jump_accept + meta["log_post"] = log_post shutil.rmtree(temp_outDir) return samples, meta, loglike + + def write_current_state(self): + """TODO: implement a checkpointing method""" + pass diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py index 4ff4b232adcfb3432a1f8ce2973fb24fc4778d7b..1c6a2790a74c61001577bc8606ee0834a1b80455 100644 --- a/bilby/core/sampler/pymc3.py +++ b/bilby/core/sampler/pymc3.py @@ -2,16 +2,20 @@ from distutils.version import StrictVersion import numpy as np +from ...gw.likelihood import BasicGravitationalWaveTransient, GravitationalWaveTransient +from ..likelihood import ( + ExponentialLikelihood, + GaussianLikelihood, + PoissonLikelihood, + StudentTLikelihood, +) +from ..prior import Cosine, DeltaFunction, MultivariateGaussian, PowerLaw, Sine from ..utils import derivatives, infer_args_from_method -from ..prior import DeltaFunction, Sine, Cosine, PowerLaw, MultivariateGaussian from .base_sampler import MCMCSampler -from ..likelihood import GaussianLikelihood, PoissonLikelihood, ExponentialLikelihood, \ - StudentTLikelihood -from ...gw.likelihood import BasicGravitationalWaveTransient, GravitationalWaveTransient class Pymc3(MCMCSampler): - """ bilby wrapper of the PyMC3 sampler (https://docs.pymc.io/) + """bilby wrapper of the PyMC3 sampler (https://docs.pymc.io/) All keyword arguments (i.e., the kwargs) passed to `run_sampler` will be propapated to `pymc3.sample` where appropriate, see documentation for that @@ -51,39 +55,77 @@ class Pymc3(MCMCSampler): """ default_kwargs = dict( - draws=500, step=None, init='auto', n_init=200000, start=None, trace=None, chain_idx=0, - chains=2, cores=1, tune=500, progressbar=True, model=None, random_seed=None, - discard_tuned_samples=True, compute_convergence_checks=True, nuts_kwargs=None, + draws=500, + step=None, + init="auto", + n_init=200000, + start=None, + trace=None, + chain_idx=0, + chains=2, + cores=1, + tune=500, + progressbar=True, + model=None, + random_seed=None, + discard_tuned_samples=True, + compute_convergence_checks=True, + nuts_kwargs=None, step_kwargs=None, ) default_nuts_kwargs = dict( - target_accept=None, max_treedepth=None, step_scale=None, Emax=None, - gamma=None, k=None, t0=None, adapt_step_size=None, early_max_treedepth=None, - scaling=None, is_cov=None, potential=None, + target_accept=None, + max_treedepth=None, + step_scale=None, + Emax=None, + gamma=None, + k=None, + t0=None, + adapt_step_size=None, + early_max_treedepth=None, + scaling=None, + is_cov=None, + potential=None, ) default_kwargs.update(default_nuts_kwargs) - def __init__(self, likelihood, priors, outdir='outdir', label='label', - use_ratio=False, plot=False, - skip_import_verification=False, **kwargs): + def __init__( + self, + likelihood, + priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + skip_import_verification=False, + **kwargs, + ): # add default step kwargs _, STEP_METHODS, _ = self._import_external_sampler() self.default_step_kwargs = {m.__name__.lower(): None for m in STEP_METHODS} self.default_kwargs.update(self.default_step_kwargs) - super(Pymc3, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label, - use_ratio=use_ratio, plot=plot, - skip_import_verification=skip_import_verification, **kwargs) - self.draws = self._kwargs['draws'] - self.chains = self._kwargs['chains'] + super(Pymc3, self).__init__( + likelihood=likelihood, + priors=priors, + outdir=outdir, + label=label, + use_ratio=use_ratio, + plot=plot, + skip_import_verification=skip_import_verification, + **kwargs, + ) + self.draws = self._kwargs["draws"] + self.chains = self._kwargs["chains"] @staticmethod def _import_external_sampler(): import pymc3 from pymc3.sampling import STEP_METHODS from pymc3.theanof import floatX + return pymc3, STEP_METHODS, floatX @staticmethod @@ -91,6 +133,7 @@ class Pymc3(MCMCSampler): import theano # noqa import theano.tensor as tt from theano.compile.ops import as_op # noqa + return theano, tt, as_op def _verify_parameters(self): @@ -116,77 +159,79 @@ class Pymc3(MCMCSampler): self.prior_map = prior_map # predefined PyMC3 distributions - prior_map['Gaussian'] = { - 'pymc3': 'Normal', - 'argmap': {'mu': 'mu', 'sigma': 'sd'}} - prior_map['TruncatedGaussian'] = { - 'pymc3': 'TruncatedNormal', - 'argmap': {'mu': 'mu', - 'sigma': 'sd', - 'minimum': 'lower', - 'maximum': 'upper'}} - prior_map['HalfGaussian'] = { - 'pymc3': 'HalfNormal', - 'argmap': {'sigma': 'sd'}} - prior_map['Uniform'] = { - 'pymc3': 'Uniform', - 'argmap': {'minimum': 'lower', - 'maximum': 'upper'}} - prior_map['LogNormal'] = { - 'pymc3': 'Lognormal', - 'argmap': {'mu': 'mu', - 'sigma': 'sd'}} - prior_map['Exponential'] = { - 'pymc3': 'Exponential', - 'argmap': {'mu': 'lam'}, - 'argtransform': {'mu': lambda mu: 1. / mu}} - prior_map['StudentT'] = { - 'pymc3': 'StudentT', - 'argmap': {'df': 'nu', - 'mu': 'mu', - 'scale': 'sd'}} - prior_map['Beta'] = { - 'pymc3': 'Beta', - 'argmap': {'alpha': 'alpha', - 'beta': 'beta'}} - prior_map['Logistic'] = { - 'pymc3': 'Logistic', - 'argmap': {'mu': 'mu', - 'scale': 's'}} - prior_map['Cauchy'] = { - 'pymc3': 'Cauchy', - 'argmap': {'alpha': 'alpha', - 'beta': 'beta'}} - prior_map['Gamma'] = { - 'pymc3': 'Gamma', - 'argmap': {'k': 'alpha', - 'theta': 'beta'}, - 'argtransform': {'theta': lambda theta: 1. / theta}} - prior_map['ChiSquared'] = { - 'pymc3': 'ChiSquared', - 'argmap': {'nu': 'nu'}} - prior_map['Interped'] = { - 'pymc3': 'Interpolated', - 'argmap': {'xx': 'x_points', - 'yy': 'pdf_points'}} - prior_map['Normal'] = prior_map['Gaussian'] - prior_map['TruncatedNormal'] = prior_map['TruncatedGaussian'] - prior_map['HalfNormal'] = prior_map['HalfGaussian'] - prior_map['LogGaussian'] = prior_map['LogNormal'] - prior_map['Lorentzian'] = prior_map['Cauchy'] - prior_map['FromFile'] = prior_map['Interped'] + prior_map["Gaussian"] = { + "pymc3": "Normal", + "argmap": {"mu": "mu", "sigma": "sd"}, + } + prior_map["TruncatedGaussian"] = { + "pymc3": "TruncatedNormal", + "argmap": { + "mu": "mu", + "sigma": "sd", + "minimum": "lower", + "maximum": "upper", + }, + } + prior_map["HalfGaussian"] = {"pymc3": "HalfNormal", "argmap": {"sigma": "sd"}} + prior_map["Uniform"] = { + "pymc3": "Uniform", + "argmap": {"minimum": "lower", "maximum": "upper"}, + } + prior_map["LogNormal"] = { + "pymc3": "Lognormal", + "argmap": {"mu": "mu", "sigma": "sd"}, + } + prior_map["Exponential"] = { + "pymc3": "Exponential", + "argmap": {"mu": "lam"}, + "argtransform": {"mu": lambda mu: 1.0 / mu}, + } + prior_map["StudentT"] = { + "pymc3": "StudentT", + "argmap": {"df": "nu", "mu": "mu", "scale": "sd"}, + } + prior_map["Beta"] = { + "pymc3": "Beta", + "argmap": {"alpha": "alpha", "beta": "beta"}, + } + prior_map["Logistic"] = { + "pymc3": "Logistic", + "argmap": {"mu": "mu", "scale": "s"}, + } + prior_map["Cauchy"] = { + "pymc3": "Cauchy", + "argmap": {"alpha": "alpha", "beta": "beta"}, + } + prior_map["Gamma"] = { + "pymc3": "Gamma", + "argmap": {"k": "alpha", "theta": "beta"}, + "argtransform": {"theta": lambda theta: 1.0 / theta}, + } + prior_map["ChiSquared"] = {"pymc3": "ChiSquared", "argmap": {"nu": "nu"}} + prior_map["Interped"] = { + "pymc3": "Interpolated", + "argmap": {"xx": "x_points", "yy": "pdf_points"}, + } + prior_map["Normal"] = prior_map["Gaussian"] + prior_map["TruncatedNormal"] = prior_map["TruncatedGaussian"] + prior_map["HalfNormal"] = prior_map["HalfGaussian"] + prior_map["LogGaussian"] = prior_map["LogNormal"] + prior_map["Lorentzian"] = prior_map["Cauchy"] + prior_map["FromFile"] = prior_map["Interped"] # GW specific priors - prior_map['UniformComovingVolume'] = prior_map['Interped'] + prior_map["UniformComovingVolume"] = prior_map["Interped"] # internally defined mappings for bilby priors - prior_map['DeltaFunction'] = {'internal': self._deltafunction_prior} - prior_map['Sine'] = {'internal': self._sine_prior} - prior_map['Cosine'] = {'internal': self._cosine_prior} - prior_map['PowerLaw'] = {'internal': self._powerlaw_prior} - prior_map['LogUniform'] = {'internal': self._powerlaw_prior} - prior_map['MultivariateGaussian'] = {'internal': self._multivariate_normal_prior} - prior_map['MultivariateNormal'] = {'internal': self._multivariate_normal_prior} + prior_map["DeltaFunction"] = {"internal": self._deltafunction_prior} + prior_map["Sine"] = {"internal": self._sine_prior} + prior_map["Cosine"] = {"internal": self._cosine_prior} + prior_map["PowerLaw"] = {"internal": self._powerlaw_prior} + prior_map["LogUniform"] = {"internal": self._powerlaw_prior} + prior_map["MultivariateGaussian"] = { + "internal": self._multivariate_normal_prior + } + prior_map["MultivariateNormal"] = {"internal": self._multivariate_normal_prior} def _deltafunction_prior(self, key, **kwargs): """ @@ -197,7 +242,7 @@ class Pymc3(MCMCSampler): if isinstance(self.priors[key], DeltaFunction): return self.priors[key].peak else: - raise ValueError("Prior for '{}' is not a DeltaFunction".format(key)) + raise ValueError(f"Prior for '{key}' is not a DeltaFunction") def _sine_prior(self, key): """ @@ -210,20 +255,22 @@ class Pymc3(MCMCSampler): if isinstance(self.priors[key], Sine): class Pymc3Sine(pymc3.Continuous): - def __init__(self, lower=0., upper=np.pi): + def __init__(self, lower=0.0, upper=np.pi): if lower >= upper: raise ValueError("Lower bound is above upper bound!") # set the mode self.lower = lower = tt.as_tensor_variable(floatX(lower)) self.upper = upper = tt.as_tensor_variable(floatX(upper)) - self.norm = (tt.cos(lower) - tt.cos(upper)) - self.mean = \ - (tt.sin(upper) + lower * tt.cos(lower) - - tt.sin(lower) - upper * tt.cos(upper)) / self.norm + self.norm = tt.cos(lower) - tt.cos(upper) + self.mean = ( + tt.sin(upper) + + lower * tt.cos(lower) + - tt.sin(lower) + - upper * tt.cos(upper) + ) / self.norm - transform = pymc3.distributions.transforms.interval(lower, - upper) + transform = pymc3.distributions.transforms.interval(lower, upper) super(Pymc3Sine, self).__init__(transform=transform) @@ -232,12 +279,15 @@ class Pymc3(MCMCSampler): lower = self.lower return pymc3.distributions.dist_math.bound( tt.log(tt.sin(value) / self.norm), - lower <= value, value <= upper) + lower <= value, + value <= upper, + ) - return Pymc3Sine(key, lower=self.priors[key].minimum, - upper=self.priors[key].maximum) + return Pymc3Sine( + key, lower=self.priors[key].minimum, upper=self.priors[key].maximum + ) else: - raise ValueError("Prior for '{}' is not a Sine".format(key)) + raise ValueError(f"Prior for '{key}' is not a Sine") def _cosine_prior(self, key): """ @@ -250,19 +300,21 @@ class Pymc3(MCMCSampler): if isinstance(self.priors[key], Cosine): class Pymc3Cosine(pymc3.Continuous): - def __init__(self, lower=-np.pi / 2., upper=np.pi / 2.): + def __init__(self, lower=-np.pi / 2.0, upper=np.pi / 2.0): if lower >= upper: raise ValueError("Lower bound is above upper bound!") self.lower = lower = tt.as_tensor_variable(floatX(lower)) self.upper = upper = tt.as_tensor_variable(floatX(upper)) - self.norm = (tt.sin(upper) - tt.sin(lower)) - self.mean = \ - (upper * tt.sin(upper) + tt.cos(upper) - - lower * tt.sin(lower) - tt.cos(lower)) / self.norm + self.norm = tt.sin(upper) - tt.sin(lower) + self.mean = ( + upper * tt.sin(upper) + + tt.cos(upper) + - lower * tt.sin(lower) + - tt.cos(lower) + ) / self.norm - transform = pymc3.distributions.transforms.interval(lower, - upper) + transform = pymc3.distributions.transforms.interval(lower, upper) super(Pymc3Cosine, self).__init__(transform=transform) @@ -271,12 +323,15 @@ class Pymc3(MCMCSampler): lower = self.lower return pymc3.distributions.dist_math.bound( tt.log(tt.cos(value) / self.norm), - lower <= value, value <= upper) + lower <= value, + value <= upper, + ) - return Pymc3Cosine(key, lower=self.priors[key].minimum, - upper=self.priors[key].maximum) + return Pymc3Cosine( + key, lower=self.priors[key].minimum, upper=self.priors[key].maximum + ) else: - raise ValueError("Prior for '{}' is not a Cosine".format(key)) + raise ValueError(f"Prior for '{key}' is not a Cosine") def _powerlaw_prior(self, key): """ @@ -289,17 +344,18 @@ class Pymc3(MCMCSampler): if isinstance(self.priors[key], PowerLaw): # check power law is set - if not hasattr(self.priors[key], 'alpha'): + if not hasattr(self.priors[key], "alpha"): raise AttributeError("No 'alpha' attribute set for PowerLaw prior") - if self.priors[key].alpha < -1.: + if self.priors[key].alpha < -1.0: # use Pareto distribution - palpha = -(1. + self.priors[key].alpha) + palpha = -(1.0 + self.priors[key].alpha) - return pymc3.Bound( - pymc3.Pareto, upper=self.priors[key].minimum)( - key, alpha=palpha, m=self.priors[key].maximum) + return pymc3.Bound(pymc3.Pareto, upper=self.priors[key].minimum)( + key, alpha=palpha, m=self.priors[key].maximum + ) else: + class Pymc3PowerLaw(pymc3.Continuous): def __init__(self, lower, upper, alpha, testval=1): falpha = alpha @@ -308,17 +364,21 @@ class Pymc3(MCMCSampler): self.alpha = alpha = tt.as_tensor_variable(floatX(alpha)) if falpha == -1: - self.norm = 1. / (tt.log(self.upper / self.lower)) + self.norm = 1.0 / (tt.log(self.upper / self.lower)) else: - beta = (1. + self.alpha) - self.norm = 1. / (beta * (tt.pow(self.upper, beta) - - tt.pow(self.lower, beta))) + beta = 1.0 + self.alpha + self.norm = 1.0 / ( + beta + * (tt.pow(self.upper, beta) - tt.pow(self.lower, beta)) + ) transform = pymc3.distributions.transforms.interval( - lower, upper) + lower, upper + ) super(Pymc3PowerLaw, self).__init__( - transform=transform, testval=testval) + transform=transform, testval=testval + ) def logp(self, value): upper = self.upper @@ -327,13 +387,18 @@ class Pymc3(MCMCSampler): return pymc3.distributions.dist_math.bound( alpha * tt.log(value) + tt.log(self.norm), - lower <= value, value <= upper) - - return Pymc3PowerLaw(key, lower=self.priors[key].minimum, - upper=self.priors[key].maximum, - alpha=self.priors[key].alpha) + lower <= value, + value <= upper, + ) + + return Pymc3PowerLaw( + key, + lower=self.priors[key].minimum, + upper=self.priors[key].maximum, + alpha=self.priors[key].alpha, + ) else: - raise ValueError("Prior for '{}' is not a Power Law".format(key)) + raise ValueError(f"Prior for '{key}' is not a Power Law") def _multivariate_normal_prior(self, key): """ @@ -359,14 +424,14 @@ class Pymc3(MCMCSampler): testvals = [] for bound in mvg.bounds.values(): if np.isinf(bound[0]) and np.isinf(bound[1]): - testvals.append(0.) + testvals.append(0.0) elif np.isinf(bound[0]): - testvals.append(bound[1] - 1.) + testvals.append(bound[1] - 1.0) elif np.isinf(bound[1]): - testvals.append(bound[0] + 1.) + testvals.append(bound[0] + 1.0) else: # half-way between the two bounds - testvals.append(bound[0] + (bound[1] - bound[0]) / 2.) + testvals.append(bound[0] + (bound[1] - bound[0]) / 2.0) # if bounds are at +/-infinity set to 100 sigmas as infinities # cause problems for the Bound class @@ -375,44 +440,54 @@ class Pymc3(MCMCSampler): maxsigma = np.max(mvg.sigmas, axis=0) for i in range(len(mvpars)): if np.isinf(lower[i]): - lower[i] = minmu[i] - 100. * maxsigma[i] + lower[i] = minmu[i] - 100.0 * maxsigma[i] if np.isinf(upper[i]): - upper[i] = maxmu[i] + 100. * maxsigma[i] + upper[i] = maxmu[i] + 100.0 * maxsigma[i] # create a bounded MultivariateNormal distribution BoundedMvN = pymc3.Bound(pymc3.MvNormal, lower=lower, upper=upper) comp_dists = [] # list of any component modes for i in range(mvg.nmodes): - comp_dists.append(BoundedMvN('comp{}'.format(i), mu=mvg.mus[i], - cov=mvg.covs[i], - shape=len(mvpars)).distribution) + comp_dists.append( + BoundedMvN( + f"comp{i}", + mu=mvg.mus[i], + cov=mvg.covs[i], + shape=len(mvpars), + ).distribution + ) # create a Mixture model - setname = 'mixture{}'.format(self.multivariate_normal_num_sets) - mix = pymc3.Mixture(setname, w=mvg.weights, comp_dists=comp_dists, - shape=len(mvpars), testval=testvals) + setname = f"mixture{self.multivariate_normal_num_sets}" + mix = pymc3.Mixture( + setname, + w=mvg.weights, + comp_dists=comp_dists, + shape=len(mvpars), + testval=testvals, + ) for i, p in enumerate(mvpars): self.multivariate_normal_sets[p] = {} - self.multivariate_normal_sets[p]['prior'] = mix[i] - self.multivariate_normal_sets[p]['set'] = setname - self.multivariate_normal_sets[p]['index'] = i + self.multivariate_normal_sets[p]["prior"] = mix[i] + self.multivariate_normal_sets[p]["set"] = setname + self.multivariate_normal_sets[p]["index"] = i self.multivariate_normal_num_sets += 1 # return required parameter - return self.multivariate_normal_sets[key]['prior'] + return self.multivariate_normal_sets[key]["prior"] else: - raise ValueError("Prior for '{}' is not a MultivariateGaussian".format(key)) + raise ValueError(f"Prior for '{key}' is not a MultivariateGaussian") def run_sampler(self): # set the step method pymc3, STEP_METHODS, floatX = self._import_external_sampler() step_methods = {m.__name__.lower(): m.__name__ for m in STEP_METHODS} - if 'step' in self._kwargs: - self.step_method = self._kwargs.pop('step') + if "step" in self._kwargs: + self.step_method = self._kwargs.pop("step") # 'step' could be a dictionary of methods for different parameters, # so check for this @@ -421,7 +496,9 @@ class Pymc3(MCMCSampler): elif isinstance(self.step_method, dict): for key in self.step_method: if key not in self._search_parameter_keys: - raise ValueError("Setting a step method for an unknown parameter '{}'".format(key)) + raise ValueError( + f"Setting a step method for an unknown parameter '{key}'" + ) else: # check if using a compound step (a list of step # methods for a particular parameter) @@ -431,7 +508,9 @@ class Pymc3(MCMCSampler): sms = [self.step_method[key]] for sm in sms: if sm.lower() not in step_methods: - raise ValueError("Using invalid step method '{}'".format(self.step_method[key])) + raise ValueError( + f"Using invalid step method '{self.step_method[key]}'" + ) else: # check if using a compound step (a list of step # methods for a particular parameter) @@ -442,7 +521,7 @@ class Pymc3(MCMCSampler): for i in range(len(sms)): if sms[i].lower() not in step_methods: - raise ValueError("Using invalid step method '{}'".format(sms[i])) + raise ValueError(f"Using invalid step method '{sms[i]}'") else: self.step_method = None @@ -457,7 +536,7 @@ class Pymc3(MCMCSampler): # takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines # the likelihood within that context manager likeargs = infer_args_from_method(self.likelihood.log_likelihood) - if 'sampler' in likeargs: + if "sampler" in likeargs: self.likelihood.log_likelihood(sampler=self) else: # set the likelihood function from predefined functions @@ -498,7 +577,7 @@ class Pymc3(MCMCSampler): # set the step method if isinstance(self.step_method, dict): # create list of step methods (any not given will default to NUTS) - self.kwargs['step'] = [] + self.kwargs["step"] = [] with self.pymc3_model: for key in self.step_method: # check for a compound step list @@ -506,13 +585,25 @@ class Pymc3(MCMCSampler): for sms in self.step_method[key]: curmethod = sms.lower() methodslist.append(curmethod) - nuts_kwargs = self._create_nuts_kwargs(curmethod, key, nuts_kwargs, pymc3, step_kwargs, - step_methods) + nuts_kwargs = self._create_nuts_kwargs( + curmethod, + key, + nuts_kwargs, + pymc3, + step_kwargs, + step_methods, + ) else: curmethod = self.step_method[key].lower() methodslist.append(curmethod) - nuts_kwargs = self._create_nuts_kwargs(curmethod, key, nuts_kwargs, pymc3, step_kwargs, - step_methods) + nuts_kwargs = self._create_nuts_kwargs( + curmethod, + key, + nuts_kwargs, + pymc3, + step_kwargs, + step_methods, + ) else: with self.pymc3_model: # check for a compound step list @@ -521,28 +612,38 @@ class Pymc3(MCMCSampler): for sms in self.step_method: curmethod = sms.lower() methodslist.append(curmethod) - args, nuts_kwargs = self._create_args_and_nuts_kwargs(curmethod, nuts_kwargs, step_kwargs) + args, nuts_kwargs = self._create_args_and_nuts_kwargs( + curmethod, nuts_kwargs, step_kwargs + ) compound.append(pymc3.__dict__[step_methods[curmethod]](**args)) - self.kwargs['step'] = compound + self.kwargs["step"] = compound else: - self.kwargs['step'] = None + self.kwargs["step"] = None if self.step_method is not None: curmethod = self.step_method.lower() methodslist.append(curmethod) - args, nuts_kwargs = self._create_args_and_nuts_kwargs(curmethod, nuts_kwargs, step_kwargs) - self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args) + args, nuts_kwargs = self._create_args_and_nuts_kwargs( + curmethod, nuts_kwargs, step_kwargs + ) + self.kwargs["step"] = pymc3.__dict__[step_methods[curmethod]]( + **args + ) else: # re-add step_kwargs if no step methods are set - if len(step_kwargs) > 0 and StrictVersion(pymc3.__version__) < StrictVersion("3.7"): - self.kwargs['step_kwargs'] = step_kwargs + if len(step_kwargs) > 0 and StrictVersion( + pymc3.__version__ + ) < StrictVersion("3.7"): + self.kwargs["step_kwargs"] = step_kwargs # check whether only NUTS step method has been assigned - if np.all([sm.lower() == 'nuts' for sm in methodslist]): + if np.all([sm.lower() == "nuts" for sm in methodslist]): # in this case we can let PyMC3 autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs - self.kwargs['step'] = None + self.kwargs["step"] = None - if len(nuts_kwargs) > 0 and StrictVersion(pymc3.__version__) < StrictVersion("3.7"): - self.kwargs['nuts_kwargs'] = nuts_kwargs + if len(nuts_kwargs) > 0 and StrictVersion( + pymc3.__version__ + ) < StrictVersion("3.7"): + self.kwargs["nuts_kwargs"] = nuts_kwargs elif len(nuts_kwargs) > 0: # add NUTS kwargs to standard kwargs self.kwargs.update(nuts_kwargs) @@ -564,22 +665,27 @@ class Pymc3(MCMCSampler): return self.result def _create_args_and_nuts_kwargs(self, curmethod, nuts_kwargs, step_kwargs): - if curmethod == 'nuts': + if curmethod == "nuts": args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs) else: args = step_kwargs.get(curmethod, {}) return args, nuts_kwargs - def _create_nuts_kwargs(self, curmethod, key, nuts_kwargs, pymc3, step_kwargs, step_methods): - if curmethod == 'nuts': + def _create_nuts_kwargs( + self, curmethod, key, nuts_kwargs, pymc3, step_kwargs, step_methods + ): + if curmethod == "nuts": args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs) else: if step_kwargs is not None: args = step_kwargs.get(curmethod, {}) else: args = {} - self.kwargs['step'].append( - pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args)) + self.kwargs["step"].append( + pymc3.__dict__[step_methods[curmethod]]( + vars=[self.pymc3_priors[key]], **args + ) + ) return nuts_kwargs @staticmethod @@ -587,7 +693,7 @@ class Pymc3(MCMCSampler): if nuts_kwargs is not None: args = nuts_kwargs elif step_kwargs is not None: - args = step_kwargs.pop('nuts', {}) + args = step_kwargs.pop("nuts", {}) # add values into nuts_kwargs nuts_kwargs = args else: @@ -618,55 +724,85 @@ class Pymc3(MCMCSampler): # if the prior contains ln_prob method that takes a 'sampler' argument # then try using that lnprobargs = infer_args_from_method(self.priors[key].ln_prob) - if 'sampler' in lnprobargs: + if "sampler" in lnprobargs: try: self.pymc3_priors[key] = self.priors[key].ln_prob(sampler=self) except RuntimeError: - raise RuntimeError(("Problem setting PyMC3 prior for ", - "'{}'".format(key))) + raise RuntimeError( + ("Problem setting PyMC3 prior for ", f"'{key}'") + ) else: # use Prior distribution name distname = self.priors[key].__class__.__name__ if distname in self.prior_map: # check if we have a predefined PyMC3 distribution - if 'pymc3' in self.prior_map[distname] and 'argmap' in self.prior_map[distname]: + if ( + "pymc3" in self.prior_map[distname] + and "argmap" in self.prior_map[distname] + ): # check the required arguments for the PyMC3 distribution - pymc3distname = self.prior_map[distname]['pymc3'] + pymc3distname = self.prior_map[distname]["pymc3"] if pymc3distname not in pymc3.__dict__: - raise ValueError("Prior '{}' is not a known PyMC3 distribution.".format(pymc3distname)) + raise ValueError( + f"Prior '{pymc3distname}' is not a known PyMC3 distribution." + ) - reqargs = infer_args_from_method(pymc3.__dict__[pymc3distname].__init__) + reqargs = infer_args_from_method( + pymc3.__dict__[pymc3distname].__init__ + ) # set keyword arguments priorkwargs = {} - for (targ, parg) in self.prior_map[distname]['argmap'].items(): + for (targ, parg) in self.prior_map[distname][ + "argmap" + ].items(): if hasattr(self.priors[key], targ): if parg in reqargs: - if 'argtransform' in self.prior_map[distname]: - if targ in self.prior_map[distname]['argtransform']: - tfunc = self.prior_map[distname]['argtransform'][targ] + if "argtransform" in self.prior_map[distname]: + if ( + targ + in self.prior_map[distname][ + "argtransform" + ] + ): + tfunc = self.prior_map[distname][ + "argtransform" + ][targ] else: + def tfunc(x): return x + else: + def tfunc(x): return x - priorkwargs[parg] = tfunc(getattr(self.priors[key], targ)) + priorkwargs[parg] = tfunc( + getattr(self.priors[key], targ) + ) else: - raise ValueError("Unknown argument {}".format(parg)) + raise ValueError(f"Unknown argument {parg}") else: if parg in reqargs: priorkwargs[parg] = None - self.pymc3_priors[key] = pymc3.__dict__[pymc3distname](key, **priorkwargs) - elif 'internal' in self.prior_map[distname]: - self.pymc3_priors[key] = self.prior_map[distname]['internal'](key) + self.pymc3_priors[key] = pymc3.__dict__[pymc3distname]( + key, **priorkwargs + ) + elif "internal" in self.prior_map[distname]: + self.pymc3_priors[key] = self.prior_map[distname][ + "internal" + ](key) else: - raise ValueError("Prior '{}' is not a known distribution.".format(distname)) + raise ValueError( + f"Prior '{distname}' is not a known distribution." + ) else: - raise ValueError("Prior '{}' is not a known distribution.".format(distname)) + raise ValueError( + f"Prior '{distname}' is not a known distribution." + ) def set_likelihood(self): """ @@ -692,17 +828,19 @@ class Pymc3(MCMCSampler): if isinstance(self.priors[key], float): self.likelihood.parameters[key] = self.priors[key] - self.logpgrad = LogLikeGrad(self.parameters, self.likelihood, self.priors) + self.logpgrad = LogLikeGrad( + self.parameters, self.likelihood, self.priors + ) def perform(self, node, inputs, outputs): - theta, = inputs + (theta,) = inputs for i, key in enumerate(self.parameters): self.likelihood.parameters[key] = theta[i] outputs[0][0] = np.array(self.likelihood.log_likelihood()) def grad(self, inputs, g): - theta, = inputs + (theta,) = inputs return [g[0] * self.logpgrad(theta)] # create theano Op for calculating the gradient of the log likelihood @@ -723,7 +861,7 @@ class Pymc3(MCMCSampler): self.likelihood.parameters[key] = self.priors[key] def perform(self, node, inputs, outputs): - theta, = inputs + (theta,) = inputs # define version of likelihood function to pass to derivative function def lnlike(values): @@ -732,7 +870,9 @@ class Pymc3(MCMCSampler): return self.likelihood.log_likelihood() # calculate gradients - grads = derivatives(theta, lnlike, abseps=1e-5, mineps=1e-12, reltol=1e-2) + grads = derivatives( + theta, lnlike, abseps=1e-5, mineps=1e-12, reltol=1e-2 + ) outputs[0][0] = grads @@ -740,84 +880,114 @@ class Pymc3(MCMCSampler): # check if it is a predefined likelhood function if isinstance(self.likelihood, GaussianLikelihood): # check required attributes exist - if (not hasattr(self.likelihood, 'sigma') or - not hasattr(self.likelihood, 'x') or - not hasattr(self.likelihood, 'y')): - raise ValueError("Gaussian Likelihood does not have all the correct attributes!") - - if 'sigma' in self.pymc3_priors: + if ( + not hasattr(self.likelihood, "sigma") + or not hasattr(self.likelihood, "x") + or not hasattr(self.likelihood, "y") + ): + raise ValueError( + "Gaussian Likelihood does not have all the correct attributes!" + ) + + if "sigma" in self.pymc3_priors: # if sigma is suppled use that value if self.likelihood.sigma is None: - self.likelihood.sigma = self.pymc3_priors.pop('sigma') + self.likelihood.sigma = self.pymc3_priors.pop("sigma") else: - del self.pymc3_priors['sigma'] + del self.pymc3_priors["sigma"] for key in self.pymc3_priors: if key not in self.likelihood.function_keys: - raise ValueError("Prior key '{}' is not a function key!".format(key)) + raise ValueError(f"Prior key '{key}' is not a function key!") model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors) # set the distribution - pymc3.Normal('likelihood', mu=model, sd=self.likelihood.sigma, - observed=self.likelihood.y) + pymc3.Normal( + "likelihood", + mu=model, + sd=self.likelihood.sigma, + observed=self.likelihood.y, + ) elif isinstance(self.likelihood, PoissonLikelihood): # check required attributes exist - if (not hasattr(self.likelihood, 'x') or - not hasattr(self.likelihood, 'y')): - raise ValueError("Poisson Likelihood does not have all the correct attributes!") + if not hasattr(self.likelihood, "x") or not hasattr( + self.likelihood, "y" + ): + raise ValueError( + "Poisson Likelihood does not have all the correct attributes!" + ) for key in self.pymc3_priors: if key not in self.likelihood.function_keys: - raise ValueError("Prior key '{}' is not a function key!".format(key)) + raise ValueError(f"Prior key '{key}' is not a function key!") # get rate function model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors) # set the distribution - pymc3.Poisson('likelihood', mu=model, observed=self.likelihood.y) + pymc3.Poisson("likelihood", mu=model, observed=self.likelihood.y) elif isinstance(self.likelihood, ExponentialLikelihood): # check required attributes exist - if (not hasattr(self.likelihood, 'x') or - not hasattr(self.likelihood, 'y')): - raise ValueError("Exponential Likelihood does not have all the correct attributes!") + if not hasattr(self.likelihood, "x") or not hasattr( + self.likelihood, "y" + ): + raise ValueError( + "Exponential Likelihood does not have all the correct attributes!" + ) for key in self.pymc3_priors: if key not in self.likelihood.function_keys: - raise ValueError("Prior key '{}' is not a function key!".format(key)) + raise ValueError(f"Prior key '{key}' is not a function key!") # get mean function model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors) # set the distribution - pymc3.Exponential('likelihood', lam=1. / model, observed=self.likelihood.y) + pymc3.Exponential( + "likelihood", lam=1.0 / model, observed=self.likelihood.y + ) elif isinstance(self.likelihood, StudentTLikelihood): # check required attributes exist - if (not hasattr(self.likelihood, 'x') or - not hasattr(self.likelihood, 'y') or - not hasattr(self.likelihood, 'nu') or - not hasattr(self.likelihood, 'sigma')): - raise ValueError("StudentT Likelihood does not have all the correct attributes!") - - if 'nu' in self.pymc3_priors: + if ( + not hasattr(self.likelihood, "x") + or not hasattr(self.likelihood, "y") + or not hasattr(self.likelihood, "nu") + or not hasattr(self.likelihood, "sigma") + ): + raise ValueError( + "StudentT Likelihood does not have all the correct attributes!" + ) + + if "nu" in self.pymc3_priors: # if nu is suppled use that value if self.likelihood.nu is None: - self.likelihood.nu = self.pymc3_priors.pop('nu') + self.likelihood.nu = self.pymc3_priors.pop("nu") else: - del self.pymc3_priors['nu'] + del self.pymc3_priors["nu"] for key in self.pymc3_priors: if key not in self.likelihood.function_keys: - raise ValueError("Prior key '{}' is not a function key!".format(key)) + raise ValueError(f"Prior key '{key}' is not a function key!") model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors) # set the distribution - pymc3.StudentT('likelihood', nu=self.likelihood.nu, mu=model, sd=self.likelihood.sigma, - observed=self.likelihood.y) - elif isinstance(self.likelihood, (GravitationalWaveTransient, BasicGravitationalWaveTransient)): + pymc3.StudentT( + "likelihood", + nu=self.likelihood.nu, + mu=model, + sd=self.likelihood.sigma, + observed=self.likelihood.y, + ) + elif isinstance( + self.likelihood, + (GravitationalWaveTransient, BasicGravitationalWaveTransient), + ): # set theano Op - pass _search_parameter_keys, which only contains non-fixed variables - logl = LogLike(self._search_parameter_keys, self.likelihood, self.pymc3_priors) + logl = LogLike( + self._search_parameter_keys, self.likelihood, self.pymc3_priors + ) parameters = dict() for key in self._search_parameter_keys: @@ -825,11 +995,14 @@ class Pymc3(MCMCSampler): parameters[key] = self.pymc3_priors[key] except KeyError: raise KeyError( - "Unknown key '{}' when setting GravitationalWaveTransient likelihood".format(key)) + f"Unknown key '{key}' when setting GravitationalWaveTransient likelihood" + ) # convert to theano tensor variable values = tt.as_tensor_variable(list(parameters.values())) - pymc3.DensityDist('likelihood', lambda v: logl(v), observed={'v': values}) + pymc3.DensityDist( + "likelihood", lambda v: logl(v), observed={"v": values} + ) else: raise ValueError("Unknown likelihood has been provided") diff --git a/bilby/core/sampler/pymultinest.py b/bilby/core/sampler/pymultinest.py index d9869362256ad60c8029acdd0e19020a31dea8f8..da6e7a9778273f0bb25a2702bddcbc2fd18f4ae3 100644 --- a/bilby/core/sampler/pymultinest.py +++ b/bilby/core/sampler/pymultinest.py @@ -1,20 +1,15 @@ +import datetime import importlib import os -import shutil -import distutils.dir_util -import signal import time -import datetime -import sys import numpy as np -from ..utils import check_directory_exists_and_if_not_mkdir from ..utils import logger -from .base_sampler import NestedSampler +from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper -class Pymultinest(NestedSampler): +class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler): """ bilby wrapper of pymultinest (https://github.com/JohannesBuchner/PyMultiNest) @@ -65,6 +60,8 @@ class Pymultinest(NestedSampler): init_MPI=False, dump_callback=None, ) + short_name = "pm" + hard_exit = True def __init__( self, @@ -94,6 +91,7 @@ class Pymultinest(NestedSampler): plot=plot, skip_import_verification=skip_import_verification, exit_code=exit_code, + temporary_directory=temporary_directory, **kwargs ) self._apply_multinest_boundaries() @@ -105,10 +103,6 @@ class Pymultinest(NestedSampler): ) self.use_temporary_directory = temporary_directory and not using_mpi - signal.signal(signal.SIGTERM, self.write_current_state_and_exit) - signal.signal(signal.SIGINT, self.write_current_state_and_exit) - signal.signal(signal.SIGALRM, self.write_current_state_and_exit) - def _translate_kwargs(self, kwargs): if "n_live_points" not in kwargs: for equiv in self.npoints_equiv_kwargs: @@ -141,74 +135,7 @@ class Pymultinest(NestedSampler): else: self.kwargs["wrapped_params"].append(0) - @property - def outputfiles_basename(self): - return self._outputfiles_basename - - @outputfiles_basename.setter - def outputfiles_basename(self, outputfiles_basename): - if outputfiles_basename is None: - outputfiles_basename = "{}/pm_{}/".format(self.outdir, self.label) - if not outputfiles_basename.endswith("/"): - outputfiles_basename += "/" - check_directory_exists_and_if_not_mkdir(self.outdir) - self._outputfiles_basename = outputfiles_basename - - @property - def temporary_outputfiles_basename(self): - return self._temporary_outputfiles_basename - - @temporary_outputfiles_basename.setter - def temporary_outputfiles_basename(self, temporary_outputfiles_basename): - if not temporary_outputfiles_basename.endswith("/"): - temporary_outputfiles_basename = "{}/".format( - temporary_outputfiles_basename - ) - self._temporary_outputfiles_basename = temporary_outputfiles_basename - if os.path.exists(self.outputfiles_basename): - shutil.copytree( - self.outputfiles_basename, self.temporary_outputfiles_basename - ) - - def write_current_state_and_exit(self, signum=None, frame=None): - """Write current state and exit on exit_code""" - logger.info( - "Run interrupted by signal {}: checkpoint and exit on {}".format( - signum, self.exit_code - ) - ) - self._calculate_and_save_sampling_time() - if self.use_temporary_directory: - self._move_temporary_directory_to_proper_path() - sys.exit(self.exit_code) - - def _copy_temporary_directory_contents_to_proper_path(self): - """ - Copy the temporary back to the proper path. - Do not delete the temporary directory. - """ - logger.info( - "Overwriting {} with {}".format( - self.outputfiles_basename, self.temporary_outputfiles_basename - ) - ) - if self.outputfiles_basename.endswith("/"): - outputfiles_basename_stripped = self.outputfiles_basename[:-1] - else: - outputfiles_basename_stripped = self.outputfiles_basename - distutils.dir_util.copy_tree( - self.temporary_outputfiles_basename, outputfiles_basename_stripped - ) - - def _move_temporary_directory_to_proper_path(self): - """ - Copy the temporary back to the proper path - - Anything in the temporary directory at this point is removed - """ - self._copy_temporary_directory_contents_to_proper_path() - shutil.rmtree(self.temporary_outputfiles_basename) - + @signal_wrapper def run_sampler(self): import pymultinest @@ -247,27 +174,6 @@ class Pymultinest(NestedSampler): self.result.nested_samples = self._nested_samples return self.result - def _check_and_load_sampling_time_file(self): - self.time_file_path = self.kwargs["outputfiles_basename"] + "/sampling_time.dat" - if os.path.exists(self.time_file_path): - with open(self.time_file_path, "r") as time_file: - self.total_sampling_time = float(time_file.readline()) - else: - self.total_sampling_time = 0 - - def _calculate_and_save_sampling_time(self): - current_time = time.time() - new_sampling_time = current_time - self.start_time - self.total_sampling_time += new_sampling_time - self.start_time = current_time - with open(self.time_file_path, "w") as time_file: - time_file.write(str(self.total_sampling_time)) - - def _clean_up_run_directory(self): - if self.use_temporary_directory: - self._move_temporary_directory_to_proper_path() - self.kwargs["outputfiles_basename"] = self.outputfiles_basename - @property def _nested_samples(self): """ diff --git a/bilby/core/sampler/ultranest.py b/bilby/core/sampler/ultranest.py index 2348319e4b6048eb2669e798e94c6a06af9b6e0e..fc70b38ad4f0b04169254e8031c0a996087abdd2 100644 --- a/bilby/core/sampler/ultranest.py +++ b/bilby/core/sampler/ultranest.py @@ -1,20 +1,15 @@ - import datetime -import distutils.dir_util import inspect -import os -import shutil -import signal import time import numpy as np from pandas import DataFrame -from ..utils import check_directory_exists_and_if_not_mkdir, logger -from .base_sampler import NestedSampler +from ..utils import logger +from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper -class Ultranest(NestedSampler): +class Ultranest(_TemporaryFileSamplerMixin, NestedSampler): """ bilby wrapper of ultranest (https://johannesbuchner.github.io/UltraNest/index.html) @@ -73,6 +68,8 @@ class Ultranest(NestedSampler): step_sampler=None, ) + short_name = "ultra" + def __init__( self, likelihood, @@ -96,31 +93,31 @@ class Ultranest(NestedSampler): plot=plot, skip_import_verification=skip_import_verification, exit_code=exit_code, + temporary_directory=temporary_directory, **kwargs, ) self._apply_ultranest_boundaries() - self.use_temporary_directory = temporary_directory if self.use_temporary_directory: # set callback interval, so copying of results does not thrash the # disk (ultranest will call viz_callback quite a lot) self.callback_interval = callback_interval - signal.signal(signal.SIGTERM, self.write_current_state_and_exit) - signal.signal(signal.SIGINT, self.write_current_state_and_exit) - signal.signal(signal.SIGALRM, self.write_current_state_and_exit) - def _translate_kwargs(self, kwargs): if "num_live_points" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: kwargs["num_live_points"] = kwargs.pop(equiv) - if "verbose" in kwargs and "show_status" not in kwargs: kwargs["show_status"] = kwargs.pop("verbose") + resume = kwargs.get("resume", False) + if resume is True: + kwargs["resume"] = "overwrite" + elif resume is False: + kwargs["resume"] = "overwrite" def _verify_kwargs_against_default_kwargs(self): - """ Check the kwargs """ + """Check the kwargs""" self.outputfiles_basename = self.kwargs.pop("log_dir", None) if self.kwargs["viz_callback"] is None: @@ -148,76 +145,13 @@ class Ultranest(NestedSampler): else: self.kwargs["wrapped_params"].append(0) - @property - def outputfiles_basename(self): - return self._outputfiles_basename - - @outputfiles_basename.setter - def outputfiles_basename(self, outputfiles_basename): - if outputfiles_basename is None: - outputfiles_basename = os.path.join( - self.outdir, "ultra_{}/".format(self.label) - ) - if not outputfiles_basename.endswith("/"): - outputfiles_basename += "/" - check_directory_exists_and_if_not_mkdir(self.outdir) - self._outputfiles_basename = outputfiles_basename - - @property - def temporary_outputfiles_basename(self): - return self._temporary_outputfiles_basename - - @temporary_outputfiles_basename.setter - def temporary_outputfiles_basename(self, temporary_outputfiles_basename): - if not temporary_outputfiles_basename.endswith("/"): - temporary_outputfiles_basename = "{}/".format( - temporary_outputfiles_basename - ) - self._temporary_outputfiles_basename = temporary_outputfiles_basename - if os.path.exists(self.outputfiles_basename): - shutil.copytree( - self.outputfiles_basename, self.temporary_outputfiles_basename - ) - - def write_current_state_and_exit(self, signum=None, frame=None): - """ Write current state and exit on exit_code """ - logger.info( - "Run interrupted by signal {}: checkpoint and exit on {}".format( - signum, self.exit_code - ) - ) - self._calculate_and_save_sampling_time() - if self.use_temporary_directory: - self._move_temporary_directory_to_proper_path() - os._exit(self.exit_code) - def _copy_temporary_directory_contents_to_proper_path(self): """ Copy the temporary back to the proper path. Do not delete the temporary directory. """ if inspect.stack()[1].function != "_viz_callback": - logger.info( - "Overwriting {} with {}".format( - self.outputfiles_basename, self.temporary_outputfiles_basename - ) - ) - if self.outputfiles_basename.endswith("/"): - outputfiles_basename_stripped = self.outputfiles_basename[:-1] - else: - outputfiles_basename_stripped = self.outputfiles_basename - distutils.dir_util.copy_tree( - self.temporary_outputfiles_basename, outputfiles_basename_stripped - ) - - def _move_temporary_directory_to_proper_path(self): - """ - Move the temporary back to the proper path - - Anything in the proper path at this point is removed including links - """ - self._copy_temporary_directory_contents_to_proper_path() - shutil.rmtree(self.temporary_outputfiles_basename) + super(Ultranest, self)._copy_temporary_directory_contents_to_proper_path() @property def sampler_function_kwargs(self): @@ -271,6 +205,7 @@ class Ultranest(NestedSampler): return init_kwargs + @signal_wrapper def run_sampler(self): import ultranest import ultranest.stepsampler @@ -285,7 +220,7 @@ class Ultranest(NestedSampler): stepsampler = self.kwargs.pop("step_sampler", None) self._setup_run_directory() - self.kwargs["log_dir"] = self.kwargs.pop("outputfiles_basename") + self.kwargs["log_dir"] = self.kwargs["outputfiles_basename"] self._check_and_load_sampling_time_file() # use reactive nested sampler when no live points are given @@ -317,7 +252,6 @@ class Ultranest(NestedSampler): results = sampler.run(**self.sampler_function_kwargs) self._calculate_and_save_sampling_time() - # Clean up self._clean_up_run_directory() self._generate_result(results) @@ -325,27 +259,6 @@ class Ultranest(NestedSampler): return self.result - def _clean_up_run_directory(self): - if self.use_temporary_directory: - self._move_temporary_directory_to_proper_path() - self.kwargs["log_dir"] = self.outputfiles_basename - - def _check_and_load_sampling_time_file(self): - self.time_file_path = os.path.join(self.kwargs["log_dir"], "sampling_time.dat") - if os.path.exists(self.time_file_path): - with open(self.time_file_path, "r") as time_file: - self.total_sampling_time = float(time_file.readline()) - else: - self.total_sampling_time = 0 - - def _calculate_and_save_sampling_time(self): - current_time = time.time() - new_sampling_time = current_time - self.start_time - self.total_sampling_time += new_sampling_time - with open(self.time_file_path, "w") as time_file: - time_file.write(str(self.total_sampling_time)) - self.start_time = current_time - def _generate_result(self, out): # extract results data = np.array(out["weighted_samples"]["points"]) @@ -357,16 +270,22 @@ class Ultranest(NestedSampler): nested_samples = DataFrame(data, columns=self.search_parameter_keys) nested_samples["weights"] = weights nested_samples["log_likelihood"] = out["weighted_samples"]["logl"] - self.result.log_likelihood_evaluations = np.array(out["weighted_samples"]["logl"])[ - mask - ] + self.result.log_likelihood_evaluations = np.array( + out["weighted_samples"]["logl"] + )[mask] self.result.sampler_output = out self.result.samples = data[mask, :] self.result.nested_samples = nested_samples self.result.log_evidence = out["logz"] self.result.log_evidence_err = out["logzerr"] if self.kwargs["num_live_points"] is not None: - self.result.information_gain = np.power(out["logzerr"], 2) * self.kwargs["num_live_points"] + self.result.information_gain = ( + np.power(out["logzerr"], 2) * self.kwargs["num_live_points"] + ) self.result.outputfiles_basename = self.outputfiles_basename self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time) + + def log_likelihood(self, theta): + log_l = super(Ultranest, self).log_likelihood(theta=theta) + return np.nan_to_num(log_l) diff --git a/bilby/core/sampler/zeus.py b/bilby/core/sampler/zeus.py index 78c3529ea00c1c9ee518a8698e2f70bc29fe194d..c7ae40da222201e5b29c53635c16c3edc94744f0 100644 --- a/bilby/core/sampler/zeus.py +++ b/bilby/core/sampler/zeus.py @@ -1,18 +1,17 @@ import os -import signal import shutil -import sys -from collections import namedtuple from shutil import copyfile import numpy as np -from pandas import DataFrame -from ..utils import logger, check_directory_exists_and_if_not_mkdir -from .base_sampler import MCMCSampler, SamplerError +from .base_sampler import SamplerError, signal_wrapper +from .emcee import Emcee +from .ptemcee import LikePriorEvaluator +_evaluator = LikePriorEvaluator() -class Zeus(MCMCSampler): + +class Zeus(Emcee): """bilby wrapper for Zeus (https://zeus-mcmc.readthedocs.io/) All positional and keyword arguments (i.e., the args and kwargs) passed to @@ -65,12 +64,8 @@ class Zeus(MCMCSampler): burn_in_fraction=0.25, resume=True, burn_in_act=3, - **kwargs + **kwargs, ): - import zeus - - self.zeus = zeus - super(Zeus, self).__init__( likelihood=likelihood, priors=priors, @@ -79,25 +74,16 @@ class Zeus(MCMCSampler): use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification, - **kwargs + pos0=pos0, + nburn=nburn, + burn_in_fraction=burn_in_fraction, + resume=resume, + burn_in_act=burn_in_act, + **kwargs, ) - self.resume = resume - self.pos0 = pos0 - self.nburn = nburn - self.burn_in_fraction = burn_in_fraction - self.burn_in_act = burn_in_act - - signal.signal(signal.SIGTERM, self.checkpoint_and_exit) - signal.signal(signal.SIGINT, self.checkpoint_and_exit) def _translate_kwargs(self, kwargs): - if "nwalkers" not in kwargs: - for equiv in self.nwalkers_equiv_kwargs: - if equiv in kwargs: - kwargs["nwalkers"] = kwargs.pop(equiv) - if "iterations" not in kwargs: - if "nsteps" in kwargs: - kwargs["iterations"] = kwargs.pop("nsteps") + super(Zeus, self)._translate_kwargs(kwargs=kwargs) # check if using emcee-style arguments if "start" not in kwargs: @@ -107,17 +93,6 @@ class Zeus(MCMCSampler): if "lnprob0" in kwargs: kwargs["log_prob0"] = kwargs.pop("lnprob0") - if "threads" in kwargs: - if kwargs["threads"] != 1: - logger.warning( - "The 'threads' argument cannot be used for " - "parallelisation. This run will proceed " - "without parallelisation, but consider the use " - "of an appropriate Pool object passed to the " - "'pool' keyword." - ) - kwargs["threads"] = 1 - @property def sampler_function_kwargs(self): keys = ["log_prob0", "start", "blobs0", "iterations", "thin", "progress"] @@ -134,168 +109,21 @@ class Zeus(MCMCSampler): if key not in self.sampler_function_kwargs } - init_kwargs["logprob_fn"] = self.lnpostfn + init_kwargs["logprob_fn"] = _evaluator.call_emcee init_kwargs["ndim"] = self.ndim return init_kwargs - def lnpostfn(self, theta): - log_prior = self.log_prior(theta) - if np.isinf(log_prior): - return -np.inf, [np.nan, np.nan] - else: - log_likelihood = self.log_likelihood(theta) - return log_likelihood + log_prior, [log_likelihood, log_prior] - - @property - def nburn(self): - if type(self.__nburn) in [float, int]: - return int(self.__nburn) - elif self.result.max_autocorrelation_time is None: - return int(self.burn_in_fraction * self.nsteps) - else: - return int(self.burn_in_act * self.result.max_autocorrelation_time) - - @nburn.setter - def nburn(self, nburn): - if isinstance(nburn, (float, int)): - if nburn > self.kwargs["iterations"] - 1: - raise ValueError( - "Number of burn-in samples must be smaller " - "than the total number of iterations" - ) - - self.__nburn = nburn - - @property - def nwalkers(self): - return self.kwargs["nwalkers"] - - @property - def nsteps(self): - return self.kwargs["iterations"] - - @nsteps.setter - def nsteps(self, nsteps): - self.kwargs["iterations"] = nsteps - - @property - def stored_chain(self): - """Read the stored zero-temperature chain data in from disk""" - return np.genfromtxt(self.checkpoint_info.chain_file, names=True) - - @property - def stored_samples(self): - """Returns the samples stored on disk""" - return self.stored_chain[self.search_parameter_keys] - - @property - def stored_loglike(self): - """Returns the log-likelihood stored on disk""" - return self.stored_chain["log_l"] - - @property - def stored_logprior(self): - """Returns the log-prior stored on disk""" - return self.stored_chain["log_p"] - - def _init_chain_file(self): - with open(self.checkpoint_info.chain_file, "w+") as ff: - ff.write( - "walker\t{}\tlog_l\tlog_p\n".format( - "\t".join(self.search_parameter_keys) - ) - ) - - @property - def checkpoint_info(self): - """Defines various things related to checkpointing and storing data - - Returns - ======= - checkpoint_info: named_tuple - An object with attributes `sampler_file`, `chain_file`, and - `chain_template`. The first two give paths to where the sampler and - chain data is stored, the last a formatted-str-template with which - to write the chain data to disk - - """ - out_dir = os.path.join( - self.outdir, "{}_{}".format(self.__class__.__name__.lower(), self.label) - ) - check_directory_exists_and_if_not_mkdir(out_dir) - - chain_file = os.path.join(out_dir, "chain.dat") - sampler_file = os.path.join(out_dir, "sampler.pickle") - chain_template = ( - "{:d}" + "\t{:.9e}" * (len(self.search_parameter_keys) + 2) + "\n" - ) - - CheckpointInfo = namedtuple( - "CheckpointInfo", ["sampler_file", "chain_file", "chain_template"] - ) - - checkpoint_info = CheckpointInfo( - sampler_file=sampler_file, - chain_file=chain_file, - chain_template=chain_template, - ) - - return checkpoint_info - - @property - def sampler_chain(self): - nsteps = self._previous_iterations - return self.sampler.chain[:, :nsteps, :] - - def checkpoint(self): - """Writes a pickle file of the sampler to disk using dill""" - import dill - - logger.info( - "Checkpointing sampler to file {}".format(self.checkpoint_info.sampler_file) - ) - with open(self.checkpoint_info.sampler_file, "wb") as f: - # Overwrites the stored sampler chain with one that is truncated - # to only the completed steps - self.sampler._chain = self.sampler_chain - dill.dump(self._sampler, f) - - def checkpoint_and_exit(self, signum, frame): - logger.info("Received signal {}".format(signum)) - self.checkpoint() - sys.exit() + def write_current_state(self): + self._sampler.distribute = map + super(Zeus, self).write_current_state() + self._sampler.distribute = getattr(self._sampler.pool, "map", map) def _initialise_sampler(self): - self._sampler = self.zeus.EnsembleSampler(**self.sampler_init_kwargs) - self._init_chain_file() + from zeus import EnsembleSampler - @property - def sampler(self): - """Returns the Zeus sampler object - - If, already initialized, returns the stored _sampler value. Otherwise, - first checks if there is a pickle file from which to load. If there is - not, then initialize the sampler and set the initial random draw - - """ - if hasattr(self, "_sampler"): - pass - elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file): - import dill - - logger.info( - "Resuming run from checkpoint file {}".format( - self.checkpoint_info.sampler_file - ) - ) - with open(self.checkpoint_info.sampler_file, "rb") as f: - self._sampler = dill.load(f) - self._set_pos0_for_resume() - else: - self._initialise_sampler() - self._set_pos0() - return self._sampler + self._sampler = EnsembleSampler(**self.sampler_init_kwargs) + self._init_chain_file() def write_chains_to_file(self, sample): chain_file = self.checkpoint_info.chain_file @@ -310,48 +138,12 @@ class Zeus(MCMCSampler): ff.write(self.checkpoint_info.chain_template.format(ii, *point)) shutil.move(temp_chain_file, chain_file) - @property - def _previous_iterations(self): - """Returns the number of iterations that the sampler has saved - - This is used when loading in a sampler from a pickle file to figure out - how much of the run has already been completed - """ - try: - return len(self.sampler.get_blobs()) - except AttributeError: - return 0 - - def _draw_pos0_from_prior(self): - return np.array( - [self.get_random_draw_from_prior() for _ in range(self.nwalkers)] - ) - - @property - def _pos0_shape(self): - return (self.nwalkers, self.ndim) - - def _set_pos0(self): - if self.pos0 is not None: - logger.debug("Using given initial positions for walkers") - if isinstance(self.pos0, DataFrame): - self.pos0 = self.pos0[self.search_parameter_keys].values - elif type(self.pos0) in (list, np.ndarray): - self.pos0 = np.squeeze(self.pos0) - - if self.pos0.shape != self._pos0_shape: - raise ValueError("Input pos0 should be of shape ndim, nwalkers") - logger.debug("Checking input pos0") - for draw in self.pos0: - self.check_draw(draw) - else: - logger.debug("Generating initial walker positions from prior") - self.pos0 = self._draw_pos0_from_prior() - def _set_pos0_for_resume(self): self.pos0 = self.sampler.get_last_sample() + @signal_wrapper def run_sampler(self): + self._setup_pool() sampler_function_kwargs = self.sampler_function_kwargs iterations = sampler_function_kwargs.pop("iterations") iterations -= self._previous_iterations @@ -363,7 +155,8 @@ class Zeus(MCMCSampler): iterations=iterations, **sampler_function_kwargs ): self.write_chains_to_file(sample) - self.checkpoint() + self._close_pool() + self.write_current_state() self.result.sampler_output = np.nan self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim))) @@ -381,10 +174,12 @@ class Zeus(MCMCSampler): if self.result.nburn > self.nsteps: raise SamplerError( "The run has finished, but the chain is not burned in: " - "`nburn < nsteps` ({} < {}). Try increasing the " - "number of steps.".format(self.result.nburn, self.nsteps) + f"`nburn < nsteps` ({self.result.nburn} < {self.nsteps})." + " Try increasing the number of steps." ) - blobs = np.array(self.sampler.get_blobs(flat=True, discard=self.nburn)).reshape((-1, 2)) + blobs = np.array(self.sampler.get_blobs(flat=True, discard=self.nburn)).reshape( + (-1, 2) + ) log_likelihoods, log_priors = blobs.T self.result.log_likelihood_evaluations = log_likelihoods self.result.log_prior_evaluations = log_priors diff --git a/containers/dockerfile-template b/containers/dockerfile-template index a594a9a996c33e50a2c873e42585f6269c3eda12..f4eca79d814edd4ecb9a6ba2e2fbd31bbeafa2ad 100644 --- a/containers/dockerfile-template +++ b/containers/dockerfile-template @@ -33,6 +33,7 @@ RUN pip install corner healpy cython tables RUN conda install -n ${{conda_env}} -c conda-forge dynesty emcee nestle ptemcee RUN conda install -n ${{conda_env}} -c conda-forge pymultinest ultranest RUN conda install -n ${{conda_env}} -c conda-forge cpnest kombine dnest4 zeus-mcmc +RUN conda install -n ${{conda_env}} -c conda-forge ptmcmcsampler RUN conda install -n ${{conda_env}} -c conda-forge pytorch RUN conda install -n ${{conda_env}} -c conda-forge theano-pymc RUN conda install -n ${{conda_env}} -c conda-forge pymc3 @@ -49,10 +50,6 @@ RUN apt-get install -y gfortran RUN git clone https://github.com/PolyChord/PolyChordLite.git \ && (cd PolyChordLite && python setup.py --no-mpi install) -# Install PTMCMCSampler -RUN git clone https://github.com/jellis18/PTMCMCSampler.git \ -&& (cd PTMCMCSampler && python setup.py install) - # Install GW packages RUN conda install -n ${{conda_env}} -c conda-forge python-lalsimulation bilby.cython RUN pip install ligo-gracedb gwpy ligo.skymap diff --git a/containers/v3-dockerfile-test-suite-python38 b/containers/v3-dockerfile-test-suite-python38 index 04452eac65334c2b813219d66033802a98ed7215..4ee250e8b5792e2c96c78342a974c6a4fd8b97d9 100644 --- a/containers/v3-dockerfile-test-suite-python38 +++ b/containers/v3-dockerfile-test-suite-python38 @@ -35,6 +35,7 @@ RUN pip install corner healpy cython tables RUN conda install -n ${conda_env} -c conda-forge dynesty emcee nestle ptemcee RUN conda install -n ${conda_env} -c conda-forge pymultinest ultranest RUN conda install -n ${conda_env} -c conda-forge cpnest kombine dnest4 zeus-mcmc +RUN conda install -n ${conda_env} -c conda-forge ptmcmcsampler RUN conda install -n ${conda_env} -c conda-forge pytorch RUN conda install -n ${conda_env} -c conda-forge theano-pymc RUN conda install -n ${conda_env} -c conda-forge pymc3 @@ -51,10 +52,6 @@ RUN apt-get install -y gfortran RUN git clone https://github.com/PolyChord/PolyChordLite.git \ && (cd PolyChordLite && python setup.py --no-mpi install) -# Install PTMCMCSampler -RUN git clone https://github.com/jellis18/PTMCMCSampler.git \ -&& (cd PTMCMCSampler && python setup.py install) - # Install GW packages RUN conda install -n ${conda_env} -c conda-forge python-lalsimulation bilby.cython RUN pip install ligo-gracedb gwpy ligo.skymap diff --git a/containers/v3-dockerfile-test-suite-python39 b/containers/v3-dockerfile-test-suite-python39 index abaa39f235771b5489e2a8c647072599e7a93e1d..218b3d3e24b3b65b278a95958ffb0444b4f85d7c 100644 --- a/containers/v3-dockerfile-test-suite-python39 +++ b/containers/v3-dockerfile-test-suite-python39 @@ -35,6 +35,7 @@ RUN pip install corner healpy cython tables RUN conda install -n ${conda_env} -c conda-forge dynesty emcee nestle ptemcee RUN conda install -n ${conda_env} -c conda-forge pymultinest ultranest RUN conda install -n ${conda_env} -c conda-forge cpnest kombine dnest4 zeus-mcmc +RUN conda install -n ${conda_env} -c conda-forge ptmcmcsampler RUN conda install -n ${conda_env} -c conda-forge pytorch RUN conda install -n ${conda_env} -c conda-forge theano-pymc RUN conda install -n ${conda_env} -c conda-forge pymc3 @@ -51,10 +52,6 @@ RUN apt-get install -y gfortran RUN git clone https://github.com/PolyChord/PolyChordLite.git \ && (cd PolyChordLite && python setup.py --no-mpi install) -# Install PTMCMCSampler -RUN git clone https://github.com/jellis18/PTMCMCSampler.git \ -&& (cd PTMCMCSampler && python setup.py install) - # Install GW packages RUN conda install -n ${conda_env} -c conda-forge python-lalsimulation bilby.cython RUN pip install ligo-gracedb gwpy ligo.skymap diff --git a/test/bilby_mcmc/test_sampler.py b/test/bilby_mcmc/test_sampler.py index aa52967da16cbfbf754279deed3c83139632d1e7..746eb1a9e1150732e93d5c31664751040e7b639c 100644 --- a/test/bilby_mcmc/test_sampler.py +++ b/test/bilby_mcmc/test_sampler.py @@ -3,7 +3,7 @@ import shutil import unittest import bilby -from bilby.bilby_mcmc.sampler import Bilby_MCMC, BilbyMCMCSampler, _initialize_global_variables +from bilby.bilby_mcmc.sampler import Bilby_MCMC, BilbyMCMCSampler from bilby.bilby_mcmc.utils import ConvergenceInputs from bilby.core.sampler.base_sampler import SamplerError import numpy as np @@ -44,7 +44,12 @@ class TestBilbyMCMCSampler(unittest.TestCase): search_parameter_keys = ['m', 'c'] use_ratio = False - _initialize_global_variables(likelihood, priors, search_parameter_keys, use_ratio) + bilby.core.sampler.base_sampler._initialize_global_variables( + likelihood, + priors, + search_parameter_keys, + use_ratio, + ) def tearDown(self): if os.path.isdir(self.outdir): diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py index 30be5e2ba542205d4cdbef10c6f8e9d681bcbbf1..3a1059e0dd82ac27b1897713b6b40f9f702bf644 100644 --- a/test/core/sampler/base_sampler_test.py +++ b/test/core/sampler/base_sampler_test.py @@ -1,7 +1,9 @@ import copy import os +import shutil import unittest from unittest.mock import MagicMock +from parameterized import parameterized import numpy as np @@ -102,22 +104,100 @@ class TestSampler(unittest.TestCase): self.sampler._check_bad_value(val=np.nan, warning=False, theta=None, label=None) def test_bad_value_np_abs_nan(self): - self.sampler._check_bad_value(val=np.abs(np.nan), warning=False, theta=None, label=None) + self.sampler._check_bad_value( + val=np.abs(np.nan), warning=False, theta=None, label=None + ) def test_bad_value_abs_nan(self): - self.sampler._check_bad_value(val=abs(np.nan), warning=False, theta=None, label=None) + self.sampler._check_bad_value( + val=abs(np.nan), warning=False, theta=None, label=None + ) def test_bad_value_pos_inf(self): self.sampler._check_bad_value(val=np.inf, warning=False, theta=None, label=None) def test_bad_value_neg_inf(self): - self.sampler._check_bad_value(val=-np.inf, warning=False, theta=None, label=None) + self.sampler._check_bad_value( + val=-np.inf, warning=False, theta=None, label=None + ) def test_bad_value_pos_inf_nan_to_num(self): - self.sampler._check_bad_value(val=np.nan_to_num(np.inf), warning=False, theta=None, label=None) + self.sampler._check_bad_value( + val=np.nan_to_num(np.inf), warning=False, theta=None, label=None + ) def test_bad_value_neg_inf_nan_to_num(self): - self.sampler._check_bad_value(val=np.nan_to_num(-np.inf), warning=False, theta=None, label=None) + self.sampler._check_bad_value( + val=np.nan_to_num(-np.inf), warning=False, theta=None, label=None + ) + + +samplers = [ + "bilby_mcmc", + "dynamic_dynesty", + "dynesty", + "emcee", + "kombine", + "ptemcee", + "zeus", +] + + +class GenericSamplerTest(unittest.TestCase): + def setUp(self): + self.likelihood = bilby.core.likelihood.Likelihood(dict()) + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) + ) + + def tearDown(self): + if os.path.isdir("outdir"): + shutil.rmtree("outdir") + + @parameterized.expand(samplers) + def test_pool_creates_properly_no_pool(self, sampler_name): + sampler = bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler_name]( + self.likelihood, self.priors + ) + sampler._setup_pool() + if sampler_name == "kombine": + from kombine import SerialPool + + self.assertIsInstance(sampler.pool, SerialPool) + pass + else: + self.assertIsNone(sampler.pool) + + @parameterized.expand(samplers) + def test_pool_creates_properly_pool(self, sampler): + sampler = bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler]( + self.likelihood, self.priors, npool=2 + ) + sampler._setup_pool() + if hasattr(sampler, "setup_sampler"): + sampler.setup_sampler() + self.assertEqual(sampler.pool._processes, 2) + sampler._close_pool() + + +class ReorderLikelihoodsTest(unittest.TestCase): + def setUp(self): + self.unsorted_ln_likelihoods = np.array([1, 5, 2, 5, 1]) + self.unsorted_samples = np.array([[0, 1], [1, 1], [1, 0], [0, 0], [0, 1]]) + self.sorted_samples = np.array([[0, 1], [0, 1], [1, 0], [1, 1], [0, 0]]) + self.sorted_ln_likelihoods = np.array([1, 1, 2, 5, 5]) + + def tearDown(self): + pass + + def test_ordering(self): + func = bilby.core.sampler.base_sampler.NestedSampler.reorder_loglikelihoods + sorted_ln_likelihoods = func( + self.unsorted_ln_likelihoods, self.unsorted_samples, self.sorted_samples + ) + self.assertTrue( + np.array_equal(sorted_ln_likelihoods, self.sorted_ln_likelihoods) + ) if __name__ == "__main__": diff --git a/test/core/sampler/ultranest_test.py b/test/core/sampler/ultranest_test.py index dc578cd71932c877f0de8414361781cc86837789..be22c1a1f50b8d304000fcb8d0e4816e57c9c1b9 100644 --- a/test/core/sampler/ultranest_test.py +++ b/test/core/sampler/ultranest_test.py @@ -28,7 +28,7 @@ class TestUltranest(unittest.TestCase): def test_default_kwargs(self): expected = dict( - resume=True, + resume="overwrite", show_status=True, num_live_points=None, wrapped_params=None, @@ -63,7 +63,7 @@ class TestUltranest(unittest.TestCase): def test_translate_kwargs(self): expected = dict( - resume=True, + resume="overwrite", show_status=True, num_live_points=123, wrapped_params=None, diff --git a/test/integration/sampler_run_test.py b/test/integration/sampler_run_test.py index 2bf1d355e03cf48a91a3f9ff29b97ec47f10264e..3aa2157c04a6ad5d4008fde4a60710225e59bcb3 100644 --- a/test/integration/sampler_run_test.py +++ b/test/integration/sampler_run_test.py @@ -1,34 +1,100 @@ +import multiprocessing +import os +import sys +import threading +import time +from signal import SIGINT + +multiprocessing.set_start_method("fork") # noqa + import unittest import pytest +from parameterized import parameterized import shutil import bilby import numpy as np +_sampler_kwargs = dict( + bilby_mcmc=dict(nsamples=200, printdt=1), + cpnest=dict(nlive=100), + dnest4=dict( + max_num_levels=2, + num_steps=10, + new_level_interval=10, + num_per_step=10, + thread_steps=1, + num_particles=50, + max_pool=1, + ), + dynesty=dict(nlive=100), + dynamic_dynesty=dict( + nlive_init=100, + nlive_batch=100, + dlogz_init=1.0, + maxbatch=0, + maxcall=100, + bound="single", + ), + emcee=dict(iterations=1000, nwalkers=10), + kombine=dict(iterations=200, nwalkers=10, autoburnin=False), + nessai=dict( + nlive=100, + poolsize=1000, + max_iteration=1000, + max_threads=3, + ), + nestle=dict(nlive=100), + ptemcee=dict( + nsamples=100, + nwalkers=50, + burn_in_act=1, + ntemps=1, + frac_threshold=0.5, + ), + PTMCMCSampler=dict(Niter=101, burn=2, isave=100), + # pymc3=dict(draws=50, tune=50, n_init=250), removed until testing issue can be resolved + pymultinest=dict(nlive=100), + pypolychord=dict(nlive=100), + ultranest=dict(nlive=100, temporary_directory=False), +) + +sampler_imports = dict( + bilby_mcmc="bilby", + dynamic_dynesty="dynesty" +) + +no_pool_test = ["dnest4", "pymultinest", "nestle", "ptmcmcsampler", "pypolychord", "ultranest"] + + +def slow_func(x, m, c): + time.sleep(0.01) + return m * x + c + + +def model(x, m, c): + return m * x + c + + class TestRunningSamplers(unittest.TestCase): def setUp(self): np.random.seed(42) bilby.core.utils.command_line_args.bilby_test_mode = False self.x = np.linspace(0, 1, 11) - self.model = lambda x, m, c: m * x + c self.injection_parameters = dict(m=0.5, c=0.2) self.sigma = 0.1 - self.y = self.model(self.x, **self.injection_parameters) + np.random.normal( + self.y = model(self.x, **self.injection_parameters) + np.random.normal( 0, self.sigma, len(self.x) ) self.likelihood = bilby.likelihood.GaussianLikelihood( - self.x, self.y, self.model, self.sigma + self.x, self.y, model, self.sigma ) self.priors = bilby.core.prior.PriorDict() self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic") self.priors["c"] = bilby.core.prior.Uniform(-2, 2, boundary="reflective") - self.kwargs = dict( - save=False, - conversion_function=self.conversion_function, - verbose=True, - ) + self._remove_tree() bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir") @staticmethod @@ -42,226 +108,83 @@ class TestRunningSamplers(unittest.TestCase): del self.likelihood del self.priors bilby.core.utils.command_line_args.bilby_test_mode = False - shutil.rmtree("outdir") - - def test_run_cpnest(self): - pytest.importorskip("cpnest") - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="cpnest", - nlive=100, - resume=False, - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - def test_run_dnest4(self): - pytest.importorskip("dnest4") - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="dnest4", - max_num_levels=2, - num_steps=10, - new_level_interval=10, - num_per_step=10, - thread_steps=1, - num_particles=50, - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - def test_run_dynesty(self): - pytest.importorskip("dynesty") - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="dynesty", - nlive=100, - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - def test_run_dynamic_dynesty(self): - pytest.importorskip("dynesty") - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="dynamic_dynesty", - nlive_init=100, - nlive_batch=100, - dlogz_init=1.0, - maxbatch=0, - maxcall=100, - bound="single", - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - def test_run_emcee(self): - pytest.importorskip("emcee") - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="emcee", - iterations=1000, - nwalkers=10, - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - def test_run_kombine(self): - pytest.importorskip("kombine") - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="kombine", - iterations=2000, - nwalkers=20, - autoburnin=False, - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - def test_run_nestle(self): - pytest.importorskip("nestle") - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="nestle", - nlive=100, - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - def test_run_nessai(self): - pytest.importorskip("nessai") - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="nessai", - nlive=100, - poolsize=1000, - max_iteration=1000, - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - def test_run_pypolychord(self): - pytest.importorskip("pypolychord") - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="pypolychord", - nlive=100, - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - def test_run_ptemcee(self): - pytest.importorskip("ptemcee") - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="ptemcee", - nsamples=100, - nwalkers=50, - burn_in_act=1, - ntemps=1, - frac_threshold=0.5, - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - @pytest.mark.xfail( - raises=AttributeError, - reason="Dependency issue with pymc3 causes attribute error on import", - ) - def test_run_pymc3(self): - pytest.importorskip("pymc3") - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="pymc3", - draws=50, - tune=50, - n_init=250, - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - def test_run_pymultinest(self): - pytest.importorskip("pymultinest") - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="pymultinest", - nlive=100, - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - def test_run_PTMCMCSampler(self): - pytest.importorskip("PTMCMCSampler") - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="PTMCMCsampler", - Niter=101, - burn=2, - isave=100, - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - def test_run_ultranest(self): - pytest.importorskip("ultranest") - # run using NestedSampler (with nlive specified) - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="ultranest", - nlive=100, - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - # run using ReactiveNestedSampler (with no nlive given) - res = bilby.run_sampler( - likelihood=self.likelihood, - priors=self.priors, - sampler="ultranest", - **self.kwargs, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None - - def test_run_bilby_mcmc(self): + self._remove_tree() + + def _remove_tree(self): + try: + shutil.rmtree("outdir") + except OSError: + pass + + @parameterized.expand(_sampler_kwargs.keys()) + def test_run_sampler_single(self, sampler): + self._run_sampler(sampler, pool_size=1) + + @parameterized.expand(_sampler_kwargs.keys()) + def test_run_sampler_pool(self, sampler): + self._run_sampler(sampler, pool_size=2) + + def _run_sampler(self, sampler, pool_size, **extra_kwargs): + pytest.importorskip(sampler_imports.get(sampler, sampler)) + if pool_size > 1 and sampler.lower() in no_pool_test: + pytest.skip(f"{sampler} cannot be parallelized") + bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir") + kwargs = _sampler_kwargs[sampler] res = bilby.run_sampler( likelihood=self.likelihood, priors=self.priors, - sampler="bilby_mcmc", - nsamples=200, - **self.kwargs, - printdt=1, - ) - assert "derived" in res.posterior - assert res.log_likelihood_evaluations is not None + sampler=sampler, + save=False, + npool=pool_size, + conversion_function=self.conversion_function, + **kwargs, + **extra_kwargs, + ) + assert "derived" in res.posterior + assert res.log_likelihood_evaluations is not None + + @parameterized.expand(_sampler_kwargs.keys()) + def test_interrupt_sampler_single(self, sampler): + self._run_with_signal_handling(sampler, pool_size=1) + + @parameterized.expand(_sampler_kwargs.keys()) + def test_interrupt_sampler_pool(self, sampler): + self._run_with_signal_handling(sampler, pool_size=2) + + def _run_with_signal_handling(self, sampler, pool_size=1): + pytest.importorskip(sampler_imports.get(sampler, sampler)) + if bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler.lower()].hard_exit: + pytest.skip(f"{sampler} hard exits, can't test signal handling.") + if pool_size > 1 and sampler.lower() in no_pool_test: + pytest.skip(f"{sampler} cannot be parallelized") + if sys.version_info.minor == 8 and sampler.lower == "cpnest": + pytest.skip("Pool interrupting broken for cpnest with py3.8") + if sampler.lower() == "nessai" and pool_size > 1: + pytest.skip( + "Interrupting with a pool is failing in pytest. " + "Likely due to interactions with the signal handling in nessai." + ) + pid = os.getpid() + print(sampler) + + def trigger_signal(): + # You could do something more robust, e.g. wait until port is listening + time.sleep(4) + os.kill(pid, SIGINT) + + thread = threading.Thread(target=trigger_signal) + thread.daemon = True + thread.start() + + self.likelihood._func = slow_func + + with self.assertRaises((SystemExit, KeyboardInterrupt)): + try: + while True: + self._run_sampler(sampler=sampler, pool_size=pool_size, exit_code=5) + except SystemExit as error: + self.assertEqual(error.code, 5) + raise if __name__ == "__main__":