Skip to content
Snippets Groups Projects
Commit 1ea662a9 authored by Moritz's avatar Moritz
Browse files

Proof of concept for a generic implementation of correlated Priors

parent 2f8f54f2
No related branches found
No related tags found
1 merge request!332Resolve "Introduce conditional prior sets"
Pipeline #44056 failed
......@@ -336,7 +336,13 @@ class CorrelatedPriorDict(PriorDict):
float: Joint probability of all individual sample probabilities
"""
return np.product([self[key].prob(sample[key]) for key in sample], **kwargs)
ls = []
for key in sample:
method_kwargs = infer_args_from_method(self[key].prob)
method_kwargs.remove('val')
correlated_variables = {key: sample[key] for key in method_kwargs}
ls.append(self[key].prob(sample[key], **correlated_variables))
return np.product(ls, **kwargs)
def ln_prob(self, sample):
"""
......@@ -351,7 +357,13 @@ class CorrelatedPriorDict(PriorDict):
float: Joint log probability of all the individual sample probabilities
"""
return np.sum([self[key].ln_prob(sample[key]) for key in sample])
ls = []
for key in sample:
method_kwargs = infer_args_from_method(self[key].prob)
method_kwargs.remove('val')
correlated_variables = {key: sample[key] for key in method_kwargs}
ls.append(self[key].ln_prob(sample[key], **correlated_variables))
return np.sum(ls)
def rescale(self, keys, theta):
"""Rescale samples from unit cube to prior
......@@ -367,7 +379,13 @@ class CorrelatedPriorDict(PriorDict):
-------
list: List of floats containing the rescaled sample
"""
return [self[key].rescale(sample) for key, sample in zip(keys, theta)]
ls = []
for key in theta:
method_kwargs = infer_args_from_method(self[key].prob)
method_kwargs.remove('val')
correlated_variables = {key: theta for key in method_kwargs}
ls.append(self[key].rescale(sample, correlated_variables) for key, sample in zip(keys, theta))
return ls
def create_default_prior(name, default_priors_file=None):
......
......@@ -422,4 +422,32 @@ class CorrelatedSecondaryMassPrior(Uniform):
self.maximum = mass_1
res = super().sample(size)
self.maximum = maximum
return res
\ No newline at end of file
return res
def prob(self, val, mass_1=None):
if mass_1 is None:
return super().prob(val)
maximum = self.maximum
self.maximum = mass_1
res = super().prob(val)
self.maximum = maximum
return res
def ln_prob(self, val, mass_1=None):
if mass_1 is None:
return super().ln_prob(val)
maximum = self.maximum
self.maximum = mass_1
res = super().ln_prob(val)
self.maximum = maximum
return res
def rescale(self, val, mass_1=None):
Prior.test_valid_for_rescaling(val)
if mass_1 is None:
return super().rescale(val)
maximum = self.maximum
self.maximum = mass_1
res = super().rescale(val)
self.maximum = maximum
return res
import bilby
import bilby.gw.prior
mass_1 = bilby.core.prior.Uniform(5, 100)
......@@ -6,13 +5,15 @@ mass_2 = bilby.gw.prior.CorrelatedSecondaryMassPrior(minimum=5, maximum=100)
correlated_priors = bilby.core.prior.CorrelatedPriorDict(dictionary=dict(mass_1=mass_1, mass_2=mass_2))
samples = correlated_priors.sample(100)
samples = correlated_priors.sample(10)
primary_masses = samples['mass_1']
secondary_masses = samples['mass_2']
for i in range(len(primary_masses)):
if primary_masses[i] < secondary_masses[i]:
print('False')
break
else:
if primary_masses[i] > secondary_masses[i]:
print('True')
else:
print('False')
sample = dict(mass_1=25, mass_2=20)
print(correlated_priors.prob(sample))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment