Commit 4df29f1d authored by Colm Talbot's avatar Colm Talbot

Merge branch 'master' into 'change_sampled_parameters'

# Conflicts:
#   tupak/sampler.py
parents 3f734cc2 469c5ee5
Pipeline #19458 passed with stages
in 5 minutes and 54 seconds
...@@ -36,8 +36,7 @@ class Sampler(object): ...@@ -36,8 +36,7 @@ class Sampler(object):
""" """
def __init__(self, likelihood, priors, external_sampler='nestle', def __init__(self, likelihood, priors, external_sampler='nestle', outdir='outdir', label='label', use_ratio=False,
outdir='outdir', label='label', use_ratio=False, result=None,
**kwargs): **kwargs):
self.likelihood = likelihood self.likelihood = likelihood
self.priors = priors self.priors = priors
...@@ -53,7 +52,6 @@ class Sampler(object): ...@@ -53,7 +52,6 @@ class Sampler(object):
self.verify_parameters() self.verify_parameters()
self.kwargs = kwargs self.kwargs = kwargs
self.result = result
self.check_cached_result() self.check_cached_result()
self.log_summary_for_sampler() self.log_summary_for_sampler()
...@@ -61,25 +59,7 @@ class Sampler(object): ...@@ -61,25 +59,7 @@ class Sampler(object):
if os.path.isdir(outdir) is False: if os.path.isdir(outdir) is False:
os.makedirs(outdir) os.makedirs(outdir)
@property self.result = self.initialise_result()
def result(self):
return self.__result
@result.setter
def result(self, result):
if result is None:
self.__result = Result()
self.__result.search_parameter_keys = self.__search_parameter_keys
self.__result.fixed_parameter_keys = self.__fixed_parameter_keys
self.__result.parameter_labels = [
self.priors[k].latex_label for k in
self.__search_parameter_keys]
self.__result.label = self.label
self.__result.outdir = self.outdir
elif type(result) is Result:
self.__result = result
else:
raise TypeError('result must either be a Result or None')
@property @property
def search_parameter_keys(self): def search_parameter_keys(self):
...@@ -149,6 +129,19 @@ class Sampler(object): ...@@ -149,6 +129,19 @@ class Sampler(object):
for key in self.__fixed_parameter_keys: for key in self.__fixed_parameter_keys:
logging.info(' {} = {}'.format(key, self.priors[key].peak)) logging.info(' {} = {}'.format(key, self.priors[key].peak))
def initialise_result(self):
result = Result()
result.search_parameter_keys = self.__search_parameter_keys
result.fixed_parameter_keys = self.__fixed_parameter_keys
result.parameter_labels = [
self.priors[k].latex_label for k in
self.__search_parameter_keys]
result.label = self.label
result.outdir = self.outdir
result.priors = self.priors
result.kwargs = self.kwargs
return result
def verify_parameters(self): def verify_parameters(self):
for key in self.priors: for key in self.priors:
try: try:
......
...@@ -59,7 +59,7 @@ class WaveformGenerator(object): ...@@ -59,7 +59,7 @@ class WaveformGenerator(object):
model_frequency_strain = dict() model_frequency_strain = dict()
time_domain_strain = self.time_domain_source_model(self.time_array, **self.parameters) time_domain_strain = self.time_domain_source_model(self.time_array, **self.parameters)
if isinstance(time_domain_strain, np.ndarray): if isinstance(time_domain_strain, np.ndarray):
return time_domain_strain return utils.nfft(time_domain_strain, self.sampling_frequency)
for key in time_domain_strain: for key in time_domain_strain:
model_frequency_strain[key], self.frequency_array = utils.nfft(time_domain_strain[key], model_frequency_strain[key], self.frequency_array = utils.nfft(time_domain_strain[key],
self.sampling_frequency) self.sampling_frequency)
...@@ -79,7 +79,7 @@ class WaveformGenerator(object): ...@@ -79,7 +79,7 @@ class WaveformGenerator(object):
model_time_series = dict() model_time_series = dict()
frequency_domain_strain = self.frequency_domain_source_model(self.frequency_array, **self.parameters) frequency_domain_strain = self.frequency_domain_source_model(self.frequency_array, **self.parameters)
if isinstance(frequency_domain_strain, np.ndarray): if isinstance(frequency_domain_strain, np.ndarray):
return frequency_domain_strain return utils.infft(frequency_domain_strain, self.sampling_frequency)
for key in frequency_domain_strain: for key in frequency_domain_strain:
model_time_series[key] = utils.infft(frequency_domain_strain[key], self.sampling_frequency) model_time_series[key] = utils.infft(frequency_domain_strain[key], self.sampling_frequency)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment