diff options
author | Maxim Devaev <[email protected]> | 2022-01-18 09:25:17 +0300 |
---|---|---|
committer | Maxim Devaev <[email protected]> | 2022-01-18 09:25:17 +0300 |
commit | 3ab43edeb962e95402ab0e18c553665ffd6117f0 (patch) | |
tree | abd46db985273d9964ed09400dd608a23d2579a4 /kvmd | |
parent | 3ee1948f19ca3e1a3d396e29dc902aeec4402479 (diff) |
pikvm/kvmd#66: OCR API
Diffstat (limited to 'kvmd')
-rw-r--r-- | kvmd/apps/__init__.py | 4 | ||||
-rw-r--r-- | kvmd/apps/kvmd/__init__.py | 2 | ||||
-rw-r--r-- | kvmd/apps/kvmd/api/hid.py | 4 | ||||
-rw-r--r-- | kvmd/apps/kvmd/api/streamer.py | 54 | ||||
-rw-r--r-- | kvmd/apps/kvmd/server.py | 34 | ||||
-rw-r--r-- | kvmd/apps/kvmd/tesseract.py | 161 | ||||
-rw-r--r-- | kvmd/libc.py | 3 |
7 files changed, 251 insertions, 11 deletions
diff --git a/kvmd/apps/__init__.py b/kvmd/apps/__init__.py index c41ed9b9..4b7a6588 100644 --- a/kvmd/apps/__init__.py +++ b/kvmd/apps/__init__.py @@ -448,6 +448,10 @@ def _get_config_scheme() -> Dict: "cmd_append": Option([], type=valid_options), }, + "ocr": { + "langs": Option(["eng"], type=valid_string_list, unpack_as="default_langs"), + }, + "snapshot": { "idle_interval": Option(0.0, type=valid_float_f0), "live_interval": Option(0.0, type=valid_float_f0), diff --git a/kvmd/apps/kvmd/__init__.py b/kvmd/apps/kvmd/__init__.py index 759cfc55..afbe4493 100644 --- a/kvmd/apps/kvmd/__init__.py +++ b/kvmd/apps/kvmd/__init__.py @@ -37,6 +37,7 @@ from .logreader import LogReader from .ugpio import UserGpio from .streamer import Streamer from .snapshoter import Snapshoter +from .tesseract import TesseractOcr from .server import KvmdServer @@ -86,6 +87,7 @@ def main(argv: Optional[List[str]]=None) -> None: info_manager=InfoManager(global_config), log_reader=LogReader(), user_gpio=UserGpio(config.gpio, global_config.otg.udc), + ocr=TesseractOcr(**config.ocr._unpack()), hid=hid, atx=get_atx_class(config.atx.type)(**config.atx._unpack(ignore=["type"])), diff --git a/kvmd/apps/kvmd/api/hid.py b/kvmd/apps/kvmd/api/hid.py index 46f2206d..7f344e43 100644 --- a/kvmd/apps/kvmd/api/hid.py +++ b/kvmd/apps/kvmd/api/hid.py @@ -112,7 +112,7 @@ class HidApi: # ===== - def get_keymaps(self) -> Dict: # Ugly hack to generate hid_keymaps_state (see server.py) + async def get_keymaps(self) -> Dict: # Ugly hack to generate hid_keymaps_state (see server.py) keymaps: Set[str] = set() for keymap_name in os.listdir(self.__keymaps_dir_path): path = os.path.join(self.__keymaps_dir_path, keymap_name) @@ -127,7 +127,7 @@ class HidApi: @exposed_http("GET", "/hid/keymaps") async def __keymaps_handler(self, _: Request) -> Response: - return make_json_response(self.get_keymaps()) + return make_json_response(await self.get_keymaps()) @exposed_http("POST", "/hid/print") async def __print_handler(self, request: Request) -> Response: diff --git a/kvmd/apps/kvmd/api/streamer.py b/kvmd/apps/kvmd/api/streamer.py index c24a04b7..f25bc7d3 100644 --- a/kvmd/apps/kvmd/api/streamer.py +++ b/kvmd/apps/kvmd/api/streamer.py @@ -23,13 +23,19 @@ import io import functools +from typing import List +from typing import Dict + from aiohttp.web import Request from aiohttp.web import Response from PIL import Image as PilImage +from ....validators import check_string_in_list from ....validators.basic import valid_bool +from ....validators.basic import valid_number from ....validators.basic import valid_int_f0 +from ....validators.basic import valid_string_list from ....validators.kvm import valid_stream_quality from .... import aiotools @@ -41,11 +47,14 @@ from ..http import make_json_response from ..streamer import StreamerSnapshot from ..streamer import Streamer +from ..tesseract import TesseractOcr + # ===== class StreamerApi: - def __init__(self, streamer: Streamer) -> None: + def __init__(self, streamer: Streamer, ocr: TesseractOcr) -> None: self.__streamer = streamer + self.__ocr = ocr # ===== @@ -61,7 +70,25 @@ class StreamerApi: allow_offline=valid_bool(request.query.get("allow_offline", "false")), ) if snapshot: - if valid_bool(request.query.get("preview", "false")): + if valid_bool(request.query.get("ocr", "false")): + langs = await self.__ocr.get_available_langs() + return Response( + body=(await self.__ocr.recognize( + data=snapshot.data, + langs=valid_string_list( + arg=str(request.query.get("ocr_langs", "")).strip(), + subval=(lambda lang: check_string_in_list(lang, "OCR lang", langs)), + name="OCR langs list", + ), + left=int(valid_number(request.query.get("ocr_left", "-1"))), + top=int(valid_number(request.query.get("ocr_top", "-1"))), + right=int(valid_number(request.query.get("ocr_right", "-1"))), + bottom=int(valid_number(request.query.get("ocr_bottom", "-1"))), + )), + headers=dict(snapshot.headers), + content_type="text/plain", + ) + elif valid_bool(request.query.get("preview", "false")): data = await self.__make_preview( snapshot=snapshot, max_width=valid_int_f0(request.query.get("preview_max_width", "0")), @@ -84,6 +111,29 @@ class StreamerApi: # ===== + async def get_ocr(self) -> Dict: # XXX: Ugly hack + enabled = self.__ocr.is_available() + default: List[str] = [] + available: List[str] = [] + if enabled: + default = await self.__ocr.get_default_langs() + available = await self.__ocr.get_available_langs() + return { + "ocr": { + "enabled": enabled, + "langs": { + "default": default, + "available": available, + }, + }, + } + + @exposed_http("GET", "/streamer/ocr") + async def __ocr_handler(self, _: Request) -> Response: + return make_json_response(await self.get_ocr()) + + # ===== + async def __make_preview(self, snapshot: StreamerSnapshot, max_width: int, max_height: int, quality: int) -> bytes: if max_width == 0 and max_height == 0: max_width = snapshot.width // 5 diff --git a/kvmd/apps/kvmd/server.py b/kvmd/apps/kvmd/server.py index ea526518..7ac422f8 100644 --- a/kvmd/apps/kvmd/server.py +++ b/kvmd/apps/kvmd/server.py @@ -32,6 +32,7 @@ from typing import List from typing import Dict from typing import Set from typing import Callable +from typing import Awaitable from typing import Coroutine from typing import AsyncGenerator from typing import Optional @@ -68,6 +69,7 @@ from .logreader import LogReader from .ugpio import UserGpio from .streamer import Streamer from .snapshoter import Snapshoter +from .tesseract import TesseractOcr from .http import HttpError from .http import HttpExposed @@ -147,6 +149,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins info_manager: InfoManager, log_reader: LogReader, user_gpio: UserGpio, + ocr: TesseractOcr, hid: BaseHid, atx: BaseAtx, @@ -192,6 +195,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins ] self.__hid_api = HidApi(hid, keymap_path, ignore_keys, mouse_x_range, mouse_y_range) # Ugly hack to get keymaps state + self.__streamer_api = StreamerApi(streamer, ocr) # Same hack to get ocr langs state self.__apis: List[object] = [ self, AuthApi(auth_manager), @@ -201,7 +205,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins self.__hid_api, AtxApi(atx), MsdApi(msd), - StreamerApi(streamer), + self.__streamer_api, ExportApi(info_manager, atx, user_gpio), RedfishApi(info_manager, atx), ] @@ -251,21 +255,27 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins @exposed_http("GET", "/ws") async def __ws_handler(self, request: aiohttp.web.Request) -> aiohttp.web.WebSocketResponse: logger = get_logger(0) + client = _WsClient( ws=aiohttp.web.WebSocketResponse(heartbeat=self.__heartbeat), stream=valid_bool(request.query.get("stream", "true")), ) await client.ws.prepare(request) await self.__register_ws_client(client) + try: - await self.__send_event(client.ws, "gpio_model_state", await self.__user_gpio.get_model()) - await self.__send_event(client.ws, "hid_keymaps_state", self.__hid_api.get_keymaps()) - await asyncio.gather(*[ - self.__send_event(client.ws, component.event_type, await component.get_state()) - for component in self.__components - if component.get_state + await self.__send_events_aws(client.ws, [ + ("gpio_model_state", self.__user_gpio.get_model()), + ("hid_keymaps_state", self.__hid_api.get_keymaps()), + ("streamer_ocr_state", self.__streamer_api.get_ocr()), + ]) + await self.__send_events_aws(client.ws, [ + (comp.event_type, comp.get_state()) + for comp in self.__components + if comp.get_state ]) await self.__send_event(client.ws, "loop", {}) + async for msg in client.ws: if msg.type == aiohttp.web.WSMsgType.TEXT: try: @@ -282,6 +292,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins logger.error("Unknown websocket event: %r", data) else: break + return client.ws finally: await self.__remove_ws_client(client) @@ -380,6 +391,15 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins logger.exception("Cleanup error on %s", comp.name) logger.info("On-Cleanup complete") + async def __send_events_aws(self, ws: aiohttp.web.WebSocketResponse, sources: List[Tuple[str, Awaitable]]) -> None: + await asyncio.gather(*[ + self.__send_event(ws, event_type, state) + for (event_type, state) in zip( + map(operator.itemgetter(0), sources), + await asyncio.gather(*map(operator.itemgetter(1), sources)), + ) + ]) + async def __send_event(self, ws: aiohttp.web.WebSocketResponse, event_type: str, event: Optional[Dict]) -> None: await ws.send_str(json.dumps({ "event_type": event_type, diff --git a/kvmd/apps/kvmd/tesseract.py b/kvmd/apps/kvmd/tesseract.py new file mode 100644 index 00000000..f467f027 --- /dev/null +++ b/kvmd/apps/kvmd/tesseract.py @@ -0,0 +1,161 @@ +# ========================================================================== # +# # +# KVMD - The main PiKVM daemon. # +# # +# Copyright (C) 2018-2022 Maxim Devaev <[email protected]> # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <https://www.gnu.org/licenses/>. # +# # +# ========================================================================== # + + +import io +import ctypes +import ctypes.util +import contextlib +import warnings + +from ctypes import POINTER +from ctypes import Structure +from ctypes import c_int +from ctypes import c_bool +from ctypes import c_char_p +from ctypes import c_void_p +from ctypes import c_char + +from typing import List +from typing import Set +from typing import Generator +from typing import Optional + +from PIL import Image as PilImage + +from ...errors import OperationError + +from ... import libc +from ... import aiotools + + +# ===== +class OcrError(OperationError): + pass + + +# ===== +class _TessBaseAPI(Structure): + pass + + +def _load_libtesseract() -> Optional[ctypes.CDLL]: + try: + path = ctypes.util.find_library("tesseract") + if not path: + raise RuntimeError("Can't find libtesseract") + lib = ctypes.CDLL(path) + for (name, restype, argtypes) in [ + ("TessBaseAPICreate", POINTER(_TessBaseAPI), []), + ("TessBaseAPIInit3", c_int, [POINTER(_TessBaseAPI), c_char_p, c_char_p]), + ("TessBaseAPISetImage", None, [POINTER(_TessBaseAPI), c_void_p, c_int, c_int, c_int, c_int]), + ("TessBaseAPIGetUTF8Text", POINTER(c_char), [POINTER(_TessBaseAPI)]), + ("TessBaseAPISetVariable", c_bool, [POINTER(_TessBaseAPI), c_char_p, c_char_p]), + ("TessBaseAPIGetAvailableLanguagesAsVector", POINTER(POINTER(c_char)), [POINTER(_TessBaseAPI)]), + ]: + func = getattr(lib, name) + if not func: + raise RuntimeError(f"Can't find libtesseract.{name}") + setattr(func, "restype", restype) + setattr(func, "argtypes", argtypes) + return lib + except Exception as err: + warnings.warn(f"Can't load libtesseract: {err}", RuntimeWarning) + return None + + +_libtess = _load_libtesseract() + + +def _tess_api(langs: List[str]) -> Generator[_TessBaseAPI, None, None]: + if not _libtess: + raise OcrError("Tesseract is not available") + api = _libtess.TessBaseAPICreate() + try: + if _libtess.TessBaseAPIInit3(api, None, "+".join(langs).encode()) != 0: + raise OcrError("Can't initialize Tesseract") + if not _libtess.TessBaseAPISetVariable(api, b"debug_file", b"/dev/null"): + raise OcrError("Can't set debug_file=/dev/null") + yield api + finally: + _libtess.TessBaseAPIDelete(api) + + +# ===== +class TesseractOcr: + def __init__(self, default_langs: List[str]) -> None: + self.__default_langs = default_langs + + def is_available(self) -> bool: + return bool(_libtess) + + async def get_default_langs(self) -> List[str]: + return list(self.__default_langs) + + async def get_available_langs(self) -> List[str]: + return (await aiotools.run_async(self.__inner_get_available_langs)) + + def __inner_get_available_langs(self) -> List[str]: + with _tess_api(["osd"]) as api: + assert _libtess + langs: Set[str] = set() + langs_ptr = _libtess.TessBaseAPIGetAvailableLanguagesAsVector(api) + if langs_ptr is not None: + index = 0 + while langs_ptr[index]: + lang = ctypes.cast(langs_ptr[index], c_char_p).value + if lang is not None: + langs.add(lang.decode()) + libc.free(langs_ptr[index]) + index += 1 + libc.free(langs_ptr) + return sorted(langs) + + async def recognize(self, data: bytes, langs: List[str], left: int, top: int, right: int, bottom: int) -> str: + if not langs: + langs = self.__default_langs + return (await aiotools.run_async(self.__inner_recognize, data, langs, left, top, right, bottom)) + + def __inner_recognize(self, data: bytes, langs: List[str], left: int, top: int, right: int, bottom: int) -> str: + with _tess_api(langs) as api: + assert _libtess + with io.BytesIO(data) as bio: + with PilImage.open(bio) as image: + if left >= 0 or top >= 0 or right >= 0 or bottom >= 0: + left = (0 if left < 0 else min(image.width, left)) + top = (0 if top < 0 else min(image.height, top)) + right = (image.width if right < 0 else min(image.width, right)) + bottom = (image.height if bottom < 0 else min(image.height, bottom)) + if left < right and top < bottom: + image.crop((left, top, right, bottom)) + + _libtess.TessBaseAPISetImage(api, image.tobytes("raw", "RGB"), image.width, image.height, 3, image.width * 3) + text_ptr = None + try: + text_ptr = _libtess.TessBaseAPIGetUTF8Text(api) + text = ctypes.cast(text_ptr, c_char_p).value + if text is None: + raise OcrError("Can't recognize image") + return text.decode("utf-8") + finally: + if text_ptr is not None: + libc.free(text_ptr) diff --git a/kvmd/libc.py b/kvmd/libc.py index a1892d94..46e6278c 100644 --- a/kvmd/libc.py +++ b/kvmd/libc.py @@ -28,6 +28,7 @@ import ctypes.util from ctypes import c_int from ctypes import c_uint32 from ctypes import c_char_p +from ctypes import c_void_p # ===== @@ -41,6 +42,7 @@ def _load_libc() -> ctypes.CDLL: ("inotify_init", c_int, []), ("inotify_add_watch", c_int, [c_int, c_char_p, c_uint32]), ("inotify_rm_watch", c_int, [c_int, c_uint32]), + ("free", c_int, [c_void_p]), ]: func = getattr(lib, name) if not func: @@ -57,3 +59,4 @@ _libc = _load_libc() inotify_init = _libc.inotify_init inotify_add_watch = _libc.inotify_add_watch inotify_rm_watch = _libc.inotify_rm_watch +free = _libc.free |