Skip to content
Snippets Groups Projects
Commit 4d25c431 authored by Moritz's avatar Moritz
Browse files

Improved sampling logic for conditional prior sets

parent de1948a0
No related branches found
No related tags found
1 merge request!332Resolve "Introduce conditional prior sets"
......@@ -494,66 +494,61 @@ class ConditionalPriorDict(PriorDict):
filename: str
See parent class
"""
self.conditioned_keys = []
self.unconditioned_keys = []
self._conditioned_keys = []
self._unconditioned_keys = []
self._sorted_keys = []
super(ConditionalPriorDict, self).__init__(dictionary=dictionary, filename=filename)
self._resolve_conditions()
def _resolve_conditions(self):
self.convert_floats_to_delta_functions()
self.conditioned_keys = []
self.unconditioned_keys = []
for key in self.keys():
if hasattr(self[key], 'condition_func'):
self.conditioned_keys.append(key)
else:
self.unconditioned_keys.append(key)
def sample_subset(self, keys=iter([]), size=None):
self.convert_floats_to_delta_functions()
conditional_keys = []
unconditional_keys = []
sampled_keys = []
for key in keys:
if key in self.unconditioned_keys:
unconditional_keys.append(key)
elif key in self.conditioned_keys:
conditional_keys.append(key)
else:
raise KeyError('Invalid key in keys argument')
samples = dict()
for key in unconditional_keys:
samples[key] = self[key].sample(size=size)
sampled_keys.append(key)
for i in range(0, 1000):
for key in conditional_keys:
if self._check_conditions_resolved(key, sampled_keys):
cvars = self._get_conditional_variables(key)
samples[key] = self[key].sample(size=size, **cvars)
conditional_keys.remove(key)
sampled_keys.append(key)
if not conditional_keys:
break
if i == 999:
raise Exception('This set contains unresolvable conditions')
return samples
def _get_conditional_variables(self, key):
conditional_variables = dict()
for k in self[key].required_variables:
conditional_variables[k] = self[k].least_recently_sampled
return conditional_variables
""" Resolves how variables depend on each other and automatically sorts them into the right order """
conditioned_keys_unsorted = [key for key in self.keys() if hasattr(self[key], 'condition_func')]
self._unconditioned_keys = [key for key in self.keys() if not hasattr(self[key], 'condition_func')]
self._conditioned_keys = []
checked_keys = self._unconditioned_keys.copy()
for i in range(len(self)):
for key in conditioned_keys_unsorted:
if self._check_conditions_resolved(key, checked_keys):
checked_keys.append(key)
self._conditioned_keys.append(key)
conditioned_keys_unsorted.remove(key)
self._sorted_keys = self._unconditioned_keys.copy()
self._sorted_keys.extend(self.conditional_keys)
if len(conditioned_keys_unsorted) != 0:
raise IllegalConditionsException('This set contains unresolvable conditions')
def _check_conditions_resolved(self, key, sampled_keys):
""" Checks if all required variables have already been sampled so we can sample this key """
conditions_resolved = True
for k in self[key].required_variables:
if k not in sampled_keys:
conditions_resolved = False
return conditions_resolved
@property
def conditional_keys(self):
return self._conditioned_keys
@property
def unconditional_keys(self):
return self._unconditioned_keys
def sample_subset(self, keys=iter([]), size=None):
self.convert_floats_to_delta_functions()
subset_unconditional_keys = [key for key in keys if key in self.unconditional_keys]
subset_conditional_keys = [key for key in keys if key in self.conditional_keys]
samples = {key: self[key].sample(size=size) for key in subset_unconditional_keys}
for key in subset_conditional_keys:
required_variables = self._get_required_variables(key)
samples[key] = self[key].sample(size=size, **required_variables)
return samples
def _get_required_variables(self, key):
""" Returns the required variables to sample a given key """
return {k: self[k].least_recently_sampled for k in self[key].required_variables}
def prob(self, sample, **kwargs):
"""
......@@ -571,7 +566,7 @@ class ConditionalPriorDict(PriorDict):
"""
ls = []
for key in sample:
if key in self.conditioned_keys:
if key in self.conditional_keys:
conditional_variables = dict([(k, sample[k]) for k in self[key].required_variables])
ls.append(self[key].prob(sample[key], **conditional_variables))
else:
......@@ -593,7 +588,7 @@ class ConditionalPriorDict(PriorDict):
"""
ls = []
for key in sample:
if key in self.conditioned_keys:
if key in self.conditional_keys:
conditional_variables = dict([(k, sample[k]) for k in self[key].required_variables])
ls.append(self[key].ln_prob(sample[key], **conditional_variables))
else:
......@@ -3484,3 +3479,15 @@ class ConditionalPriorException(PriorException):
class IllegalRequiredVariablesException(ConditionalPriorException):
""" Exception class for exceptions relating to handling the required variables. """
class PriorDictException(Exception):
""" General base class for all prior dict exceptions """
class ConditionalPriorDictException(PriorDictException):
""" General base class for all conditional prior dict exceptions """
class IllegalConditionsException(ConditionalPriorDictException):
""" Exception class to handle prior dicts that contain unresolvable conditions. """
......@@ -33,7 +33,7 @@ plt.clf()
plt.hist(res['mass_2'], bins='fd', alpha=0.6, density=True, label='Sampled')
plt.xlabel('$q$')
plt.xlabel('$m_2$')
plt.ylabel('$p(m_2 | m_1)$')
# plt.loglog()
plt.legend()
......@@ -86,8 +86,8 @@ likelihood = bilby.gw.GravitationalWaveTransient(
# Run sampler. In this case we're going to use the `dynesty` sampler
result = bilby.run_sampler(
likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000,
injection_parameters=injection_parameters, outdir=outdir, label=label)
likelihood=likelihood, priors=priors, sampler='dynesty', npoints=10,
injection_parameters=injection_parameters, outdir=outdir, label=label, clean=True, resume=False)
# Make a corner plot.
result.plot_corner()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment