5
5
import datetime
6
6
import json
7
7
import logging
8
+ import queue
8
9
import re
10
+ import threading
9
11
import time
10
12
import warnings
11
13
from pathlib import Path
31
33
import shapely .wkt
32
34
from requests .adapters import HTTPAdapter , Retry
33
35
36
+ import openeo
34
37
from openeo import BatchJob , Connection
35
38
from openeo .internal .processes .parse import (
36
39
Parameter ,
37
40
Process ,
38
41
parse_remote_process_definition ,
39
42
)
40
43
from openeo .rest import OpenEoApiError
44
+ from openeo .rest .auth .auth import BearerAuth
41
45
from openeo .util import LazyLoadCache , deep_get , repr_truncate , rfc3339
42
46
43
47
_log = logging .getLogger (__name__ )
@@ -223,6 +227,9 @@ def __init__(
223
227
)
224
228
self ._thread = None
225
229
230
+ self ._work_queue = queue .Queue ()
231
+ self ._result_queue = queue .Queue ()
232
+
226
233
def add_backend (
227
234
self ,
228
235
name : str ,
@@ -493,6 +500,10 @@ def run_jobs(
493
500
# TODO: support user-provided `stats`
494
501
stats = collections .defaultdict (int )
495
502
503
+ # TODO: multiple workers instead of a single one? Work with thread pool?
504
+ worker_thread = _JobManagerWorkerThread (work_queue = self ._work_queue , result_queue = self ._result_queue )
505
+ worker_thread .start ()
506
+
496
507
while sum (job_db .count_by_status (statuses = ["not_started" , "created" , "queued" , "running" ]).values ()) > 0 :
497
508
self ._job_update_loop (job_db = job_db , start_job = start_job , stats = stats )
498
509
stats ["run_jobs loop" ] += 1
@@ -502,6 +513,9 @@ def run_jobs(
502
513
time .sleep (self .poll_sleep )
503
514
stats ["sleep" ] += 1
504
515
516
+ worker_thread .stop_event .set ()
517
+ worker_thread .join ()
518
+
505
519
return stats
506
520
507
521
def _job_update_loop (
@@ -542,6 +556,7 @@ def _job_update_loop(
542
556
total_added += 1
543
557
544
558
# Act on jobs
559
+ # TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads?
545
560
for job , row in jobs_done :
546
561
self .on_job_done (job , row )
547
562
@@ -551,6 +566,11 @@ def _job_update_loop(
551
566
for job , row in jobs_cancel :
552
567
self .on_job_cancel (job , row )
553
568
569
+ # Check worker thread results
570
+ while not self ._result_queue .empty ():
571
+ # TODO
572
+ ...
573
+
554
574
555
575
def _launch_job (self , start_job , df , i , backend_name , stats : Optional [dict ] = None ):
556
576
"""Helper method for launching jobs
@@ -584,7 +604,7 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
584
604
connection = self ._get_connection (backend_name , resilient = True )
585
605
586
606
stats ["start_job call" ] += 1
587
- job = start_job (
607
+ job : BatchJob = start_job (
588
608
row = row ,
589
609
connection_provider = self ._get_connection ,
590
610
connection = connection ,
@@ -605,14 +625,24 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
605
625
if status == "created" :
606
626
# start job if not yet done by callback
607
627
try :
608
- job .start ()
609
- stats ["job start" ] += 1
610
- df .loc [i , "status" ] = job .status ()
611
- stats ["job get status" ] += 1
628
+ job_con = job .connection
629
+ self ._work_queue .put (
630
+ (
631
+ _JobManagerWorkerThread .WORK_TYPE_START_JOB ,
632
+ (
633
+ job_con .root_url ,
634
+ job_con .auth .bearer if isinstance (job_con .auth , BearerAuth ) else None ,
635
+ job .job_id ,
636
+ ),
637
+ )
638
+ )
639
+ job_status = "queued_for_start"
640
+ stats [f"job { job_status } " ] += 1
641
+ df .loc [i , "status" ] = job_status
612
642
except OpenEoApiError as e :
613
643
_log .error (e )
614
- df .loc [i , "status" ] = "start_failed "
615
- stats ["job start error" ] += 1
644
+ df .loc [i , "status" ] = "queued_for_start_failed "
645
+ stats ["job queued_for_start error" ] += 1
616
646
else :
617
647
# TODO: what is this "skipping" about actually?
618
648
df .loc [i , "status" ] = "skipped"
@@ -673,20 +703,20 @@ def _cancel_prolonged_job(self, job: BatchJob, row):
673
703
try :
674
704
# Ensure running start time is valid
675
705
job_running_start_time = rfc3339 .parse_datetime (row .get ("running_start_time" ), with_timezone = True )
676
-
706
+
677
707
# Parse the current time into a datetime object with timezone info
678
708
current_time = rfc3339 .parse_datetime (rfc3339 .utcnow (), with_timezone = True )
679
709
680
710
# Calculate the elapsed time between job start and now
681
711
elapsed = current_time - job_running_start_time
682
712
683
713
if elapsed > self ._cancel_running_job_after :
684
-
714
+
685
715
_log .info (
686
716
f"Cancelling long-running job { job .job_id } (after { elapsed } , running since { job_running_start_time } )"
687
717
)
688
718
job .stop ()
689
-
719
+
690
720
except Exception as e :
691
721
_log .error (f"Unexpected error while handling job { job .job_id } : { e } " )
692
722
@@ -783,6 +813,42 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] =
783
813
return jobs_done , jobs_error , jobs_cancel
784
814
785
815
816
+ class _JobManagerWorkerThread (threading .Thread ):
817
+ WORK_TYPE_START_JOB = "start_job"
818
+
819
+ def __init__ (self , work_queue : queue .Queue , result_queue : queue .Queue ):
820
+ super ().__init__ ()
821
+ self .work_queue = work_queue
822
+ self .result_queue = result_queue
823
+ self .stop_event = threading .Event ()
824
+ # TODO: add customization options for timeout/sleep?
825
+
826
+ def run (self ):
827
+ while not self .stop_event .is_set ():
828
+ try :
829
+ work_type , work_args = self .work_queue .get (timeout = 5 )
830
+ if work_type == self .WORK_TYPE_START_JOB :
831
+ self ._start_job (work_args )
832
+ else :
833
+ raise ValueError (f"Unknown work item: { work_type !r} " )
834
+ except queue .Empty :
835
+ time .sleep (10 )
836
+
837
+ def _start_job (self , work_args : tuple ):
838
+ root_url , bearer , job_id = work_args
839
+ try :
840
+ connection = openeo .connect (url = root_url )
841
+ if bearer :
842
+ connection .authenticate_bearer_token (bearer_token = bearer )
843
+ job = connection .job (job_id )
844
+ job .start ()
845
+ status = job .status ()
846
+ except Exception as e :
847
+ self .result_queue .put ((self .WORK_TYPE_START_JOB , (job_id , "failed" , repr (e ))))
848
+ else :
849
+ self .result_queue .put ((self .WORK_TYPE_START_JOB , (job_id , "started" , status )))
850
+
851
+
786
852
def _format_usage_stat (job_metadata : dict , field : str ) -> str :
787
853
value = deep_get (job_metadata , "usage" , field , "value" , default = 0 )
788
854
unit = deep_get (job_metadata , "usage" , field , "unit" , default = "" )
0 commit comments