Skip to content
Snippets Groups Projects
Commit 76482db7 authored by Colm Talbot's avatar Colm Talbot
Browse files

Allow ptemcee to be initialized from a dictionary

parent 4081c24a
No related branches found
No related tags found
1 merge request!1049Allow ptemcee to be initialized from a dictionary
......@@ -64,7 +64,7 @@ class Ptemcee(MCMCSampler):
autocorr_c: int, (5)
The step size for the window search used by emcee.autocorr.integrated_time
safety: int, (1)
A multiplicitive factor for the estimated autocorrelation. Useful for
A multiplicative factor for the estimated autocorrelation. Useful for
cases where non-convergence can be observed by eye but the automated
tools are failing.
autocorr_tau: int, (1)
......@@ -92,14 +92,18 @@ class Ptemcee(MCMCSampler):
is not recommended for cases where tau is large.
ignore_keys_for_tau: str
A pattern used to ignore keys in estimating the autocorrelation time.
pos0: str, list, np.ndarray
pos0: str, list, np.ndarray, dict
If a string, one of "prior" or "minimize". For "prior", the initial
positions of the sampler are drawn from the sampler. If "minimize",
a scipy.optimize step is applied to all parameters a number of times.
The walkers are then initialized from the range of values obtained.
If a list, for the keys in the list the optimization step is applied,
otherwise the initial points are drawn from the prior. If a numpy array
the shape should be (ntemps, nwalkers, ndim).
otherwise the initial points are drawn from the prior.
If a :code:`numpy` array the shape should be
:code:`(ntemps, nwalkers, ndim)`.
If a :code:`dict`, this should be a dictionary with keys matching the
:code:`search_parameter_keys`. Each entry should be an array with
shape :code:`(ntemps, nwalkers)`.
niterations_per_check: int (5)
The number of iteration steps to take before checking ACT. This
......@@ -282,13 +286,13 @@ class Ptemcee(MCMCSampler):
"""
logger.info("Generating pos0 samples")
return [
return np.array([
[
self.get_random_draw_from_prior()
for _ in range(self.nwalkers)
]
for _ in range(self.kwargs["ntemps"])
]
])
def get_pos0_from_minimize(self, minimize_list=None):
""" Draw the initial positions using an initial minimization step
......@@ -376,6 +380,18 @@ class Ptemcee(MCMCSampler):
else:
return self.pos0
def get_pos0_from_dict(self):
"""
Initialize the starting points from a passed dictionary.
The :code:`pos0` passed to the :code:`Sampler` should be a dictionary
with keys matching the :code:`search_parameter_keys`.
Each entry should have shape :code:`(ntemps, nwalkers)`.
"""
pos0 = np.array([self.pos0[key] for key in self.search_parameter_keys])
self.pos0 = np.moveaxis(pos0, 0, -1)
return self.get_pos0_from_array()
def setup_sampler(self):
""" Either initialize the sampler or read in the resume file """
import ptemcee
......@@ -461,6 +477,8 @@ class Ptemcee(MCMCSampler):
return self.get_pos0_from_minimize(minimize_list=self.pos0)
elif isinstance(self.pos0, np.ndarray):
return self.get_pos0_from_array()
elif isinstance(self.pos0, dict):
return self.get_pos0_from_dict()
else:
raise SamplerError("pos0={} not implemented".format(self.pos0))
......
import unittest
from unittest.mock import MagicMock
import bilby
from bilby.core.likelihood import GaussianLikelihood
from bilby.core.prior import Uniform, PriorDict
from bilby.core.sampler import Ptemcee
from bilby.core.sampler.base_sampler import MCMCSampler
import numpy as np
class TestPTEmcee(unittest.TestCase):
def setUp(self):
self.likelihood = MagicMock()
self.priors = bilby.core.prior.PriorDict(
dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1))
self.likelihood = GaussianLikelihood(
x=np.linspace(0, 1, 2),
y=np.linspace(0, 1, 2),
func=lambda x, **kwargs: x,
sigma=1,
)
self.sampler = bilby.core.sampler.Ptemcee(
self.likelihood,
self.priors,
self.priors = PriorDict(dict(a=Uniform(0, 1), b=Uniform(0, 1)))
self._args = (self.likelihood, self.priors)
self._kwargs = dict(
outdir="outdir",
label="label",
use_ratio=False,
plot=False,
skip_import_verification=True,
)
def tearDown(self):
del self.likelihood
del self.priors
del self.sampler
def test_default_kwargs(self):
expected = dict(
self.sampler = Ptemcee(*self._args, **self._kwargs)
self.expected = dict(
ntemps=10,
nwalkers=100,
Tmax=None,
......@@ -38,27 +37,56 @@ class TestPTEmcee(unittest.TestCase):
adapt=False,
swap_ratios=False,
)
self.assertDictEqual(expected, self.sampler.kwargs)
def tearDown(self):
del self.likelihood
del self.priors
del self.sampler
def test_default_kwargs(self):
self.assertDictEqual(self.expected, self.sampler.kwargs)
def test_translate_kwargs(self):
expected = dict(
ntemps=10,
nwalkers=100,
Tmax=None,
betas=None,
a=2.0,
adaptation_lag=10000,
adaptation_time=100,
random=None,
adapt=False,
swap_ratios=False,
)
for equiv in bilby.core.sampler.base_sampler.MCMCSampler.nwalkers_equiv_kwargs:
for equiv in MCMCSampler.nwalkers_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
del new_kwargs["nwalkers"]
new_kwargs[equiv] = 100
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)
self.assertDictEqual(self.expected, self.sampler.kwargs)
def test_set_pos0_using_array(self):
"""
Verify that setting the initial points from an array matches the
default method.
"""
pos0 = self.sampler.get_pos0()
new_sampler = Ptemcee(*self._args, **self._kwargs, pos0=pos0)
self.assertTrue(np.array_equal(new_sampler.get_pos0(), pos0))
def test_set_pos0_using_dict(self):
"""
Verify that setting the initial points from a dictionary matches the
default method.
"""
old = np.array(self.sampler.get_pos0())
pos0 = np.moveaxis(old, -1, 0)
pos0 = {
key: points for key, points in
zip(self.sampler.search_parameter_keys, pos0)
}
new_sampler = Ptemcee(*self._args, **self._kwargs, pos0=pos0)
new = new_sampler.get_pos0()
self.assertTrue(np.array_equal(new, old))
def test_set_pos0_from_minimize(self):
"""
Verify that the minimize method of setting the initial points
returns the same shape as the default.
"""
old = self.sampler.get_pos0().shape
new_sampler = Ptemcee(*self._args, **self._kwargs, pos0="minimize")
new = new_sampler.get_pos0().shape
self.assertEqual(old, new)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment