diff options
Diffstat (limited to 'kvmd/server.py')
-rw-r--r-- | kvmd/server.py | 424 |
1 files changed, 424 insertions, 0 deletions
diff --git a/kvmd/server.py b/kvmd/server.py new file mode 100644 index 00000000..19282fc2 --- /dev/null +++ b/kvmd/server.py @@ -0,0 +1,424 @@ +import os +import signal +import asyncio +import platform +import functools +import json +import time + +from typing import List +from typing import Dict +from typing import Set +from typing import Callable +from typing import Optional + +import aiohttp.web +import setproctitle + +from .aioregion import RegionIsBusyError + +from .hid import Hid + +from .atx import Atx + +from .msd import MsdOperationError +from .msd import MassStorageDevice + +from .streamer import Streamer + +from .logging import get_logger + + +# ===== +__version__ = "0.66" + + [email protected]_cache() +def _get_system_info() -> Dict[str, Dict[str, str]]: + return { + "version": { + "platform": platform.platform(), + "python": platform.python_version(), + "kvmd": __version__, + }, + } + + +# ===== +def _system_task(method: Callable) -> Callable: + async def wrap(self: "Server") -> None: + try: + await method(self) + except asyncio.CancelledError: + pass + except Exception: + get_logger().exception("Unhandled exception, killing myself ...") + os.kill(os.getpid(), signal.SIGTERM) + return wrap + + +def _json(result: Optional[Dict]=None, status: int=200) -> aiohttp.web.Response: + return aiohttp.web.Response( + text=json.dumps({ + "ok": (True if status == 200 else False), + "result": (result or {}), + }, sort_keys=True, indent=4), + status=status, + content_type="application/json", + ) + + +def _json_exception(msg: str, err: Exception, status: int) -> aiohttp.web.Response: + msg = "%s: %s" % (msg, err) + get_logger().error(msg) + return _json({ + "error": type(err).__name__, + "error_msg": msg, + }, status=status) + + +class BadRequest(Exception): + pass + + +def _wrap_exceptions_for_web(msg: str) -> Callable: + def make_wrapper(method: Callable) -> Callable: + async def wrap(self: "Server", request: aiohttp.web.Request) -> aiohttp.web.Response: + try: + return (await method(self, request)) + except RegionIsBusyError as err: + return _json_exception(msg, err, 409) + except (BadRequest, MsdOperationError) as err: + return _json_exception(msg, err, 400) + return wrap + return make_wrapper + + +class Server: # pylint: disable=too-many-instance-attributes + def __init__( + self, + hid: Hid, + atx: Atx, + msd: MassStorageDevice, + streamer: Streamer, + + heartbeat: float, + atx_state_poll: float, + streamer_shutdown_delay: float, + msd_chunk_size: int, + + loop: asyncio.AbstractEventLoop, + ) -> None: + + self.__hid = hid + self.__atx = atx + self.__msd = msd + self.__streamer = streamer + + self.__heartbeat = heartbeat + self.__streamer_shutdown_delay = streamer_shutdown_delay + self.__atx_state_poll = atx_state_poll + self.__msd_chunk_size = msd_chunk_size + + self.__loop = loop + + self.__sockets: Set[aiohttp.web.WebSocketResponse] = set() + self.__sockets_lock = asyncio.Lock() + + self.__system_tasks: List[asyncio.Task] = [] + + self.__reset_streamer = False + self.__streamer_quality = streamer.get_current_quality() + + def run(self, host: str, port: int) -> None: + self.__hid.start() + + setproctitle.setproctitle("[main] " + setproctitle.getproctitle()) + + app = aiohttp.web.Application(loop=self.__loop) + + app.router.add_get("/info", self.__info_handler) + + app.router.add_get("/ws", self.__ws_handler) + + app.router.add_get("/atx", self.__atx_state_handler) + app.router.add_post("/atx/click", self.__atx_click_handler) + + app.router.add_get("/msd", self.__msd_state_handler) + app.router.add_post("/msd/connect", self.__msd_connect_handler) + app.router.add_post("/msd/write", self.__msd_write_handler) + + app.router.add_get("/streamer", self.__streamer_state_handler) + app.router.add_post("/streamer/set_params", self.__streamer_set_params_handler) + app.router.add_post("/streamer/reset", self.__streamer_reset_handler) + + app.on_shutdown.append(self.__on_shutdown) + app.on_cleanup.append(self.__on_cleanup) + + self.__system_tasks.extend([ + self.__loop.create_task(self.__hid_watchdog()), + self.__loop.create_task(self.__stream_controller()), + self.__loop.create_task(self.__poll_dead_sockets()), + self.__loop.create_task(self.__poll_atx_state()), + ]) + + aiohttp.web.run_app(app, host=host, port=port, print=self.__run_app_print) + + # ===== INFO + + async def __info_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response: + return _json(_get_system_info()) + + # ===== WEBSOCKET + + async def __ws_handler(self, request: aiohttp.web.Request) -> aiohttp.web.WebSocketResponse: + logger = get_logger(0) + ws = aiohttp.web.WebSocketResponse(heartbeat=self.__heartbeat) + await ws.prepare(request) + await self.__register_socket(ws) + async for msg in ws: + if msg.type == aiohttp.web.WSMsgType.TEXT: + try: + event = json.loads(msg.data) + except Exception as err: + logger.error("Can't parse JSON event from websocket: %s", err) + else: + event_type = event.get("event_type") + if event_type == "ping": + await ws.send_str(json.dumps({"msg_type": "pong"})) + elif event_type == "key": + await self.__handle_ws_key_event(event) + elif event_type == "mouse_move": + await self.__handle_ws_mouse_move_event(event) + elif event_type == "mouse_button": + await self.__handle_ws_mouse_button_event(event) + elif event_type == "mouse_wheel": + await self.__handle_ws_mouse_wheel_event(event) + else: + logger.error("Unknown websocket event: %r", event) + else: + break + return ws + + async def __handle_ws_key_event(self, event: Dict) -> None: + key = str(event.get("key", ""))[:64].strip() + state = event.get("state") + if key and state in [True, False]: + await self.__hid.send_key_event(key, state) + + async def __handle_ws_mouse_move_event(self, event: Dict) -> None: + try: + to_x = int(event["to"]["x"]) + to_y = int(event["to"]["y"]) + except Exception: + return + await self.__hid.send_mouse_move_event(to_x, to_y) + + async def __handle_ws_mouse_button_event(self, event: Dict) -> None: + button = str(event.get("button", ""))[:64].strip() + state = event.get("state") + if button and state in [True, False]: + await self.__hid.send_mouse_button_event(button, state) + + async def __handle_ws_mouse_wheel_event(self, event: Dict) -> None: + try: + delta_y = int(event["delta"]["y"]) + except Exception: + return + await self.__hid.send_mouse_wheel_event(delta_y) + + # ===== ATX + + async def __atx_state_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response: + return _json(self.__atx.get_state()) + + @_wrap_exceptions_for_web("Click error") + async def __atx_click_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response: + button = request.query.get("button") + clicker = { + "power": self.__atx.click_power, + "power_long": self.__atx.click_power_long, + "reset": self.__atx.click_reset, + }.get(button) + if not clicker: + raise BadRequest("Missing or invalid 'button=%s'" % (button)) + await self.__broadcast_event("atx_click", button=button) # type: ignore + await clicker() + await self.__broadcast_event("atx_click", button=None) # type: ignore + return _json({"clicked": button}) + + # ===== MSD + + async def __msd_state_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response: + return _json(self.__msd.get_state()) + + @_wrap_exceptions_for_web("Mass-storage error") + async def __msd_connect_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response: + to = request.query.get("to") + if to == "kvm": + await self.__msd.connect_to_kvm() + state = self.__msd.get_state() + await self.__broadcast_event("msd_state", **state) + elif to == "server": + await self.__msd.connect_to_pc() + state = self.__msd.get_state() + await self.__broadcast_event("msd_state", **state) + else: + raise BadRequest("Missing or invalid 'to=%s'" % (to)) + return _json(state) + + @_wrap_exceptions_for_web("Can't write data to mass-storage device") + async def __msd_write_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response: + logger = get_logger(0) + reader = await request.multipart() + written = 0 + try: + field = await reader.next() + if not field or field.name != "image_name": + raise BadRequest("Missing 'image_name' field") + image_name = (await field.read()).decode("utf-8")[:256] + + field = await reader.next() + if not field or field.name != "image_data": + raise BadRequest("Missing 'image_data' field") + + async with self.__msd: + await self.__broadcast_event("msd_state", **self.__msd.get_state()) + logger.info("Writing image %r to mass-storage device ...", image_name) + await self.__msd.write_image_info(image_name, False) + while True: + chunk = await field.read_chunk(self.__msd_chunk_size) + if not chunk: + break + written = await self.__msd.write_image_chunk(chunk) + await self.__msd.write_image_info(image_name, True) + finally: + await self.__broadcast_event("msd_state", **self.__msd.get_state()) + if written != 0: + logger.info("Written %d bytes to mass-storage device", written) + return _json({"written": written}) + + # ===== STREAMER + + async def __streamer_state_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response: + return _json(self.__streamer.get_state()) + + @_wrap_exceptions_for_web("Can't set stream params") + async def __streamer_set_params_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response: + quality = request.query.get("quality") + if quality: + try: + quality_int = int(quality) + if not (1 <= quality_int <= 100): + raise ValueError() + except Exception: + raise BadRequest("Invalid quality %r" % (quality)) + self.__streamer_quality = quality_int + return _json() + + async def __streamer_reset_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response: + self.__reset_streamer = True + return _json() + + # ===== + + def __run_app_print(self, text: str) -> None: + logger = get_logger() + for line in text.strip().splitlines(): + logger.info(line.strip()) + + async def __on_shutdown(self, _: aiohttp.web.Application) -> None: + logger = get_logger(0) + + logger.info("Cancelling system tasks ...") + for task in self.__system_tasks: + task.cancel() + await asyncio.gather(*self.__system_tasks) + + logger.info("Disconnecting clients ...") + for ws in list(self.__sockets): + await self.__remove_socket(ws) + + async def __on_cleanup(self, _: aiohttp.web.Application) -> None: + await self.__hid.cleanup() + await self.__streamer.cleanup() + await self.__msd.cleanup() + + @_system_task + async def __hid_watchdog(self) -> None: + while self.__hid.is_alive(): + await asyncio.sleep(0.1) + raise RuntimeError("HID is dead") + + @_system_task + async def __stream_controller(self) -> None: + prev = 0 + shutdown_at = 0.0 + + while True: + cur = len(self.__sockets) + if prev == 0 and cur > 0: + if not self.__streamer.is_running(): + await self.__streamer.start(self.__streamer_quality) + await self.__broadcast_event("streamer_state", **self.__streamer.get_state()) + elif prev > 0 and cur == 0: + shutdown_at = time.time() + self.__streamer_shutdown_delay + elif prev == 0 and cur == 0 and time.time() > shutdown_at: + if self.__streamer.is_running(): + await self.__streamer.stop() + await self.__broadcast_event("streamer_state", **self.__streamer.get_state()) + + if self.__reset_streamer or self.__streamer_quality != self.__streamer.get_current_quality(): + if self.__streamer.is_running(): + await self.__streamer.stop() + await self.__streamer.start(self.__streamer_quality, no_init_restart=True) + await self.__broadcast_event("streamer_state", **self.__streamer.get_state()) + self.__reset_streamer = False + + prev = cur + await asyncio.sleep(0.1) + + @_system_task + async def __poll_dead_sockets(self) -> None: + while True: + for ws in list(self.__sockets): + if ws.closed or not ws._req.transport: # pylint: disable=protected-access + await self.__remove_socket(ws) + await asyncio.sleep(0.1) + + @_system_task + async def __poll_atx_state(self) -> None: + while True: + if self.__sockets: + await self.__broadcast_event("atx_state", **self.__atx.get_state()) + await asyncio.sleep(self.__atx_state_poll) + + async def __broadcast_event(self, event: str, **kwargs: Dict) -> None: + await asyncio.gather(*[ + ws.send_str(json.dumps({ + "msg_type": "event", + "msg": { + "event": event, + "event_attrs": kwargs, + }, + })) + for ws in list(self.__sockets) + if not ws.closed and ws._req.transport # pylint: disable=protected-access + ], return_exceptions=True) + + async def __register_socket(self, ws: aiohttp.web.WebSocketResponse) -> None: + async with self.__sockets_lock: + self.__sockets.add(ws) + get_logger().info("Registered new client socket: remote=%s; id=%d; active=%d", + ws._req.remote, id(ws), len(self.__sockets)) # pylint: disable=protected-access + + async def __remove_socket(self, ws: aiohttp.web.WebSocketResponse) -> None: + async with self.__sockets_lock: + await self.__hid.clear_events() + try: + self.__sockets.remove(ws) + get_logger().info("Removed client socket: remote=%s; id=%d; active=%d", + ws._req.remote, id(ws), len(self.__sockets)) # pylint: disable=protected-access + await ws.close() + except Exception: + pass |