-
Matthew David Pitkin authoredMatthew David Pitkin authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
ultranest.py 11.45 KiB
from __future__ import absolute_import
import os
import shutil
import signal
import tempfile
import numpy as np
from pandas import DataFrame
from ..utils import check_directory_exists_and_if_not_mkdir, logger
from .base_sampler import NestedSampler
class Ultranest(NestedSampler):
"""
bilby wrapper of ultranest
(https://johannesbuchner.github.io/UltraNest/index.html)
All positional and keyword arguments (i.e., the args and kwargs) passed to
`run_sampler` will be propagated to `ultranest.ReactiveNestedSampler.run`
or `ultranest.NestedSampler.run`, see documentation for those classes for
further help. Under Other Parameters, we list commonly used kwargs and the
bilby defaults. If the number of live points is specified the
`ultranest.NestedSampler` will be used, otherwise the
`ultranest.ReactiveNestedSampler` will be used.
Other Parameters
----------------
num_live_points: int
The number of live points, note this can also equivalently be given as
one of [nlive, nlives, n_live_points, num_live_points]. If not given
then the `ultranest.ReactiveNestedSampler` will be used, which does not
require the number of live points to be specified.
show_status: Bool
If true, print information information about the convergence during
resume: bool
If true, resume run from checkpoint (if available)
step_sampler:
An UltraNest step sampler object. This defaults to None, so the default
stepping behaviour is used.
"""
default_kwargs = dict(
resume=True,
show_status=True,
num_live_points=None,
wrapped_params=None,
log_dir=None,
derived_param_names=[],
run_num=None,
vectorized=False,
num_test_samples=2,
draw_multiple=True,
num_bootstraps=30,
update_interval_iter=None,
update_interval_ncall=None,
log_interval=None,
dlogz=None,
max_iters=None,
update_interval_iter_fraction=0.2,
viz_callback="auto",
dKL=0.5,
frac_remain=0.01,
Lepsilon=0.001,
min_ess=400,
max_ncalls=None,
max_num_improvement_loops=-1,
min_num_live_points=400,
cluster_num_live_points=40,
step_sampler=None,
)
def __init__(
self,
likelihood,
priors,
outdir="outdir",
label="label",
use_ratio=False,
plot=False,
exit_code=77,
skip_import_verification=False,
**kwargs,
):
super(Ultranest, self).__init__(
likelihood=likelihood,
priors=priors,
outdir=outdir,
label=label,
use_ratio=use_ratio,
plot=plot,
skip_import_verification=skip_import_verification,
exit_code=exit_code,
**kwargs,
)
self._apply_ultranest_boundaries()
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
def _translate_kwargs(self, kwargs):
if "num_live_points" not in kwargs:
for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs:
kwargs["num_live_points"] = kwargs.pop(equiv)
if "verbose" in kwargs and "show_status" not in kwargs:
kwargs["show_status"] = kwargs.pop("verbose")
def _verify_kwargs_against_default_kwargs(self):
""" Check the kwargs """
self.outputfiles_basename = self.kwargs.pop("log_dir", None)
NestedSampler._verify_kwargs_against_default_kwargs(self)
def _apply_ultranest_boundaries(self):
if (
self.kwargs["wrapped_params"] is None
or len(self.kwargs.get("wrapped_params", [])) == 0
):
self.kwargs["wrapped_params"] = []
for param, value in self.priors.items():
if param in self.search_parameter_keys:
if value.boundary == "periodic":
self.kwargs["wrapped_params"].append(1)
else:
self.kwargs["wrapped_params"].append(0)
@property
def outputfiles_basename(self):
return self._outputfiles_basename
@outputfiles_basename.setter
def outputfiles_basename(self, outputfiles_basename):
if outputfiles_basename is None:
outputfiles_basename = "{}/ultra_{}".format(self.outdir, self.label)
if outputfiles_basename.endswith("/") is True:
outputfiles_basename = outputfiles_basename.rstrip("/")
check_directory_exists_and_if_not_mkdir(self.outdir)
self._outputfiles_basename = outputfiles_basename
@property
def temporary_outputfiles_basename(self):
return self._temporary_outputfiles_basename
@temporary_outputfiles_basename.setter
def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
if temporary_outputfiles_basename.endswith("/") is False:
temporary_outputfiles_basename = "{}/".format(
temporary_outputfiles_basename
)
self._temporary_outputfiles_basename = temporary_outputfiles_basename
if os.path.exists(self.outputfiles_basename):
shutil.copytree(
self.outputfiles_basename, self.temporary_outputfiles_basename
)
if os.path.islink(self.outputfiles_basename):
os.unlink(self.outputfiles_basename)
else:
shutil.rmtree(self.outputfiles_basename)
def write_current_state_and_exit(self, signum=None, frame=None):
""" Write current state and exit on exit_code """
logger.info(
"Run interrupted by signal {}: checkpoint and exit on {}".format(
signum, self.exit_code
)
)
# self.copy_temporary_directory_to_proper_path()
os._exit(self.exit_code)
def copy_temporary_directory_to_proper_path(self):
logger.info(
"Overwriting {} with {}".format(
self.outputfiles_basename, self.temporary_outputfiles_basename
)
)
# First remove anything in the outputfiles_basename for overwriting
if os.path.exists(self.outputfiles_basename):
if os.path.islink(self.outputfiles_basename):
os.unlink(self.outputfiles_basename)
else:
shutil.rmtree(self.outputfiles_basename, ignore_errors=True)
shutil.copytree(self.temporary_outputfiles_basename, self.outputfiles_basename)
@property
def sampler_function_kwargs(self):
if self.kwargs.get("num_live_points", None) is not None:
keys = [
"update_interval_iter",
"update_interval_ncall",
"log_interval",
"dlogz",
"max_iters",
]
else:
keys = [
"update_interval_iter_fraction",
"update_interval_ncall",
"log_interval",
"show_status",
"viz_callback",
"dlogz",
"dKL",
"frac_remain",
"Lepsilon",
"min_ess",
"max_iters",
"max_ncalls",
"max_num_improvement_loops",
"min_num_live_points",
"cluster_num_live_points",
]
function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
return function_kwargs
@property
def sampler_init_kwargs(self):
keys = [
"derived_param_names",
"resume",
"run_num",
"vectorized",
"log_dir",
"wrapped_params",
]
if self.kwargs.get("num_live_points", None) is not None:
keys += ["num_live_points"]
else:
keys += ["num_test_samples", "draw_multiple", "num_bootstraps"]
init_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
return init_kwargs
def run_sampler(self):
import ultranest
import ultranest.stepsampler
if self.kwargs["dlogz"] is None:
# remove dlogz, so ultranest defaults (which are different for
# NestedSampler and ReactiveNestedSampler) are used
self.kwargs.pop("dlogz")
self._verify_kwargs_against_default_kwargs()
stepsampler = self.kwargs.pop("step_sampler", None)
temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
self.temporary_outputfiles_basename = temporary_outputfiles_basename
logger.info("Using temporary file {}".format(temporary_outputfiles_basename))
check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
self.kwargs["log_dir"] = self.temporary_outputfiles_basename
# Symlink the temporary directory with the target directory: ensures data is stored on exit
os.symlink(
os.path.abspath(self.temporary_outputfiles_basename),
os.path.abspath(self.outputfiles_basename),
target_is_directory=True,
)
# use reactive nested sampler when no live points are given
if self.kwargs.get("num_live_points", None) is not None:
integrator = ultranest.integrator.NestedSampler
else:
integrator = ultranest.integrator.ReactiveNestedSampler
sampler = integrator(
self.search_parameter_keys,
self.log_likelihood,
transform=self.prior_transform,
**self.sampler_init_kwargs,
)
if stepsampler is not None:
if isinstance(stepsampler, ultranest.stepsampler.StepSampler):
sampler.stepsampler = stepsampler
else:
logger.warning(
"The supplied step sampler is not the correct type. "
"The default step sampling will be used instead."
)
results = sampler.run(**self.sampler_function_kwargs)
self.copy_temporary_directory_to_proper_path()
# Clean up
shutil.rmtree(temporary_outputfiles_basename)
self._generate_result(results)
self.calc_likelihood_count()
return self.result
def _generate_result(self, out):
# extract results (samples stored in "v" will change to "points",
# weights stored in "w" will change to "weights")
datakey = "v" if "v" in out["weighted_samples"] else "points"
weightskey = "w" if "w" in out["weighted_samples"] else "weights"
data = np.array(out["weighted_samples"][datakey])
weights = np.array(out["weighted_samples"][weightskey])
scaledweights = weights / weights.max()
mask = np.random.rand(len(scaledweights)) < scaledweights
nested_samples = DataFrame(data, columns=self.search_parameter_keys)
nested_samples["weights"] = weights
nested_samples["log_likelihood"] = out["weighted_samples"]["L"]
self.result.log_likelihood_evaluations = np.array(out["weighted_samples"]["L"])[
mask
]
self.result.sampler_output = out
self.result.samples = data[mask, :]
self.result.nested_samples = nested_samples
self.result.log_evidence = out["logz"]
self.result.log_evidence_err = out["logzerr"]
self.result.outputfiles_basename = self.outputfiles_basename