From bfe8dffcadf0b39cda7eb9a82b638fb323324b19 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Tue, 25 Jun 2019 21:15:56 -0500 Subject: [PATCH] Add a zero-likelihood mode This adds a test-mode command like argument to run with zero likelihood. As such, this will sample from the priors alone. To use, run any standard bilby example with `--bilby-zero-likelihood-mode` in the command line. This has been intentionally obfuscated to avoid overloading the usual documentation. --- .gitlab-ci.yml | 1 + bilby/core/likelihood.py | 25 +++++++++ bilby/core/sampler/__init__.py | 5 ++ bilby/core/utils.py | 3 ++ setup.cfg | 1 + test/sample_from_the_prior_test.py | 83 ++++++++++++++++++++++++++++++ 6 files changed, 118 insertions(+) create mode 100644 test/sample_from_the_prior_test.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c22cc98c..1d38624c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -85,6 +85,7 @@ scheduled-python-3.7: # Run tests which are only done on schedule - pytest test/example_test.py - pytest test/gw_example_test.py + - pytest test/sample_from_the_prior_test.py pages: stage: deploy diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index f1286d5b..2f7933a3 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -62,6 +62,31 @@ class Likelihood(object): raise ValueError("The meta_data must be an instance of dict") +class ZeroLikelihood(Likelihood): + """ A special test-only class which already returns zero likelihood + + Parameters + ---------- + likelihood: bilby.core.likelihood.Likelihood + A likelihood object to mimic + + """ + + def __init__(self, likelihood): + Likelihood.__init__(self, dict.fromkeys(likelihood.parameters)) + self.parameters = likelihood.parameters + self._parent = likelihood + + def log_likelihood(self): + return 0 + + def noise_log_likelihood(self): + return 0 + + def __getattr__(self, name): + return getattr(self._parent, name) + + class Analytical1DLikelihood(Likelihood): """ A general class for 1D analytical functions. The model diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 424f1e10..37311298 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -133,6 +133,11 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', meta_data = dict() meta_data['likelihood'] = likelihood.meta_data + if command_line_args.bilby_zero_likelihood_mode: + from bilby.core.likelihood import ZeroLikelihood + likelihood = ZeroLikelihood(likelihood) + + if isinstance(sampler, Sampler): pass elif isinstance(sampler, str): diff --git a/bilby/core/utils.py b/bilby/core/utils.py index ee7af17b..5506a443 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -539,6 +539,9 @@ def set_up_command_line_arguments(): parser.add_argument("--bilby-test-mode", action="store_true", help=("Used for testing only: don't run full PE, but" " just check nothing breaks")) + parser.add_argument("--bilby-zero-likelihood-mode", action="store_true", + help=("Used for testing only: don't run full PE, but" + " just check nothing breaks")) args, unknown_args = parser.parse_known_args() if args.quiet: args.log_level = logging.WARNING diff --git a/setup.cfg b/setup.cfg index 5a877a55..6007f298 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,7 @@ addopts = --ignore test/other_test.py --ignore test/gw_example_test.py --ignore test/example_test.py + --ignore test/sample_from_the_prior_test.py [metadata] license_file = LICENSE.md diff --git a/test/sample_from_the_prior_test.py b/test/sample_from_the_prior_test.py new file mode 100644 index 00000000..04e5c6b2 --- /dev/null +++ b/test/sample_from_the_prior_test.py @@ -0,0 +1,83 @@ +from __future__ import absolute_import +import shutil +import os +import logging + +import unittest +import numpy as np +import bilby +from scipy.stats import ks_2samp, kstest + + +class Test(unittest.TestCase): + outdir = 'outdir_for_tests' + + @classmethod + def setUpClass(self): + if os.path.isdir(self.outdir): + try: + shutil.rmtree(self.outdir) + except OSError: + logging.warning( + "{} not removed prior to tests".format(self.outdir)) + + @classmethod + def tearDownClass(self): + if os.path.isdir(self.outdir): + try: + shutil.rmtree(self.outdir) + except OSError: + logging.warning( + "{} not removed prior to tests".format(self.outdir)) + + def test_fifteen_dimensional_cbc(self): + duration = 4. + sampling_frequency = 2048. + label = 'full_15_parameters' + np.random.seed(88170235) + + waveform_arguments = dict(waveform_approximant='IMRPhenomPv2', + reference_frequency=50., minimum_frequency=20.) + waveform_generator = bilby.gw.WaveformGenerator( + duration=duration, sampling_frequency=sampling_frequency, + frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, + parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters, + waveform_arguments=waveform_arguments) + + ifos = bilby.gw.detector.InterferometerList(['H1', 'L1']) + ifos.set_strain_data_from_power_spectral_densities( + sampling_frequency=sampling_frequency, duration=duration, + start_time=0) + + priors = bilby.gw.prior.BBHPriorDict() + priors.pop('mass_1') + priors.pop('mass_2') + priors['chirp_mass'] = bilby.prior.Uniform( + name='chirp_mass', latex_label='$M$', minimum=10.0, maximum=100.0, + unit='$M_{\\odot}$') + priors['mass_ratio'] = bilby.prior.Uniform( + name='mass_ratio', latex_label='$q$', minimum=0.5, maximum=1.0) + priors['geocent_time'] = bilby.core.prior.Uniform( + minimum=-0.1, maximum=0.1) + + likelihood = bilby.gw.GravitationalWaveTransient( + interferometers=ifos, waveform_generator=waveform_generator, + priors=priors, distance_marginalization=False, + phase_marginalization=False, time_marginalization=False) + + likelihood = bilby.core.likelihood.ZeroLikelihood(likelihood) + + result = bilby.run_sampler( + likelihood=likelihood, priors=priors, sampler='dynesty', + npoints=1000, walks=100, outdir=self.outdir, label=label) + pvalues = [ks_2samp(result.priors[key].sample(10000), + result.posterior[key].values).pvalue + for key in priors.keys()] + print("P values per parameter") + for key, p in zip(priors.keys(), pvalues): + print(key, p) + self.assertGreater(kstest(pvalues, "uniform").pvalue, 0.01) + + +if __name__ == '__main__': + unittest.main() -- GitLab