ultranest.py 13.1 KB
Newer Older
1

2 3 4
import datetime
import distutils.dir_util
import inspect
5 6 7
import os
import shutil
import signal
8
import time
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

import numpy as np
from pandas import DataFrame

from ..utils import check_directory_exists_and_if_not_mkdir, logger
from .base_sampler import NestedSampler


class Ultranest(NestedSampler):
    """
    bilby wrapper of ultranest
    (https://johannesbuchner.github.io/UltraNest/index.html)

    All positional and keyword arguments (i.e., the args and kwargs) passed to
    `run_sampler` will be propagated to `ultranest.ReactiveNestedSampler.run`
    or `ultranest.NestedSampler.run`, see documentation for those classes for
    further help. Under Other Parameters, we list commonly used kwargs and the
    bilby defaults. If the number of live points is specified the
    `ultranest.NestedSampler` will be used, otherwise the
    `ultranest.ReactiveNestedSampler` will be used.

    Other Parameters
    ----------------
    num_live_points: int
        The number of live points, note this can also equivalently be given as
        one of [nlive, nlives, n_live_points, num_live_points]. If not given
        then the `ultranest.ReactiveNestedSampler` will be used, which does not
        require the number of live points to be specified.
    show_status: Bool
        If true, print information information about the convergence during
    resume: bool
        If true, resume run from checkpoint (if available)
    step_sampler:
        An UltraNest step sampler object. This defaults to None, so the default
        stepping behaviour is used.
    """

    default_kwargs = dict(
        resume=True,
        show_status=True,
        num_live_points=None,
        wrapped_params=None,
        log_dir=None,
        derived_param_names=[],
        run_num=None,
        vectorized=False,
        num_test_samples=2,
        draw_multiple=True,
        num_bootstraps=30,
        update_interval_iter=None,
        update_interval_ncall=None,
        log_interval=None,
        dlogz=None,
        max_iters=None,
63
        update_interval_volume_fraction=0.2,
64
        viz_callback=None,
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        dKL=0.5,
        frac_remain=0.01,
        Lepsilon=0.001,
        min_ess=400,
        max_ncalls=None,
        max_num_improvement_loops=-1,
        min_num_live_points=400,
        cluster_num_live_points=40,
        step_sampler=None,
    )

    def __init__(
        self,
        likelihood,
        priors,
        outdir="outdir",
        label="label",
        use_ratio=False,
        plot=False,
        exit_code=77,
        skip_import_verification=False,
86 87
        temporary_directory=True,
        callback_interval=10,
88 89 90 91 92 93 94 95 96 97
        **kwargs,
    ):
        super(Ultranest, self).__init__(
            likelihood=likelihood,
            priors=priors,
            outdir=outdir,
            label=label,
            use_ratio=use_ratio,
            plot=plot,
            skip_import_verification=skip_import_verification,
Colm Talbot's avatar
Colm Talbot committed
98
            exit_code=exit_code,
99 100 101
            **kwargs,
        )
        self._apply_ultranest_boundaries()
102 103 104 105 106 107
        self.use_temporary_directory = temporary_directory

        if self.use_temporary_directory:
            # set callback interval, so copying of results does not thrash the
            # disk (ultranest will call viz_callback quite a lot)
            self.callback_interval = callback_interval
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125

        signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
        signal.signal(signal.SIGINT, self.write_current_state_and_exit)
        signal.signal(signal.SIGALRM, self.write_current_state_and_exit)

    def _translate_kwargs(self, kwargs):
        if "num_live_points" not in kwargs:
            for equiv in self.npoints_equiv_kwargs:
                if equiv in kwargs:
                    kwargs["num_live_points"] = kwargs.pop(equiv)

        if "verbose" in kwargs and "show_status" not in kwargs:
            kwargs["show_status"] = kwargs.pop("verbose")

    def _verify_kwargs_against_default_kwargs(self):
        """ Check the kwargs """

        self.outputfiles_basename = self.kwargs.pop("log_dir", None)
126 127
        if self.kwargs["viz_callback"] is None:
            self.kwargs["viz_callback"] = self._viz_callback
128 129 130

        NestedSampler._verify_kwargs_against_default_kwargs(self)

131 132 133 134 135 136 137
    def _viz_callback(self, *args, **kwargs):
        if self.use_temporary_directory:
            if not (self._viz_callback_counter % self.callback_interval):
                self._copy_temporary_directory_contents_to_proper_path()
                self._calculate_and_save_sampling_time()
            self._viz_callback_counter += 1

138 139 140 141 142 143 144
    def _apply_ultranest_boundaries(self):
        if (
            self.kwargs["wrapped_params"] is None
            or len(self.kwargs.get("wrapped_params", [])) == 0
        ):
            self.kwargs["wrapped_params"] = []
            for param, value in self.priors.items():
145 146 147 148 149
                if param in self.search_parameter_keys:
                    if value.boundary == "periodic":
                        self.kwargs["wrapped_params"].append(1)
                    else:
                        self.kwargs["wrapped_params"].append(0)
150 151 152 153 154 155 156 157

    @property
    def outputfiles_basename(self):
        return self._outputfiles_basename

    @outputfiles_basename.setter
    def outputfiles_basename(self, outputfiles_basename):
        if outputfiles_basename is None:
158 159 160 161 162
            outputfiles_basename = os.path.join(
                self.outdir, "ultra_{}/".format(self.label)
            )
        if not outputfiles_basename.endswith("/"):
            outputfiles_basename += "/"
163 164 165 166 167 168 169 170 171
        check_directory_exists_and_if_not_mkdir(self.outdir)
        self._outputfiles_basename = outputfiles_basename

    @property
    def temporary_outputfiles_basename(self):
        return self._temporary_outputfiles_basename

    @temporary_outputfiles_basename.setter
    def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
172
        if not temporary_outputfiles_basename.endswith("/"):
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
            temporary_outputfiles_basename = "{}/".format(
                temporary_outputfiles_basename
            )
        self._temporary_outputfiles_basename = temporary_outputfiles_basename
        if os.path.exists(self.outputfiles_basename):
            shutil.copytree(
                self.outputfiles_basename, self.temporary_outputfiles_basename
            )

    def write_current_state_and_exit(self, signum=None, frame=None):
        """ Write current state and exit on exit_code """
        logger.info(
            "Run interrupted by signal {}: checkpoint and exit on {}".format(
                signum, self.exit_code
            )
        )
189 190 191
        self._calculate_and_save_sampling_time()
        if self.use_temporary_directory:
            self._move_temporary_directory_to_proper_path()
192 193
        os._exit(self.exit_code)

194 195 196 197 198 199 200 201 202 203
    def _copy_temporary_directory_contents_to_proper_path(self):
        """
        Copy the temporary back to the proper path.
        Do not delete the temporary directory.
        """
        if inspect.stack()[1].function != "_viz_callback":
            logger.info(
                "Overwriting {} with {}".format(
                    self.outputfiles_basename, self.temporary_outputfiles_basename
                )
204
            )
205 206 207 208 209 210
        if self.outputfiles_basename.endswith("/"):
            outputfiles_basename_stripped = self.outputfiles_basename[:-1]
        else:
            outputfiles_basename_stripped = self.outputfiles_basename
        distutils.dir_util.copy_tree(
            self.temporary_outputfiles_basename, outputfiles_basename_stripped
211 212
        )

213 214 215
    def _move_temporary_directory_to_proper_path(self):
        """
        Move the temporary back to the proper path
216

217 218 219 220
        Anything in the proper path at this point is removed including links
        """
        self._copy_temporary_directory_contents_to_proper_path()
        shutil.rmtree(self.temporary_outputfiles_basename)
221 222 223 224 225 226 227 228 229 230 231 232 233

    @property
    def sampler_function_kwargs(self):
        if self.kwargs.get("num_live_points", None) is not None:
            keys = [
                "update_interval_iter",
                "update_interval_ncall",
                "log_interval",
                "dlogz",
                "max_iters",
            ]
        else:
            keys = [
234
                "update_interval_volume_fraction",
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
                "update_interval_ncall",
                "log_interval",
                "show_status",
                "viz_callback",
                "dlogz",
                "dKL",
                "frac_remain",
                "Lepsilon",
                "min_ess",
                "max_iters",
                "max_ncalls",
                "max_num_improvement_loops",
                "min_num_live_points",
                "cluster_num_live_points",
            ]

        function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}

        return function_kwargs

    @property
    def sampler_init_kwargs(self):
        keys = [
            "derived_param_names",
            "resume",
            "run_num",
            "vectorized",
            "log_dir",
            "wrapped_params",
        ]
        if self.kwargs.get("num_live_points", None) is not None:
            keys += ["num_live_points"]
        else:
            keys += ["num_test_samples", "draw_multiple", "num_bootstraps"]

        init_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}

        return init_kwargs

    def run_sampler(self):
        import ultranest
        import ultranest.stepsampler

        if self.kwargs["dlogz"] is None:
            # remove dlogz, so ultranest defaults (which are different for
            # NestedSampler and ReactiveNestedSampler) are used
            self.kwargs.pop("dlogz")

        self._verify_kwargs_against_default_kwargs()

        stepsampler = self.kwargs.pop("step_sampler", None)

287
        self._setup_run_directory()
288
        self.kwargs["log_dir"] = self.kwargs.pop("outputfiles_basename")
289
        self._check_and_load_sampling_time_file()
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312

        # use reactive nested sampler when no live points are given
        if self.kwargs.get("num_live_points", None) is not None:
            integrator = ultranest.integrator.NestedSampler
        else:
            integrator = ultranest.integrator.ReactiveNestedSampler

        sampler = integrator(
            self.search_parameter_keys,
            self.log_likelihood,
            transform=self.prior_transform,
            **self.sampler_init_kwargs,
        )

        if stepsampler is not None:
            if isinstance(stepsampler, ultranest.stepsampler.StepSampler):
                sampler.stepsampler = stepsampler
            else:
                logger.warning(
                    "The supplied step sampler is not the correct type. "
                    "The default step sampling will be used instead."
                )

313 314
        if self.use_temporary_directory:
            self._viz_callback_counter = 1
315

316 317 318
        self.start_time = time.time()
        results = sampler.run(**self.sampler_function_kwargs)
        self._calculate_and_save_sampling_time()
319 320

        # Clean up
321
        self._clean_up_run_directory()
322 323 324 325 326 327

        self._generate_result(results)
        self.calc_likelihood_count()

        return self.result

328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
    def _clean_up_run_directory(self):
        if self.use_temporary_directory:
            self._move_temporary_directory_to_proper_path()
            self.kwargs["log_dir"] = self.outputfiles_basename

    def _check_and_load_sampling_time_file(self):
        self.time_file_path = os.path.join(self.kwargs["log_dir"], "sampling_time.dat")
        if os.path.exists(self.time_file_path):
            with open(self.time_file_path, "r") as time_file:
                self.total_sampling_time = float(time_file.readline())
        else:
            self.total_sampling_time = 0

    def _calculate_and_save_sampling_time(self):
        current_time = time.time()
        new_sampling_time = current_time - self.start_time
        self.total_sampling_time += new_sampling_time
        with open(self.time_file_path, "w") as time_file:
            time_file.write(str(self.total_sampling_time))
        self.start_time = current_time

349
    def _generate_result(self, out):
350 351 352
        # extract results
        data = np.array(out["weighted_samples"]["points"])
        weights = np.array(out["weighted_samples"]["weights"])
353 354 355 356 357 358

        scaledweights = weights / weights.max()
        mask = np.random.rand(len(scaledweights)) < scaledweights

        nested_samples = DataFrame(data, columns=self.search_parameter_keys)
        nested_samples["weights"] = weights
359 360
        nested_samples["log_likelihood"] = out["weighted_samples"]["logl"]
        self.result.log_likelihood_evaluations = np.array(out["weighted_samples"]["logl"])[
361 362 363 364 365 366 367
            mask
        ]
        self.result.sampler_output = out
        self.result.samples = data[mask, :]
        self.result.nested_samples = nested_samples
        self.result.log_evidence = out["logz"]
        self.result.log_evidence_err = out["logzerr"]
368 369
        if self.kwargs["num_live_points"] is not None:
            self.result.information_gain = np.power(out["logzerr"], 2) * self.kwargs["num_live_points"]
370 371

        self.result.outputfiles_basename = self.outputfiles_basename
372
        self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time)