From 3b16242cfa4b656ba7d396a600b452184aac6c76 Mon Sep 17 00:00:00 2001 From: Devaev Maxim Date: Mon, 2 Mar 2020 02:13:47 +0300 Subject: region: notify about enter/exit, unregion on exception --- kvmd/aiotools.py | 47 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 15 deletions(-) (limited to 'kvmd') diff --git a/kvmd/aiotools.py b/kvmd/aiotools.py index f8965edd..24021495 100644 --- a/kvmd/aiotools.py +++ b/kvmd/aiotools.py @@ -35,6 +35,7 @@ from typing import Coroutine from typing import AsyncGenerator from typing import Type from typing import TypeVar +from typing import Optional from typing import Any import aiofiles @@ -126,10 +127,31 @@ async def afile_write_now(afile: aiofiles.base.AiofilesContextManager, data: byt await run_async(os.fsync, afile.fileno()) +# ===== +class AioNotifier: + def __init__(self) -> None: + self.__queue: asyncio.queues.Queue = asyncio.Queue() + + async def notify(self) -> None: + await self.__queue.put(None) + + async def wait(self) -> None: + await self.__queue.get() + while not self.__queue.empty(): + await self.__queue.get() + + # ===== class AioExclusiveRegion: - def __init__(self, exc_type: Type[Exception]) -> None: + def __init__( + self, + exc_type: Type[Exception], + notifier: Optional[AioNotifier]=None, + ) -> None: + self.__exc_type = exc_type + self.__notifier = notifier + self.__busy = False def is_busy(self) -> bool: @@ -138,11 +160,19 @@ class AioExclusiveRegion: async def enter(self) -> None: if not self.__busy: self.__busy = True + try: + if self.__notifier: + await self.__notifier.notify() + except: # noqa: E722 + self.__busy = False + raise return raise self.__exc_type() async def exit(self) -> None: self.__busy = False + if self.__notifier: + await self.__notifier.notify() @contextlib.asynccontextmanager async def exit_only_on_exception(self) -> AsyncGenerator[None, None]: @@ -162,18 +192,5 @@ class AioExclusiveRegion: _exc: BaseException, _tb: types.TracebackType, ) -> None: - await self.exit() - - -# ===== -class AioNotifier: - def __init__(self) -> None: - self.__queue: asyncio.queues.Queue = asyncio.Queue() - async def notify(self) -> None: - await self.__queue.put(None) - - async def wait(self) -> None: - await self.__queue.get() - while not self.__queue.empty(): - await self.__queue.get() + await self.exit() -- cgit v1.2.3