diff options
Diffstat (limited to 'kvmd')
-rw-r--r-- | kvmd/apps/kvmd/server.py | 58 | ||||
-rw-r--r-- | kvmd/htserver.py | 36 |
2 files changed, 57 insertions, 37 deletions
diff --git a/kvmd/apps/kvmd/server.py b/kvmd/apps/kvmd/server.py index f33412d9..e7bc744d 100644 --- a/kvmd/apps/kvmd/server.py +++ b/kvmd/apps/kvmd/server.py @@ -25,7 +25,6 @@ import signal import asyncio import operator import dataclasses -import json from typing import Tuple from typing import List @@ -56,6 +55,9 @@ from ...htserver import get_exposed_http from ...htserver import get_exposed_ws from ...htserver import make_json_response from ...htserver import make_json_exception +from ...htserver import send_ws_event +from ...htserver import broadcast_ws_event +from ...htserver import parse_ws_event from ...htserver import HttpServer from ...plugins import BasePlugin @@ -279,28 +281,25 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins )) for stage in [stage1, stage2]: await asyncio.gather(*[ - self.__send_event(client.ws, event_type, events.pop(event_type)) + send_ws_event(client.ws, event_type, events.pop(event_type)) for (event_type, _) in stage ]) - await self.__send_event(client.ws, "loop", {}) + await send_ws_event(client.ws, "loop", {}) async for msg in client.ws: - if msg.type == aiohttp.web.WSMsgType.TEXT: - try: - data = json.loads(msg.data) - event_type = data.get("event_type") - event = data["event"] - except Exception as err: - logger.error("Can't parse JSON event from websocket: %r", err) - else: - handler = self.__ws_handlers.get(event_type) - if handler: - await handler(client.ws, event) - else: - logger.error("Unknown websocket event: %r", data) - else: + if msg.type != aiohttp.web.WSMsgType.TEXT: break + try: + (event_type, event) = parse_ws_event(msg.data) + except Exception as err: + logger.error("Can't parse JSON event from websocket: %r", err) + else: + handler = self.__ws_handlers.get(event_type) + if handler: + await handler(client.ws, event) + else: + logger.error("Unknown websocket event: %r", msg.data) return client.ws finally: @@ -308,7 +307,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins @exposed_ws("ping") async def __ws_ping_handler(self, ws: aiohttp.web.WebSocketResponse, _: Dict) -> None: - await self.__send_event(ws, "pong", {}) + await send_ws_event(ws, "pong", {}) # ===== SYSTEM STUFF @@ -390,24 +389,6 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins logger.exception("Cleanup error on %s", comp.name) logger.info("On-Cleanup complete") - async def __send_event(self, ws: aiohttp.web.WebSocketResponse, event_type: str, event: Optional[Dict]) -> None: - await ws.send_str(json.dumps({ - "event_type": event_type, - "event": event, - })) - - async def __broadcast_event(self, event_type: str, event: Optional[Dict]) -> None: - if self.__ws_clients: - await asyncio.gather(*[ - self.__send_event(client.ws, event_type, event) - for client in list(self.__ws_clients) - if ( - not client.ws.closed - and client.ws._req is not None # pylint: disable=protected-access - and client.ws._req.transport is not None # pylint: disable=protected-access - ) - ], return_exceptions=True) - async def __register_ws_client(self, client: _WsClient) -> None: async with self.__ws_clients_lock: self.__ws_clients.add(client) @@ -454,7 +435,10 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins async def __poll_state(self, event_type: str, poller: AsyncGenerator[Dict, None]) -> None: async for state in poller: - await self.__broadcast_event(event_type, state) + await broadcast_ws_event([ + client.ws + for client in list(self.__ws_clients) + ], event_type, state) async def __stream_snapshoter(self) -> None: await self.__snapshoter.run( diff --git a/kvmd/htserver.py b/kvmd/htserver.py index f1b51b38..6196b7a1 100644 --- a/kvmd/htserver.py +++ b/kvmd/htserver.py @@ -27,6 +27,7 @@ import dataclasses import inspect import json +from typing import Tuple from typing import List from typing import Dict from typing import Callable @@ -36,6 +37,7 @@ from aiohttp.web import BaseRequest from aiohttp.web import Request from aiohttp.web import Response from aiohttp.web import StreamResponse +from aiohttp.web import WebSocketResponse from aiohttp.web import Application from aiohttp.web import run_app from aiohttp.web import normalize_path_middleware @@ -198,6 +200,40 @@ async def stream_json_exception(response: StreamResponse, err: Exception) -> Non # ===== +async def send_ws_event(ws: WebSocketResponse, event_type: str, event: Optional[Dict]) -> None: + await ws.send_str(json.dumps({ + "event_type": event_type, + "event": event, + })) + + +async def broadcast_ws_event(wss: List[WebSocketResponse], event_type: str, event: Optional[Dict]) -> None: + if wss: + await asyncio.gather(*[ + send_ws_event(ws, event_type, event) + for ws in wss + if ( + not ws.closed + and ws._req is not None # pylint: disable=protected-access + and ws._req.transport is not None # pylint: disable=protected-access + ) + ], return_exceptions=True) + + +def parse_ws_event(msg: str) -> Tuple[str, Dict]: + data = json.loads(msg) + if not isinstance(data, dict): + raise RuntimeError("Top-level event structure is not a dict") + event_type = data.get("event_type") + if not isinstance(event_type, str): + raise RuntimeError("event_type must be a string") + event = data["event"] + if not isinstance(event, dict): + raise RuntimeError("event must be a dict") + return (event_type, event) + + +# ===== _REQUEST_AUTH_INFO = "_kvmd_auth_info" |