Skip to content

Commit 2c4572a

Browse files
authored
Add formatter (#89)
Add ruff formatter and format
1 parent 18dcdd7 commit 2c4572a

30 files changed

+524
-415
lines changed

dplutils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
try:
44
from ._version import __version__
55
except ImportError:
6-
__version__ = ''
6+
__version__ = ""

dplutils/cli.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from argparse import ArgumentParser, Namespace
3+
34
from dplutils.pipeline import PipelineExecutor
45

56

@@ -16,9 +17,9 @@ def add_generic_args(argparser):
1617
argparser: The :class:`ArgumentParser<argparse.ArgumentParser>` instance
1718
to add args to.
1819
"""
19-
argparser.add_argument('-c', '--set-context', action='append', default=[], help='set context parameter')
20-
argparser.add_argument('-s', '--set-config', action='append', default=[], help='set configuration parameter')
21-
argparser.add_argument('-o', '--out-dir', default='.', help='write results to directory')
20+
argparser.add_argument("-c", "--set-context", action="append", default=[], help="set context parameter")
21+
argparser.add_argument("-s", "--set-config", action="append", default=[], help="set configuration parameter")
22+
argparser.add_argument("-o", "--out-dir", default=".", help="write results to directory")
2223

2324

2425
def get_argparser(**kwargs):
@@ -40,7 +41,7 @@ def get_argparser(**kwargs):
4041

4142

4243
def parse_config_element(conf):
43-
k,v = conf.split('=', 1)
44+
k, v = conf.split("=", 1)
4445
try:
4546
v = json.loads(v)
4647
except json.decoder.JSONDecodeError:
@@ -67,7 +68,7 @@ def set_config_from_args(pipeline: PipelineExecutor, args: Namespace):
6768
pipeline.set_config(*parse_config_element(conf))
6869

6970

70-
def cli_run(pipeline: PipelineExecutor, args: Namespace|None = None, **argparse_kwargs):
71+
def cli_run(pipeline: PipelineExecutor, args: Namespace | None = None, **argparse_kwargs):
7172
"""Run pipeline from cli args
7273
7374
If ``args`` is None, this function runs the pipeline for the standard set of

dplutils/observer/__init__.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class Timer:
1717
with observer.timer('calltime'):
1818
<<do something>>
1919
"""
20+
2021
def __init__(self, observer, name, **kwargs):
2122
self.observer = observer
2223
self.name = name
@@ -36,7 +37,7 @@ def stop(self):
3637

3738
def complete(self):
3839
if not self.started:
39-
raise ValueError('Timer not started!')
40+
raise ValueError("Timer not started!")
4041
self.stop()
4142
self.observer.observe(self.name, self.accum, **self.kwargs)
4243

@@ -59,6 +60,7 @@ class Observer(ABC):
5960
While implementations are required to implement ``observe``, ``increment``
6061
and ``param``, there may be legitimit cases where the recording of
6162
"""
63+
6264
@abstractmethod
6365
def observe(self, name, value, **kwargs):
6466
"""Observe a metric value
@@ -130,6 +132,7 @@ class NoOpObserver(Observer):
130132
This is akin to the ``NullHandler<logging.NullHandler>`` in the logging
131133
module and is the default upon initialization.
132134
"""
135+
133136
def observe(*args):
134137
"""This method does nothing"""
135138
pass
@@ -150,6 +153,7 @@ class InMemoryObserver(Observer):
150153
each element in the list is a tuple (recorded_unix_time, value). Params are
151154
stored in a separate dict keyed by the parameter ``name``.
152155
"""
156+
153157
def __init__(self):
154158
self.metrics = defaultdict(list)
155159
self.params = {}
@@ -168,39 +172,39 @@ def param(self, name, value, **kwargs):
168172
self.params[name] = value
169173

170174
def dump(self):
171-
return {'params': self.params, 'metrics': self.metrics}
175+
return {"params": self.params, "metrics": self.metrics}
172176

173177

174178
observer_map = {
175-
'root': NoOpObserver(),
179+
"root": NoOpObserver(),
176180
}
177181

178182

179-
def set_observer(obs, key='root'):
183+
def set_observer(obs, key="root"):
180184
"""Set the global observer at ``key``"""
181185
observer_map[key] = obs
182186

183187

184-
def get_observer(key='root'):
188+
def get_observer(key="root"):
185189
"""Get the global observer at ``key``"""
186-
return observer_map.get(key, observer_map['root'])
190+
return observer_map.get(key, observer_map["root"])
187191

188192

189193
def observe(*args, **kwargs):
190194
"""call observe on the root observer"""
191-
observer_map['root'].observe(*args, **kwargs)
195+
observer_map["root"].observe(*args, **kwargs)
192196

193197

194198
def increment(*args, **kwargs):
195199
"""call increment on the root observer"""
196-
observer_map['root'].increment(*args, **kwargs)
200+
observer_map["root"].increment(*args, **kwargs)
197201

198202

199203
def param(*args, **kwargs):
200204
"""call param on the root observer"""
201-
observer_map['root'].param(*args, **kwargs)
205+
observer_map["root"].param(*args, **kwargs)
202206

203207

204208
def timer(*args, **kwargs):
205209
"""call timer on the root observer"""
206-
return observer_map['root'].timer(*args, **kwargs)
210+
return observer_map["root"].timer(*args, **kwargs)

dplutils/observer/aim.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class AimObserver(Observer):
2222
Aim does not track the time with metric, only the step and this
2323
implementation uses the default auto-increment step counter.
2424
"""
25+
2526
def __init__(self, run=None, **aim_kwargs):
2627
if run is not None:
2728
self.run = run

dplutils/observer/mlflow.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ class MlflowObserver(Observer):
2424
will be passed to its instantiation, using
2525
``mlflow.MlflowClient.create_run``.
2626
"""
27+
2728
def __init__(self, run=None, experiment=None, tracking_uri=None, **mlflow_kwargs):
2829
if mlflow is None:
2930
raise ImportError("mlflow must be installed to create observer run!")
3031

3132
tracking_uri = tracking_uri or mlflow.get_tracking_uri()
32-
self.mlflow_client = mlflow.MlflowClient(tracking_uri = tracking_uri)
33+
self.mlflow_client = mlflow.MlflowClient(tracking_uri=tracking_uri)
3334

3435
if run is not None:
3536
self.run = run
@@ -41,7 +42,7 @@ def __init__(self, run=None, experiment=None, tracking_uri=None, **mlflow_kwargs
4142
expid = exp.experiment_id
4243
else:
4344
expid = self.mlflow_client.create_experiment(experiment)
44-
self.run = self.mlflow_client.create_run(experiment_id = expid, **mlflow_kwargs)
45+
self.run = self.mlflow_client.create_run(experiment_id=expid, **mlflow_kwargs)
4546

4647
self.run_id = self.run.info.run_id
4748
self._countercache = {}

dplutils/observer/ray.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ray
22
from ray.util.metrics import Counter, Gauge
3+
34
from dplutils.observer import Observer
45

56

@@ -15,6 +16,7 @@ class RayActorWrappedObserver(Observer):
1516
*args: Args to pass to ``cls`` instantiation
1617
**kwargs: Keyword args to pass to ``cls`` instantiation
1718
"""
19+
1820
def __init__(self, cls, *args, **kwargs):
1921
self.actor = ray.remote(cls).remote(*args, **kwargs)
2022
self._wait = False # for testing purposes. If true wait instead of fire-and-forget
@@ -41,14 +43,15 @@ class RayMetricsObserver(Observer):
4143
objects, this can be used directly having copies per worker (so does not
4244
need to be wrapped in actor).
4345
"""
46+
4447
def __init__(self):
4548
self.mmap = {}
4649

4750
def _get_or_set_as(self, name, kind):
4851
if name in self.mmap:
4952
metric = self.mmap[name]
5053
if not isinstance(metric, kind):
51-
raise TypeError(f'setting metric requires {kind}, but {name} is {type(metric)}')
54+
raise TypeError(f"setting metric requires {kind}, but {name} is {type(metric)}")
5255
else:
5356
metric = kind(name)
5457
self.mmap[name] = metric

dplutils/pipeline/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from .task import PipelineTask
2-
from .executor import PipelineExecutor, OutputBatch
1+
from .executor import OutputBatch, PipelineExecutor
32
from .graph import PipelineGraph
3+
from .task import PipelineTask
44

5-
__all__ = ['PipelineTask', 'PipelineExecutor', 'OutputBatch', 'PipelineGraph']
5+
__all__ = ["PipelineTask", "PipelineExecutor", "OutputBatch", "PipelineGraph"]

dplutils/pipeline/executor.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import uuid
2-
import pandas as pd
3-
import yaml
42
from abc import ABC, abstractmethod
3+
from collections.abc import Iterable
54
from copy import deepcopy
65
from dataclasses import dataclass
76
from pathlib import Path
87
from typing import Any
9-
from collections.abc import Iterable
8+
9+
import pandas as pd
10+
import yaml
11+
1012
from dplutils.pipeline.graph import PipelineGraph
1113
from dplutils.pipeline.utils import dict_from_coord
1214

@@ -27,6 +29,7 @@ class PipelineExecutor(ABC):
2729
to execute the pipeline and return and generator of dataframes of the final
2830
tasks in the graph.
2931
"""
32+
3033
def __init__(self, graph: PipelineGraph):
3134
if isinstance(graph, list):
3235
self.graph = PipelineGraph(deepcopy(graph))
@@ -36,21 +39,21 @@ def __init__(self, graph: PipelineGraph):
3639
self._run_id = None
3740

3841
@classmethod
39-
def from_graph(cls, graph: PipelineGraph) -> 'PipelineExecutor':
42+
def from_graph(cls, graph: PipelineGraph) -> "PipelineExecutor":
4043
return cls(graph)
4144

4245
@property
4346
def tasks_idx(self): # for back compat
4447
return self.graph.task_map
4548

46-
def set_context(self, key, value) -> 'PipelineExecutor':
49+
def set_context(self, key, value) -> "PipelineExecutor":
4750
self.ctx[key] = value
4851
return self
4952

50-
def set_config_from_dict(self, config) -> 'PipelineExecutor':
53+
def set_config_from_dict(self, config) -> "PipelineExecutor":
5154
for task_name, confs in config.items():
5255
if task_name not in self.tasks_idx:
53-
raise ValueError(f'no such task: {task_name}')
56+
raise ValueError(f"no such task: {task_name}")
5457
for key, value in confs.items():
5558
task = self.tasks_idx[task_name]
5659
task_val = getattr(task, key)
@@ -61,11 +64,11 @@ def set_config_from_dict(self, config) -> 'PipelineExecutor':
6164
return self
6265

6366
def set_config(
64-
self,
65-
coord: str|dict|None = None,
66-
value: Any|None = None,
67-
from_yaml: str|Path|None = None,
68-
) -> 'PipelineExecutor':
67+
self,
68+
coord: str | dict | None = None,
69+
value: Any | None = None,
70+
from_yaml: str | Path | None = None,
71+
) -> "PipelineExecutor":
6972
"""Set task configuration options for this instance.
7073
7174
This applies configurations to :class:`PipelineTask
@@ -90,8 +93,8 @@ def set_config(
9093
"""
9194
if coord is None:
9295
if from_yaml is None:
93-
raise ValueError('one of dict/string coordinate and value/file input is required')
94-
with open(from_yaml, 'r') as f:
96+
raise ValueError("one of dict/string coordinate and value/file input is required")
97+
with open(from_yaml, "r") as f:
9598
return self.set_config_from_dict(yaml.load(f, yaml.SafeLoader))
9699
if isinstance(coord, dict):
97100
return self.set_config_from_dict(coord)
@@ -106,7 +109,7 @@ def validate(self) -> None:
106109
except ValueError as e:
107110
excs.append(str(e))
108111
if len(excs) > 0:
109-
raise ValueError('Errors in validation:\n - ' + '\n - '.join(excs))
112+
raise ValueError("Errors in validation:\n - " + "\n - ".join(excs))
110113

111114
@property
112115
def run_id(self) -> str:
@@ -147,7 +150,9 @@ def run(self) -> Iterable[OutputBatch]:
147150
self._run_id = None # force reallocation
148151
return self.execute()
149152

150-
def writeto(self, outdir: Path|str, partition_by_task: bool|None = None, task_partition_name: str = 'task') -> None:
153+
def writeto(
154+
self, outdir: Path | str, partition_by_task: bool | None = None, task_partition_name: str = "task"
155+
) -> None:
151156
"""Run pipeline, writing results to parquet table.
152157
153158
args:
@@ -166,10 +171,10 @@ def writeto(self, outdir: Path|str, partition_by_task: bool|None = None, task_pa
166171
Path(outdir).mkdir(parents=True, exist_ok=True)
167172
for c, batch in enumerate(self.run()):
168173
if partition_by_task:
169-
part_name = batch.task or '__HIVE_DEFAULT_PARTITION__'
170-
part_path = Path(outdir) / f'{task_partition_name}={part_name}'
174+
part_name = batch.task or "__HIVE_DEFAULT_PARTITION__"
175+
part_path = Path(outdir) / f"{task_partition_name}={part_name}"
171176
part_path.mkdir(exist_ok=True)
172-
outfile = part_path / f'{self.run_id}-{c}.parquet'
177+
outfile = part_path / f"{self.run_id}-{c}.parquet"
173178
else:
174-
outfile = Path(outdir) / f'{self.run_id}-{c}.parquet'
179+
outfile = Path(outdir) / f"{self.run_id}-{c}.parquet"
175180
batch.data.to_parquet(outfile, index=False)

dplutils/pipeline/graph.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from enum import Enum
2-
from networkx import DiGraph, path_graph, all_simple_paths, is_directed_acyclic_graph, bfs_edges
2+
3+
from networkx import DiGraph, all_simple_paths, bfs_edges, is_directed_acyclic_graph, path_graph
4+
35
from dplutils.pipeline.task import PipelineTask
46

57

68
class TRM(Enum):
7-
sink = 'sink'
8-
source = 'source'
9+
sink = "sink"
10+
source = "source"
911

1012

1113
class PipelineGraph(DiGraph):
@@ -18,37 +20,37 @@ class PipelineGraph(DiGraph):
1820
graph: This is either a list of :class:`PipelineTask` objects representing a
1921
simple-graph, or anything that is legal input to :class:`networkx.DiGraph`.
2022
"""
23+
2124
def __init__(self, graph=None):
2225
if isinstance(graph, list) and isinstance(graph[0], PipelineTask):
2326
graph = path_graph(graph, DiGraph)
2427
super().__init__(graph)
2528
if not is_directed_acyclic_graph(self):
26-
raise ValueError('cycles detected in graph')
29+
raise ValueError("cycles detected in graph")
2730

2831
@property
2932
def task_map(self):
3033
return {i.name: i for i in self}
3134

3235
@property
3336
def source_tasks(self):
34-
return [n for n,d in self.in_degree() if d == 0]
37+
return [n for n, d in self.in_degree() if d == 0]
3538

3639
@property
3740
def sink_tasks(self):
38-
return [n for n,d in self.out_degree() if d == 0]
41+
return [n for n, d in self.out_degree() if d == 0]
3942

4043
def to_list(self):
41-
"""Return list representation of task iff it is a simple-path graph
42-
"""
44+
"""Return list representation of task iff it is a simple-path graph"""
4345
if len(self.source_tasks) != 1 or len(self.sink_tasks) != 1:
44-
raise ValueError('to_list requires a graph with only one start and end task')
46+
raise ValueError("to_list requires a graph with only one start and end task")
4547
source = self.source_tasks[0]
4648
sink = self.sink_tasks[0]
4749
if source == sink:
4850
return [source]
4951
paths = list(all_simple_paths(self, source, sink))
5052
if len(paths) != 1:
51-
raise ValueError('to_list requires a single path from start to end task, found {len(paths)}')
53+
raise ValueError("to_list requires a single path from start to end task, found {len(paths)}")
5254
return paths[0]
5355

5456
def with_terminals(self):
@@ -59,11 +61,13 @@ def with_terminals(self):
5961

6062
def _walk(self, source, back=False, sort_key=None):
6163
graph = self.with_terminals()
64+
6265
# doubly wrap the sort key function for conveneince (since bfs search
6366
# takes list, not sort key) and to inject the ignoring of terminal
6467
# nodes. This makes the walk sort key behave a bit more like `sorted()`
6568
def _sort_key(x):
6669
return 0 if isinstance(x, TRM) else sort_key(x)
70+
6771
sorter = (lambda x: sorted(x, key=_sort_key)) if sort_key else None
6872
for _, node in bfs_edges(graph, source, reverse=back, sort_neighbors=sorter):
6973
if not isinstance(node, TRM):

0 commit comments

Comments
 (0)