summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMaxim Devaev <[email protected]>2022-06-14 18:18:21 +0300
committerMaxim Devaev <[email protected]>2022-06-14 18:18:21 +0300
commit88c7796551010faf30e7f7f843432af919ea7ce0 (patch)
tree9ca033133bad8e10878cd5b45510499588e40262
parent37e5118fff1e63e5af0183e7900bb6a7bc708a34 (diff)
common websocket code
-rw-r--r--kvmd/apps/kvmd/api/hid.py12
-rw-r--r--kvmd/apps/kvmd/server.py83
-rw-r--r--kvmd/htserver.py179
3 files changed, 145 insertions, 129 deletions
diff --git a/kvmd/apps/kvmd/api/hid.py b/kvmd/apps/kvmd/api/hid.py
index ddbaadd8..a683921c 100644
--- a/kvmd/apps/kvmd/api/hid.py
+++ b/kvmd/apps/kvmd/api/hid.py
@@ -32,7 +32,6 @@ from typing import Callable
from aiohttp.web import Request
from aiohttp.web import Response
-from aiohttp.web import WebSocketResponse
from ....mouse import MouseRange
@@ -42,6 +41,7 @@ from ....keyboard.printer import text_to_web_keys
from ....htserver import exposed_http
from ....htserver import exposed_ws
from ....htserver import make_json_response
+from ....htserver import WsSession
from ....plugins.hid import BaseHid
@@ -158,7 +158,7 @@ class HidApi:
# =====
@exposed_ws("key")
- async def __ws_key_handler(self, _: WebSocketResponse, event: Dict) -> None:
+ async def __ws_key_handler(self, _: WsSession, event: Dict) -> None:
try:
key = valid_hid_key(event["key"])
state = valid_bool(event["state"])
@@ -168,7 +168,7 @@ class HidApi:
self.__hid.send_key_events([(key, state)])
@exposed_ws("mouse_button")
- async def __ws_mouse_button_handler(self, _: WebSocketResponse, event: Dict) -> None:
+ async def __ws_mouse_button_handler(self, _: WsSession, event: Dict) -> None:
try:
button = valid_hid_mouse_button(event["button"])
state = valid_bool(event["state"])
@@ -177,7 +177,7 @@ class HidApi:
self.__hid.send_mouse_button_event(button, state)
@exposed_ws("mouse_move")
- async def __ws_mouse_move_handler(self, _: WebSocketResponse, event: Dict) -> None:
+ async def __ws_mouse_move_handler(self, _: WsSession, event: Dict) -> None:
try:
to_x = valid_hid_mouse_move(event["to"]["x"])
to_y = valid_hid_mouse_move(event["to"]["y"])
@@ -186,11 +186,11 @@ class HidApi:
self.__send_mouse_move_event_remapped(to_x, to_y)
@exposed_ws("mouse_relative")
- async def __ws_mouse_relative_handler(self, _: WebSocketResponse, event: Dict) -> None:
+ async def __ws_mouse_relative_handler(self, _: WsSession, event: Dict) -> None:
self.__process_delta_ws_request(event, self.__hid.send_mouse_relative_event)
@exposed_ws("mouse_wheel")
- async def __ws_mouse_wheel_handler(self, _: WebSocketResponse, event: Dict) -> None:
+ async def __ws_mouse_wheel_handler(self, _: WsSession, event: Dict) -> None:
self.__process_delta_ws_request(event, self.__hid.send_mouse_wheel_event)
def __process_delta_ws_request(self, event: Dict, handler: Callable[[int, int], None]) -> None:
diff --git a/kvmd/apps/kvmd/server.py b/kvmd/apps/kvmd/server.py
index 20b2cb7e..5adc8d8d 100644
--- a/kvmd/apps/kvmd/server.py
+++ b/kvmd/apps/kvmd/server.py
@@ -27,7 +27,6 @@ import dataclasses
from typing import Tuple
from typing import List
from typing import Dict
-from typing import Set
from typing import Callable
from typing import Coroutine
from typing import AsyncGenerator
@@ -48,12 +47,8 @@ from ... import aioproc
from ...htserver import HttpExposed
from ...htserver import exposed_http
from ...htserver import exposed_ws
-from ...htserver import get_exposed_http
-from ...htserver import get_exposed_ws
from ...htserver import make_json_response
-from ...htserver import send_ws_event
-from ...htserver import broadcast_ws_event
-from ...htserver import process_ws_messages
+from ...htserver import WsSession
from ...htserver import HttpServer
from ...plugins import BasePlugin
@@ -128,15 +123,6 @@ class _Component: # pylint: disable=too-many-instance-attributes
assert self.event_type, self
[email protected](frozen=True)
-class _WsClient:
- ws: WebSocketResponse
- stream: bool
-
- def __str__(self) -> str:
- return f"WsClient(id={id(self)}, stream={self.stream})"
-
-
class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-instance-attributes
def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
@@ -160,6 +146,8 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
stream_forever: bool,
) -> None:
+ super().__init__()
+
self.__auth_manager = auth_manager
self.__hid = hid
self.__streamer = streamer
@@ -201,11 +189,6 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
RedfishApi(info_manager, atx),
]
- self.__ws_handlers: Dict[str, Callable] = {}
-
- self.__ws_clients: Set[_WsClient] = set()
- self.__ws_clients_lock = asyncio.Lock()
-
self.__streamer_notifier = aiotools.AioNotifier()
self.__reset_streamer = False
self.__new_streamer_params: Dict = {}
@@ -244,11 +227,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
@exposed_http("GET", "/ws")
async def __ws_handler(self, request: Request) -> WebSocketResponse:
stream = valid_bool(request.query.get("stream", "true"))
- ws = await self._make_ws_response(request)
- client = _WsClient(ws, stream)
- await self.__register_ws_client(client)
-
- try:
+ async with self._ws_session(request, stream=stream) as ws:
stage1 = [
("gpio_model_state", self.__user_gpio.get_model()),
("hid_keymaps_state", self.__hid_api.get_keymaps()),
@@ -266,19 +245,15 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
))
for stage in [stage1, stage2]:
await asyncio.gather(*[
- send_ws_event(ws, event_type, events.pop(event_type))
+ ws.send_event(event_type, events.pop(event_type))
for (event_type, _) in stage
])
-
- await send_ws_event(ws, "loop", {})
- await process_ws_messages(ws, self.__ws_handlers)
- return ws
- finally:
- await self.__remove_ws_client(client)
+ await ws.send_event("loop", {})
+ return (await self._ws_loop(ws))
@exposed_ws("ping")
- async def __ws_ping_handler(self, ws: WebSocketResponse, _: Dict) -> None:
- await send_ws_event(ws, "pong", {})
+ async def __ws_ping_handler(self, ws: WsSession, _: Dict) -> None:
+ await ws.send_event("pong", {})
# ===== SYSTEM STUFF
@@ -300,26 +275,16 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
if comp.poll_state:
aiotools.create_deadly_task(f"{comp.name} [poller]", self.__poll_state(comp.event_type, comp.poll_state()))
aiotools.create_deadly_task("Stream snapshoter", self.__stream_snapshoter())
-
- for api in self.__apis:
- for http_exposed in get_exposed_http(api):
- self._add_exposed(http_exposed)
- for ws_exposed in get_exposed_ws(api):
- self.__ws_handlers[ws_exposed.event_type] = ws_exposed.handler
+ self._add_exposed(*self.__apis)
async def _on_shutdown(self) -> None:
logger = get_logger(0)
-
logger.info("Waiting short tasks ...")
await aiotools.wait_all_short_tasks()
-
logger.info("Stopping system tasks ...")
await aiotools.stop_all_deadly_tasks()
-
logger.info("Disconnecting clients ...")
- for client in list(self.__ws_clients):
- await self.__remove_ws_client(client)
-
+ await self._close_all_wss()
logger.info("On-Shutdown complete")
async def _on_cleanup(self) -> None:
@@ -333,25 +298,18 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
logger.exception("Cleanup error on %s", comp.name)
logger.info("On-Cleanup complete")
- async def __register_ws_client(self, client: _WsClient) -> None:
- async with self.__ws_clients_lock:
- self.__ws_clients.add(client)
- get_logger().info("Registered new client socket: %s; clients now: %d", client, len(self.__ws_clients))
+ async def _on_ws_opened(self) -> None:
await self.__streamer_notifier.notify()
- async def __remove_ws_client(self, client: _WsClient) -> None:
- async with self.__ws_clients_lock:
- self.__hid.clear_events()
- try:
- self.__ws_clients.remove(client)
- get_logger().info("Removed client socket: %s; clients now: %d", client, len(self.__ws_clients))
- await client.ws.close()
- except Exception:
- pass
+ async def _on_ws_closed(self) -> None:
+ self.__hid.clear_events()
await self.__streamer_notifier.notify()
def __has_stream_clients(self) -> bool:
- return bool(sum(map(operator.attrgetter("stream"), self.__ws_clients)))
+ return bool(sum(map(
+ (lambda ws: ws.kwargs["stream"]),
+ self._get_wss(),
+ )))
# ===== SYSTEM TASKS
@@ -379,10 +337,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
async def __poll_state(self, event_type: str, poller: AsyncGenerator[Dict, None]) -> None:
async for state in poller:
- await broadcast_ws_event([
- client.ws
- for client in list(self.__ws_clients)
- ], event_type, state)
+ await self._broadcast_ws_event(event_type, state)
async def __stream_snapshoter(self) -> None:
await self.__snapshoter.run(
diff --git a/kvmd/htserver.py b/kvmd/htserver.py
index c24837d2..d3919310 100644
--- a/kvmd/htserver.py
+++ b/kvmd/htserver.py
@@ -23,6 +23,7 @@
import os
import socket
import asyncio
+import contextlib
import dataclasses
import inspect
import json
@@ -31,7 +32,9 @@ from typing import Tuple
from typing import List
from typing import Dict
from typing import Callable
+from typing import AsyncGenerator
from typing import Optional
+from typing import Any
from aiohttp.web import BaseRequest
from aiohttp.web import Request
@@ -103,7 +106,7 @@ def exposed_http(http_method: str, path: str, auth_required: bool=True) -> Calla
return set_attrs
-def get_exposed_http(obj: object) -> List[HttpExposed]:
+def _get_exposed_http(obj: object) -> List[HttpExposed]:
return [
HttpExposed(
method=getattr(handler, _HTTP_METHOD),
@@ -135,7 +138,7 @@ def exposed_ws(event_type: str) -> Callable:
return set_attrs
-def get_exposed_ws(obj: object) -> List[WsExposed]:
+def _get_exposed_ws(obj: object) -> List[WsExposed]:
return [
WsExposed(
event_type=getattr(handler, _WS_EVENT_TYPE),
@@ -206,57 +209,6 @@ async def stream_json_exception(response: StreamResponse, err: Exception) -> Non
# =====
-async def send_ws_event(ws: WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
- await ws.send_str(json.dumps({
- "event_type": event_type,
- "event": event,
- }))
-
-
-async def broadcast_ws_event(wss: List[WebSocketResponse], event_type: str, event: Optional[Dict]) -> None:
- if wss:
- await asyncio.gather(*[
- send_ws_event(ws, event_type, event)
- for ws in wss
- if (
- not ws.closed
- and ws._req is not None # pylint: disable=protected-access
- and ws._req.transport is not None # pylint: disable=protected-access
- )
- ], return_exceptions=True)
-
-
-def _parse_ws_event(msg: str) -> Tuple[str, Dict]:
- data = json.loads(msg)
- if not isinstance(data, dict):
- raise RuntimeError("Top-level event structure is not a dict")
- event_type = data.get("event_type")
- if not isinstance(event_type, str):
- raise RuntimeError("event_type must be a string")
- event = data["event"]
- if not isinstance(event, dict):
- raise RuntimeError("event must be a dict")
- return (event_type, event)
-
-
-async def process_ws_messages(ws: WebSocketResponse, handlers: Dict[str, Callable]) -> None:
- logger = get_logger(1)
- async for msg in ws:
- if msg.type != WSMsgType.TEXT:
- break
- try:
- (event_type, event) = _parse_ws_event(msg.data)
- except Exception as err:
- logger.error("Can't parse JSON event from websocket: %r", err)
- else:
- handler = handlers.get(event_type)
- if handler:
- await handler(ws, event)
- else:
- logger.error("Unknown websocket event: %r", msg.data)
-
-
-# =====
_REQUEST_AUTH_INFO = "_kvmd_auth_info"
@@ -272,7 +224,28 @@ def set_request_auth_info(request: BaseRequest, info: str) -> None:
# =====
[email protected](frozen=True)
+class WsSession:
+ wsr: WebSocketResponse
+ kwargs: Dict[str, Any]
+
+ def __str__(self) -> str:
+ return f"WsSession(id={id(self)}, {self.kwargs})"
+
+ async def send_event(self, event_type: str, event: Optional[Dict]) -> None:
+ await self.wsr.send_str(json.dumps({
+ "event_type": event_type,
+ "event": event,
+ }))
+
+
class HttpServer:
+ def __init__(self) -> None:
+ self.__ws_heartbeat: Optional[float] = None
+ self.__ws_handlers: Dict[str, Callable] = {}
+ self.__ws_sessions: List[WsSession] = []
+ self.__ws_sessions_lock = asyncio.Lock()
+
def run(
self,
unix_path: str,
@@ -282,7 +255,7 @@ class HttpServer:
access_log_format: str,
) -> None:
- self.__heartbeat = heartbeat # pylint: disable=attribute-defined-outside-init
+ self.__ws_heartbeat = heartbeat
if unix_rm and os.path.exists(unix_path):
os.remove(unix_path)
@@ -302,7 +275,14 @@ class HttpServer:
# =====
- def _add_exposed(self, exposed: HttpExposed) -> None:
+ def _add_exposed(self, *objs: object) -> None:
+ for obj in objs:
+ for http_exposed in _get_exposed_http(obj):
+ self.__add_exposed_http(http_exposed)
+ for ws_exposed in _get_exposed_ws(obj):
+ self.__add_exposed_ws(ws_exposed)
+
+ def __add_exposed_http(self, exposed: HttpExposed) -> None:
async def wrapper(request: Request) -> Response:
try:
await self._check_request_auth(exposed, request)
@@ -315,10 +295,85 @@ class HttpServer:
return make_json_exception(err)
self.__app.router.add_route(exposed.method, exposed.path, wrapper)
- async def _make_ws_response(self, request: Request) -> WebSocketResponse:
- ws = WebSocketResponse(heartbeat=self.__heartbeat)
- await ws.prepare(request)
- return ws
+ def __add_exposed_ws(self, exposed: WsExposed) -> None:
+ self.__ws_handlers[exposed.event_type] = exposed.handler
+
+ # =====
+
+ @contextlib.asynccontextmanager
+ async def _ws_session(self, request: Request, **kwargs: Any) -> AsyncGenerator[WsSession, None]:
+ assert self.__ws_heartbeat is not None
+ wsr = WebSocketResponse(heartbeat=self.__ws_heartbeat)
+ await wsr.prepare(request)
+ ws = WsSession(wsr, kwargs)
+
+ async with self.__ws_sessions_lock:
+ self.__ws_sessions.append(ws)
+ get_logger(2).info("Registered new client session: %s; clients now: %d", ws, len(self.__ws_sessions))
+
+ try:
+ await self._on_ws_opened()
+ yield ws
+ finally:
+ await self.__close_ws(ws)
+
+ async def _ws_loop(self, ws: WsSession) -> WebSocketResponse:
+ logger = get_logger()
+ async for msg in ws.wsr:
+ if msg.type != WSMsgType.TEXT:
+ break
+ try:
+ (event_type, event) = self.__parse_ws_event(msg.data)
+ except Exception as err:
+ logger.error("Can't parse JSON event from websocket: %r", err)
+ else:
+ handler = self.__ws_handlers.get(event_type)
+ if handler:
+ await handler(ws, event)
+ else:
+ logger.error("Unknown websocket event: %r", msg.data)
+ return ws.wsr
+
+ async def _broadcast_ws_event(self, event_type: str, event: Optional[Dict]) -> None:
+ if self.__ws_sessions:
+ await asyncio.gather(*[
+ ws.send_event(event_type, event)
+ for ws in self.__ws_sessions
+ if (
+ not ws.wsr.closed
+ and ws.wsr._req is not None # pylint: disable=protected-access
+ and ws.wsr._req.transport is not None # pylint: disable=protected-access
+ )
+ ], return_exceptions=True)
+
+ async def _close_all_wss(self) -> None:
+ for ws in self._get_wss():
+ await self.__close_ws(ws)
+
+ def _get_wss(self) -> List[WsSession]:
+ return list(self.__ws_sessions)
+
+ async def __close_ws(self, ws: WsSession) -> None:
+ async with self.__ws_sessions_lock:
+ try:
+ self.__ws_sessions.remove(ws)
+ get_logger(3).info("Removed client socket: %s; clients now: %d", ws, len(self.__ws_sessions))
+ await ws.wsr.close()
+ except Exception:
+ pass
+ await self._on_ws_closed()
+
+ def __parse_ws_event(self, msg: str) -> Tuple[str, Dict]:
+ data = json.loads(msg)
+ if not isinstance(data, dict):
+ raise RuntimeError("Top-level event structure is not a dict")
+ event_type = data.get("event_type")
+ if not isinstance(event_type, str):
+ raise RuntimeError("event_type must be a string")
+ event = data["event"]
+ if not isinstance(event, dict):
+ raise RuntimeError("event must be a dict")
+ return (event_type, event)
# =====
@@ -334,6 +389,12 @@ class HttpServer:
async def _on_cleanup(self) -> None:
pass
+ async def _on_ws_opened(self) -> None:
+ pass
+
+ async def _on_ws_closed(self) -> None:
+ pass
+
# =====
async def __make_app(self) -> Application: