Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • john-veitch/bilby
  • duncanmmacleod/bilby
  • colm.talbot/bilby
  • lscsoft/bilby
  • matthew-pitkin/bilby
  • salvatore-vitale/tupak
  • charlie.hoy/bilby
  • bfarr/bilby
  • virginia.demilio/bilby
  • vivien/bilby
  • eric-howell/bilby
  • sebastian-khan/bilby
  • rhys.green/bilby
  • moritz.huebner/bilby
  • joseph.mills/bilby
  • scott.coughlin/bilby
  • matthew.carney/bilby
  • hyungwon.lee/bilby
  • monica.rizzo/bilby
  • christopher-berry/bilby
  • lindsay.demarchi/bilby
  • kaushik.rao/bilby
  • charles.kimball/bilby
  • andrew.matas/bilby
  • juan.calderonbustillo/bilby
  • patrick-meyers/bilby
  • hannah.middleton/bilby
  • eve.chase/bilby
  • grant.meadors/bilby
  • khun.phukon/bilby
  • sumeet.kulkarni/bilby
  • daniel.reardon/bilby
  • cjhaster/bilby
  • sylvia.biscoveanu/bilby
  • james-clark/bilby
  • meg.millhouse/bilby
  • joshua.willis/bilby
  • nikhil.sarin/bilby
  • paul.easter/bilby
  • youngmin/bilby
  • daniel-williams/bilby
  • shanika.galaudage/bilby
  • bruce.edelman/bilby
  • avi.vajpeyi/bilby
  • isobel.romero-shaw/bilby
  • andrew.kim/bilby
  • dominika.zieba/bilby
  • jonathan.davies/bilby
  • marc.arene/bilby
  • srishti.tiwari/bilby-tidal-heating-eccentric
  • aditya.vijaykumar/bilby
  • michael.williams/bilby
  • cecilio.garcia-quiros/bilby
  • rory-smith/bilby
  • maite.mateu-lucena/bilby
  • wushichao/bilby
  • kaylee.desoto/bilby
  • brandon.piotrzkowski/bilby
  • rossella.gamba/bilby
  • hunter.gabbard/bilby
  • deep.chatterjee/bilby
  • tathagata.ghosh/bilby
  • arunava.mukherjee/bilby
  • philip.relton/bilby
  • reed.essick/bilby
  • pawan.gupta/bilby
  • francisco.hernandez/bilby
  • rhiannon.udall/bilby
  • leo.tsukada/bilby
  • will-farr/bilby
  • vijay.varma/bilby
  • jeremy.baier/bilby
  • joshua.brandt/bilby
  • ethan.payne/bilby
  • ka-lok.lo/bilby
  • antoni.ramos-buades/bilby
  • oliviastephany.wilk/bilby
  • jack.heinzel/bilby
  • samson.leong/bilby-psi4
  • viviana.caceres/bilby
  • nadia.qutob/bilby
  • michael-coughlin/bilby
  • hemantakumar.phurailatpam/bilby
  • boris.goncharov/bilby
  • sama.al-shammari/bilby
  • siqi.zhong/bilby
  • jocelyn-read/bilby
  • marc.penuliar/bilby
  • stephanie.letourneau/bilby
  • alexandresebastien.goettel/bilby
  • alec.gunny/bilby
  • serguei.ossokine/bilby
  • pratyusava.baral/bilby
  • sophie.hourihane/bilby
  • eunsub/bilby
  • james.hart/bilby
  • pratyusava.baral/bilby-tg
  • zhaozc/bilby
  • pratyusava.baral/bilby_SoG
  • tomasz.baka/bilby
  • nicogerardo.bers/bilby
  • soumen.roy/bilby
  • isaac.mcmahon/healpix-redundancy
  • asamakai.baker/bilby-frequency-dependent-antenna-pattern-functions
  • anna.puecher/bilby
  • pratyusava.baral/bilby-x-g
  • thibeau.wouters/bilby
  • christian.adamcewicz/bilby
  • raffi.enficiaud/bilby
