diff options
author | Maxim Devaev <[email protected]> | 2022-06-14 11:23:04 +0300 |
---|---|---|
committer | Maxim Devaev <[email protected]> | 2022-06-14 16:44:59 +0300 |
commit | e050bbd725d2670cacf595632c82f9b84e555ac4 (patch) | |
tree | 26fbb27d55443880f7358633ceae17f85485f50b /kvmd/aiotools.py | |
parent | 6caeb2ce8211af2637c17a0ba2c561f591765321 (diff) |
refactoring
Diffstat (limited to 'kvmd/aiotools.py')
-rw-r--r-- | kvmd/aiotools.py | 49 |
1 files changed, 43 insertions, 6 deletions
diff --git a/kvmd/aiotools.py b/kvmd/aiotools.py index 85684758..add405f8 100644 --- a/kvmd/aiotools.py +++ b/kvmd/aiotools.py @@ -20,8 +20,9 @@ # ========================================================================== # -import asyncio +import os import signal +import asyncio import functools import types @@ -42,8 +43,6 @@ from .logging import get_logger # ===== -_ATTR_SHORT_TASK = "_aiotools_short_task" - _MethodT = TypeVar("_MethodT", bound=Callable[..., Any]) _RetvalT = TypeVar("_RetvalT") @@ -57,18 +56,56 @@ def atomic(method: _MethodT) -> _MethodT: # ===== +_ATTR_SHORT_TASK = "_aiotools_short_task" + + def create_short_task(coro: Coroutine) -> asyncio.Task: task = asyncio.create_task(coro) setattr(task, _ATTR_SHORT_TASK, True) return task -def get_short_tasks() -> List[asyncio.Task]: - return [ +async def wait_all_short_tasks() -> None: + await asyncio.gather(*[ task for task in asyncio.all_tasks() if getattr(task, _ATTR_SHORT_TASK, False) - ] + ], return_exceptions=True) + + +# ===== +_ATTR_DEADLY_TASK = "_aiotools_deadly_task" + + +def create_deadly_task(name: str, coro: Coroutine) -> asyncio.Task: + logger = get_logger() + + async def wrapper() -> None: + try: + await coro + raise RuntimeError(f"Deadly task is dead: {name}") + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Unhandled exception in deadly task, killing myself ...") + pid = os.getpid() + if pid == 1: + os._exit(1) # Docker workaround # pylint: disable=protected-access + else: + os.kill(pid, signal.SIGTERM) + + task = asyncio.create_task(wrapper()) + setattr(task, _ATTR_DEADLY_TASK, True) + return task + + +async def stop_all_deadly_tasks() -> None: + tasks: List[asyncio.Task] = [] + for task in asyncio.all_tasks(): + if getattr(task, _ATTR_DEADLY_TASK, False): + tasks.append(task) + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) # ===== |