Skip to content
Snippets Groups Projects
Commit 375e6e90 authored by Paul Lasky's avatar Paul Lasky
Browse files

Merge branch 'add_uniform_in_log' into 'master'

Add uniform in log prior

See merge request Monash/tupak!42
parents c8d84daf 1d6a5ef3
No related branches found
No related tags found
1 merge request!42Add uniform in log prior
Pipeline #
......@@ -24,7 +24,7 @@ def sine_gaussian(f, A, f0, tau, phi0, geocent_time, ra, dec, psi):
# We now define some parameters that we will inject and then a waveform generator
injection_parameters = dict(A=1e-21, f0=10, tau=1, phi0=0, geocent_time=0,
injection_parameters = dict(A=1e-23, f0=100, tau=1, phi0=0, geocent_time=0,
ra=0, dec=0, psi=0)
waveform_generator = tupak.waveform_generator.WaveformGenerator(time_duration=time_duration,
sampling_frequency=sampling_frequency,
......@@ -42,15 +42,14 @@ IFOs = [tupak.detector.get_interferometer_with_fake_noise_and_injection(
# Here we define the priors for the search. We use the injection parameters
# except for the amplitude, f0, and geocent_time
prior = injection_parameters.copy()
prior['A'] = tupak.prior.Uniform(0, 1e-20, 'A')
prior['f0'] = tupak.prior.Uniform(0, 20, 'f')
prior['geocent_time'] = tupak.prior.Uniform(-0.01, 0.01, 'geocent_time')
prior['A'] = tupak.prior.PowerLaw(alpha=-1, minimum=1e-25, maximum=1e-21, name='A')
prior['f0'] = tupak.prior.Uniform(90, 110, 'f')
likelihood = tupak.likelihood.GravitationalWaveTransient(IFOs, waveform_generator)
result = tupak.sampler.run_sampler(
likelihood, prior, sampler='dynesty', outdir=outdir, label=label,
resume=False, sample='unif')
resume=False, sample='unif', injection_parameters=injection_parameters)
result.plot_walks()
result.plot_distributions()
result.plot_corner()
......
Source diff could not be displayed: it is too large. Options to address this: view the blob.
......@@ -92,8 +92,9 @@ class TestPriorClasses(unittest.TestCase):
tupak.prior.DeltaFunction(name='test', peak=1),
tupak.prior.Gaussian(name='test', mu=0, sigma=1),
tupak.prior.PowerLaw(name='test', alpha=0, minimum=0, maximum=1),
tupak.prior.PowerLaw(name='test', alpha=-1, minimum=1, maximum=1e2),
tupak.prior.PowerLaw(name='test', alpha=2, minimum=1, maximum=1e2),
tupak.prior.Uniform(name='test', minimum=0, maximum=1),
tupak.prior.LogUniform(name='test', minimum=5e0, maximum=1e2),
tupak.prior.UniformComovingVolume(name='test', minimum=2e2, maximum=5e3),
tupak.prior.Sine(name='test'),
tupak.prior.Cosine(name='test'),
......@@ -102,26 +103,42 @@ class TestPriorClasses(unittest.TestCase):
tupak.prior.TruncatedGaussian(name='test', mu=1, sigma=0.4, minimum=-1, maximum=1)
]
def test_rescaling(self):
def test_minimum_rescaling(self):
"""Test the the rescaling works as expected."""
for prior in self.priors:
"""Test the the rescaling works as expected."""
minimum_sample = prior.rescale(0)
self.assertAlmostEqual(minimum_sample, prior.minimum)
def test_maximum_rescaling(self):
"""Test the the rescaling works as expected."""
for prior in self.priors:
maximum_sample = prior.rescale(1)
self.assertAlmostEqual(maximum_sample, prior.maximum)
def test_many_sample_rescaling(self):
"""Test the the rescaling works as expected."""
for prior in self.priors:
many_samples = prior.rescale(np.random.uniform(0, 1, 1000))
self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)))
def test_out_of_bounds_rescaling(self):
"""Test the the rescaling works as expected."""
for prior in self.priors:
self.assertRaises(ValueError, lambda: prior.rescale(-1))
def test_sampling(self):
def test_sampling_single(self):
"""Test that sampling from the prior always returns values within its domain."""
for prior in self.priors:
single_sample = prior.sample()
self.assertTrue((single_sample >= prior.minimum) & (single_sample <= prior.maximum))
def test_sampling_many(self):
"""Test that sampling from the prior always returns values within its domain."""
for prior in self.priors:
many_samples = prior.sample(1000)
self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)))
def test_prob(self):
def test_probability_above_domain(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
# skip delta function prior in this case
......@@ -130,15 +147,36 @@ class TestPriorClasses(unittest.TestCase):
if prior.maximum != np.inf:
outside_domain = np.linspace(prior.maximum + 1, prior.maximum + 1e4, 1000)
self.assertTrue(all(prior.prob(outside_domain) == 0))
def test_probability_below_domain(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
# skip delta function prior in this case
if isinstance(prior, tupak.prior.DeltaFunction):
continue
if prior.minimum != -np.inf:
outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000)
self.assertTrue(all(prior.prob(outside_domain) == 0))
def test_probability_in_domain(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
# skip delta function prior in this case
if isinstance(prior, tupak.prior.DeltaFunction):
continue
if prior.minimum == -np.inf:
prior.minimum = -1e5
if prior.maximum == np.inf:
prior.maximum = 1e5
domain = np.linspace(prior.minimum, prior.maximum, 1000)
self.assertTrue(all(prior.prob(domain) >= 0))
def test_probability_surrounding_domain(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
# skip delta function prior in this case
if isinstance(prior, tupak.prior.DeltaFunction):
continue
surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000)
prior.prob(surround_domain)
......@@ -151,6 +189,7 @@ class TestPriorClasses(unittest.TestCase):
domain = np.linspace(-1e2, 1e2, 1000)
else:
domain = np.linspace(prior.minimum, prior.maximum, 1000)
print(prior.minimum, prior.maximum)
self.assertAlmostEqual(np.trapz(prior.prob(domain), domain), 1, 3)
......@@ -189,6 +228,5 @@ class TestFillPrior(unittest.TestCase):
self.assertIsInstance(self.priors['ra'], tupak.prior.Uniform)
if __name__ == '__main__':
unittest.main()
......@@ -123,23 +123,6 @@ class Prior(object):
return label
class Uniform(Prior):
"""Uniform prior"""
def __init__(self, minimum, maximum, name=None, latex_label=None):
Prior.__init__(self, name, latex_label, minimum, maximum)
self.support = maximum - minimum
def rescale(self, val):
Prior.test_valid_for_rescaling(val)
return self.minimum + val * self.support
def prob(self, val):
"""Return the prior probability of val"""
in_prior = (val >= self.minimum) & (val <= self.maximum)
return 1 / self.support * in_prior
class DeltaFunction(Prior):
"""Dirac delta function prior, this always returns peak."""
......@@ -191,6 +174,24 @@ class PowerLaw(Prior):
- self.minimum ** (1 + self.alpha))) * in_prior
class Uniform(PowerLaw):
"""Uniform prior"""
def __init__(self, minimum, maximum, name=None, latex_label=None):
Prior.__init__(self, name, latex_label, minimum, maximum)
self.alpha = 0
class LogUniform(PowerLaw):
"""Uniform prior"""
def __init__(self, minimum, maximum, name=None, latex_label=None):
Prior.__init__(self, name, latex_label, minimum, maximum)
self.alpha = -1
if self.minimum<=0:
logging.warning('You specified a uniform-in-log prior with minimum={}'.format(self.minimum))
class Cosine(Prior):
def __init__(self, name=None, latex_label=None, minimum=-np.pi / 2, maximum=np.pi / 2):
......
......@@ -466,7 +466,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
result.log_bayes_factor = result.logz - result.noise_logz
if injection_parameters is not None:
result.injection_parameters = injection_parameters
tupak.conversion.generate_all_bbh_parameters(result.injection_parameters)
if conversion_function is not None:
conversion_function(result.injection_parameters)
result.fixed_parameter_keys = [key for key in priors if isinstance(key, prior.DeltaFunction)]
# result.prior = prior # Removed as this breaks the saving of the data
result.samples_to_data_frame(likelihood=likelihood, priors=priors, conversion_function=conversion_function)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment