Commit 4df29f1d authored by Colm Talbot's avatar Colm Talbot

Merge branch 'master' into 'change_sampled_parameters'

# Conflicts:
#   tupak/sampler.py
parents 3f734cc2 469c5ee5
Pipeline #19458 passed with stages
in 5 minutes and 54 seconds
......@@ -36,8 +36,7 @@ class Sampler(object):
"""
def __init__(self, likelihood, priors, external_sampler='nestle',
outdir='outdir', label='label', use_ratio=False, result=None,
def __init__(self, likelihood, priors, external_sampler='nestle', outdir='outdir', label='label', use_ratio=False,
**kwargs):
self.likelihood = likelihood
self.priors = priors
......@@ -53,7 +52,6 @@ class Sampler(object):
self.verify_parameters()
self.kwargs = kwargs
self.result = result
self.check_cached_result()
self.log_summary_for_sampler()
......@@ -61,25 +59,7 @@ class Sampler(object):
if os.path.isdir(outdir) is False:
os.makedirs(outdir)
@property
def result(self):
return self.__result
@result.setter
def result(self, result):
if result is None:
self.__result = Result()
self.__result.search_parameter_keys = self.__search_parameter_keys
self.__result.fixed_parameter_keys = self.__fixed_parameter_keys
self.__result.parameter_labels = [
self.priors[k].latex_label for k in
self.__search_parameter_keys]
self.__result.label = self.label
self.__result.outdir = self.outdir
elif type(result) is Result:
self.__result = result
else:
raise TypeError('result must either be a Result or None')
self.result = self.initialise_result()
@property
def search_parameter_keys(self):
......@@ -149,6 +129,19 @@ class Sampler(object):
for key in self.__fixed_parameter_keys:
logging.info(' {} = {}'.format(key, self.priors[key].peak))
def initialise_result(self):
result = Result()
result.search_parameter_keys = self.__search_parameter_keys
result.fixed_parameter_keys = self.__fixed_parameter_keys
result.parameter_labels = [
self.priors[k].latex_label for k in
self.__search_parameter_keys]
result.label = self.label
result.outdir = self.outdir
result.priors = self.priors
result.kwargs = self.kwargs
return result
def verify_parameters(self):
for key in self.priors:
try:
......
......@@ -59,7 +59,7 @@ class WaveformGenerator(object):
model_frequency_strain = dict()
time_domain_strain = self.time_domain_source_model(self.time_array, **self.parameters)
if isinstance(time_domain_strain, np.ndarray):
return time_domain_strain
return utils.nfft(time_domain_strain, self.sampling_frequency)
for key in time_domain_strain:
model_frequency_strain[key], self.frequency_array = utils.nfft(time_domain_strain[key],
self.sampling_frequency)
......@@ -79,7 +79,7 @@ class WaveformGenerator(object):
model_time_series = dict()
frequency_domain_strain = self.frequency_domain_source_model(self.frequency_array, **self.parameters)
if isinstance(frequency_domain_strain, np.ndarray):
return frequency_domain_strain
return utils.infft(frequency_domain_strain, self.sampling_frequency)
for key in frequency_domain_strain:
model_time_series[key] = utils.infft(frequency_domain_strain[key], self.sampling_frequency)
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment