From ff70b08036670a7cdfd9e80d85416859cdce2b3e Mon Sep 17 00:00:00 2001
From: Daichi Tsuna <daichi.tsuna@ligo.org>
Date: Sun, 24 Feb 2019 02:48:26 -0800
Subject: [PATCH] add autocorrelation chi2 calculation

currently compiles but segfaults in the middle...
---
 gstlal-burst/bin/gstlal_cs_triggergen         |  52 +++--
 .../gst/lal/gstlal_string_triggergen.c        | 186 +++++++++++++++++-
 .../gst/lal/gstlal_string_triggergen.h        |  16 +-
 3 files changed, 229 insertions(+), 25 deletions(-)

diff --git a/gstlal-burst/bin/gstlal_cs_triggergen b/gstlal-burst/bin/gstlal_cs_triggergen
index 8d2b3a9ca8..e343940c72 100755
--- a/gstlal-burst/bin/gstlal_cs_triggergen
+++ b/gstlal-burst/bin/gstlal_cs_triggergen
@@ -85,6 +85,7 @@ class PSDHandler(simplehandler.Handler):
 	def __init__(self, mainloop, pipeline, firbank):
 		simplehandler.Handler.__init__(self,mainloop, pipeline)
 		self.firbank = firbank
+		self.triggergen = triggergen
 
 	def do_on_message(self, bus, message):
 		if message.type == Gst.MessageType.ELEMENT and  message.get_structure().get_name() == "spectrum":
@@ -94,57 +95,84 @@ class PSDHandler(simplehandler.Handler):
 			stability = float(message.src.get_property("n-samples")) / message.src.get_property("average-samples")
 
 			if stability > 0.3:
-				template_bank = [None] * len(template_bank_table)
+				print >> sys.stderr, "At GPS time", timestamp, "PSD stable"
 				template_t = [None] * len(template_bank_table)
 				autocorr = [None] * len(template_bank_table)
 				# make templates, whiten, put into firbank
 				for i, row in enumerate(template_bank_table):
 					# linearly polarized, so just use plus mode time series
-					template_bank[i], _ = lalsimulation.GenerateStringCusp(1.0,row.central_freq,1.0/options.sample_rate)
+					template_t[i], _ = lalsimulation.GenerateStringCusp(1.0,row.central_freq,1.0/options.sample_rate)
 					# zero-pad it to 32 seconds to obtain same deltaF as the PSD
-					template_bank[i] = lal.ResizeREAL8TimeSeries(template_bank[i],-int(32*options.sample_rate - template_bank[i].data.length)//2,int(32*options.sample_rate))
+					template_t[i] = lal.ResizeREAL8TimeSeries(template_t[i],-int(32*options.sample_rate - template_t[i].data.length)//2,int(32*options.sample_rate))
 					# setup of frequency domain
-					length = template_bank[i].data.length
+					length = template_t[i].data.length
 					duration = float(length) / options.sample_rate
 					epoch = -(length-1) // 2 /options.sample_rate
 					template_f = lal.CreateCOMPLEX16FrequencySeries("template_freq", LIGOTimeGPS(epoch), psd.f0, 1.0/duration, lal.Unit("strain s"), length // 2 + 1)
 					fplan = lal.CreateForwardREAL8FFTPlan(length,0)
 					# FFT to frequency domain
-					lal.REAL8TimeFreqFFT(template_f,template_bank[i],fplan)
+					lal.REAL8TimeFreqFFT(template_f,template_t[i],fplan)
 					# set DC and Nyquist to zero
 					template_f.data.data[0] = 0.0
 					template_f.data.data[template_f.data.length-1] = 0.0
 					# whiten
 					template_f = lal.WhitenCOMPLEX16FrequencySeries(template_f,psd)
-					# obtain autocorrelation time series by
-					# squaring the template and inverse FFTing it
+					# obtain autocorr time series by squaring template and inverse FFT it
 					template_f_squared = lal.CreateCOMPLEX16FrequencySeries("whitened template_freq squared", LIGOTimeGPS(epoch), psd.f0, 1.0/duration, lal.Unit("s"), length // 2 + 1)
-					autocorr[i] = lal.CreateREAL8TimeSeries("autocorr_time", LIGOTimeGPS(epoch), psd.f0, 1.0 / options.sample_rate, lal.Unit("strain"), length)
+					autocorr_t = lal.CreateREAL8TimeSeries("autocorr_time", LIGOTimeGPS(epoch), psd.f0, 1.0 / options.sample_rate, lal.Unit("strain"), length)
 					rplan = lal.CreateReverseREAL8FFTPlan(length,0)
 					template_f_squared.data.data = abs(template_f.data.data)**2
-					lal.REAL8FreqTimeFFT(autocorr[i],template_f_squared,rplan)
+					lal.REAL8FreqTimeFFT(autocorr_t,template_f_squared,rplan)
 					# normalize autocorrelation by central (maximum) value
-					autocorr[i].data.data /= numpy.max(autocorr[i].data.data)
+					autocorr_t.data.data /= numpy.max(autocorr_t.data.data)
+					autocorr_t = autocorr_t.data.data
+					max_index = numpy.argmax(autocorr_t)
+					# find the index of the third extremum for the first template
+					# FIXME we do this because we know that the template with the lowest high-f cutoff will have
+					# the largest chi2_index.
+					if i == 0:
+						extr_ctr = 0
+						chi2_index = 0
+						for j in range(max_index+1, len(autocorr_t)):
+							slope1 = autocorr_t[j+1] - autocorr_t[j]
+							slope0 = autocorr_t[j] - autocorr_t[j-1]
+							chi2_index += 1
+							if(slope1 * slope0 < 0):
+								extr_ctr += 1
+								if(extr_ctr == 2):
+									break
+					assert extr_ctr == 2, 'could not find 3rd extremum'
+					# extract the part within the third extremum, setting the peak to be the center.
+					autocorr[i] = numpy.concatenate((autocorr_t[1:(chi2_index+1)][::-1], autocorr_t[:(chi2_index+1)]))
+					assert len(autocorr[i])%2==1, 'autocorr must have odd number of samples'
 					# Inverse FFT template bank back to time domain
 					template_t[i] = lal.CreateREAL8TimeSeries("whitened template time", LIGOTimeGPS(epoch), psd.f0, 1.0 / options.sample_rate, lal.Unit("strain"), length)
 					lal.REAL8FreqTimeFFT(template_t[i],template_f,rplan)
 					# normalize
 					template_t[i].data.data /= numpy.sqrt(numpy.dot(template_t[i].data.data, template_t[i].data.data))
 					template_t[i] = template_t[i].data.data
-				firbank.set_property("latency",-(len(template_t[0]) - 1) // 2)
+				firbank.set_property("latency", -(len(template_t[0]) - 1) // 2)
 				firbank.set_property("fir_matrix", template_t)
+				triggergen.set_property("autocorrelation_matrix", autocorr)
+				print >> sys.stderr, "finished setting all the properties" 
 				self.firbank = firbank
+				self.triggergen = triggergen
 			else:
 				# use templates with all zeros during burn-in period, that way we won't get any triggers.
+				print >> sys.stderr, "At GPS time", timestamp, "burn in period"
 				template = [None] * len(template_bank_table)
+				autocorr = [None] * len(template_bank_table)
 				for i, row in enumerate(template_bank_table):
 					template[i], _ = lalsimulation.GenerateStringCusp(1.0,30,1.0/options.sample_rate)
 					template[i] = lal.ResizeREAL8TimeSeries(template[i], -int(32*options.sample_rate - template[i].data.length)//2 ,int(32*options.sample_rate))
 					template[i] = template[i].data.data
 					template[i] *= 0.0
+					autocorr[i] = template[i]
 				firbank.set_property("latency",-(len(template[0]) - 1) // 2)
 				firbank.set_property("fir_matrix", template)
+				triggergen.set_property("autocorrelation_matrix", autocorr)
 				self.firbank = firbank
+				self.triggergen = triggergen
 			return True
 		return False
 
@@ -228,7 +256,7 @@ xmldoc.childNodes[-1].appendChild(sngl_burst_table)
 # trigger generator
 #
 
-head = pipeparts.mkgeneric(pipeline, head, "lal_string_triggergen", threshold = options.threshold, cluster = options.cluster_events, bank_filename = options.template_bank)
+head = triggergen = pipeparts.mkgeneric(pipeline, head, "lal_string_triggergen", threshold = options.threshold, cluster = options.cluster_events, bank_filename = options.template_bank)
 
 
 #
diff --git a/gstlal-burst/gst/lal/gstlal_string_triggergen.c b/gstlal-burst/gst/lal/gstlal_string_triggergen.c
index 8eecfa1b3d..2cf2e6d1f2 100644
--- a/gstlal-burst/gst/lal/gstlal_string_triggergen.c
+++ b/gstlal-burst/gst/lal/gstlal_string_triggergen.c
@@ -1,3 +1,8 @@
+/*
+ * Cosmic string trigger generator and autocorrelation chisq plugin element
+ */
+
+
 /*
  *======================================================
  *
@@ -24,7 +29,9 @@
 #include <glib.h>
 #include <gst/gst.h>
 #include <gst/audio/audio.h>
+#include <gst/base/gstadapter.h>
 #include <gst/base/gstbasetransform.h>
+#include <gsl/gsl_errno.h>
 
 
 /*
@@ -43,9 +50,13 @@
  */
 
 
+#include <gstlal/gstlal.h>
 #include <gstlal_string_triggergen.h>
 #include <gstlal/gstaudioadapter.h>
+#include <gstlal/gstlal_autocorrelation_chi2.h>
 #include <gstlal/gstlal_debug.h>
+#include <gstlal/gstlal_peakfinder.h>
+
 
 /*
  *======================================================
@@ -75,7 +86,7 @@ G_DEFINE_TYPE_WITH_CODE(
  */
 
 
-#define DEFAULT_THRES 5.5
+#define DEFAULT_THRES 4.0
 #define DEFAULT_CLUSTER 0.1
 
 
@@ -88,6 +99,19 @@ G_DEFINE_TYPE_WITH_CODE(
  */
 
 
+static unsigned autocorrelation_length(const GSTLALStringTriggergen *element)
+{
+	return gstlal_autocorrelation_chi2_autocorrelation_length(element->autocorrelation_matrix);
+}
+
+
+static guint64 output_num_bytes(GSTLALStringTriggergen *element)
+{
+	// FIXME don't hardcode sample rate
+        return (guint64) 8192 * element->adapter->unit_size;
+}
+
+
 static void free_bankfile(GSTLALStringTriggergen *element)
 {
 	g_free(element->bank_filename);
@@ -146,10 +170,6 @@ static int setup_bankfile_input(GSTLALStringTriggergen *element, char *bank_file
 
 static GstFlowReturn trigger_generator(GSTLALStringTriggergen *element, GstBuffer *inbuf, GstBuffer *outbuf)
 {
-	/*
-	 * find events. we do not use chisq.
-	 */
-
 	GstMapInfo inmap;
 	float *snrdata;
 	SnglBurst *triggers = NULL;
@@ -167,6 +187,17 @@ static GstFlowReturn trigger_generator(GSTLALStringTriggergen *element, GstBuffe
 	t0 = GST_BUFFER_PTS(inbuf);
 	length = GST_BUFFER_OFFSET_END(inbuf) - GST_BUFFER_OFFSET(inbuf);
 
+	/* copy samples */
+	gst_audioadapter_copy_samples(element->adapter, element->data, length, NULL, NULL);
+
+	/* compute the chisq norm if it doesn't exist */
+	if (!element->autocorrelation_norm)
+		element->autocorrelation_norm = gstlal_autocorrelation_chi2_compute_norms(element->autocorrelation_matrix, NULL);
+	
+	/* check that autocorrelation vector has odd number of samples */
+	g_assert(autocorrelation_length(element) & 1);
+	
+	/* find events */
 	GST_DEBUG_OBJECT(element, "searching %" G_GUINT64_FORMAT " samples at %" GST_TIME_SECONDS_FORMAT " for events with SNR greater than %f", length, GST_TIME_SECONDS_ARGS(t0),element->threshold);
 	for(sample = 0; sample < length; sample++){
 		LIGOTimeGPS t;
@@ -183,6 +214,17 @@ static GstFlowReturn trigger_generator(GSTLALStringTriggergen *element, GstBuffe
 					 */
 					element->bank[channel].snr = snr;
 					element->bank[channel].peak_time = t;
+					element->bank[channel].chisq_dof = 1.0;
+					/*
+					 * We calculate chisq each time this update occurs, by defining this as peak.
+					 */
+					element->maxdata->values.as_double[channel] = *snrdata;
+					element->maxdata->samples[channel] = sample;
+					/* extract data around peak for chisq calculation */
+					/* put the dat pointer one pad length in */
+					gstlal_double_series_around_peak(element->maxdata, ((double *) element->data) + element->maxdata->pad * element->num_templates, (double *) element->snr_mat, element->maxdata->pad);
+					/* calculate chisq */
+					gstlal_autocorrelation_chi2(&element->bank[channel].chisq, (double complex *) element->snr_mat, autocorrelation_length(element), -((int) autocorrelation_length(element)) / 2, element->threshold, &element->autocorrelation_matrix[channel], NULL, &element->autocorrelation_norm[channel]);
 				}
 			} else if(element->bank[channel].snr != 0. && XLALGPSDiff(&t, &element->last_time[channel]) > element->cluster) {
 				/*
@@ -192,6 +234,8 @@ static GstFlowReturn trigger_generator(GSTLALStringTriggergen *element, GstBuffe
 				triggers = g_renew(SnglBurst, triggers, ntriggers + 1);
 				triggers[ntriggers++] = element->bank[channel];
 				element->bank[channel].snr = 0.0;
+				element->bank[channel].chisq = 0.0;
+				element->bank[channel].chisq_dof = 0.0;
 			}
 		}
 	}
@@ -291,7 +335,22 @@ static gboolean set_caps(GstBaseTransform *trans, GstCaps *incaps, GstCaps *outc
 	GSTLALStringTriggergen *element = GSTLAL_STRING_TRIGGERGEN(trans);
 	gboolean success = gst_audio_info_from_caps(&element->audio_info, incaps);
 
-	g_object_set(element->adapter, "unit-size", GST_AUDIO_INFO_WIDTH(&element->audio_info) / 8, NULL);
+	g_object_set(element->adapter, "unit-size", GST_AUDIO_INFO_WIDTH(&element->audio_info) / 8 * element->num_templates, NULL);
+
+	if (element->maxdata)
+		gstlal_peak_state_free(element->maxdata);
+	element->maxdata = gstlal_peak_state_new(element->num_templates, GSTLAL_PEAK_DOUBLE_COMPLEX);
+	/* Update padding any time the autocorrelation property is updated */
+	if (element->autocorrelation_matrix) {
+		element->maxdata->pad = autocorrelation_length(element) / 2;
+		if (element->snr_mat)
+			free(element->snr_mat);
+		element->snr_mat = calloc(element->num_templates * autocorrelation_length(element), element->maxdata->unit);
+	}
+
+	/*
+	 * done
+	 */
 
 	return success;
 }
@@ -316,8 +375,16 @@ static gboolean start(GstBaseTransform *trans)
 		XLALINT8NSToGPS(&element->bank[i].peak_time, 0);
 		element->bank[i].snr = 0;
 
-		/* initialize the last time array, too */
+		/*
+		 * Initialize the chisq and chisq_dof, too.
+		 * We follow the definition of the previous string search pipeline,
+		 * The actual chi^2 is then chisq/chisq_dof. We can come
+		 * back to the definition later if we have to.
+		 */
+		element->bank[i].chisq = 0;
+		element->bank[i].chisq_dof = 0;
 
+		/* initialize the last time array, too */
 		XLALINT8NSToGPS(&element->last_time[i], 0);
 	}
 
@@ -347,6 +414,17 @@ static GstFlowReturn transform(GstBaseTransform *trans, GstBuffer *inbuf, GstBuf
 {
 	GSTLALStringTriggergen *element = GSTLAL_STRING_TRIGGERGEN(trans);
 	GstFlowReturn result;
+	guint64 maxsize;
+
+	/* The max size to copy from an adapter is the typical output size plus the padding */
+	maxsize = output_num_bytes(element) + element->adapter->unit_size * element->maxdata->pad * 2;
+
+	/* if we haven't allocated storage do it now, we should never try to copy from an adapter with a larger buffer than this */
+	if (!element->data)
+		element->data = malloc(maxsize);
+	
+	/* put the incoming buffer into an adapter */
+	gst_audioadapter_push(element->adapter, inbuf);
 
 	result = trigger_generator(element,inbuf,outbuf);
 
@@ -373,7 +451,8 @@ static GstFlowReturn transform(GstBaseTransform *trans, GstBuffer *inbuf, GstBuf
 enum property {
 	ARG_THRES = 1,
 	ARG_CLUSTER,
-	ARG_BANK_FILENAME
+	ARG_BANK_FILENAME,
+	ARG_AUTOCORRELATION_MATRIX
 };
 
 
@@ -398,6 +477,33 @@ static void set_property(GObject *object, enum property prop_id, const GValue *v
 		g_mutex_unlock(&element->bank_lock);
 		break;
 
+	case ARG_AUTOCORRELATION_MATRIX:
+		g_mutex_lock(&element->bank_lock);
+		if(element->autocorrelation_matrix)
+			gsl_matrix_complex_free(element->autocorrelation_matrix);
+		element->autocorrelation_matrix = gstlal_gsl_matrix_complex_from_g_value_array(g_value_get_boxed(value));
+		fprintf(stderr,"done setting autocorrelation matrix\n");
+
+		/* This should be called any time caps change too */
+		if(element->maxdata && element->autocorrelation_matrix){
+			fprintf(stderr, "autocorrelation length = %d\n", autocorrelation_length(element));
+			element->maxdata->pad = autocorrelation_length(element) / 2;
+			if (element->snr_mat)
+				free(element->snr_mat);
+			element->snr_mat = calloc(element->num_templates * autocorrelation_length(element), element->maxdata->unit);
+		}
+		
+		/*
+		 * induce norms to be recomputed
+		 */
+		if(element->autocorrelation_norm) {
+			gsl_vector_free(element->autocorrelation_norm);
+			element->autocorrelation_norm = NULL;
+		}
+
+		g_mutex_unlock(&element->bank_lock);
+		break;
+
 	default:
 		G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec);
 		break;
@@ -427,6 +533,18 @@ static void get_property(GObject *object, enum property prop_id, GValue *value,
 		g_value_set_string(value, element->bank_filename);
 		g_mutex_unlock(&element->bank_lock);
 		break;
+	
+	case ARG_AUTOCORRELATION_MATRIX:
+		g_mutex_lock(&element->bank_lock);
+		if(element->autocorrelation_matrix)
+			g_value_take_boxed(value, gstlal_g_value_array_from_gsl_matrix_complex(element->autocorrelation_matrix));
+		else {
+			GST_WARNING_OBJECT(element, "no autocorrelation matrix");
+			/* FIXME deprecated.. */
+			g_value_take_boxed(value, g_value_array_new(0)); 
+			}
+		g_mutex_unlock(&element->bank_lock);
+		break;
 
 	default:
 		G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec);
@@ -442,12 +560,35 @@ static void finalize(GObject *object)
 	GSTLALStringTriggergen *element = GSTLAL_STRING_TRIGGERGEN(object);
 	g_mutex_clear(&element->bank_lock);
 	free_bankfile(element);
+	if(element->maxdata) {
+		gstlal_peak_state_free(element->maxdata);
+		element->maxdata = NULL;
+	}
+	if(element->data){
+		free(element->data);
+		element->data = NULL;
+	}
+
 	g_free(element->instrument);
 	element->instrument = NULL;
 	g_free(element->channel_name);
 	element->channel_name = NULL;
+
 	gst_audioadapter_clear(element->adapter);
 	g_object_unref(element->adapter);
+
+	if(element->snr_mat) {
+		free(element->snr_mat);
+		element->snr_mat = NULL;
+	}
+	if(element->autocorrelation_matrix) {
+		gsl_matrix_complex_free(element->autocorrelation_matrix);
+		element->autocorrelation_matrix = NULL;
+	}
+	if(element->autocorrelation_norm) {
+		gsl_vector_free(element->autocorrelation_norm);
+		element->autocorrelation_norm = NULL;
+	}
 	G_OBJECT_CLASS(gstlal_string_triggergen_parent_class)->finalize(object);
 }
 
@@ -538,6 +679,30 @@ static void gstlal_string_triggergen_class_init(GSTLALStringTriggergenClass *kla
 			G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS | G_PARAM_CONSTRUCT
 		)
 	);
+	g_object_class_install_property(
+		gobject_class,
+		ARG_AUTOCORRELATION_MATRIX,
+		g_param_spec_value_array(
+			"autocorrelation-matrix",
+			"Autocorrelation Matrix",
+			"Array of autocorrelation vectors.  Number of vectors (rows) in matrix sets number of channels.  All vectors must have the same length.",
+			g_param_spec_value_array(
+				"autocorrelation",
+				"Autocorrelation",
+				"Array of autocorrelation samples.",
+				/* FIXME:  should be complex */
+				g_param_spec_double(
+					"sample",
+					"Sample",
+					"Autocorrelation sample",
+					-G_MAXDOUBLE, G_MAXDOUBLE, 0.0,
+					G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS
+				),
+				G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS
+			),
+			G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS
+		)
+	);
 }
 
 /*
@@ -551,10 +716,15 @@ static void gstlal_string_triggergen_init(GSTLALStringTriggergen *element)
 	element->adapter = g_object_new(GST_TYPE_AUDIOADAPTER, NULL);
 	element->bank_filename = NULL;
 	element->bank = NULL;
+	element->data = NULL;
+	element->maxdata = NULL;
 	element->instrument = NULL;
 	element->channel_name = NULL;
 	element->num_templates = 0;
 	element->last_time = NULL;
+	element->snr_mat = NULL;
 	element->audio_info.bpf = 0;	/* impossible value */
+	element->autocorrelation_matrix = NULL;
+	element->autocorrelation_norm = NULL;
 	gst_base_transform_set_gap_aware(GST_BASE_TRANSFORM(element), TRUE);
 }
diff --git a/gstlal-burst/gst/lal/gstlal_string_triggergen.h b/gstlal-burst/gst/lal/gstlal_string_triggergen.h
index f33056f3cf..4926a6fe93 100644
--- a/gstlal-burst/gst/lal/gstlal_string_triggergen.h
+++ b/gstlal-burst/gst/lal/gstlal_string_triggergen.h
@@ -5,9 +5,14 @@
 #include <glib.h>
 #include <gst/gst.h>
 #include <gst/audio/audio.h>
+#include <gst/base/gstadapter.h>
 #include <gst/base/gstbasetransform.h>
 #include <gstlal/gstaudioadapter.h>
+#include <gstlal/gstlal_peakfinder.h>
 #include <lal/LIGOMetadataTables.h>
+#include <gsl/gsl_matrix.h>
+#include <gsl/gsl_matrix_float.h>
+
 
 G_BEGIN_DECLS
 
@@ -40,20 +45,21 @@ typedef struct {
 	
 	GstAudioInfo audio_info;
 
-	/*
-	 * extracting triggers above threshold
-	 */
-
 	float threshold;
 	float cluster;
 
 	GMutex bank_lock;
+	gsl_matrix_complex *autocorrelation_matrix;
+	gsl_vector *autocorrelation_norm;
 	char *bank_filename;
+	SnglBurst *bank;
+	void *data;
+	struct gstlal_peak_state *maxdata;
 	gchar *instrument;
 	gchar *channel_name;
-	SnglBurst *bank;
 	gint num_templates;
 	LIGOTimeGPS *last_time;
+	void *snr_mat;
 } GSTLALStringTriggergen;
 
 
-- 
GitLab