diff options
author | Devaev Maxim <[email protected]> | 2020-06-11 09:09:25 +0300 |
---|---|---|
committer | Devaev Maxim <[email protected]> | 2020-06-11 09:09:25 +0300 |
commit | 595209c470b09eb4d97f6af1440e84e41f37191d (patch) | |
tree | 46420d25318959d77ed8b50721a78ba5d90e5525 /kvmd | |
parent | aaea8fef24325fd96897d33ee4f6ed9f0ce263b4 (diff) |
no-stream mode for /ws
Diffstat (limited to 'kvmd')
-rw-r--r-- | kvmd/apps/kvmd/server.py | 76 |
1 files changed, 48 insertions, 28 deletions
diff --git a/kvmd/apps/kvmd/server.py b/kvmd/apps/kvmd/server.py index 5d027dcd..31fb8b88 100644 --- a/kvmd/apps/kvmd/server.py +++ b/kvmd/apps/kvmd/server.py @@ -23,6 +23,7 @@ import os import signal import asyncio +import operator import dataclasses import json @@ -51,6 +52,8 @@ from ...plugins.msd import BaseMsd from ...validators import ValidatorError +from ...validators.basic import valid_bool + from ...validators.kvm import valid_stream_quality from ...validators.kvm import valid_stream_fps @@ -106,6 +109,15 @@ class _Component: assert self.event_type, self [email protected](frozen=True) +class _WsClient: + ws: aiohttp.web.WebSocketResponse + stream: bool + + def __str__(self) -> str: + return f"WsClient(id={id(self)}, stream={self.stream})" + + class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-instance-attributes def __init__( # pylint: disable=too-many-arguments self, @@ -157,8 +169,8 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins self.__ws_handlers: Dict[str, Callable] = {} - self.__sockets: Set[aiohttp.web.WebSocketResponse] = set() - self.__sockets_lock = asyncio.Lock() + self.__ws_clients: Set[_WsClient] = set() + self.__ws_clients_lock = asyncio.Lock() self.__system_tasks: List[asyncio.Task] = [] @@ -193,16 +205,19 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins @exposed_http("GET", "/ws") async def __ws_handler(self, request: aiohttp.web.Request) -> aiohttp.web.WebSocketResponse: logger = get_logger(0) - ws = aiohttp.web.WebSocketResponse(heartbeat=self.__heartbeat) - await ws.prepare(request) - await self.__register_socket(ws) + client = _WsClient( + ws=aiohttp.web.WebSocketResponse(heartbeat=self.__heartbeat), + stream=valid_bool(request.query.get("stream", "true")), + ) + await client.ws.prepare(request) + await self.__register_ws_client(client) try: await asyncio.gather(*[ self.__broadcast_event(component.event_type, await component.get_state()) for component in self.__components if component.get_state ]) - async for msg in ws: + async for msg in client.ws: if msg.type == aiohttp.web.WSMsgType.TEXT: try: data = json.loads(msg.data) @@ -213,14 +228,14 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins else: handler = self.__ws_handlers.get(event_type) if handler: - await handler(ws, event) + await handler(client.ws, event) else: logger.error("Unknown websocket event: %r", data) else: break - return ws + return client.ws finally: - await self.__remove_socket(ws) + await self.__remove_ws_client(client) @exposed_ws("ping") async def __ws_ping_handler(self, ws: aiohttp.web.WebSocketResponse, _: Dict) -> None: @@ -292,8 +307,8 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins await asyncio.gather(*self.__system_tasks, return_exceptions=True) logger.info("Disconnecting clients ...") - for ws in list(self.__sockets): - await self.__remove_socket(ws) + for client in list(self.__ws_clients): + await self.__remove_ws_client(client) async def __on_cleanup(self, _: aiohttp.web.Application) -> None: logger = get_logger(0) @@ -306,41 +321,46 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins logger.exception("Cleanup error on %s", component.name) async def __broadcast_event(self, event_type: str, event: Dict) -> None: - if self.__sockets: + if self.__ws_clients: await asyncio.gather(*[ - ws.send_str(json.dumps({ + client.ws.send_str(json.dumps({ "event_type": event_type, "event": event, })) - for ws in list(self.__sockets) - if not ws.closed and ws._req is not None and ws._req.transport is not None # pylint: disable=protected-access + for client in list(self.__ws_clients) + if ( + not client.ws.closed + and client.ws._req is not None # pylint: disable=protected-access + and client.ws._req.transport is not None # pylint: disable=protected-access + ) ], return_exceptions=True) - async def __register_socket(self, ws: aiohttp.web.WebSocketResponse) -> None: - async with self.__sockets_lock: - self.__sockets.add(ws) - remote: Optional[str] = (ws._req.remote if ws._req is not None else None) # pylint: disable=protected-access - get_logger().info("Registered new client socket: remote=%s; id=%d; active=%d", remote, id(ws), len(self.__sockets)) + async def __register_ws_client(self, client: _WsClient) -> None: + async with self.__ws_clients_lock: + self.__ws_clients.add(client) + get_logger().info("Registered new client socket: %s; clients now: %d", client, len(self.__ws_clients)) await self.__streamer_notifier.notify() - async def __remove_socket(self, ws: aiohttp.web.WebSocketResponse) -> None: - async with self.__sockets_lock: + async def __remove_ws_client(self, client: _WsClient) -> None: + async with self.__ws_clients_lock: self.__hid.clear_events() try: - self.__sockets.remove(ws) - remote: Optional[str] = (ws._req.remote if ws._req is not None else None) # pylint: disable=protected-access - get_logger().info("Removed client socket: remote=%s; id=%d; active=%d", remote, id(ws), len(self.__sockets)) - await ws.close() + self.__ws_clients.remove(client) + get_logger().info("Removed client socket: %s; clients now: %d", client, len(self.__ws_clients)) + await client.ws.close() except Exception: pass await self.__streamer_notifier.notify() + def __has_stream_clients(self) -> bool: + return bool(sum(map(operator.attrgetter("stream"), self.__ws_clients))) + # ===== SYSTEM TASKS async def __stream_controller(self) -> None: prev = False while True: - cur = (bool(self.__sockets) or self.__snapshoter.snapshoting()) + cur = (self.__has_stream_clients() or self.__snapshoter.snapshoting()) if not prev and cur: await self.__streamer.ensure_start(init_restart=True) elif prev and not cur: @@ -365,6 +385,6 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins async def __stream_snapshoter(self) -> None: await self.__snapshoter.run( - is_live=(lambda: bool(self.__sockets)), + is_live=self.__has_stream_clients, notifier=self.__streamer_notifier, ) |