109 results
Show changes
Commits on Source (20)
......@@ -4,31 +4,55 @@ from setuptools import setup
import subprocess
from os import path
version = '0.2.1'
# Write a version file containing the git hash and info
try:
git_log = subprocess.check_output(
['git', 'log', '-1', '--pretty=%h %ai']).decode('utf-8')
git_diff = (subprocess.check_output(['git', 'diff', '.'])
+ subprocess.check_output(
['git', 'diff', '--cached', '.'])).decode('utf-8')
if git_diff == '':
status = '(CLEAN) ' + git_log
else:
status = '(UNCLEAN) ' + git_log
except subprocess.CalledProcessError:
status = ''
version_file = '.version'
if path.isfile(version_file) is False:
with open('tupak/' + version_file, 'w+') as f:
f.write('{} - {}'.format(version, status))
here = path.abspath(path.dirname(__file__))
with open(path.join(here, 'README.rst')) as f:
long_description = f.read()
def write_version_file(version):
""" Writes a file with version information to be used at run time
Parameters
----------
version: str
A string containing the current version information
Returns
-------
version_file: str
A path to the version file
"""
try:
git_log = subprocess.check_output(
['git', 'log', '-1', '--pretty=%h %ai']).decode('utf-8')
git_diff = (subprocess.check_output(['git', 'diff', '.'])
+ subprocess.check_output(
['git', 'diff', '--cached', '.'])).decode('utf-8')
if git_diff == '':
git_status = '(CLEAN) ' + git_log
else:
git_status = '(UNCLEAN) ' + git_log
except Exception as e:
print("Unable to obtain git version information, exception: {}"
.format(e))
git_status = ''
version_file = '.version'
if path.isfile(version_file) is False:
with open('tupak/' + version_file, 'w+') as f:
f.write('{}: {}'.format(version, git_status))
return version_file
def get_long_description():
""" Finds the README and reads in the description """
here = path.abspath(path.dirname(__file__))
with open(path.join(here, 'README.rst')) as f:
long_description = f.read()
return long_description
version = '0.2.1'
version_file = write_version_file(version)
long_description = get_long_description()
setup(name='tupak',
description='The User friendly Parameter estimAtion Kode',
......@@ -40,7 +64,8 @@ setup(name='tupak',
version=version,
packages=['tupak', 'tupak.core', 'tupak.gw', 'tupak.hyper', 'cli_tupak'],
package_dir={'tupak': 'tupak'},
package_data={'tupak.gw': ['prior_files/*', 'noise_curves/*.txt', 'detectors/*'],
package_data={'tupak.gw': ['prior_files/*', 'noise_curves/*.txt',
'detectors/*'],
'tupak': [version_file]},
install_requires=[
'future',
......@@ -50,9 +75,7 @@ setup(name='tupak',
'matplotlib>=2.0',
'deepdish',
'pandas',
'scipy',
],
'scipy'],
entry_points={'console_scripts':
['tupak_plot=cli_tupak.plot_multiple_posteriors:main']
})
......@@ -313,8 +313,8 @@ class Result(dict):
return fig
def plot_walkers(self, save=True, **kwargs):
""" Method to plot the trace of the walkers in an ensmble MCMC plot """
def plot_walkers(self, **kwargs):
""" Method to plot the trace of the walkers in an ensemble MCMC plot """
if hasattr(self, 'walkers') is False:
logger.warning("Cannot plot_walkers as no walkers are saved")
return
......@@ -341,14 +341,6 @@ class Result(dict):
logger.debug('Saving walkers plot to {}'.format('filename'))
fig.savefig(filename)
def plot_walks(self, save=True, **kwargs):
"""DEPRECATED"""
logger.warning("plot_walks deprecated")
def plot_distributions(self, save=True, **kwargs):
"""DEPRECATED"""
logger.warning("plot_distributions deprecated")
def samples_to_posterior(self, likelihood=None, priors=None,
conversion_function=None):
"""
......@@ -370,6 +362,8 @@ class Result(dict):
if conversion_function is not None:
data_frame = conversion_function(data_frame, likelihood, priors)
self.posterior = data_frame
# We save the samples in the posterior and remove the array of samples
del self.samples
def construct_cbc_derived_parameters(self):
""" Construct widely used derived parameters of CBCs """
......
......@@ -200,7 +200,13 @@ class Sampler(object):
return result
def _check_if_priors_can_be_sampled(self):
"""Check if all priors can be sampled properly. Raises AttributeError if prior can't be sampled."""
"""Check if all priors can be sampled properly.
Raises
------
AttributeError
prior can't be sampled.
"""
for key in self.priors:
try:
self.likelihood.parameters[key] = self.priors[key].sample()
......@@ -208,13 +214,26 @@ class Sampler(object):
logger.warning('Cannot sample from {}, {}'.format(key, e))
def _verify_parameters(self):
""" Sets initial values for likelihood.parameters. Raises TypeError if likelihood can't be evaluated."""
""" Sets initial values for likelihood.parameters.
Raises
------
TypeError
Likelihood can't be evaluated.
"""
self._check_if_priors_can_be_sampled()
try:
t1 = datetime.datetime.now()
self.likelihood.log_likelihood()
self._sample_log_likelihood_eval = (datetime.datetime.now() - t1).total_seconds()
logger.info("Single likelihood evaluation took {:.3e} s".format(self._sample_log_likelihood_eval))
self._log_likelihood_eval_time = (
datetime.datetime.now() - t1).total_seconds()
if self._log_likelihood_eval_time == 0:
self._log_likelihood_eval_time = np.nan
logger.info("Unable to measure single likelihood time")
else:
logger.info("Single likelihood evaluation took {:.3e} s"
.format(self._log_likelihood_eval_time))
except TypeError as e:
raise TypeError(
"Likelihood evaluation failed with message: \n'{}'\n"
......@@ -450,21 +469,34 @@ class Dynesty(Sampler):
@kwargs.setter
def kwargs(self, kwargs):
self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk', resume=True,
walks=self.ndim * 5, verbose=True, check_point_delta_t=60 * 10)
# Set some default values
self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk',
resume=True, walks=self.ndim * 5, verbose=True,
check_point_delta_t=60 * 10, nlive=250)
# Overwrite default values with user specified values
self.__kwargs.update(kwargs)
# Check if nlive was instead given by another name
if 'nlive' not in self.__kwargs:
for equiv in ['nlives', 'n_live_points', 'npoint', 'npoints']:
if equiv in self.__kwargs:
self.__kwargs['nlive'] = self.__kwargs.pop(equiv)
if 'nlive' not in self.__kwargs:
self.__kwargs['nlive'] = 250
# Set the update interval
if 'update_interval' not in self.__kwargs:
self.__kwargs['update_interval'] = int(0.6 * self.__kwargs['nlive'])
if 'n_check_point' not in kwargs:
# checkpointing done by default ~ every 10 minutes
# Set the checking pointing
# If the log_likelihood_eval_time was not able to be calculated
# then n_check_point is set to None (no checkpointing)
if np.isnan(self._log_likelihood_eval_time):
self.__kwargs['n_check_point'] = None
# If n_check_point is not already set, set it checkpoint every 10 mins
if 'n_check_point' not in self.__kwargs:
n_check_point_raw = (self.__kwargs['check_point_delta_t']
/ self._sample_log_likelihood_eval)
/ self._log_likelihood_eval_time)
n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw)))
self.__kwargs['n_check_point'] = n_check_point_rnd
......@@ -504,46 +536,19 @@ class Dynesty(Sampler):
def _run_external_sampler(self):
dynesty = self.external_sampler
if self.kwargs.get('dynamic', False) is False:
nested_sampler = dynesty.NestedSampler(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
ndim=self.ndim, **self.kwargs)
if self.kwargs['resume']:
resume = self.read_saved_state(nested_sampler, continuing=True)
if resume:
logger.info('Resuming from previous run.')
old_ncall = nested_sampler.ncall
maxcall = self.kwargs['n_check_point']
while True:
maxcall += self.kwargs['n_check_point']
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
print_func=self._print_func, maxcall=maxcall,
add_live=False)
if nested_sampler.ncall == old_ncall:
break
old_ncall = nested_sampler.ncall
self.write_current_state(nested_sampler)
self.read_saved_state(nested_sampler)
nested_sampler = dynesty.NestedSampler(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
ndim=self.ndim, **self.kwargs)
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
print_func=self._print_func, add_live=True)
if self.kwargs['n_check_point']:
out = self._run_external_sampler_with_checkpointing(nested_sampler)
else:
nested_sampler = dynesty.DynamicNestedSampler(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
ndim=self.ndim, **self.kwargs)
nested_sampler.run_nested(print_progress=self.kwargs['verbose'])
print("")
out = nested_sampler.results
out = self._run_external_sampler_without_checkpointing(nested_sampler)
# Flushes the output to force a line break
if self.kwargs["verbose"]:
print("")
# self.result.sampler_output = out
weights = np.exp(out['logwt'] - out['logz'][-1])
......@@ -556,9 +561,47 @@ class Dynesty(Sampler):
if self.plot:
self.generate_trace_plots(out)
self._remove_checkpoint()
return self.result
def _run_external_sampler_without_checkpointing(self, nested_sampler):
logger.debug("Running sampler without checkpointing")
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
print_func=self._print_func)
return nested_sampler.results
def _run_external_sampler_with_checkpointing(self, nested_sampler):
logger.debug("Running sampler with checkpointing")
if self.kwargs['resume']:
resume = self.read_saved_state(nested_sampler, continuing=True)
if resume:
logger.info('Resuming from previous run.')
old_ncall = nested_sampler.ncall
maxcall = self.kwargs['n_check_point']
while True:
maxcall += self.kwargs['n_check_point']
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
print_func=self._print_func, maxcall=maxcall,
add_live=False)
if nested_sampler.ncall == old_ncall:
break
old_ncall = nested_sampler.ncall
self.write_current_state(nested_sampler)
self.read_saved_state(nested_sampler)
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
print_func=self._print_func, add_live=True)
self._remove_checkpoint()
return nested_sampler.results
def _remove_checkpoint(self):
"""Remove checkpointed state"""
if os.path.isfile('{}/{}_resume.h5'.format(self.outdir, self.label)):
......@@ -773,7 +816,9 @@ class Emcee(Sampler):
def _run_external_sampler(self):
self.nwalkers = self.kwargs.get('nwalkers', 100)
self.nsteps = self.kwargs.get('nsteps', 100)
self.nburn = self.kwargs.get('nburn', 50)
self.nburn = self.kwargs.get('nburn', None)
self.burn_in_fraction = self.kwargs.get('burn_in_fraction', 0.25)
self.burn_in_act = self.kwargs.get('burn_in_act', 3)
a = self.kwargs.get('a', 2)
emcee = self.external_sampler
tqdm = utils.get_progress_bar(self.kwargs.pop('tqdm', 'tqdm'))
......@@ -806,18 +851,14 @@ class Emcee(Sampler):
pass
self.result.sampler_output = np.nan
self.calculate_autocorrelation(sampler)
self.setup_nburn()
self.result.nburn = self.nburn
self.result.samples = sampler.chain[:, self.nburn:, :].reshape(
(-1, self.ndim))
self.result.walkers = sampler.chain[:, :, :]
self.result.nburn = self.nburn
self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan
try:
logger.info("Max autocorr time = {}".format(
np.max(sampler.get_autocorr_time())))
except emcee.autocorr.AutocorrError as e:
logger.info("Unable to calculate autocorr time: {}".format(e))
return self.result
def lnpostfn(self, theta):
......@@ -827,6 +868,41 @@ class Emcee(Sampler):
else:
return self.log_likelihood(theta) + p
def setup_nburn(self):
""" Handles calculating nburn, either from a given value or inferred """
if type(self.nburn) in [float, int]:
self.nburn = int(self.nburn)
logger.info("Discarding {} steps for burn-in".format(self.nburn))
elif self.result.max_autocorrelation_time is None:
self.nburn = int(self.burn_in_fraction * self.nsteps)
logger.info("Autocorrelation time not calculated, discarding {} "
" steps for burn-in".format(self.nburn))
else:
self.nburn = int(
self.burn_in_act * self.result.max_autocorrelation_time)
logger.info("Discarding {} steps for burn-in, estimated from "
"autocorr".format(self.nburn))
def calculate_autocorrelation(self, sampler, c=3):
""" Uses the `emcee.autocorr` module to estimate the autocorrelation
Parameters
----------
c: float
The minimum number of autocorrelation times needed to trust the
estimate (default: `3`). See `emcee.autocorr.integrated_time`.
"""
import emcee
try:
self.result.max_autocorrelation_time = int(np.max(
sampler.get_autocorr_time(c=c)))
logger.info("Max autocorr time = {}".format(
self.result.max_autocorrelation_time))
except emcee.autocorr.AutocorrError as e:
self.result.max_autocorrelation_time = None
logger.info("Unable to calculate autocorr time: {}".format(e))
class Ptemcee(Emcee):
""" https://github.com/willvousden/ptemcee """
......@@ -869,7 +945,7 @@ class Ptemcee(Emcee):
def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
sampler='dynesty', use_ratio=None, injection_parameters=None,
conversion_function=None, plot=False, default_priors_file=None,
clean=None, meta_data=None, **kwargs):
clean=None, meta_data=None, save=True, **kwargs):
"""
The primary interface to easy parameter estimation
......@@ -908,6 +984,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
saving. For example, if `meta_data={dtype: 'signal'}`. Warning: in case
of conflict with keys saved by tupak, the meta_data keys will be
overwritten.
save: bool
If true, save the priors and results to disk.
**kwargs:
All kwargs are passed directly to the samplers `run` function
......@@ -920,7 +998,6 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
if clean:
utils.command_line_args.clean = clean
utils.check_directory_exists_and_if_not_mkdir(outdir)
implemented_samplers = get_implemented_samplers()
if priors is None:
......@@ -934,7 +1011,10 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
raise ValueError
priors.fill_priors(likelihood, default_priors_file=default_priors_file)
priors.write_to_file(outdir, label)
if save:
utils.check_directory_exists_and_if_not_mkdir(outdir)
priors.write_to_file(outdir, label)
if implemented_samplers.__contains__(sampler.title()):
sampler_class = globals()[sampler.title()]
......@@ -973,14 +1053,14 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
if conversion_function is not None:
result.injection_parameters = conversion_function(result.injection_parameters)
result.fixed_parameter_keys = sampler.fixed_parameter_keys
# result.prior = prior # Removed as this breaks the saving of the data
result.samples_to_posterior(likelihood=likelihood, priors=priors,
conversion_function=conversion_function)
result.kwargs = sampler.kwargs
result.save_to_file()
if save:
result.save_to_file()
logger.info("Results saved to {}/".format(outdir))
if plot:
result.plot_corner()
logger.info("Sampling finished, results saved to {}/".format(outdir))
logger.info("Summary of results:\n{}".format(result))
return result
else:
......
......@@ -18,12 +18,12 @@ except ImportError:
" not be able to use some of the prebuilt functions.")
class InterferometerSet(list):
class InterferometerList(list):
""" A list of Interferometer objects """
def __init__(self, interferometers):
""" Instantiate a InterferometerSet
""" Instantiate a InterferometerList
The InterferometerSet is a list of Interferometer objects, each
The InterferometerList is a list of Interferometer objects, each
object has the data used in evaluating the likelihood
Parameters
......@@ -178,18 +178,18 @@ class InterferometerSet(list):
return self[0].strain_data.frequency_array
def append(self, interferometer):
if isinstance(interferometer, InterferometerSet):
super(InterferometerSet, self).extend(interferometer)
if isinstance(interferometer, InterferometerList):
super(InterferometerList, self).extend(interferometer)
else:
super(InterferometerSet, self).append(interferometer)
super(InterferometerList, self).append(interferometer)
self._check_interferometers()
def extend(self, interferometers):
super(InterferometerSet, self).extend(interferometers)
super(InterferometerList, self).extend(interferometers)
self._check_interferometers()
def insert(self, index, interferometer):
super(InterferometerSet, self).insert(index, interferometer)
super(InterferometerList, self).insert(index, interferometer)
self._check_interferometers()
......@@ -305,8 +305,9 @@ class InterferometerStrainData(object):
@property
def maximum_frequency(self):
""" Force the maximum frequency be less than the Nyquist frequency """
if 2 * self.__maximum_frequency > self.sampling_frequency:
self.__maximum_frequency = self.sampling_frequency / 2.
if self.sampling_frequency is not None:
if 2 * self.__maximum_frequency > self.sampling_frequency:
self.__maximum_frequency = self.sampling_frequency / 2.
return self.__maximum_frequency
@maximum_frequency.setter
......@@ -1437,12 +1438,12 @@ class Interferometer(object):
outdir, self.name, label))
class TriangularInterferometer(InterferometerSet):
class TriangularInterferometer(InterferometerList):
def __init__(self, name, power_spectral_density, minimum_frequency, maximum_frequency,
length, latitude, longitude, elevation, xarm_azimuth, yarm_azimuth,
xarm_tilt=0., yarm_tilt=0.):
InterferometerSet.__init__(self, [])
InterferometerList.__init__(self, [])
self.name = name
# for attr in ['power_spectral_density', 'minimum_frequency', 'maximum_frequency']:
if isinstance(power_spectral_density, PowerSpectralDensity):
......@@ -1773,9 +1774,9 @@ def load_interferometer(filename):
def get_interferometer_with_open_data(
name, trigger_time, duration=4, start_time=None, roll_off=0.4, psd_offset=-1024,
psd_duration=100, cache=True, outdir='outdir', label=None, plot=True, filter_freq=None,
raw_data_file=None, **kwargs):
name, trigger_time, duration=4, start_time=None, roll_off=0.4,
psd_offset=-1024, psd_duration=100, cache=True, outdir='outdir',
label=None, plot=True, filter_freq=None, **kwargs):
"""
Helper function to obtain an Interferometer instance with appropriate
PSD and data, given an center_time.
......@@ -1799,8 +1800,6 @@ def get_interferometer_with_open_data(
`center_time+psd_offset` to `center_time+psd_offset + psd_duration`.
cache: bool, optional
Whether or not to store the acquired data
raw_data_file: str
Name of a raw data file if this supposed to be read from a local file
outdir: str
Directory where the psd files are saved
label: str
......@@ -1942,7 +1941,7 @@ def get_interferometer_with_fake_noise_and_injection(
def get_event_data(
event, interferometer_names=None, duration=4, roll_off=0.4,
psd_offset=-1024, psd_duration=100, cache=True, outdir='outdir',
label=None, plot=True, filter_freq=None, raw_data_file=None, **kwargs):
label=None, plot=True, filter_freq=None, **kwargs):
"""
Get open data for a specified event.
......@@ -1963,8 +1962,6 @@ def get_event_data(
`center_time+psd_offset` to `center_time+psd_offset + psd_duration`.
cache: bool
Whether or not to store the acquired data.
raw_data_file:
If we want to read the event data from a local file.
outdir: str
Directory where the psd files are saved
label: str
......@@ -1997,9 +1994,9 @@ def get_event_data(
name, trigger_time=event_time, duration=duration, roll_off=roll_off,
psd_offset=psd_offset, psd_duration=psd_duration, cache=cache,
outdir=outdir, label=label, plot=plot, filter_freq=filter_freq,
raw_data_file=raw_data_file, **kwargs))
**kwargs))
except ValueError as e:
logger.debug("Error raised {}".format(e))
logger.warning('No data found for {}.'.format(name))
return InterferometerSet(interferometers)
return InterferometerList(interferometers)
......@@ -56,7 +56,7 @@ class GravitationalWaveTransient(likelihood.Likelihood):
self.waveform_generator = waveform_generator
likelihood.Likelihood.__init__(self, waveform_generator.parameters)
self.interferometers = tupak.gw.detector.InterferometerSet(interferometers)
self.interferometers = tupak.gw.detector.InterferometerList(interferometers)
self.time_marginalization = time_marginalization
self.distance_marginalization = distance_marginalization
self.phase_marginalization = phase_marginalization
......