Skip to content
Snippets Groups Projects
Forked from lscsoft / GstLAL
2167 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
construct_skymap_test_dag 8.73 KiB
#!/usr/bin/env python
'''
./construct_skymap_test_dag path/to/injection/database path/to/tmp/space max_number_inspiral_jobs
'''

import sqlite3
import sys
import os
import glob

from ligo.lw import ligolw
from ligo.lw import lsctables
from ligo.lw import utils as ligolw_utils
from ligo.lw.utils import process as ligolw_process
from ligo.lw import dbtables
from gstlal import dagparts

# copied from gstlal_inspiral_plotsummary
def create_sim_coinc_view(connection):
	"""
	Construct a sim_inspiral --> best matching coinc_event mapping.
	Only injections that match at least one coinc get an entry in this
	table.
	"""
	#
	# the log likelihood ratio stored in the likelihood column of the
	# coinc_event table is the ranking statistic.  the "best match" is
	# the coinc with the highest value in this column.  although it has
	# not been true in the past, there is now a one-to-one relationship
	# between the value of this ranking statistic and false-alarm rate,
	# therefore it is OK to order by log likelihood ratio and then,
	# later, impose a "detection" threshold based on false-alarm rate.
	#

	connection.cursor().execute("""
CREATE TEMPORARY TABLE
	sim_coinc_map_helper
AS
	SELECT a.event_id as sid,
		coinc_event.coinc_event_id as cid,
		coinc_event.likelihood as lr
	FROM coinc_event_map as a
		JOIN coinc_event_map AS b ON (b.coinc_event_id == a.coinc_event_id)
		JOIN coinc_event ON (coinc_event.coinc_event_id == b.event_id)
	WHERE a.table_name == "sim_inspiral"
		AND b.table_name == "coinc_event"
		AND NOT EXISTS (SELECT * FROM time_slide WHERE time_slide.time_slide_id == coinc_event.time_slide_id AND time_slide.offset != 0);
	""")

	connection.cursor().execute("CREATE INDEX IF NOT EXISTS sim_coinc_map_helper_index ON sim_coinc_map_helper (sid, cid);")

	connection.cursor().execute("""
CREATE TEMPORARY TABLE
        sim_coinc_map
AS
        SELECT
                sim_inspiral.simulation_id AS simulation_id,
                (
                        SELECT
                                cid
                        FROM
				sim_coinc_map_helper
                        WHERE
                                sid = simulation_id
                        ORDER BY
                                lr
			DESC
                        LIMIT 1
                ) AS coinc_event_id
        FROM
                sim_inspiral
        WHERE
                coinc_event_id IS NOT NULL;

	""")

	connection.cursor().execute("DROP INDEX sim_coinc_map_helper_index;")

	connection.cursor().execute("""
CREATE TEMPORARY TABLE
	sim_id_combined_far
AS
	SELECT
		coinc_inspiral.combined_far AS far, sim_coinc_map.simulation_id AS sim_id
	FROM
		sim_coinc_map
		JOIN coinc_inspiral ON ( coinc_inspiral.coinc_event_id == sim_coinc_map.coinc_event_id  )
	""")

	connection.cursor().execute("""
CREATE TEMPORARY TABLE
	sim_id_sngl_id
	AS
		SELECT
			sim_coinc_map.simulation_id AS sim_id, sngl_inspiral.event_id AS sngl_id
		FROM
			sim_coinc_map
		JOIN coinc_event_map as mapA ON ( mapA.coinc_event_id == sim_coinc_map.coinc_event_id )
		JOIN sngl_inspiral ON ( sngl_inspiral.event_id == mapA.event_id )
	""")

	connection.cursor().execute("CREATE INDEX IF NOT EXISTS sim_id_combined_far_index ON sim_id_combined_far (far, sim_id)")
	connection.cursor().execute("CREATE INDEX IF NOT EXISTS sim_id_sngl_id_index ON sim_id_sngl_id (sim_id, sngl_id)")

	connection.cursor().execute("""
CREATE TEMPORARY TABLE
	sim_sngl_far
	AS
		SELECT
			sngl_inspiral.process_id AS pid,
			sngl_inspiral.Gamma1 AS Gamma1,
			sim_id_combined_far.far AS far,
			sim_inspiral.simulation_id AS simulation_id,
			sim_inspiral.*
		FROM
			sim_inspiral
			JOIN sim_id_sngl_id ON (
				sim_inspiral.simulation_id == sim_id_sngl_id.sim_id
			)
			JOIN sngl_inspiral ON (
				sngl_inspiral.event_id == sim_id_sngl_id.sngl_id
			)
			JOIN sim_id_combined_far ON (
				sim_id_combined_far.sim_id == sim_id_sngl_id.sim_id
			)
	""")

	connection.cursor().execute("DROP INDEX sim_id_combined_far_index")
	connection.cursor().execute("DROP INDEX sim_id_sngl_id_index")

inj_db = sys.argv[1]
tmp_space = sys.argv[2]
num_inspiral_jobs = int(sys.argv[3])
analysis_dir = os.path.dirname(inj_db)

working_filename = dbtables.get_connection_filename(inj_db, tmp_path = tmp_space, verbose = True)
connection = sqlite3.connect(working_filename)

create_sim_coinc_view(connection)
sim_row = {}
xmldoc = dbtables.get_xml(connection)
sim_inspiral_table = lsctables.SimInspiralTable.get_table(xmldoc)
for record in connection.cursor().execute("""
SELECT 
	*
FROM 
	sim_sngl_far 
WHERE 
	far <= 3.86e-7 
ORDER BY 
	far ASC
LIMIT ? 
""", (int(num_inspiral_jobs),)):
	process_id = record[0]
	bank_id = record[1]
	far = record[2]
	simid = record[3]
	simrow = record[4:]
	sim_row[(bank_id, process_id, simid)] = sim_inspiral_table.row_from_cols(simrow)

master_opts_dict = { 
	"gps-start-time":None,
	"gps-end-time": None,
	"psd-fft-length": 32,
	"likelihood-snapshot-interval": 100000.0,
	"track-psd": "",
	"min-instruments": None,
	"gracedb-far-threshold": 1e-6,
	"gracedb-service-url": None,
	"ht-gate-threshold": 50.0,
	"veto-segments-name": "vetoes",
	"fir-stride": 0.25,
	"gracedb-group": "CBC",
	"coincidence-threshold": 0.005,
	"control-peak-time": 0,
	"gracedb-pipeline": "gstlal",
	"data-source": None,
	"frame-segments-name": None,
	"tmp-space": None,
	"gracedb-search": "AllSky",
	"channel-name": None,
	"singles-threshold": "inf",
	"verbose": ""
}

master_input_dict = {
	"reference-psd": None,
	"svd-bank": None, # FIXME THIS ONE IS TRICKY
	"ranking-stat-pdf": "%s/post_marginalized_likelihood.xml.gz" % analysis_dir,
	"ranking-stat-input": None, # FIXME THIS ONE IS TRICKY
	"veto-segments-file": None,
	"frame-segments-file": None,
	"frame-cache": None,
	"time-slide-file": None,
	"injections": None, # FIXME make this just a single injection with the correct parameters
	}

master_output_dict = {
	"ranking-stat-output": "not_used.xml.gz",
	"zerolag-rankingstat-pdf": "notused2.xml.gz",
	"output": None,
}

try:
	os.mkdir("logs")
except:
	pass
dag = dagparts.DAG("trigger_pipe")

gstlalInspiralInjJob = dagparts.DAGJob("gstlal_inspiral",
	tag_base="gstlal_inspiral_inj",
	condor_commands = {"request_memory":"5gb", 
		"request_cpus":"2",
		"want_graceful_removal":"True",
		"kill_sig":"15"}
	)

def updatedict(x, y):
	for k in x:
		if x[k] is None:
			try:
				x[k] = y[k]
			except KeyError as e:
				pass

def fixrelpath(x, ys):
	for y in ys:
		x[y] = "%s/%s" % (analysis_dir, x[y][0])

def new_inj_file(row, output):
	xmldoc = ligolw.Document()
	lw = xmldoc.appendChild(ligolw.LIGO_LW())
	sim_inspiral_table = lsctables.New(lsctables.SimInspiralTable)
	lw.appendChild(sim_inspiral_table)
	sim_inspiral_table.append(row)
	ligolw_utils.write_filename(xmldoc, output, gz = output.endswith('gz'))


try:
	os.mkdir("inj_files")
except OSError:
	pass

try:
	os.mkdir("lloid_files")
except OSError:
	pass

for job_id, (bankid, process_id, simid) in enumerate(sim_row, start=1):
	# FIXME Need to add option for dist stats output
	print "++ job_id: %s ++" % job_id
	job_dict = {}
	for param, value in connection.cursor().execute("SELECT param, value FROM process_params WHERE process_id == ?", (process_id,)):
		job_dict.setdefault(param.replace("--",""), []).append(value)
	this_opts_dict = master_opts_dict.copy()
	updatedict(this_opts_dict, job_dict)
	this_input_dict = master_input_dict.copy()
	updatedict(this_input_dict, job_dict)
	this_output_dict = master_output_dict.copy()
	updatedict(this_output_dict, job_dict)

	# FIX some stuff
	fixrelpath(this_input_dict, ("reference-psd", "frame-cache", "time-slide-file", "veto-segments-file", "frame-segments-file"))

	# make a custom injection file
	inj_file_name = "inj_files/%d_%d_%d_inj.xml.gz" % (job_id, bankid, process_id)
	new_inj_file(sim_row[(bank_id, process_id, simid)], inj_file_name)
	this_input_dict["injections"] = inj_file_name

	# FIXME hacks for the svd
	instruments = [x.split("=")[0] for x in this_opts_dict["channel-name"]]
	banks = ["%s:%s" % (ifo, glob.glob("%s/gstlal_svd_bank/%s-%04d_SVD*" % (analysis_dir, ifo, bankid))[0]) for ifo in instruments]
	this_input_dict["svd-bank"] = ",".join(banks)

	# FIXME don't hardcode H1L1V1
	ranking_stat_pdf = glob.glob("%s/gstlal_inspiral_marginalize_likelihood/H1L1V1-%04d_MARG_DIST_STATS*" % (analysis_dir, bankid))[0]
	this_input_dict["ranking-stat-input"] = ranking_stat_pdf

	# just name the output the same as the input
	outdir = "lloid_files/%d_%d_%d" % (job_id, bankid, process_id)
	try:
		os.mkdir(outdir)
	except OSError:
		pass
	
	output_file_name = "%s/%d_%d_%d_lloid.xml.gz" % (outdir, job_id, bankid, process_id)
	this_output_dict["output"] = output_file_name

	this_opts_dict["gracedb-service-url"] = "file://%s/%s" % (os.getcwd(), outdir)

	dagparts.DAGNode(gstlalInspiralInjJob, dag, parent_nodes = [], opts = this_opts_dict, input_files = this_input_dict, output_files = this_output_dict)

dag.write_sub_files()
dag.write_dag()
dag.write_script()