diff options
author | Maxim Devaev <[email protected]> | 2022-06-14 18:18:21 +0300 |
---|---|---|
committer | Maxim Devaev <[email protected]> | 2022-06-14 18:18:21 +0300 |
commit | 88c7796551010faf30e7f7f843432af919ea7ce0 (patch) | |
tree | 9ca033133bad8e10878cd5b45510499588e40262 /kvmd | |
parent | 37e5118fff1e63e5af0183e7900bb6a7bc708a34 (diff) |
common websocket code
Diffstat (limited to 'kvmd')
-rw-r--r-- | kvmd/apps/kvmd/api/hid.py | 12 | ||||
-rw-r--r-- | kvmd/apps/kvmd/server.py | 83 | ||||
-rw-r--r-- | kvmd/htserver.py | 179 |
3 files changed, 145 insertions, 129 deletions
diff --git a/kvmd/apps/kvmd/api/hid.py b/kvmd/apps/kvmd/api/hid.py index ddbaadd8..a683921c 100644 --- a/kvmd/apps/kvmd/api/hid.py +++ b/kvmd/apps/kvmd/api/hid.py @@ -32,7 +32,6 @@ from typing import Callable from aiohttp.web import Request from aiohttp.web import Response -from aiohttp.web import WebSocketResponse from ....mouse import MouseRange @@ -42,6 +41,7 @@ from ....keyboard.printer import text_to_web_keys from ....htserver import exposed_http from ....htserver import exposed_ws from ....htserver import make_json_response +from ....htserver import WsSession from ....plugins.hid import BaseHid @@ -158,7 +158,7 @@ class HidApi: # ===== @exposed_ws("key") - async def __ws_key_handler(self, _: WebSocketResponse, event: Dict) -> None: + async def __ws_key_handler(self, _: WsSession, event: Dict) -> None: try: key = valid_hid_key(event["key"]) state = valid_bool(event["state"]) @@ -168,7 +168,7 @@ class HidApi: self.__hid.send_key_events([(key, state)]) @exposed_ws("mouse_button") - async def __ws_mouse_button_handler(self, _: WebSocketResponse, event: Dict) -> None: + async def __ws_mouse_button_handler(self, _: WsSession, event: Dict) -> None: try: button = valid_hid_mouse_button(event["button"]) state = valid_bool(event["state"]) @@ -177,7 +177,7 @@ class HidApi: self.__hid.send_mouse_button_event(button, state) @exposed_ws("mouse_move") - async def __ws_mouse_move_handler(self, _: WebSocketResponse, event: Dict) -> None: + async def __ws_mouse_move_handler(self, _: WsSession, event: Dict) -> None: try: to_x = valid_hid_mouse_move(event["to"]["x"]) to_y = valid_hid_mouse_move(event["to"]["y"]) @@ -186,11 +186,11 @@ class HidApi: self.__send_mouse_move_event_remapped(to_x, to_y) @exposed_ws("mouse_relative") - async def __ws_mouse_relative_handler(self, _: WebSocketResponse, event: Dict) -> None: + async def __ws_mouse_relative_handler(self, _: WsSession, event: Dict) -> None: self.__process_delta_ws_request(event, self.__hid.send_mouse_relative_event) @exposed_ws("mouse_wheel") - async def __ws_mouse_wheel_handler(self, _: WebSocketResponse, event: Dict) -> None: + async def __ws_mouse_wheel_handler(self, _: WsSession, event: Dict) -> None: self.__process_delta_ws_request(event, self.__hid.send_mouse_wheel_event) def __process_delta_ws_request(self, event: Dict, handler: Callable[[int, int], None]) -> None: diff --git a/kvmd/apps/kvmd/server.py b/kvmd/apps/kvmd/server.py index 20b2cb7e..5adc8d8d 100644 --- a/kvmd/apps/kvmd/server.py +++ b/kvmd/apps/kvmd/server.py @@ -27,7 +27,6 @@ import dataclasses from typing import Tuple from typing import List from typing import Dict -from typing import Set from typing import Callable from typing import Coroutine from typing import AsyncGenerator @@ -48,12 +47,8 @@ from ... import aioproc from ...htserver import HttpExposed from ...htserver import exposed_http from ...htserver import exposed_ws -from ...htserver import get_exposed_http -from ...htserver import get_exposed_ws from ...htserver import make_json_response -from ...htserver import send_ws_event -from ...htserver import broadcast_ws_event -from ...htserver import process_ws_messages +from ...htserver import WsSession from ...htserver import HttpServer from ...plugins import BasePlugin @@ -128,15 +123,6 @@ class _Component: # pylint: disable=too-many-instance-attributes assert self.event_type, self [email protected](frozen=True) -class _WsClient: - ws: WebSocketResponse - stream: bool - - def __str__(self) -> str: - return f"WsClient(id={id(self)}, stream={self.stream})" - - class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-instance-attributes def __init__( # pylint: disable=too-many-arguments,too-many-locals self, @@ -160,6 +146,8 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins stream_forever: bool, ) -> None: + super().__init__() + self.__auth_manager = auth_manager self.__hid = hid self.__streamer = streamer @@ -201,11 +189,6 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins RedfishApi(info_manager, atx), ] - self.__ws_handlers: Dict[str, Callable] = {} - - self.__ws_clients: Set[_WsClient] = set() - self.__ws_clients_lock = asyncio.Lock() - self.__streamer_notifier = aiotools.AioNotifier() self.__reset_streamer = False self.__new_streamer_params: Dict = {} @@ -244,11 +227,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins @exposed_http("GET", "/ws") async def __ws_handler(self, request: Request) -> WebSocketResponse: stream = valid_bool(request.query.get("stream", "true")) - ws = await self._make_ws_response(request) - client = _WsClient(ws, stream) - await self.__register_ws_client(client) - - try: + async with self._ws_session(request, stream=stream) as ws: stage1 = [ ("gpio_model_state", self.__user_gpio.get_model()), ("hid_keymaps_state", self.__hid_api.get_keymaps()), @@ -266,19 +245,15 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins )) for stage in [stage1, stage2]: await asyncio.gather(*[ - send_ws_event(ws, event_type, events.pop(event_type)) + ws.send_event(event_type, events.pop(event_type)) for (event_type, _) in stage ]) - - await send_ws_event(ws, "loop", {}) - await process_ws_messages(ws, self.__ws_handlers) - return ws - finally: - await self.__remove_ws_client(client) + await ws.send_event("loop", {}) + return (await self._ws_loop(ws)) @exposed_ws("ping") - async def __ws_ping_handler(self, ws: WebSocketResponse, _: Dict) -> None: - await send_ws_event(ws, "pong", {}) + async def __ws_ping_handler(self, ws: WsSession, _: Dict) -> None: + await ws.send_event("pong", {}) # ===== SYSTEM STUFF @@ -300,26 +275,16 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins if comp.poll_state: aiotools.create_deadly_task(f"{comp.name} [poller]", self.__poll_state(comp.event_type, comp.poll_state())) aiotools.create_deadly_task("Stream snapshoter", self.__stream_snapshoter()) - - for api in self.__apis: - for http_exposed in get_exposed_http(api): - self._add_exposed(http_exposed) - for ws_exposed in get_exposed_ws(api): - self.__ws_handlers[ws_exposed.event_type] = ws_exposed.handler + self._add_exposed(*self.__apis) async def _on_shutdown(self) -> None: logger = get_logger(0) - logger.info("Waiting short tasks ...") await aiotools.wait_all_short_tasks() - logger.info("Stopping system tasks ...") await aiotools.stop_all_deadly_tasks() - logger.info("Disconnecting clients ...") - for client in list(self.__ws_clients): - await self.__remove_ws_client(client) - + await self._close_all_wss() logger.info("On-Shutdown complete") async def _on_cleanup(self) -> None: @@ -333,25 +298,18 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins logger.exception("Cleanup error on %s", comp.name) logger.info("On-Cleanup complete") - async def __register_ws_client(self, client: _WsClient) -> None: - async with self.__ws_clients_lock: - self.__ws_clients.add(client) - get_logger().info("Registered new client socket: %s; clients now: %d", client, len(self.__ws_clients)) + async def _on_ws_opened(self) -> None: await self.__streamer_notifier.notify() - async def __remove_ws_client(self, client: _WsClient) -> None: - async with self.__ws_clients_lock: - self.__hid.clear_events() - try: - self.__ws_clients.remove(client) - get_logger().info("Removed client socket: %s; clients now: %d", client, len(self.__ws_clients)) - await client.ws.close() - except Exception: - pass + async def _on_ws_closed(self) -> None: + self.__hid.clear_events() await self.__streamer_notifier.notify() def __has_stream_clients(self) -> bool: - return bool(sum(map(operator.attrgetter("stream"), self.__ws_clients))) + return bool(sum(map( + (lambda ws: ws.kwargs["stream"]), + self._get_wss(), + ))) # ===== SYSTEM TASKS @@ -379,10 +337,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins async def __poll_state(self, event_type: str, poller: AsyncGenerator[Dict, None]) -> None: async for state in poller: - await broadcast_ws_event([ - client.ws - for client in list(self.__ws_clients) - ], event_type, state) + await self._broadcast_ws_event(event_type, state) async def __stream_snapshoter(self) -> None: await self.__snapshoter.run( diff --git a/kvmd/htserver.py b/kvmd/htserver.py index c24837d2..d3919310 100644 --- a/kvmd/htserver.py +++ b/kvmd/htserver.py @@ -23,6 +23,7 @@ import os import socket import asyncio +import contextlib import dataclasses import inspect import json @@ -31,7 +32,9 @@ from typing import Tuple from typing import List from typing import Dict from typing import Callable +from typing import AsyncGenerator from typing import Optional +from typing import Any from aiohttp.web import BaseRequest from aiohttp.web import Request @@ -103,7 +106,7 @@ def exposed_http(http_method: str, path: str, auth_required: bool=True) -> Calla return set_attrs -def get_exposed_http(obj: object) -> List[HttpExposed]: +def _get_exposed_http(obj: object) -> List[HttpExposed]: return [ HttpExposed( method=getattr(handler, _HTTP_METHOD), @@ -135,7 +138,7 @@ def exposed_ws(event_type: str) -> Callable: return set_attrs -def get_exposed_ws(obj: object) -> List[WsExposed]: +def _get_exposed_ws(obj: object) -> List[WsExposed]: return [ WsExposed( event_type=getattr(handler, _WS_EVENT_TYPE), @@ -206,57 +209,6 @@ async def stream_json_exception(response: StreamResponse, err: Exception) -> Non # ===== -async def send_ws_event(ws: WebSocketResponse, event_type: str, event: Optional[Dict]) -> None: - await ws.send_str(json.dumps({ - "event_type": event_type, - "event": event, - })) - - -async def broadcast_ws_event(wss: List[WebSocketResponse], event_type: str, event: Optional[Dict]) -> None: - if wss: - await asyncio.gather(*[ - send_ws_event(ws, event_type, event) - for ws in wss - if ( - not ws.closed - and ws._req is not None # pylint: disable=protected-access - and ws._req.transport is not None # pylint: disable=protected-access - ) - ], return_exceptions=True) - - -def _parse_ws_event(msg: str) -> Tuple[str, Dict]: - data = json.loads(msg) - if not isinstance(data, dict): - raise RuntimeError("Top-level event structure is not a dict") - event_type = data.get("event_type") - if not isinstance(event_type, str): - raise RuntimeError("event_type must be a string") - event = data["event"] - if not isinstance(event, dict): - raise RuntimeError("event must be a dict") - return (event_type, event) - - -async def process_ws_messages(ws: WebSocketResponse, handlers: Dict[str, Callable]) -> None: - logger = get_logger(1) - async for msg in ws: - if msg.type != WSMsgType.TEXT: - break - try: - (event_type, event) = _parse_ws_event(msg.data) - except Exception as err: - logger.error("Can't parse JSON event from websocket: %r", err) - else: - handler = handlers.get(event_type) - if handler: - await handler(ws, event) - else: - logger.error("Unknown websocket event: %r", msg.data) - - -# ===== _REQUEST_AUTH_INFO = "_kvmd_auth_info" @@ -272,7 +224,28 @@ def set_request_auth_info(request: BaseRequest, info: str) -> None: # ===== [email protected](frozen=True) +class WsSession: + wsr: WebSocketResponse + kwargs: Dict[str, Any] + + def __str__(self) -> str: + return f"WsSession(id={id(self)}, {self.kwargs})" + + async def send_event(self, event_type: str, event: Optional[Dict]) -> None: + await self.wsr.send_str(json.dumps({ + "event_type": event_type, + "event": event, + })) + + class HttpServer: + def __init__(self) -> None: + self.__ws_heartbeat: Optional[float] = None + self.__ws_handlers: Dict[str, Callable] = {} + self.__ws_sessions: List[WsSession] = [] + self.__ws_sessions_lock = asyncio.Lock() + def run( self, unix_path: str, @@ -282,7 +255,7 @@ class HttpServer: access_log_format: str, ) -> None: - self.__heartbeat = heartbeat # pylint: disable=attribute-defined-outside-init + self.__ws_heartbeat = heartbeat if unix_rm and os.path.exists(unix_path): os.remove(unix_path) @@ -302,7 +275,14 @@ class HttpServer: # ===== - def _add_exposed(self, exposed: HttpExposed) -> None: + def _add_exposed(self, *objs: object) -> None: + for obj in objs: + for http_exposed in _get_exposed_http(obj): + self.__add_exposed_http(http_exposed) + for ws_exposed in _get_exposed_ws(obj): + self.__add_exposed_ws(ws_exposed) + + def __add_exposed_http(self, exposed: HttpExposed) -> None: async def wrapper(request: Request) -> Response: try: await self._check_request_auth(exposed, request) @@ -315,10 +295,85 @@ class HttpServer: return make_json_exception(err) self.__app.router.add_route(exposed.method, exposed.path, wrapper) - async def _make_ws_response(self, request: Request) -> WebSocketResponse: - ws = WebSocketResponse(heartbeat=self.__heartbeat) - await ws.prepare(request) - return ws + def __add_exposed_ws(self, exposed: WsExposed) -> None: + self.__ws_handlers[exposed.event_type] = exposed.handler + + # ===== + + @contextlib.asynccontextmanager + async def _ws_session(self, request: Request, **kwargs: Any) -> AsyncGenerator[WsSession, None]: + assert self.__ws_heartbeat is not None + wsr = WebSocketResponse(heartbeat=self.__ws_heartbeat) + await wsr.prepare(request) + ws = WsSession(wsr, kwargs) + + async with self.__ws_sessions_lock: + self.__ws_sessions.append(ws) + get_logger(2).info("Registered new client session: %s; clients now: %d", ws, len(self.__ws_sessions)) + + try: + await self._on_ws_opened() + yield ws + finally: + await self.__close_ws(ws) + + async def _ws_loop(self, ws: WsSession) -> WebSocketResponse: + logger = get_logger() + async for msg in ws.wsr: + if msg.type != WSMsgType.TEXT: + break + try: + (event_type, event) = self.__parse_ws_event(msg.data) + except Exception as err: + logger.error("Can't parse JSON event from websocket: %r", err) + else: + handler = self.__ws_handlers.get(event_type) + if handler: + await handler(ws, event) + else: + logger.error("Unknown websocket event: %r", msg.data) + return ws.wsr + + async def _broadcast_ws_event(self, event_type: str, event: Optional[Dict]) -> None: + if self.__ws_sessions: + await asyncio.gather(*[ + ws.send_event(event_type, event) + for ws in self.__ws_sessions + if ( + not ws.wsr.closed + and ws.wsr._req is not None # pylint: disable=protected-access + and ws.wsr._req.transport is not None # pylint: disable=protected-access + ) + ], return_exceptions=True) + + async def _close_all_wss(self) -> None: + for ws in self._get_wss(): + await self.__close_ws(ws) + + def _get_wss(self) -> List[WsSession]: + return list(self.__ws_sessions) + + async def __close_ws(self, ws: WsSession) -> None: + async with self.__ws_sessions_lock: + try: + self.__ws_sessions.remove(ws) + get_logger(3).info("Removed client socket: %s; clients now: %d", ws, len(self.__ws_sessions)) + await ws.wsr.close() + except Exception: + pass + await self._on_ws_closed() + + def __parse_ws_event(self, msg: str) -> Tuple[str, Dict]: + data = json.loads(msg) + if not isinstance(data, dict): + raise RuntimeError("Top-level event structure is not a dict") + event_type = data.get("event_type") + if not isinstance(event_type, str): + raise RuntimeError("event_type must be a string") + event = data["event"] + if not isinstance(event, dict): + raise RuntimeError("event must be a dict") + return (event_type, event) # ===== @@ -334,6 +389,12 @@ class HttpServer: async def _on_cleanup(self) -> None: pass + async def _on_ws_opened(self) -> None: + pass + + async def _on_ws_closed(self) -> None: + pass + # ===== async def __make_app(self) -> Application: |