Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • john-veitch/bilby
  • duncanmmacleod/bilby
  • colm.talbot/bilby
  • lscsoft/bilby
  • matthew-pitkin/bilby
  • salvatore-vitale/tupak
  • charlie.hoy/bilby
  • bfarr/bilby
  • virginia.demilio/bilby
  • vivien/bilby
  • eric-howell/bilby
  • sebastian-khan/bilby
  • rhys.green/bilby
  • moritz.huebner/bilby
  • joseph.mills/bilby
  • scott.coughlin/bilby
  • matthew.carney/bilby
  • hyungwon.lee/bilby
  • monica.rizzo/bilby
  • christopher-berry/bilby
  • lindsay.demarchi/bilby
  • kaushik.rao/bilby
  • charles.kimball/bilby
  • andrew.matas/bilby
  • juan.calderonbustillo/bilby
  • patrick-meyers/bilby
  • hannah.middleton/bilby
  • eve.chase/bilby
  • grant.meadors/bilby
  • khun.phukon/bilby
  • sumeet.kulkarni/bilby
  • daniel.reardon/bilby
  • cjhaster/bilby
  • sylvia.biscoveanu/bilby
  • james-clark/bilby
  • meg.millhouse/bilby
  • joshua.willis/bilby
  • nikhil.sarin/bilby
  • paul.easter/bilby
  • youngmin/bilby
  • daniel-williams/bilby
  • shanika.galaudage/bilby
  • bruce.edelman/bilby
  • avi.vajpeyi/bilby
  • isobel.romero-shaw/bilby
  • andrew.kim/bilby
  • dominika.zieba/bilby
  • jonathan.davies/bilby
  • marc.arene/bilby
  • srishti.tiwari/bilby-tidal-heating-eccentric
  • aditya.vijaykumar/bilby
  • michael.williams/bilby
  • cecilio.garcia-quiros/bilby
  • rory-smith/bilby
  • maite.mateu-lucena/bilby
  • wushichao/bilby
  • kaylee.desoto/bilby
  • brandon.piotrzkowski/bilby
  • rossella.gamba/bilby
  • hunter.gabbard/bilby
  • deep.chatterjee/bilby
  • tathagata.ghosh/bilby
  • arunava.mukherjee/bilby
  • philip.relton/bilby
  • reed.essick/bilby
  • pawan.gupta/bilby
  • francisco.hernandez/bilby
  • rhiannon.udall/bilby
  • leo.tsukada/bilby
  • will-farr/bilby
  • vijay.varma/bilby
  • jeremy.baier/bilby
  • joshua.brandt/bilby
  • ethan.payne/bilby
  • ka-lok.lo/bilby
  • antoni.ramos-buades/bilby
  • oliviastephany.wilk/bilby
  • jack.heinzel/bilby
  • samson.leong/bilby-psi4
  • viviana.caceres/bilby
  • nadia.qutob/bilby
  • michael-coughlin/bilby
  • hemantakumar.phurailatpam/bilby
  • boris.goncharov/bilby
  • sama.al-shammari/bilby
  • siqi.zhong/bilby
  • jocelyn-read/bilby
  • marc.penuliar/bilby
  • stephanie.letourneau/bilby
  • alexandresebastien.goettel/bilby
  • alec.gunny/bilby
  • serguei.ossokine/bilby
  • pratyusava.baral/bilby
  • sophie.hourihane/bilby
  • eunsub/bilby
  • james.hart/bilby
  • pratyusava.baral/bilby-tg
  • zhaozc/bilby
  • pratyusava.baral/bilby_SoG
  • tomasz.baka/bilby
  • nicogerardo.bers/bilby
  • soumen.roy/bilby
  • isaac.mcmahon/healpix-redundancy
  • asamakai.baker/bilby-frequency-dependent-antenna-pattern-functions
  • anna.puecher/bilby
  • pratyusava.baral/bilby-x-g
  • thibeau.wouters/bilby
  • christian.adamcewicz/bilby
  • raffi.enficiaud/bilby
109 results
Show changes
Commits on Source (52)
Showing
with 3438 additions and 97 deletions
hist
livetime
......@@ -4,19 +4,27 @@ repos:
hooks:
- id: check-merge-conflict # prevent committing files with merge conflicts
- id: flake8 # checks for flake8 errors
#- repo: https://github.com/codespell-project/codespell
# rev: v1.16.0
# hooks:
# - id: codespell # Spellchecker
# args: [-L, nd, --skip, "*.ipynb,*.html", --ignore-words=.dictionary.txt]
# exclude: ^examples/tutorials/
#- repo: https://github.com/asottile/seed-isort-config
# rev: v1.3.0
# hooks:
# - id: seed-isort-config
# args: [--application-directories, 'bilby/']
#- repo: https://github.com/pre-commit/mirrors-isort
# rev: v4.3.21
# hooks:
# - id: isort # sort imports alphabetically and separates import into sections
# args: [-w=88, -m=3, -tc, -sp=setup.cfg ]
- repo: https://github.com/psf/black
rev: 20.8b1
hooks:
- id: black
language_version: python3
files: ^bilby/bilby_mcmc/
- repo: https://github.com/codespell-project/codespell
rev: v1.16.0
hooks:
- id: codespell
args: [--ignore-words=.dictionary.txt]
files: ^bilby/bilby_mcmc/
- repo: https://github.com/asottile/seed-isort-config
rev: v1.3.0
hooks:
- id: seed-isort-config
args: [--application-directories, 'bilby/']
files: ^bilby/bilby_mcmc/
- repo: https://github.com/pre-commit/mirrors-isort
rev: v4.3.21
hooks:
- id: isort # sort imports alphabetically and separates import into sections
args: [-w=88, -m=3, -tc, -sp=setup.cfg ]
files: ^bilby/bilby_mcmc/
......@@ -30,6 +30,7 @@ Isobel Marguarethe Romero-Shaw
Jade Powell
James A Clark
John Veitch
Joshua Brandt
Katerina Chatziioannou
Kaylee de Soto
Khun Sang Phukon
......
# All notable changes will be documented in this file
## [1.1.3] 2021-07-02
Version 1.1.3 release of bilby
### Added
- Added `Categorical` prior (!982)(!990)
- Added a built-in mcmc sampler (`bilby_mcmc`) (!905)(!985)
- Added run statistics to the `dynesty` meta data (!969)
- Added `cdf` method to `PriorDict` classes (!943)
### Changes
- Removed the autoburnin causing `kombine` to fail the CI tests (!988)
- Sped up the spline interpolation in ROQ (!971)
- Replaced bessel interpolant to scipy function (!976)
- Improved checkpoint stats plot (!977)
- Fixed a typo in the sampler documentation (!986)
- Fixed issue that causes ConditionalDeltaFunction posterior samples not to be saved correctly (!973)
- Solved an issue where injected SNRs were logged incorrectly (!980)
- Made Python 3.6+ a specific requirement (!978)
- Fixed the calibration and time marginalized likelihood (!978)
- Removed a possible error in the distance marginalization (!960)
- Fixed an issue where `check_draw` did not catch `np.nan` values (!965)
- Removed a superfluous line in the docs configuration file (!963)
- Added a warning about class side effects to the `GravtiationalWaveTransient` likelihood classes (!964)
- Allow `ptemcee` initialization with array (!955)
- Removed `Prior.test_valid_for_rescaling` (!956)
- Replaced deprecated numpy aliases builtins (!970)
- Fixed a bug in the algorithm to determine time resolution of ROQ (!967)
- Restructured utils module into several submodules. API remains backwards compatible (!873)
- Changed number of default walks in `dynesty` from `10*self.ndim` to `100` (!961)
## [1.1.2] 2021-05-05
Version 1.1.2 release of bilby
......
from .sampler import Bilby_MCMC
from distutils.version import LooseVersion
import numpy as np
import pandas as pd
from ..core.sampler.base_sampler import SamplerError
from ..core.utils import logger
from .utils import LOGLKEY, LOGLLATEXKEY, LOGPKEY, LOGPLATEXKEY
class Chain(object):
def __init__(
self,
initial_sample,
burn_in_nact=1,
thin_by_nact=1,
fixed_discard=0,
autocorr_c=5,
min_tau=1,
fixed_tau=None,
tau_window=None,
block_length=100000,
):
"""Object to store a single mcmc chain
Parameters
----------
initial_sample: bilby.bilby_mcmc.chain.Sample
The starting point of the chain
burn_in_nact, thin_by_nact : int (1, 1)
The number of autocorrelation times (tau) to discard for burn-in
and the multiplicative factor to thin by (thin_by_nact < 1). I.e
burn_in_nact=10 and thin_by_nact=1 will discard 10*tau samples from
the start of the chain, then thin the final chain by a factor
of 1*tau (resulting in independent samples).
fixed_discard: int (0)
A fixed minimum number of samples to discard (can be used to
override the burn_in_nact if it is too small).
autocorr_c: float (5)
The step size of the window search used by emcee.autocorr when
estimating the autocorrelation time.
min_tau: int (1)
A minimum value for the autocorrelation time.
fixed_tau: int (None)
A fixed value for the autocorrelation (overrides the automated
autocorrelation time estimation). Used in testing.
tau_window: int (None)
Only calculate the autocorrelation time in a trailing window. If
None (default) this method is not used.
block_length: int
The incremental size to extend the array by when it runs out of
space.
"""
self.autocorr_c = autocorr_c
self.min_tau = min_tau
self.burn_in_nact = burn_in_nact
self.thin_by_nact = thin_by_nact
self.block_length = block_length
self.fixed_discard = int(fixed_discard)
self.fixed_tau = fixed_tau
self.tau_window = tau_window
self.ndim = initial_sample.ndim
self.current_sample = initial_sample
self.keys = self.current_sample.keys
self.parameter_keys = self.current_sample.parameter_keys
# Initialize chain
self._chain_array = self._get_zero_chain_array()
self._chain_array_length = block_length
self.position = -1
self.max_log_likelihood = -np.inf
self.max_tau_dict = {}
self.converged = False
self.cached_tau_count = 0
self._minimum_index_proposal = 0
self._minimum_index_adapt = 0
self._last_minimum_index = (0, 0, "I")
self.last_full_tau_dict = {key: np.inf for key in self.parameter_keys}
# Append the initial sample
self.append(self.current_sample)
def _get_zero_chain_array(self):
return np.zeros((self.block_length, self.ndim + 2), dtype=np.float64)
def _extend_chain_array(self):
self._chain_array = np.concatenate(
(self._chain_array, self._get_zero_chain_array()), axis=0
)
self._chain_array_length = len(self._chain_array)
@property
def current_sample(self):
return self._current_sample.copy()
@current_sample.setter
def current_sample(self, current_sample):
self._current_sample = current_sample
def append(self, sample):
self.position += 1
# Extend the array if needed
if self.position >= self._chain_array_length:
self._extend_chain_array()
# Store the current sample and append to the array
self.current_sample = sample
self._chain_array[self.position] = sample.list
# Update the maximum log_likelihood
if sample[LOGLKEY] > self.max_log_likelihood:
self.max_log_likelihood = sample[LOGLKEY]
def __getitem__(self, index):
if index < 0:
index = index + self.position + 1
if index <= self.position:
values = self._chain_array[index]
return Sample({k: v for k, v in zip(self.keys, values)})
else:
raise SamplerError(f"Requested index {index} out of bounds")
def __setitem__(self, index, sample):
if index < 0:
index = index + self.position + 1
self._chain_array[index] = sample.list
def key_to_idx(self, key):
return self.keys.index(key)
def get_1d_array(self, key):
return self._chain_array[: 1 + self.position, self.key_to_idx(key)]
@property
def _random_idx(self):
mindex = self._last_minimum_index[1]
# Check if mindex exceeds current position by 10 ACT: if so use a random sample
# otherwise we draw only from the chain past the minimum_index
if np.isinf(self.tau_last) or self.position - mindex < 10 * self.tau_last:
mindex = 0
return np.random.randint(mindex, self.position + 1)
@property
def random_sample(self):
return self[self._random_idx]
@property
def fixed_discard(self):
return self._fixed_discard
@fixed_discard.setter
def fixed_discard(self, fixed_discard):
self._fixed_discard = int(fixed_discard)
@property
def minimum_index(self):
"""This calculated a minimum index from which to discard samples
A number of methods are provided for the calculation. A subset are
switched off (by `if False` statements) for future development
"""
position = self.position
# Return cached minimum index
last_minimum_index = self._last_minimum_index
if position == last_minimum_index[0]:
return int(last_minimum_index[1])
# If fixed discard is not yet reached, just return that
if position < self.fixed_discard:
self.minimum_index_method = "FD"
return self.fixed_discard
# Initialize list of minimum index methods with the fixed discard (FD)
minimum_index_list = [self.fixed_discard]
minimum_index_method_list = ["FD"]
# Calculate minimum index from tau
if self.tau_last < np.inf:
tau = self.tau_last
elif len(self.max_tau_dict) == 0:
# Bootstrap calculating tau when minimum index has not yet been calculated
tau = self._tau_for_full_chain
else:
tau = np.inf
if tau < np.inf:
minimum_index_list.append(self.burn_in_nact * tau)
minimum_index_method_list.append(f"{self.burn_in_nact}tau")
# Calculate points when log-posterior is within z std of the mean
if True:
zfactor = 1
N = 100
delta_lnP = zfactor * self.ndim / 2
logl = self.get_1d_array(LOGLKEY)
log_prior = self.get_1d_array(LOGPKEY)
log_posterior = logl + log_prior
max_posterior = np.max(log_posterior)
ave = pd.Series(log_posterior).rolling(window=N).mean().iloc[N - 1 :]
delta = max_posterior - ave
passes = ave[delta < delta_lnP]
if len(passes) > 0:
minimum_index_list.append(passes.index[0] + 1)
minimum_index_method_list.append(f"z{zfactor}")
# Add last minimum_index_method
if False:
minimum_index_list.append(last_minimum_index[1])
minimum_index_method_list.append(last_minimum_index[2])
# Minimum index set by proposals
minimum_index_list.append(self.minimum_index_proposal)
minimum_index_method_list.append("PR")
# Minimum index set by temperature adaptation
minimum_index_list.append(self.minimum_index_adapt)
minimum_index_method_list.append("AD")
# Calculate the maximum minimum index and associated method (reporting)
minimum_index = int(np.max(minimum_index_list))
minimum_index_method = minimum_index_method_list[np.argmax(minimum_index_list)]
# Cache the method
self._last_minimum_index = (position, minimum_index, minimum_index_method)
self.minimum_index_method = minimum_index_method
return minimum_index
@property
def minimum_index_proposal(self):
return self._minimum_index_proposal
@minimum_index_proposal.setter
def minimum_index_proposal(self, minimum_index_proposal):
if minimum_index_proposal > self._minimum_index_proposal:
self._minimum_index_proposal = minimum_index_proposal
@property
def minimum_index_adapt(self):
return self._minimum_index_adapt
@minimum_index_adapt.setter
def minimum_index_adapt(self, minimum_index_adapt):
if minimum_index_adapt > self._minimum_index_adapt:
self._minimum_index_adapt = minimum_index_adapt
@property
def tau(self):
""" The maximum ACT over all parameters """
if self.position in self.max_tau_dict:
# If we have the ACT at the current position, return it
return self.max_tau_dict[self.position]
elif (
self.tau_last < np.inf
and self.cached_tau_count < 50
and self.nsamples_last > 50
):
# If we have a recent ACT return it
self.cached_tau_count += 1
return self.tau_last
else:
# Calculate the ACT
return self.tau_nocache
@property
def tau_nocache(self):
""" Calculate tau forcing a recalculation (no cached tau) """
tau = max(self.tau_dict.values())
self.max_tau_dict[self.position] = tau
self.cached_tau_count = 0
return tau
@property
def tau_last(self):
""" Return the last-calculated tau if it exists, else inf """
if len(self.max_tau_dict) > 0:
return list(self.max_tau_dict.values())[-1]
else:
return np.inf
@property
def _tau_for_full_chain(self):
""" The maximum ACT over all parameters """
return max(self._tau_dict_for_full_chain.values())
@property
def _tau_dict_for_full_chain(self):
return self._calculate_tau_dict(minimum_index=0)
@property
def tau_dict(self):
""" Calculate a dictionary of tau (ACT) for every parameter """
return self._calculate_tau_dict(self.minimum_index)
def _calculate_tau_dict(self, minimum_index):
""" Calculate a dictionary of tau (ACT) for every parameter """
logger.debug(f"Calculating tau_dict {self}")
# If there are too few samples to calculate tau
if (self.position - minimum_index) < 2 * self.autocorr_c:
return {key: np.inf for key in self.parameter_keys}
# Choose minimimum index for the ACT calculation
last_tau = self.tau_last
if self.tau_window is not None and last_tau < np.inf:
minimum_index_for_act = max(
minimum_index, int(self.position - self.tau_window * last_tau)
)
else:
minimum_index_for_act = minimum_index
# Calculate a dictionary of tau's for each parameter
taus = {}
for key in self.parameter_keys:
if self.fixed_tau is None:
x = self.get_1d_array(key)[minimum_index_for_act:]
tau = calculate_tau(x, self.autocorr_c)
taux = round(tau, 1)
else:
taux = self.fixed_tau
taus[key] = max(taux, self.min_tau)
# Cache the last tau dictionary for future use
self.last_full_tau_dict = taus
return taus
@property
def thin(self):
if np.isfinite(self.tau):
return np.max([1, int(self.thin_by_nact * self.tau)])
else:
return 1
@property
def nsamples(self):
nuseable_steps = self.position - self.minimum_index
return int(nuseable_steps / (self.thin_by_nact * self.tau))
@property
def nsamples_last(self):
nuseable_steps = self.position - self.minimum_index
return int(nuseable_steps / (self.thin_by_nact * self.tau_last))
@property
def samples(self):
samples = self._chain_array[self.minimum_index : self.position : self.thin]
return pd.DataFrame(samples, columns=self.keys)
def plot(self, outdir=".", label="label", priors=None, all_samples=None):
import matplotlib.pyplot as plt
fig, axes = plt.subplots(
nrows=self.ndim + 3, ncols=2, figsize=(8, 9 + 3 * (self.ndim))
)
scatter_kwargs = dict(
lw=0,
marker="o",
)
K = 1000
nburn = self.minimum_index
plot_setups = zip(
[0, nburn, nburn],
[nburn, self.position, self.position],
[1, 1, self.thin], # Thin-by factor
["tab:red", "tab:grey", "tab:blue"], # Color
[0.5, 0.05, 0.5], # Alpha
[1, 1, 1], # Marker size
)
position_indexes = np.arange(self.position + 1)
# Plot the traceplots
for (start, stop, thin, color, alpha, ms) in plot_setups:
for ax, key in zip(axes[:, 0], self.keys):
xx = position_indexes[start:stop:thin] / K
yy = self.get_1d_array(key)[start:stop:thin]
# Downsample plots to max_pts: avoid memory issues
max_pts = 10000
while len(xx) > max_pts:
xx = xx[::2]
yy = yy[::2]
ax.plot(
xx,
yy,
color=color,
alpha=alpha,
ms=ms,
**scatter_kwargs,
)
ax.set_ylabel(self._get_plot_label_by_key(key, priors))
if key not in [LOGLKEY, LOGPKEY]:
msg = r"$\tau=$" + f"{self.last_full_tau_dict[key]:0.1f}"
ax.set_title(msg)
# Plot the histograms
for ax, key in zip(axes[:, 1], self.keys):
if all_samples is not None:
yy_all = all_samples[key]
ax.hist(yy_all, bins=50, alpha=0.6, density=True, color="k")
yy = self.get_1d_array(key)[nburn : self.position : self.thin]
ax.hist(yy, bins=50, alpha=0.8, density=True)
ax.set_xlabel(self._get_plot_label_by_key(key, priors))
# Add x-axes labels to the traceplots
axes[-1, 0].set_xlabel(r"Iteration $[\times 10^{3}]$")
# Plot the calculated ACT
ax = axes[-1, 0]
tausit = np.array(list(self.max_tau_dict.keys()) + [self.position]) / K
taus = list(self.max_tau_dict.values()) + [self.tau_last]
ax.plot(tausit, taus, color="C3")
ax.set(ylabel=r"Maximum $\tau$")
axes[-1, 1].set_axis_off()
filename = "{}/{}_checkpoint_trace.png".format(outdir, label)
msg = [
r"Maximum $\tau$" + f"={self.tau:0.1f} ",
r"$n_{\rm samples}=$" + f"{self.nsamples} ",
]
if self.thin_by_nact != 1:
msg += [
r"$n_{\rm samples}^{\rm eff}=$"
+ f"{int(self.nsamples * self.thin_by_nact)} "
]
fig.suptitle(
"| ".join(msg),
y=1,
)
fig.tight_layout()
fig.savefig(filename, dpi=200)
plt.close(fig)
@staticmethod
def _get_plot_label_by_key(key, priors=None):
if priors is not None and key in priors:
return priors[key].latex_label
elif key == LOGLKEY:
return LOGLLATEXKEY
elif key == LOGPKEY:
return LOGPLATEXKEY
else:
return key
class Sample(object):
def __init__(self, sample_dict):
"""A single sample
Parameters
----------
sample_dict: dict
A dictionary of the sample
"""
self.sample_dict = sample_dict
self.keys = list(sample_dict.keys())
self.parameter_keys = [k for k in self.keys if k not in [LOGPKEY, LOGLKEY]]
self.ndim = len(self.parameter_keys)
def __getitem__(self, key):
return self.sample_dict[key]
def __setitem__(self, key, value):
self.sample_dict[key] = value
if key not in self.keys:
self.keys = list(self.sample_dict.keys())
@property
def list(self):
return list(self.sample_dict.values())
def __repr__(self):
return str(self.sample_dict)
@property
def parameter_only_dict(self):
return {key: self.sample_dict[key] for key in self.parameter_keys}
@property
def dict(self):
return {key: self.sample_dict[key] for key in self.keys}
def as_dict(self, keys=None):
sdict = self.dict
if keys is None:
return sdict
else:
return {key: sdict[key] for key in keys}
def __eq__(self, other_sample):
return self.list == other_sample.list
def copy(self):
return Sample(self.sample_dict.copy())
def calculate_tau(x, autocorr_c=5):
import emcee
if LooseVersion(emcee.__version__) < LooseVersion("3"):
raise SamplerError("bilby-mcmc requires emcee > 3.0 for autocorr analysis")
if np.all(np.diff(x) == 0):
return np.inf
try:
# Hard code tol=1: we perform this check internally
tau = emcee.autocorr.integrated_time(x, c=autocorr_c, tol=1)[0]
if np.isnan(tau):
tau = np.inf
return tau
except emcee.autocorr.AutocorrError:
return np.inf
import torch
from nflows.distributions.normal import StandardNormal
from nflows.flows.base import Flow
from nflows.nn import nets as nets
from nflows.transforms import (
CompositeTransform,
MaskedAffineAutoregressiveTransform,
RandomPermutation,
)
from nflows.transforms.coupling import (
AdditiveCouplingTransform,
AffineCouplingTransform,
)
from nflows.transforms.normalization import BatchNorm
from torch.nn import functional as F
# Turn off parallelism
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
class NVPFlow(Flow):
"""A simplified version of Real NVP for 1-dim inputs.
This implementation uses 1-dim checkerboard masking but doesn't use
multi-scaling.
Reference:
> L. Dinh et al., Density estimation using Real NVP, ICLR 2017.
This class has been modified from the example found at:
https://github.com/bayesiains/nflows/blob/master/nflows/flows/realnvp.py
"""
def __init__(
self,
features,
hidden_features,
num_layers,
num_blocks_per_layer,
use_volume_preserving=False,
activation=F.relu,
dropout_probability=0.0,
batch_norm_within_layers=False,
batch_norm_between_layers=False,
random_permutation=True,
):
if use_volume_preserving:
coupling_constructor = AdditiveCouplingTransform
else:
coupling_constructor = AffineCouplingTransform
mask = torch.ones(features)
mask[::2] = -1
def create_resnet(in_features, out_features):
return nets.ResidualNet(
in_features,
out_features,
hidden_features=hidden_features,
num_blocks=num_blocks_per_layer,
activation=activation,
dropout_probability=dropout_probability,
use_batch_norm=batch_norm_within_layers,
)
layers = []
for _ in range(num_layers):
transform = coupling_constructor(
mask=mask, transform_net_create_fn=create_resnet
)
layers.append(transform)
mask *= -1
if batch_norm_between_layers:
layers.append(BatchNorm(features=features))
if random_permutation:
layers.append(RandomPermutation(features=features))
super().__init__(
transform=CompositeTransform(layers),
distribution=StandardNormal([features]),
)
class BasicFlow(Flow):
def __init__(self, features):
transform = CompositeTransform(
[
MaskedAffineAutoregressiveTransform(
features=features, hidden_features=2 * features
),
RandomPermutation(features=features),
]
)
distribution = StandardNormal(shape=[features])
super().__init__(
transform=transform,
distribution=distribution,
)
This diff is collapsed.
This diff is collapsed.
from collections import namedtuple
LOGLKEY = "logl"
LOGLLATEXKEY = r"$\log\mathcal{L}$"
LOGPKEY = "logp"
LOGPLATEXKEY = r"$\log\pi$"
ConvergenceInputs = namedtuple(
"ConvergenceInputs",
[
"autocorr_c",
"burn_in_nact",
"thin_by_nact",
"fixed_discard",
"target_nsamples",
"stop_after_convergence",
"L1steps",
"L2steps",
"min_tau",
"fixed_tau",
"tau_window",
],
)
ParallelTemperingInputs = namedtuple(
"ParallelTemperingInputs",
[
"ntemps",
"nensemble",
"Tmax",
"Tmax_from_SNR",
"initial_betas",
"adapt",
"adapt_t0",
"adapt_nu",
"pt_ensemble",
],
)
......@@ -40,7 +40,6 @@ class DeltaFunction(Prior):
=======
float: Rescaled probability, equivalent to peak
"""
self.test_valid_for_rescaling(val)
return self.peak * val ** 0
def prob(self, val):
......@@ -105,7 +104,6 @@ class PowerLaw(Prior):
=======
Union[float, array_like]: Rescaled probability
"""
self.test_valid_for_rescaling(val)
if self.alpha == -1:
return self.minimum * np.exp(val * np.log(self.maximum / self.minimum))
else:
......@@ -206,7 +204,6 @@ class Uniform(Prior):
=======
Union[float, array_like]: Rescaled probability
"""
self.test_valid_for_rescaling(val)
return self.minimum + val * (self.maximum - self.minimum)
def prob(self, val):
......@@ -314,7 +311,6 @@ class SymmetricLogUniform(Prior):
=======
Union[float, array_like]: Rescaled probability
"""
self.test_valid_for_rescaling(val)
if isinstance(val, (float, int)):
if val < 0.5:
return -self.maximum * np.exp(-2 * val * np.log(self.maximum / self.minimum))
......@@ -401,7 +397,6 @@ class Cosine(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
self.test_valid_for_rescaling(val)
norm = 1 / (np.sin(self.maximum) - np.sin(self.minimum))
return np.arcsin(val / norm + np.sin(self.minimum))
......@@ -456,7 +451,6 @@ class Sine(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
self.test_valid_for_rescaling(val)
norm = 1 / (np.cos(self.minimum) - np.cos(self.maximum))
return np.arccos(np.cos(self.minimum) - val / norm)
......@@ -515,7 +509,6 @@ class Gaussian(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
self.test_valid_for_rescaling(val)
return self.mu + erfinv(2 * val - 1) * 2 ** 0.5 * self.sigma
def prob(self, val):
......@@ -602,7 +595,6 @@ class TruncatedGaussian(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
self.test_valid_for_rescaling(val)
return erfinv(2 * val * self.normalisation + erf(
(self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) * 2 ** 0.5 * self.sigma + self.mu
......@@ -695,7 +687,6 @@ class LogNormal(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
self.test_valid_for_rescaling(val)
return np.exp(self.mu + np.sqrt(2 * self.sigma ** 2) * erfinv(2 * val - 1))
def prob(self, val):
......@@ -790,7 +781,6 @@ class Exponential(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
self.test_valid_for_rescaling(val)
return -self.mu * log1p(-val)
def prob(self, val):
......@@ -887,7 +877,6 @@ class StudentT(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
self.test_valid_for_rescaling(val)
if isinstance(val, (float, int)):
if val == 0:
rescaled = -np.inf
......@@ -977,7 +966,6 @@ class Beta(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
self.test_valid_for_rescaling(val)
return btdtri(self.alpha, self.beta, val) * (self.maximum - self.minimum) + self.minimum
def prob(self, val):
......@@ -1070,7 +1058,6 @@ class Logistic(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
self.test_valid_for_rescaling(val)
if isinstance(val, (float, int)):
if val == 0:
rescaled = -np.inf
......@@ -1151,7 +1138,6 @@ class Cauchy(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
self.test_valid_for_rescaling(val)
rescaled = self.alpha + self.beta * np.tan(np.pi * (val - 0.5))
if isinstance(val, (float, int)):
if val == 1:
......@@ -1233,7 +1219,6 @@ class Gamma(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
self.test_valid_for_rescaling(val)
return gammaincinv(self.k, val) * self.theta
def prob(self, val):
......@@ -1385,8 +1370,6 @@ class FermiDirac(Prior):
.. [1] M. Pitkin, M. Isi, J. Veitch & G. Woan, `arXiv:1705.08978v1
<https:arxiv.org/abs/1705.08978v1>`_, 2017.
"""
self.test_valid_for_rescaling(val)
inv = (-np.exp(-1. * self.r) + (1. + np.exp(self.r)) ** -val +
np.exp(-1. * self.r) * (1. + np.exp(self.r)) ** -val)
......@@ -1440,3 +1423,97 @@ class FermiDirac(Prior):
idx = val >= self.minimum
lnp[idx] = norm - np.logaddexp((val[idx] / self.sigma) - self.r, 0.)
return lnp
class Categorical(Prior):
def __init__(self, ncategories, name=None, latex_label=None,
unit=None, boundary="periodic"):
""" An equal-weighted Categorical prior
Parameters:
-----------
ncategories: int
The number of available categories. The prior mass support is then
integers [0, ncategories - 1].
name: str
See superclass
latex_label: str
See superclass
unit: str
See superclass
"""
minimum = 0
# Small delta added to help with MCMC walking
maximum = ncategories - 1 + 1e-15
super(Categorical, self).__init__(
name=name, latex_label=latex_label, minimum=minimum,
maximum=maximum, unit=unit, boundary=boundary)
self.ncategories = ncategories
self.categories = np.arange(self.minimum, self.maximum)
self.p = 1 / self.ncategories
self.lnp = -np.log(self.ncategories)
def rescale(self, val):
"""
'Rescale' a sample from the unit line element to the categorical prior.
This maps to the inverse CDF. This has been analytically solved for this case.
Parameters
==========
val: Union[float, int, array_like]
Uniform probability
Returns
=======
Union[float, array_like]: Rescaled probability
"""
return np.floor(val * (1 + self.maximum))
def prob(self, val):
"""Return the prior probability of val.
Parameters
==========
val: Union[float, int, array_like]
Returns
=======
float: Prior probability of val
"""
if isinstance(val, (float, int)):
if val in self.categories:
return self.p
else:
return 0
else:
val = np.atleast_1d(val)
probs = np.zeros_like(val, dtype=np.float64)
idxs = np.isin(val, self.categories)
probs[idxs] = self.p
return probs
def ln_prob(self, val):
"""Return the logarithmic prior probability of val
Parameters
==========
val: Union[float, int, array_like]
Returns
=======
float:
"""
if isinstance(val, (float, int)):
if val in self.categories:
return self.lnp
else:
return -np.inf
else:
val = np.atleast_1d(val)
probs = -np.inf * np.ones_like(val, dtype=np.float64)
idxs = np.isin(val, self.categories)
probs[idxs] = self.lnp
return probs
......@@ -202,23 +202,6 @@ class Prior(object):
"""
return (val >= self.minimum) & (val <= self.maximum)
@staticmethod
def test_valid_for_rescaling(val):
"""Test if 0 < val < 1
Parameters
==========
val: Union[float, int, array_like]
Raises
=======
ValueError: If val is not between 0 and 1
"""
valarray = np.atleast_1d(val)
tests = (valarray < 0) + (valarray > 1)
if np.any(tests):
raise ValueError("Number to be rescaled should be in [0, 1]")
def __repr__(self):
"""Overrides the special method __repr__.
......
......@@ -518,6 +518,21 @@ class PriorDict(dict):
constrained_ln_prob[keep] = ln_prob[keep] + np.log(ratio)
return constrained_ln_prob
def cdf(self, sample):
"""Evaluate the cumulative distribution function at the provided points
Parameters
----------
sample: dict, pandas.DataFrame
Dictionary of the samples of which to calculate the CDF
Returns
-------
dict, pandas.DataFrame: Dictionary containing the CDF values
"""
return sample.__class__({key: self[key].cdf(sample) for key, sample in sample.items()})
def rescale(self, keys, theta):
"""Rescale samples from unit cube to prior
......@@ -681,9 +696,7 @@ class ConditionalPriorDict(PriorDict):
float: Joint probability of all individual sample probabilities
"""
self._check_resolved()
for key, value in sample.items():
self[key].least_recently_sampled = value
self._prepare_evaluation(*zip(*sample.items()))
res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample]
prob = np.product(res, **kwargs)
return self.check_prob(sample, prob)
......@@ -703,13 +716,16 @@ class ConditionalPriorDict(PriorDict):
float: Joint log probability of all the individual sample probabilities
"""
self._check_resolved()
for key, value in sample.items():
self[key].least_recently_sampled = value
self._prepare_evaluation(*zip(*sample.items()))
res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample]
ln_prob = np.sum(res, axis=axis)
return self.check_ln_prob(sample, ln_prob)
def cdf(self, sample):
self._prepare_evaluation(*zip(*sample.items()))
res = {key: self[key].cdf(sample[key], **self.get_required_variables(key)) for key in sample}
return sample.__class__(res)
def rescale(self, keys, theta):
"""Rescale samples from unit cube to prior
......@@ -724,12 +740,14 @@ class ConditionalPriorDict(PriorDict):
=======
list: List of floats containing the rescaled sample
"""
keys = list(keys)
theta = list(theta)
self._check_resolved()
self._update_rescale_keys(keys)
result = dict()
for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes):
required_variables = {k: result[k] for k in getattr(self[key], 'required_variables', [])}
result[key] = self[key].rescale(theta[index], **required_variables)
result[key] = self[key].rescale(theta[index], **self.get_required_variables(key))
self[key].least_recently_sampled = result[key]
return [result[key] for key in keys]
def _update_rescale_keys(self, keys):
......@@ -737,6 +755,11 @@ class ConditionalPriorDict(PriorDict):
self._rescale_indexes = [keys.index(element) for element in self.sorted_keys_without_fixed_parameters]
self._least_recently_rescaled_keys = keys
def _prepare_evaluation(self, keys, theta):
self._check_resolved()
for key, value in zip(keys, theta):
self[key].least_recently_sampled = value
def _check_resolved(self):
if not self._resolved:
raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
......
......@@ -86,7 +86,6 @@ class Interped(Prior):
This maps to the inverse CDF. This is done using interpolation.
"""
self.test_valid_for_rescaling(val)
rescaled = self.inverse_cumulative_distribution(val)
if rescaled.shape == ():
rescaled = float(rescaled)
......
......@@ -172,7 +172,7 @@ class BaseJointPriorDist(object):
raise ValueError("Array is the wrong shape")
# check sample(s) is within bounds
outbounds = np.ones(samp.shape[0], dtype=np.bool)
outbounds = np.ones(samp.shape[0], dtype=bool)
for s, bound in zip(samp.T, self.bounds.values()):
outbounds = (s < bound[0]) | (s > bound[1])
if np.any(outbounds):
......@@ -630,7 +630,7 @@ class MultivariateGaussianDist(BaseJointPriorDist):
elif isinstance(self.__dict__[key], (np.ndarray, list)):
thisarr = np.asarray(self.__dict__[key])
otherarr = np.asarray(other.__dict__[key])
if thisarr.dtype == np.float and otherarr.dtype == np.float:
if thisarr.dtype == float and otherarr.dtype == float:
fin1 = np.isfinite(np.asarray(self.__dict__[key]))
fin2 = np.isfinite(np.asarray(other.__dict__[key]))
if not np.array_equal(fin1, fin2):
......@@ -710,7 +710,6 @@ class JointPrior(Prior):
A sample from the prior paramter.
"""
self.test_valid_for_rescaling(val)
self.dist.rescale_parameters[self.name] = val
if self.dist.filled_rescale():
......
......@@ -22,7 +22,7 @@ from .utils import (
recursively_load_dict_contents_from_group,
recursively_decode_bilby_json,
)
from .prior import Prior, PriorDict, DeltaFunction
from .prior import Prior, PriorDict, DeltaFunction, ConditionalDeltaFunction
def result_file_name(outdir, label, extension='json', gzip=False):
......@@ -1399,7 +1399,8 @@ class Result(object):
if priors is None:
return posterior
for key in priors:
if isinstance(priors[key], DeltaFunction):
if isinstance(priors[key], DeltaFunction) and \
not isinstance(priors[key], ConditionalDeltaFunction):
posterior[key] = priors[key].peak
elif isinstance(priors[key], float):
posterior[key] = priors[key]
......@@ -1775,7 +1776,7 @@ class ResultList(list):
# check which kind of sampler was used: MCMC or Nested Sampling
if result._nested_samples is not None:
posteriors, result = self._combine_nested_sampled_runs(result)
elif result.sampler in ["bilbymcmc"]:
elif result.sampler in ["bilby_mcmc", "bilbymcmc"]:
posteriors, result = self._combine_mcmc_sampled_runs(result)
else:
posteriors = [res.posterior for res in self]
......
......@@ -22,10 +22,11 @@ from .pymultinest import Pymultinest
from .ultranest import Ultranest
from .fake_sampler import FakeSampler
from .dnest4 import DNest4
from bilby.bilby_mcmc import Bilby_MCMC
from . import proposal
IMPLEMENTED_SAMPLERS = {
'cpnest': Cpnest, 'dnest4': DNest4, 'dynamic_dynesty': DynamicDynesty,
'bilby_mcmc': Bilby_MCMC, 'cpnest': Cpnest, 'dnest4': DNest4, 'dynamic_dynesty': DynamicDynesty,
'dynesty': Dynesty, 'emcee': Emcee,'kombine': Kombine, 'nessai': Nessai,
'nestle': Nestle, 'ptemcee': Ptemcee, 'ptmcmcsampler': PTMCMCSampler,
'pymc3': Pymc3, 'pymultinest': Pymultinest, 'pypolychord': PyPolyChord,
......
......@@ -446,6 +446,8 @@ class Sampler(object):
==========
theta: array_like
Parameter values at which to evaluate likelihood
warning: bool
Whether or not to print a warning
Returns
=======
......@@ -453,14 +455,19 @@ class Sampler(object):
True if the likelihood and prior are finite, false otherwise
"""
bad_values = [np.inf, np.nan_to_num(np.inf), np.nan]
if abs(self.log_prior(theta)) in bad_values:
log_p = self.log_prior(theta)
log_l = self.log_likelihood(theta)
return \
self._check_bad_value(val=log_p, warning=warning, theta=theta, label='prior') and \
self._check_bad_value(val=log_l, warning=warning, theta=theta, label='likelihood')
@staticmethod
def _check_bad_value(val, warning, theta, label):
val = np.abs(val)
bad_values = [np.inf, np.nan_to_num(np.inf)]
if val in bad_values or np.isnan(val):
if warning:
logger.warning('Prior draw {} has inf prior'.format(theta))
return False
if abs(self.log_likelihood(theta)) in bad_values:
if warning:
logger.warning('Prior draw {} has inf likelihood'.format(theta))
logger.warning(f'Prior draw {theta} has inf {label}')
return False
return True
......@@ -661,6 +668,10 @@ class SamplerError(Error):
""" Base class for Error related to samplers in this module """
class ResumeError(Error):
""" Class for errors arising from resuming runs """
class SamplerNotInstalledError(SamplerError):
""" Base class for Error raised by not installed samplers """
......
......@@ -13,6 +13,7 @@ from ..utils import (
check_directory_exists_and_if_not_mkdir,
reflect,
safe_file_dump,
latex_plot_format,
)
from .base_sampler import Sampler, NestedSampler
from ..result import rejection_sample
......@@ -101,7 +102,9 @@ class Dynesty(NestedSampler):
Method used to sample uniformly within the likelihood constraints,
conditioned on the provided bounds
walks: int
Number of walks taken if using `sample='rwalk'`, defaults to `ndim * 10`
Number of walks taken if using `sample='rwalk'`, defaults to `100`.
Note that the default `walks` in dynesty itself is 25, although using
`ndim * 10` can be a reasonable rule of thumb for new problems.
dlogz: float, (0.1)
Stopping criteria
verbose: Bool
......@@ -212,7 +215,7 @@ class Dynesty(NestedSampler):
def _verify_kwargs_against_default_kwargs(self):
from tqdm.auto import tqdm
if not self.kwargs['walks']:
self.kwargs['walks'] = self.ndim * 10
self.kwargs['walks'] = 100
if not self.kwargs['update_interval']:
self.kwargs['update_interval'] = int(0.6 * self.kwargs['nlive'])
if self.kwargs['print_func'] is None:
......@@ -324,7 +327,7 @@ class Dynesty(NestedSampler):
"Using the bilby-implemented rwalk sample method with ACT estimated walks")
dynesty.dynesty._SAMPLING["rwalk"] = sample_rwalk_bilby
dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_bilby
if self.kwargs.get("walks", 25) > self.kwargs.get("maxmcmc"):
if self.kwargs.get("walks") > self.kwargs.get("maxmcmc"):
raise DynestySetupError("You have maxmcmc > walks (minimum mcmc)")
if self.kwargs.get("nact", 5) < 1:
raise DynestySetupError("Unable to run with nact < 1")
......@@ -373,7 +376,6 @@ class Dynesty(NestedSampler):
dill.dump(out, file)
self._generate_result(out)
self.calc_likelihood_count()
self.result.sampling_time = self.sampling_time
if self.plot:
......@@ -383,7 +385,9 @@ class Dynesty(NestedSampler):
def _generate_result(self, out):
import dynesty
weights = np.exp(out['logwt'] - out['logz'][-1])
from scipy.special import logsumexp
logwts = out["logwt"]
weights = np.exp(logwts - out['logz'][-1])
nested_samples = DataFrame(
out.samples, columns=self.search_parameter_keys)
nested_samples['weights'] = weights
......@@ -396,6 +400,16 @@ class Dynesty(NestedSampler):
self.result.log_evidence = out.logz[-1]
self.result.log_evidence_err = out.logzerr[-1]
self.result.information_gain = out.information[-1]
self.result.num_likelihood_evaluations = getattr(self.sampler, 'ncall', 0)
logneff = logsumexp(logwts) * 2 - logsumexp(logwts * 2)
neffsamples = int(np.exp(logneff))
self.result.meta_data["run_statistics"] = dict(
nlikelihood=self.result.num_likelihood_evaluations,
neffsamples=neffsamples,
sampling_time_s=self.sampling_time.seconds,
ncores=self.kwargs.get("queue_size", 1)
)
def _run_nested_wrapper(self, kwargs):
""" Wrapper function to run_nested
......@@ -641,11 +655,7 @@ class Dynesty(NestedSampler):
plt.close('all')
try:
filename = "{}/{}_checkpoint_stats.png".format(self.outdir, self.label)
fig, axs = plt.subplots(nrows=3, sharex=True)
for ax, name in zip(axs, ["boundidx", "nc", "scale"]):
ax.plot(getattr(self.sampler, "saved_{}".format(name)), color="C0")
ax.set_ylabel(name)
axs[-1].set_xlabel("iteration")
fig, axs = dynesty_stats_plot(self.sampler)
fig.tight_layout()
plt.savefig(filename)
except (RuntimeError, ValueError) as e:
......@@ -701,16 +711,6 @@ class Dynesty(NestedSampler):
"""
return self.priors.rescale(self._search_parameter_keys, theta)
def calc_likelihood_count(self):
if self.likelihood_benchmark:
if hasattr(self, 'sampler'):
self.result.num_likelihood_evaluations = \
getattr(self.sampler, 'ncall', 0)
else:
self.result.num_likelihood_evaluations = 0
else:
return None
def sample_rwalk_bilby(args):
""" Modified bilby-implemented version of dynesty.sampling.sample_rwalk """
......@@ -728,8 +728,8 @@ def sample_rwalk_bilby(args):
# Setup.
n = len(u)
walks = kwargs.get('walks', 25) # minimum number of steps
maxmcmc = kwargs.get('maxmcmc', 2000) # Maximum number of steps
walks = kwargs.get('walks', 100) # minimum number of steps
maxmcmc = kwargs.get('maxmcmc', 5000) # Maximum number of steps
nact = kwargs.get('nact', 5) # Number of ACT
old_act = kwargs.get('old_act', walks)
......@@ -864,5 +864,72 @@ def estimate_nmcmc(accept_ratio, old_act, maxmcmc, safety=5, tau=None):
return max(safety, int(Nmcmc_exact))
@latex_plot_format
def dynesty_stats_plot(sampler):
"""
Plot diagnostic statistics from a dynesty run
The plotted quantities per iteration are:
- nc: the number of likelihood calls
- scale: the scale applied to the MCMC steps
- lifetime: the number of iterations a point stays in the live set
There is also a histogram of the lifetime compared with the theoretical
distribution. To avoid edge effects, we discard the first 6 * nlive
Parameters
----------
sampler
Returns
-------
fig: matplotlib.pyplot.figure.Figure
Figure handle for the new plot
axs: matplotlib.pyplot.axes.Axes
Axes handles for the new plot
"""
import matplotlib.pyplot as plt
from scipy.stats import geom, ks_1samp
fig, axs = plt.subplots(nrows=4, figsize=(8, 8))
for ax, name in zip(axs, ["nc", "scale"]):
ax.plot(getattr(sampler, "saved_{}".format(name)), color="blue")
ax.set_ylabel(name.title())
lifetimes = np.arange(len(sampler.saved_it)) - sampler.saved_it
axs[-2].set_ylabel("Lifetime")
nlive = sampler.nlive
burn = int(geom(p=1 / nlive).isf(1 / 2 / nlive))
if len(sampler.saved_it) > burn + sampler.nlive:
axs[-2].plot(np.arange(0, burn), lifetimes[:burn], color="grey")
axs[-2].plot(np.arange(burn, len(lifetimes) - nlive), lifetimes[burn: -nlive], color="blue")
axs[-2].plot(np.arange(len(lifetimes) - nlive, len(lifetimes)), lifetimes[-nlive:], color="red")
lifetimes = lifetimes[burn: -nlive]
ks_result = ks_1samp(lifetimes, geom(p=1 / nlive).cdf)
axs[-1].hist(
lifetimes,
bins=np.linspace(0, 6 * nlive, 60),
histtype="step",
density=True,
color="blue",
label=f"p value = {ks_result.pvalue:.3f}"
)
axs[-1].plot(
np.arange(1, 6 * nlive),
geom(p=1 / nlive).pmf(np.arange(1, 6 * nlive)),
color="red"
)
axs[-1].set_xlim(0, 6 * nlive)
axs[-1].legend()
axs[-1].set_yscale("log")
else:
axs[-2].plot(np.arange(0, len(lifetimes) - nlive), lifetimes[:-nlive], color="grey")
axs[-2].plot(np.arange(len(lifetimes) - nlive, len(lifetimes)), lifetimes[-nlive:], color="red")
axs[-2].set_yscale("log")
axs[-2].set_xlabel("Iteration")
axs[-1].set_xlabel("Lifetime")
return fig, axs
class DynestySetupError(Exception):
pass
......@@ -92,13 +92,15 @@ class Ptemcee(MCMCSampler):
is not recommended for cases where tau is large.
ignore_keys_for_tau: str
A pattern used to ignore keys in estimating the autocorrelation time.
pos0: str, list ("prior")
pos0: str, list, np.ndarray
If a string, one of "prior" or "minimize". For "prior", the initial
positions of the sampler are drawn from the sampler. If "minimize",
a scipy.optimize step is applied to all parameters a number of times.
The walkers are then initialized from the range of values obtained.
If a list, for the keys in the list the optimization step is applied,
otherwise the initial points are drawn from the prior.
otherwise the initial points are drawn from the prior. If a numpy array
the shape should be (ntemps, nwalkers, ndim).
niterations_per_check: int (5)
The number of iteration steps to take before checking ACT. This
effectively pre-thins the chains. Larger values reduce the per-eval
......@@ -363,6 +365,17 @@ class Ptemcee(MCMCSampler):
)
return pos0
def get_pos0_from_array(self):
if self.pos0.shape != (self.ntemps, self.nwalkers, self.ndim):
raise ValueError(
"Shape of starting array should be (ntemps, nwalkers, ndim). "
"In this case that is ({}, {}, {}), got {}".format(
self.ntemps, self.nwalkers, self.ndim, self.pos0.shape
)
)
else:
return self.pos0
def setup_sampler(self):
""" Either initialize the sampler or read in the resume file """
import ptemcee
......@@ -446,6 +459,8 @@ class Ptemcee(MCMCSampler):
return self.get_pos0_from_minimize()
elif isinstance(self.pos0, list):
return self.get_pos0_from_minimize(minimize_list=self.pos0)
elif isinstance(self.pos0, np.ndarray):
return self.get_pos0_from_array()
else:
raise SamplerError("pos0={} not implemented".format(self.pos0))
......