diff options
author | Maxim Devaev <[email protected]> | 2022-04-06 00:39:16 +0300 |
---|---|---|
committer | Maxim Devaev <[email protected]> | 2022-04-06 00:55:20 +0300 |
commit | 6f6772a6b62ecaaa166197f097f82d81c6fb9422 (patch) | |
tree | b0e1a2d8e316e8e7abe9b51f702ac7d8d27aedfa /kvmd/htserver.py | |
parent | 8ce08fb4567ed7f3cfad85da4ae1e123eef42024 (diff) |
refactoring
Diffstat (limited to 'kvmd/htserver.py')
-rw-r--r-- | kvmd/htserver.py | 264 |
1 files changed, 264 insertions, 0 deletions
diff --git a/kvmd/htserver.py b/kvmd/htserver.py new file mode 100644 index 00000000..f1b51b38 --- /dev/null +++ b/kvmd/htserver.py @@ -0,0 +1,264 @@ +# ========================================================================== # +# # +# KVMD - The main PiKVM daemon. # +# # +# Copyright (C) 2018-2022 Maxim Devaev <[email protected]> # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <https://www.gnu.org/licenses/>. # +# # +# ========================================================================== # + + +import os +import socket +import asyncio +import dataclasses +import inspect +import json + +from typing import List +from typing import Dict +from typing import Callable +from typing import Optional + +from aiohttp.web import BaseRequest +from aiohttp.web import Request +from aiohttp.web import Response +from aiohttp.web import StreamResponse +from aiohttp.web import Application +from aiohttp.web import run_app +from aiohttp.web import normalize_path_middleware + +try: + from aiohttp.web import AccessLogger # type: ignore +except ImportError: + from aiohttp.helpers import AccessLogger # type: ignore + +from .logging import get_logger + + +# ===== +class HttpError(Exception): + def __init__(self, msg: str, status: int) -> None: + super().__init__(msg) + self.status = status + + +class UnauthorizedError(HttpError): + def __init__(self) -> None: + super().__init__("Unauthorized", 401) + + +class ForbiddenError(HttpError): + def __init__(self) -> None: + super().__init__("Forbidden", 403) + + +class UnavailableError(HttpError): + def __init__(self) -> None: + super().__init__("Service Unavailable", 503) + + +# ===== [email protected](frozen=True) +class HttpExposed: + method: str + path: str + auth_required: bool + handler: Callable + + +_HTTP_EXPOSED = "_http_exposed" +_HTTP_METHOD = "_http_method" +_HTTP_PATH = "_http_path" +_HTTP_AUTH_REQUIRED = "_http_auth_required" + + +def exposed_http(http_method: str, path: str, auth_required: bool=True) -> Callable: + def set_attrs(handler: Callable) -> Callable: + setattr(handler, _HTTP_EXPOSED, True) + setattr(handler, _HTTP_METHOD, http_method) + setattr(handler, _HTTP_PATH, path) + setattr(handler, _HTTP_AUTH_REQUIRED, auth_required) + return handler + return set_attrs + + +def get_exposed_http(obj: object) -> List[HttpExposed]: + return [ + HttpExposed( + method=getattr(handler, _HTTP_METHOD), + path=getattr(handler, _HTTP_PATH), + auth_required=getattr(handler, _HTTP_AUTH_REQUIRED), + handler=handler, + ) + for handler in [getattr(obj, name) for name in dir(obj)] + if inspect.ismethod(handler) and getattr(handler, _HTTP_EXPOSED, False) + ] + + +# ===== [email protected](frozen=True) +class WsExposed: + event_type: str + handler: Callable + + +_WS_EXPOSED = "_ws_exposed" +_WS_EVENT_TYPE = "_ws_event_type" + + +def exposed_ws(event_type: str) -> Callable: + def set_attrs(handler: Callable) -> Callable: + setattr(handler, _WS_EXPOSED, True) + setattr(handler, _WS_EVENT_TYPE, event_type) + return handler + return set_attrs + + +def get_exposed_ws(obj: object) -> List[WsExposed]: + return [ + WsExposed( + event_type=getattr(handler, _WS_EVENT_TYPE), + handler=handler, + ) + for handler in [getattr(obj, name) for name in dir(obj)] + if inspect.ismethod(handler) and getattr(handler, _WS_EXPOSED, False) + ] + + +# ===== +def make_json_response( + result: Optional[Dict]=None, + status: int=200, + set_cookies: Optional[Dict[str, str]]=None, + wrap_result: bool=True, +) -> Response: + + response = Response( + text=json.dumps(({ + "ok": (status == 200), + "result": (result or {}), + } if wrap_result else result), sort_keys=True, indent=4), + status=status, + content_type="application/json", + ) + if set_cookies: + for (key, value) in set_cookies.items(): + response.set_cookie(key, value) + return response + + +def make_json_exception(err: Exception, status: Optional[int]=None) -> Response: + name = type(err).__name__ + msg = str(err) + if isinstance(err, HttpError): + status = err.status + else: + get_logger().error("API error: %s: %s", name, msg) + assert status is not None, err + return make_json_response({ + "error": name, + "error_msg": msg, + }, status=status) + + +async def start_streaming(request: Request, content_type: str="application/x-ndjson") -> StreamResponse: + response = StreamResponse(status=200, reason="OK", headers={"Content-Type": content_type}) + await response.prepare(request) + return response + + +async def stream_json(response: StreamResponse, result: Dict, ok: bool=True) -> None: + await response.write(json.dumps({ + "ok": ok, + "result": result, + }).encode("utf-8") + b"\r\n") + + +async def stream_json_exception(response: StreamResponse, err: Exception) -> None: + name = type(err).__name__ + msg = str(err) + get_logger().error("API error: %s: %s", name, msg) + await stream_json(response, { + "error": name, + "error_msg": msg, + }, False) + + +# ===== +_REQUEST_AUTH_INFO = "_kvmd_auth_info" + + +def _format_P(request: BaseRequest, *_, **__) -> str: # type: ignore # pylint: disable=invalid-name + return (getattr(request, _REQUEST_AUTH_INFO, None) or "-") + + +AccessLogger._format_P = staticmethod(_format_P) # type: ignore # pylint: disable=protected-access + + +def set_request_auth_info(request: BaseRequest, info: str) -> None: + setattr(request, _REQUEST_AUTH_INFO, info) + + +# ===== +class HttpServer: + def run( + self, + unix_path: str, + unix_rm: bool, + unix_mode: int, + access_log_format: str, + ) -> None: + + if unix_rm and os.path.exists(unix_path): + os.remove(unix_path) + server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_socket.bind(unix_path) + if unix_mode: + os.chmod(unix_path, unix_mode) + + run_app( + sock=server_socket, + app=self.__make_app(), + shutdown_timeout=1, + access_log_format=access_log_format, + print=self.__run_app_print, + loop=asyncio.get_event_loop(), + ) + + async def _init_app(self, app: Application) -> None: + raise NotImplementedError + + async def _on_shutdown(self, app: Application) -> None: + _ = app + + async def _on_cleanup(self, app: Application) -> None: + _ = app + + async def __make_app(self) -> Application: + app = Application(middlewares=[normalize_path_middleware( + append_slash=False, + remove_slash=True, + merge_slashes=True, + )]) + app.on_shutdown.append(self._on_shutdown) + app.on_cleanup.append(self._on_cleanup) + await self._init_app(app) + return app + + def __run_app_print(self, text: str) -> None: + logger = get_logger(0) + for line in text.strip().splitlines(): + logger.info(line.strip()) |