summaryrefslogtreecommitdiff
path: root/kvmd/htserver.py
diff options
context:
space:
mode:
Diffstat (limited to 'kvmd/htserver.py')
-rw-r--r--kvmd/htserver.py37
1 files changed, 26 insertions, 11 deletions
diff --git a/kvmd/htserver.py b/kvmd/htserver.py
index 351c1328..63c82fcb 100644
--- a/kvmd/htserver.py
+++ b/kvmd/htserver.py
@@ -232,6 +232,16 @@ async def send_ws_event(
}))
+async def send_ws_bin(
+ wsr: (ClientWebSocketResponse | WebSocketResponse),
+ op: int,
+ data: bytes,
+) -> None:
+
+ assert 0 <= op <= 255
+ await wsr.send_bytes(op.to_bytes() + data)
+
+
def parse_ws_event(msg: str) -> tuple[str, dict]:
data = json.loads(msg)
if not isinstance(data, dict):
@@ -264,14 +274,24 @@ def set_request_auth_info(req: BaseRequest, info: str) -> None:
@dataclasses.dataclass(frozen=True)
class WsSession:
wsr: WebSocketResponse
- kwargs: dict[str, Any]
+ kwargs: dict[str, Any] = dataclasses.field(hash=False)
def __str__(self) -> str:
return f"WsSession(id={id(self)}, {self.kwargs})"
+ def is_alive(self) -> bool:
+ return (
+ not self.wsr.closed
+ and self.wsr._req is not None # pylint: disable=protected-access
+ and self.wsr._req.transport is not None # pylint: disable=protected-access
+ )
+
async def send_event(self, event_type: str, event: (dict | None)) -> None:
await send_ws_event(self.wsr, event_type, event)
+ async def send_bin(self, op: int, data: bytes) -> None:
+ await send_ws_bin(self.wsr, op, data)
+
class HttpServer:
def __init__(self) -> None:
@@ -353,7 +373,7 @@ class HttpServer:
get_logger(2).info("Registered new client session: %s; clients now: %d", ws, len(self.__ws_sessions))
try:
- await self._on_ws_opened()
+ await self._on_ws_opened(ws)
yield ws
finally:
await aiotools.shield_fg(self.__close_ws(ws))
@@ -389,12 +409,7 @@ class HttpServer:
await asyncio.gather(*[
ws.send_event(event_type, event)
for ws in self.__ws_sessions
- if (
- not ws.wsr.closed
- and ws.wsr._req is not None # pylint: disable=protected-access
- and ws.wsr._req.transport is not None # pylint: disable=protected-access
- and (legacy is None or ws.kwargs.get("legacy") == legacy)
- )
+ if ws.is_alive() and (legacy is None or ws.kwargs.get("legacy") == legacy)
], return_exceptions=True)
async def _close_all_wss(self) -> bool:
@@ -414,7 +429,7 @@ class HttpServer:
await ws.wsr.close()
except Exception:
pass
- await self._on_ws_closed()
+ await self._on_ws_closed(ws)
# =====
@@ -430,10 +445,10 @@ class HttpServer:
async def _on_cleanup(self) -> None:
pass
- async def _on_ws_opened(self) -> None:
+ async def _on_ws_opened(self, ws: WsSession) -> None:
pass
- async def _on_ws_closed(self) -> None:
+ async def _on_ws_closed(self, ws: WsSession) -> None:
pass
# =====