diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index efc8d580cf9567d876e5f2de198e333635efc054..10feec20783621b55cd3a5928b0405a2e601dce3 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -19,6 +19,7 @@ class Likelihood(object): """ self.parameters = parameters self._meta_data = None + self._marginalized_parameters = [] def __repr__(self): return self.__class__.__name__ + '(parameters={})'.format(self.parameters) @@ -61,6 +62,10 @@ class Likelihood(object): else: raise ValueError("The meta_data must be an instance of dict") + @property + def marginalized_parameters(self): + return self._marginalized_parameters + class ZeroLikelihood(Likelihood): """ A special test-only class which already returns zero likelihood diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 37311298ee32d65b8228938d20584f322e9f3ad9..0bc59adff77bed66f25bfd426c590e64c28a6ebb 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -4,9 +4,9 @@ import datetime from collections import OrderedDict from ..utils import command_line_args, logger -from ..prior import PriorDict +from ..prior import PriorDict, DeltaFunction -from .base_sampler import Sampler +from .base_sampler import Sampler, SamplingMarginalisedParameterError from .cpnest import Cpnest from .dynesty import Dynesty from .dynamic_dynesty import DynamicDynesty @@ -119,6 +119,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', if priors is None: priors = dict() + _check_marginalized_parameters_not_sampled(likelihood, priors) + if type(priors) in [dict, OrderedDict]: priors = PriorDict(priors) elif isinstance(priors, PriorDict): @@ -137,7 +139,6 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', from bilby.core.likelihood import ZeroLikelihood likelihood = ZeroLikelihood(likelihood) - if isinstance(sampler, Sampler): pass elif isinstance(sampler, str): @@ -210,3 +211,12 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', result.plot_corner() logger.info("Summary of results:\n{}".format(result)) return result + + +def _check_marginalized_parameters_not_sampled(likelihood, priors): + for key in likelihood.marginalized_parameters: + if key in priors: + if not isinstance(priors[key], (float, DeltaFunction)): + raise SamplingMarginalisedParameterError( + "Likelihood is {} marginalized but you are trying to sample in {}. " + .format(key, key)) diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index fe90758ef94da23dc47baa86326a7e37112485d9..ab6d62bd4206caf12db8fbeda6b344d4ada02e0f 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -644,3 +644,7 @@ class SamplerNotInstalledError(SamplerError): class IllegalSamplingSetError(Error): """ Class for illegal sets of sampling parameters """ + + +class SamplingMarginalisedParameterError(IllegalSamplingSetError): + """ Class for errors that occur when sampling over marginalized parameters """ diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py index 3741f536cb4443533385c03be726e9bc875a31ba..9879a718729d773bbe3f1830169808da4f579cf7 100644 --- a/bilby/gw/likelihood.py +++ b/bilby/gw/likelihood.py @@ -118,6 +118,7 @@ class GravitationalWaveTransient(likelihood.Likelihood): priors['time_jitter'] = Uniform( minimum=- self._delta_tc / 2, maximum=self._delta_tc / 2, boundary='periodic') + self._marginalized_parameters.append('geocent_time') elif self.jitter_time: logger.debug( "Time jittering requested with non-time-marginalised " @@ -129,6 +130,7 @@ class GravitationalWaveTransient(likelihood.Likelihood): self._bessel_function_interped = None self._setup_phase_marginalization() priors['phase'] = float(0) + self._marginalized_parameters.append('phase') if self.distance_marginalization: self._lookup_table_filename = None @@ -142,6 +144,7 @@ class GravitationalWaveTransient(likelihood.Likelihood): self._setup_distance_marginalization( distance_marginalization_lookup_table) priors['luminosity_distance'] = float(self._ref_dist) + self._marginalized_parameters.append('luminosity_distance') def __repr__(self): return self.__class__.__name__ + '(interferometers={},\n\twaveform_generator={},\n\ttime_marginalization={}, ' \ diff --git a/test/gw_likelihood_test.py b/test/gw_likelihood_test.py index 2317baaee6baf5919d5d157fb8b9c00583653c33..25171e681e324a96ce4b44da8dda9385a0c5f516 100644 --- a/test/gw_likelihood_test.py +++ b/test/gw_likelihood_test.py @@ -327,6 +327,36 @@ class TestMarginalizedLikelihood(unittest.TestCase): self.assertTrue(same) self.prior['phase'] = temp + def test_run_sampler_flags_if_marginalized_phase_is_sampled(self): + like = bilby.gw.likelihood.GravitationalWaveTransient( + interferometers=self.interferometers, + waveform_generator=self.waveform_generator, priors=self.prior, + phase_marginalization=True + ) + new_prior = self.prior.copy() + new_prior['phase'] = bilby.prior.Uniform(minimum=0, maximum=2*np.pi) + for key, param in dict( + mass_1=31., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.0, tilt_2=0.0, + phi_12=1.7, phi_jl=0.3, theta_jn=0.4, psi=2.659, ra=1.375, dec=-1.2108).items(): + new_prior[key] = param + with self.assertRaises(bilby.core.sampler.SamplingMarginalisedParameterError): + bilby.run_sampler(like, new_prior) + + def test_run_sampler_flags_if_marginalized_time_is_sampled(self): + like = bilby.gw.likelihood.GravitationalWaveTransient( + interferometers=self.interferometers, + waveform_generator=self.waveform_generator, priors=self.prior, + time_marginalization=True + ) + new_prior = self.prior.copy() + new_prior['geocent_time'] = bilby.prior.Uniform(minimum=0, maximum=1) + for key, param in dict( + mass_1=31., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.0, tilt_2=0.0, + phi_12=1.7, phi_jl=0.3, theta_jn=0.4, psi=2.659, ra=1.375, dec=-1.2108).items(): + new_prior[key] = param + with self.assertRaises(bilby.core.sampler.SamplingMarginalisedParameterError): + bilby.run_sampler(like, new_prior) + class TestPhaseMarginalization(unittest.TestCase):