Commit 3f734cc2 authored by Colm Talbot's avatar Colm Talbot
Browse files

remove references to wg

parent 15db7783
Pipeline #19457 passed with stages
in 5 minutes and 33 seconds
...@@ -152,14 +152,14 @@ class Sampler(object): ...@@ -152,14 +152,14 @@ class Sampler(object):
def verify_parameters(self): def verify_parameters(self):
for key in self.priors: for key in self.priors:
try: try:
self.likelihood.waveform_generator.parameters[key] = self.priors[key].sample() self.likelihood.parameters[key] = self.priors[key].sample()
except AttributeError as e: except AttributeError as e:
logging.warning('Cannot sample from {}, {}'.format(key, e)) logging.warning('Cannot sample from {}, {}'.format(key, e))
try: try:
self.likelihood.waveform_generator.frequency_domain_strain() self.likelihood.log_likelihood_ratio()
except TypeError: except TypeError:
raise TypeError('Waveform generation failed. Have you definitely specified all the parameters?\n{}'.format( raise TypeError('Likelihood evaluation failed. Have you definitely specified all the parameters?\n{}'.format(
self.likelihood.waveform_generator.parameters)) self.likelihood.parameters))
def prior_transform(self, theta): def prior_transform(self, theta):
return [self.priors[key].rescale(t) for key, t in zip(self.__search_parameter_keys, theta)] return [self.priors[key].rescale(t) for key, t in zip(self.__search_parameter_keys, theta)]
...@@ -410,7 +410,7 @@ class Ptemcee(Sampler): ...@@ -410,7 +410,7 @@ class Ptemcee(Sampler):
def run_sampler(likelihood, priors=None, label='label', outdir='outdir', def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
sampler='nestle', use_ratio=True, injection_parameters=None, sampler='nestle', use_ratio=True, injection_parameters=None,
sampling_parameters=None, **kwargs): conversion_function=None, **kwargs):
""" """
The primary interface to easy parameter estimation The primary interface to easy parameter estimation
...@@ -431,12 +431,15 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', ...@@ -431,12 +431,15 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
samplers samplers
use_ratio: bool (False) use_ratio: bool (False)
If True, use the likelihood's loglikelihood_ratio, rather than just If True, use the likelihood's loglikelihood_ratio, rather than just
the loglikelhood. the log likelhood.
injection_parameters: dict injection_parameters: dict
A dictionary of injection parameters used in creating the data (if A dictionary of injection parameters used in creating the data (if
using simulated data). Appended to the result object and saved. using simulated data). Appended to the result object and saved.
conversion_function: function, optional
Function to apply to posterior to generate additional parameters.
**kwargs: **kwargs:
All kwargs are passed directly to the samplers `run` functino All kwargs are passed directly to the samplers `run` function
Returns Returns
------ ------
...@@ -449,7 +452,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', ...@@ -449,7 +452,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
if priors is None: if priors is None:
priors = dict() priors = dict()
priors = fill_priors(priors, likelihood, sampling_parameters) priors = fill_priors(priors, likelihood, parameters=likelihood.sampling_parameter_keys)
tupak.prior.write_priors_to_file(priors, outdir) tupak.prior.write_priors_to_file(priors, outdir)
if implemented_samplers.__contains__(sampler.title()): if implemented_samplers.__contains__(sampler.title()):
...@@ -458,8 +461,6 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', ...@@ -458,8 +461,6 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
label=label, use_ratio=use_ratio, label=label, use_ratio=use_ratio,
**kwargs) **kwargs)
likelihood.waveform_generator.search_parameter_keys = [
key for key in priors if not isinstance(priors[key], tupak.prior.DeltaFunction)]
if sampler.cached_result: if sampler.cached_result:
logging.info("Using cached result") logging.info("Using cached result")
return sampler.cached_result return sampler.cached_result
...@@ -476,8 +477,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', ...@@ -476,8 +477,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
tupak.conversion.generate_all_bbh_parameters(result.injection_parameters) tupak.conversion.generate_all_bbh_parameters(result.injection_parameters)
result.fixed_parameter_keys = [key for key in priors if isinstance(key, prior.DeltaFunction)] result.fixed_parameter_keys = [key for key in priors if isinstance(key, prior.DeltaFunction)]
# result.prior = prior # Removed as this breaks the saving of the data # result.prior = prior # Removed as this breaks the saving of the data
result.samples_to_data_frame(waveform_generator=likelihood.waveform_generator, result.samples_to_data_frame(likelihood=likelihood, priors=priors, conversion_function=conversion_function)
interferometers=likelihood.interferometers, priors=priors)
result.kwargs = sampler.kwargs result.kwargs = sampler.kwargs
result.save_to_file(outdir=outdir, label=label) result.save_to_file(outdir=outdir, label=label)
return result return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment