summaryrefslogtreecommitdiff
path: root/kvmd/apps
diff options
context:
space:
mode:
authorDevaev Maxim <[email protected]>2021-06-08 03:12:24 +0300
committerDevaev Maxim <[email protected]>2021-06-08 03:12:24 +0300
commitb5ab5699c439a88c17eafc2800a5c9e7213aa3c4 (patch)
tree2f1160f781b013f45e12612d2ae9dcc301e87d73 /kvmd/apps
parentcf08c04e55742beefbe5d642ff5bfa2fd7d3dff8 (diff)
pikvm/pikvm#321: server-side uploading counters
Diffstat (limited to 'kvmd/apps')
-rw-r--r--kvmd/apps/kvmd/api/msd.py8
-rw-r--r--kvmd/apps/kvmd/http.py5
2 files changed, 10 insertions, 3 deletions
diff --git a/kvmd/apps/kvmd/api/msd.py b/kvmd/apps/kvmd/api/msd.py
index 0b8fa6e2..315a86f1 100644
--- a/kvmd/apps/kvmd/api/msd.py
+++ b/kvmd/apps/kvmd/api/msd.py
@@ -28,10 +28,12 @@ from ....logging import get_logger
from ....plugins.msd import BaseMsd
from ....validators.basic import valid_bool
+from ....validators.basic import valid_int_f0
from ....validators.kvm import valid_msd_image_name
from ..http import exposed_http
from ..http import make_json_response
+from ..http import get_field_value
from ..http import get_multipart_field
@@ -71,12 +73,12 @@ class MsdApi:
name = ""
written = 0
try:
- name_field = await get_multipart_field(reader, "image")
- name = valid_msd_image_name((await name_field.read()).decode("utf-8"))
+ name = valid_msd_image_name(await get_field_value(reader, "image"))
+ size = valid_int_f0(await get_field_value(reader, "size"))
data_field = await get_multipart_field(reader, "data")
- async with self.__msd.write_image(name):
+ async with self.__msd.write_image(name, size):
logger.info("Writing image %r to MSD ...", name)
while True:
chunk = await data_field.read_chunk(self.__msd.get_upload_chunk_size())
diff --git a/kvmd/apps/kvmd/http.py b/kvmd/apps/kvmd/http.py
index e46be2fa..941205b2 100644
--- a/kvmd/apps/kvmd/http.py
+++ b/kvmd/apps/kvmd/http.py
@@ -171,6 +171,11 @@ def make_json_exception(err: Exception, status: Optional[int]=None) -> aiohttp.w
# =====
+async def get_field_value(reader: aiohttp.MultipartReader, name: str) -> str:
+ field = await get_multipart_field(reader, name)
+ return (await field.read()).decode("utf-8")
+
+
async def get_multipart_field(reader: aiohttp.MultipartReader, name: str) -> aiohttp.BodyPartReader:
field = await reader.next()
if not isinstance(field, aiohttp.BodyPartReader):