diff options
Diffstat (limited to 'kvmd')
-rw-r--r-- | kvmd/aiotools.py | 5 | ||||
-rw-r--r-- | kvmd/apps/__init__.py | 5 | ||||
-rw-r--r-- | kvmd/apps/kvmd/__init__.py | 5 | ||||
-rw-r--r-- | kvmd/apps/kvmd/api/switch.py | 164 | ||||
-rw-r--r-- | kvmd/apps/kvmd/server.py | 7 | ||||
-rw-r--r-- | kvmd/apps/kvmd/switch/__init__.py | 400 | ||||
-rw-r--r-- | kvmd/apps/kvmd/switch/chain.py | 440 | ||||
-rw-r--r-- | kvmd/apps/kvmd/switch/device.py | 196 | ||||
-rw-r--r-- | kvmd/apps/kvmd/switch/lib.py | 35 | ||||
-rw-r--r-- | kvmd/apps/kvmd/switch/proto.py | 295 | ||||
-rw-r--r-- | kvmd/apps/kvmd/switch/state.py | 355 | ||||
-rw-r--r-- | kvmd/apps/kvmd/switch/storage.py | 186 | ||||
-rw-r--r-- | kvmd/apps/kvmd/switch/types.py | 308 | ||||
-rw-r--r-- | kvmd/apps/pst/server.py | 12 | ||||
-rw-r--r-- | kvmd/clients/pst.py | 93 | ||||
-rw-r--r-- | kvmd/validators/__init__.py | 8 | ||||
-rw-r--r-- | kvmd/validators/os.py | 10 | ||||
-rw-r--r-- | kvmd/validators/switch.py | 67 |
18 files changed, 2582 insertions, 9 deletions
diff --git a/kvmd/aiotools.py b/kvmd/aiotools.py index a47c94c6..6183690f 100644 --- a/kvmd/aiotools.py +++ b/kvmd/aiotools.py @@ -45,6 +45,11 @@ async def read_file(path: str) -> str: return (await file.read()) +async def write_file(path: str, text: str) -> None: + async with aiofiles.open(path, "w") as file: + await file.write(text) + + # ===== def run(coro: Coroutine, final: (Coroutine | None)=None) -> None: # https://github.com/aio-libs/aiohttp/blob/a1d4dac1d/aiohttp/web.py#L515 diff --git a/kvmd/apps/__init__.py b/kvmd/apps/__init__.py index cfc39499..2090e5c6 100644 --- a/kvmd/apps/__init__.py +++ b/kvmd/apps/__init__.py @@ -502,6 +502,11 @@ def _get_config_scheme() -> dict: "table": Option([], type=valid_ugpio_view_table), }, }, + + "switch": { + "device": Option("/dev/kvmd-switch", type=valid_abs_path, unpack_as="device_path"), + "default_edid": Option("/etc/kvmd/switch-edid.hex", type=valid_abs_path, unpack_as="default_edid_path"), + }, }, "pst": { diff --git a/kvmd/apps/kvmd/__init__.py b/kvmd/apps/kvmd/__init__.py index 495a320f..088a62ef 100644 --- a/kvmd/apps/kvmd/__init__.py +++ b/kvmd/apps/kvmd/__init__.py @@ -35,6 +35,7 @@ from .ugpio import UserGpio from .streamer import Streamer from .snapshoter import Snapshoter from .ocr import Ocr +from .switch import Switch from .server import KvmdServer @@ -90,6 +91,10 @@ def main(argv: (list[str] | None)=None) -> None: log_reader=(LogReader() if config.log_reader.enabled else None), user_gpio=UserGpio(config.gpio, global_config.otg), ocr=Ocr(**config.ocr._unpack()), + switch=Switch( + pst_unix_path=global_config.pst.server.unix, + **config.switch._unpack(), + ), hid=hid, atx=get_atx_class(config.atx.type)(**config.atx._unpack(ignore=["type"])), diff --git a/kvmd/apps/kvmd/api/switch.py b/kvmd/apps/kvmd/api/switch.py new file mode 100644 index 00000000..bf91b83e --- /dev/null +++ b/kvmd/apps/kvmd/api/switch.py @@ -0,0 +1,164 @@ +# ========================================================================== # +# # +# KVMD - The main PiKVM daemon. # +# # +# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <https://www.gnu.org/licenses/>. # +# # +# ========================================================================== # + + +from aiohttp.web import Request +from aiohttp.web import Response + +from ....htserver import exposed_http +from ....htserver import make_json_response + +from ....validators.basic import valid_bool +from ....validators.basic import valid_int_f0 +from ....validators.basic import valid_stripped_string_not_empty +from ....validators.kvm import valid_atx_power_action +from ....validators.kvm import valid_atx_button +from ....validators.switch import valid_switch_port_name +from ....validators.switch import valid_switch_edid_id +from ....validators.switch import valid_switch_edid_data +from ....validators.switch import valid_switch_color +from ....validators.switch import valid_switch_atx_click_delay + +from ..switch import Switch +from ..switch import Colors + + +# ===== +class SwitchApi: + def __init__(self, switch: Switch) -> None: + self.__switch = switch + + # ===== + + @exposed_http("GET", "/switch") + async def __state_handler(self, _: Request) -> Response: + return make_json_response(await self.__switch.get_state()) + + @exposed_http("POST", "/switch/set_active") + async def __set_active_port_handler(self, req: Request) -> Response: + port = valid_int_f0(req.query.get("port")) + await self.__switch.set_active_port(port) + return make_json_response() + + @exposed_http("POST", "/switch/set_beacon") + async def __set_beacon_handler(self, req: Request) -> Response: + on = valid_bool(req.query.get("state")) + if "port" in req.query: + port = valid_int_f0(req.query.get("port")) + await self.__switch.set_port_beacon(port, on) + elif "uplink" in req.query: + unit = valid_int_f0(req.query.get("uplink")) + await self.__switch.set_uplink_beacon(unit, on) + else: # Downlink + unit = valid_int_f0(req.query.get("downlink")) + await self.__switch.set_downlink_beacon(unit, on) + return make_json_response() + + @exposed_http("POST", "/switch/set_port_params") + async def __set_port_params(self, req: Request) -> Response: + port = valid_int_f0(req.query.get("port")) + params = { + param: validator(req.query.get(param)) + for (param, validator) in [ + ("edid_id", (lambda arg: valid_switch_edid_id(arg, allow_default=True))), + ("name", valid_switch_port_name), + ("atx_click_power_delay", valid_switch_atx_click_delay), + ("atx_click_power_long_delay", valid_switch_atx_click_delay), + ("atx_click_reset_delay", valid_switch_atx_click_delay), + ] + if req.query.get(param) is not None + } + await self.__switch.set_port_params(port, **params) # type: ignore + return make_json_response() + + @exposed_http("POST", "/switch/set_colors") + async def __set_colors(self, req: Request) -> Response: + params = { + param: valid_switch_color(req.query.get(param), allow_default=True) + for param in Colors.ROLES + if req.query.get(param) is not None + } + await self.__switch.set_colors(**params) + return make_json_response() + + # ===== + + @exposed_http("POST", "/switch/reset") + async def __reset(self, req: Request) -> Response: + unit = valid_int_f0(req.query.get("unit")) + bootloader = valid_bool(req.query.get("bootloader", False)) + await self.__switch.reboot_unit(unit, bootloader) + return make_json_response() + + # ===== + + @exposed_http("POST", "/switch/edids/create") + async def __create_edid(self, req: Request) -> Response: + name = valid_stripped_string_not_empty(req.query.get("name")) + data_hex = valid_switch_edid_data(req.query.get("data")) + edid_id = await self.__switch.create_edid(name, data_hex) + return make_json_response({"id": edid_id}) + + @exposed_http("POST", "/switch/edids/change") + async def __change_edid(self, req: Request) -> Response: + edid_id = valid_switch_edid_id(req.query.get("id"), allow_default=False) + params = { + param: validator(req.query.get(param)) + for (param, validator) in [ + ("name", valid_switch_port_name), + ("data", valid_switch_edid_data), + ] + if req.query.get(param) is not None + } + if params: + await self.__switch.change_edid(edid_id, **params) + return make_json_response() + + @exposed_http("POST", "/switch/edids/remove") + async def __remove_edid(self, req: Request) -> Response: + edid_id = valid_switch_edid_id(req.query.get("id"), allow_default=False) + await self.__switch.remove_edid(edid_id) + return make_json_response() + + # ===== + + @exposed_http("POST", "/switch/atx/power") + async def __power_handler(self, req: Request) -> Response: + port = valid_int_f0(req.query.get("port")) + action = valid_atx_power_action(req.query.get("action")) + await ({ + "on": self.__switch.atx_power_on, + "off": self.__switch.atx_power_off, + "off_hard": self.__switch.atx_power_off_hard, + "reset_hard": self.__switch.atx_power_reset_hard, + }[action])(port) + return make_json_response() + + @exposed_http("POST", "/switch/atx/click") + async def __click_handler(self, req: Request) -> Response: + port = valid_int_f0(req.query.get("port")) + button = valid_atx_button(req.query.get("button")) + await ({ + "power": self.__switch.atx_click_power, + "power_long": self.__switch.atx_click_power_long, + "reset": self.__switch.atx_click_reset, + }[button])(port) + return make_json_response() diff --git a/kvmd/apps/kvmd/server.py b/kvmd/apps/kvmd/server.py index ed85bb24..92eb496c 100644 --- a/kvmd/apps/kvmd/server.py +++ b/kvmd/apps/kvmd/server.py @@ -66,6 +66,7 @@ from .ugpio import UserGpio from .streamer import Streamer from .snapshoter import Snapshoter from .ocr import Ocr +from .switch import Switch from .api.auth import AuthApi from .api.auth import check_request_auth @@ -77,6 +78,7 @@ from .api.hid import HidApi from .api.atx import AtxApi from .api.msd import MsdApi from .api.streamer import StreamerApi +from .api.switch import SwitchApi from .api.export import ExportApi from .api.redfish import RedfishApi @@ -125,7 +127,6 @@ class _Subsystem: cleanup=getattr(obj, "cleanup", None), trigger_state=getattr(obj, "trigger_state", None), poll_state=getattr(obj, "poll_state", None), - ) @@ -137,6 +138,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins __EV_STREAMER_STATE = "streamer_state" __EV_OCR_STATE = "ocr_state" __EV_INFO_STATE = "info_state" + __EV_SWITCH_STATE = "switch_state" def __init__( # pylint: disable=too-many-arguments,too-many-locals self, @@ -145,6 +147,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins log_reader: (LogReader | None), user_gpio: UserGpio, ocr: Ocr, + switch: Switch, hid: BaseHid, atx: BaseAtx, @@ -177,6 +180,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins AtxApi(atx), MsdApi(msd), StreamerApi(streamer, ocr), + SwitchApi(switch), ExportApi(info_manager, atx, user_gpio), RedfishApi(info_manager, atx), ] @@ -189,6 +193,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins _Subsystem.make(streamer, "Streamer", self.__EV_STREAMER_STATE), _Subsystem.make(ocr, "OCR", self.__EV_OCR_STATE), _Subsystem.make(info_manager, "Info manager", self.__EV_INFO_STATE), + _Subsystem.make(switch, "Switch", self.__EV_SWITCH_STATE), ] self.__streamer_notifier = aiotools.AioNotifier() diff --git a/kvmd/apps/kvmd/switch/__init__.py b/kvmd/apps/kvmd/switch/__init__.py new file mode 100644 index 00000000..49bfbd7d --- /dev/null +++ b/kvmd/apps/kvmd/switch/__init__.py @@ -0,0 +1,400 @@ +# ========================================================================== # +# # +# KVMD - The main PiKVM daemon. # +# # +# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <https://www.gnu.org/licenses/>. # +# # +# ========================================================================== # + + +import os +import asyncio + +from typing import AsyncGenerator + +from .lib import OperationError +from .lib import get_logger +from .lib import aiotools +from .lib import Inotify + +from .types import Edid +from .types import Edids +from .types import Color +from .types import Colors +from .types import PortNames +from .types import AtxClickPowerDelays +from .types import AtxClickPowerLongDelays +from .types import AtxClickResetDelays + +from .chain import DeviceFoundEvent +from .chain import ChainTruncatedEvent +from .chain import PortActivatedEvent +from .chain import UnitStateEvent +from .chain import UnitAtxLedsEvent +from .chain import Chain + +from .state import StateCache + +from .storage import Storage + + +# ===== +class SwitchError(Exception): + pass + + +class SwitchOperationError(OperationError, SwitchError): + pass + + +class SwitchUnknownEdidError(SwitchOperationError): + def __init__(self) -> None: + super().__init__("No specified EDID ID found") + + +# ===== +class Switch: # pylint: disable=too-many-public-methods + __X_EDIDS = "edids" + __X_COLORS = "colors" + __X_PORT_NAMES = "port_names" + __X_ATX_CP_DELAYS = "atx_cp_delays" + __X_ATX_CPL_DELAYS = "atx_cpl_delays" + __X_ATX_CR_DELAYS = "atx_cr_delays" + + __X_ALL = frozenset([ + __X_EDIDS, __X_COLORS, __X_PORT_NAMES, + __X_ATX_CP_DELAYS, __X_ATX_CPL_DELAYS, __X_ATX_CR_DELAYS, + ]) + + def __init__( + self, + device_path: str, + default_edid_path: str, + pst_unix_path: str, + ) -> None: + + self.__default_edid_path = default_edid_path + + self.__chain = Chain(device_path) + self.__cache = StateCache() + self.__storage = Storage(pst_unix_path) + + self.__lock = asyncio.Lock() + + self.__save_notifier = aiotools.AioNotifier() + + # ===== + + def __x_set_edids(self, edids: Edids, save: bool=True) -> None: + self.__chain.set_edids(edids) + self.__cache.set_edids(edids) + if save: + self.__save_notifier.notify() + + def __x_set_colors(self, colors: Colors, save: bool=True) -> None: + self.__chain.set_colors(colors) + self.__cache.set_colors(colors) + if save: + self.__save_notifier.notify() + + def __x_set_port_names(self, port_names: PortNames, save: bool=True) -> None: + self.__cache.set_port_names(port_names) + if save: + self.__save_notifier.notify() + + def __x_set_atx_cp_delays(self, delays: AtxClickPowerDelays, save: bool=True) -> None: + self.__cache.set_atx_cp_delays(delays) + if save: + self.__save_notifier.notify() + + def __x_set_atx_cpl_delays(self, delays: AtxClickPowerLongDelays, save: bool=True) -> None: + self.__cache.set_atx_cpl_delays(delays) + if save: + self.__save_notifier.notify() + + def __x_set_atx_cr_delays(self, delays: AtxClickResetDelays, save: bool=True) -> None: + self.__cache.set_atx_cr_delays(delays) + if save: + self.__save_notifier.notify() + + # ===== + + async def set_active_port(self, port: int) -> None: + self.__chain.set_active_port(port) + + # ===== + + async def set_port_beacon(self, port: int, on: bool) -> None: + self.__chain.set_port_beacon(port, on) + + async def set_uplink_beacon(self, unit: int, on: bool) -> None: + self.__chain.set_uplink_beacon(unit, on) + + async def set_downlink_beacon(self, unit: int, on: bool) -> None: + self.__chain.set_downlink_beacon(unit, on) + + # ===== + + async def atx_power_on(self, port: int) -> None: + self.__inner_atx_cp(port, False, self.__X_ATX_CP_DELAYS) + + async def atx_power_off(self, port: int) -> None: + self.__inner_atx_cp(port, True, self.__X_ATX_CP_DELAYS) + + async def atx_power_off_hard(self, port: int) -> None: + self.__inner_atx_cp(port, True, self.__X_ATX_CPL_DELAYS) + + async def atx_power_reset_hard(self, port: int) -> None: + self.__inner_atx_cr(port, True) + + async def atx_click_power(self, port: int) -> None: + self.__inner_atx_cp(port, None, self.__X_ATX_CP_DELAYS) + + async def atx_click_power_long(self, port: int) -> None: + self.__inner_atx_cp(port, None, self.__X_ATX_CPL_DELAYS) + + async def atx_click_reset(self, port: int) -> None: + self.__inner_atx_cr(port, None) + + def __inner_atx_cp(self, port: int, if_powered: (bool | None), x_delay: str) -> None: + assert x_delay in [self.__X_ATX_CP_DELAYS, self.__X_ATX_CPL_DELAYS] + delay = getattr(self.__cache, f"get_{x_delay}")()[port] + self.__chain.click_power(port, delay, if_powered) + + def __inner_atx_cr(self, port: int, if_powered: (bool | None)) -> None: + delay = self.__cache.get_atx_cr_delays()[port] + self.__chain.click_reset(port, delay, if_powered) + + # ===== + + async def create_edid(self, name: str, data_hex: str) -> str: + async with self.__lock: + edids = self.__cache.get_edids() + edid_id = edids.add(Edid.from_data(name, data_hex)) + self.__x_set_edids(edids) + return edid_id + + async def change_edid( + self, + edid_id: str, + name: (str | None)=None, + data_hex: (str | None)=None, + ) -> None: + + assert edid_id != Edids.DEFAULT_ID + async with self.__lock: + edids = self.__cache.get_edids() + if not edids.has_id(edid_id): + raise SwitchUnknownEdidError() + old = edids.get(edid_id) + name = (name or old.name) + data_hex = (data_hex or old.as_text()) + edids.set(edid_id, Edid.from_data(name, data_hex)) + self.__x_set_edids(edids) + + async def remove_edid(self, edid_id: str) -> None: + assert edid_id != Edids.DEFAULT_ID + async with self.__lock: + edids = self.__cache.get_edids() + if not edids.has_id(edid_id): + raise SwitchUnknownEdidError() + edids.remove(edid_id) + self.__x_set_edids(edids) + + # ===== + + async def set_colors(self, **values: str) -> None: + async with self.__lock: + old = self.__cache.get_colors() + new = {} + for role in Colors.ROLES: + if role in values: + if values[role] != "default": + new[role] = Color.from_text(values[role]) + # else reset to default + else: + new[role] = getattr(old, role) + self.__x_set_colors(Colors(**new)) # type: ignore + + # ===== + + async def set_port_params( + self, + port: int, + edid_id: (str | None)=None, + name: (str | None)=None, + atx_click_power_delay: (float | None)=None, + atx_click_power_long_delay: (float | None)=None, + atx_click_reset_delay: (float | None)=None, + ) -> None: + + async with self.__lock: + if edid_id is not None: + edids = self.__cache.get_edids() + if not edids.has_id(edid_id): + raise SwitchUnknownEdidError() + edids.assign(port, edid_id) + self.__x_set_edids(edids) + + for (key, value) in [ + (self.__X_PORT_NAMES, name), + (self.__X_ATX_CP_DELAYS, atx_click_power_delay), + (self.__X_ATX_CPL_DELAYS, atx_click_power_long_delay), + (self.__X_ATX_CR_DELAYS, atx_click_reset_delay), + ]: + if value is not None: + new = getattr(self.__cache, f"get_{key}")() + new[port] = (value or None) # None == reset to default + getattr(self, f"_Switch__x_set_{key}")(new) + + # ===== + + async def reboot_unit(self, unit: int, bootloader: bool) -> None: + self.__chain.reboot_unit(unit, bootloader) + + # ===== + + async def get_state(self) -> dict: + return self.__cache.get_state() + + async def trigger_state(self) -> None: + await self.__cache.trigger_state() + + async def poll_state(self) -> AsyncGenerator[dict, None]: + async for state in self.__cache.poll_state(): + yield state + + # ===== + + async def systask(self) -> None: + tasks = [ + asyncio.create_task(self.__systask_events()), + asyncio.create_task(self.__systask_default_edid()), + asyncio.create_task(self.__systask_storage()), + ] + try: + await asyncio.gather(*tasks) + except Exception: + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + + async def __systask_events(self) -> None: + async for event in self.__chain.poll_events(): + match event: + case DeviceFoundEvent(): + await self.__load_configs() + case ChainTruncatedEvent(): + self.__cache.truncate(event.units) + case PortActivatedEvent(): + self.__cache.update_active_port(event.port) + case UnitStateEvent(): + self.__cache.update_unit_state(event.unit, event.state) + case UnitAtxLedsEvent(): + self.__cache.update_unit_atx_leds(event.unit, event.atx_leds) + + async def __load_configs(self) -> None: + async with self.__lock: + try: + async with self.__storage.readable() as ctx: + values = { + key: await getattr(ctx, f"read_{key}")() + for key in self.__X_ALL + } + data_hex = await aiotools.read_file(self.__default_edid_path) + values["edids"].set_default(data_hex) + except Exception: + get_logger(0).exception("Can't load configs") + else: + for (key, value) in values.items(): + func = getattr(self, f"_Switch__x_set_{key}") + if isinstance(value, tuple): + func(*value, save=False) + else: + func(value, save=False) + self.__chain.set_actual(True) + + async def __systask_default_edid(self) -> None: + logger = get_logger(0) + async for _ in self.__poll_default_edid(): + async with self.__lock: + edids = self.__cache.get_edids() + try: + data_hex = await aiotools.read_file(self.__default_edid_path) + edids.set_default(data_hex) + except Exception: + logger.exception("Can't read default EDID, ignoring ...") + else: + self.__x_set_edids(edids, save=False) + + async def __poll_default_edid(self) -> AsyncGenerator[None, None]: + logger = get_logger(0) + while True: + while not os.path.exists(self.__default_edid_path): + await asyncio.sleep(5) + try: + with Inotify() as inotify: + await inotify.watch_all_changes(self.__default_edid_path) + if os.path.islink(self.__default_edid_path): + await inotify.watch_all_changes(os.path.realpath(self.__default_edid_path)) + yield None + while True: + need_restart = False + need_notify = False + for event in (await inotify.get_series(timeout=1)): + need_notify = True + if event.restart: + logger.warning("Got fatal inotify event: %s; reinitializing ...", event) + need_restart = True + break + if need_restart: + break + if need_notify: + yield None + except Exception: + logger.exception("Unexpected watcher error") + await asyncio.sleep(1) + + async def __systask_storage(self) -> None: + # При остановке KVMD можем не успеть записать, ну да пофиг + prevs = dict.fromkeys(self.__X_ALL) + while True: + await self.__save_notifier.wait() + while (await self.__save_notifier.wait(5)): + pass + while True: + try: + async with self.__lock: + write = { + key: new + for (key, old) in prevs.items() + if (new := getattr(self.__cache, f"get_{key}")()) != old + } + if write: + async with self.__storage.writable() as ctx: + for (key, new) in write.items(): + func = getattr(ctx, f"write_{key}") + if isinstance(new, tuple): + await func(*new) + else: + await func(new) + prevs[key] = new + except Exception: + get_logger(0).exception("Unexpected storage error") + await asyncio.sleep(5) + else: + break diff --git a/kvmd/apps/kvmd/switch/chain.py b/kvmd/apps/kvmd/switch/chain.py new file mode 100644 index 00000000..8e4d94eb --- /dev/null +++ b/kvmd/apps/kvmd/switch/chain.py @@ -0,0 +1,440 @@ +# ========================================================================== # +# # +# KVMD - The main PiKVM daemon. # +# # +# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <https://www.gnu.org/licenses/>. # +# # +# ========================================================================== # + + +import multiprocessing +import queue +import select +import dataclasses +import time + +from typing import AsyncGenerator + +from .lib import get_logger +from .lib import tools +from .lib import aiotools +from .lib import aioproc + +from .types import Edids +from .types import Colors + +from .proto import Response +from .proto import UnitState +from .proto import UnitAtxLeds + +from .device import Device +from .device import DeviceError + + +# ===== +class _BaseCmd: + pass + + [email protected](frozen=True) +class _CmdSetActual(_BaseCmd): + actual: bool + + [email protected](frozen=True) +class _CmdSetActivePort(_BaseCmd): + port: int + + def __post_init__(self) -> None: + assert self.port >= 0 + + [email protected](frozen=True) +class _CmdSetPortBeacon(_BaseCmd): + port: int + on: bool + + [email protected](frozen=True) +class _CmdSetUnitBeacon(_BaseCmd): + unit: int + on: bool + downlink: bool + + [email protected](frozen=True) +class _CmdSetEdids(_BaseCmd): + edids: Edids + + [email protected](frozen=True) +class _CmdSetColors(_BaseCmd): + colors: Colors + + [email protected](frozen=True) +class _CmdAtxClick(_BaseCmd): + port: int + delay: float + reset: bool + if_powered: (bool | None) + + def __post_init__(self) -> None: + assert self.port >= 0 + assert 0.001 <= self.delay <= (0xFFFF / 1000) + + [email protected](frozen=True) +class _CmdRebootUnit(_BaseCmd): + unit: int + bootloader: bool + + def __post_init__(self) -> None: + assert self.unit >= 0 + + +class _UnitContext: + __TIMEOUT = 5.0 + + def __init__(self) -> None: + self.state: (UnitState | None) = None + self.atx_leds: (UnitAtxLeds | None) = None + self.__rid = -1 + self.__deadline_ts = -1.0 + + def can_be_changed(self) -> bool: + return ( + self.state is not None + and not self.state.flags.changing_busy + and self.changing_rid < 0 + ) + + # ===== + + @property + def changing_rid(self) -> int: + if self.__deadline_ts >= 0 and self.__deadline_ts < time.monotonic(): + self.__rid = -1 + self.__deadline_ts = -1 + return self.__rid + + @changing_rid.setter + def changing_rid(self, rid: int) -> None: + self.__rid = rid + self.__deadline_ts = ((time.monotonic() + self.__TIMEOUT) if rid >= 0 else -1) + + # ===== + + def is_atx_allowed(self, ch: int) -> tuple[bool, bool]: # (allowed, power_led) + if self.state is None or self.atx_leds is None: + return (False, False) + return ((not self.state.atx_busy[ch]), self.atx_leds.power[ch]) + + +# ===== +class BaseEvent: + pass + + +class DeviceFoundEvent(BaseEvent): + pass + + [email protected](frozen=True) +class ChainTruncatedEvent(BaseEvent): + units: int + + [email protected](frozen=True) +class PortActivatedEvent(BaseEvent): + port: int + + [email protected](frozen=True) +class UnitStateEvent(BaseEvent): + unit: int + state: UnitState + + [email protected](frozen=True) +class UnitAtxLedsEvent(BaseEvent): + unit: int + atx_leds: UnitAtxLeds + + +# ===== +class Chain: # pylint: disable=too-many-instance-attributes + def __init__(self, device_path: str) -> None: + self.__device = Device(device_path) + + self.__actual = False + + self.__edids = Edids() + + self.__colors = Colors() + + self.__units: list[_UnitContext] = [] + self.__active_port = -1 + + self.__cmd_queue: "multiprocessing.Queue[_BaseCmd]" = multiprocessing.Queue() + self.__events_queue: "multiprocessing.Queue[BaseEvent]" = multiprocessing.Queue() + + self.__stop_event = multiprocessing.Event() + + def set_actual(self, actual: bool) -> None: + # Флаг разрешения синхронизации EDID и прочих чувствительных вещей + self.__queue_cmd(_CmdSetActual(actual)) + + # ===== + + def set_active_port(self, port: int) -> None: + self.__queue_cmd(_CmdSetActivePort(port)) + + # ===== + + def set_port_beacon(self, port: int, on: bool) -> None: + self.__queue_cmd(_CmdSetPortBeacon(port, on)) + + def set_uplink_beacon(self, unit: int, on: bool) -> None: + self.__queue_cmd(_CmdSetUnitBeacon(unit, on, downlink=False)) + + def set_downlink_beacon(self, unit: int, on: bool) -> None: + self.__queue_cmd(_CmdSetUnitBeacon(unit, on, downlink=True)) + + # ===== + + def set_edids(self, edids: Edids) -> None: + self.__queue_cmd(_CmdSetEdids(edids)) # Will be copied because of multiprocessing.Queue() + + def set_colors(self, colors: Colors) -> None: + self.__queue_cmd(_CmdSetColors(colors)) + + # ===== + + def click_power(self, port: int, delay: float, if_powered: (bool | None)) -> None: + self.__queue_cmd(_CmdAtxClick(port, delay, reset=False, if_powered=if_powered)) + + def click_reset(self, port: int, delay: float, if_powered: (bool | None)) -> None: + self.__queue_cmd(_CmdAtxClick(port, delay, reset=True, if_powered=if_powered)) + + # ===== + + def reboot_unit(self, unit: int, bootloader: bool) -> None: + self.__queue_cmd(_CmdRebootUnit(unit, bootloader)) + + # ===== + + async def poll_events(self) -> AsyncGenerator[BaseEvent, None]: + proc = multiprocessing.Process(target=self.__subprocess, daemon=True) + try: + proc.start() + while True: + try: + yield (await aiotools.run_async(self.__events_queue.get, True, 0.1)) + except queue.Empty: + pass + finally: + if proc.is_alive(): + self.__stop_event.set() + if proc.is_alive() or proc.exitcode is not None: + await aiotools.run_async(proc.join) + + # ===== + + def __queue_cmd(self, cmd: _BaseCmd) -> None: + if not self.__stop_event.is_set(): + self.__cmd_queue.put_nowait(cmd) + + def __queue_event(self, event: BaseEvent) -> None: + if not self.__stop_event.is_set(): + self.__events_queue.put_nowait(event) + + def __subprocess(self) -> None: + logger = aioproc.settle("Switch", "switch") + no_device_reported = False + while True: + try: + if self.__device.has_device(): + no_device_reported = False + with self.__device: + logger.info("Switch found") + self.__queue_event(DeviceFoundEvent()) + self.__main_loop() + elif not no_device_reported: + self.__queue_event(ChainTruncatedEvent(0)) + logger.info("Switch is missing") + no_device_reported = True + except DeviceError as ex: + logger.error("%s", tools.efmt(ex)) + except Exception: + logger.exception("Unexpected error in the Switch loop") + tools.clear_queue(self.__cmd_queue) + if self.__stop_event.is_set(): + break + time.sleep(1) + + def __main_loop(self) -> None: + self.__device.request_state() + self.__device.request_atx_leds() + while not self.__stop_event.is_set(): + if self.__select(): + for resp in self.__device.read_all(): + self.__update_units(resp) + self.__adjust_start_port() + self.__finish_changing_request(resp) + self.__consume_commands() + self.__ensure_config() + + def __select(self) -> bool: + try: + return bool(select.select([ + self.__device.get_fd(), + self.__cmd_queue._reader, # type: ignore # pylint: disable=protected-access + ], [], [], 1)[0]) + except Exception as ex: + raise DeviceError(ex) + + def __consume_commands(self) -> None: + while not self.__cmd_queue.empty(): + cmd = self.__cmd_queue.get() + match cmd: + case _CmdSetActual(): + self.__actual = cmd.actual + + case _CmdSetActivePort(): + # Может быть вызвано изнутри при синхронизации + self.__active_port = cmd.port + self.__queue_event(PortActivatedEvent(self.__active_port)) + + case _CmdSetPortBeacon(): + (unit, ch) = self.get_real_unit_channel(cmd.port) + self.__device.request_beacon(unit, ch, cmd.on) + + case _CmdSetUnitBeacon(): + ch = (4 if cmd.downlink else 5) + self.__device.request_beacon(cmd.unit, ch, cmd.on) + + case _CmdAtxClick(): + (unit, ch) = self.get_real_unit_channel(cmd.port) + if unit < len(self.__units): + (allowed, powered) = self.__units[unit].is_atx_allowed(ch) + if allowed and (cmd.if_powered is None or cmd.if_powered == powered): + delay_ms = min(int(cmd.delay * 1000), 0xFFFF) + if cmd.reset: + self.__device.request_atx_cr(unit, ch, delay_ms) + else: + self.__device.request_atx_cp(unit, ch, delay_ms) + + case _CmdSetEdids(): + self.__edids = cmd.edids + + case _CmdSetColors(): + self.__colors = cmd.colors + + case _CmdRebootUnit(): + self.__device.request_reboot(cmd.unit, cmd.bootloader) + + def __update_units(self, resp: Response) -> None: + units = resp.header.unit + 1 + while len(self.__units) < units: + self.__units.append(_UnitContext()) + + match resp.body: + case UnitState(): + if not resp.body.flags.has_downlink and len(self.__units) > units: + del self.__units[units:] + self.__queue_event(ChainTruncatedEvent(units)) + self.__units[resp.header.unit].state = resp.body + self.__queue_event(UnitStateEvent(resp.header.unit, resp.body)) + + case UnitAtxLeds(): + self.__units[resp.header.unit].atx_leds = resp.body + self.__queue_event(UnitAtxLedsEvent(resp.header.unit, resp.body)) + + def __adjust_start_port(self) -> None: + if self.__active_port < 0: + for (unit, ctx) in enumerate(self.__units): + if ctx.state is not None and ctx.state.ch < 4: + # Trigger queue select() + port = self.get_virtual_port(unit, ctx.state.ch) + get_logger().info("Found an active port %d on [%d:%d]: Syncing ...", + port, unit, ctx.state.ch) + self.set_active_port(port) + break + + def __finish_changing_request(self, resp: Response) -> None: + if self.__units[resp.header.unit].changing_rid == resp.header.rid: + self.__units[resp.header.unit].changing_rid = -1 + + # ===== + + def __ensure_config(self) -> None: + for (unit, ctx) in enumerate(self.__units): + if ctx.state is not None: + self.__ensure_config_port(unit, ctx) + if self.__actual: + self.__ensure_config_edids(unit, ctx) + self.__ensure_config_colors(unit, ctx) + + def __ensure_config_port(self, unit: int, ctx: _UnitContext) -> None: + assert ctx.state is not None + if self.__active_port >= 0 and ctx.can_be_changed(): + ch = self.get_unit_target_channel(unit, self.__active_port) + if ctx.state.ch != ch: + get_logger().info("Switching for active port %d: [%d:%d] -> [%d:%d] ...", + self.__active_port, unit, ctx.state.ch, unit, ch) + ctx.changing_rid = self.__device.request_switch(unit, ch) + + def __ensure_config_edids(self, unit: int, ctx: _UnitContext) -> None: + assert self.__actual + assert ctx.state is not None + if ctx.can_be_changed(): + for ch in range(4): + port = self.get_virtual_port(unit, ch) + edid = self.__edids.get_edid_for_port(port) + if not ctx.state.compare_edid(ch, edid): + get_logger().info("Changing EDID on port %d on [%d:%d]: %d/%d -> %d/%d (%s) ...", + port, unit, ch, + ctx.state.video_crc[ch], ctx.state.video_edid[ch], + edid.crc, edid.valid, edid.name) + ctx.changing_rid = self.__device.request_set_edid(unit, ch, edid) + break # Busy globally + + def __ensure_config_colors(self, unit: int, ctx: _UnitContext) -> None: + assert self.__actual + assert ctx.state is not None + for np in range(6): + if self.__colors.crc != ctx.state.np_crc[np]: + # get_logger().info("Changing colors on NP [%d:%d]: %d -> %d ...", + # unit, np, ctx.state.np_crc[np], self.__colors.crc) + self.__device.request_set_colors(unit, np, self.__colors) + + # ===== + + @classmethod + def get_real_unit_channel(cls, port: int) -> tuple[int, int]: + return (port // 4, port % 4) + + @classmethod + def get_unit_target_channel(cls, unit: int, port: int) -> int: + (t_unit, t_ch) = cls.get_real_unit_channel(port) + if unit != t_unit: + t_ch = 4 + return t_ch + + @classmethod + def get_virtual_port(cls, unit: int, ch: int) -> int: + return (unit * 4) + ch diff --git a/kvmd/apps/kvmd/switch/device.py b/kvmd/apps/kvmd/switch/device.py new file mode 100644 index 00000000..b56cc406 --- /dev/null +++ b/kvmd/apps/kvmd/switch/device.py @@ -0,0 +1,196 @@ +# ========================================================================== # +# # +# KVMD - The main PiKVM daemon. # +# # +# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <https://www.gnu.org/licenses/>. # +# # +# ========================================================================== # + + +import os +import random +import types + +import serial + +from .lib import tools + +from .types import Edid +from .types import Colors + +from .proto import Packable +from .proto import Request +from .proto import Response +from .proto import Header + +from .proto import BodySwitch +from .proto import BodySetBeacon +from .proto import BodyAtxClick +from .proto import BodySetEdid +from .proto import BodyClearEdid +from .proto import BodySetColors + + +# ===== +class DeviceError(Exception): + def __init__(self, ex: Exception): + super().__init__(tools.efmt(ex)) + + +class Device: + __SPEED = 115200 + __TIMEOUT = 5.0 + + def __init__(self, device_path: str) -> None: + self.__device_path = device_path + self.__rid = random.randint(1, 0xFFFF) + self.__tty: (serial.Serial | None) = None + self.__buf: bytes = b"" + + def __enter__(self) -> "Device": + try: + self.__tty = serial.Serial( + self.__device_path, + baudrate=self.__SPEED, + timeout=self.__TIMEOUT, + ) + except Exception as ex: + raise DeviceError(ex) + return self + + def __exit__( + self, + _exc_type: type[BaseException], + _exc: BaseException, + _tb: types.TracebackType, + ) -> None: + + if self.__tty is not None: + try: + self.__tty.close() + except Exception: + pass + self.__tty = None + + def has_device(self) -> bool: + return os.path.exists(self.__device_path) + + def get_fd(self) -> int: + assert self.__tty is not None + return self.__tty.fd + + def read_all(self) -> list[Response]: + assert self.__tty is not None + try: + if not self.__tty.in_waiting: + return [] + self.__buf += self.__tty.read_all() + except Exception as ex: + raise DeviceError(ex) + + results: list[Response] = [] + while True: + try: + begin = self.__buf.index(0xF1) + except ValueError: + break + try: + end = self.__buf.index(0xF2, begin) + except ValueError: + break + msg = self.__buf[begin + 1:end] + if 0xF1 in msg: + # raise RuntimeError(f"Found 0xF1 inside the message: {msg!r}") + break + self.__buf = self.__buf[end + 1:] + msg = self.__unescape(msg) + resp = Response.unpack(msg) + if resp is not None: + results.append(resp) + return results + + def __unescape(self, msg: bytes) -> bytes: + if 0xF0 not in msg: + return msg + unesc: list[int] = [] + esc = False + for ch in msg: + if ch == 0xF0: + esc = True + else: + if esc: + ch ^= 0xFF + esc = False + unesc.append(ch) + return bytes(unesc) + + def request_reboot(self, unit: int, bootloader: bool) -> int: + return self.__send_request((Header.BOOTLOADER if bootloader else Header.REBOOT), unit, None) + + def request_state(self) -> int: + return self.__send_request(Header.STATE, 0xFF, None) + + def request_switch(self, unit: int, ch: int) -> int: + return self.__send_request(Header.SWITCH, unit, BodySwitch(ch)) + + def request_beacon(self, unit: int, ch: int, on: bool) -> int: + return self.__send_request(Header.BEACON, unit, BodySetBeacon(ch, on)) + + def request_atx_leds(self) -> int: + return self.__send_request(Header.ATX_LEDS, 0xFF, None) + + def request_atx_cp(self, unit: int, ch: int, delay_ms: int) -> int: + return self.__send_request(Header.ATX_CLICK, unit, BodyAtxClick(ch, BodyAtxClick.POWER, delay_ms)) + + def request_atx_cr(self, unit: int, ch: int, delay_ms: int) -> int: + return self.__send_request(Header.ATX_CLICK, unit, BodyAtxClick(ch, BodyAtxClick.RESET, delay_ms)) + + def request_set_edid(self, unit: int, ch: int, edid: Edid) -> int: + if edid.valid: + return self.__send_request(Header.SET_EDID, unit, BodySetEdid(ch, edid)) + return self.__send_request(Header.CLEAR_EDID, unit, BodyClearEdid(ch)) + + def request_set_colors(self, unit: int, ch: int, colors: Colors) -> int: + return self.__send_request(Header.SET_COLORS, unit, BodySetColors(ch, colors)) + + def __send_request(self, op: int, unit: int, body: (Packable | None)) -> int: + assert self.__tty is not None + req = Request(Header( + proto=1, + rid=self.__get_next_rid(), + op=op, + unit=unit, + ), body) + data: list[int] = [0xF1] + for ch in req.pack(): + if 0xF0 <= ch <= 0xF2: + data.append(0xF0) + ch ^= 0xFF + data.append(ch) + data.append(0xF2) + try: + self.__tty.write(bytes(data)) + self.__tty.flush() + except Exception as ex: + raise DeviceError(ex) + return req.header.rid + + def __get_next_rid(self) -> int: + rid = self.__rid + self.__rid += 1 + if self.__rid > 0xFFFF: + self.__rid = 1 + return rid diff --git a/kvmd/apps/kvmd/switch/lib.py b/kvmd/apps/kvmd/switch/lib.py new file mode 100644 index 00000000..4ef2647e --- /dev/null +++ b/kvmd/apps/kvmd/switch/lib.py @@ -0,0 +1,35 @@ +# ========================================================================== # +# # +# KVMD - The main PiKVM daemon. # +# # +# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <https://www.gnu.org/licenses/>. # +# # +# ========================================================================== # + + +# pylint: disable=unused-import + +from ....logging import get_logger # noqa: F401 + +from .... import tools # noqa: F401 +from .... import aiotools # noqa: F401 +from .... import aioproc # noqa: F401 +from .... import bitbang # noqa: F401 +from .... import htclient # noqa: F401 +from ....inotify import Inotify # noqa: F401 +from ....errors import OperationError # noqa: F401 +from ....edid import EdidNoBlockError as ParsedEdidNoBlockError # noqa: F401 +from ....edid import Edid as ParsedEdid # noqa: F401 diff --git a/kvmd/apps/kvmd/switch/proto.py b/kvmd/apps/kvmd/switch/proto.py new file mode 100644 index 00000000..d4f43f84 --- /dev/null +++ b/kvmd/apps/kvmd/switch/proto.py @@ -0,0 +1,295 @@ +# ========================================================================== # +# # +# KVMD - The main PiKVM daemon. # +# # +# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <https://www.gnu.org/licenses/>. # +# # +# ========================================================================== # + + +import struct +import dataclasses + +from typing import Optional + +from .types import Edid +from .types import Colors + + +# ===== +class Packable: + def pack(self) -> bytes: + raise NotImplementedError() + + +class Unpackable: + @classmethod + def unpack(cls, data: bytes, offset: int=0) -> "Unpackable": + raise NotImplementedError() + + +# ===== [email protected](frozen=True) +class Header(Packable, Unpackable): + proto: int + rid: int + op: int + unit: int + + NAK = 0 + BOOTLOADER = 2 + REBOOT = 3 + STATE = 4 + SWITCH = 5 + BEACON = 6 + ATX_LEDS = 7 + ATX_CLICK = 8 + SET_EDID = 9 + CLEAR_EDID = 10 + SET_COLORS = 12 + + __struct = struct.Struct("<BHBB") + + SIZE = __struct.size + + def pack(self) -> bytes: + return self.__struct.pack(self.proto, self.rid, self.op, self.unit) + + @classmethod + def unpack(cls, data: bytes, offset: int=0) -> "Header": + return Header(*cls.__struct.unpack_from(data, offset=offset)) + + [email protected](frozen=True) +class Nak(Unpackable): + reason: int + + INVALID_COMMAND = 0 + BUSY = 1 + NO_DOWNLINK = 2 + DOWNLINK_OVERFLOW = 3 + + __struct = struct.Struct("<B") + + @classmethod + def unpack(cls, data: bytes, offset: int=0) -> "Nak": + return Nak(*cls.__struct.unpack_from(data, offset=offset)) + + [email protected](frozen=True) +class UnitFlags: + changing_busy: bool + flashing_busy: bool + has_downlink: bool + + [email protected](frozen=True) +class UnitState(Unpackable): # pylint: disable=too-many-instance-attributes + sw_version: int + hw_version: int + flags: UnitFlags + ch: int + beacons: tuple[bool, bool, bool, bool, bool, bool] + np_crc: tuple[int, int, int, int, int, int] + video_5v_sens: tuple[bool, bool, bool, bool, bool] + video_hpd: tuple[bool, bool, bool, bool, bool] + video_edid: tuple[bool, bool, bool, bool] + video_crc: tuple[int, int, int, int] + usb_5v_sens: tuple[bool, bool, bool, bool] + atx_busy: tuple[bool, bool, bool, bool] + + __struct = struct.Struct("<HHHBBHHHHHHBBBHHHHBxB30x") + + def compare_edid(self, ch: int, edid: Optional["Edid"]) -> bool: + if edid is None: + # Сойдет любой невалидный EDID + return (not self.video_edid[ch]) + return ( + self.video_edid[ch] == edid.valid + and self.video_crc[ch] == edid.crc + ) + + @classmethod + def unpack(cls, data: bytes, offset: int=0) -> "UnitState": # pylint: disable=too-many-locals + ( + sw_version, hw_version, flags, ch, + beacons, nc0, nc1, nc2, nc3, nc4, nc5, + video_5v_sens, video_hpd, video_edid, vc0, vc1, vc2, vc3, + usb_5v_sens, atx_busy, + ) = cls.__struct.unpack_from(data, offset=offset) + return UnitState( + sw_version, + hw_version, + flags=UnitFlags( + changing_busy=bool(flags & 0x80), + flashing_busy=bool(flags & 0x40), + has_downlink=bool(flags & 0x02), + ), + ch=ch, + beacons=cls.__make_flags6(beacons), + np_crc=(nc0, nc1, nc2, nc3, nc4, nc5), + video_5v_sens=cls.__make_flags5(video_5v_sens), + video_hpd=cls.__make_flags5(video_hpd), + video_edid=cls.__make_flags4(video_edid), + video_crc=(vc0, vc1, vc2, vc3), + usb_5v_sens=cls.__make_flags4(usb_5v_sens), + atx_busy=cls.__make_flags4(atx_busy), + ) + + @classmethod + def __make_flags6(cls, mask: int) -> tuple[bool, bool, bool, bool, bool, bool]: + return ( + bool(mask & 0x01), bool(mask & 0x02), bool(mask & 0x04), + bool(mask & 0x08), bool(mask & 0x10), bool(mask & 0x20), + ) + + @classmethod + def __make_flags5(cls, mask: int) -> tuple[bool, bool, bool, bool, bool]: + return ( + bool(mask & 0x01), bool(mask & 0x02), bool(mask & 0x04), + bool(mask & 0x08), bool(mask & 0x10), + ) + + @classmethod + def __make_flags4(cls, mask: int) -> tuple[bool, bool, bool, bool]: + return (bool(mask & 0x01), bool(mask & 0x02), bool(mask & 0x04), bool(mask & 0x08)) + + [email protected](frozen=True) +class UnitAtxLeds(Unpackable): + power: tuple[bool, bool, bool, bool] + hdd: tuple[bool, bool, bool, bool] + + __struct = struct.Struct("<B") + + @classmethod + def unpack(cls, data: bytes, offset: int=0) -> "UnitAtxLeds": + (mask,) = cls.__struct.unpack_from(data, offset=offset) + return UnitAtxLeds( + power=(bool(mask & 0x01), bool(mask & 0x02), bool(mask & 0x04), bool(mask & 0x08)), + hdd=(bool(mask & 0x10), bool(mask & 0x20), bool(mask & 0x40), bool(mask & 0x80)), + ) + + +# ===== [email protected](frozen=True) +class BodySwitch(Packable): + ch: int + + def __post_init__(self) -> None: + assert 0 <= self.ch <= 4 + + def pack(self) -> bytes: + return self.ch.to_bytes() + + [email protected](frozen=True) +class BodySetBeacon(Packable): + ch: int + on: bool + + def __post_init__(self) -> None: + assert 0 <= self.ch <= 5 + + def pack(self) -> bytes: + return self.ch.to_bytes() + self.on.to_bytes() + + [email protected](frozen=True) +class BodyAtxClick(Packable): + ch: int + action: int + delay_ms: int + + POWER = 0 + RESET = 1 + + __struct = struct.Struct("<BBH") + + def __post_init__(self) -> None: + assert 0 <= self.ch <= 3 + assert self.action in [self.POWER, self.RESET] + assert 1 <= self.delay_ms <= 0xFFFF + + def pack(self) -> bytes: + return self.__struct.pack(self.ch, self.action, self.delay_ms) + + [email protected](frozen=True) +class BodySetEdid(Packable): + ch: int + edid: Edid + + def __post_init__(self) -> None: + assert 0 <= self.ch <= 3 + + def pack(self) -> bytes: + return self.ch.to_bytes() + self.edid.pack() + + [email protected](frozen=True) +class BodyClearEdid(Packable): + ch: int + + def __post_init__(self) -> None: + assert 0 <= self.ch <= 3 + + def pack(self) -> bytes: + return self.ch.to_bytes() + + [email protected](frozen=True) +class BodySetColors(Packable): + ch: int + colors: Colors + + def __post_init__(self) -> None: + assert 0 <= self.ch <= 5 + + def pack(self) -> bytes: + return self.ch.to_bytes() + self.colors.pack() + + +# ===== [email protected](frozen=True) +class Request: + header: Header + body: (Packable | None) = dataclasses.field(default=None) + + def pack(self) -> bytes: + msg = self.header.pack() + if self.body is not None: + msg += self.body.pack() + return msg + + [email protected](frozen=True) +class Response: + header: Header + body: Unpackable + + @classmethod + def unpack(cls, msg: bytes) -> Optional["Response"]: + header = Header.unpack(msg) + match header.op: + case Header.NAK: + return Response(header, Nak.unpack(msg, Header.SIZE)) + case Header.STATE: + return Response(header, UnitState.unpack(msg, Header.SIZE)) + case Header.ATX_LEDS: + return Response(header, UnitAtxLeds.unpack(msg, Header.SIZE)) + # raise RuntimeError(f"Unknown OP in the header: {header!r}") + return None diff --git a/kvmd/apps/kvmd/switch/state.py b/kvmd/apps/kvmd/switch/state.py new file mode 100644 index 00000000..626cdfe1 --- /dev/null +++ b/kvmd/apps/kvmd/switch/state.py @@ -0,0 +1,355 @@ +# ========================================================================== # +# # +# KVMD - The main PiKVM daemon. # +# # +# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <https://www.gnu.org/licenses/>. # +# # +# ========================================================================== # + + +import asyncio +import dataclasses +import time + +from typing import AsyncGenerator + +from .types import Edids +from .types import Color +from .types import Colors +from .types import PortNames +from .types import AtxClickPowerDelays +from .types import AtxClickPowerLongDelays +from .types import AtxClickResetDelays + +from .proto import UnitState +from .proto import UnitAtxLeds + +from .chain import Chain + + +# ===== +class _UnitInfo: + state: (UnitState | None) = dataclasses.field(default=None) + atx_leds: (UnitAtxLeds | None) = dataclasses.field(default=None) + + +# ===== +class StateCache: # pylint: disable=too-many-instance-attributes + __FULL = 0xFFFF + __SUMMARY = 0x01 + __EDIDS = 0x02 + __COLORS = 0x04 + __VIDEO = 0x08 + __USB = 0x10 + __BEACONS = 0x20 + __ATX = 0x40 + + def __init__(self) -> None: + self.__edids = Edids() + self.__colors = Colors() + self.__port_names = PortNames({}) + self.__atx_cp_delays = AtxClickPowerDelays({}) + self.__atx_cpl_delays = AtxClickPowerLongDelays({}) + self.__atx_cr_delays = AtxClickResetDelays({}) + + self.__units: list[_UnitInfo] = [] + self.__active_port = -1 + self.__synced = True + + self.__queue: "asyncio.Queue[int]" = asyncio.Queue() + + def get_edids(self) -> Edids: + return self.__edids.copy() + + def get_colors(self) -> Colors: + return self.__colors + + def get_port_names(self) -> PortNames: + return self.__port_names.copy() + + def get_atx_cp_delays(self) -> AtxClickPowerDelays: + return self.__atx_cp_delays.copy() + + def get_atx_cpl_delays(self) -> AtxClickPowerLongDelays: + return self.__atx_cpl_delays.copy() + + def get_atx_cr_delays(self) -> AtxClickResetDelays: + return self.__atx_cr_delays.copy() + + # ===== + + def get_state(self) -> dict: + return self.__inner_get_state(self.__FULL) + + async def trigger_state(self) -> None: + self.__bump_state(self.__FULL) + + async def poll_state(self) -> AsyncGenerator[dict, None]: + atx_ts: float = 0 + while True: + try: + mask = await asyncio.wait_for(self.__queue.get(), timeout=0.1) + except TimeoutError: + mask = 0 + + if mask == self.__ATX: + # Откладываем единичное новое событие ATX, чтобы аккумулировать с нескольких свичей + if atx_ts == 0: + atx_ts = time.monotonic() + 0.2 + continue + elif atx_ts >= time.monotonic(): + continue + # ... Ну или разрешаем отправить, если оно уже достаточно мариновалось + elif mask == 0 and atx_ts > time.monotonic(): + # Разрешаем отправить отложенное + mask = self.__ATX + atx_ts = 0 + elif mask & self.__ATX: + # Комплексное событие всегда должно обрабатываться сразу + atx_ts = 0 + + if mask != 0: + yield self.__inner_get_state(mask) + + def __inner_get_state(self, mask: int) -> dict: # pylint: disable=too-many-branches,too-many-statements,too-many-locals + assert mask != 0 + x_model = (mask == self.__FULL) + x_summary = (mask & self.__SUMMARY) + x_edids = (mask & self.__EDIDS) + x_colors = (mask & self.__COLORS) + x_video = (mask & self.__VIDEO) + x_usb = (mask & self.__USB) + x_beacons = (mask & self.__BEACONS) + x_atx = (mask & self.__ATX) + + state: dict = {} + if x_model: + state["model"] = { + "units": [], + "ports": [], + "limits": { + "atx": { + "click_delays": { + key: {"default": value, "min": 0, "max": 10} + for (key, value) in [ + ("power", self.__atx_cp_delays.default), + ("power_long", self.__atx_cpl_delays.default), + ("reset", self.__atx_cr_delays.default), + ] + }, + }, + }, + } + if x_summary: + state["summary"] = {"active_port": self.__active_port, "synced": self.__synced} + if x_edids: + state["edids"] = { + "all": { + edid_id: { + "name": edid.name, + "data": edid.as_text(), + "parsed": (dataclasses.asdict(edid.info) if edid.info is not None else None), + } + for (edid_id, edid) in self.__edids.all.items() + }, + "used": [], + } + if x_colors: + state["colors"] = { + role: { + comp: getattr(getattr(self.__colors, role), comp) + for comp in Color.COMPONENTS + } + for role in Colors.ROLES + } + if x_video: + state["video"] = {"links": []} + if x_usb: + state["usb"] = {"links": []} + if x_beacons: + state["beacons"] = {"uplinks": [], "downlinks": [], "ports": []} + if x_atx: + state["atx"] = {"busy": [], "leds": {"power": [], "hdd": []}} + + if not self.__is_units_ready(): + return state + + for (unit, ui) in enumerate(self.__units): + assert ui.state is not None + assert ui.atx_leds is not None + if x_model: + state["model"]["units"].append({"firmware": {"version": ui.state.sw_version}}) + if x_video: + state["video"]["links"].extend(ui.state.video_5v_sens[:4]) + if x_usb: + state["usb"]["links"].extend(ui.state.usb_5v_sens) + if x_beacons: + state["beacons"]["uplinks"].append(ui.state.beacons[5]) + state["beacons"]["downlinks"].append(ui.state.beacons[4]) + state["beacons"]["ports"].extend(ui.state.beacons[:4]) + if x_atx: + state["atx"]["busy"].extend(ui.state.atx_busy) + state["atx"]["leds"]["power"].extend(ui.atx_leds.power) + state["atx"]["leds"]["hdd"].extend(ui.atx_leds.hdd) + if x_model or x_edids: + for ch in range(4): + port = Chain.get_virtual_port(unit, ch) + if x_model: + state["model"]["ports"].append({ + "unit": unit, + "channel": ch, + "name": self.__port_names[port], + "atx": { + "click_delays": { + "power": self.__atx_cp_delays[port], + "power_long": self.__atx_cpl_delays[port], + "reset": self.__atx_cr_delays[port], + }, + }, + }) + if x_edids: + state["edids"]["used"].append(self.__edids.get_id_for_port(port)) + return state + + def __inner_check_synced(self) -> bool: + for (unit, ui) in enumerate(self.__units): + if ui.state is None or ui.state.flags.changing_busy: + return False + if ( + self.__active_port >= 0 + and ui.state.ch != Chain.get_unit_target_channel(unit, self.__active_port) + ): + return False + for ch in range(4): + port = Chain.get_virtual_port(unit, ch) + edid = self.__edids.get_edid_for_port(port) + if not ui.state.compare_edid(ch, edid): + return False + for ch in range(6): + if ui.state.np_crc[ch] != self.__colors.crc: + return False + return True + + def __recache_synced(self) -> bool: + synced = self.__inner_check_synced() + if self.__synced != synced: + self.__synced = synced + return True + return False + + def truncate(self, units: int) -> None: + if len(self.__units) > units: + del self.__units[units:] + self.__bump_state(self.__FULL) + + def update_active_port(self, port: int) -> None: + changed = (self.__active_port != port) + self.__active_port = port + changed = (self.__recache_synced() or changed) + if changed: + self.__bump_state(self.__SUMMARY) + + def update_unit_state(self, unit: int, new: UnitState) -> None: + ui = self.__ensure_unit(unit) + (prev, ui.state) = (ui.state, new) + if not self.__is_units_ready(): + return + mask = 0 + if prev is None: + mask = self.__FULL + else: + if self.__recache_synced(): + mask |= self.__SUMMARY + if prev.video_5v_sens != new.video_5v_sens: + mask |= self.__VIDEO + if prev.usb_5v_sens != new.usb_5v_sens: + mask |= self.__USB + if prev.beacons != new.beacons: + mask |= self.__BEACONS + if prev.atx_busy != new.atx_busy: + mask |= self.__ATX + if mask: + self.__bump_state(mask) + + def update_unit_atx_leds(self, unit: int, new: UnitAtxLeds) -> None: + ui = self.__ensure_unit(unit) + (prev, ui.atx_leds) = (ui.atx_leds, new) + if not self.__is_units_ready(): + return + if prev is None: + self.__bump_state(self.__FULL) + elif prev != new: + self.__bump_state(self.__ATX) + + def __is_units_ready(self) -> bool: + for ui in self.__units: + if ui.state is None or ui.atx_leds is None: + return False + return True + + def __ensure_unit(self, unit: int) -> _UnitInfo: + while len(self.__units) < unit + 1: + self.__units.append(_UnitInfo()) + return self.__units[unit] + + def __bump_state(self, mask: int) -> None: + assert mask != 0 + self.__queue.put_nowait(mask) + + # ===== + + def set_edids(self, edids: Edids) -> None: + changed = ( + self.__edids.all != edids.all + or not self.__edids.compare_on_ports(edids, self.__get_ports()) + ) + self.__edids = edids.copy() + if changed: + self.__bump_state(self.__EDIDS) + + def set_colors(self, colors: Colors) -> None: + changed = (self.__colors != colors) + self.__colors = colors + if changed: + self.__bump_state(self.__COLORS) + + def set_port_names(self, port_names: PortNames) -> None: + changed = (not self.__port_names.compare_on_ports(port_names, self.__get_ports())) + self.__port_names = port_names.copy() + if changed: + self.__bump_state(self.__FULL) + + def set_atx_cp_delays(self, delays: AtxClickPowerDelays) -> None: + changed = (not self.__atx_cp_delays.compare_on_ports(delays, self.__get_ports())) + self.__atx_cp_delays = delays.copy() + if changed: + self.__bump_state(self.__FULL) + + def set_atx_cpl_delays(self, delays: AtxClickPowerLongDelays) -> None: + changed = (not self.__atx_cpl_delays.compare_on_ports(delays, self.__get_ports())) + self.__atx_cpl_delays = delays.copy() + if changed: + self.__bump_state(self.__FULL) + + def set_atx_cr_delays(self, delays: AtxClickResetDelays) -> None: + changed = (not self.__atx_cr_delays.compare_on_ports(delays, self.__get_ports())) + self.__atx_cr_delays = delays.copy() + if changed: + self.__bump_state(self.__FULL) + + def __get_ports(self) -> int: + return (len(self.__units) * 4) diff --git a/kvmd/apps/kvmd/switch/storage.py b/kvmd/apps/kvmd/switch/storage.py new file mode 100644 index 00000000..6e3a0a76 --- /dev/null +++ b/kvmd/apps/kvmd/switch/storage.py @@ -0,0 +1,186 @@ +# ========================================================================== # +# # +# KVMD - The main PiKVM daemon. # +# # +# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <https://www.gnu.org/licenses/>. # +# # +# ========================================================================== # + + +import os +import asyncio +import json +import contextlib + +from typing import AsyncGenerator + +try: + from ....clients.pst import PstClient +except ImportError: + PstClient = None # type: ignore + +# from .lib import get_logger +from .lib import aiotools +from .lib import htclient +from .lib import get_logger + +from .types import Edid +from .types import Edids +from .types import Color +from .types import Colors +from .types import PortNames +from .types import AtxClickPowerDelays +from .types import AtxClickPowerLongDelays +from .types import AtxClickResetDelays + + +# ===== +class StorageContext: + __F_EDIDS_ALL = "edids_all.json" + __F_EDIDS_PORT = "edids_port.json" + + __F_COLORS = "colors.json" + + __F_PORT_NAMES = "port_names.json" + + __F_ATX_CP_DELAYS = "atx_click_power_delays.json" + __F_ATX_CPL_DELAYS = "atx_click_power_long_delays.json" + __F_ATX_CR_DELAYS = "atx_click_reset_delays.json" + + def __init__(self, path: str, rw: bool) -> None: + self.__path = path + self.__rw = rw + + # ===== + + async def write_edids(self, edids: Edids) -> None: + await self.__write_json_keyvals(self.__F_EDIDS_ALL, { + edid_id.lower(): {"name": edid.name, "data": edid.as_text()} + for (edid_id, edid) in edids.all.items() + if edid_id != Edids.DEFAULT_ID + }) + await self.__write_json_keyvals(self.__F_EDIDS_PORT, edids.port) + + async def write_colors(self, colors: Colors) -> None: + await self.__write_json_keyvals(self.__F_COLORS, { + role: { + comp: getattr(getattr(colors, role), comp) + for comp in Color.COMPONENTS + } + for role in Colors.ROLES + }) + + async def write_port_names(self, port_names: PortNames) -> None: + await self.__write_json_keyvals(self.__F_PORT_NAMES, port_names.kvs) + + async def write_atx_cp_delays(self, delays: AtxClickPowerDelays) -> None: + await self.__write_json_keyvals(self.__F_ATX_CP_DELAYS, delays.kvs) + + async def write_atx_cpl_delays(self, delays: AtxClickPowerLongDelays) -> None: + await self.__write_json_keyvals(self.__F_ATX_CPL_DELAYS, delays.kvs) + + async def write_atx_cr_delays(self, delays: AtxClickResetDelays) -> None: + await self.__write_json_keyvals(self.__F_ATX_CR_DELAYS, delays.kvs) + + async def __write_json_keyvals(self, name: str, kvs: dict) -> None: + if len(self.__path) == 0: + return + assert self.__rw + kvs = {str(key): value for (key, value) in kvs.items()} + if (await self.__read_json_keyvals(name)) == kvs: + return # Don't write the same data + path = os.path.join(self.__path, name) + get_logger(0).info("Writing '%s' ...", name) + await aiotools.write_file(path, json.dumps(kvs)) + + # ===== + + async def read_edids(self) -> Edids: + all_edids = { + edid_id.lower(): Edid.from_data(edid["name"], edid["data"]) + for (edid_id, edid) in (await self.__read_json_keyvals(self.__F_EDIDS_ALL)).items() + } + port_edids = await self.__read_json_keyvals_int(self.__F_EDIDS_PORT) + return Edids(all_edids, port_edids) + + async def read_colors(self) -> Colors: + raw = await self.__read_json_keyvals(self.__F_COLORS) + return Colors(**{ # type: ignore + role: Color(**{comp: raw[role][comp] for comp in Color.COMPONENTS}) + for role in Colors.ROLES + if role in raw + }) + + async def read_port_names(self) -> PortNames: + return PortNames(await self.__read_json_keyvals_int(self.__F_PORT_NAMES)) + + async def read_atx_cp_delays(self) -> AtxClickPowerDelays: + return AtxClickPowerDelays(await self.__read_json_keyvals_int(self.__F_ATX_CP_DELAYS)) + + async def read_atx_cpl_delays(self) -> AtxClickPowerLongDelays: + return AtxClickPowerLongDelays(await self.__read_json_keyvals_int(self.__F_ATX_CPL_DELAYS)) + + async def read_atx_cr_delays(self) -> AtxClickResetDelays: + return AtxClickResetDelays(await self.__read_json_keyvals_int(self.__F_ATX_CR_DELAYS)) + + async def __read_json_keyvals_int(self, name: str) -> dict: + return (await self.__read_json_keyvals(name, int_keys=True)) + + async def __read_json_keyvals(self, name: str, int_keys: bool=False) -> dict: + if len(self.__path) == 0: + return {} + path = os.path.join(self.__path, name) + try: + kvs: dict = json.loads(await aiotools.read_file(path)) + except FileNotFoundError: + kvs = {} + if int_keys: + kvs = {int(key): value for (key, value) in kvs.items()} + return kvs + + +class Storage: + __SUBDIR = "__switch__" + __TIMEOUT = 5.0 + + def __init__(self, unix_path: str) -> None: + self.__pst: (PstClient | None) = None + if len(unix_path) > 0 and PstClient is not None: + self.__pst = PstClient( + subdir=self.__SUBDIR, + unix_path=unix_path, + timeout=self.__TIMEOUT, + user_agent=htclient.make_user_agent("KVMD"), + ) + self.__lock = asyncio.Lock() + + @contextlib.asynccontextmanager + async def readable(self) -> AsyncGenerator[StorageContext, None]: + async with self.__lock: + if self.__pst is None: + yield StorageContext("", False) + else: + path = await self.__pst.get_path() + yield StorageContext(path, False) + + @contextlib.asynccontextmanager + async def writable(self) -> AsyncGenerator[StorageContext, None]: + async with self.__lock: + if self.__pst is None: + yield StorageContext("", True) + else: + async with self.__pst.writable() as path: + yield StorageContext(path, True) diff --git a/kvmd/apps/kvmd/switch/types.py b/kvmd/apps/kvmd/switch/types.py new file mode 100644 index 00000000..32225f06 --- /dev/null +++ b/kvmd/apps/kvmd/switch/types.py @@ -0,0 +1,308 @@ +# ========================================================================== # +# # +# KVMD - The main PiKVM daemon. # +# # +# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <https://www.gnu.org/licenses/>. # +# # +# ========================================================================== # + + +import re +import struct +import uuid +import dataclasses + +from typing import TypeVar +from typing import Generic + +from .lib import bitbang +from .lib import ParsedEdidNoBlockError +from .lib import ParsedEdid + + +# ===== [email protected](frozen=True) +class EdidInfo: + mfc_id: str + product_id: int + serial: int + monitor_name: (str | None) + monitor_serial: (str | None) + audio: bool + + @classmethod + def from_data(cls, data: bytes) -> "EdidInfo": + parsed = ParsedEdid(data) + + monitor_name: (str | None) = None + try: + monitor_name = parsed.get_monitor_name() + except ParsedEdidNoBlockError: + pass + + monitor_serial: (str | None) = None + try: + monitor_serial = parsed.get_monitor_serial() + except ParsedEdidNoBlockError: + pass + + return EdidInfo( + mfc_id=parsed.get_mfc_id(), + product_id=parsed.get_product_id(), + serial=parsed.get_serial(), + monitor_name=monitor_name, + monitor_serial=monitor_serial, + audio=parsed.get_audio(), + ) + + [email protected](frozen=True) +class Edid: + name: str + data: bytes + crc: int = dataclasses.field(default=0) + valid: bool = dataclasses.field(default=False) + info: (EdidInfo | None) = dataclasses.field(default=None) + + __HEADER = b"\x00\xFF\xFF\xFF\xFF\xFF\xFF\x00" + + def __post_init__(self) -> None: + assert len(self.name) > 0 + assert len(self.data) == 256 + object.__setattr__(self, "crc", bitbang.make_crc16(self.data)) + object.__setattr__(self, "valid", self.data.startswith(self.__HEADER)) + try: + object.__setattr__(self, "info", EdidInfo.from_data(self.data)) + except Exception: + pass + + def as_text(self) -> str: + return "".join(f"{item:0{2}X}" for item in self.data) + + def pack(self) -> bytes: + return self.data + + @classmethod + def from_data(cls, name: str, data: (str | bytes | None)) -> "Edid": + if data is None: # Пустой едид + return Edid(name, b"\x00" * 256) + + if isinstance(data, bytes): + if data.startswith(cls.__HEADER): + return Edid(name, data) # Бинарный едид + data_hex = data.decode() # Текстовый едид, прочитанный как бинарный из файла + else: # isinstance(data, str) + data_hex = str(data) # Текстовый едид + + data_hex = re.sub(r"\s", "", data_hex) + assert len(data_hex) == 512 + data = bytes([ + int(data_hex[index:index + 2], 16) + for index in range(0, len(data_hex), 2) + ]) + return Edid(name, data) + + +class Edids: + DEFAULT_NAME = "Default" + DEFAULT_ID = "default" + + all: dict[str, Edid] = dataclasses.field(default_factory=dict) + port: dict[int, str] = dataclasses.field(default_factory=dict) + + def __post_init__(self) -> None: + if self.DEFAULT_ID not in self.all: + self.set_default(None) + + def set_default(self, data: (str | bytes | None)) -> None: + self.all[self.DEFAULT_ID] = Edid.from_data(self.DEFAULT_NAME, data) + + def copy(self) -> "Edids": + return Edids(dict(self.all), dict(self.port)) + + def compare_on_ports(self, other: "Edids", ports: int) -> bool: + for port in range(ports): + if self.get_id_for_port(port) != other.get_id_for_port(port): + return False + return True + + def add(self, edid: Edid) -> str: + edid_id = str(uuid.uuid4()).lower() + self.all[edid_id] = edid + return edid_id + + def set(self, edid_id: str, edid: Edid) -> None: + assert edid_id in self.all + self.all[edid_id] = edid + + def get(self, edid_id: str) -> Edid: + return self.all[edid_id] + + def remove(self, edid_id: str) -> None: + assert edid_id in self.all + self.all.pop(edid_id) + for port in list(self.port): + if self.port[port] == edid_id: + self.port.pop(port) + + def has_id(self, edid_id: str) -> bool: + return (edid_id in self.all) + + def assign(self, port: int, edid_id: str) -> None: + assert edid_id in self.all + if edid_id == Edids.DEFAULT_ID: + self.port.pop(port, None) + else: + self.port[port] = edid_id + + def get_id_for_port(self, port: int) -> str: + return self.port.get(port, self.DEFAULT_ID) + + def get_edid_for_port(self, port: int) -> Edid: + return self.all[self.get_id_for_port(port)] + + +# ===== [email protected](frozen=True) +class Color: + COMPONENTS = frozenset(["red", "green", "blue", "brightness", "blink_ms"]) + + red: int + green: int + blue: int + brightness: int + blink_ms: int + crc: int = dataclasses.field(default=0) + _packed: bytes = dataclasses.field(default=b"") + + __struct = struct.Struct("<BBBBH") + __rx = re.compile(r"^([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2}):([0-9a-fA-F]{2}):([0-9a-fA-F]{4})$") + + def __post_init__(self) -> None: + assert 0 <= self.red <= 0xFF + assert 0 <= self.green <= 0xFF + assert 0 <= self.blue <= 0xFF + assert 0 <= self.brightness <= 0xFF + assert 0 <= self.blink_ms <= 0xFFFF + data = self.__struct.pack(self.red, self.green, self.blue, self.brightness, self.blink_ms) + object.__setattr__(self, "crc", bitbang.make_crc16(data)) + object.__setattr__(self, "_packed", data) + + def pack(self) -> bytes: + return self._packed + + @classmethod + def from_text(cls, text: str) -> "Color": + match = cls.__rx.match(text) + assert match is not None, text + return Color( + red=int(match.group(1), 16), + green=int(match.group(2), 16), + blue=int(match.group(3), 16), + brightness=int(match.group(4), 16), + blink_ms=int(match.group(5), 16), + ) + + [email protected](frozen=True) +class Colors: + ROLES = frozenset(["inactive", "active", "flashing", "beacon", "bootloader"]) + + inactive: Color = dataclasses.field(default=Color(255, 0, 0, 64, 0)) + active: Color = dataclasses.field(default=Color(0, 255, 0, 128, 0)) + flashing: Color = dataclasses.field(default=Color(0, 170, 255, 128, 0)) + beacon: Color = dataclasses.field(default=Color(228, 44, 156, 255, 250)) + bootloader: Color = dataclasses.field(default=Color(255, 170, 0, 128, 0)) + crc: int = dataclasses.field(default=0) + _packed: bytes = dataclasses.field(default=b"") + + __crc_struct = struct.Struct("<HHHHH") + + def __post_init__(self) -> None: + crcs: list[int] = [] + packed: bytes = b"" + for color in [self.inactive, self.active, self.flashing, self.beacon, self.bootloader]: + crcs.append(color.crc) + packed += color.pack() + object.__setattr__(self, "crc", bitbang.make_crc16(self.__crc_struct.pack(*crcs))) + object.__setattr__(self, "_packed", packed) + + def pack(self) -> bytes: + return self._packed + + +# ===== +_T = TypeVar("_T") + + +class _PortsDict(Generic[_T]): + def __init__(self, default: _T, kvs: dict[int, _T]) -> None: + self.default = default + self.kvs = { + port: value + for (port, value) in kvs.items() + if value != default + } + + def compare_on_ports(self, other: "_PortsDict[_T]", ports: int) -> bool: + for port in range(ports): + if self[port] != other[port]: + return False + return True + + def __getitem__(self, port: int) -> _T: + return self.kvs.get(port, self.default) + + def __setitem__(self, port: int, value: (_T | None)) -> None: + if value is None: + value = self.default + if value == self.default: + self.kvs.pop(port, None) + else: + self.kvs[port] = value + + +class PortNames(_PortsDict[str]): + def __init__(self, kvs: dict[int, str]) -> None: + super().__init__("", kvs) + + def copy(self) -> "PortNames": + return PortNames(self.kvs) + + +class AtxClickPowerDelays(_PortsDict[float]): + def __init__(self, kvs: dict[int, float]) -> None: + super().__init__(0.5, kvs) + + def copy(self) -> "AtxClickPowerDelays": + return AtxClickPowerDelays(self.kvs) + + +class AtxClickPowerLongDelays(_PortsDict[float]): + def __init__(self, kvs: dict[int, float]) -> None: + super().__init__(5.5, kvs) + + def copy(self) -> "AtxClickPowerLongDelays": + return AtxClickPowerLongDelays(self.kvs) + + +class AtxClickResetDelays(_PortsDict[float]): + def __init__(self, kvs: dict[int, float]) -> None: + super().__init__(0.5, kvs) + + def copy(self) -> "AtxClickResetDelays": + return AtxClickResetDelays(self.kvs) diff --git a/kvmd/apps/pst/server.py b/kvmd/apps/pst/server.py index 79bbf7c8..8d8bf9d4 100644 --- a/kvmd/apps/pst/server.py +++ b/kvmd/apps/pst/server.py @@ -24,6 +24,7 @@ import os import asyncio from aiohttp.web import Request +from aiohttp.web import Response from aiohttp.web import WebSocketResponse from ...logging import get_logger @@ -35,6 +36,7 @@ from ... import fstab from ...htserver import exposed_http from ...htserver import exposed_ws +from ...htserver import make_json_response from ...htserver import WsSession from ...htserver import HttpServer @@ -65,6 +67,16 @@ class PstServer(HttpServer): # pylint: disable=too-many-arguments,too-many-inst await ws.send_event("loop", {}) return (await self._ws_loop(ws)) + @exposed_http("GET", "/state") + async def __state_handler(self, _: Request) -> Response: + return make_json_response({ + "clients": len(self._get_wss()), + "data": { + "path": self.__data_path, + "write_allowed": self.__is_write_available(), + }, + }) + @exposed_ws("ping") async def __ws_ping_handler(self, ws: WsSession, _: dict) -> None: await ws.send_event("pong", {}) diff --git a/kvmd/clients/pst.py b/kvmd/clients/pst.py new file mode 100644 index 00000000..6b9f5234 --- /dev/null +++ b/kvmd/clients/pst.py @@ -0,0 +1,93 @@ +# ========================================================================== # +# # +# KVMD - The main PiKVM daemon. # +# # +# Copyright (C) 2020 Maxim Devaev <[email protected]> # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <https://www.gnu.org/licenses/>. # +# # +# ========================================================================== # + + +import os +import contextlib + +from typing import AsyncGenerator + +import aiohttp + +from .. import htclient +from .. import htserver + + +# ===== +class PstError(Exception): + pass + + +# ===== +class PstClient: + def __init__( + self, + subdir: str, + unix_path: str, + timeout: float, + user_agent: str, + ) -> None: + + self.__subdir = subdir + self.__unix_path = unix_path + self.__timeout = timeout + self.__user_agent = user_agent + + async def get_path(self) -> str: + async with self.__make_http_session() as session: + async with session.get("http://localhost:0/state") as resp: + htclient.raise_not_200(resp) + path = (await resp.json())["result"]["data"]["path"] + return os.path.join(path, self.__subdir) + + @contextlib.asynccontextmanager + async def writable(self) -> AsyncGenerator[str, None]: + async with self.__inner_writable() as path: + path = os.path.join(path, self.__subdir) + if not os.path.exists(path): + os.mkdir(path) + yield path + + @contextlib.asynccontextmanager + async def __inner_writable(self) -> AsyncGenerator[str, None]: + async with self.__make_http_session() as session: + async with session.ws_connect("http://localhost:0/ws") as ws: + path = "" + async for msg in ws: + if msg.type != aiohttp.WSMsgType.TEXT: + raise PstError(f"Unexpected message type: {msg!r}") + (event_type, event) = htserver.parse_ws_event(msg.data) + if event_type == "storage_state": + if not event["data"]["write_allowed"]: + raise PstError("Write is not allowed") + path = event["data"]["path"] + break + if not path: + raise PstError("WS loop broken without write_allowed=True flag") + # TODO: Actually we should follow ws events, but for fast writing we can safely ignore them + yield path + + def __make_http_session(self) -> aiohttp.ClientSession: + return aiohttp.ClientSession( + headers={"User-Agent": self.__user_agent}, + connector=aiohttp.UnixConnector(path=self.__unix_path), + timeout=aiohttp.ClientTimeout(total=self.__timeout), + ) diff --git a/kvmd/validators/__init__.py b/kvmd/validators/__init__.py index 39ff60aa..aa997ab9 100644 --- a/kvmd/validators/__init__.py +++ b/kvmd/validators/__init__.py @@ -99,3 +99,11 @@ def check_any(arg: Any, name: str, validators: list[Callable[[Any], Any]]) -> An except Exception: pass raise_error(arg, name) + + +# ===== +def filter_printable(arg: str, replace: str, limit: int) -> str: + return "".join( + (ch if ch.isprintable() else replace) + for ch in arg[:limit] + ) diff --git a/kvmd/validators/os.py b/kvmd/validators/os.py index 94d3a40f..b2381d0b 100644 --- a/kvmd/validators/os.py +++ b/kvmd/validators/os.py @@ -26,6 +26,7 @@ import stat from typing import Any from . import raise_error +from . import filter_printable from .basic import valid_number from .basic import valid_string_list @@ -75,9 +76,7 @@ def valid_abs_dir(arg: Any, name: str="") -> str: def valid_printable_filename(arg: Any, name: str="") -> str: if not name: name = "printable filename" - arg = valid_stripped_string_not_empty(arg, name) - if ( "/" in arg or "\0" in arg @@ -85,12 +84,7 @@ def valid_printable_filename(arg: Any, name: str="") -> str: or arg == "lost+found" ): raise_error(arg, name) - - arg = "".join( - (ch if ch.isprintable() else "_") - for ch in arg[:255] - ) - return arg + return filter_printable(arg, "_", 255) # ===== diff --git a/kvmd/validators/switch.py b/kvmd/validators/switch.py new file mode 100644 index 00000000..d4f3ab2f --- /dev/null +++ b/kvmd/validators/switch.py @@ -0,0 +1,67 @@ +# ========================================================================== # +# # +# KVMD - The main PiKVM daemon. # +# # +# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see <https://www.gnu.org/licenses/>. # +# # +# ========================================================================== # + + +import re + +from typing import Any + +from . import filter_printable +from . import check_re_match + +from .basic import valid_stripped_string +from .basic import valid_number + + +# ===== +def valid_switch_port_name(arg: Any) -> str: + arg = valid_stripped_string(arg, name="switch port name") + arg = filter_printable(arg, " ", 255) + arg = re.sub(r"\s+", " ", arg) + return arg.strip() + + +def valid_switch_edid_id(arg: Any, allow_default: bool) -> str: + pattern = "(?i)^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" + if allow_default: + pattern += "|^default$" + return check_re_match(arg, "switch EDID ID", pattern).lower() + + +def valid_switch_edid_data(arg: Any) -> str: + name = "switch EDID data" + arg = valid_stripped_string(arg, name=name) + arg = re.sub(r"\s", "", arg) + return check_re_match(arg, name, "(?i)^[0-9a-f]{512}$").upper() + + +def valid_switch_color(arg: Any, allow_default: bool) -> str: + pattern = "(?i)^[0-9a-f]{6}:[0-9a-f]{2}:[0-9a-f]{4}$" + if allow_default: + pattern += "|^default$" + arg = check_re_match(arg, "switch color", pattern).upper() + if arg == "DEFAULT": + arg = "default" + return arg + + +def valid_switch_atx_click_delay(arg: Any) -> float: + return valid_number(arg, min=0, max=10, type=float, name="ATX delay") |