diff --git a/bilby/core/prior.py b/bilby/core/prior.py index 65fb6aa2d264e76e12928d038ca4a8bb50187fc8..09d7aadaf5d47d150540241c09d9afab2f7588b6 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -332,11 +332,10 @@ class PriorDict(OrderedDict): self.convert_floats_to_delta_functions() samples = dict() for key in keys: - if isinstance(self[key], Prior): - if isinstance(self[key], Constraint): - continue - else: - samples[key] = self[key].sample(size=size) + if isinstance(self[key], Constraint): + continue + elif isinstance(self[key], Prior): + samples[key] = self[key].sample(size=size) else: logger.debug('{} not a known prior.'.format(key)) return samples @@ -544,18 +543,25 @@ class ConditionalPriorDict(PriorDict): self.convert_floats_to_delta_functions() subset_dict = ConditionalPriorDict({key: self[key] for key in keys}) if not subset_dict._resolved: - raise IllegalConditionsException("The current set of priors contains unresolveable conditions.") - res = dict() + raise IllegalConditionsException("The current set of priors contains unresolvable conditions.") + samples = dict() for key in subset_dict.sorted_keys: - if isinstance(self[key], Prior): - if isinstance(self[key], Constraint): - continue - else: - res[key] = self[key].sample( - size=size, **subset_dict.get_required_variables(key)) + if isinstance(self[key], Constraint): + continue + elif isinstance(self[key], Prior): + try: + samples[key] = subset_dict[key].sample(size=size, **subset_dict.get_required_variables(key)) + except ValueError: + # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) + # If that is the case, we sample each sample individually. + required_variables = subset_dict.get_required_variables(key) + samples[key] = np.zeros(size) + for i in range(size): + rvars = {key: value[i] for key, value in required_variables.items()} + samples[key][i] = subset_dict[key].sample(**rvars) else: logger.debug('{} not a known prior.'.format(key)) - return res + return samples def get_required_variables(self, key): """ Returns the required variables to sample a given conditional key. @@ -633,7 +639,7 @@ class ConditionalPriorDict(PriorDict): for key, index in zip(self._rescale_keys, self._rescale_indexes): required_variables = {k: result[k] for k in getattr(self[key], 'required_variables', [])} result[key] = self[key].rescale(theta[index], **required_variables) - return list(result.values()) + return [result[key] for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: @@ -3570,7 +3576,8 @@ def conditional_prior_factory(prior_class): return super(ConditionalPrior, self).prob(val) def ln_prob(self, val, **required_variables): - return np.log(self.prob(val, **required_variables)) + self.update_conditions(**required_variables) + return super(ConditionalPrior, self).ln_prob(val) def update_conditions(self, **required_variables): """ diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py index 0ac52c7ba1454d5c39949c03afc767842d6e9844..354548889afd2c63a8ba49f848582051963a5cb4 100644 --- a/bilby/core/sampler/cpnest.py +++ b/bilby/core/sampler/cpnest.py @@ -61,9 +61,7 @@ class Cpnest(NestedSampler): def __init__(self, names, priors): self.names = names self.priors = priors - self.bounds = [ - [self.priors[key].minimum, self.priors[key].maximum] - for key in self.names] + self._update_bounds() @staticmethod def log_likelihood(x, **kwargs): @@ -75,10 +73,17 @@ class Cpnest(NestedSampler): theta = [x[n] for n in self.search_parameter_keys] return self.log_prior(theta) + def _update_bounds(self): + self.bounds = [ + [self.priors[key].minimum, self.priors[key].maximum] + for key in self.names] + def new_point(self): """Draw a point from the prior""" + prior_samples = self.priors.sample() + self._update_bounds() point = LivePoint( - self.names, [self.priors[name].sample() + self.names, [prior_samples[name] for name in self.names]) return point diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 32c238ccb9e865342544bf47b402fa95fee46100..4bc4e0abe1a137e88bdc9f11dc75fe01a4728c83 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -4,7 +4,7 @@ import numpy as np from scipy.interpolate import InterpolatedUnivariateSpline from ..core.prior import (ConditionalPriorDict, PriorDict, Uniform, Prior, DeltaFunction, Gaussian, - Interped, Constraint, ConditionalUniform, conditional_prior_factory) + Interped, Constraint, conditional_prior_factory) from ..core.utils import infer_args_from_method, logger from .conversion import ( convert_to_lal_binary_black_hole_parameters, @@ -635,15 +635,6 @@ def secondary_mass_condition_function(reference_params, mass_1): return dict(minimum=reference_params['minimum'], maximum=mass_1) -class ConditionalSecondaryMassPrior(ConditionalUniform): - - def __init__(self, name=None, latex_label=None, unit=None, minimum=0, maximum=np.inf): - super(ConditionalSecondaryMassPrior, self).__init__(minimum=minimum, maximum=maximum, - name=name, latex_label=latex_label, - unit=unit, - condition_func=secondary_mass_condition_function) - - ConditionalCosmological = conditional_prior_factory(Cosmological) ConditionalUniformComovingVolume = conditional_prior_factory(UniformComovingVolume) ConditionalUniformSourceFrame = conditional_prior_factory(UniformSourceFrame) diff --git a/examples/core_examples/conditional_prior.py b/examples/core_examples/conditional_prior.py index 3e46799838949838bd2c4e3d0e4be3317ce8fb63..2475e479d160f1fe0088f9142b6fdec31919a9e1 100644 --- a/examples/core_examples/conditional_prior.py +++ b/examples/core_examples/conditional_prior.py @@ -3,31 +3,33 @@ import numpy as np # This tutorial demonstrates how we can sample a prior in the shape of a ball # Note that this will not end up sampling uniformly in that space, constraint priors are more suitable for that. -# This implementation will draw a value for the x-coordinate, and given that draw a value for the -# y-coordinate, and given that draw a value for the z-coordinate. +# This implementation will draw a value for the x-coordinate from p(x), and given that draw a value for the +# y-coordinate from p(y|x), and given that draw a value for the z-coordinate from p(z|x,y). # Only the x-coordinate will end up being uniform for this class ZeroLikelihood(bilby.core.likelihood.Likelihood): """ Flat likelihood. This always returns 0. - This way we can see if we correctly sampled uniformly in the prior space""" + This way our posterior distribution is exactly the prior distribution.""" def log_likelihood(self): return 0 def condition_func_y(reference_params, x): + """ Condition function for our p(y|x) prior.""" radius = 0.5 * (reference_params['maximum'] - reference_params['minimum']) y_max = np.sqrt(radius**2 - x**2) return dict(minimum=-y_max, maximum=y_max) def condition_func_z(reference_params, x, y): - """""" + """ Condition function for our p(z|x, y) prior.""" radius = 0.5 * (reference_params['maximum'] - reference_params['minimum']) z_max = np.sqrt(radius**2 - x**2 - y**2) return dict(minimum=-z_max, maximum=z_max) +# Set up the conditional priors and the flat likelihood priors = bilby.core.prior.ConditionalPriorDict() priors['x'] = bilby.core.prior.Uniform(minimum=-1, maximum=1, latex_label="$x$") priors['y'] = bilby.core.prior.ConditionalUniform(condition_func=condition_func_y, minimum=-1, @@ -36,6 +38,7 @@ priors['z'] = bilby.core.prior.ConditionalUniform(condition_func=condition_func_ maximum=1, latex_label="$z$") likelihood = ZeroLikelihood(parameters=dict(x=0, y=0, z=0)) +# Sample the prior distribution res = bilby.run_sampler(likelihood=likelihood, priors=priors, sampler='dynesty', npoints=5000, walks=100, label='conditional_prior', outdir='outdir', resume=False, clean=True) res.plot_corner() diff --git a/examples/gw_examples/injection_examples/conditional_prior.py b/examples/gw_examples/injection_examples/conditional_prior.py deleted file mode 100644 index 4774c9aca20e9dd2e954d3a0b02257429063d1f2..0000000000000000000000000000000000000000 --- a/examples/gw_examples/injection_examples/conditional_prior.py +++ /dev/null @@ -1,90 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np - -import bilby.gw.prior - - -def condition_function(reference_params, mass_1): - return dict(minimum=reference_params['minimum'], maximum=mass_1) - - -mass_1_min = 5 -mass_1_max = 100 - -mass_1 = bilby.core.prior.Uniform(minimum=mass_1_min, maximum=mass_1_max, name='mass_1', - latex_label='$m_1$', boundary='reflective') -mass_2 = bilby.core.prior.ConditionalUniform(minimum=mass_1_min, maximum=mass_1_max, name='mass_2', - latex_label='$m_2$', condition_func=condition_function, - boundary='reflective') - -conditional_dict = bilby.core.prior.ConditionalPriorDict(dictionary=dict(mass_1=mass_1, mass_2=mass_2)) - - -res = conditional_dict.sample(100000) - -plt.hist(res['mass_1'], bins='fd', alpha=0.6, density=True, label='Sampled') -plt.plot(np.linspace(2, 100, 200), conditional_dict['mass_1'].prob(np.linspace(2, 100, 200)), label='Uniform prior') -plt.xlabel('$m_1$') -plt.ylabel('$p(m_1)$') -plt.legend() -plt.tight_layout() -plt.show() -plt.clf() - - -plt.hist(res['mass_2'], bins='fd', alpha=0.6, density=True, label='Sampled') -plt.xlabel('$m_2$') -plt.ylabel('$p(m_2 | m_1)$') -plt.legend() -plt.tight_layout() -plt.show() -plt.clf() - - -duration = 4. -sampling_frequency = 2048. -outdir = 'outdir' -label = 'conditional_prior' -bilby.core.utils.setup_logger(outdir=outdir, label=label) - -np.random.seed(88170235) - -injection_parameters = dict( - mass_1=30., mass_2=30, a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0, - phi_12=1.7, phi_jl=0.3, luminosity_distance=1500., theta_jn=0.4, psi=2.659, - phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108) - -waveform_arguments = dict(waveform_approximant='IMRPhenomPv2', - reference_frequency=50., minimum_frequency=20.) - -waveform_generator = bilby.gw.WaveformGenerator( - duration=duration, sampling_frequency=sampling_frequency, - frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, - waveform_arguments=waveform_arguments) - -ifos = bilby.gw.detector.InterferometerList(['H1', 'L1']) -ifos.set_strain_data_from_power_spectral_densities( - sampling_frequency=sampling_frequency, duration=duration, - start_time=injection_parameters['geocent_time'] - 3) -ifos.inject_signal(waveform_generator=waveform_generator, - parameters=injection_parameters) - -priors = bilby.core.prior.ConditionalPriorDict() -for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'psi', 'ra', - 'dec', 'geocent_time', 'phase', 'theta_jn', 'luminosity_distance']: - priors[key] = injection_parameters[key] -priors['mass_1'] = mass_1 -priors['mass_2'] = mass_2 - -# Initialise the likelihood by passing in the interferometer data (ifos) and -# the waveform generator -likelihood = bilby.gw.GravitationalWaveTransient( - interferometers=ifos, waveform_generator=waveform_generator) - -# Run sampler. In this case we're going to use the `dynesty` sampler -result = bilby.run_sampler( - likelihood=likelihood, priors=priors, sampler='dynesty', npoints=100, - injection_parameters=injection_parameters, outdir=outdir, label=label, clean=True, resume=False) - -# Make a corner plot. -result.plot_corner() diff --git a/test/prior_test.py b/test/prior_test.py index f6e8fc966cef8f3aae98a5d03b347254de6f17d3..6d7b0ecae1e42494c104dcb34205435918969518 100644 --- a/test/prior_test.py +++ b/test/prior_test.py @@ -148,6 +148,7 @@ class TestPriorBoundary(unittest.TestCase): with self.assertRaises(ValueError): self.prior.boundary = 'else' + class TestPriorClasses(unittest.TestCase): def setUp(self): @@ -1021,8 +1022,10 @@ class TestConditionalPrior(unittest.TestCase): with mock.patch.object(self.prior, 'update_conditions') as m: self.prior.ln_prob(1, test_parameter_1=self.test_variable_1, test_parameter_2=self.test_variable_2) - m.assert_called_with(test_parameter_1=self.test_variable_1, - test_parameter_2=self.test_variable_2) + calls = [mock.call(test_parameter_1=self.test_variable_1, + test_parameter_2=self.test_variable_2), + mock.call()] + m.assert_has_calls(calls) def test_reset_to_reference_parameters(self): self.prior.minimum = 10 @@ -1122,6 +1125,17 @@ class TestConditionalPriorDict(unittest.TestCase): with self.assertRaises(bilby.core.prior.IllegalConditionsException): self.conditional_priors.sample_subset(keys=['var_1']) + def test_sample_multiple(self): + def condition_func(reference_params, a): + return dict(minimum=reference_params['minimum'], + maximum=reference_params['maximum'], + alpha=reference_params['alpha'] * a) + priors = bilby.core.prior.ConditionalPriorDict() + priors['a'] = bilby.core.prior.Uniform(minimum=0, maximum=1) + priors['b'] = bilby.core.prior.ConditionalPowerLaw(condition_func=condition_func, minimum=1, maximum=2, + alpha=-2) + print(priors.sample(2)) + def test_rescale(self): def condition_func_1_rescale(reference_parameters, var_0):