summaryrefslogtreecommitdiff
path: root/kvmd
diff options
context:
space:
mode:
authorMaxim Devaev <[email protected]>2021-07-27 05:25:54 +0300
committerMaxim Devaev <[email protected]>2021-07-27 05:25:54 +0300
commit6b07a80834879970639b651b1a225f8d1f5f7c2e (patch)
tree51f5c91b3b384852ff5c8558b097c02f7a5fd120 /kvmd
parent3c421fa94cbd1e479a51d66eaa4fb33d74277009 (diff)
/msd/write_remote handle
Diffstat (limited to 'kvmd')
-rw-r--r--kvmd/apps/kvmd/api/msd.py101
-rw-r--r--kvmd/apps/kvmd/http.py4
-rw-r--r--kvmd/htclient.py53
-rw-r--r--kvmd/plugins/msd/__init__.py17
-rw-r--r--kvmd/plugins/msd/disabled.py7
-rw-r--r--kvmd/plugins/msd/otg/__init__.py7
-rw-r--r--kvmd/plugins/msd/relay/__init__.py7
-rw-r--r--kvmd/validators/net.py5
8 files changed, 163 insertions, 38 deletions
diff --git a/kvmd/apps/kvmd/api/msd.py b/kvmd/apps/kvmd/api/msd.py
index 98c7957e..4e1a7f14 100644
--- a/kvmd/apps/kvmd/api/msd.py
+++ b/kvmd/apps/kvmd/api/msd.py
@@ -20,19 +20,34 @@
# ========================================================================== #
+import time
+
+from typing import Dict
+from typing import Optional
+
+import aiohttp
+
from aiohttp.web import Request
from aiohttp.web import Response
+from aiohttp.web import StreamResponse
from ....logging import get_logger
+from .... import htclient
+
from ....plugins.msd import BaseMsd
from ....validators.basic import valid_bool
from ....validators.basic import valid_int_f0
+from ....validators.basic import valid_float_f01
+from ....validators.net import valid_url
from ....validators.kvm import valid_msd_image_name
from ..http import exposed_http
from ..http import make_json_response
+from ..http import make_json_exception
+from ..http import start_streaming
+from ..http import stream_json
from ..http import get_multipart_reader
from ..http import get_multipart_reader_str
from ..http import get_multipart_reader_field
@@ -67,29 +82,79 @@ class MsdApi:
await self.__msd.set_connected(valid_bool(request.query.get("connected")))
return make_json_response()
+ # =====
+
@exposed_http("POST", "/msd/write")
async def __write_handler(self, request: Request) -> Response:
- logger = get_logger(0)
reader = await get_multipart_reader(request)
- name = ""
+ name = valid_msd_image_name(await get_multipart_reader_str(reader, "image"))
+ size = valid_int_f0(await get_multipart_reader_str(reader, "size"))
+ data_field = await get_multipart_reader_field(reader, "data")
+
written = 0
+ async with self.__msd.write_image(name, size) as chunk_size:
+ while True:
+ chunk = await data_field.read_chunk(chunk_size)
+ if not chunk:
+ break
+ written = await self.__msd.write_image_chunk(chunk)
+
+ return make_json_response(self.__make_write_info(name, size, written))
+
+ @exposed_http("POST", "/msd/write_remote")
+ async def __write_remote_handler(self, request: Request) -> StreamResponse: # pylint: disable=too-many-locals
+ url = valid_url(request.query.get("url"))
+ insecure = valid_bool(request.query.get("insecure", "0"))
+ timeout = valid_float_f01(request.query.get("timeout", 10.0))
+
+ name = ""
+ size = written = 0
+ response: Optional[StreamResponse] = None
+
+ async def stream_write_info() -> None:
+ assert response is not None
+ await stream_json(response, self.__make_write_info(name, size, written))
+
try:
- name = valid_msd_image_name(await get_multipart_reader_str(reader, "image"))
- size = valid_int_f0(await get_multipart_reader_str(reader, "size"))
-
- data_field = await get_multipart_reader_field(reader, "data")
-
- async with self.__msd.write_image(name, size):
- logger.info("Writing image %r to MSD ...", name)
- while True:
- chunk = await data_field.read_chunk(self.__msd.get_upload_chunk_size())
- if not chunk:
- break
- written = await self.__msd.write_image_chunk(chunk)
- finally:
- if written != 0:
- logger.info("Written image %r with size=%d bytes to MSD", name, written)
- return make_json_response({"image": {"name": name, "size": written}})
+ async with htclient.download(
+ url=url,
+ verify=(not insecure),
+ timeout=timeout,
+ read_timeout=(7 * 24 * 3600),
+ ) as remote:
+
+ name = str(request.query.get("image", "")).strip()
+ if len(name) == 0:
+ name = htclient.get_filename(remote)
+ name = valid_msd_image_name(name)
+
+ size = htclient.get_content_length(remote)
+
+ get_logger(0).info("Downloading image %r as %r to MSD ...", url, name)
+ async with self.__msd.write_image(name, size) as chunk_size:
+ response = await start_streaming(request, "application/stream+json")
+ last_report_ts = 0
+ async for chunk in remote.content.iter_chunked(chunk_size):
+ written = await self.__msd.write_image_chunk(chunk)
+ now = int(time.time())
+ if last_report_ts + 1 < now:
+ await stream_write_info()
+ last_report_ts = now
+
+ await stream_write_info()
+ return response
+
+ except Exception as err:
+ if response is not None:
+ await stream_write_info()
+ elif isinstance(err, aiohttp.ClientError):
+ return make_json_exception(err, 400)
+ raise
+
+ def __make_write_info(self, name: str, size: int, written: int) -> Dict:
+ return {"image": {"name": name, "size": size, "written": written}}
+
+ # =====
@exposed_http("POST", "/msd/remove")
async def __remove_handler(self, request: Request) -> Response:
diff --git a/kvmd/apps/kvmd/http.py b/kvmd/apps/kvmd/http.py
index e436844a..ae2cebc0 100644
--- a/kvmd/apps/kvmd/http.py
+++ b/kvmd/apps/kvmd/http.py
@@ -176,6 +176,10 @@ async def start_streaming(request: aiohttp.web.Request, content_type: str) -> ai
return response
+async def stream_json(response: aiohttp.web.StreamResponse, result: Dict) -> None:
+ await response.write(json.dumps(result).encode("utf-8") + b"\r\n")
+
+
# =====
async def get_multipart_reader(request: aiohttp.web.Request) -> aiohttp.MultipartReader:
try:
diff --git a/kvmd/htclient.py b/kvmd/htclient.py
index 368bf668..1927ede2 100644
--- a/kvmd/htclient.py
+++ b/kvmd/htclient.py
@@ -20,7 +20,15 @@
# ========================================================================== #
+import os
+import contextlib
+
+from typing import Dict
+from typing import AsyncGenerator
+from typing import Optional
+
import aiohttp
+import aiohttp.multipart
from . import __version__
@@ -41,3 +49,48 @@ def raise_not_200(response: aiohttp.ClientResponse) -> None:
message=response.reason,
headers=response.headers,
)
+
+
+def get_content_length(response: aiohttp.ClientResponse) -> int:
+ try:
+ value = int(response.headers["Content-Length"])
+ except Exception:
+ raise aiohttp.ClientError("Empty or invalid Content-Length")
+ if value < 0:
+ raise aiohttp.ClientError("Negative Content-Length")
+ return value
+
+
+def get_filename(response: aiohttp.ClientResponse) -> str:
+ try:
+ disp = response.headers["Content-Disposition"]
+ parsed = aiohttp.multipart.parse_content_disposition(disp)
+ return str(parsed[1]["filename"])
+ except Exception:
+ try:
+ return os.path.basename(response.url.path)
+ except Exception:
+ raise aiohttp.ClientError("Can't determine filename")
+
+
+async def download(
+ url: str,
+ verify: bool=True,
+ timeout: float=10.0,
+ read_timeout: Optional[float]=None,
+ app: str="KVMD",
+) -> AsyncGenerator[aiohttp.ClientResponse, None]:
+
+ kwargs: Dict = {
+ "headers": {"User-Agent": make_user_agent(app)},
+ "timeout": aiohttp.ClientTimeout(
+ connect=timeout,
+ sock_connect=timeout,
+ sock_read=(read_timeout if read_timeout is not None else timeout),
+ ),
+ }
+ async with aiohttp.ClientSession(**kwargs) as session:
+ async with session.get(url, verify_ssl=verify) as response:
+ raise_not_200(response)
+ yield response
diff --git a/kvmd/plugins/msd/__init__.py b/kvmd/plugins/msd/__init__.py
index 680ccece..547d778b 100644
--- a/kvmd/plugins/msd/__init__.py
+++ b/kvmd/plugins/msd/__init__.py
@@ -31,6 +31,8 @@ from typing import Optional
import aiofiles
import aiofiles.base
+from ...logging import get_logger
+
from ... import aiofs
from ...errors import OperationError
@@ -119,13 +121,10 @@ class BaseMsd(BasePlugin):
raise NotImplementedError()
@contextlib.asynccontextmanager
- async def write_image(self, name: str, size: int) -> AsyncGenerator[None, None]: # pylint: disable=unused-argument
+ async def write_image(self, name: str, size: int) -> AsyncGenerator[int, None]: # pylint: disable=unused-argument
if self is not None: # XXX: Vulture and pylint hack
raise NotImplementedError()
- yield
-
- def get_upload_chunk_size(self) -> int:
- raise NotImplementedError()
+ yield 1
async def write_image_chunk(self, chunk: bytes) -> int:
raise NotImplementedError()
@@ -158,6 +157,7 @@ class MsdImageWriter:
async def open(self) -> "MsdImageWriter":
assert self.__file is None
+ get_logger(1).info("Writing %r image (%d bytes) to MSD ...", self.__name, self.__size)
self.__file = await aiofiles.open(self.__path, mode="w+b", buffering=0) # type: ignore
return self
@@ -176,6 +176,13 @@ class MsdImageWriter:
async def close(self) -> None:
assert self.__file is not None
+ if self.__written == self.__size:
+ (log, result) = (get_logger().info, "OK")
+ elif self.__written < self.__size:
+ (log, result) = (get_logger().error, "INCOMPLETE")
+ else: # written > size
+ (log, result) = (get_logger().warning, "OVERFLOW")
+ log("Written %d of %d bytes to MSD image %r: %s", self.__written, self.__size, self.__name, result)
await aiofs.afile_sync(self.__file)
await self.__file.close() # type: ignore
diff --git a/kvmd/plugins/msd/disabled.py b/kvmd/plugins/msd/disabled.py
index fc1979e0..220eaf7a 100644
--- a/kvmd/plugins/msd/disabled.py
+++ b/kvmd/plugins/msd/disabled.py
@@ -70,13 +70,10 @@ class Plugin(BaseMsd):
raise MsdDisabledError()
@contextlib.asynccontextmanager
- async def write_image(self, name: str, size: int) -> AsyncGenerator[None, None]:
+ async def write_image(self, name: str, size: int) -> AsyncGenerator[int, None]:
if self is not None: # XXX: Vulture and pylint hack
raise MsdDisabledError()
- yield
-
- def get_upload_chunk_size(self) -> int:
- raise MsdDisabledError()
+ yield 1
async def write_image_chunk(self, chunk: bytes) -> int:
raise MsdDisabledError()
diff --git a/kvmd/plugins/msd/otg/__init__.py b/kvmd/plugins/msd/otg/__init__.py
index 03e06650..4289db2a 100644
--- a/kvmd/plugins/msd/otg/__init__.py
+++ b/kvmd/plugins/msd/otg/__init__.py
@@ -306,7 +306,7 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
self.__state.vd.connected = connected
@contextlib.asynccontextmanager
- async def write_image(self, name: str, size: int) -> AsyncGenerator[None, None]:
+ async def write_image(self, name: str, size: int) -> AsyncGenerator[int, None]:
try:
async with self.__state._region: # pylint: disable=protected-access
try:
@@ -328,7 +328,7 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
self.__new_writer = await MsdImageWriter(path, size, self.__sync_chunk_size).open()
await self.__notifier.notify()
- yield
+ yield self.__upload_chunk_size
self.__set_image_complete(name, True)
finally:
@@ -343,9 +343,6 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
await self.__reload_state()
await self.__notifier.notify()
- def get_upload_chunk_size(self) -> int:
- return self.__upload_chunk_size
-
async def write_image_chunk(self, chunk: bytes) -> int:
assert self.__new_writer
written = await self.__new_writer.write(chunk)
diff --git a/kvmd/plugins/msd/relay/__init__.py b/kvmd/plugins/msd/relay/__init__.py
index 82de6596..a608e535 100644
--- a/kvmd/plugins/msd/relay/__init__.py
+++ b/kvmd/plugins/msd/relay/__init__.py
@@ -208,7 +208,7 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
self.__connected = connected
@contextlib.asynccontextmanager
- async def write_image(self, name: str, size: int) -> AsyncGenerator[None, None]:
+ async def write_image(self, name: str, size: int) -> AsyncGenerator[int, None]:
async with self.__working():
async with self.__region:
try:
@@ -220,15 +220,12 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
await self.__write_image_info(False)
await self.__notifier.notify()
- yield
+ yield self.__upload_chunk_size
await self.__write_image_info(True)
finally:
await self.__close_device_writer()
await self.__load_device_info()
- def get_upload_chunk_size(self) -> int:
- return self.__upload_chunk_size
-
async def write_image_chunk(self, chunk: bytes) -> int:
assert self.__device_writer
return (await self.__device_writer.write(chunk))
diff --git a/kvmd/validators/net.py b/kvmd/validators/net.py
index 991540bd..c45b3aed 100644
--- a/kvmd/validators/net.py
+++ b/kvmd/validators/net.py
@@ -116,3 +116,8 @@ def valid_ssl_ciphers(arg: Any) -> str:
except Exception as err:
raise ValidatorError(f"The argument {arg!r} is not a valid {name}: {err}")
return arg
+
+
+def valid_url(arg: Any) -> str:
+ # XXX: VERY primitive
+ return check_re_match(arg, "HTTP(S) URL", r"^https?://[\[\w]+\S*")