Skip to content
Snippets Groups Projects
Commit 25a2518f authored by Chad Hanna's avatar Chad Hanna
Browse files

gstlal_inspiral_mass_model: rewrite

parent bf1fdd00
No related branches found
No related tags found
No related merge requests found
......@@ -24,8 +24,7 @@ from glue.ligolw import lsctables, param as ligolw_param, array as ligolw_array
from glue.ligolw import utils as ligolw_utils
from glue.ligolw.utils import process as ligolw_process
import lal.series
from gstlal.stats.inspiral_lr import TYPICAL_HORIZON_DISTANCE
from gstlal import svd_bank
from lal import rate
@ligolw_array.use_in
@ligolw_param.use_in
......@@ -34,43 +33,43 @@ class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
pass
parser = argparse.ArgumentParser(description = "Create analytic mass models for prior weighting of templates")
parser.add_argument("--svd-bank", metavar='name', type=str, help='The input svd bank file name. Can be specified multiple times', action="append", required = True)
parser.add_argument("--reference-psd", metavar='name', type=str, help='The input psd file name', required = True)
parser.add_argument("--instrument", metavar='name', type=str, help='The instrument to use, e.g., H1', required = True)
parser.add_argument("--template-bank", metavar='name', type=str, help='The input template bank file name.', required = True)
parser.add_argument("--output", metavar='name', type=str, help='The output file name', default = "inspiral_mass_model.h5")
parser.add_argument("--model", metavar='name', type=str, help='Mass model. Options are salpeter, uniform-in-template')
parser.add_argument("--model", metavar='name', type=str, help='Mass model. Options are: salpeter. If you want another one, submit a patch.')
parser.add_argument("--verbose", help='Be verbose', action="store_true")
options = parser.parse_args()
# Read the PSD file
psd = lal.series.read_psd_xmldoc(ligolw_utils.load_filename(options.reference_psd, verbose = True, contenthandler = lal.series.PSDContentHandler))[options.instrument]
# Read the template bank file
xmldoc = ligolw_utils.load_filename(options.template_bank, verbose = options.verbose, contenthandler = LIGOLWContentHandler)
sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)
mass1 = sngl_inspiral_table.get_column("mass1")
mass2 = sngl_inspiral_table.get_column("mass2")
num_templates = len(mass1)
num_bins = max(2, int((num_templates / 100.)**.5))
min_mass = min(min(mass1), min(mass2)) - 1.e-6
max_mass = max(max(mass1), max(mass2)) + 1.e-6
massBA = rate.BinnedDensity(rate.NDBins((rate.LogarithmicBins(min_mass, max_mass, num_bins), rate.LogarithmicBins(min_mass, max_mass, num_bins))))
print min_mass, max_mass
for m1, m2 in zip(mass1, mass2):
massBA.count[(m1, m2)] += 1
massBA.count[(m2, m1)] += 1
rate.filter_array(massBA.array, rate.gaussian_window(1.5, 1.5, sigma = 5))
# Assign the proper mass probabilities
ids = {}
for svdbank in options.svd_bank:
banks = {options.instrument:[]}
sngl_inspiral_table = []
for n, bank in enumerate(svd_bank.read_banks(svdbank, contenthandler = LIGOLWContentHandler, verbose = True)):
banks[options.instrument].append(bank)
sngl_inspiral_table.extend(bank.sngl_inspiral_table)
horizon_distance_function = svd_bank.make_horizon_distance_func(banks)
horizon_distance = horizon_distance_function(psd)[0]
for row in sngl_inspiral_table:
assert row.template_id not in ids
if options.model == "salpeter":
ids[row.template_id] = numpy.log(row.mass1**-2.35)
elif options.model == "uniform-in-template":
# the LR code has a horizon distance term to make each
# template uniform you have to undo it.
ids[row.template_id] = numpy.log((horizon_distance / TYPICAL_HORIZON_DISTANCE)**3)
else:
raise ValueError("Invalid mass model")
for row in sngl_inspiral_table:
assert row.template_id not in ids
if options.model == "salpeter":
ids[row.template_id] = numpy.log(row.mass1**-2.35 / massBA[row.mass1, row.mass2])
else:
raise ValueError("Invalid mass model")
coefficients = numpy.zeros((1, 1, max(ids)+1), dtype=float)
for tid in ids:
coefficients[0,0,tid] = ids[tid]
# Write it out
f = h5py.File(options.output, "w")
print coefficients
# put in a dummy interval for the piecewise polynomials in SNR
f.create_dataset("SNR", data = numpy.array([0., 100.]))
f.create_dataset("coefficients", data = coefficients, compression="gzip")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment