summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--kvmd/aiotools.py49
-rw-r--r--kvmd/apps/janus/runner.py5
-rw-r--r--kvmd/apps/vnc/server.py61
3 files changed, 60 insertions, 55 deletions
diff --git a/kvmd/aiotools.py b/kvmd/aiotools.py
index d085cf89..6895421c 100644
--- a/kvmd/aiotools.py
+++ b/kvmd/aiotools.py
@@ -44,6 +44,41 @@ from .logging import get_logger
# =====
+def run(coro: Coroutine, final: Optional[Coroutine]=None) -> None:
+ # https://github.com/aio-libs/aiohttp/blob/a1d4dac1d/aiohttp/web.py#L515
+
+ def sigint_handler() -> None:
+ raise KeyboardInterrupt()
+
+ def sigterm_handler() -> None:
+ raise SystemExit()
+
+ loop = asyncio.get_event_loop()
+ loop.add_signal_handler(signal.SIGINT, sigint_handler)
+ loop.add_signal_handler(signal.SIGTERM, sigterm_handler)
+
+ main_task = loop.create_task(coro)
+ try:
+ loop.run_until_complete(main_task)
+ except (SystemExit, KeyboardInterrupt):
+ pass
+ finally:
+ main_task.cancel()
+ loop.run_until_complete(asyncio.gather(main_task, return_exceptions=True))
+
+ if final is not None:
+ loop.run_until_complete(final)
+
+ tasks = asyncio.all_tasks(loop)
+ for task in tasks:
+ task.cancel()
+ loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
+
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ loop.close()
+
+
+# =====
_FunctionT = TypeVar("_FunctionT", bound=Callable[..., Any])
_RetvalT = TypeVar("_RetvalT")
@@ -152,20 +187,6 @@ async def close_writer(writer: asyncio.StreamWriter) -> bool:
# =====
-def run(coro: Coroutine) -> None:
- def sigint_handler() -> None:
- raise KeyboardInterrupt()
-
- def sigterm_handler() -> None:
- raise SystemExit()
-
- loop = asyncio.get_event_loop()
- loop.add_signal_handler(signal.SIGINT, sigint_handler)
- loop.add_signal_handler(signal.SIGTERM, sigterm_handler)
- loop.run_until_complete(coro)
-
-
-# =====
class AioNotifier:
def __init__(self) -> None:
self.__queue: "asyncio.Queue[None]" = asyncio.Queue()
diff --git a/kvmd/apps/janus/runner.py b/kvmd/apps/janus/runner.py
index a0110f47..204f8a9e 100644
--- a/kvmd/apps/janus/runner.py
+++ b/kvmd/apps/janus/runner.py
@@ -61,10 +61,7 @@ class JanusRunner: # pylint: disable=too-many-instance-attributes
def run(self) -> None:
logger = get_logger(0)
logger.info("Starting Janus Runner ...")
- try:
- aiotools.run(self.__run())
- except (SystemExit, KeyboardInterrupt):
- aiotools.run(self.__stop_janus())
+ aiotools.run(self.__run(), self.__stop_janus())
logger.info("Bye-bye")
# =====
diff --git a/kvmd/apps/vnc/server.py b/kvmd/apps/vnc/server.py
index 4c62f2a3..434fd4f8 100644
--- a/kvmd/apps/vnc/server.py
+++ b/kvmd/apps/vnc/server.py
@@ -484,6 +484,8 @@ class VncServer: # pylint: disable=too-many-instance-attributes
none_auth_only=none_auth_only,
shared_params=shared_params,
).run()
+ except asyncio.CancelledError:
+ raise
except Exception:
logger.exception("[entry] %s: Unhandled exception in client task", remote)
finally:
@@ -492,41 +494,26 @@ class VncServer: # pylint: disable=too-many-instance-attributes
self.__handle_client = handle_client
- def run(self) -> None:
- logger = get_logger(0)
- loop = asyncio.get_event_loop()
- try:
- if not loop.run_until_complete(self.__vnc_auth_manager.read_credentials())[1]:
- raise SystemExit(1)
-
- logger.info("Listening VNC on TCP [%s]:%d ...", self.__host, self.__port)
-
- (family, _, _, _, addr) = socket.getaddrinfo(self.__host, self.__port, type=socket.SOCK_STREAM)[0]
- with contextlib.closing(socket.socket(family, socket.SOCK_STREAM)) as sock:
- if family == socket.AF_INET6:
- sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- sock.bind(addr)
-
- server_kwargs = ({"loop": loop} if sys.version_info < (3, 10) else {})
- server = loop.run_until_complete(asyncio.start_server(
- client_connected_cb=self.__handle_client,
- sock=sock,
- backlog=self.__max_clients,
- **server_kwargs, # type: ignore
- ))
+ async def __inner_run(self) -> None:
+ if not (await self.__vnc_auth_manager.read_credentials())[1]:
+ raise SystemExit(1)
+
+ get_logger(0).info("Listening VNC on TCP [%s]:%d ...", self.__host, self.__port)
+ (family, _, _, _, addr) = socket.getaddrinfo(self.__host, self.__port, type=socket.SOCK_STREAM)[0]
+ with contextlib.closing(socket.socket(family, socket.SOCK_STREAM)) as sock:
+ if family == socket.AF_INET6:
+ sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind(addr)
+
+ server = await asyncio.start_server(
+ client_connected_cb=self.__handle_client,
+ sock=sock,
+ backlog=self.__max_clients,
+ )
+ async with server:
+ await server.serve_forever()
- try:
- loop.run_forever()
- except (SystemExit, KeyboardInterrupt):
- pass
- finally:
- server.close()
- loop.run_until_complete(server.wait_closed())
- finally:
- tasks = asyncio.all_tasks(loop)
- for task in tasks:
- task.cancel()
- loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
- loop.close()
- logger.info("Bye-bye")
+ def run(self) -> None:
+ aiotools.run(self.__inner_run())
+ get_logger().info("Bye-bye")