diff options
-rw-r--r-- | kvmd/aiotools.py | 49 | ||||
-rw-r--r-- | kvmd/apps/janus/runner.py | 5 | ||||
-rw-r--r-- | kvmd/apps/vnc/server.py | 61 |
3 files changed, 60 insertions, 55 deletions
diff --git a/kvmd/aiotools.py b/kvmd/aiotools.py index d085cf89..6895421c 100644 --- a/kvmd/aiotools.py +++ b/kvmd/aiotools.py @@ -44,6 +44,41 @@ from .logging import get_logger # ===== +def run(coro: Coroutine, final: Optional[Coroutine]=None) -> None: + # https://github.com/aio-libs/aiohttp/blob/a1d4dac1d/aiohttp/web.py#L515 + + def sigint_handler() -> None: + raise KeyboardInterrupt() + + def sigterm_handler() -> None: + raise SystemExit() + + loop = asyncio.get_event_loop() + loop.add_signal_handler(signal.SIGINT, sigint_handler) + loop.add_signal_handler(signal.SIGTERM, sigterm_handler) + + main_task = loop.create_task(coro) + try: + loop.run_until_complete(main_task) + except (SystemExit, KeyboardInterrupt): + pass + finally: + main_task.cancel() + loop.run_until_complete(asyncio.gather(main_task, return_exceptions=True)) + + if final is not None: + loop.run_until_complete(final) + + tasks = asyncio.all_tasks(loop) + for task in tasks: + task.cancel() + loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() + + +# ===== _FunctionT = TypeVar("_FunctionT", bound=Callable[..., Any]) _RetvalT = TypeVar("_RetvalT") @@ -152,20 +187,6 @@ async def close_writer(writer: asyncio.StreamWriter) -> bool: # ===== -def run(coro: Coroutine) -> None: - def sigint_handler() -> None: - raise KeyboardInterrupt() - - def sigterm_handler() -> None: - raise SystemExit() - - loop = asyncio.get_event_loop() - loop.add_signal_handler(signal.SIGINT, sigint_handler) - loop.add_signal_handler(signal.SIGTERM, sigterm_handler) - loop.run_until_complete(coro) - - -# ===== class AioNotifier: def __init__(self) -> None: self.__queue: "asyncio.Queue[None]" = asyncio.Queue() diff --git a/kvmd/apps/janus/runner.py b/kvmd/apps/janus/runner.py index a0110f47..204f8a9e 100644 --- a/kvmd/apps/janus/runner.py +++ b/kvmd/apps/janus/runner.py @@ -61,10 +61,7 @@ class JanusRunner: # pylint: disable=too-many-instance-attributes def run(self) -> None: logger = get_logger(0) logger.info("Starting Janus Runner ...") - try: - aiotools.run(self.__run()) - except (SystemExit, KeyboardInterrupt): - aiotools.run(self.__stop_janus()) + aiotools.run(self.__run(), self.__stop_janus()) logger.info("Bye-bye") # ===== diff --git a/kvmd/apps/vnc/server.py b/kvmd/apps/vnc/server.py index 4c62f2a3..434fd4f8 100644 --- a/kvmd/apps/vnc/server.py +++ b/kvmd/apps/vnc/server.py @@ -484,6 +484,8 @@ class VncServer: # pylint: disable=too-many-instance-attributes none_auth_only=none_auth_only, shared_params=shared_params, ).run() + except asyncio.CancelledError: + raise except Exception: logger.exception("[entry] %s: Unhandled exception in client task", remote) finally: @@ -492,41 +494,26 @@ class VncServer: # pylint: disable=too-many-instance-attributes self.__handle_client = handle_client - def run(self) -> None: - logger = get_logger(0) - loop = asyncio.get_event_loop() - try: - if not loop.run_until_complete(self.__vnc_auth_manager.read_credentials())[1]: - raise SystemExit(1) - - logger.info("Listening VNC on TCP [%s]:%d ...", self.__host, self.__port) - - (family, _, _, _, addr) = socket.getaddrinfo(self.__host, self.__port, type=socket.SOCK_STREAM)[0] - with contextlib.closing(socket.socket(family, socket.SOCK_STREAM)) as sock: - if family == socket.AF_INET6: - sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(addr) - - server_kwargs = ({"loop": loop} if sys.version_info < (3, 10) else {}) - server = loop.run_until_complete(asyncio.start_server( - client_connected_cb=self.__handle_client, - sock=sock, - backlog=self.__max_clients, - **server_kwargs, # type: ignore - )) + async def __inner_run(self) -> None: + if not (await self.__vnc_auth_manager.read_credentials())[1]: + raise SystemExit(1) + + get_logger(0).info("Listening VNC on TCP [%s]:%d ...", self.__host, self.__port) + (family, _, _, _, addr) = socket.getaddrinfo(self.__host, self.__port, type=socket.SOCK_STREAM)[0] + with contextlib.closing(socket.socket(family, socket.SOCK_STREAM)) as sock: + if family == socket.AF_INET6: + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(addr) + + server = await asyncio.start_server( + client_connected_cb=self.__handle_client, + sock=sock, + backlog=self.__max_clients, + ) + async with server: + await server.serve_forever() - try: - loop.run_forever() - except (SystemExit, KeyboardInterrupt): - pass - finally: - server.close() - loop.run_until_complete(server.wait_closed()) - finally: - tasks = asyncio.all_tasks(loop) - for task in tasks: - task.cancel() - loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) - loop.close() - logger.info("Bye-bye") + def run(self) -> None: + aiotools.run(self.__inner_run()) + get_logger().info("Bye-bye") |