diff options
author | Devaev Maxim <[email protected]> | 2020-05-24 03:00:29 +0300 |
---|---|---|
committer | Devaev Maxim <[email protected]> | 2020-05-24 03:00:29 +0300 |
commit | d61471d3a32037d9451b8e9aced117b15990a342 (patch) | |
tree | d336c2f73d0c874c293628252a990638f97bf149 /kvmd | |
parent | 564c67fdb7c21151d1f69d7108b9b77c1a51517a (diff) |
share ClientSession via KvmdClientSession
Diffstat (limited to 'kvmd')
-rw-r--r-- | kvmd/apps/ipmi/server.py | 15 | ||||
-rw-r--r-- | kvmd/apps/vnc/server.py | 38 | ||||
-rw-r--r-- | kvmd/clients/kvmd.py | 220 |
3 files changed, 156 insertions, 117 deletions
diff --git a/kvmd/apps/ipmi/server.py b/kvmd/apps/ipmi/server.py index 5f1a47dd..7ef99670 100644 --- a/kvmd/apps/ipmi/server.py +++ b/kvmd/apps/ipmi/server.py @@ -21,9 +21,9 @@ import asyncio +import functools from typing import Dict -from typing import Callable import aiohttp @@ -95,7 +95,7 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute session.send_ipmi_response(code=0xC1) def __get_chassis_status_handler(self, _: Dict, session: IpmiServerSession) -> None: - result = self.__make_request(session, "atx.get_state()", self.__kvmd.atx.get_state) + result = self.__make_request(session, "atx.get_state()", "atx.get_state") data = [int(result["leds"]["power"]), 0, 0] session.send_ipmi_response(data=data) @@ -107,7 +107,7 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute 5: "off", }.get(request["data"][0], "") if action: - if not self.__make_request(session, f"atx.switch_power({action})", self.__kvmd.atx.switch_power, action=action): + if not self.__make_request(session, f"atx.switch_power({action})", "atx.switch_power", action=action): code = 0xC0 # Try again later else: code = 0 @@ -117,19 +117,18 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute # ===== - def __make_request(self, session: IpmiServerSession, name: str, method: Callable, **kwargs): # type: ignore + def __make_request(self, session: IpmiServerSession, name: str, method_path: str, **kwargs): # type: ignore async def runner(): # type: ignore logger = get_logger(0) credentials = self.__auth_manager.get_credentials(session.username.decode()) logger.info("Performing request %s from user %r (IPMI) as %r (KVMD)", name, credentials.ipmi_user, credentials.kvmd_user) try: - return (await method(credentials.kvmd_user, credentials.kvmd_passwd, **kwargs)) + async with self.__kvmd.make_session(credentials.kvmd_user, credentials.kvmd_passwd) as kvmd_session: + method = functools.reduce(getattr, method_path.split("."), kvmd_session) + return (await method(**kwargs)) except (aiohttp.ClientError, asyncio.TimeoutError) as err: logger.error("Can't perform request %s: %s", name, str(err)) raise - except Exception: - logger.exception("Unexpected exception while performing request %s", name) - raise return aiotools.run_sync(runner()) diff --git a/kvmd/apps/vnc/server.py b/kvmd/apps/vnc/server.py index 189332c7..270072db 100644 --- a/kvmd/apps/vnc/server.py +++ b/kvmd/apps/vnc/server.py @@ -38,6 +38,7 @@ from ...logging import get_logger from ...keyboard.keysym import SymmapWebKey from ...keyboard.keysym import build_symmap +from ...clients.kvmd import KvmdClientSession from ...clients.kvmd import KvmdClient from ...clients.streamer import StreamerError @@ -105,6 +106,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes self.__shared_params = shared_params + self.__kvmd_session: Optional[KvmdClientSession] = None self.__authorized = asyncio.Future() # type: ignore self.__ws_connected = asyncio.Future() # type: ignore self.__ws_writer_queue: asyncio.queues.Queue = asyncio.Queue() @@ -123,10 +125,14 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes # ===== async def run(self) -> None: - await self._run( - kvmd=self.__kvmd_task_loop(), - streamer=self.__streamer_task_loop(), - ) + try: + await self._run( + kvmd=self.__kvmd_task_loop(), + streamer=self.__streamer_task_loop(), + ) + finally: + if self.__kvmd_session: + await self.__kvmd_session.close() # ===== @@ -134,9 +140,9 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes logger = get_logger(0) await self.__authorized - (user, passwd) = self.__authorized.result() + assert self.__kvmd_session - async with self.__kvmd.ws(user, passwd) as ws: + async with self.__kvmd_session.ws() as ws: logger.info("[kvmd] Client %s: Connected to KVMD websocket", self._remote) self.__ws_connected.set_result(None) @@ -238,8 +244,9 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes # ===== async def _authorize_userpass(self, user: str, passwd: str) -> bool: - if (await self.__kvmd.auth.check(user, passwd)): - self.__authorized.set_result((user, passwd)) + self.__kvmd_session = self.__kvmd.make_session(user, passwd) + if (await self.__kvmd_session.auth.check()): + self.__authorized.set_result(None) return True return False @@ -285,14 +292,12 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes async def _on_cut_event(self, text: str) -> None: assert self.__authorized.done() - (user, passwd) = self.__authorized.result() + assert self.__kvmd_session logger = get_logger(0) logger.info("[main] Client %s: Printing %d characters ...", self._remote, len(text)) try: - (default, available) = await self.__kvmd.hid.get_keymaps(user, passwd) - await self.__kvmd.hid.print( - user=user, - passwd=passwd, + (default, available) = await self.__kvmd_session.hid.get_keymaps() + await self.__kvmd_session.hid.print( text=text, limit=0, keymap_name=(self.__keymap_name if self.__keymap_name in available else default), @@ -302,10 +307,10 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes async def _on_set_encodings(self) -> None: assert self.__authorized.done() - (user, passwd) = self.__authorized.result() + assert self.__kvmd_session get_logger(0).info("[main] Client %s: Applying streamer params: quality=%d%%; desired_fps=%d ...", self._remote, self._encodings.tight_jpeg_quality, self.__desired_fps) - await self.__kvmd.streamer.set_params(user, passwd, self._encodings.tight_jpeg_quality, self.__desired_fps) + await self.__kvmd_session.streamer.set_params(self._encodings.tight_jpeg_quality, self.__desired_fps) async def _on_fb_update_request(self) -> None: async with self.__lock: @@ -348,7 +353,8 @@ class VncServer: # pylint: disable=too-many-instance-attributes logger.info("Preparing client %s ...", remote) try: try: - none_auth_only = await kvmd.auth.check("", "") + async with kvmd.make_session("", "") as kvmd_session: + none_auth_only = await kvmd_session.auth.check() except (aiohttp.ClientError, asyncio.TimeoutError) as err: logger.error("Client %s: Can't check KVMD auth mode: %s: %s", remote, type(err).__name__, err) return diff --git a/kvmd/clients/kvmd.py b/kvmd/clients/kvmd.py index 5e44811b..fd46f7cb 100644 --- a/kvmd/clients/kvmd.py +++ b/kvmd/clients/kvmd.py @@ -21,11 +21,15 @@ import contextlib +import types from typing import Tuple from typing import Dict from typing import Set +from typing import Callable +from typing import Type from typing import AsyncGenerator +from typing import Optional import aiohttp @@ -33,106 +37,128 @@ from .. import aiotools # ===== -class _BaseClientPart: +class _BaseApiPart: def __init__( self, - host: str, - port: int, - unix_path: str, - timeout: float, - user_agent: str, + ensure_http_session: Callable[[], aiohttp.ClientSession], + make_url: Callable[[str], str], ) -> None: - assert port or unix_path - self.__host = host - self.__port = port - self.__unix_path = unix_path - self.__timeout = timeout - self.__user_agent = user_agent + self._ensure_http_session = ensure_http_session + self._make_url = make_url - def _make_session(self, user: str, passwd: str) -> aiohttp.ClientSession: - kwargs: Dict = { - "headers": { - "X-KVMD-User": user, - "X-KVMD-Passwd": passwd, - "User-Agent": self.__user_agent, - }, - "timeout": aiohttp.ClientTimeout(total=self.__timeout), - } - if self.__unix_path: - kwargs["connector"] = aiohttp.UnixConnector(path=self.__unix_path) - return aiohttp.ClientSession(**kwargs) - - def _make_url(self, handle: str) -> str: - assert not handle.startswith("/"), handle - return f"http://{self.__host}:{self.__port}/{handle}" - -class _AuthClientPart(_BaseClientPart): - async def check(self, user: str, passwd: str) -> bool: +class _AuthApiPart(_BaseApiPart): + async def check(self) -> bool: + session = self._ensure_http_session() try: - async with self._make_session(user, passwd) as session: - async with session.get(self._make_url("auth/check")) as response: - aiotools.raise_not_200(response) - return True + async with session.get(self._make_url("auth/check")) as response: + aiotools.raise_not_200(response) + return True except aiohttp.ClientResponseError as err: if err.status in [401, 403]: return False raise -class _StreamerClientPart(_BaseClientPart): - async def set_params(self, user: str, passwd: str, quality: int, desired_fps: int) -> None: - async with self._make_session(user, passwd) as session: +class _StreamerApiPart(_BaseApiPart): + async def set_params(self, quality: int, desired_fps: int) -> None: + session = self._ensure_http_session() + async with session.post( + url=self._make_url("streamer/set_params"), + params={"quality": quality, "desired_fps": desired_fps}, + ) as response: + aiotools.raise_not_200(response) + + +class _HidApiPart(_BaseApiPart): + async def get_keymaps(self) -> Tuple[str, Set[str]]: + session = self._ensure_http_session() + async with session.get(self._make_url("hid/keymaps")) as response: + aiotools.raise_not_200(response) + result = (await response.json())["result"] + return (result["keymaps"]["default"], set(result["keymaps"]["available"])) + + async def print(self, text: str, limit: int, keymap_name: str) -> None: + session = self._ensure_http_session() + async with session.post( + url=self._make_url("hid/print"), + params={"limit": limit, "keymap": keymap_name}, + data=text, + ) as response: + aiotools.raise_not_200(response) + + +class _AtxApiPart(_BaseApiPart): + async def get_state(self) -> Dict: + session = self._ensure_http_session() + async with session.get(self._make_url("atx")) as response: + aiotools.raise_not_200(response) + return (await response.json())["result"] + + async def switch_power(self, action: str) -> bool: + session = self._ensure_http_session() + try: async with session.post( - url=self._make_url("streamer/set_params"), - params={"quality": quality, "desired_fps": desired_fps}, + url=self._make_url("atx/power"), + params={"action": action}, ) as response: aiotools.raise_not_200(response) + return True + except aiohttp.ClientResponseError as err: + if err.status == 409: + return False + raise -class _HidClientPart(_BaseClientPart): - async def get_keymaps(self, user: str, passwd: str) -> Tuple[str, Set[str]]: - async with self._make_session(user, passwd) as session: - async with session.get(self._make_url("hid/keymaps")) as response: - aiotools.raise_not_200(response) - result = (await response.json())["result"] - return (result["keymaps"]["default"], set(result["keymaps"]["available"])) +class KvmdClientSession: + def __init__( + self, + make_http_session: Callable[[], aiohttp.ClientSession], + make_url: Callable[[str], str], + ) -> None: - async def print(self, user: str, passwd: str, text: str, limit: int, keymap_name: str) -> None: - async with self._make_session(user, passwd) as session: - async with session.post( - url=self._make_url("hid/print"), - params={"limit": limit, "keymap": keymap_name}, - data=text, - ) as response: - aiotools.raise_not_200(response) + self.__make_http_session = make_http_session + self.__make_url = make_url + self.__http_session: Optional[aiohttp.ClientSession] = None -class _AtxClientPart(_BaseClientPart): - async def get_state(self, user: str, passwd: str) -> Dict: - async with self._make_session(user, passwd) as session: - async with session.get(self._make_url("atx")) as response: - aiotools.raise_not_200(response) - return (await response.json())["result"] + args = (self.__ensure_http_session, make_url) + self.auth = _AuthApiPart(*args) + self.streamer = _StreamerApiPart(*args) + self.hid = _HidApiPart(*args) + self.atx = _AtxApiPart(*args) - async def switch_power(self, user: str, passwd: str, action: str) -> bool: - try: - async with self._make_session(user, passwd) as session: - async with session.post( - url=self._make_url("atx/power"), - params={"action": action}, - ) as response: - aiotools.raise_not_200(response) - return True - except aiohttp.ClientResponseError as err: - if err.status == 409: - return False - raise + @contextlib.asynccontextmanager + async def ws(self) -> AsyncGenerator[aiohttp.ClientWebSocketResponse, None]: + session = self.__ensure_http_session() + async with session.ws_connect(self.__make_url("ws")) as ws: + yield ws + def __ensure_http_session(self) -> aiohttp.ClientSession: + if not self.__http_session: + self.__http_session = self.__make_http_session() + return self.__http_session -# ===== -class KvmdClient(_BaseClientPart): + async def close(self) -> None: + if self.__http_session: + await self.__http_session.close() + self.__http_session = None + + async def __aenter__(self) -> "KvmdClientSession": + return self + + async def __aexit__( + self, + _exc_type: Type[BaseException], + _exc: BaseException, + _tb: types.TracebackType, + ) -> None: + + await self.close() + + +class KvmdClient: def __init__( self, host: str, @@ -142,23 +168,31 @@ class KvmdClient(_BaseClientPart): user_agent: str, ) -> None: - kwargs: Dict = { - "host": host, - "port": port, - "unix_path": unix_path, - "timeout": timeout, - "user_agent": user_agent, - } + self.__host = host + self.__port = port + self.__unix_path = unix_path + self.__timeout = timeout + self.__user_agent = user_agent - super().__init__(**kwargs) + def make_session(self, user: str, passwd: str) -> KvmdClientSession: + return KvmdClientSession( + make_http_session=(lambda: self.__make_http_session(user, passwd)), + make_url=self.__make_url, + ) - self.auth = _AuthClientPart(**kwargs) - self.streamer = _StreamerClientPart(**kwargs) - self.hid = _HidClientPart(**kwargs) - self.atx = _AtxClientPart(**kwargs) + def __make_http_session(self, user: str, passwd: str) -> aiohttp.ClientSession: + kwargs: Dict = { + "headers": { + "X-KVMD-User": user, + "X-KVMD-Passwd": passwd, + "User-Agent": self.__user_agent, + }, + "timeout": aiohttp.ClientTimeout(total=self.__timeout), + } + if self.__unix_path: + kwargs["connector"] = aiohttp.UnixConnector(path=self.__unix_path) + return aiohttp.ClientSession(**kwargs) - @contextlib.asynccontextmanager - async def ws(self, user: str, passwd: str) -> AsyncGenerator[aiohttp.ClientWebSocketResponse, None]: - async with self._make_session(user, passwd) as session: - async with session.ws_connect(self._make_url("ws")) as ws: - yield ws + def __make_url(self, handle: str) -> str: + assert not handle.startswith("/"), handle + return f"http://{self.__host}:{self.__port}/{handle}" |