Commit 5bb71c9d authored by Gregory Ashton's avatar Gregory Ashton

Rename sampling_parameter_keys -> non_standard_sampling_parameter_keys

parent 2ecb3655
Pipeline #19559 passed with stages
in 5 minutes and 57 seconds
...@@ -26,7 +26,7 @@ waveform_generator = tupak.waveform_generator.WaveformGenerator( ...@@ -26,7 +26,7 @@ waveform_generator = tupak.waveform_generator.WaveformGenerator(
sampling_frequency=sampling_frequency, time_duration=time_duration, sampling_frequency=sampling_frequency, time_duration=time_duration,
frequency_domain_source_model=tupak.source.lal_binary_black_hole, frequency_domain_source_model=tupak.source.lal_binary_black_hole,
parameter_conversion=tupak.conversion.convert_to_lal_binary_black_hole_parameters, parameter_conversion=tupak.conversion.convert_to_lal_binary_black_hole_parameters,
sampling_parameter_keys=['chirp_mass', 'mass_ratio', 'cos_iota'], non_standard_sampling_parameter_keys=['chirp_mass', 'mass_ratio', 'cos_iota'],
parameters=injection_parameters) parameters=injection_parameters)
hf_signal = waveform_generator.frequency_domain_strain() hf_signal = waveform_generator.frequency_domain_strain()
......
...@@ -215,7 +215,7 @@ def compute_snrs(sample, likelihood): ...@@ -215,7 +215,7 @@ def compute_snrs(sample, likelihood):
for ii in range(len(temp_sample)): for ii in range(len(temp_sample)):
for key in set(temp_sample.keys()).intersection(likelihood.waveform_generator.parameters.keys()): for key in set(temp_sample.keys()).intersection(likelihood.waveform_generator.parameters.keys()):
likelihood.waveform_generator.parameters[key] = temp_sample[key][ii] likelihood.waveform_generator.parameters[key] = temp_sample[key][ii]
for key in likelihood.waveform_generator.sampling_parameter_keys: for key in likelihood.waveform_generator.non_standard_sampling_parameter_keys:
likelihood.waveform_generator.parameters[key] = temp_sample[key][ii] likelihood.waveform_generator.parameters[key] = temp_sample[key][ii]
signal_polarizations = likelihood.waveform_generator.frequency_domain_strain() signal_polarizations = likelihood.waveform_generator.frequency_domain_strain()
for interferometer in all_interferometers: for interferometer in all_interferometers:
......
...@@ -18,7 +18,7 @@ class Likelihood(object): ...@@ -18,7 +18,7 @@ class Likelihood(object):
self.interferometers = interferometers self.interferometers = interferometers
self.waveform_generator = waveform_generator self.waveform_generator = waveform_generator
self.parameters = self.waveform_generator.parameters self.parameters = self.waveform_generator.parameters
self.sampling_parameter_keys = self.waveform_generator.sampling_parameter_keys self.non_standard_sampling_parameter_keys = self.waveform_generator.non_standard_sampling_parameter_keys
self.distance_marginalization = distance_marginalization self.distance_marginalization = distance_marginalization
self.phase_marginalization = phase_marginalization self.phase_marginalization = phase_marginalization
self.prior = prior self.prior = prior
......
...@@ -444,7 +444,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', ...@@ -444,7 +444,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
if priors is None: if priors is None:
priors = dict() priors = dict()
priors = fill_priors(priors, likelihood, parameters=likelihood.sampling_parameter_keys) priors = fill_priors(priors, likelihood, parameters=likelihood.non_standard_sampling_parameter_keys)
tupak.prior.write_priors_to_file(priors, outdir) tupak.prior.write_priors_to_file(priors, outdir)
if implemented_samplers.__contains__(sampler.title()): if implemented_samplers.__contains__(sampler.title()):
......
...@@ -25,7 +25,7 @@ class WaveformGenerator(object): ...@@ -25,7 +25,7 @@ class WaveformGenerator(object):
Initial values for the parameters Initial values for the parameters
parameter_conversion: func parameter_conversion: func
Function to convert from sampled parameters to parameters of the waveform generator Function to convert from sampled parameters to parameters of the waveform generator
sampling_parameter_keys: list non_standard_sampling_parameter_keys: list
List of parameter name for *non-standard* sampling parameters. List of parameter name for *non-standard* sampling parameters.
Note: the arguments of frequency_domain_source_model (except the first, which is the Note: the arguments of frequency_domain_source_model (except the first, which is the
...@@ -35,7 +35,7 @@ class WaveformGenerator(object): ...@@ -35,7 +35,7 @@ class WaveformGenerator(object):
def __init__(self, time_duration, sampling_frequency, frequency_domain_source_model=None, def __init__(self, time_duration, sampling_frequency, frequency_domain_source_model=None,
time_domain_source_model=None, parameters=None, parameter_conversion=None, time_domain_source_model=None, parameters=None, parameter_conversion=None,
sampling_parameter_keys=None): non_standard_sampling_parameter_keys=None):
self.time_duration = time_duration self.time_duration = time_duration
self.sampling_frequency = sampling_frequency self.sampling_frequency = sampling_frequency
self.frequency_domain_source_model = frequency_domain_source_model self.frequency_domain_source_model = frequency_domain_source_model
...@@ -43,7 +43,7 @@ class WaveformGenerator(object): ...@@ -43,7 +43,7 @@ class WaveformGenerator(object):
self.time_duration = time_duration self.time_duration = time_duration
self.sampling_frequency = sampling_frequency self.sampling_frequency = sampling_frequency
self.parameter_conversion = parameter_conversion self.parameter_conversion = parameter_conversion
self.sampling_parameter_keys = sampling_parameter_keys self.non_standard_sampling_parameter_keys = non_standard_sampling_parameter_keys
self.parameters = parameters self.parameters = parameters
self.__frequency_array_updated = False self.__frequency_array_updated = False
self.__time_array_updated = False self.__time_array_updated = False
...@@ -51,7 +51,7 @@ class WaveformGenerator(object): ...@@ -51,7 +51,7 @@ class WaveformGenerator(object):
def frequency_domain_strain(self): def frequency_domain_strain(self):
""" Wrapper to source_model """ """ Wrapper to source_model """
if self.parameter_conversion is not None: if self.parameter_conversion is not None:
added_keys = self.parameter_conversion(self.parameters, self.sampling_parameter_keys) added_keys = self.parameter_conversion(self.parameters, self.non_standard_sampling_parameter_keys)
if self.frequency_domain_source_model is not None: if self.frequency_domain_source_model is not None:
model_frequency_strain = self.frequency_domain_source_model(self.frequency_array, **self.parameters) model_frequency_strain = self.frequency_domain_source_model(self.frequency_array, **self.parameters)
...@@ -72,7 +72,7 @@ class WaveformGenerator(object): ...@@ -72,7 +72,7 @@ class WaveformGenerator(object):
def time_domain_strain(self): def time_domain_strain(self):
if self.parameter_conversion is not None: if self.parameter_conversion is not None:
added_keys = self.parameter_conversion(self.parameters, self.sampling_parameter_keys) added_keys = self.parameter_conversion(self.parameters, self.non_standard_sampling_parameter_keys)
if self.time_domain_source_model is not None: if self.time_domain_source_model is not None:
model_time_series = self.time_domain_source_model(self.time_array, **self.parameters) model_time_series = self.time_domain_source_model(self.time_array, **self.parameters)
elif self.frequency_domain_source_model is not None: elif self.frequency_domain_source_model is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment