summaryrefslogtreecommitdiff
path: root/kvmd
diff options
context:
space:
mode:
Diffstat (limited to 'kvmd')
-rw-r--r--kvmd/aiotools.py47
1 files changed, 32 insertions, 15 deletions
diff --git a/kvmd/aiotools.py b/kvmd/aiotools.py
index f8965edd..24021495 100644
--- a/kvmd/aiotools.py
+++ b/kvmd/aiotools.py
@@ -35,6 +35,7 @@ from typing import Coroutine
from typing import AsyncGenerator
from typing import Type
from typing import TypeVar
+from typing import Optional
from typing import Any
import aiofiles
@@ -127,9 +128,30 @@ async def afile_write_now(afile: aiofiles.base.AiofilesContextManager, data: byt
# =====
+class AioNotifier:
+ def __init__(self) -> None:
+ self.__queue: asyncio.queues.Queue = asyncio.Queue()
+
+ async def notify(self) -> None:
+ await self.__queue.put(None)
+
+ async def wait(self) -> None:
+ await self.__queue.get()
+ while not self.__queue.empty():
+ await self.__queue.get()
+
+
+# =====
class AioExclusiveRegion:
- def __init__(self, exc_type: Type[Exception]) -> None:
+ def __init__(
+ self,
+ exc_type: Type[Exception],
+ notifier: Optional[AioNotifier]=None,
+ ) -> None:
+
self.__exc_type = exc_type
+ self.__notifier = notifier
+
self.__busy = False
def is_busy(self) -> bool:
@@ -138,11 +160,19 @@ class AioExclusiveRegion:
async def enter(self) -> None:
if not self.__busy:
self.__busy = True
+ try:
+ if self.__notifier:
+ await self.__notifier.notify()
+ except: # noqa: E722
+ self.__busy = False
+ raise
return
raise self.__exc_type()
async def exit(self) -> None:
self.__busy = False
+ if self.__notifier:
+ await self.__notifier.notify()
@contextlib.asynccontextmanager
async def exit_only_on_exception(self) -> AsyncGenerator[None, None]:
@@ -162,18 +192,5 @@ class AioExclusiveRegion:
_exc: BaseException,
_tb: types.TracebackType,
) -> None:
- await self.exit()
-
-
-# =====
-class AioNotifier:
- def __init__(self) -> None:
- self.__queue: asyncio.queues.Queue = asyncio.Queue()
- async def notify(self) -> None:
- await self.__queue.put(None)
-
- async def wait(self) -> None:
- await self.__queue.get()
- while not self.__queue.empty():
- await self.__queue.get()
+ await self.exit()