summaryrefslogtreecommitdiff
path: root/kvmd
diff options
context:
space:
mode:
Diffstat (limited to 'kvmd')
-rw-r--r--kvmd/apps/kvmd/server.py28
1 files changed, 19 insertions, 9 deletions
diff --git a/kvmd/apps/kvmd/server.py b/kvmd/apps/kvmd/server.py
index 2721189b..1bd63446 100644
--- a/kvmd/apps/kvmd/server.py
+++ b/kvmd/apps/kvmd/server.py
@@ -155,8 +155,14 @@ _HEADER_AUTH_PASSWD = "X-KVMD-Passwd"
_COOKIE_AUTH_TOKEN = "auth_token"
-def _exposed(http_method: str, path: str, auth_required: bool=True) -> Callable:
- def make_wrapper(method: Callable) -> Callable:
+def _atomic(handler: Callable) -> Callable:
+ async def wrap(self: "Server", request: aiohttp.web.Request) -> aiohttp.web.Response:
+ return (await asyncio.shield(handler(self, request)))
+ return wrap
+
+
+def _exposed(http_method: str, path: str, atomic: bool=False, auth_required: bool=True) -> Callable:
+ def make_wrapper(handler: Callable) -> Callable:
async def wrap(self: "Server", request: aiohttp.web.Request) -> aiohttp.web.Response:
try:
if auth_required:
@@ -180,7 +186,7 @@ def _exposed(http_method: str, path: str, auth_required: bool=True) -> Callable:
else:
raise UnauthorizedError("Unauthorized")
- return (await method(self, request))
+ return (await handler(self, request))
except RegionIsBusyError as err:
return _json_exception(err, 409)
@@ -191,6 +197,9 @@ def _exposed(http_method: str, path: str, auth_required: bool=True) -> Callable:
except ForbiddenError as err:
return _json_exception(err, 403)
+ if atomic:
+ wrap = _atomic(wrap)
+
setattr(wrap, _ATTR_EXPOSED, True)
setattr(wrap, _ATTR_EXPOSED_METHOD, http_method)
setattr(wrap, _ATTR_EXPOSED_PATH, path)
@@ -302,7 +311,7 @@ class Server: # pylint: disable=too-many-instance-attributes
# ===== AUTH
- @_exposed("POST", "/auth/login", auth_required=False)
+ @_exposed("POST", "/auth/login", atomic=True, auth_required=False)
async def __auth_login_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
credentials = await request.post()
token = await self._auth_manager.login(
@@ -420,7 +429,7 @@ class Server: # pylint: disable=too-many-instance-attributes
async def __hid_state_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
return _json(self.__hid.get_state())
- @_exposed("POST", "/hid/reset")
+ @_exposed("POST", "/hid/reset", atomic=True)
async def __hid_reset_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
await self.__hid.reset()
return _json()
@@ -431,7 +440,7 @@ class Server: # pylint: disable=too-many-instance-attributes
async def __atx_state_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
return _json(self.__atx.get_state())
- @_exposed("POST", "/atx/power")
+ @_exposed("POST", "/atx/power", atomic=True)
async def __atx_power_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
action = valid_atx_power_action(request.query.get("action"))
done = await ({
@@ -442,7 +451,7 @@ class Server: # pylint: disable=too-many-instance-attributes
}[action])()
return _json({"action": action, "done": done})
- @_exposed("POST", "/atx/click")
+ @_exposed("POST", "/atx/click", atomic=True)
async def __atx_click_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
button = valid_atx_button(request.query.get("button"))
await ({
@@ -458,7 +467,7 @@ class Server: # pylint: disable=too-many-instance-attributes
async def __msd_state_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
return _json(self.__msd.get_state())
- @_exposed("POST", "/msd/connect")
+ @_exposed("POST", "/msd/connect", atomic=True)
async def __msd_connect_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
to = valid_kvm_target(request.query.get("to"))
return _json(await ({
@@ -491,7 +500,7 @@ class Server: # pylint: disable=too-many-instance-attributes
logger.info("Written %d bytes to mass-storage device", written)
return _json({"written": written})
- @_exposed("POST", "/msd/reset")
+ @_exposed("POST", "/msd/reset", atomic=True)
async def __msd_reset_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
await self.__msd.reset()
return _json()
@@ -538,6 +547,7 @@ class Server: # pylint: disable=too-many-instance-attributes
getattr(method, _ATTR_EXPOSED_PATH),
method,
)
+
return app
def __run_app_print(self, text: str) -> None: