Commit c7b36e76 authored by Colm Talbot's avatar Colm Talbot Committed by Gregory Ashton

Add Constraint prior

parent 9dacda96
......@@ -16,7 +16,8 @@ from .utils import logger, infer_args_from_method, check_directory_exists_and_if
class PriorDict(OrderedDict):
def __init__(self, dictionary=None, filename=None):
def __init__(self, dictionary=None, filename=None,
conversion_function=None):
""" A set of priors
Parameters
......@@ -25,6 +26,9 @@ class PriorDict(OrderedDict):
If given, a dictionary to generate the prior set.
filename: str, None
If given, a file containing the prior to generate the prior set.
conversion_function: func
Function to convert between sampled parameters and constraints.
Default is no conversion.
"""
OrderedDict.__init__(self)
if isinstance(dictionary, dict):
......@@ -40,6 +44,35 @@ class PriorDict(OrderedDict):
self.convert_floats_to_delta_functions()
if conversion_function is not None:
self.conversion_function = conversion_function
else:
self.conversion_function = self.default_conversion_function
def evaluate_constraints(self, sample):
out_sample = self.conversion_function(sample)
prob = 1
for key in self:
if isinstance(self[key], Constraint) and key in out_sample:
prob *= self[key].prob(out_sample[key])
return prob
def default_conversion_function(self, sample):
"""
Placeholder parameter conversion function.
Parameters
----------
sample: dict
Dictionary to convert
Returns
-------
sample: dict
Same as input
"""
return sample
def to_file(self, outdir, label):
""" Write the prior distribution to file.
......@@ -168,7 +201,7 @@ class PriorDict(OrderedDict):
-------
dict: Dictionary of the samples
"""
return self.sample_subset(keys=self.keys(), size=size)
return self.sample_subset_constrained(keys=list(self.keys()), size=size)
def sample_subset(self, keys=iter([]), size=None):
"""Draw samples from the prior set for parameters which are not a DeltaFunction
......@@ -188,11 +221,35 @@ class PriorDict(OrderedDict):
samples = dict()
for key in keys:
if isinstance(self[key], Prior):
samples[key] = self[key].sample(size=size)
if isinstance(self[key], Constraint):
continue
else:
samples[key] = self[key].sample(size=size)
else:
logger.debug('{} not a known prior.'.format(key))
return samples
def sample_subset_constrained(self, keys=iter([]), size=None):
if size is None or size == 1:
while True:
sample = self.sample_subset(keys=keys, size=size)
if self.evaluate_constraints(sample):
return sample
else:
needed = np.prod(size)
all_samples = {key: np.array([]) for key in keys}
_first_key = list(all_samples.keys())[0]
while len(all_samples[_first_key]) <= needed:
samples = self.sample_subset(keys=keys, size=needed)
keep = np.array(self.evaluate_constraints(samples), dtype=bool)
for key in samples:
all_samples[key] = np.hstack(
[all_samples[key], samples[key][keep].flatten()])
all_samples = {key: np.reshape(all_samples[key][:needed], size)
for key in all_samples
if not isinstance(self[key], Constraint)}
return all_samples
def prob(self, sample, **kwargs):
"""
......@@ -208,7 +265,14 @@ class PriorDict(OrderedDict):
float: Joint probability of all individual sample probabilities
"""
return np.product([self[key].prob(sample[key]) for key in sample], **kwargs)
prob = np.product([self[key].prob(sample[key])
for key in sample], **kwargs)
if prob == 0:
return 0
elif self.evaluate_constraints(sample):
return prob
else:
return 0
def ln_prob(self, sample, axis=None):
"""
......@@ -226,8 +290,14 @@ class PriorDict(OrderedDict):
Joint log probability of all the individual sample probabilities
"""
return np.sum([self[key].ln_prob(sample[key]) for key in sample],
axis=axis)
ln_prob = np.sum([self[key].ln_prob(sample[key])
for key in sample], axis=axis)
if np.isinf(ln_prob):
return ln_prob
elif self.evaluate_constraints(sample):
return ln_prob
else:
return -np.inf
def rescale(self, keys, theta):
"""Rescale samples from unit cube to prior
......@@ -259,6 +329,8 @@ class PriorDict(OrderedDict):
"""
redundant = False
for key in self:
if isinstance(self[key], Constraint):
continue
temp = self.copy()
del temp[key]
if temp.test_redundancy(key, disable_logging=True):
......@@ -490,7 +562,7 @@ class Prior(object):
bool: Whether it's fixed or not!
"""
return isinstance(self, DeltaFunction)
return isinstance(self, (Constraint, DeltaFunction))
@property
def latex_label(self):
......@@ -553,6 +625,20 @@ class Prior(object):
return label
class Constraint(Prior):
def __init__(self, minimum, maximum, name=None, latex_label=None,
unit=None):
Prior.__init__(self, minimum=minimum, maximum=maximum, name=name,
latex_label=latex_label, unit=unit)
def prob(self, val):
return (val > self.minimum) & (val < self.maximum)
def ln_prob(self, val):
return np.log((val > self.minimum) & (val < self.maximum))
class DeltaFunction(Prior):
def __init__(self, peak, name=None, latex_label=None, unit=None):
......
......@@ -91,7 +91,8 @@ def read_in_result(filename=None, outdir=None, label=None, extension='json', gzi
class Result(object):
def __init__(self, label='no_label', outdir='.', sampler=None,
search_parameter_keys=None, fixed_parameter_keys=None,
priors=None, sampler_kwargs=None, injection_parameters=None,
constraint_parameter_keys=None, priors=None,
sampler_kwargs=None, injection_parameters=None,
meta_data=None, posterior=None, samples=None,
nested_samples=None, log_evidence=np.nan,
log_evidence_err=np.nan, log_noise_evidence=np.nan,
......@@ -106,9 +107,10 @@ class Result(object):
----------
label, outdir, sampler: str
The label, output directory, and sampler used
search_parameter_keys, fixed_parameter_keys: list
Lists of the search and fixed parameter keys. Elemenents of the
list should be of type `str` and matchs the keys of the `prior`
search_parameter_keys, fixed_parameter_keys, constraint_parameter_keys: list
Lists of the search, constraint, and fixed parameter keys.
Elements of the list should be of type `str` and match the keys
of the `prior`
priors: dict, bilby.core.prior.PriorDict
A dictionary of the priors used in the run
sampler_kwargs: dict
......@@ -155,6 +157,7 @@ class Result(object):
self.sampler = sampler
self.search_parameter_keys = search_parameter_keys
self.fixed_parameter_keys = fixed_parameter_keys
self.constraint_parameter_keys = constraint_parameter_keys
self.parameter_labels = parameter_labels
self.parameter_labels_with_unit = parameter_labels_with_unit
self.priors = priors
......@@ -384,7 +387,8 @@ class Result(object):
'label', 'outdir', 'sampler', 'log_evidence', 'log_evidence_err',
'log_noise_evidence', 'log_bayes_factor', 'priors', 'posterior',
'injection_parameters', 'meta_data', 'search_parameter_keys',
'fixed_parameter_keys', 'sampling_time', 'sampler_kwargs',
'fixed_parameter_keys', 'constraint_parameter_keys',
'sampling_time', 'sampler_kwargs',
'log_likelihood_evaluations', 'log_prior_evaluations', 'samples',
'nested_samples', 'walkers', 'nburn', 'parameter_labels',
'parameter_labels_with_unit', 'version']
......@@ -1004,8 +1008,12 @@ class Result(object):
data_frame['log_likelihood'] = getattr(
self, 'log_likelihood_evaluations', np.nan)
if self.log_prior_evaluations is None:
data_frame['log_prior'] = self.priors.ln_prob(
data_frame[self.search_parameter_keys], axis=0)
ln_prior = list()
for ii in range(len(data_frame)):
ln_prior.append(
self.priors.ln_prob(dict(
data_frame[self.search_parameter_keys].iloc[ii])))
data_frame['log_prior'] = np.array(ln_prior)
else:
data_frame['log_prior'] = self.log_prior_evaluations
if conversion_function is not None:
......
......@@ -5,7 +5,7 @@ import numpy as np
from pandas import DataFrame
from ..utils import logger, command_line_args
from ..prior import Prior, PriorDict
from ..prior import Prior, PriorDict, DeltaFunction, Constraint
from ..result import Result, read_in_result
......@@ -102,8 +102,9 @@ class Sampler(object):
self.external_sampler_function = None
self.plot = plot
self.__search_parameter_keys = []
self.__fixed_parameter_keys = []
self._search_parameter_keys = list()
self._fixed_parameter_keys = list()
self._constraint_keys = list()
self._initialise_parameters()
self._verify_parameters()
self._verify_use_ratio()
......@@ -118,28 +119,33 @@ class Sampler(object):
@property
def search_parameter_keys(self):
"""list: List of parameter keys that are being sampled"""
return self.__search_parameter_keys
return self._search_parameter_keys
@property
def fixed_parameter_keys(self):
"""list: List of parameter keys that are not being sampled"""
return self.__fixed_parameter_keys
return self._fixed_parameter_keys
@property
def constraint_parameter_keys(self):
"""list: List of parameters providing prior constraints"""
return self._constraint_parameter_keys
@property
def ndim(self):
"""int: Number of dimensions of the search parameter space"""
return len(self.__search_parameter_keys)
return len(self._search_parameter_keys)
@property
def kwargs(self):
"""dict: Container for the kwargs. Has more sophisticated logic in subclasses """
return self.__kwargs
return self._kwargs
@kwargs.setter
def kwargs(self, kwargs):
self.__kwargs = self.default_kwargs.copy()
self._kwargs = self.default_kwargs.copy()
self._translate_kwargs(kwargs)
self.__kwargs.update(kwargs)
self._kwargs.update(kwargs)
self._verify_kwargs_against_default_kwargs()
def _translate_kwargs(self, kwargs):
......@@ -179,17 +185,17 @@ class Sampler(object):
for key in self.priors:
if isinstance(self.priors[key], Prior) \
and self.priors[key].is_fixed is False:
self.__search_parameter_keys.append(key)
elif isinstance(self.priors[key], Prior) \
and self.priors[key].is_fixed is True:
self.likelihood.parameters[key] = \
self.priors[key].sample()
self.__fixed_parameter_keys.append(key)
self._search_parameter_keys.append(key)
elif isinstance(self.priors[key], Constraint):
self._constraint_keys.append(key)
elif isinstance(self.priors[key], DeltaFunction):
self.likelihood.parameters[key] = self.priors[key].sample()
self._fixed_parameter_keys.append(key)
logger.info("Search parameters:")
for key in self.__search_parameter_keys:
for key in self._search_parameter_keys + self._constraint_keys:
logger.info(' {} = {}'.format(key, self.priors[key]))
for key in self.__fixed_parameter_keys:
for key in self._fixed_parameter_keys:
logger.info(' {} = {}'.format(key, self.priors[key].peak))
def _initialise_result(self, result_class):
......@@ -202,8 +208,9 @@ class Sampler(object):
result_kwargs = dict(
label=self.label, outdir=self.outdir,
sampler=self.__class__.__name__.lower(),
search_parameter_keys=self.__search_parameter_keys,
fixed_parameter_keys=self.__fixed_parameter_keys,
search_parameter_keys=self._search_parameter_keys,
fixed_parameter_keys=self._fixed_parameter_keys,
constraint_parameter_keys=self._constraint_keys,
priors=self.priors, meta_data=self.meta_data,
injection_parameters=self.injection_parameters,
sampler_kwargs=self.kwargs)
......@@ -227,6 +234,8 @@ class Sampler(object):
prior can't be sampled.
"""
for key in self.priors:
if isinstance(self.priors[key], Constraint):
continue
try:
self.likelihood.parameters[key] = self.priors[key].sample()
except AttributeError as e:
......@@ -248,7 +257,9 @@ class Sampler(object):
self._check_if_priors_can_be_sampled()
try:
t1 = datetime.datetime.now()
self.likelihood.log_likelihood()
theta = [self.priors[key].sample()
for key in self._search_parameter_keys]
self.log_likelihood(theta)
self._log_likelihood_eval_time = (
datetime.datetime.now() - t1).total_seconds()
if self._log_likelihood_eval_time == 0:
......@@ -296,7 +307,7 @@ class Sampler(object):
-------
list: Properly rescaled sampled values
"""
return self.priors.rescale(self.__search_parameter_keys, theta)
return self.priors.rescale(self._search_parameter_keys, theta)
def log_prior(self, theta):
"""
......@@ -308,11 +319,12 @@ class Sampler(object):
Returns
-------
float: TODO: Fill in proper explanation of what this is.
float: Joint ln prior probability of theta
"""
return self.priors.ln_prob({
key: t for key, t in zip(self.__search_parameter_keys, theta)})
params = {
key: t for key, t in zip(self._search_parameter_keys, theta)}
return self.priors.ln_prob(params)
def log_likelihood(self, theta):
"""
......@@ -328,8 +340,9 @@ class Sampler(object):
likelihood.parameter values
"""
for i, k in enumerate(self.__search_parameter_keys):
self.likelihood.parameters[k] = theta[i]
params = {
key: t for key, t in zip(self._search_parameter_keys, theta)}
self.likelihood.parameters.update(params)
if self.use_ratio:
return self.likelihood.log_likelihood_ratio()
else:
......@@ -347,7 +360,7 @@ class Sampler(object):
"""
new_sample = self.priors.sample()
draw = np.array(list(new_sample[key]
for key in self.__search_parameter_keys))
for key in self._search_parameter_keys))
self.check_draw(draw)
return draw
......@@ -459,6 +472,26 @@ class NestedSampler(Sampler):
idxs.append(idx[0])
return unsorted_loglikelihoods[idxs]
def log_likelihood(self, theta):
"""
Since some nested samplers don't call the log_prior method, evaluate
the prior constraint here.
Parameters
theta: array-like
Parameter values at which to evaluate likelihood
Returns
-------
float: log_likelihood
"""
if self.priors.evaluate_constraints({
key: theta[ii] for ii, key in
enumerate(self.search_parameter_keys)}):
return Sampler.log_likelihood(self, theta)
else:
return np.nan_to_num(-np.inf)
class MCMCSampler(Sampler):
nwalkers_equiv_kwargs = ['nwalker', 'nwalkers', 'draws']
......
......@@ -4,9 +4,8 @@ from collections import OrderedDict
import numpy as np
from ..utils import derivatives, logger, infer_args_from_method
from ..prior import Prior, DeltaFunction, Sine, Cosine, PowerLaw
from ..result import Result
from ..utils import derivatives, infer_args_from_method
from ..prior import DeltaFunction, Sine, Cosine, PowerLaw
from .base_sampler import Sampler, MCMCSampler
from ..likelihood import GaussianLikelihood, PoissonLikelihood, ExponentialLikelihood, \
StudentTLikelihood
......@@ -67,8 +66,8 @@ class Pymc3(MCMCSampler):
Sampler.__init__(self, likelihood, priors, outdir=outdir, label=label,
use_ratio=use_ratio, plot=plot,
skip_import_verification=skip_import_verification, **kwargs)
self.draws = self.__kwargs['draws']
self.chains = self.__kwargs['chains']
self.draws = self._kwargs['draws']
self.chains = self._kwargs['chains']
@staticmethod
def _import_external_sampler():
......@@ -97,71 +96,6 @@ class Pymc3(MCMCSampler):
"""
pass
def _initialise_parameters(self):
"""
Change `_initialise_parameters()`, so that it does call the `sample`
method in the Prior class.
"""
self.__search_parameter_keys = []
self.__fixed_parameter_keys = []
for key in self.priors:
if isinstance(self.priors[key], Prior) \
and self.priors[key].is_fixed is False:
self.__search_parameter_keys.append(key)
elif isinstance(self.priors[key], Prior) \
and self.priors[key].is_fixed is True:
self.__fixed_parameter_keys.append(key)
logger.info("Search parameters:")
for key in self.__search_parameter_keys:
logger.info(' {} = {}'.format(key, self.priors[key]))
for key in self.__fixed_parameter_keys:
logger.info(' {} = {}'.format(key, self.priors[key].peak))
def _initialise_result(self, result_class):
"""
Initialise results within Pymc3 subclass.
"""
result_kwargs = dict(
label=self.label, outdir=self.outdir,
sampler=self.__class__.__name__.lower(),
search_parameter_keys=self.__search_parameter_keys,
fixed_parameter_keys=self.__fixed_parameter_keys,
priors=self.priors, meta_data=self.meta_data,
injection_parameters=self.injection_parameters,
sampler_kwargs=self.kwargs)
if result_class is None:
result = Result(**result_kwargs)
elif issubclass(result_class, Result):
result = result_class(**result_kwargs)
else:
raise ValueError(
"Input result_class={} not understood".format(result_class))
return result
@property
def kwargs(self):
""" Ensures that proper keyword arguments are used for the Pymc3 sampler.
Returns
-------
dict: Keyword arguments used for the Nestle Sampler
"""
return self.__kwargs
@kwargs.setter
def kwargs(self, kwargs):
self.__kwargs = self.default_kwargs.copy()
self.__kwargs.update(kwargs)
self._verify_kwargs_against_default_kwargs()
def setup_prior_mapping(self):
"""
Set the mapping between predefined bilby priors and the equivalent
......@@ -393,8 +327,8 @@ class Pymc3(MCMCSampler):
# set the step method
pymc3, STEP_METHODS, floatX = self._import_external_sampler()
step_methods = {m.__name__.lower(): m.__name__ for m in STEP_METHODS}
if 'step' in self.__kwargs:
self.step_method = self.__kwargs.pop('step')
if 'step' in self._kwargs:
self.step_method = self._kwargs.pop('step')
# 'step' could be a dictionary of methods for different parameters,
# so check for this
......@@ -402,7 +336,7 @@ class Pymc3(MCMCSampler):
pass
elif isinstance(self.step_method, (dict, OrderedDict)):
for key in self.step_method:
if key not in self.__search_parameter_keys:
if key not in self._search_parameter_keys:
raise ValueError("Setting a step method for an unknown parameter '{}'".format(key))
else:
# check if using a compound step (a list of step
......@@ -780,11 +714,11 @@ class Pymc3(MCMCSampler):
pymc3.StudentT('likelihood', nu=self.likelihood.nu, mu=model, sd=self.likelihood.sigma,
observed=self.likelihood.y)
elif isinstance(self.likelihood, (GravitationalWaveTransient, BasicGravitationalWaveTransient)):
# set theano Op - pass __search_parameter_keys, which only contains non-fixed variables
logl = LogLike(self.__search_parameter_keys, self.likelihood, self.pymc3_priors)
# set theano Op - pass _search_parameter_keys, which only contains non-fixed variables
logl = LogLike(self._search_parameter_keys, self.likelihood, self.pymc3_priors)
parameters = OrderedDict()
for key in self.__search_parameter_keys:
for key in self._search_parameter_keys:
try:
parameters[key] = self.pymc3_priors[key]
except KeyError:
......
......@@ -841,8 +841,8 @@ def generate_component_spins(sample):
"""
output_sample = sample.copy()
spin_conversion_parameters =\
['theta_jn', 'phi_jl', 'tilt_1', 'tilt_2', 'phi_12', 'a_1', 'a_2', 'mass_1',
'mass_2', 'reference_frequency', 'phase']
['theta_jn', 'phi_jl', 'tilt_1', 'tilt_2', 'phi_12', 'a_1', 'a_2',
'mass_1', 'mass_2', 'reference_frequency', 'phase']
if all(key in output_sample.keys() for key in spin_conversion_parameters):
output_sample['iota'], output_sample['spin_1x'],\
output_sample['spin_1y'], output_sample['spin_1z'], \
......
......@@ -4,8 +4,12 @@ import numpy as np
from scipy.interpolate import UnivariateSpline
from ..core.prior import (PriorDict, Uniform, Prior, DeltaFunction, Gaussian,
Interped)
Interped, Constraint)
from ..core.utils import infer_args_from_method, logger
from .conversion import (
convert_to_lal_binary_black_hole_parameters,
convert_to_lal_binary_neutron_star_parameters, generate_mass_parameters,
generate_tidal_parameters, fill_from_fixed_priors)
from .cosmology import get_cosmology
try:
......@@ -193,7 +197,8 @@ class AlignedSpin(Interped):
class BBHPriorDict(PriorDict):
def __init__(self, dictionary=None, filename=None, aligned_spin=False):
def __init__(self, dictionary=None, filename=None, aligned_spin=False,
conversion_function=None):
""" Initialises a Prior set for Binary Black holes
Parameters
......@@ -202,6 +207,10 @@ class BBHPriorDict(PriorDict):
See superclass
filename: str, optional
See superclass
conversion_function: func
Function to convert between sampled parameters and constraints.
By default this generates many additional parameters, see
BBHPriorDict.default_conversion_function
"""
basedir = os.path.join(os.path.dirname(__file__), 'prior_files')
if dictionary is None and filename is None:
......@@ -214,7 +223,36 @@ class BBHPriorDict(PriorDict):
elif filename is not None:
if not os.path.isfile(filename):
filename = os.path.join(os.path.dirname(__file__), 'prior_files', filename)
PriorDict.__init__(self, dictionary=dictionary, filename=filename)
PriorDict.__init__(self, dictionary=dictionary, filename=filename,
conversion_function=conversion_function)
def default_conversion_function(self, sample):
"""
Default parameter conversion function for BBH signals.
This generates:
- the parameters passed to source.lal_binary_black_hole
- all mass parameters
It does not generate:
- component spins
- source-frame parameters
Parameters