Commit 136f2e50 authored by Leo P. Singer's avatar Leo P. Singer
Browse files

Parallelize bayestar_aggregate_found_injections

Original: c785967d7c8b7a849bfd7c463f704f85f8f48f48
parent 26f8ac50
......@@ -35,31 +35,35 @@ arguments, and may also be provided as globs (such as '*.fits.gz').
__author__ = "Leo Singer <leo.singer@ligo.org>"
# Command line interface.
from optparse import Option, OptionParser
from lalinference.bayestar import command
parser = OptionParser(
formatter=command.NewlinePreservingHelpFormatter(),
description=__doc__,
usage="%prog [-o OUTPUT] DATABASE.sqlite FILE1.fits[.gz] FILE2.fits[.gz] ...",
option_list=[
Option("-o", "--output", default="/dev/stdout",
help="Name of output file [default=%default]")
]
)
opts, args = parser.parse_args()
try:
dbfilename = args[0]
fitsfileglobs = args[1:]
except IndexError:
parser.error("not enough command line arguments")
outfile = open(opts.output, "w")
if __name__ == '__main__':
# Command line interface.
from optparse import Option, OptionParser
from lalinference.bayestar import command
parser = OptionParser(
formatter=command.NewlinePreservingHelpFormatter(),
description=__doc__,
usage="%prog [-o OUTPUT] DATABASE.sqlite FILE1.fits[.gz] FILE2.fits[.gz] ...",
option_list=[
Option("-o", "--output", default="/dev/stdout",
help="Name of output file [default=%default]"),
Option("-j", "--jobs", default=1, type=int,
help="Number of threads [default=%default]")
]
)
opts, args = parser.parse_args()
try:
dbfilename = args[0]
fitsfileglobs = args[1:]
except IndexError:
parser.error("not enough command line arguments")
outfile = open(opts.output, "w")
# Imports.
import glob
import functools
import itertools
import os
import numpy as np
......@@ -69,10 +73,6 @@ from pylal.progress import ProgressBar
from lalinference.bayestar import fits
fitsfilenames = itertools.chain.from_iterable(glob.iglob(fitsfileglob)
for fitsfileglob in fitsfileglobs)
sql = """
SELECT DISTINCT sim.longitude AS ra, sim.latitude AS dec, ci.combined_far AS far
FROM coinc_event_map AS cem1 INNER JOIN coinc_event_map AS cem2
......@@ -143,15 +143,13 @@ def find_injection(sky_map, true_ra, true_dec):
# Done.
return searched_area, searched_prob, offset
progress = ProgressBar()
progress.update(-1, 'opening database')
db = sqlite3.connect(dbfilename)
def startup(dbfilename):
global db
db = sqlite3.connect(dbfilename)
print('objid', 'far', 'searched_area', 'searched_prob', 'offset', 'runtime',
sep=',', file=outfile)
for fitsfilename in progress.iterate(fitsfilenames):
def process(fitsfilename):
sky_map, metadata = fits.read_sky_map(fitsfilename)
coinc_event_id = metadata['objid']
......@@ -163,5 +161,29 @@ for fitsfilename in progress.iterate(fitsfilenames):
true_ra, true_dec, far = db.execute(sql, (coinc_event_id,)).fetchone()
searched_area, searched_prob, offset = find_injection(sky_map, true_ra, true_dec)
print(coinc_event_id, far, searched_area, searched_prob, offset, runtime,
return coinc_event_id, far, searched_area, searched_prob, offset, runtime
if __name__ == '__main__':
if opts.jobs == 1:
from itertools import imap
startup(dbfilename)
else:
import multiprocessing
imap = multiprocessing.Pool(opts.jobs, startup, (dbfilename,)).imap_unordered
progress = ProgressBar()
progress.update(-1, 'obtaining filenames of sky maps')
fitsfilenames = tuple(itertools.chain.from_iterable(glob.iglob(fitsfileglob)
for fitsfileglob in fitsfileglobs))
print('objid', 'far', 'searched_area', 'searched_prob', 'offset', 'runtime',
sep=',', file=outfile)
count_records = 0
progress.max = len(fitsfilenames)
for record in imap(functools.partial(process), fitsfilenames):
count_records += 1
progress.update(count_records, record[0])
print(*record, sep=',', file=outfile)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment