diff options
-rw-r--r-- | kvmd/aiotools.py | 41 | ||||
-rw-r--r-- | testenv/tests/test_aiotools.py | 34 |
2 files changed, 75 insertions, 0 deletions
diff --git a/kvmd/aiotools.py b/kvmd/aiotools.py index 6895421c..2a798e2a 100644 --- a/kvmd/aiotools.py +++ b/kvmd/aiotools.py @@ -84,6 +84,47 @@ _RetvalT = TypeVar("_RetvalT") # ===== +class _ArmoredFuture(asyncio.Future): + def cancel(self, *_, **__) -> bool: # type: ignore + # FIXME: Выяснить, почему это работает + return False + + def forced_cancel(self) -> bool: + return super().cancel() + + +def shield_fg(aw: Awaitable): # type: ignore + # XXX: Копия asyncio.shield() с небольшими изменениями + + inner = asyncio.ensure_future(aw) + if inner.done(): + return inner + outer = _ArmoredFuture(loop=inner.get_loop()) + + def inner_done(_) -> None: # type: ignore + if outer.cancelled(): + if not inner.cancelled(): + inner.exception() # Mark inner's result as retrieved + return + + if inner.cancelled(): + outer.forced_cancel() + else: + err = inner.exception() + if err is not None: + outer.set_exception(err) + else: + outer.set_result(inner.result()) + + def outer_done(_) -> None: # type: ignore + if not inner.done(): + inner.remove_done_callback(inner_done) + + inner.add_done_callback(inner_done) + outer.add_done_callback(outer_done) + return outer + + def atomic(func: _FunctionT) -> _FunctionT: @functools.wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: diff --git a/testenv/tests/test_aiotools.py b/testenv/tests/test_aiotools.py index 23bc99c0..9ccaee0a 100644 --- a/testenv/tests/test_aiotools.py +++ b/testenv/tests/test_aiotools.py @@ -22,9 +22,12 @@ import asyncio +from typing import List + import pytest from kvmd.aiotools import AioExclusiveRegion +from kvmd.aiotools import shield_fg # ===== @@ -115,3 +118,34 @@ async def test_fail__region__access_two() -> None: assert not region.is_busy() await region.exit() assert not region.is_busy() + + +# ===== +async def test_ok__shield_fg() -> None: + ops: List[str] = [] + + async def foo(op: str, delay: float) -> None: # pylint: disable=disallowed-name + await asyncio.sleep(delay) + ops.append(op) + + async def bar() -> None: # pylint: disable=disallowed-name + try: + try: + try: + raise RuntimeError() + finally: + await shield_fg(foo("foo1", 2.0)) + ops.append("foo1-noexc") + finally: + await shield_fg(foo("foo2", 1.0)) + ops.append("foo2-noexc") + finally: + ops.append("done") + + task = asyncio.create_task(bar()) + await asyncio.sleep(0.1) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + assert ops == ["foo1", "foo2", "foo2-noexc", "done"] |