Skip to content
Snippets Groups Projects
Commit 6a760352 authored by Patrick Godwin's avatar Patrick Godwin
Browse files

add missing type annotations throughout

parent ab5ad1ae
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@ from collections import defaultdict
import os
from pathlib import Path
import re
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple
from htcondor import dags
......@@ -23,14 +23,14 @@ class DAG(dags.DAG):
"""
def __init__(self, config=None, *args, **kwargs):
def __init__(self, config=None, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.config = config
self._node_layers = {}
self._layers = {}
self._provides = {}
self._node_layers: Dict[str, dags.NodeLayer] = {}
self._layers: Dict[str, Layer] = {}
self._provides: Dict[str, Tuple[str, int]] = {}
def attach(self, layer: Layer):
def attach(self, layer: Layer) -> None:
"""Attach a layer of related job nodes to this DAG.
Parameters
......@@ -41,7 +41,7 @@ class DAG(dags.DAG):
"""
key = layer.name
if key in self._layers:
return KeyError(f"{key} layer already added to DAG")
raise KeyError(f"{key} layer already added to DAG")
self._layers[layer.name] = layer
# determine parent-child relationships and connect accordingly
......@@ -77,14 +77,14 @@ class DAG(dags.DAG):
for output in node.provides:
self._provides[output] = (key, idx)
def create_log_dir(self, log_dir: Path = Path("logs")):
def create_log_dir(self, log_dir: Path = Path("logs")) -> None:
"""Create the log directory where job logs are stored.
If not specified, creates a log directory in ./logs
"""
os.makedirs(log_dir, exist_ok=True)
def write_dag(self, filename: str, path: Path = Path.cwd(), **kwargs):
def write_dag(self, filename: str, path: Path = Path.cwd(), **kwargs) -> None:
"""Write out the given DAG to the given directory.
This includes the DAG description file itself, as well as any
......@@ -97,7 +97,7 @@ class DAG(dags.DAG):
filename: str,
path: Path = Path.cwd(),
formatter: Optional[dags.NodeNameFormatter] = None,
):
) -> None:
if not formatter:
formatter = HexFormatter()
......@@ -116,7 +116,7 @@ class DAG(dags.DAG):
print(f"# Job {node_name}", file=f)
print(executable + " " + args.format(**node_vars) + "\n", file=f)
def _get_edge_type(self, parent_name, child_name, edges):
def _get_edge_type(self, parent_name, child_name, edges) -> dags.BaseEdge:
parent = self._layers[parent_name]
child = self._layers[child_name]
edges = sorted(list(edges))
......@@ -137,7 +137,7 @@ class DAG(dags.DAG):
class HexFormatter(dags.SimpleFormatter):
"""A hex-based node formatter that produces names like LayerName_000C."""
def __init__(self, offset: int = 0):
def __init__(self, offset: int = 0) -> None:
self.separator = "."
self.index_format = "{:05X}"
self.offset = offset
......@@ -151,7 +151,7 @@ class HexFormatter(dags.SimpleFormatter):
class EdgeConnector(dags.BaseEdge):
"""This edge connects individual nodes in layers given an explicit mapping."""
def __init__(self, indices):
def __init__(self, indices) -> None:
self.indices = indices
def get_edges(self, parent, child, join_factory):
......@@ -164,7 +164,7 @@ def write_dag(
dag_dir: Path = Path.cwd(),
formatter: Optional[dags.NodeNameFormatter] = None,
**kwargs,
):
) -> None:
"""Write out the given DAG to the given directory.
This includes the DAG description file itself, as well as any associated
......
......@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
import itertools
import os
import shutil
from typing import Union
from typing import Any, Dict, List, Union
import htcondor
......@@ -67,16 +67,16 @@ class Layer:
outputs: dict = field(default_factory=dict)
nodes: list = field(default_factory=list)
def __post_init__(self):
def __post_init__(self) -> None:
if not self.name:
self.name = os.path.basename(self.executable)
def config(self):
def config(self) -> Dict[str, Any]:
# check that nodes are valid
self.validate()
# add base submit opts + requirements
submit_options = {
submit_options: Dict[str, Any] = {
"universe": self.universe,
"executable": shutil.which(self.executable),
"arguments": self._arguments(),
......@@ -135,7 +135,7 @@ class Layer:
"retries": self.retries,
}
def append(self, node: Node):
def append(self, node: Node) -> None:
"""Append a node to this layer."""
assert isinstance(node.inputs, list)
assert isinstance(node.outputs, list)
......@@ -145,19 +145,19 @@ class Layer:
self.outputs.setdefault(output.name, []).append(output.argument)
self.nodes.append(node)
def extend(self, nodes: Iterable[Node]):
def extend(self, nodes: Iterable[Node]) -> None:
"""Append multiple nodes to this layer."""
for node in nodes:
self.append(node)
def __iadd__(self, nodes):
def __iadd__(self, nodes) -> Layer:
if isinstance(nodes, Iterable):
self.extend(nodes)
else:
self.append(nodes)
return self
def validate(self):
def validate(self) -> None:
"""Ensure all nodes in this layer are consistent with each other."""
assert self.nodes, "at least one node must be connected to this layer"
......@@ -175,11 +175,11 @@ class Layer:
assert outputs == [arg.name for arg in node.outputs]
@property
def has_dependencies(self):
def has_dependencies(self) -> bool:
"""Check if any of the nodes in this layer have dependencies."""
return any([node.requires for node in self.nodes])
def _arguments(self):
def _arguments(self) -> str:
args = [f"$({arg.condor_name})" for arg in self.nodes[0].arguments]
io_args = []
io_opts = []
......@@ -191,20 +191,20 @@ class Layer:
io_opts.append(f"$({arg.condor_name})")
return " ".join(itertools.chain(args, io_opts, io_args))
def _inputs(self):
def _inputs(self) -> str:
return ",".join([f"$(input_{arg.condor_name})" for arg in self.nodes[0].inputs])
def _outputs(self):
def _outputs(self) -> str:
return ",".join(
[f"$(output_{arg.condor_name})" for arg in self.nodes[0].outputs]
)
def _output_remaps(self):
def _output_remaps(self) -> str:
return ";".join(
[f"$(output_{arg.condor_name}_remap)" for arg in self.nodes[0].outputs]
)
def _vars(self):
def _vars(self) -> List[Dict[str, str]]:
allvars = []
for i, node in enumerate(self.nodes):
nodevars = {"nodename": f"{self.name}_{i:05X}"}
......@@ -297,7 +297,7 @@ class Node:
inputs: Union[Argument, Option, list] = field(default_factory=list)
outputs: Union[Argument, Option, list] = field(default_factory=list)
def __post_init__(self):
def __post_init__(self) -> None:
if isinstance(self.arguments, Argument) or isinstance(self.arguments, Option):
self.arguments = [self.arguments]
if isinstance(self.inputs, Argument) or isinstance(self.inputs, Option):
......@@ -306,7 +306,7 @@ class Node:
self.outputs = [self.outputs]
@property
def requires(self):
def requires(self) -> List[str]:
"""
Returns
-------
......@@ -320,7 +320,7 @@ class Node:
)
@property
def provides(self):
def provides(self) -> List[str]:
"""
Returns
-------
......
......@@ -64,7 +64,7 @@ class Argument:
suppress_with_remap: bool = False
_args: List[str] = field(init=False)
def __post_init__(self):
def __post_init__(self) -> None:
# check against list of protected condor names/characters,
# rename condor variables name to avoid issues
self.condor_name = self.name.replace("-", "_")
......@@ -174,7 +174,7 @@ class Option:
suppress_with_remap: Optional[bool] = False
_args: List[str] = field(init=False)
def __post_init__(self):
def __post_init__(self) -> None:
# check against list of protected condor names/characters,
# rename condor variables name to avoid issues
self.condor_name = self.name.replace("-", "_")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment