diff --git a/bilby/core/prior.py b/bilby/core/prior.py index 9861bcdac4e3a35fd09b741eca501a56330252d9..7c0ed966287263afa72730d1f458a7fdd6e6847e 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -494,66 +494,61 @@ class ConditionalPriorDict(PriorDict): filename: str See parent class """ - self.conditioned_keys = [] - self.unconditioned_keys = [] + self._conditioned_keys = [] + self._unconditioned_keys = [] + self._sorted_keys = [] super(ConditionalPriorDict, self).__init__(dictionary=dictionary, filename=filename) self._resolve_conditions() def _resolve_conditions(self): - self.convert_floats_to_delta_functions() - self.conditioned_keys = [] - self.unconditioned_keys = [] - for key in self.keys(): - if hasattr(self[key], 'condition_func'): - self.conditioned_keys.append(key) - else: - self.unconditioned_keys.append(key) - - def sample_subset(self, keys=iter([]), size=None): - self.convert_floats_to_delta_functions() - conditional_keys = [] - unconditional_keys = [] - sampled_keys = [] - for key in keys: - if key in self.unconditioned_keys: - unconditional_keys.append(key) - elif key in self.conditioned_keys: - conditional_keys.append(key) - else: - raise KeyError('Invalid key in keys argument') - - samples = dict() - for key in unconditional_keys: - samples[key] = self[key].sample(size=size) - sampled_keys.append(key) - - for i in range(0, 1000): - for key in conditional_keys: - if self._check_conditions_resolved(key, sampled_keys): - cvars = self._get_conditional_variables(key) - samples[key] = self[key].sample(size=size, **cvars) - conditional_keys.remove(key) - sampled_keys.append(key) - if not conditional_keys: - break - if i == 999: - raise Exception('This set contains unresolvable conditions') - - return samples - - def _get_conditional_variables(self, key): - conditional_variables = dict() - for k in self[key].required_variables: - conditional_variables[k] = self[k].least_recently_sampled - return conditional_variables + """ Resolves how variables depend on each other and automatically sorts them into the right order """ + conditioned_keys_unsorted = [key for key in self.keys() if hasattr(self[key], 'condition_func')] + self._unconditioned_keys = [key for key in self.keys() if not hasattr(self[key], 'condition_func')] + self._conditioned_keys = [] + checked_keys = self._unconditioned_keys.copy() + for i in range(len(self)): + for key in conditioned_keys_unsorted: + if self._check_conditions_resolved(key, checked_keys): + checked_keys.append(key) + self._conditioned_keys.append(key) + conditioned_keys_unsorted.remove(key) + + self._sorted_keys = self._unconditioned_keys.copy() + self._sorted_keys.extend(self.conditional_keys) + + if len(conditioned_keys_unsorted) != 0: + raise IllegalConditionsException('This set contains unresolvable conditions') def _check_conditions_resolved(self, key, sampled_keys): + """ Checks if all required variables have already been sampled so we can sample this key """ conditions_resolved = True for k in self[key].required_variables: if k not in sampled_keys: conditions_resolved = False return conditions_resolved + @property + def conditional_keys(self): + return self._conditioned_keys + + @property + def unconditional_keys(self): + return self._unconditioned_keys + + def sample_subset(self, keys=iter([]), size=None): + self.convert_floats_to_delta_functions() + subset_unconditional_keys = [key for key in keys if key in self.unconditional_keys] + subset_conditional_keys = [key for key in keys if key in self.conditional_keys] + samples = {key: self[key].sample(size=size) for key in subset_unconditional_keys} + for key in subset_conditional_keys: + required_variables = self._get_required_variables(key) + samples[key] = self[key].sample(size=size, **required_variables) + return samples + + def _get_required_variables(self, key): + """ Returns the required variables to sample a given key """ + return {k: self[k].least_recently_sampled for k in self[key].required_variables} + def prob(self, sample, **kwargs): """ @@ -571,7 +566,7 @@ class ConditionalPriorDict(PriorDict): """ ls = [] for key in sample: - if key in self.conditioned_keys: + if key in self.conditional_keys: conditional_variables = dict([(k, sample[k]) for k in self[key].required_variables]) ls.append(self[key].prob(sample[key], **conditional_variables)) else: @@ -593,7 +588,7 @@ class ConditionalPriorDict(PriorDict): """ ls = [] for key in sample: - if key in self.conditioned_keys: + if key in self.conditional_keys: conditional_variables = dict([(k, sample[k]) for k in self[key].required_variables]) ls.append(self[key].ln_prob(sample[key], **conditional_variables)) else: @@ -3484,3 +3479,15 @@ class ConditionalPriorException(PriorException): class IllegalRequiredVariablesException(ConditionalPriorException): """ Exception class for exceptions relating to handling the required variables. """ + +class PriorDictException(Exception): + """ General base class for all prior dict exceptions """ + + +class ConditionalPriorDictException(PriorDictException): + """ General base class for all conditional prior dict exceptions """ + + +class IllegalConditionsException(ConditionalPriorDictException): + """ Exception class to handle prior dicts that contain unresolvable conditions. """ + diff --git a/examples/gw_examples/injection_examples/conditional_prior.py b/examples/gw_examples/injection_examples/conditional_prior.py index d7a855111232bd522c0db314758e1d6f92cdf582..05a2e55e22736761dd644f5cca2c634c28c1be10 100644 --- a/examples/gw_examples/injection_examples/conditional_prior.py +++ b/examples/gw_examples/injection_examples/conditional_prior.py @@ -33,7 +33,7 @@ plt.clf() plt.hist(res['mass_2'], bins='fd', alpha=0.6, density=True, label='Sampled') -plt.xlabel('$q$') +plt.xlabel('$m_2$') plt.ylabel('$p(m_2 | m_1)$') # plt.loglog() plt.legend() @@ -86,8 +86,8 @@ likelihood = bilby.gw.GravitationalWaveTransient( # Run sampler. In this case we're going to use the `dynesty` sampler result = bilby.run_sampler( - likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000, - injection_parameters=injection_parameters, outdir=outdir, label=label) + likelihood=likelihood, priors=priors, sampler='dynesty', npoints=10, + injection_parameters=injection_parameters, outdir=outdir, label=label, clean=True, resume=False) # Make a corner plot. result.plot_corner()