Commit 79fe4e13 authored by ChiWai Chan's avatar ChiWai Chan

gstlal_inspiral_calc_snr & svd_bank_snr.py: reducing memory usages when

outputting SNRs
parent 43258090
......@@ -134,11 +134,11 @@ def parse_command_line():
#parser.add_option_group(group)
group = OptionGroup(parser, "Output Control Options", "Control SNR output")
group.add_option("--outdir", metavar = "directory", type = "str", help = "Output directory for SNR(s) (requires).")
group.add_option("--outdir", metavar = "directory", type = "str", help = "Output directory for SNR(s) (require).")
group.add_option("--mode", metavar = "method", type = "int", default = 0, help = "The method (0 = LLOID / 1 = FIR) that is used to calculate SNR (default = 0).")
group.add_option("--complex", action = "store_true", help = "Choose whether to output the complex snr or not.")
group.add_option("--start", metavar = "seconds", type = "int", help = "Start SNR time series at GPS time '--start'.")
group.add_option("--end", metavar = "seconds", type = "int", help = "End SNR time series at GPS time '--end'.")
group.add_option("--start", metavar = "seconds", type = "int", help = "Start SNR time series at GPS time '--start' (require).")
group.add_option("--end", metavar = "seconds", type = "int", help = "End SNR time series at GPS time '--end' (require).")
group.add_option("--output-width", metavar = "bits", type = "int", default = 32, help = "The size of the output data, can only be 32 or 64 bits (default = 32 bits).")
group.add_option("--instrument", metavar = "name", help = "The detector from which the --reference-psd and --frame-cache are loaded (require).")
parser.add_option_group(group)
......@@ -148,7 +148,9 @@ def parse_command_line():
options, args = parser.parse_args()
# Check SNR series output
if options.start and options.end:
if options.start is None or options.end is None:
raise ValueError("Must have --start and --end.")
else:
if options.start >= options.end:
raise ValueError("--start must less than --end.")
......@@ -237,6 +239,9 @@ if options.mode == 0:
gw_data_source_info,
bank,
options.instrument,
options.row_number,
options.start,
options.end,
psd = psd,
psd_fft_length = options.psd_fft_length,
track_psd = options.track_psd,
......@@ -245,19 +250,19 @@ if options.mode == 0:
)
if options.row_number is None:
for index, snr in enumerate(lloid_snr(COMPLEX = options.complex, row_number = options.row_number, start = options.start, end = options.end)):
for index, snr in enumerate(lloid_snr(COMPLEX = options.complex)):
snr.epoch += bank.sngl_inspiral_table[index].end
snrdict = {options.instrument: [snr]}
svd_bank_snr.write_url(svd_bank_snr.make_xmldoc(snrdict), os.path.join(options.outdir, "%s-SNR_%d-%d-%d.xml.gz" % (options.instrument, index, int(snr.epoch), int(snr.data.length * snr.deltaT))), verbose = options.verbose)
else:
lloidsnr = lloid_snr(COMPLEX = options.complex, row_number = options.row_number, start = options.start, end = options.end)
lloidsnr = lloid_snr(COMPLEX = options.complex)
lloidsnr[0].epoch += bank.sngl_inspiral_table[options.row_number].end
snrdict = {options.instrument: lloidsnr}
svd_bank_snr.write_url(svd_bank_snr.make_xmldoc(snrdict), os.path.join(options.outdir, "%s-SNR_%d-%d-%d.xml.gz" % (options.instrument, options.row_number, int(lloidsnr[0].epoch), int(lloidsnr[0].data.length * lloidsnr[0].deltaT))), verbose = options.verbose)
#
# uncomment to save all snrs in one single XML
#
#snrdict = {options.instrument : lloid_snr(COMPLEX = options.complex, row_number = options.row_number, drop_first = options.drop_first, drop_last = options.drop_last)}
#snrdict = {options.instrument : lloid_snr(COMPLEX = options.complex)}
#svd_bank_snr.write_url(svd_bank_snr.make_xmldoc(snrdict), os.path.join(options.outdir, "%s-SNR-%d-%d.xml.gz" % (options.instrument, int(snrdict.[options.instrument][0].epoch), int(snrdict[options.instrument][0].data.length * snrdict[options.instrument][0].deltaT))), verbose = options.verbose)
elif options.mode == 1:
......@@ -271,6 +276,8 @@ elif options.mode == 1:
options.instrument,
options.sample_rate,
0,
options.start,
options.end,
psd = psd,
psd_fft_length = options.psd_fft_length,
width = options.output_width,
......@@ -278,7 +285,7 @@ elif options.mode == 1:
verbose = options.verbose
)
firsnr = fir_snr(COMPLEX = options.complex, start = options.start, end = options.end)
firsnr = fir_snr(COMPLEX = options.complex)
firsnr[0].epoch += time_offset
snrdict = {options.instrument : firsnr}
svd_bank_snr.write_url(svd_bank_snr.make_xmldoc(snrdict),os.path.join(options.outdir, "%s-SNR-%d-%d.xml.gz" % (options.instrument, int(snrdict[options.instrument][0].epoch), int(snrdict[options.instrument][0].data.length * snrdict[options.instrument][0].deltaT))), verbose = options.verbose)
......@@ -40,12 +40,15 @@ class SNRContentHandler(ligolw.LIGOLWContentHandler):
pass
class SNR_Pipeline(object):
def __init__(self, name = "gstlal_inspiral_SNR", verbose = False):
def __init__(self, row_number, start, end, name = "gstlal_inspiral_SNR", verbose = False):
self.pipeline = Gst.Pipeline(name = name)
self.mainloop = GObject.MainLoop()
self.handler = simplehandler.Handler(self.mainloop, self.pipeline)
self.verbose = verbose
self.lock = threading.Lock()
self.row_number = row_number
self.start = start
self.end = end
self.snr_info = {
"epoch": None,
"instrument": None,
......@@ -76,68 +79,43 @@ class SNR_Pipeline(object):
raise RuntimeError("pipeline could not be set to NULL.")
def make_series(self, data):
para = {"name" : self.snr_info["instrument"],
"epoch" : self.snr_info["epoch"],
"deltaT" : self.snr_info["deltaT"],
"f0": 0,
"sampleUnits" : lal.DimensionlessUnit,
"length" : len(data)}
if data.dtype == numpy.float32:
tseries = lal.CreateREAL4TimeSeries(
name = self.snr_info["instrument"],
epoch = self.snr_info["epoch"],
deltaT = self.snr_info["deltaT"],
f0 = 0,
sampleUnits = lal.DimensionlessUnit,
length = len(data)
)
tseries.data.data = data
tseries = lal.CreateREAL4TimeSeries(**para)
elif data.dtype == numpy.float64:
tseries = lal.CreateREAL8TimeSeries(
name = self.snr_info["instrument"],
epoch = self.snr_info["epoch"],
deltaT = self.snr_info["deltaT"],
f0 = 0,
sampleUnits = lal.DimensionlessUnit,
length = len(data)
)
tseries.data.data = data
tseries = lal.CreateREAL8TimeSeries(**para)
elif data.dtype == numpy.complex64:
tseries = lal.CreateCOMPLEX8TimeSeries(
name = self.snr_info["instrument"],
epoch = self.snr_info["epoch"],
deltaT = self.snr_info["deltaT"],
f0 = 0,
sampleUnits = lal.DimensionlessUnit,
length = len(data)
)
tseries.data.data = data
tseries = lal.CreateCOMPLEX8TimeSeries(**para)
elif data.dtype == numpy.complex128:
tseries = lal.CreateCOMPLEX16TimeSeries(
name = self.snr_info["instrument"],
epoch = self.snr_info["epoch"],
deltaT = self.snr_info["deltaT"],
f0 = 0,
sampleUnits = lal.DimensionlessUnit,
length = len(data)
)
tseries.data.data = data
tseries = lal.CreateCOMPLEX16TimeSeries(**para)
else:
raise ValueError("unsupported type : %s " % data.dtype)
tseries.data.data = data
return tseries
def get_snr_series(self, COMPLEX = False, row_number = None, start = None, end = None):
def get_snr_series(self, COMPLEX = False):
gps_start = self.snr_info["epoch"].gpsSeconds + self.snr_info["epoch"].gpsNanoSeconds * 10.**-9
gps = gps_start + numpy.arange(len(self.snr_info["data"])) * self.snr_info["deltaT"]
if start and end:
if start >= end:
if self.start and self.end:
if self.start >= self.end:
raise ValueError("Start time must be less than end time.")
if start - gps[0] >= 0 and start - gps[-1] <= 0:
s = abs(gps - start).argmin()
if self.start - gps[0] >= 0 and self.start - gps[-1] <= 0:
s = abs(gps - self.start).argmin()
else:
raise ValueError("Invalid choice of start time %f." % start)
raise ValueError("Invalid choice of start time %f." % self.start)
if end - gps[0] >= 0 and end - gps[-1] <= 0:
e = abs(gps - end).argmin()
if self.end - gps[0] >= 0 and self.end - gps[-1] <= 0:
e = abs(gps - self.end).argmin()
else:
raise ValueError("Invalid choice of end time %f." % end)
raise ValueError("Invalid choice of end time %f." % self.end)
self.snr_info["epoch"] = gps[s]
self.snr_info["data"] = self.snr_info["data"][s:e].T
......@@ -145,7 +123,7 @@ class SNR_Pipeline(object):
self.snr_info["epoch"] = gps[0]
self.snr_info["data"] = self.snr_info["data"].T
if row_number is None:
if self.row_number is None:
temp = []
if COMPLEX:
for data in self.snr_info["data"]:
......@@ -156,7 +134,7 @@ class SNR_Pipeline(object):
temp.append(self.make_series(numpy.abs(data)))
return temp
else:
self.snr_info["data"] = self.snr_info["data"][row_number]
self.snr_info["data"] = self.snr_info["data"][self.row_number]
if COMPLEX:
return [self.make_series(self.snr_info["data"])]
else:
......@@ -184,22 +162,27 @@ class SNR_Pipeline(object):
else:
assert self.snr_info["deltaT"] == 1. / rate, "data have different sampling rate."
# record the first timestamp
if self.snr_info["epoch"] is None:
self.snr_info["epoch"] = LIGOTimeGPS(0, sample.get_buffer().pts)
buf = sample.get_buffer()
if buf.mini_object.flags & Gst.BufferFlags.GAP or buf.n_memory() == 0:
return Gst.FlowReturn.OK
# FIXME: check timestamps
data = pipeio.array_from_audio_sample(sample)
if data is not None:
self.snr_info["data"].append(data)
buf = sample.get_buffer()
if buf.mini_object.flags & Gst.BufferFlags.GAP or buf.n_memory() == 0:
return Gst.FlowReturn.OK
# drop snrs that are irrelevant
cur_time_stamp = LIGOTimeGPS(0, sample.get_buffer().pts)
if self.start >= cur_time_stamp and self.end > cur_time_stamp:
# record the first timestamp closet to start time
self.snr_info["epoch"] = cur_time_stamp
# FIXME: check timestamps
self.snr_info["data"] = [pipeio.array_from_audio_sample(sample)]
elif self.start <= cur_time_stamp < self.end:
self.snr_info["data"].append(pipeio.array_from_audio_sample(sample))
else:
Gst.FlowReturn.OK
return Gst.FlowReturn.OK
class LLOID_SNR(SNR_Pipeline):
def __init__(self, gw_data_source_info, bank, instrument, psd = None, psd_fft_length = 32, ht_gate_threshold = float("inf"), veto_segments = None, track_psd = False, width = 32, verbose = False):
SNR_Pipeline.__init__(self, name = "gstlal_inspiral_lloid_snr", verbose = verbose)
def __init__(self, gw_data_source_info, bank, instrument, row_number, start, end, psd = None, psd_fft_length = 32, ht_gate_threshold = float("inf"), veto_segments = None, track_psd = False, width = 32, verbose = False):
SNR_Pipeline.__init__(self, row_number, start, end, name = "gstlal_inspiral_lloid_snr", verbose = verbose)
self.snr_info["instrument"] = instrument
# sanity check
......@@ -250,12 +233,12 @@ class LLOID_SNR(SNR_Pipeline):
self.run(gw_data_source_info.seg)
self.snr_info["data"] = numpy.concatenate(numpy.array(self.snr_info["data"]), axis = 0)
def __call__(self, COMPLEX = False, row_number = 0, start = None, end = None):
return self.get_snr_series(COMPLEX, row_number, start, end)
def __call__(self, COMPLEX = False):
return self.get_snr_series(COMPLEX)
class FIR_SNR(SNR_Pipeline):
def __init__(self, gw_data_source_info, template, instrument, rate, latency, psd = None, psd_fft_length = 32, ht_gate_threshold = float("inf"), veto_segments = None, width = 32, track_psd = False, verbose = False):
SNR_Pipeline.__init__(self, name = "gstlal_inspiral_fir_snr", verbose = verbose)
def __init__(self, gw_data_source_info, template, instrument, rate, latency, start, end, psd = None, psd_fft_length = 32, ht_gate_threshold = float("inf"), veto_segments = None, width = 32, track_psd = False, verbose = False):
SNR_Pipeline.__init__(self, 0, start, end, name = "gstlal_inspiral_fir_snr", verbose = verbose)
self.snr_info["instrument"] = instrument
# sanity check
......@@ -317,8 +300,8 @@ class FIR_SNR(SNR_Pipeline):
return template, row[0].end
def __call__(self, COMPLEX = False, row_number = 0 , start = None, end = None):
return self.get_snr_series(COMPLEX, row_number, start, end)
def __call__(self, COMPLEX = False):
return self.get_snr_series(COMPLEX)
#=============================================================================================
#
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment