summaryrefslogtreecommitdiff
path: root/kvmd/plugins/msd/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'kvmd/plugins/msd/__init__.py')
-rw-r--r--kvmd/plugins/msd/__init__.py54
1 files changed, 53 insertions, 1 deletions
diff --git a/kvmd/plugins/msd/__init__.py b/kvmd/plugins/msd/__init__.py
index c8015a2a..680ccece 100644
--- a/kvmd/plugins/msd/__init__.py
+++ b/kvmd/plugins/msd/__init__.py
@@ -20,6 +20,7 @@
# ========================================================================== #
+import os
import contextlib
from typing import Dict
@@ -27,6 +28,11 @@ from typing import Type
from typing import AsyncGenerator
from typing import Optional
+import aiofiles
+import aiofiles.base
+
+from ... import aiofs
+
from ...errors import OperationError
from ...errors import IsBusyError
@@ -113,7 +119,7 @@ class BaseMsd(BasePlugin):
raise NotImplementedError()
@contextlib.asynccontextmanager
- async def write_image(self, name: str) -> AsyncGenerator[None, None]: # pylint: disable=unused-argument
+ async def write_image(self, name: str, size: int) -> AsyncGenerator[None, None]: # pylint: disable=unused-argument
if self is not None: # XXX: Vulture and pylint hack
raise NotImplementedError()
yield
@@ -128,6 +134,52 @@ class BaseMsd(BasePlugin):
raise NotImplementedError()
+class MsdImageWriter:
+ def __init__(self, path: str, size: int, sync: int) -> None:
+ self.__name = os.path.basename(path)
+ self.__path = path
+ self.__size = size
+ self.__sync = sync
+
+ self.__file: Optional[aiofiles.base.AiofilesContextManager] = None
+ self.__written = 0
+ self.__unsynced = 0
+
+ def get_file(self) -> aiofiles.base.AiofilesContextManager:
+ assert self.__file is not None
+ return self.__file
+
+ def get_state(self) -> Dict:
+ return {
+ "name": self.__name,
+ "size": self.__size,
+ "written": self.__written,
+ }
+
+ async def open(self) -> "MsdImageWriter":
+ assert self.__file is None
+ self.__file = await aiofiles.open(self.__path, mode="w+b", buffering=0) # type: ignore
+ return self
+
+ async def write(self, chunk: bytes) -> int:
+ assert self.__file is not None
+
+ await self.__file.write(chunk) # type: ignore
+ self.__written += len(chunk)
+
+ self.__unsynced += len(chunk)
+ if self.__unsynced >= self.__sync:
+ await aiofs.afile_sync(self.__file)
+ self.__unsynced = 0
+
+ return self.__written
+
+ async def close(self) -> None:
+ assert self.__file is not None
+ await aiofs.afile_sync(self.__file)
+ await self.__file.close() # type: ignore
+
+
# =====
def get_msd_class(name: str) -> Type[BaseMsd]:
return get_plugin_class("msd", name) # type: ignore