Skip to content

Commit 3335070

Browse files
committed
Issue #719 job manager WIP: start jobs in worker thread
1 parent 57c67fe commit 3335070

File tree

1 file changed

+76
-10
lines changed

1 file changed

+76
-10
lines changed

openeo/extra/job_management/__init__.py

+76-10
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import datetime
66
import json
77
import logging
8+
import queue
89
import re
10+
import threading
911
import time
1012
import warnings
1113
from pathlib import Path
@@ -31,13 +33,15 @@
3133
import shapely.wkt
3234
from requests.adapters import HTTPAdapter, Retry
3335

36+
import openeo
3437
from openeo import BatchJob, Connection
3538
from openeo.internal.processes.parse import (
3639
Parameter,
3740
Process,
3841
parse_remote_process_definition,
3942
)
4043
from openeo.rest import OpenEoApiError
44+
from openeo.rest.auth.auth import BearerAuth
4145
from openeo.util import LazyLoadCache, deep_get, repr_truncate, rfc3339
4246

4347
_log = logging.getLogger(__name__)
@@ -223,6 +227,9 @@ def __init__(
223227
)
224228
self._thread = None
225229

230+
self._work_queue = queue.Queue()
231+
self._result_queue = queue.Queue()
232+
226233
def add_backend(
227234
self,
228235
name: str,
@@ -493,6 +500,10 @@ def run_jobs(
493500
# TODO: support user-provided `stats`
494501
stats = collections.defaultdict(int)
495502

503+
# TODO: multiple workers instead of a single one? Work with thread pool?
504+
worker_thread = _JobManagerWorkerThread(work_queue=self._work_queue, result_queue=self._result_queue)
505+
worker_thread.start()
506+
496507
while sum(job_db.count_by_status(statuses=["not_started", "created", "queued", "running"]).values()) > 0:
497508
self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats)
498509
stats["run_jobs loop"] += 1
@@ -502,6 +513,9 @@ def run_jobs(
502513
time.sleep(self.poll_sleep)
503514
stats["sleep"] += 1
504515

516+
worker_thread.stop_event.set()
517+
worker_thread.join()
518+
505519
return stats
506520

507521
def _job_update_loop(
@@ -542,6 +556,7 @@ def _job_update_loop(
542556
total_added += 1
543557

544558
# Act on jobs
559+
# TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads?
545560
for job, row in jobs_done:
546561
self.on_job_done(job, row)
547562

@@ -551,6 +566,11 @@ def _job_update_loop(
551566
for job, row in jobs_cancel:
552567
self.on_job_cancel(job, row)
553568

569+
# Check worker thread results
570+
while not self._result_queue.empty():
571+
# TODO
572+
...
573+
554574

555575
def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = None):
556576
"""Helper method for launching jobs
@@ -584,7 +604,7 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
584604
connection = self._get_connection(backend_name, resilient=True)
585605

586606
stats["start_job call"] += 1
587-
job = start_job(
607+
job: BatchJob = start_job(
588608
row=row,
589609
connection_provider=self._get_connection,
590610
connection=connection,
@@ -605,14 +625,24 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
605625
if status == "created":
606626
# start job if not yet done by callback
607627
try:
608-
job.start()
609-
stats["job start"] += 1
610-
df.loc[i, "status"] = job.status()
611-
stats["job get status"] += 1
628+
job_con = job.connection
629+
self._work_queue.put(
630+
(
631+
_JobManagerWorkerThread.WORK_TYPE_START_JOB,
632+
(
633+
job_con.root_url,
634+
job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None,
635+
job.job_id,
636+
),
637+
)
638+
)
639+
job_status = "queued_for_start"
640+
stats[f"job {job_status}"] += 1
641+
df.loc[i, "status"] = job_status
612642
except OpenEoApiError as e:
613643
_log.error(e)
614-
df.loc[i, "status"] = "start_failed"
615-
stats["job start error"] += 1
644+
df.loc[i, "status"] = "queued_for_start_failed"
645+
stats["job queued_for_start error"] += 1
616646
else:
617647
# TODO: what is this "skipping" about actually?
618648
df.loc[i, "status"] = "skipped"
@@ -673,20 +703,20 @@ def _cancel_prolonged_job(self, job: BatchJob, row):
673703
try:
674704
# Ensure running start time is valid
675705
job_running_start_time = rfc3339.parse_datetime(row.get("running_start_time"), with_timezone=True)
676-
706+
677707
# Parse the current time into a datetime object with timezone info
678708
current_time = rfc3339.parse_datetime(rfc3339.utcnow(), with_timezone=True)
679709

680710
# Calculate the elapsed time between job start and now
681711
elapsed = current_time - job_running_start_time
682712

683713
if elapsed > self._cancel_running_job_after:
684-
714+
685715
_log.info(
686716
f"Cancelling long-running job {job.job_id} (after {elapsed}, running since {job_running_start_time})"
687717
)
688718
job.stop()
689-
719+
690720
except Exception as e:
691721
_log.error(f"Unexpected error while handling job {job.job_id}: {e}")
692722

@@ -783,6 +813,42 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] =
783813
return jobs_done, jobs_error, jobs_cancel
784814

785815

816+
class _JobManagerWorkerThread(threading.Thread):
817+
WORK_TYPE_START_JOB = "start_job"
818+
819+
def __init__(self, work_queue: queue.Queue, result_queue: queue.Queue):
820+
super().__init__()
821+
self.work_queue = work_queue
822+
self.result_queue = result_queue
823+
self.stop_event = threading.Event()
824+
# TODO: add customization options for timeout/sleep?
825+
826+
def run(self):
827+
while not self.stop_event.is_set():
828+
try:
829+
work_type, work_args = self.work_queue.get(timeout=5)
830+
if work_type == self.WORK_TYPE_START_JOB:
831+
self._start_job(work_args)
832+
else:
833+
raise ValueError(f"Unknown work item: {work_type!r}")
834+
except queue.Empty:
835+
time.sleep(10)
836+
837+
def _start_job(self, work_args: tuple):
838+
root_url, bearer, job_id = work_args
839+
try:
840+
connection = openeo.connect(url=root_url)
841+
if bearer:
842+
connection.authenticate_bearer_token(bearer_token=bearer)
843+
job = connection.job(job_id)
844+
job.start()
845+
status = job.status()
846+
except Exception as e:
847+
self.result_queue.put((self.WORK_TYPE_START_JOB, (job_id, "failed", repr(e))))
848+
else:
849+
self.result_queue.put((self.WORK_TYPE_START_JOB, (job_id, "started", status)))
850+
851+
786852
def _format_usage_stat(job_metadata: dict, field: str) -> str:
787853
value = deep_get(job_metadata, "usage", field, "value", default=0)
788854
unit = deep_get(job_metadata, "usage", field, "unit", default="")

0 commit comments

Comments
 (0)