Skip to content
Snippets Groups Projects
Commit 636135b6 authored by Patrick Godwin's avatar Patrick Godwin Committed by ChiWai Chan
Browse files

datafind.py: allow DataCache.generate() to use custom file extensions

parent 27022c0b
No related branches found
No related tags found
No related merge requests found
......@@ -716,10 +716,11 @@ def cluster_layer(config, dag, trigger_cache):
root="rank",
)
trigger_db_cache = DataCache.generate(
DataType.TRIGGER_DATABASE,
DataType.TRIGGERS,
config.all_ifos,
config.time_boundaries,
subtype=inj_types,
extension="sqlite",
root="rank"
)
......
......@@ -42,10 +42,12 @@ class DataFileMixin:
description.append(subtype.upper())
return "_".join(description)
def filename(self, ifos, span=None, svd_bin=None, subtype=None):
def filename(self, ifos, span=None, svd_bin=None, subtype=None, extension=None):
if not span:
span = segment(0, 0)
return T050017_filename(ifos, self.description(svd_bin, subtype), span, self.extension)
if not extension:
extension = self.extension
return T050017_filename(ifos, self.description(svd_bin, subtype), span, extension)
def file_pattern(self, svd_bin=None, subtype=None):
return f"*-{self.description(svd_bin, subtype)}-*-*{self.extension}"
......@@ -64,7 +66,6 @@ class DataType(DataFileMixin, Enum):
MEDIAN_PSD = (1, "xml.gz")
SMOOTH_PSD = (2, "xml.gz")
TRIGGERS = (10, "xml.gz")
TRIGGER_DATABASE = (11, "sqlite")
DIST_STATS = (20, "xml.gz")
PRIOR_DIST_STATS = (21, "xml.gz")
MARG_DIST_STATS = (22, "xml.gz")
......@@ -167,7 +168,17 @@ class DataCache:
return DataCache.from_files(self.name, cache_paths)
@classmethod
def generate(cls, name, ifos, time_bins=None, svd_bins=None, subtype=None, root=None, create_dirs=True):
def generate(
cls,
name,
ifos,
time_bins=None,
svd_bins=None,
subtype=None,
extension=None,
root=None,
create_dirs=True
):
# format args
if isinstance(ifos, str) or isinstance(ifos, frozenset):
ifos = [ifos]
......@@ -194,11 +205,13 @@ class DataCache:
if svd_bins:
for svd_bin in svd_bins:
for stype in subtype:
filename = name.filename(ifo, span, svd_bin=svd_bin, subtype=stype)
filename = name.filename(
ifo, span, svd_bin=svd_bin, subtype=stype, extension=extension
)
cache.append(os.path.join(path, filename))
else:
for stype in subtype:
filename = name.filename(ifo, span, subtype=stype)
filename = name.filename(ifo, span, subtype=stype, extension=extension)
cache.append(os.path.join(path, filename))
return cls(name, [CacheEntry.from_T050017(entry) for entry in cache])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment