summaryrefslogtreecommitdiff
path: root/kvmd/aiotools.py
diff options
context:
space:
mode:
Diffstat (limited to 'kvmd/aiotools.py')
-rw-r--r--kvmd/aiotools.py77
1 files changed, 34 insertions, 43 deletions
diff --git a/kvmd/aiotools.py b/kvmd/aiotools.py
index 24021495..18990c2c 100644
--- a/kvmd/aiotools.py
+++ b/kvmd/aiotools.py
@@ -24,7 +24,6 @@ import os
import asyncio
import asyncio.queues
import functools
-import contextlib
import types
import typing
@@ -32,7 +31,6 @@ import typing
from typing import List
from typing import Callable
from typing import Coroutine
-from typing import AsyncGenerator
from typing import Type
from typing import TypeVar
from typing import Optional
@@ -59,27 +57,6 @@ def atomic(method: _MethodT) -> _MethodT:
return typing.cast(_MethodT, wrapper)
-def muted(msg: str) -> Callable[[_MethodT], Callable[..., None]]:
- def make_wrapper(method: _MethodT) -> Callable[..., None]:
- @functools.wraps(method)
- async def wrapper(*args: Any, **kwargs: Any) -> None:
- try:
- await method(*args, **kwargs)
- except asyncio.CancelledError: # pylint: disable=try-except-raise
- raise
- except Exception:
- get_logger(0).exception(msg)
- return typing.cast(Callable[..., None], wrapper)
- return make_wrapper
-
-
-def tasked(method: Callable[..., Any]) -> Callable[..., asyncio.Task]:
- @functools.wraps(method)
- async def wrapper(*args: Any, **kwargs: Any) -> asyncio.Task:
- return create_short_task(method(*args, **kwargs))
- return typing.cast(Callable[..., asyncio.Task], wrapper)
-
-
# =====
def create_short_task(coro: Coroutine) -> asyncio.Task:
task = asyncio.create_task(coro)
@@ -110,17 +87,6 @@ async def wait_infinite() -> None:
# =====
-async def unlock_only_on_exception(lock: asyncio.Lock) -> AsyncGenerator[None, None]:
- await lock.acquire()
- try:
- yield
- except: # noqa: E722
- lock.release()
- raise
-
-
-# =====
async def afile_write_now(afile: aiofiles.base.AiofilesContextManager, data: bytes) -> None:
await afile.write(data)
await afile.flush()
@@ -154,6 +120,9 @@ class AioExclusiveRegion:
self.__busy = False
+ def get_exc_type(self) -> Type[Exception]:
+ return self.__exc_type
+
def is_busy(self) -> bool:
return self.__busy
@@ -174,15 +143,6 @@ class AioExclusiveRegion:
if self.__notifier:
await self.__notifier.notify()
- @contextlib.asynccontextmanager
- async def exit_only_on_exception(self) -> AsyncGenerator[None, None]:
- await self.enter()
- try:
- yield
- except: # noqa: E722
- await self.exit()
- raise
-
async def __aenter__(self) -> None:
await self.enter()
@@ -194,3 +154,34 @@ class AioExclusiveRegion:
) -> None:
await self.exit()
+
+
+async def run_region_task(
+ msg: str,
+ region: AioExclusiveRegion,
+ method: Callable[..., Coroutine[Any, Any, None]],
+ *args: Any,
+ **kwargs: Any,
+) -> None:
+
+ entered = asyncio.Future() # type: ignore
+
+ async def wrapper() -> None:
+ try:
+ async with region:
+ entered.set_result(None)
+ await method(*args, **kwargs)
+ except asyncio.CancelledError: # pylint: disable=try-except-raise
+ raise
+ except region.get_exc_type():
+ raise
+ except Exception:
+ get_logger(0).exception(msg)
+
+ task = create_short_task(wrapper())
+ await asyncio.wait([entered, task], return_when=asyncio.FIRST_COMPLETED)
+
+ if entered.done():
+ return
+ if (exc := task.exception()) is not None: # noqa: E203,E231
+ raise exc