Skip to content
Snippets Groups Projects
Commit 214489cd authored by ChiWai Chan's avatar ChiWai Chan
Browse files

svd_bank.py: moved the clipping processes to the Bank class and modified...

svd_bank.py: moved the clipping processes to the Bank class and modified gstlal_svd_bank accordingly.
parent a1761eee
No related branches found
No related tags found
No related merge requests found
...@@ -148,19 +148,19 @@ svd_bank.write_bank( ...@@ -148,19 +148,19 @@ svd_bank.write_bank(
options.ortho_gate_fap, options.ortho_gate_fap,
inspiral_lr.LnLRDensity.snr_min, inspiral_lr.LnLRDensity.snr_min,
options.svd_tolerance, options.svd_tolerance,
clipleft,
clipright,
padding = options.padding, padding = options.padding,
identity_transform = options.identity_transform, identity_transform = options.identity_transform,
verbose = options.verbose, verbose = options.verbose,
autocorrelation_length = options.autocorrelation_length, autocorrelation_length = options.autocorrelation_length,
samples_min = options.samples_min, samples_min = options.samples_min,
samples_max_256 = options.samples_max_256, samples_max_256 = options.samples_max_256,
samples_max_64 = options.samples_max_64, samples_max_64 = options.samples_max_64,
samples_max = options.samples_max, samples_max = options.samples_max,
bank_id = bank_id, bank_id = bank_id,
contenthandler = svd_bank.DefaultContentHandler, contenthandler = svd_bank.DefaultContentHandler,
sample_rate = options.sample_rate sample_rate = options.sample_rate
) for (template_bank, bank_id) in zip(options.template_bank, options.bank_id)], ) for (template_bank, bank_id, clipleft, clipright) in zip(options.template_bank, options.bank_id, options.clipleft, options.clipright)],
psd, psd
options.clipleft,
options.clipright
) )
...@@ -174,7 +174,7 @@ class BankFragment(object): ...@@ -174,7 +174,7 @@ class BankFragment(object):
class Bank(object): class Bank(object):
def __init__(self, bank_xmldoc, psd, time_slices, gate_fap, snr_threshold, tolerance, flow = 40.0, autocorrelation_length = None, logname = None, identity_transform = False, verbose = False, bank_id = None, fhigh = None): def __init__(self, bank_xmldoc, psd, time_slices, gate_fap, snr_threshold, tolerance, clipleft = None, clipright = None, flow = 40.0, autocorrelation_length = None, logname = None, identity_transform = False, verbose = False, bank_id = None, fhigh = None):
# FIXME: remove template_bank_filename when no longer needed # FIXME: remove template_bank_filename when no longer needed
# by trigger generator element # by trigger generator element
self.template_bank_filename = None self.template_bank_filename = None
...@@ -231,6 +231,27 @@ class Bank(object): ...@@ -231,6 +231,27 @@ class Bank(object):
if verbose: if verbose:
print("sum-of-squares threshold for false-alarm probability of %.16g: %.16g" % (gate_fap, self.gate_threshold), file=sys.stderr) print("sum-of-squares threshold for false-alarm probability of %.16g: %.16g" % (gate_fap, self.gate_threshold), file=sys.stderr)
# Sanity checks before cliping
clipright = len(self.sngl_inspiral_table) - clipright if clipright is not None else None
doubled_clipright = clipright * 2 if clipright is not None else None
doubled_clipleft = clipleft * 2 if clipleft is not None else None
# Apply clipping options
new_sngl_table = self.sngl_inspiral_table.copy()
for row in self.sngl_inspiral_table[clipleft:clipright]:
# FIXME need a proper id column
row.Gamma1 = int(self.bank_id.split("_")[0])
new_sngl_table.append(row)
self.sngl_inspiral_table = new_sngl_table
self.autocorrelation_bank = self.autocorrelation_bank[clipleft:clipright,:]
self.autocorrelation_mask = self.autocorrelation_mask[clipleft:clipright,:]
self.sigmasq = self.sigmasq[clipleft:clipright]
self.bank_correlation_matrix = self.bank_correlation_matrix[clipleft:clipright,clipleft:clipright]
for i, frag in enumerate(self.bank_fragments):
if frag.mix_matrix is not None:
frag.mix_matrix = frag.mix_matrix[:,doubled_clipleft:doubled_clipright]
frag.chifacs = frag.chifacs[doubled_clipleft:doubled_clipright]
def get_rates(self): def get_rates(self):
return set(bank_fragment.rate for bank_fragment in self.bank_fragments) return set(bank_fragment.rate for bank_fragment in self.bank_fragments)
...@@ -241,7 +262,7 @@ class Bank(object): ...@@ -241,7 +262,7 @@ class Bank(object):
def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_tolerance, padding = 1.5, identity_transform = False, verbose = False, autocorrelation_length = 201, samples_min = 1024, samples_max_256 = 1024, samples_max_64 = 2048, samples_max = 4096, bank_id = None, contenthandler = None, sample_rate = None, instrument_override = None): def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_tolerance, clipleft = None, clipright = None, padding = 1.5, identity_transform = False, verbose = False, autocorrelation_length = 201, samples_min = 1024, samples_max_256 = 1024, samples_max_64 = 2048, samples_max = 4096, bank_id = None, contenthandler = None, sample_rate = None, instrument_override = None):
"""! """!
Return an instance of a Bank class. Return an instance of a Bank class.
...@@ -251,6 +272,8 @@ def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_ ...@@ -251,6 +272,8 @@ def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_
@param ortho_gate_fap The FAP threshold for the sum of squares threshold, see http://arxiv.org/abs/1101.0584 @param ortho_gate_fap The FAP threshold for the sum of squares threshold, see http://arxiv.org/abs/1101.0584
@param snr_threshold The SNR threshold for the search @param snr_threshold The SNR threshold for the search
@param svd_tolerance The target SNR loss of the SVD, see http://arxiv.org/abs/1005.0012 @param svd_tolerance The target SNR loss of the SVD, see http://arxiv.org/abs/1005.0012
@param clipleft The number of N poorly reconstructed templates from the left edge of each sub-bank to be removed
@param cliptright The number of N poorly reconstructed templates from the right edge of each sub-bank to be removed
@param padding The padding from Nyquist for any template time slice, e.g., if a time slice has a Nyquist of 256 Hz and the padding is set to 2, only allow the template frequency to extend to 128 Hz. @param padding The padding from Nyquist for any template time slice, e.g., if a time slice has a Nyquist of 256 Hz and the padding is set to 2, only allow the template frequency to extend to 128 Hz.
@param identity_transform Don't do the SVD, just do time slices and keep the raw waveforms @param identity_transform Don't do the SVD, just do time slices and keep the raw waveforms
@param verbose Be verbose @param verbose Be verbose
...@@ -300,6 +323,8 @@ def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_ ...@@ -300,6 +323,8 @@ def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_
gate_fap = ortho_gate_fap, gate_fap = ortho_gate_fap,
snr_threshold = snr_threshold, snr_threshold = snr_threshold,
tolerance = svd_tolerance, tolerance = svd_tolerance,
clipleft = clipleft,
clipright = clipright,
flow = flow, flow = flow,
autocorrelation_length = autocorrelation_length, # samples autocorrelation_length = autocorrelation_length, # samples
identity_transform = identity_transform, identity_transform = identity_transform,
...@@ -314,31 +339,19 @@ def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_ ...@@ -314,31 +339,19 @@ def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_
return bank return bank
def write_bank(filename, banks, psd_input, cliplefts = None, cliprights = None, verbose = False): def write_bank(filename, banks, psd_input, verbose = False):
"""Write SVD banks to a LIGO_LW xml file.""" """Write SVD banks to a LIGO_LW xml file."""
# Create new document # Create new document
xmldoc = ligolw.Document() xmldoc = ligolw.Document()
lw = xmldoc.appendChild(ligolw.LIGO_LW()) lw = xmldoc.appendChild(ligolw.LIGO_LW())
for bank, clipleft, clipright in zip(banks, cliplefts, cliprights): for bank in banks:
# set up root for this sub bank # set up root for this sub bank
root = lw.appendChild(ligolw.LIGO_LW(Attributes({u"Name": u"gstlal_svd_bank_Bank"}))) root = lw.appendChild(ligolw.LIGO_LW(Attributes({u"Name": u"gstlal_svd_bank_Bank"})))
# FIXME FIXME FIXME move this clipping stuff to the Bank class
# set the right clipping index
clipright = len(bank.sngl_inspiral_table) - clipright
# Apply clipping option to sngl inspiral table
# put the bank table into the output document
new_sngl_table = bank.sngl_inspiral_table.copy()
for row in bank.sngl_inspiral_table[clipleft:clipright]:
# FIXME need a proper id column
row.Gamma1 = int(bank.bank_id.split("_")[0])
new_sngl_table.append(row)
# put the possibly clipped table into the file # put the possibly clipped table into the file
root.appendChild(new_sngl_table) root.appendChild(bank.sngl_inspiral_table)
# Add root-level scalar params # Add root-level scalar params
root.appendChild(ligolw_param.Param.from_pyvalue('filter_length', bank.filter_length)) root.appendChild(ligolw_param.Param.from_pyvalue('filter_length', bank.filter_length))
...@@ -353,12 +366,6 @@ def write_bank(filename, banks, psd_input, cliplefts = None, cliprights = None, ...@@ -353,12 +366,6 @@ def write_bank(filename, banks, psd_input, cliplefts = None, cliprights = None,
root.appendChild(ligolw_param.Param.from_pyvalue('sample_rate_max', int(bank.sample_rate_max))) root.appendChild(ligolw_param.Param.from_pyvalue('sample_rate_max', int(bank.sample_rate_max)))
root.appendChild(ligolw_param.Param.from_pyvalue('gstlal_fir_whiten', os.environ['GSTLAL_FIR_WHITEN'])) root.appendChild(ligolw_param.Param.from_pyvalue('gstlal_fir_whiten', os.environ['GSTLAL_FIR_WHITEN']))
# apply clipping to autocorrelations and sigmasq
bank.autocorrelation_bank = bank.autocorrelation_bank[clipleft:clipright,:]
bank.autocorrelation_mask = bank.autocorrelation_mask[clipleft:clipright,:]
bank.sigmasq = bank.sigmasq[clipleft:clipright]
bank.bank_correlation_matrix = bank.bank_correlation_matrix[clipleft:clipright,clipleft:clipright]
# Add root-level arrays # Add root-level arrays
# FIXME: ligolw format now supports complex-valued data # FIXME: ligolw format now supports complex-valued data
root.appendChild(ligolw_array.Array.build('autocorrelation_bank_real', bank.autocorrelation_bank.real)) root.appendChild(ligolw_array.Array.build('autocorrelation_bank_real', bank.autocorrelation_bank.real))
...@@ -373,11 +380,6 @@ def write_bank(filename, banks, psd_input, cliplefts = None, cliprights = None, ...@@ -373,11 +380,6 @@ def write_bank(filename, banks, psd_input, cliplefts = None, cliprights = None,
# Start new bank fragment container # Start new bank fragment container
el = root.appendChild(ligolw.LIGO_LW()) el = root.appendChild(ligolw.LIGO_LW())
# Apply clipping option
if frag.mix_matrix is not None:
frag.mix_matrix = frag.mix_matrix[:,clipleft*2:clipright*2]
frag.chifacs = frag.chifacs[clipleft*2:clipright*2]
# Add scalar params # Add scalar params
el.appendChild(ligolw_param.Param.from_pyvalue('rate', int(frag.rate))) el.appendChild(ligolw_param.Param.from_pyvalue('rate', int(frag.rate)))
el.appendChild(ligolw_param.Param.from_pyvalue('start', frag.start)) el.appendChild(ligolw_param.Param.from_pyvalue('start', frag.start))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment