diff options
Diffstat (limited to 'kvmd/htserver.py')
-rw-r--r-- | kvmd/htserver.py | 37 |
1 files changed, 26 insertions, 11 deletions
diff --git a/kvmd/htserver.py b/kvmd/htserver.py index 351c1328..63c82fcb 100644 --- a/kvmd/htserver.py +++ b/kvmd/htserver.py @@ -232,6 +232,16 @@ async def send_ws_event( })) +async def send_ws_bin( + wsr: (ClientWebSocketResponse | WebSocketResponse), + op: int, + data: bytes, +) -> None: + + assert 0 <= op <= 255 + await wsr.send_bytes(op.to_bytes() + data) + + def parse_ws_event(msg: str) -> tuple[str, dict]: data = json.loads(msg) if not isinstance(data, dict): @@ -264,14 +274,24 @@ def set_request_auth_info(req: BaseRequest, info: str) -> None: @dataclasses.dataclass(frozen=True) class WsSession: wsr: WebSocketResponse - kwargs: dict[str, Any] + kwargs: dict[str, Any] = dataclasses.field(hash=False) def __str__(self) -> str: return f"WsSession(id={id(self)}, {self.kwargs})" + def is_alive(self) -> bool: + return ( + not self.wsr.closed + and self.wsr._req is not None # pylint: disable=protected-access + and self.wsr._req.transport is not None # pylint: disable=protected-access + ) + async def send_event(self, event_type: str, event: (dict | None)) -> None: await send_ws_event(self.wsr, event_type, event) + async def send_bin(self, op: int, data: bytes) -> None: + await send_ws_bin(self.wsr, op, data) + class HttpServer: def __init__(self) -> None: @@ -353,7 +373,7 @@ class HttpServer: get_logger(2).info("Registered new client session: %s; clients now: %d", ws, len(self.__ws_sessions)) try: - await self._on_ws_opened() + await self._on_ws_opened(ws) yield ws finally: await aiotools.shield_fg(self.__close_ws(ws)) @@ -389,12 +409,7 @@ class HttpServer: await asyncio.gather(*[ ws.send_event(event_type, event) for ws in self.__ws_sessions - if ( - not ws.wsr.closed - and ws.wsr._req is not None # pylint: disable=protected-access - and ws.wsr._req.transport is not None # pylint: disable=protected-access - and (legacy is None or ws.kwargs.get("legacy") == legacy) - ) + if ws.is_alive() and (legacy is None or ws.kwargs.get("legacy") == legacy) ], return_exceptions=True) async def _close_all_wss(self) -> bool: @@ -414,7 +429,7 @@ class HttpServer: await ws.wsr.close() except Exception: pass - await self._on_ws_closed() + await self._on_ws_closed(ws) # ===== @@ -430,10 +445,10 @@ class HttpServer: async def _on_cleanup(self) -> None: pass - async def _on_ws_opened(self) -> None: + async def _on_ws_opened(self, ws: WsSession) -> None: pass - async def _on_ws_closed(self) -> None: + async def _on_ws_closed(self, ws: WsSession) -> None: pass # ===== |