diff options
author | Devaev Maxim <[email protected]> | 2020-05-24 03:00:29 +0300 |
---|---|---|
committer | Devaev Maxim <[email protected]> | 2020-05-24 03:00:29 +0300 |
commit | d61471d3a32037d9451b8e9aced117b15990a342 (patch) | |
tree | d336c2f73d0c874c293628252a990638f97bf149 /kvmd/clients/kvmd.py | |
parent | 564c67fdb7c21151d1f69d7108b9b77c1a51517a (diff) |
share ClientSession via KvmdClientSession
Diffstat (limited to 'kvmd/clients/kvmd.py')
-rw-r--r-- | kvmd/clients/kvmd.py | 220 |
1 files changed, 127 insertions, 93 deletions
diff --git a/kvmd/clients/kvmd.py b/kvmd/clients/kvmd.py index 5e44811b..fd46f7cb 100644 --- a/kvmd/clients/kvmd.py +++ b/kvmd/clients/kvmd.py @@ -21,11 +21,15 @@ import contextlib +import types from typing import Tuple from typing import Dict from typing import Set +from typing import Callable +from typing import Type from typing import AsyncGenerator +from typing import Optional import aiohttp @@ -33,106 +37,128 @@ from .. import aiotools # ===== -class _BaseClientPart: +class _BaseApiPart: def __init__( self, - host: str, - port: int, - unix_path: str, - timeout: float, - user_agent: str, + ensure_http_session: Callable[[], aiohttp.ClientSession], + make_url: Callable[[str], str], ) -> None: - assert port or unix_path - self.__host = host - self.__port = port - self.__unix_path = unix_path - self.__timeout = timeout - self.__user_agent = user_agent + self._ensure_http_session = ensure_http_session + self._make_url = make_url - def _make_session(self, user: str, passwd: str) -> aiohttp.ClientSession: - kwargs: Dict = { - "headers": { - "X-KVMD-User": user, - "X-KVMD-Passwd": passwd, - "User-Agent": self.__user_agent, - }, - "timeout": aiohttp.ClientTimeout(total=self.__timeout), - } - if self.__unix_path: - kwargs["connector"] = aiohttp.UnixConnector(path=self.__unix_path) - return aiohttp.ClientSession(**kwargs) - - def _make_url(self, handle: str) -> str: - assert not handle.startswith("/"), handle - return f"http://{self.__host}:{self.__port}/{handle}" - -class _AuthClientPart(_BaseClientPart): - async def check(self, user: str, passwd: str) -> bool: +class _AuthApiPart(_BaseApiPart): + async def check(self) -> bool: + session = self._ensure_http_session() try: - async with self._make_session(user, passwd) as session: - async with session.get(self._make_url("auth/check")) as response: - aiotools.raise_not_200(response) - return True + async with session.get(self._make_url("auth/check")) as response: + aiotools.raise_not_200(response) + return True except aiohttp.ClientResponseError as err: if err.status in [401, 403]: return False raise -class _StreamerClientPart(_BaseClientPart): - async def set_params(self, user: str, passwd: str, quality: int, desired_fps: int) -> None: - async with self._make_session(user, passwd) as session: +class _StreamerApiPart(_BaseApiPart): + async def set_params(self, quality: int, desired_fps: int) -> None: + session = self._ensure_http_session() + async with session.post( + url=self._make_url("streamer/set_params"), + params={"quality": quality, "desired_fps": desired_fps}, + ) as response: + aiotools.raise_not_200(response) + + +class _HidApiPart(_BaseApiPart): + async def get_keymaps(self) -> Tuple[str, Set[str]]: + session = self._ensure_http_session() + async with session.get(self._make_url("hid/keymaps")) as response: + aiotools.raise_not_200(response) + result = (await response.json())["result"] + return (result["keymaps"]["default"], set(result["keymaps"]["available"])) + + async def print(self, text: str, limit: int, keymap_name: str) -> None: + session = self._ensure_http_session() + async with session.post( + url=self._make_url("hid/print"), + params={"limit": limit, "keymap": keymap_name}, + data=text, + ) as response: + aiotools.raise_not_200(response) + + +class _AtxApiPart(_BaseApiPart): + async def get_state(self) -> Dict: + session = self._ensure_http_session() + async with session.get(self._make_url("atx")) as response: + aiotools.raise_not_200(response) + return (await response.json())["result"] + + async def switch_power(self, action: str) -> bool: + session = self._ensure_http_session() + try: async with session.post( - url=self._make_url("streamer/set_params"), - params={"quality": quality, "desired_fps": desired_fps}, + url=self._make_url("atx/power"), + params={"action": action}, ) as response: aiotools.raise_not_200(response) + return True + except aiohttp.ClientResponseError as err: + if err.status == 409: + return False + raise -class _HidClientPart(_BaseClientPart): - async def get_keymaps(self, user: str, passwd: str) -> Tuple[str, Set[str]]: - async with self._make_session(user, passwd) as session: - async with session.get(self._make_url("hid/keymaps")) as response: - aiotools.raise_not_200(response) - result = (await response.json())["result"] - return (result["keymaps"]["default"], set(result["keymaps"]["available"])) +class KvmdClientSession: + def __init__( + self, + make_http_session: Callable[[], aiohttp.ClientSession], + make_url: Callable[[str], str], + ) -> None: - async def print(self, user: str, passwd: str, text: str, limit: int, keymap_name: str) -> None: - async with self._make_session(user, passwd) as session: - async with session.post( - url=self._make_url("hid/print"), - params={"limit": limit, "keymap": keymap_name}, - data=text, - ) as response: - aiotools.raise_not_200(response) + self.__make_http_session = make_http_session + self.__make_url = make_url + self.__http_session: Optional[aiohttp.ClientSession] = None -class _AtxClientPart(_BaseClientPart): - async def get_state(self, user: str, passwd: str) -> Dict: - async with self._make_session(user, passwd) as session: - async with session.get(self._make_url("atx")) as response: - aiotools.raise_not_200(response) - return (await response.json())["result"] + args = (self.__ensure_http_session, make_url) + self.auth = _AuthApiPart(*args) + self.streamer = _StreamerApiPart(*args) + self.hid = _HidApiPart(*args) + self.atx = _AtxApiPart(*args) - async def switch_power(self, user: str, passwd: str, action: str) -> bool: - try: - async with self._make_session(user, passwd) as session: - async with session.post( - url=self._make_url("atx/power"), - params={"action": action}, - ) as response: - aiotools.raise_not_200(response) - return True - except aiohttp.ClientResponseError as err: - if err.status == 409: - return False - raise + @contextlib.asynccontextmanager + async def ws(self) -> AsyncGenerator[aiohttp.ClientWebSocketResponse, None]: + session = self.__ensure_http_session() + async with session.ws_connect(self.__make_url("ws")) as ws: + yield ws + def __ensure_http_session(self) -> aiohttp.ClientSession: + if not self.__http_session: + self.__http_session = self.__make_http_session() + return self.__http_session -# ===== -class KvmdClient(_BaseClientPart): + async def close(self) -> None: + if self.__http_session: + await self.__http_session.close() + self.__http_session = None + + async def __aenter__(self) -> "KvmdClientSession": + return self + + async def __aexit__( + self, + _exc_type: Type[BaseException], + _exc: BaseException, + _tb: types.TracebackType, + ) -> None: + + await self.close() + + +class KvmdClient: def __init__( self, host: str, @@ -142,23 +168,31 @@ class KvmdClient(_BaseClientPart): user_agent: str, ) -> None: - kwargs: Dict = { - "host": host, - "port": port, - "unix_path": unix_path, - "timeout": timeout, - "user_agent": user_agent, - } + self.__host = host + self.__port = port + self.__unix_path = unix_path + self.__timeout = timeout + self.__user_agent = user_agent - super().__init__(**kwargs) + def make_session(self, user: str, passwd: str) -> KvmdClientSession: + return KvmdClientSession( + make_http_session=(lambda: self.__make_http_session(user, passwd)), + make_url=self.__make_url, + ) - self.auth = _AuthClientPart(**kwargs) - self.streamer = _StreamerClientPart(**kwargs) - self.hid = _HidClientPart(**kwargs) - self.atx = _AtxClientPart(**kwargs) + def __make_http_session(self, user: str, passwd: str) -> aiohttp.ClientSession: + kwargs: Dict = { + "headers": { + "X-KVMD-User": user, + "X-KVMD-Passwd": passwd, + "User-Agent": self.__user_agent, + }, + "timeout": aiohttp.ClientTimeout(total=self.__timeout), + } + if self.__unix_path: + kwargs["connector"] = aiohttp.UnixConnector(path=self.__unix_path) + return aiohttp.ClientSession(**kwargs) - @contextlib.asynccontextmanager - async def ws(self, user: str, passwd: str) -> AsyncGenerator[aiohttp.ClientWebSocketResponse, None]: - async with self._make_session(user, passwd) as session: - async with session.ws_connect(self._make_url("ws")) as ws: - yield ws + def __make_url(self, handle: str) -> str: + assert not handle.startswith("/"), handle + return f"http://{self.__host}:{self.__port}/{handle}" |