Skip to content

Commit 8256190

Browse files
committed
PR #736 Code style cleanup
using `ruff format`, `ruff check --fix --select F401,I001` and `isort`
1 parent 40095c8 commit 8256190

File tree

4 files changed

+124
-141
lines changed

4 files changed

+124
-141
lines changed

openeo/extra/job_management/__init__.py

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
from urllib3.util import Retry
3434

3535
from openeo import BatchJob, Connection
36-
from openeo.extra.job_management._thread_worker import ( _JobManagerWorkerThreadPool,
37-
_JobStartTask)
38-
36+
from openeo.extra.job_management._thread_worker import (
37+
_JobManagerWorkerThreadPool,
38+
_JobStartTask,
39+
)
3940
from openeo.internal.processes.parse import (
4041
Parameter,
4142
Process,
@@ -527,8 +528,7 @@ def run_jobs(
527528
time.sleep(self.poll_sleep)
528529
stats["sleep"] += 1
529530

530-
531-
# TODO; run post process after shutdown once more to ensure completion?
531+
# TODO; run post process after shutdown once more to ensure completion?
532532
self._worker_pool.shutdown()
533533

534534
return stats
@@ -571,7 +571,7 @@ def _job_update_loop(
571571
total_added += 1
572572

573573
self._process_threadworker_updates(self._worker_pool, job_db, stats)
574-
574+
575575
# TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads?
576576
for job, row in jobs_done:
577577
self.on_job_done(job, row)
@@ -644,7 +644,7 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
644644
)
645645
_log.info(f"Submitting task {task} to thread pool")
646646
self._worker_pool.submit_task(task)
647-
647+
648648
stats["job_queued_for_start"] += 1
649649
df.loc[i, "status"] = "queued_for_start"
650650
except OpenEoApiError as e:
@@ -660,59 +660,55 @@ def _process_threadworker_updates(
660660
self,
661661
worker_pool: _JobManagerWorkerThreadPool,
662662
job_db: JobDatabaseInterface,
663-
stats: dict
663+
stats: dict,
664664
) -> None:
665-
"""Processes asynchronous job updates from worker threads and applies them to the job database and statistics.
666-
665+
"""
666+
Processes asynchronous job updates from worker threads and applies them to the job database and statistics.
667+
667668
This wrapper function is responsible for:
668669
1. Collecting completed results from the worker thread pool
669670
2. applying database updates for each job result
670671
3. applying statistics updates
671672
4. Handles errors with comprehensive logging
672-
673+
673674
:param worker_pool:
674675
Thread pool instance managing the asynchronous job operations.
675676
Should provide a `process_futures()` method returning completed job results.
676-
677+
677678
:param job_db:
678679
Job database implementing the :py:class:`JobDatabaseInterface` interface.
679680
Used to persist job status updates and metadata.
680681
Must support the `_update_row(job_id: str, updates: dict)` method.
681-
682+
682683
:param stats:
683684
Dictionary tracking operational statistics that will be updated in-place.
684685
Expected to handle string keys with integer values.
685686
Statistics will be updated with counts from completed job results.
686-
687-
:return:
687+
688+
:return:
688689
None: All updates are applied in-place to the job_db and stats parameters.
689-
.
690690
"""
691691
results = worker_pool.process_futures()
692692
stats_updates = collections.defaultdict(int)
693-
694-
for result in results:
693+
694+
for result in results:
695695
try:
696696
# Handle job database updates
697697
if result.db_update:
698698
_log.debug(f"Processing update for job {result.job_id}")
699699
job_db._update_row(job_id=result.job_id, updates=result.db_update)
700-
700+
701701
# Aggregate statistics updates
702702
if result.stats_update:
703703
for key, count in result.stats_update.items():
704704
stats_updates[key] += int(count)
705-
706-
705+
707706
except Exception as e:
708-
_log.error(
709-
f"Failed aggregating the updates for update for job {result.job_id}: {str(e)}")
710-
707+
_log.error(f"Failed aggregating the updates for update for job {result.job_id}: {str(e)}")
708+
711709
# Apply all stat updates
712710
for key, count in stats_updates.items():
713711
stats[key] = stats.get(key, 0) + count
714-
715-
716712

717713
def on_job_done(self, job: BatchJob, row):
718714
"""
@@ -877,6 +873,7 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] =
877873

878874
return jobs_done, jobs_error, jobs_cancel
879875

876+
880877
def _format_usage_stat(job_metadata: dict, field: str) -> str:
881878
value = deep_get(job_metadata, "usage", field, "value", default=0)
882879
unit = deep_get(job_metadata, "usage", field, "unit", default="")
@@ -986,29 +983,29 @@ def _update_row(self, job_id: str, updates: dict):
986983
# Create boolean mask for target row
987984
mask = self._df["id"] == job_id
988985
match_count = mask.sum()
989-
986+
990987
# Handle row identification issues
991-
#TODO: make this more robust, e.g. falling back on the row index?
988+
# TODO: make this more robust, e.g. falling back on the row index?
992989
if match_count == 0:
993990
_log.error(f"Job {job_id!r} not found in database")
994991
return
995992
if match_count > 1:
996993
_log.error(f"Duplicate job ID {job_id!r} found in database")
997994
return
998995

999-
# Get valid columns
1000-
valid_columns = set(self._df.columns)
996+
# Get valid columns
997+
valid_columns = set(self._df.columns)
1001998
filtered_updates = {}
1002-
999+
10031000
# Validate update keys s
10041001
for key, value in updates.items():
10051002
if key in valid_columns:
10061003
filtered_updates[key] = value
10071004
else:
10081005
_log.warning(f"Ignoring invalid column {key!r} in update for job {job_id}")
10091006

1010-
# Bulk update
1011-
if not filtered_updates:
1007+
# Bulk update
1008+
if not filtered_updates:
10121009
return
10131010
try:
10141011
# Update all columns in a single operation
@@ -1017,9 +1014,6 @@ def _update_row(self, job_id: str, updates: dict):
10171014
except Exception as e:
10181015
_log.error(f"Failed to persist row update for job {job_id}: {e}")
10191016

1020-
1021-
1022-
10231017

10241018
class CsvJobDatabase(FullDataFrameJobDatabase):
10251019
"""
@@ -1075,8 +1069,6 @@ def persist(self, df: pd.DataFrame):
10751069
self.path.parent.mkdir(parents=True, exist_ok=True)
10761070
self.df.to_csv(self.path, index=False)
10771071

1078-
1079-
10801072

10811073
class ParquetJobDatabase(FullDataFrameJobDatabase):
10821074
"""

openeo/extra/job_management/_thread_worker.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import concurrent.futures
22
import logging
3-
from dataclasses import dataclass, field
4-
from typing import Optional, Any, List, Dict, Tuple
5-
import openeo
63
from abc import ABC, abstractmethod
4+
from dataclasses import dataclass, field
5+
from typing import Any, Dict, List, Optional, Tuple
76

7+
import openeo
88

99
_log = logging.getLogger(__name__)
1010

11+
1112
@dataclass
1213
class _TaskResult:
1314
"""
@@ -25,10 +26,12 @@ class _TaskResult:
2526
Optional dictionary capturing statistical counters or metrics,
2627
e.g., number of successful starts or errors. Defaults to an empty dict.
2728
"""
29+
2830
job_id: str # Mandatory
2931
db_update: Dict[str, Any] = field(default_factory=dict) # Optional
3032
stats_update: Dict[str, int] = field(default_factory=dict) # Optional
3133

34+
3235
class Task(ABC):
3336
"""
3437
Abstract base class for asynchronous tasks.
@@ -38,12 +41,13 @@ class Task(ABC):
3841
3942
Implementations must override the `execute` method to define the task logic.
4043
"""
41-
44+
4245
@abstractmethod
4346
def execute(self) -> _TaskResult:
4447
"""Execute the task and return a raw result"""
4548
pass
46-
49+
50+
4751
@dataclass
4852
class _JobStartTask(Task):
4953
"""
@@ -75,10 +79,10 @@ class _JobStartTask(Task):
7579
:raises ValueError:
7680
If any of the input parameters are invalid (e.g., empty strings).
7781
"""
82+
7883
job_id: str
7984
root_url: str
8085
bearer_token: Optional[str]
81-
8286

8387
def __post_init__(self) -> None:
8488
# Validation remains unchanged
@@ -115,10 +119,10 @@ def execute(self) -> _TaskResult:
115119
except Exception as e:
116120
_log.error(f"Failed to start job {self.job_id}: {e}")
117121
return _TaskResult(
118-
job_id=self.job_id,
119-
db_update={"status": "start_failed"},
120-
stats_update={"start_job error": 1})
121-
122+
job_id=self.job_id, db_update={"status": "start_failed"}, stats_update={"start_job error": 1}
123+
)
124+
125+
122126
class _JobManagerWorkerThreadPool:
123127
"""
124128
Thread pool-based worker that manages the execution of asynchronous tasks.
@@ -130,6 +134,7 @@ class _JobManagerWorkerThreadPool:
130134
Maximum number of concurrent threads to use for execution.
131135
Defaults to 2.
132136
"""
137+
133138
def __init__(self, max_workers: int = 2):
134139
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
135140
self._future_task_pairs: List[Tuple[concurrent.futures.Future, Task]] = []
@@ -147,7 +152,7 @@ def submit_task(self, task: Task) -> None:
147152
future = self._executor.submit(task.execute)
148153
self._future_task_pairs.append((future, task)) # Track pairs
149154

150-
def process_futures(self) -> List[ _TaskResult]:
155+
def process_futures(self) -> List[_TaskResult]:
151156
"""
152157
Process and retrieve results from completed tasks.
153158
@@ -157,34 +162,31 @@ def process_futures(self) -> List[ _TaskResult]:
157162
:returns:
158163
A list of `_TaskResult` objects from completed tasks.
159164
"""
160-
results = []
161-
to_keep = []
165+
results = []
166+
to_keep = []
162167

163168
# Use timeout=0 to avoid blocking and check for completed futures
164169
done, _ = concurrent.futures.wait(
165-
[f for f, _ in self._future_task_pairs], timeout=0,
166-
return_when=concurrent.futures.FIRST_COMPLETED
170+
[f for f, _ in self._future_task_pairs], timeout=0, return_when=concurrent.futures.FIRST_COMPLETED
167171
)
168172

169173
# Process completed futures and their tasks
170174
for future, task in self._future_task_pairs:
171175
if future in done:
172176
try:
173177
result = future.result()
174-
175-
except Exception as e:
176178

179+
except Exception as e:
177180
_log.exception(f"Error processing task: {e}")
178-
result = _TaskResult(
179-
job_id=task.job_id,
180-
db_update={"status": "start_failed"},
181-
stats_update={"start_job error": 1})
182-
181+
result = _TaskResult(
182+
job_id=task.job_id, db_update={"status": "start_failed"}, stats_update={"start_job error": 1}
183+
)
184+
183185
results.append(result)
184-
else:
185-
to_keep.append((future, task))
186+
else:
187+
to_keep.append((future, task))
186188

187-
self._future_task_pairs = to_keep
189+
self._future_task_pairs = to_keep
188190
return results
189191

190192
def shutdown(self) -> None:

0 commit comments

Comments
 (0)