From b2bb1088d1d75e8f517f27eaba03ce4c04ae6a15 Mon Sep 17 00:00:00 2001 From: Patrick Godwin <patrick.godwin@ligo.org> Date: Thu, 29 Aug 2019 18:37:28 -0700 Subject: [PATCH] inspiral_pipe.py: add load_analysis_output() utility for reranking dag, allow more flexibility in get_bank_params() and marginalize_layer() --- gstlal-inspiral/python/inspiral_pipe.py | 56 +++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/gstlal-inspiral/python/inspiral_pipe.py b/gstlal-inspiral/python/inspiral_pipe.py index bbeafafaf0..289955b353 100644 --- a/gstlal-inspiral/python/inspiral_pipe.py +++ b/gstlal-inspiral/python/inspiral_pipe.py @@ -601,7 +601,8 @@ def marginalize_layer(dag, jobs, svd_nodes, lloid_output, lloid_diststats, optio # FIXME, the svd nodes list has to be the same as the sorted keys of # lloid_output. svd nodes should be made into a dictionary much # earlier in the code to prevent a mishap - one_ifo_svd_nodes = dict(("%04d" % n, node) for n, node in enumerate( svd_nodes.values()[0])) + if svd_nodes: + one_ifo_svd_nodes = dict(("%04d" % n, [node]) for n, node in enumerate( svd_nodes.values()[0])) # Here n counts the bins # FIXME - this is broken for injection dags right now because of marg nodes @@ -613,11 +614,18 @@ def marginalize_layer(dag, jobs, svd_nodes, lloid_output, lloid_diststats, optio parents = dagparts.flatten([o[1] for o in outputs]) rankfile = functools.partial(get_rank_file, instruments, boundary_seg, bin_key) + if svd_nodes: + parent_nodes = [one_ifo_svd_nodes[bin_key]] + model_node + svd_file = one_ifo_svd_nodes[bin_key].output_files["write-svd"] + else: + parent_nodes = model_node + svd_file = dagparts.T050017_filename(ifo, '%s_SVD' % bin_key, boundary_arg, '.xml.gz', path = jobs['svd'].output_path) + # FIXME we keep this here in case we someday want to have a # mass bin dependent prior, but it really doesn't matter for # the time being. priornode = dagparts.DAGNode(jobs['createPriorDistStats'], dag, - parent_nodes = [one_ifo_svd_nodes[bin_key]] + model_node or [], + parent_nodes = parent_nodes, opts = { "instrument": instrument_set, "background-prior": 1, @@ -626,7 +634,7 @@ def marginalize_layer(dag, jobs, svd_nodes, lloid_output, lloid_diststats, optio "df": "bandwidth" }, input_files = { - "svd-file": one_ifo_svd_nodes[bin_key].output_files["write-svd"], + "svd-file": svd_file, "mass-model-file": model_file, "dtdphi-file": svd_dtdphi_map[bin_key], "psd-xml": ref_psd @@ -1023,6 +1031,40 @@ def webserver_url(): # +def load_analysis_output(options): + # load triggers + bgbin_lloid_map = {} + for ce in map(CacheEntry, open(options.lloid_cache)): + try: + bgbin_idx, _, inj = ce.description.split('_', 2) + except: + bgbin_idx, _ = ce.description.split('_', 1) + inj = None + finally: + bgbin_lloid_map.setdefault(sim_tag_from_inj_file(inj), []).append(ce.path) + + # load dist stats + lloid_diststats = {} + boundary_seg = None + for ce in map(CacheEntry, open(options.dist_stats_cache)): + if 'DIST_STATS' in ce.description: + lloid_diststats[ce.description.split("_")[0]] = [ce.path] + if not boundary_seg: + boundary_seg = ce.segment + + # load svd dtdphi map + svd_dtdphi_map = {} + bank_cache = load_bank_cache(options) + instrument_set = bank_cache.keys() + for ifo, list_of_svd_caches in bank_cache.items(): + bin_offset = 0 + for j, svd_caches in enumerate(list_of_svd_caches): + for i, individual_svd_cache in enumerate(ce.path for ce in map(CacheEntry, open(svd_caches))): + svd_dtdphi_map["%04d" % (i+bin_offset)] = options.dtdphi_file[j] + + return bgbin_lloid_map, lloid_diststats, svd_dtdphi_map, instrument_set, boundary_seg + + def get_threshold_values(template_mchirp_dict, bgbin_indices, svd_bank_strings, options): """Calculate the appropriate ht-gate-threshold values according to the scale given """ @@ -1173,7 +1215,7 @@ def sim_tag_from_inj_file(injections): return injections.replace('.xml', '').replace('.gz', '').replace('-','_') -def get_bank_params(options, verbose = False): +def load_bank_cache(options): bank_cache = {} for bank_cache_str in options.bank_cache: for c in bank_cache_str.split(','): @@ -1181,6 +1223,12 @@ def get_bank_params(options, verbose = False): cache = c.replace(ifo+"=","") bank_cache.setdefault(ifo, []).append(cache) + return bank_cache + + +def get_bank_params(options, verbose = False): + bank_cache = load_bank_cache(options) + max_time = 0 template_mchirp_dict = {} for n, cache in enumerate(bank_cache.values()[0]): -- GitLab