diff options
Diffstat (limited to 'kvmd/plugins/msd/__init__.py')
-rw-r--r-- | kvmd/plugins/msd/__init__.py | 54 |
1 files changed, 53 insertions, 1 deletions
diff --git a/kvmd/plugins/msd/__init__.py b/kvmd/plugins/msd/__init__.py index c8015a2a..680ccece 100644 --- a/kvmd/plugins/msd/__init__.py +++ b/kvmd/plugins/msd/__init__.py @@ -20,6 +20,7 @@ # ========================================================================== # +import os import contextlib from typing import Dict @@ -27,6 +28,11 @@ from typing import Type from typing import AsyncGenerator from typing import Optional +import aiofiles +import aiofiles.base + +from ... import aiofs + from ...errors import OperationError from ...errors import IsBusyError @@ -113,7 +119,7 @@ class BaseMsd(BasePlugin): raise NotImplementedError() @contextlib.asynccontextmanager - async def write_image(self, name: str) -> AsyncGenerator[None, None]: # pylint: disable=unused-argument + async def write_image(self, name: str, size: int) -> AsyncGenerator[None, None]: # pylint: disable=unused-argument if self is not None: # XXX: Vulture and pylint hack raise NotImplementedError() yield @@ -128,6 +134,52 @@ class BaseMsd(BasePlugin): raise NotImplementedError() +class MsdImageWriter: + def __init__(self, path: str, size: int, sync: int) -> None: + self.__name = os.path.basename(path) + self.__path = path + self.__size = size + self.__sync = sync + + self.__file: Optional[aiofiles.base.AiofilesContextManager] = None + self.__written = 0 + self.__unsynced = 0 + + def get_file(self) -> aiofiles.base.AiofilesContextManager: + assert self.__file is not None + return self.__file + + def get_state(self) -> Dict: + return { + "name": self.__name, + "size": self.__size, + "written": self.__written, + } + + async def open(self) -> "MsdImageWriter": + assert self.__file is None + self.__file = await aiofiles.open(self.__path, mode="w+b", buffering=0) # type: ignore + return self + + async def write(self, chunk: bytes) -> int: + assert self.__file is not None + + await self.__file.write(chunk) # type: ignore + self.__written += len(chunk) + + self.__unsynced += len(chunk) + if self.__unsynced >= self.__sync: + await aiofs.afile_sync(self.__file) + self.__unsynced = 0 + + return self.__written + + async def close(self) -> None: + assert self.__file is not None + await aiofs.afile_sync(self.__file) + await self.__file.close() # type: ignore + + # ===== def get_msd_class(name: str) -> Type[BaseMsd]: return get_plugin_class("msd", name) # type: ignore |