diff options
Diffstat (limited to 'kvmd/apps')
-rw-r--r-- | kvmd/apps/kvmd/api/msd.py | 13 | ||||
-rw-r--r-- | kvmd/apps/kvmd/http.py | 13 |
2 files changed, 17 insertions, 9 deletions
diff --git a/kvmd/apps/kvmd/api/msd.py b/kvmd/apps/kvmd/api/msd.py index 315a86f1..98c7957e 100644 --- a/kvmd/apps/kvmd/api/msd.py +++ b/kvmd/apps/kvmd/api/msd.py @@ -33,8 +33,9 @@ from ....validators.kvm import valid_msd_image_name from ..http import exposed_http from ..http import make_json_response -from ..http import get_field_value -from ..http import get_multipart_field +from ..http import get_multipart_reader +from ..http import get_multipart_reader_str +from ..http import get_multipart_reader_field # ====== @@ -69,14 +70,14 @@ class MsdApi: @exposed_http("POST", "/msd/write") async def __write_handler(self, request: Request) -> Response: logger = get_logger(0) - reader = await request.multipart() + reader = await get_multipart_reader(request) name = "" written = 0 try: - name = valid_msd_image_name(await get_field_value(reader, "image")) - size = valid_int_f0(await get_field_value(reader, "size")) + name = valid_msd_image_name(await get_multipart_reader_str(reader, "image")) + size = valid_int_f0(await get_multipart_reader_str(reader, "size")) - data_field = await get_multipart_field(reader, "data") + data_field = await get_multipart_reader_field(reader, "data") async with self.__msd.write_image(name, size): logger.info("Writing image %r to MSD ...", name) diff --git a/kvmd/apps/kvmd/http.py b/kvmd/apps/kvmd/http.py index daa0452a..c1525046 100644 --- a/kvmd/apps/kvmd/http.py +++ b/kvmd/apps/kvmd/http.py @@ -177,12 +177,19 @@ async def start_streaming(request: aiohttp.web.Request, content_type: str) -> ai # ===== -async def get_field_value(reader: aiohttp.MultipartReader, name: str) -> str: - field = await get_multipart_field(reader, name) +async def get_multipart_reader(request: aiohttp.web.Request) -> aiohttp.MultipartReader: + try: + return (await request.multipart()) + except Exception: + raise ValidatorError("Expected multipart") + + +async def get_multipart_reader_str(reader: aiohttp.MultipartReader, name: str) -> str: + field = await get_multipart_reader_field(reader, name) return (await field.read()).decode("utf-8") -async def get_multipart_field(reader: aiohttp.MultipartReader, name: str) -> aiohttp.BodyPartReader: +async def get_multipart_reader_field(reader: aiohttp.MultipartReader, name: str) -> aiohttp.BodyPartReader: field = await reader.next() if not isinstance(field, aiohttp.BodyPartReader): raise ValidatorError(f"Expected body part as {name!r} field") |