diff options
Diffstat (limited to 'kvmd/aiotools.py')
-rw-r--r-- | kvmd/aiotools.py | 77 |
1 files changed, 34 insertions, 43 deletions
diff --git a/kvmd/aiotools.py b/kvmd/aiotools.py index 24021495..18990c2c 100644 --- a/kvmd/aiotools.py +++ b/kvmd/aiotools.py @@ -24,7 +24,6 @@ import os import asyncio import asyncio.queues import functools -import contextlib import types import typing @@ -32,7 +31,6 @@ import typing from typing import List from typing import Callable from typing import Coroutine -from typing import AsyncGenerator from typing import Type from typing import TypeVar from typing import Optional @@ -59,27 +57,6 @@ def atomic(method: _MethodT) -> _MethodT: return typing.cast(_MethodT, wrapper) -def muted(msg: str) -> Callable[[_MethodT], Callable[..., None]]: - def make_wrapper(method: _MethodT) -> Callable[..., None]: - @functools.wraps(method) - async def wrapper(*args: Any, **kwargs: Any) -> None: - try: - await method(*args, **kwargs) - except asyncio.CancelledError: # pylint: disable=try-except-raise - raise - except Exception: - get_logger(0).exception(msg) - return typing.cast(Callable[..., None], wrapper) - return make_wrapper - - -def tasked(method: Callable[..., Any]) -> Callable[..., asyncio.Task]: - @functools.wraps(method) - async def wrapper(*args: Any, **kwargs: Any) -> asyncio.Task: - return create_short_task(method(*args, **kwargs)) - return typing.cast(Callable[..., asyncio.Task], wrapper) - - # ===== def create_short_task(coro: Coroutine) -> asyncio.Task: task = asyncio.create_task(coro) @@ -110,17 +87,6 @@ async def wait_infinite() -> None: # ===== -async def unlock_only_on_exception(lock: asyncio.Lock) -> AsyncGenerator[None, None]: - await lock.acquire() - try: - yield - except: # noqa: E722 - lock.release() - raise - - -# ===== async def afile_write_now(afile: aiofiles.base.AiofilesContextManager, data: bytes) -> None: await afile.write(data) await afile.flush() @@ -154,6 +120,9 @@ class AioExclusiveRegion: self.__busy = False + def get_exc_type(self) -> Type[Exception]: + return self.__exc_type + def is_busy(self) -> bool: return self.__busy @@ -174,15 +143,6 @@ class AioExclusiveRegion: if self.__notifier: await self.__notifier.notify() - @contextlib.asynccontextmanager - async def exit_only_on_exception(self) -> AsyncGenerator[None, None]: - await self.enter() - try: - yield - except: # noqa: E722 - await self.exit() - raise - async def __aenter__(self) -> None: await self.enter() @@ -194,3 +154,34 @@ class AioExclusiveRegion: ) -> None: await self.exit() + + +async def run_region_task( + msg: str, + region: AioExclusiveRegion, + method: Callable[..., Coroutine[Any, Any, None]], + *args: Any, + **kwargs: Any, +) -> None: + + entered = asyncio.Future() # type: ignore + + async def wrapper() -> None: + try: + async with region: + entered.set_result(None) + await method(*args, **kwargs) + except asyncio.CancelledError: # pylint: disable=try-except-raise + raise + except region.get_exc_type(): + raise + except Exception: + get_logger(0).exception(msg) + + task = create_short_task(wrapper()) + await asyncio.wait([entered, task], return_when=asyncio.FIRST_COMPLETED) + + if entered.done(): + return + if (exc := task.exception()) is not None: # noqa: E203,E231 + raise exc |