diff options
-rw-r--r-- | kvmd/aiotools.py | 49 | ||||
-rw-r--r-- | kvmd/apps/kvmd/server.py | 35 |
2 files changed, 50 insertions, 34 deletions
diff --git a/kvmd/aiotools.py b/kvmd/aiotools.py index 85684758..add405f8 100644 --- a/kvmd/aiotools.py +++ b/kvmd/aiotools.py @@ -20,8 +20,9 @@ # ========================================================================== # -import asyncio +import os import signal +import asyncio import functools import types @@ -42,8 +43,6 @@ from .logging import get_logger # ===== -_ATTR_SHORT_TASK = "_aiotools_short_task" - _MethodT = TypeVar("_MethodT", bound=Callable[..., Any]) _RetvalT = TypeVar("_RetvalT") @@ -57,18 +56,56 @@ def atomic(method: _MethodT) -> _MethodT: # ===== +_ATTR_SHORT_TASK = "_aiotools_short_task" + + def create_short_task(coro: Coroutine) -> asyncio.Task: task = asyncio.create_task(coro) setattr(task, _ATTR_SHORT_TASK, True) return task -def get_short_tasks() -> List[asyncio.Task]: - return [ +async def wait_all_short_tasks() -> None: + await asyncio.gather(*[ task for task in asyncio.all_tasks() if getattr(task, _ATTR_SHORT_TASK, False) - ] + ], return_exceptions=True) + + +# ===== +_ATTR_DEADLY_TASK = "_aiotools_deadly_task" + + +def create_deadly_task(name: str, coro: Coroutine) -> asyncio.Task: + logger = get_logger() + + async def wrapper() -> None: + try: + await coro + raise RuntimeError(f"Deadly task is dead: {name}") + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Unhandled exception in deadly task, killing myself ...") + pid = os.getpid() + if pid == 1: + os._exit(1) # Docker workaround # pylint: disable=protected-access + else: + os.kill(pid, signal.SIGTERM) + + task = asyncio.create_task(wrapper()) + setattr(task, _ATTR_DEADLY_TASK, True) + return task + + +async def stop_all_deadly_tasks() -> None: + tasks: List[asyncio.Task] = [] + for task in asyncio.all_tasks(): + if getattr(task, _ATTR_DEADLY_TASK, False): + tasks.append(task) + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) # ===== diff --git a/kvmd/apps/kvmd/server.py b/kvmd/apps/kvmd/server.py index f3dc222b..20b2cb7e 100644 --- a/kvmd/apps/kvmd/server.py +++ b/kvmd/apps/kvmd/server.py @@ -20,8 +20,6 @@ # ========================================================================== # -import os -import signal import asyncio import operator import dataclasses @@ -208,8 +206,6 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins self.__ws_clients: Set[_WsClient] = set() self.__ws_clients_lock = asyncio.Lock() - self.__system_tasks: List[asyncio.Task] = [] - self.__streamer_notifier = aiotools.AioNotifier() self.__reset_streamer = False self.__new_streamer_params: Dict = {} @@ -297,13 +293,13 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins await check_request_auth(self.__auth_manager, exposed, request) async def _init_app(self) -> None: - self.__run_system_task(self.__stream_controller) + aiotools.create_deadly_task("Stream controller", self.__stream_controller()) for comp in self.__components: if comp.systask: - self.__run_system_task(comp.systask) + aiotools.create_deadly_task(comp.name, comp.systask()) if comp.poll_state: - self.__run_system_task(self.__poll_state, comp.event_type, comp.poll_state()) - self.__run_system_task(self.__stream_snapshoter) + aiotools.create_deadly_task(f"{comp.name} [poller]", self.__poll_state(comp.event_type, comp.poll_state())) + aiotools.create_deadly_task("Stream snapshoter", self.__stream_snapshoter()) for api in self.__apis: for http_exposed in get_exposed_http(api): @@ -311,31 +307,14 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins for ws_exposed in get_exposed_ws(api): self.__ws_handlers[ws_exposed.event_type] = ws_exposed.handler - def __run_system_task(self, method: Callable, *args: Any) -> None: - async def wrapper() -> None: - try: - await method(*args) - raise RuntimeError(f"Dead system task: {method}" - f"({', '.join(getattr(arg, '__name__', str(arg)) for arg in args)})") - except asyncio.CancelledError: - pass - except Exception: - get_logger().exception("Unhandled exception, killing myself ...") - os.kill(os.getpid(), signal.SIGTERM) - self.__system_tasks.append(asyncio.create_task(wrapper())) - async def _on_shutdown(self) -> None: logger = get_logger(0) logger.info("Waiting short tasks ...") - await asyncio.gather(*aiotools.get_short_tasks(), return_exceptions=True) - - logger.info("Cancelling system tasks ...") - for task in self.__system_tasks: - task.cancel() + await aiotools.wait_all_short_tasks() - logger.info("Waiting system tasks ...") - await asyncio.gather(*self.__system_tasks, return_exceptions=True) + logger.info("Stopping system tasks ...") + await aiotools.stop_all_deadly_tasks() logger.info("Disconnecting clients ...") for client in list(self.__ws_clients): |