Commit 5de02649 authored by alex codoreanu's avatar alex codoreanu
Browse files

Added remote logging to bank generation class.

RELATED TO SPIIP-120.
parent d522a1e5
......@@ -23,21 +23,14 @@ import sys
import numpy
import scipy
import cmath
from scipy import integrate
from scipy import interpolate
import math
import csv
import logging
import tempfile
import lal
import lalsimulation
import logging
import tempfile
from gstlal import cbc_template_fir
from glue.ligolw import ligolw, lsctables, array, param, utils, types
from gstlal.pipeio import repack_complex_array_to_real, repack_real_array_to_complex
from gstlal import cbc_template_fir
from gstlal import chirptime
import random
import pdb
from gstlal.spiirbank.optimizer import optimize_a1
Attributes = ligolw.sax.xmlreader.AttributesImpl
......@@ -64,8 +57,7 @@ class XMLContentHandler(ligolw.LIGOLWContentHandler):
ValidApproximantsFD = set(("SpinTaylorT4", "SEOBNRv4_ROM"))
# copied from gstlal-inspiral/ templates.py
gstlal_IMR_approximants = set(
('EOBNRv2', 'IMRPhenomC', 'SEOBNRv4_ROM', 'SEOBNRv2_ROM_DoubleSpin'))
gstlal_IMR_approximants = set(('EOBNRv2', 'IMRPhenomC', 'SEOBNRv4_ROM', 'SEOBNRv2_ROM_DoubleSpin'))
def condition_imr_template(approximant, data, epoch_time, sample_rate_max,
......@@ -894,7 +886,6 @@ class Bank(object):
nround_max=10,
alpha=.99,
beta=0.25,
pnorder=4,
flower=15,
snr_cut=0.998,
all_psd=None,
......@@ -906,10 +897,25 @@ class Bank(object):
keep_track=True,
output_file=None,
remote_log=False,
remote_db_engine=None,
remote_log_table_name=None,
contenthandler=DefaultContentHandler):
"""
Build SPIIR template bank from physical parameters, e.g. mass, spin.
"""
if remote_log:
if remote_db_engine is None or remote_log_table_name is None:
print "you told me to keep a remote_log but" \
"you did not provide me a remote database engine connection " \
"or a table to write to."
exit(0)
else:
try:
from pandas import DataFrame
remote_log_df = DataFrame()
except Exception as local_exception:
print "Failed to import from pandas import DataFrame\n"
print "Exception is:\n{}".format(local_exception)
if keep_track:
if output_file is None:
......@@ -1017,10 +1023,6 @@ class Bank(object):
Bmat = {}
Dmat = {}
# TODO:
# add new definition of templates for default value
# to be list(range(len(sngl_inspiral_table)))
# templates need depend on the size of the input bank
if templates is None:
templates = list(range(len(self.sngl_inspiral_table)))
......@@ -1036,7 +1038,10 @@ class Bank(object):
if keep_track:
with open(track_file, 'w') as w:
w.writelines('{}'.format(tmp))
w.writelines('{}'.format(row))
if remote_log:
remote_log_df.loc[tmp, 'template_id'] = tmp
spiir_match = -1
epsilon = epsilon_start
......@@ -1094,6 +1099,11 @@ class Bank(object):
"spiir_match_min %s, n_filters_min %s, n_filters_max %s" %
(spiir_match_min, n_filters_min, n_filters_max))
if remote_log:
remote_log_df.loc[tmp, 'spiir_match_min'] = spiir_match_min
remote_log_df.loc[tmp, 'n_filters_min'] = n_filters_min
remote_log_df.loc[tmp, 'n_filters_max'] = n_filters_max
# h_pad is just the padded cut template
h_pad = pad_data(data, pad_length)
......@@ -1141,8 +1151,16 @@ class Bank(object):
epsilon_b,
spiir_match,
n_filters))
nround += 1
if remote_log:
remote_log_df.loc[tmp, 'nround'] = nround
remote_log_df.loc[tmp, 'epsilon_a'] = epsilon_a
remote_log_df.loc[tmp, 'epsilon'] = epsilon
remote_log_df.loc[tmp, 'epsilon_b'] = epsilon_b
remote_log_df.loc[tmp, 'spiir_match'] = spiir_match
remote_log_df.loc[tmp, 'n_filters'] = n_filters
nround += 1
epsilon_dir = 0
if n_filters_max is not None and n_filters > n_filters_max:
# we need to increase epsilon to decrease filters
......@@ -1278,6 +1296,17 @@ class Bank(object):
row.mass2, epsilon, n_filters, spiir_match,
epsilon_start, original_filters, original_match))
if remote_log:
remote_log_df.loc[tmp, 'len_sngl_inspiral_table'] = len(self.sngl_inspiral_table)
remote_log_df.loc[tmp, 'mass1'] = row.mass1
remote_log_df.loc[tmp, 'mass2'] = row.mass2
remote_log_df.loc[tmp, 'epsilon'] = epsilon
remote_log_df.loc[tmp, 'n_filters'] = n_filters
remote_log_df.loc[tmp, 'spiir_match'] = spiir_match
remote_log_df.loc[tmp, 'epsilon_start'] = epsilon_start
remote_log_df.loc[tmp, 'original_filters'] = original_filters
remote_log_df.loc[tmp, 'original_match'] = original_match
# get the filter frequencies
fs = -1. * numpy.angle(a1) / 2 / numpy.pi # Normalised freqeuncy
a1dict = {}
......@@ -1302,7 +1331,7 @@ class Bank(object):
b0dict.setdefault(sampleRate / M, []).append(
b0[i] * M**0.5 * a1[i]**(newdelay * M - delay[i]))
delaydict.setdefault(sampleRate / M, []).append(newdelay)
#logging.info("sampleRate %4.0d, filter %3.0d, M %2.0d, f %10.9f, delay %d, newdelay %d" % (sampleRate, i, M, f, delay[i], newdelay))
else:
a1dict[int(sampleRate)] = a1
b0dict[int(sampleRate)] = b0
......@@ -1326,6 +1355,13 @@ class Bank(object):
"rate %d, dmin %d, dmax %d, max_row %d, max_len %d" %
(rate, DmatMin, DmatMax, max_rows, max_len))
if remote_log:
remote_log_df.loc[tmp, 'rate'] = rate
remote_log_df.loc[tmp, 'DmatMin'] = DmatMin
remote_log_df.loc[tmp, 'DmatMax'] = DmatMax
remote_log_df.loc[tmp, 'max_rows'] = max_rows
remote_log_df.loc[tmp, 'max_len'] = max_len
self.A[rate] = numpy.zeros((max_rows, max_len),
dtype=numpy.complex128)
self.B[rate] = numpy.zeros((max_rows, max_len),
......@@ -1340,6 +1376,16 @@ class Bank(object):
for i, Dm in enumerate(Dmat[rate]):
self.D[rate][i, :len(Dm)] = Dm
if remote_log:
remote_log_df['filename'] = filename
remote_log_df.to_sql(remote_log_table_name,
con=remote_db_engine,
chunksize=remote_log_df.filename.size,
if_exists='append',
index=False)
del remote_log_df
def downsample_bank(self, flower=15, padding=1.3, verbose=True):
Amat = {}
Bmat = {}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment