diff options
author | Devaev Maxim <[email protected]> | 2021-01-28 20:36:46 +0300 |
---|---|---|
committer | Devaev Maxim <[email protected]> | 2021-01-28 20:36:46 +0300 |
commit | 0538a6828f67f3879b037fbb552fa4a65231d403 (patch) | |
tree | 67b3432a516f27cd33f2cad54dcd78eb242452f9 | |
parent | 1442515e5ca27f48d29d0c9d864823ced5e893ea (diff) |
refactoring
-rw-r--r-- | kvmd/aiotools.py | 24 | ||||
-rw-r--r-- | kvmd/apps/vnc/server.py | 37 |
2 files changed, 41 insertions, 20 deletions
diff --git a/kvmd/aiotools.py b/kvmd/aiotools.py index d6511568..e18ce935 100644 --- a/kvmd/aiotools.py +++ b/kvmd/aiotools.py @@ -106,6 +106,30 @@ class AioNotifier: # ===== +class AioStage: + def __init__(self) -> None: + self.__fut = asyncio.Future() # type: ignore + + def set_passed(self, multi: bool=False) -> None: + if multi and self.__fut.done(): + return + self.__fut.set_result(None) + + def is_passed(self) -> bool: + return self.__fut.done() + + async def wait_passed(self, timeout: float=-1) -> bool: + if timeout >= 0: + try: + await asyncio.wait_for(self.__fut, timeout=timeout) + except asyncio.TimeoutError: + return False + else: + await self.__fut + return True + + +# ===== class AioExclusiveRegion: def __init__( self, diff --git a/kvmd/apps/vnc/server.py b/kvmd/apps/vnc/server.py index 769d9002..fe340e65 100644 --- a/kvmd/apps/vnc/server.py +++ b/kvmd/apps/vnc/server.py @@ -52,6 +52,7 @@ from ...clients.streamer import StreamFormats from ...clients.streamer import BaseStreamerClient from ... import tools +from ... import aiotools from .rfb import RfbClient from .rfb.stream import rfb_format_remote @@ -113,9 +114,9 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes self.__shared_params = shared_params - self.__stage1_authorized = asyncio.Future() # type: ignore - self.__stage2_encodings_accepted = asyncio.Future() # type: ignore - self.__stage3_ws_connected = asyncio.Future() # type: ignore + self.__stage1_authorized = aiotools.AioStage() + self.__stage2_encodings_accepted = aiotools.AioStage() + self.__stage3_ws_connected = aiotools.AioStage() self.__kvmd_session: Optional[KvmdClientSession] = None self.__kvmd_ws: Optional[KvmdClientWs] = None @@ -149,19 +150,17 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes async def __kvmd_task_loop(self) -> None: logger = get_logger(0) - await self.__stage1_authorized + await self.__stage1_authorized.wait_passed() logger.info("[kvmd] %s: Waiting for the SetEncodings message ...", self._remote) - try: - await asyncio.wait_for(self.__stage2_encodings_accepted, timeout=5) - except asyncio.TimeoutError: + if not (await self.__stage2_encodings_accepted.wait_passed(timeout=5)): raise RfbError("No SetEncodings message recieved from the client in 5 secs") assert self.__kvmd_session try: async with self.__kvmd_session.ws() as self.__kvmd_ws: logger.info("[kvmd] %s: Connected to KVMD websocket", self._remote) - self.__stage3_ws_connected.set_result(None) + self.__stage3_ws_connected.set_passed() async for event in self.__kvmd_ws.communicate(): await self.__process_ws_event(event) raise RfbError("KVMD closes the websocket (the server may have been stopped)") @@ -191,7 +190,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes async def __streamer_task_loop(self) -> None: logger = get_logger(0) - await self.__stage3_ws_connected + await self.__stage3_ws_connected.wait_passed() streamer = self.__get_preferred_streamer() while True: try: @@ -272,7 +271,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes async def _authorize_userpass(self, user: str, passwd: str) -> bool: self.__kvmd_session = self.__kvmd.make_session(user, passwd) if (await self.__kvmd_session.auth.check()): - self.__stage1_authorized.set_result(None) + self.__stage1_authorized.set_passed() return True return False @@ -339,25 +338,23 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes self.__mouse_move = move async def _on_cut_event(self, text: str) -> None: - assert self.__stage1_authorized.done() + assert self.__stage1_authorized.is_passed() assert self.__kvmd_session logger = get_logger(0) logger.info("[main] %s: Printing %d characters ...", self._remote, len(text)) try: - (default, available) = await self.__kvmd_session.hid.get_keymaps() - await self.__kvmd_session.hid.print( - text=text, - limit=0, - keymap_name=(self.__keymap_name if self.__keymap_name in available else default), - ) + (keymap_name, available) = await self.__kvmd_session.hid.get_keymaps() + if self.__keymap_name in available: + keymap_name = self.__keymap_name + await self.__kvmd_session.hid.print(text, 0, keymap_name) except Exception: logger.exception("[main] %s: Can't print characters", self._remote) async def _on_set_encodings(self) -> None: - assert self.__stage1_authorized.done() + assert self.__stage1_authorized.is_passed() assert self.__kvmd_session - if not self.__stage2_encodings_accepted.done(): - self.__stage2_encodings_accepted.set_result(None) + self.__stage2_encodings_accepted.set_passed(multi=True) + has_quality = (await self.__kvmd_session.streamer.get_state())["features"]["quality"] quality = (self._encodings.tight_jpeg_quality if has_quality else None) get_logger(0).info("[main] %s: Applying streamer params: jpeg_quality=%s; desired_fps=%d ...", |