22
33import asyncio
44import logging
5+ import sys
6+ from concurrent .futures import ThreadPoolExecutor
7+ from functools import wraps
58from pathlib import PurePath
6- from typing import Any , Mapping , Optional , Sequence , Union
9+ from typing import (
10+ Any ,
11+ Callable ,
12+ Coroutine ,
13+ Mapping ,
14+ Optional ,
15+ Sequence ,
16+ TypeVar ,
17+ Union ,
18+ )
719
820from attrs import asdict
921
1022from .config import Config
11- from .db import DB , TaskModel , TaskStatus
23+ from .db import DB , TaskStatus
1224from .scheduler import Scheduler
1325from .variables import CONFIG_FILE
1426
27+ if sys .version_info < (3 , 10 ):
28+ from typing_extensions import ParamSpec
29+ else :
30+ from typing import ParamSpec
31+
32+ ReturnT_co = TypeVar ("ReturnT_co" , covariant = True )
33+ ParamT = ParamSpec ("ParamT" )
34+
35+
36+ def to_sync (
37+ func : Callable [ParamT , Coroutine [Any , Any , ReturnT_co ]],
38+ ) -> Callable [ParamT , ReturnT_co ]:
39+ """
40+ Wraps async function and run it sync in thread.
41+ """
42+
43+ @wraps (func )
44+ def outer (* args : ParamT .args , ** kwargs : ParamT .kwargs ):
45+ """
46+ Execute the async method synchronously in sync and async runtime.
47+ """
48+ coro = func (* args , ** kwargs )
49+ try :
50+ asyncio .get_running_loop () # Triggers RuntimeError if no running event loop
51+
52+ # Create a separate thread so we can block before returning
53+ with ThreadPoolExecutor (1 ) as pool :
54+ return pool .submit (lambda : asyncio .run (coro )).result ()
55+ except RuntimeError :
56+ return asyncio .run (coro )
57+
58+ return outer
59+
1560
1661class Yascheduler :
1762 """Yascheduler client"""
@@ -31,30 +76,36 @@ def __init__(
3176 self .config = Config .from_config_parser (config_path )
3277 self ._logger = logger
3378
34- def queue_submit_task (
79+ async def queue_submit_task_async (
3580 self ,
3681 label : str ,
3782 metadata : Mapping [str , Any ],
3883 engine_name : str ,
3984 webhook_onsubmit = False ,
4085 ) -> int :
4186 """Submit new task"""
42-
43- async def async_fn () -> TaskModel :
44- yac = await Scheduler .create (config = self .config , log = self ._logger )
45- task = await yac .create_new_task (
46- label = label ,
47- metadata = metadata ,
48- engine_name = engine_name ,
49- webhook_onsubmit = webhook_onsubmit ,
50- )
51- await yac .stop ()
52- return task
53-
54- task = asyncio .run (async_fn ())
87+ yac = await Scheduler .create (config = self .config , log = self ._logger )
88+ task = await yac .create_new_task (
89+ label = label ,
90+ metadata = metadata ,
91+ engine_name = engine_name ,
92+ webhook_onsubmit = webhook_onsubmit ,
93+ )
94+ await yac .stop ()
5595 return task .task_id
5696
57- def queue_get_tasks (
97+ def queue_submit_task (
98+ self ,
99+ label : str ,
100+ metadata : Mapping [str , Any ],
101+ engine_name : str ,
102+ webhook_onsubmit = False ,
103+ ) -> int :
104+ """Submit new task"""
105+ fn = to_sync (self .queue_submit_task_async )
106+ return fn (label , metadata , engine_name , webhook_onsubmit )
107+
108+ async def queue_get_tasks_async (
58109 self ,
59110 jobs : Optional [Sequence [int ]] = None ,
60111 status : Optional [Sequence [int ]] = None ,
@@ -64,24 +115,28 @@ def queue_get_tasks(
64115 raise ValueError ("jobs can be selected only by status or by task ids" )
65116 # raise ValueError if unknown task status
66117 status = [TaskStatus (x ) for x in status ] if status else None
67-
68- async def fn_get_by_statuses (statuses : Sequence [TaskStatus ]):
69- db = await DB .create (self .config .db )
70- return await db .get_tasks_by_status (statuses )
71-
72- async def fn_get_by_ids (ids : Sequence [int ]):
73- db = await DB .create (self .config .db )
74- return await db .get_tasks_by_jobs (ids )
75-
118+ db = await DB .create (self .config .db )
76119 if status :
77- tasks = asyncio . run ( fn_get_by_statuses ( status ) )
120+ tasks = await db . get_tasks_by_status ( status )
78121 elif jobs :
79- tasks = asyncio . run ( fn_get_by_ids ( jobs ) )
122+ tasks = await db . get_tasks_by_jobs ( jobs )
80123 else :
81124 return []
82-
83125 return [asdict (t ) for t in tasks ]
84126
127+ def queue_get_tasks (
128+ self ,
129+ jobs : Optional [Sequence [int ]] = None ,
130+ status : Optional [Sequence [int ]] = None ,
131+ ) -> Sequence [Mapping [str , Any ]]:
132+ """Get tasks by ids or statuses"""
133+ return to_sync (self .queue_get_tasks_async )(jobs , status )
134+
135+ async def queue_get_task_async (self , task_id : int ) -> Optional [Mapping [str , Any ]]:
136+ """Get task by id"""
137+ for task_dict in await self .queue_get_tasks_async (jobs = [task_id ]):
138+ return task_dict
139+
85140 def queue_get_task (self , task_id : int ) -> Optional [Mapping [str , Any ]]:
86141 """Get task by id"""
87142 for task_dict in self .queue_get_tasks (jobs = [task_id ]):
0 commit comments