summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMaxim Devaev <[email protected]>2022-08-07 18:42:00 +0300
committerMaxim Devaev <[email protected]>2022-08-07 18:42:00 +0300
commitaa630988cc09f31d412a62c5480d4bec1a7c626e (patch)
tree7e61031827ddcbfc56d049ffa336ac664d3c5c8b
parentd995349b6311bf06b77ac0e8ccc4a424112feeb0 (diff)
aiotools.shield_fg()
-rw-r--r--kvmd/aiotools.py41
-rw-r--r--testenv/tests/test_aiotools.py34
2 files changed, 75 insertions, 0 deletions
diff --git a/kvmd/aiotools.py b/kvmd/aiotools.py
index 6895421c..2a798e2a 100644
--- a/kvmd/aiotools.py
+++ b/kvmd/aiotools.py
@@ -84,6 +84,47 @@ _RetvalT = TypeVar("_RetvalT")
# =====
+class _ArmoredFuture(asyncio.Future):
+ def cancel(self, *_, **__) -> bool: # type: ignore
+ # FIXME: Выяснить, почему это работает
+ return False
+
+ def forced_cancel(self) -> bool:
+ return super().cancel()
+
+
+def shield_fg(aw: Awaitable): # type: ignore
+ # XXX: Копия asyncio.shield() с небольшими изменениями
+
+ inner = asyncio.ensure_future(aw)
+ if inner.done():
+ return inner
+ outer = _ArmoredFuture(loop=inner.get_loop())
+
+ def inner_done(_) -> None: # type: ignore
+ if outer.cancelled():
+ if not inner.cancelled():
+ inner.exception() # Mark inner's result as retrieved
+ return
+
+ if inner.cancelled():
+ outer.forced_cancel()
+ else:
+ err = inner.exception()
+ if err is not None:
+ outer.set_exception(err)
+ else:
+ outer.set_result(inner.result())
+
+ def outer_done(_) -> None: # type: ignore
+ if not inner.done():
+ inner.remove_done_callback(inner_done)
+
+ inner.add_done_callback(inner_done)
+ outer.add_done_callback(outer_done)
+ return outer
+
+
def atomic(func: _FunctionT) -> _FunctionT:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
diff --git a/testenv/tests/test_aiotools.py b/testenv/tests/test_aiotools.py
index 23bc99c0..9ccaee0a 100644
--- a/testenv/tests/test_aiotools.py
+++ b/testenv/tests/test_aiotools.py
@@ -22,9 +22,12 @@
import asyncio
+from typing import List
+
import pytest
from kvmd.aiotools import AioExclusiveRegion
+from kvmd.aiotools import shield_fg
# =====
@@ -115,3 +118,34 @@ async def test_fail__region__access_two() -> None:
assert not region.is_busy()
await region.exit()
assert not region.is_busy()
+
+
+# =====
+async def test_ok__shield_fg() -> None:
+ ops: List[str] = []
+
+ async def foo(op: str, delay: float) -> None: # pylint: disable=disallowed-name
+ await asyncio.sleep(delay)
+ ops.append(op)
+
+ async def bar() -> None: # pylint: disable=disallowed-name
+ try:
+ try:
+ try:
+ raise RuntimeError()
+ finally:
+ await shield_fg(foo("foo1", 2.0))
+ ops.append("foo1-noexc")
+ finally:
+ await shield_fg(foo("foo2", 1.0))
+ ops.append("foo2-noexc")
+ finally:
+ ops.append("done")
+
+ task = asyncio.create_task(bar())
+ await asyncio.sleep(0.1)
+ task.cancel()
+ with pytest.raises(asyncio.CancelledError):
+ await task
+ assert ops == ["foo1", "foo2", "foo2-noexc", "done"]