|
2 | 2 | import signal
|
3 | 3 | import sys
|
4 | 4 | from abc import ABC, abstractmethod
|
| 5 | +from datetime import timedelta |
5 | 6 | from typing import Any, Coroutine, Type
|
6 | 7 |
|
| 8 | +import pycron # type: ignore[import-untyped] |
7 | 9 | import quattro
|
8 | 10 | from ape import chain
|
9 | 11 | from ape.logging import logger
|
|
33 | 35 | from .main import SilverbackBot, TaskData
|
34 | 36 | from .recorder import BaseRecorder, TaskResult
|
35 | 37 | from .state import Datastore, StateSnapshot
|
36 |
| -from .types import TaskType |
| 38 | +from .types import TaskType, utc_now |
37 | 39 | from .utils import async_wrap_iter
|
38 | 40 |
|
39 | 41 | if sys.version_info < (3, 11):
|
@@ -123,6 +125,31 @@ async def _checkpoint(
|
123 | 125 | ):
|
124 | 126 | await self.datastore.save(snapshot)
|
125 | 127 |
|
| 128 | + async def _cron_tasks(self, cron_tasks: list[TaskData]): |
| 129 | + """ |
| 130 | + Handle all cron tasks |
| 131 | + """ |
| 132 | + |
| 133 | + while True: |
| 134 | + # NOTE: Sleep until next exact time boundary (every minute) |
| 135 | + current_time = utc_now() |
| 136 | + wait_time = timedelta( |
| 137 | + seconds=60 - 1 - current_time.second, |
| 138 | + microseconds=int(1e6) - current_time.microsecond, |
| 139 | + ) |
| 140 | + await asyncio.sleep(wait_time.total_seconds()) |
| 141 | + current_time += wait_time |
| 142 | + |
| 143 | + for task_data in cron_tasks: |
| 144 | + if not (cron := task_data.labels.get("cron")): |
| 145 | + logger.warning(f"Cron task missing `cron` label: '{task_data.name}'") |
| 146 | + continue |
| 147 | + |
| 148 | + if pycron.is_now(cron, dt=current_time): |
| 149 | + self._runtime_task_group.create_task(self.run_task(task_data, current_time)) |
| 150 | + |
| 151 | + # NOTE: TaskGroup waits for all tasks to complete before continuing |
| 152 | + |
126 | 153 | @abstractmethod
|
127 | 154 | async def _block_task(self, task_data: TaskData) -> None:
|
128 | 155 | """
|
@@ -209,6 +236,13 @@ async def startup(self) -> list[Coroutine]:
|
209 | 236 | # NOTE: No need to handle results otherwise
|
210 | 237 |
|
211 | 238 | # Create our long-running event listeners
|
| 239 | + cron_tasks_taskdata = ( |
| 240 | + await self.run_system_task(TaskType.SYSTEM_USER_TASKDATA, TaskType.CRON_JOB) |
| 241 | + if Version(config.sdk_version) >= Version("0.7.15") |
| 242 | + # NOTE: Not supported in prior versions |
| 243 | + else [] |
| 244 | + ) |
| 245 | + |
212 | 246 | new_block_tasks_taskdata = await self.run_system_task(
|
213 | 247 | TaskType.SYSTEM_USER_TASKDATA, TaskType.NEW_BLOCK
|
214 | 248 | )
|
@@ -309,6 +343,9 @@ async def wait_for_graceful_shutdown():
|
309 | 343 |
|
310 | 344 | try:
|
311 | 345 | async with quattro.TaskGroup() as tg:
|
| 346 | + # NOTE: Our runtime tasks can use this to spawn more tasks |
| 347 | + self._runtime_task_group = tg |
| 348 | + |
312 | 349 | # NOTE: User tasks that should run forever
|
313 | 350 | for coro in user_tasks:
|
314 | 351 | tg.create_task(coro)
|
|
0 commit comments