summaryrefslogtreecommitdiff
path: root/kvmd
diff options
context:
space:
mode:
Diffstat (limited to 'kvmd')
-rw-r--r--kvmd/aiotools.py5
-rw-r--r--kvmd/apps/__init__.py5
-rw-r--r--kvmd/apps/kvmd/__init__.py5
-rw-r--r--kvmd/apps/kvmd/api/switch.py164
-rw-r--r--kvmd/apps/kvmd/server.py7
-rw-r--r--kvmd/apps/kvmd/switch/__init__.py400
-rw-r--r--kvmd/apps/kvmd/switch/chain.py440
-rw-r--r--kvmd/apps/kvmd/switch/device.py196
-rw-r--r--kvmd/apps/kvmd/switch/lib.py35
-rw-r--r--kvmd/apps/kvmd/switch/proto.py295
-rw-r--r--kvmd/apps/kvmd/switch/state.py355
-rw-r--r--kvmd/apps/kvmd/switch/storage.py186
-rw-r--r--kvmd/apps/kvmd/switch/types.py308
-rw-r--r--kvmd/apps/pst/server.py12
-rw-r--r--kvmd/clients/pst.py93
-rw-r--r--kvmd/validators/__init__.py8
-rw-r--r--kvmd/validators/os.py10
-rw-r--r--kvmd/validators/switch.py67
18 files changed, 2582 insertions, 9 deletions
diff --git a/kvmd/aiotools.py b/kvmd/aiotools.py
index a47c94c6..6183690f 100644
--- a/kvmd/aiotools.py
+++ b/kvmd/aiotools.py
@@ -45,6 +45,11 @@ async def read_file(path: str) -> str:
return (await file.read())
+async def write_file(path: str, text: str) -> None:
+ async with aiofiles.open(path, "w") as file:
+ await file.write(text)
+
+
# =====
def run(coro: Coroutine, final: (Coroutine | None)=None) -> None:
# https://github.com/aio-libs/aiohttp/blob/a1d4dac1d/aiohttp/web.py#L515
diff --git a/kvmd/apps/__init__.py b/kvmd/apps/__init__.py
index cfc39499..2090e5c6 100644
--- a/kvmd/apps/__init__.py
+++ b/kvmd/apps/__init__.py
@@ -502,6 +502,11 @@ def _get_config_scheme() -> dict:
"table": Option([], type=valid_ugpio_view_table),
},
},
+
+ "switch": {
+ "device": Option("/dev/kvmd-switch", type=valid_abs_path, unpack_as="device_path"),
+ "default_edid": Option("/etc/kvmd/switch-edid.hex", type=valid_abs_path, unpack_as="default_edid_path"),
+ },
},
"pst": {
diff --git a/kvmd/apps/kvmd/__init__.py b/kvmd/apps/kvmd/__init__.py
index 495a320f..088a62ef 100644
--- a/kvmd/apps/kvmd/__init__.py
+++ b/kvmd/apps/kvmd/__init__.py
@@ -35,6 +35,7 @@ from .ugpio import UserGpio
from .streamer import Streamer
from .snapshoter import Snapshoter
from .ocr import Ocr
+from .switch import Switch
from .server import KvmdServer
@@ -90,6 +91,10 @@ def main(argv: (list[str] | None)=None) -> None:
log_reader=(LogReader() if config.log_reader.enabled else None),
user_gpio=UserGpio(config.gpio, global_config.otg),
ocr=Ocr(**config.ocr._unpack()),
+ switch=Switch(
+ pst_unix_path=global_config.pst.server.unix,
+ **config.switch._unpack(),
+ ),
hid=hid,
atx=get_atx_class(config.atx.type)(**config.atx._unpack(ignore=["type"])),
diff --git a/kvmd/apps/kvmd/api/switch.py b/kvmd/apps/kvmd/api/switch.py
new file mode 100644
index 00000000..bf91b83e
--- /dev/null
+++ b/kvmd/apps/kvmd/api/switch.py
@@ -0,0 +1,164 @@
+# ========================================================================== #
+# #
+# KVMD - The main PiKVM daemon. #
+# #
+# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <https://www.gnu.org/licenses/>. #
+# #
+# ========================================================================== #
+
+
+from aiohttp.web import Request
+from aiohttp.web import Response
+
+from ....htserver import exposed_http
+from ....htserver import make_json_response
+
+from ....validators.basic import valid_bool
+from ....validators.basic import valid_int_f0
+from ....validators.basic import valid_stripped_string_not_empty
+from ....validators.kvm import valid_atx_power_action
+from ....validators.kvm import valid_atx_button
+from ....validators.switch import valid_switch_port_name
+from ....validators.switch import valid_switch_edid_id
+from ....validators.switch import valid_switch_edid_data
+from ....validators.switch import valid_switch_color
+from ....validators.switch import valid_switch_atx_click_delay
+
+from ..switch import Switch
+from ..switch import Colors
+
+
+# =====
+class SwitchApi:
+ def __init__(self, switch: Switch) -> None:
+ self.__switch = switch
+
+ # =====
+
+ @exposed_http("GET", "/switch")
+ async def __state_handler(self, _: Request) -> Response:
+ return make_json_response(await self.__switch.get_state())
+
+ @exposed_http("POST", "/switch/set_active")
+ async def __set_active_port_handler(self, req: Request) -> Response:
+ port = valid_int_f0(req.query.get("port"))
+ await self.__switch.set_active_port(port)
+ return make_json_response()
+
+ @exposed_http("POST", "/switch/set_beacon")
+ async def __set_beacon_handler(self, req: Request) -> Response:
+ on = valid_bool(req.query.get("state"))
+ if "port" in req.query:
+ port = valid_int_f0(req.query.get("port"))
+ await self.__switch.set_port_beacon(port, on)
+ elif "uplink" in req.query:
+ unit = valid_int_f0(req.query.get("uplink"))
+ await self.__switch.set_uplink_beacon(unit, on)
+ else: # Downlink
+ unit = valid_int_f0(req.query.get("downlink"))
+ await self.__switch.set_downlink_beacon(unit, on)
+ return make_json_response()
+
+ @exposed_http("POST", "/switch/set_port_params")
+ async def __set_port_params(self, req: Request) -> Response:
+ port = valid_int_f0(req.query.get("port"))
+ params = {
+ param: validator(req.query.get(param))
+ for (param, validator) in [
+ ("edid_id", (lambda arg: valid_switch_edid_id(arg, allow_default=True))),
+ ("name", valid_switch_port_name),
+ ("atx_click_power_delay", valid_switch_atx_click_delay),
+ ("atx_click_power_long_delay", valid_switch_atx_click_delay),
+ ("atx_click_reset_delay", valid_switch_atx_click_delay),
+ ]
+ if req.query.get(param) is not None
+ }
+ await self.__switch.set_port_params(port, **params) # type: ignore
+ return make_json_response()
+
+ @exposed_http("POST", "/switch/set_colors")
+ async def __set_colors(self, req: Request) -> Response:
+ params = {
+ param: valid_switch_color(req.query.get(param), allow_default=True)
+ for param in Colors.ROLES
+ if req.query.get(param) is not None
+ }
+ await self.__switch.set_colors(**params)
+ return make_json_response()
+
+ # =====
+
+ @exposed_http("POST", "/switch/reset")
+ async def __reset(self, req: Request) -> Response:
+ unit = valid_int_f0(req.query.get("unit"))
+ bootloader = valid_bool(req.query.get("bootloader", False))
+ await self.__switch.reboot_unit(unit, bootloader)
+ return make_json_response()
+
+ # =====
+
+ @exposed_http("POST", "/switch/edids/create")
+ async def __create_edid(self, req: Request) -> Response:
+ name = valid_stripped_string_not_empty(req.query.get("name"))
+ data_hex = valid_switch_edid_data(req.query.get("data"))
+ edid_id = await self.__switch.create_edid(name, data_hex)
+ return make_json_response({"id": edid_id})
+
+ @exposed_http("POST", "/switch/edids/change")
+ async def __change_edid(self, req: Request) -> Response:
+ edid_id = valid_switch_edid_id(req.query.get("id"), allow_default=False)
+ params = {
+ param: validator(req.query.get(param))
+ for (param, validator) in [
+ ("name", valid_switch_port_name),
+ ("data", valid_switch_edid_data),
+ ]
+ if req.query.get(param) is not None
+ }
+ if params:
+ await self.__switch.change_edid(edid_id, **params)
+ return make_json_response()
+
+ @exposed_http("POST", "/switch/edids/remove")
+ async def __remove_edid(self, req: Request) -> Response:
+ edid_id = valid_switch_edid_id(req.query.get("id"), allow_default=False)
+ await self.__switch.remove_edid(edid_id)
+ return make_json_response()
+
+ # =====
+
+ @exposed_http("POST", "/switch/atx/power")
+ async def __power_handler(self, req: Request) -> Response:
+ port = valid_int_f0(req.query.get("port"))
+ action = valid_atx_power_action(req.query.get("action"))
+ await ({
+ "on": self.__switch.atx_power_on,
+ "off": self.__switch.atx_power_off,
+ "off_hard": self.__switch.atx_power_off_hard,
+ "reset_hard": self.__switch.atx_power_reset_hard,
+ }[action])(port)
+ return make_json_response()
+
+ @exposed_http("POST", "/switch/atx/click")
+ async def __click_handler(self, req: Request) -> Response:
+ port = valid_int_f0(req.query.get("port"))
+ button = valid_atx_button(req.query.get("button"))
+ await ({
+ "power": self.__switch.atx_click_power,
+ "power_long": self.__switch.atx_click_power_long,
+ "reset": self.__switch.atx_click_reset,
+ }[button])(port)
+ return make_json_response()
diff --git a/kvmd/apps/kvmd/server.py b/kvmd/apps/kvmd/server.py
index ed85bb24..92eb496c 100644
--- a/kvmd/apps/kvmd/server.py
+++ b/kvmd/apps/kvmd/server.py
@@ -66,6 +66,7 @@ from .ugpio import UserGpio
from .streamer import Streamer
from .snapshoter import Snapshoter
from .ocr import Ocr
+from .switch import Switch
from .api.auth import AuthApi
from .api.auth import check_request_auth
@@ -77,6 +78,7 @@ from .api.hid import HidApi
from .api.atx import AtxApi
from .api.msd import MsdApi
from .api.streamer import StreamerApi
+from .api.switch import SwitchApi
from .api.export import ExportApi
from .api.redfish import RedfishApi
@@ -125,7 +127,6 @@ class _Subsystem:
cleanup=getattr(obj, "cleanup", None),
trigger_state=getattr(obj, "trigger_state", None),
poll_state=getattr(obj, "poll_state", None),
-
)
@@ -137,6 +138,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
__EV_STREAMER_STATE = "streamer_state"
__EV_OCR_STATE = "ocr_state"
__EV_INFO_STATE = "info_state"
+ __EV_SWITCH_STATE = "switch_state"
def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
@@ -145,6 +147,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
log_reader: (LogReader | None),
user_gpio: UserGpio,
ocr: Ocr,
+ switch: Switch,
hid: BaseHid,
atx: BaseAtx,
@@ -177,6 +180,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
AtxApi(atx),
MsdApi(msd),
StreamerApi(streamer, ocr),
+ SwitchApi(switch),
ExportApi(info_manager, atx, user_gpio),
RedfishApi(info_manager, atx),
]
@@ -189,6 +193,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
_Subsystem.make(streamer, "Streamer", self.__EV_STREAMER_STATE),
_Subsystem.make(ocr, "OCR", self.__EV_OCR_STATE),
_Subsystem.make(info_manager, "Info manager", self.__EV_INFO_STATE),
+ _Subsystem.make(switch, "Switch", self.__EV_SWITCH_STATE),
]
self.__streamer_notifier = aiotools.AioNotifier()
diff --git a/kvmd/apps/kvmd/switch/__init__.py b/kvmd/apps/kvmd/switch/__init__.py
new file mode 100644
index 00000000..49bfbd7d
--- /dev/null
+++ b/kvmd/apps/kvmd/switch/__init__.py
@@ -0,0 +1,400 @@
+# ========================================================================== #
+# #
+# KVMD - The main PiKVM daemon. #
+# #
+# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <https://www.gnu.org/licenses/>. #
+# #
+# ========================================================================== #
+
+
+import os
+import asyncio
+
+from typing import AsyncGenerator
+
+from .lib import OperationError
+from .lib import get_logger
+from .lib import aiotools
+from .lib import Inotify
+
+from .types import Edid
+from .types import Edids
+from .types import Color
+from .types import Colors
+from .types import PortNames
+from .types import AtxClickPowerDelays
+from .types import AtxClickPowerLongDelays
+from .types import AtxClickResetDelays
+
+from .chain import DeviceFoundEvent
+from .chain import ChainTruncatedEvent
+from .chain import PortActivatedEvent
+from .chain import UnitStateEvent
+from .chain import UnitAtxLedsEvent
+from .chain import Chain
+
+from .state import StateCache
+
+from .storage import Storage
+
+
+# =====
+class SwitchError(Exception):
+ pass
+
+
+class SwitchOperationError(OperationError, SwitchError):
+ pass
+
+
+class SwitchUnknownEdidError(SwitchOperationError):
+ def __init__(self) -> None:
+ super().__init__("No specified EDID ID found")
+
+
+# =====
+class Switch: # pylint: disable=too-many-public-methods
+ __X_EDIDS = "edids"
+ __X_COLORS = "colors"
+ __X_PORT_NAMES = "port_names"
+ __X_ATX_CP_DELAYS = "atx_cp_delays"
+ __X_ATX_CPL_DELAYS = "atx_cpl_delays"
+ __X_ATX_CR_DELAYS = "atx_cr_delays"
+
+ __X_ALL = frozenset([
+ __X_EDIDS, __X_COLORS, __X_PORT_NAMES,
+ __X_ATX_CP_DELAYS, __X_ATX_CPL_DELAYS, __X_ATX_CR_DELAYS,
+ ])
+
+ def __init__(
+ self,
+ device_path: str,
+ default_edid_path: str,
+ pst_unix_path: str,
+ ) -> None:
+
+ self.__default_edid_path = default_edid_path
+
+ self.__chain = Chain(device_path)
+ self.__cache = StateCache()
+ self.__storage = Storage(pst_unix_path)
+
+ self.__lock = asyncio.Lock()
+
+ self.__save_notifier = aiotools.AioNotifier()
+
+ # =====
+
+ def __x_set_edids(self, edids: Edids, save: bool=True) -> None:
+ self.__chain.set_edids(edids)
+ self.__cache.set_edids(edids)
+ if save:
+ self.__save_notifier.notify()
+
+ def __x_set_colors(self, colors: Colors, save: bool=True) -> None:
+ self.__chain.set_colors(colors)
+ self.__cache.set_colors(colors)
+ if save:
+ self.__save_notifier.notify()
+
+ def __x_set_port_names(self, port_names: PortNames, save: bool=True) -> None:
+ self.__cache.set_port_names(port_names)
+ if save:
+ self.__save_notifier.notify()
+
+ def __x_set_atx_cp_delays(self, delays: AtxClickPowerDelays, save: bool=True) -> None:
+ self.__cache.set_atx_cp_delays(delays)
+ if save:
+ self.__save_notifier.notify()
+
+ def __x_set_atx_cpl_delays(self, delays: AtxClickPowerLongDelays, save: bool=True) -> None:
+ self.__cache.set_atx_cpl_delays(delays)
+ if save:
+ self.__save_notifier.notify()
+
+ def __x_set_atx_cr_delays(self, delays: AtxClickResetDelays, save: bool=True) -> None:
+ self.__cache.set_atx_cr_delays(delays)
+ if save:
+ self.__save_notifier.notify()
+
+ # =====
+
+ async def set_active_port(self, port: int) -> None:
+ self.__chain.set_active_port(port)
+
+ # =====
+
+ async def set_port_beacon(self, port: int, on: bool) -> None:
+ self.__chain.set_port_beacon(port, on)
+
+ async def set_uplink_beacon(self, unit: int, on: bool) -> None:
+ self.__chain.set_uplink_beacon(unit, on)
+
+ async def set_downlink_beacon(self, unit: int, on: bool) -> None:
+ self.__chain.set_downlink_beacon(unit, on)
+
+ # =====
+
+ async def atx_power_on(self, port: int) -> None:
+ self.__inner_atx_cp(port, False, self.__X_ATX_CP_DELAYS)
+
+ async def atx_power_off(self, port: int) -> None:
+ self.__inner_atx_cp(port, True, self.__X_ATX_CP_DELAYS)
+
+ async def atx_power_off_hard(self, port: int) -> None:
+ self.__inner_atx_cp(port, True, self.__X_ATX_CPL_DELAYS)
+
+ async def atx_power_reset_hard(self, port: int) -> None:
+ self.__inner_atx_cr(port, True)
+
+ async def atx_click_power(self, port: int) -> None:
+ self.__inner_atx_cp(port, None, self.__X_ATX_CP_DELAYS)
+
+ async def atx_click_power_long(self, port: int) -> None:
+ self.__inner_atx_cp(port, None, self.__X_ATX_CPL_DELAYS)
+
+ async def atx_click_reset(self, port: int) -> None:
+ self.__inner_atx_cr(port, None)
+
+ def __inner_atx_cp(self, port: int, if_powered: (bool | None), x_delay: str) -> None:
+ assert x_delay in [self.__X_ATX_CP_DELAYS, self.__X_ATX_CPL_DELAYS]
+ delay = getattr(self.__cache, f"get_{x_delay}")()[port]
+ self.__chain.click_power(port, delay, if_powered)
+
+ def __inner_atx_cr(self, port: int, if_powered: (bool | None)) -> None:
+ delay = self.__cache.get_atx_cr_delays()[port]
+ self.__chain.click_reset(port, delay, if_powered)
+
+ # =====
+
+ async def create_edid(self, name: str, data_hex: str) -> str:
+ async with self.__lock:
+ edids = self.__cache.get_edids()
+ edid_id = edids.add(Edid.from_data(name, data_hex))
+ self.__x_set_edids(edids)
+ return edid_id
+
+ async def change_edid(
+ self,
+ edid_id: str,
+ name: (str | None)=None,
+ data_hex: (str | None)=None,
+ ) -> None:
+
+ assert edid_id != Edids.DEFAULT_ID
+ async with self.__lock:
+ edids = self.__cache.get_edids()
+ if not edids.has_id(edid_id):
+ raise SwitchUnknownEdidError()
+ old = edids.get(edid_id)
+ name = (name or old.name)
+ data_hex = (data_hex or old.as_text())
+ edids.set(edid_id, Edid.from_data(name, data_hex))
+ self.__x_set_edids(edids)
+
+ async def remove_edid(self, edid_id: str) -> None:
+ assert edid_id != Edids.DEFAULT_ID
+ async with self.__lock:
+ edids = self.__cache.get_edids()
+ if not edids.has_id(edid_id):
+ raise SwitchUnknownEdidError()
+ edids.remove(edid_id)
+ self.__x_set_edids(edids)
+
+ # =====
+
+ async def set_colors(self, **values: str) -> None:
+ async with self.__lock:
+ old = self.__cache.get_colors()
+ new = {}
+ for role in Colors.ROLES:
+ if role in values:
+ if values[role] != "default":
+ new[role] = Color.from_text(values[role])
+ # else reset to default
+ else:
+ new[role] = getattr(old, role)
+ self.__x_set_colors(Colors(**new)) # type: ignore
+
+ # =====
+
+ async def set_port_params(
+ self,
+ port: int,
+ edid_id: (str | None)=None,
+ name: (str | None)=None,
+ atx_click_power_delay: (float | None)=None,
+ atx_click_power_long_delay: (float | None)=None,
+ atx_click_reset_delay: (float | None)=None,
+ ) -> None:
+
+ async with self.__lock:
+ if edid_id is not None:
+ edids = self.__cache.get_edids()
+ if not edids.has_id(edid_id):
+ raise SwitchUnknownEdidError()
+ edids.assign(port, edid_id)
+ self.__x_set_edids(edids)
+
+ for (key, value) in [
+ (self.__X_PORT_NAMES, name),
+ (self.__X_ATX_CP_DELAYS, atx_click_power_delay),
+ (self.__X_ATX_CPL_DELAYS, atx_click_power_long_delay),
+ (self.__X_ATX_CR_DELAYS, atx_click_reset_delay),
+ ]:
+ if value is not None:
+ new = getattr(self.__cache, f"get_{key}")()
+ new[port] = (value or None) # None == reset to default
+ getattr(self, f"_Switch__x_set_{key}")(new)
+
+ # =====
+
+ async def reboot_unit(self, unit: int, bootloader: bool) -> None:
+ self.__chain.reboot_unit(unit, bootloader)
+
+ # =====
+
+ async def get_state(self) -> dict:
+ return self.__cache.get_state()
+
+ async def trigger_state(self) -> None:
+ await self.__cache.trigger_state()
+
+ async def poll_state(self) -> AsyncGenerator[dict, None]:
+ async for state in self.__cache.poll_state():
+ yield state
+
+ # =====
+
+ async def systask(self) -> None:
+ tasks = [
+ asyncio.create_task(self.__systask_events()),
+ asyncio.create_task(self.__systask_default_edid()),
+ asyncio.create_task(self.__systask_storage()),
+ ]
+ try:
+ await asyncio.gather(*tasks)
+ except Exception:
+ for task in tasks:
+ task.cancel()
+ await asyncio.gather(*tasks, return_exceptions=True)
+ raise
+
+ async def __systask_events(self) -> None:
+ async for event in self.__chain.poll_events():
+ match event:
+ case DeviceFoundEvent():
+ await self.__load_configs()
+ case ChainTruncatedEvent():
+ self.__cache.truncate(event.units)
+ case PortActivatedEvent():
+ self.__cache.update_active_port(event.port)
+ case UnitStateEvent():
+ self.__cache.update_unit_state(event.unit, event.state)
+ case UnitAtxLedsEvent():
+ self.__cache.update_unit_atx_leds(event.unit, event.atx_leds)
+
+ async def __load_configs(self) -> None:
+ async with self.__lock:
+ try:
+ async with self.__storage.readable() as ctx:
+ values = {
+ key: await getattr(ctx, f"read_{key}")()
+ for key in self.__X_ALL
+ }
+ data_hex = await aiotools.read_file(self.__default_edid_path)
+ values["edids"].set_default(data_hex)
+ except Exception:
+ get_logger(0).exception("Can't load configs")
+ else:
+ for (key, value) in values.items():
+ func = getattr(self, f"_Switch__x_set_{key}")
+ if isinstance(value, tuple):
+ func(*value, save=False)
+ else:
+ func(value, save=False)
+ self.__chain.set_actual(True)
+
+ async def __systask_default_edid(self) -> None:
+ logger = get_logger(0)
+ async for _ in self.__poll_default_edid():
+ async with self.__lock:
+ edids = self.__cache.get_edids()
+ try:
+ data_hex = await aiotools.read_file(self.__default_edid_path)
+ edids.set_default(data_hex)
+ except Exception:
+ logger.exception("Can't read default EDID, ignoring ...")
+ else:
+ self.__x_set_edids(edids, save=False)
+
+ async def __poll_default_edid(self) -> AsyncGenerator[None, None]:
+ logger = get_logger(0)
+ while True:
+ while not os.path.exists(self.__default_edid_path):
+ await asyncio.sleep(5)
+ try:
+ with Inotify() as inotify:
+ await inotify.watch_all_changes(self.__default_edid_path)
+ if os.path.islink(self.__default_edid_path):
+ await inotify.watch_all_changes(os.path.realpath(self.__default_edid_path))
+ yield None
+ while True:
+ need_restart = False
+ need_notify = False
+ for event in (await inotify.get_series(timeout=1)):
+ need_notify = True
+ if event.restart:
+ logger.warning("Got fatal inotify event: %s; reinitializing ...", event)
+ need_restart = True
+ break
+ if need_restart:
+ break
+ if need_notify:
+ yield None
+ except Exception:
+ logger.exception("Unexpected watcher error")
+ await asyncio.sleep(1)
+
+ async def __systask_storage(self) -> None:
+ # При остановке KVMD можем не успеть записать, ну да пофиг
+ prevs = dict.fromkeys(self.__X_ALL)
+ while True:
+ await self.__save_notifier.wait()
+ while (await self.__save_notifier.wait(5)):
+ pass
+ while True:
+ try:
+ async with self.__lock:
+ write = {
+ key: new
+ for (key, old) in prevs.items()
+ if (new := getattr(self.__cache, f"get_{key}")()) != old
+ }
+ if write:
+ async with self.__storage.writable() as ctx:
+ for (key, new) in write.items():
+ func = getattr(ctx, f"write_{key}")
+ if isinstance(new, tuple):
+ await func(*new)
+ else:
+ await func(new)
+ prevs[key] = new
+ except Exception:
+ get_logger(0).exception("Unexpected storage error")
+ await asyncio.sleep(5)
+ else:
+ break
diff --git a/kvmd/apps/kvmd/switch/chain.py b/kvmd/apps/kvmd/switch/chain.py
new file mode 100644
index 00000000..8e4d94eb
--- /dev/null
+++ b/kvmd/apps/kvmd/switch/chain.py
@@ -0,0 +1,440 @@
+# ========================================================================== #
+# #
+# KVMD - The main PiKVM daemon. #
+# #
+# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <https://www.gnu.org/licenses/>. #
+# #
+# ========================================================================== #
+
+
+import multiprocessing
+import queue
+import select
+import dataclasses
+import time
+
+from typing import AsyncGenerator
+
+from .lib import get_logger
+from .lib import tools
+from .lib import aiotools
+from .lib import aioproc
+
+from .types import Edids
+from .types import Colors
+
+from .proto import Response
+from .proto import UnitState
+from .proto import UnitAtxLeds
+
+from .device import Device
+from .device import DeviceError
+
+
+# =====
+class _BaseCmd:
+ pass
+
+
[email protected](frozen=True)
+class _CmdSetActual(_BaseCmd):
+ actual: bool
+
+
[email protected](frozen=True)
+class _CmdSetActivePort(_BaseCmd):
+ port: int
+
+ def __post_init__(self) -> None:
+ assert self.port >= 0
+
+
[email protected](frozen=True)
+class _CmdSetPortBeacon(_BaseCmd):
+ port: int
+ on: bool
+
+
[email protected](frozen=True)
+class _CmdSetUnitBeacon(_BaseCmd):
+ unit: int
+ on: bool
+ downlink: bool
+
+
[email protected](frozen=True)
+class _CmdSetEdids(_BaseCmd):
+ edids: Edids
+
+
[email protected](frozen=True)
+class _CmdSetColors(_BaseCmd):
+ colors: Colors
+
+
[email protected](frozen=True)
+class _CmdAtxClick(_BaseCmd):
+ port: int
+ delay: float
+ reset: bool
+ if_powered: (bool | None)
+
+ def __post_init__(self) -> None:
+ assert self.port >= 0
+ assert 0.001 <= self.delay <= (0xFFFF / 1000)
+
+
[email protected](frozen=True)
+class _CmdRebootUnit(_BaseCmd):
+ unit: int
+ bootloader: bool
+
+ def __post_init__(self) -> None:
+ assert self.unit >= 0
+
+
+class _UnitContext:
+ __TIMEOUT = 5.0
+
+ def __init__(self) -> None:
+ self.state: (UnitState | None) = None
+ self.atx_leds: (UnitAtxLeds | None) = None
+ self.__rid = -1
+ self.__deadline_ts = -1.0
+
+ def can_be_changed(self) -> bool:
+ return (
+ self.state is not None
+ and not self.state.flags.changing_busy
+ and self.changing_rid < 0
+ )
+
+ # =====
+
+ @property
+ def changing_rid(self) -> int:
+ if self.__deadline_ts >= 0 and self.__deadline_ts < time.monotonic():
+ self.__rid = -1
+ self.__deadline_ts = -1
+ return self.__rid
+
+ @changing_rid.setter
+ def changing_rid(self, rid: int) -> None:
+ self.__rid = rid
+ self.__deadline_ts = ((time.monotonic() + self.__TIMEOUT) if rid >= 0 else -1)
+
+ # =====
+
+ def is_atx_allowed(self, ch: int) -> tuple[bool, bool]: # (allowed, power_led)
+ if self.state is None or self.atx_leds is None:
+ return (False, False)
+ return ((not self.state.atx_busy[ch]), self.atx_leds.power[ch])
+
+
+# =====
+class BaseEvent:
+ pass
+
+
+class DeviceFoundEvent(BaseEvent):
+ pass
+
+
[email protected](frozen=True)
+class ChainTruncatedEvent(BaseEvent):
+ units: int
+
+
[email protected](frozen=True)
+class PortActivatedEvent(BaseEvent):
+ port: int
+
+
[email protected](frozen=True)
+class UnitStateEvent(BaseEvent):
+ unit: int
+ state: UnitState
+
+
[email protected](frozen=True)
+class UnitAtxLedsEvent(BaseEvent):
+ unit: int
+ atx_leds: UnitAtxLeds
+
+
+# =====
+class Chain: # pylint: disable=too-many-instance-attributes
+ def __init__(self, device_path: str) -> None:
+ self.__device = Device(device_path)
+
+ self.__actual = False
+
+ self.__edids = Edids()
+
+ self.__colors = Colors()
+
+ self.__units: list[_UnitContext] = []
+ self.__active_port = -1
+
+ self.__cmd_queue: "multiprocessing.Queue[_BaseCmd]" = multiprocessing.Queue()
+ self.__events_queue: "multiprocessing.Queue[BaseEvent]" = multiprocessing.Queue()
+
+ self.__stop_event = multiprocessing.Event()
+
+ def set_actual(self, actual: bool) -> None:
+ # Флаг разрешения синхронизации EDID и прочих чувствительных вещей
+ self.__queue_cmd(_CmdSetActual(actual))
+
+ # =====
+
+ def set_active_port(self, port: int) -> None:
+ self.__queue_cmd(_CmdSetActivePort(port))
+
+ # =====
+
+ def set_port_beacon(self, port: int, on: bool) -> None:
+ self.__queue_cmd(_CmdSetPortBeacon(port, on))
+
+ def set_uplink_beacon(self, unit: int, on: bool) -> None:
+ self.__queue_cmd(_CmdSetUnitBeacon(unit, on, downlink=False))
+
+ def set_downlink_beacon(self, unit: int, on: bool) -> None:
+ self.__queue_cmd(_CmdSetUnitBeacon(unit, on, downlink=True))
+
+ # =====
+
+ def set_edids(self, edids: Edids) -> None:
+ self.__queue_cmd(_CmdSetEdids(edids)) # Will be copied because of multiprocessing.Queue()
+
+ def set_colors(self, colors: Colors) -> None:
+ self.__queue_cmd(_CmdSetColors(colors))
+
+ # =====
+
+ def click_power(self, port: int, delay: float, if_powered: (bool | None)) -> None:
+ self.__queue_cmd(_CmdAtxClick(port, delay, reset=False, if_powered=if_powered))
+
+ def click_reset(self, port: int, delay: float, if_powered: (bool | None)) -> None:
+ self.__queue_cmd(_CmdAtxClick(port, delay, reset=True, if_powered=if_powered))
+
+ # =====
+
+ def reboot_unit(self, unit: int, bootloader: bool) -> None:
+ self.__queue_cmd(_CmdRebootUnit(unit, bootloader))
+
+ # =====
+
+ async def poll_events(self) -> AsyncGenerator[BaseEvent, None]:
+ proc = multiprocessing.Process(target=self.__subprocess, daemon=True)
+ try:
+ proc.start()
+ while True:
+ try:
+ yield (await aiotools.run_async(self.__events_queue.get, True, 0.1))
+ except queue.Empty:
+ pass
+ finally:
+ if proc.is_alive():
+ self.__stop_event.set()
+ if proc.is_alive() or proc.exitcode is not None:
+ await aiotools.run_async(proc.join)
+
+ # =====
+
+ def __queue_cmd(self, cmd: _BaseCmd) -> None:
+ if not self.__stop_event.is_set():
+ self.__cmd_queue.put_nowait(cmd)
+
+ def __queue_event(self, event: BaseEvent) -> None:
+ if not self.__stop_event.is_set():
+ self.__events_queue.put_nowait(event)
+
+ def __subprocess(self) -> None:
+ logger = aioproc.settle("Switch", "switch")
+ no_device_reported = False
+ while True:
+ try:
+ if self.__device.has_device():
+ no_device_reported = False
+ with self.__device:
+ logger.info("Switch found")
+ self.__queue_event(DeviceFoundEvent())
+ self.__main_loop()
+ elif not no_device_reported:
+ self.__queue_event(ChainTruncatedEvent(0))
+ logger.info("Switch is missing")
+ no_device_reported = True
+ except DeviceError as ex:
+ logger.error("%s", tools.efmt(ex))
+ except Exception:
+ logger.exception("Unexpected error in the Switch loop")
+ tools.clear_queue(self.__cmd_queue)
+ if self.__stop_event.is_set():
+ break
+ time.sleep(1)
+
+ def __main_loop(self) -> None:
+ self.__device.request_state()
+ self.__device.request_atx_leds()
+ while not self.__stop_event.is_set():
+ if self.__select():
+ for resp in self.__device.read_all():
+ self.__update_units(resp)
+ self.__adjust_start_port()
+ self.__finish_changing_request(resp)
+ self.__consume_commands()
+ self.__ensure_config()
+
+ def __select(self) -> bool:
+ try:
+ return bool(select.select([
+ self.__device.get_fd(),
+ self.__cmd_queue._reader, # type: ignore # pylint: disable=protected-access
+ ], [], [], 1)[0])
+ except Exception as ex:
+ raise DeviceError(ex)
+
+ def __consume_commands(self) -> None:
+ while not self.__cmd_queue.empty():
+ cmd = self.__cmd_queue.get()
+ match cmd:
+ case _CmdSetActual():
+ self.__actual = cmd.actual
+
+ case _CmdSetActivePort():
+ # Может быть вызвано изнутри при синхронизации
+ self.__active_port = cmd.port
+ self.__queue_event(PortActivatedEvent(self.__active_port))
+
+ case _CmdSetPortBeacon():
+ (unit, ch) = self.get_real_unit_channel(cmd.port)
+ self.__device.request_beacon(unit, ch, cmd.on)
+
+ case _CmdSetUnitBeacon():
+ ch = (4 if cmd.downlink else 5)
+ self.__device.request_beacon(cmd.unit, ch, cmd.on)
+
+ case _CmdAtxClick():
+ (unit, ch) = self.get_real_unit_channel(cmd.port)
+ if unit < len(self.__units):
+ (allowed, powered) = self.__units[unit].is_atx_allowed(ch)
+ if allowed and (cmd.if_powered is None or cmd.if_powered == powered):
+ delay_ms = min(int(cmd.delay * 1000), 0xFFFF)
+ if cmd.reset:
+ self.__device.request_atx_cr(unit, ch, delay_ms)
+ else:
+ self.__device.request_atx_cp(unit, ch, delay_ms)
+
+ case _CmdSetEdids():
+ self.__edids = cmd.edids
+
+ case _CmdSetColors():
+ self.__colors = cmd.colors
+
+ case _CmdRebootUnit():
+ self.__device.request_reboot(cmd.unit, cmd.bootloader)
+
+ def __update_units(self, resp: Response) -> None:
+ units = resp.header.unit + 1
+ while len(self.__units) < units:
+ self.__units.append(_UnitContext())
+
+ match resp.body:
+ case UnitState():
+ if not resp.body.flags.has_downlink and len(self.__units) > units:
+ del self.__units[units:]
+ self.__queue_event(ChainTruncatedEvent(units))
+ self.__units[resp.header.unit].state = resp.body
+ self.__queue_event(UnitStateEvent(resp.header.unit, resp.body))
+
+ case UnitAtxLeds():
+ self.__units[resp.header.unit].atx_leds = resp.body
+ self.__queue_event(UnitAtxLedsEvent(resp.header.unit, resp.body))
+
+ def __adjust_start_port(self) -> None:
+ if self.__active_port < 0:
+ for (unit, ctx) in enumerate(self.__units):
+ if ctx.state is not None and ctx.state.ch < 4:
+ # Trigger queue select()
+ port = self.get_virtual_port(unit, ctx.state.ch)
+ get_logger().info("Found an active port %d on [%d:%d]: Syncing ...",
+ port, unit, ctx.state.ch)
+ self.set_active_port(port)
+ break
+
+ def __finish_changing_request(self, resp: Response) -> None:
+ if self.__units[resp.header.unit].changing_rid == resp.header.rid:
+ self.__units[resp.header.unit].changing_rid = -1
+
+ # =====
+
+ def __ensure_config(self) -> None:
+ for (unit, ctx) in enumerate(self.__units):
+ if ctx.state is not None:
+ self.__ensure_config_port(unit, ctx)
+ if self.__actual:
+ self.__ensure_config_edids(unit, ctx)
+ self.__ensure_config_colors(unit, ctx)
+
+ def __ensure_config_port(self, unit: int, ctx: _UnitContext) -> None:
+ assert ctx.state is not None
+ if self.__active_port >= 0 and ctx.can_be_changed():
+ ch = self.get_unit_target_channel(unit, self.__active_port)
+ if ctx.state.ch != ch:
+ get_logger().info("Switching for active port %d: [%d:%d] -> [%d:%d] ...",
+ self.__active_port, unit, ctx.state.ch, unit, ch)
+ ctx.changing_rid = self.__device.request_switch(unit, ch)
+
+ def __ensure_config_edids(self, unit: int, ctx: _UnitContext) -> None:
+ assert self.__actual
+ assert ctx.state is not None
+ if ctx.can_be_changed():
+ for ch in range(4):
+ port = self.get_virtual_port(unit, ch)
+ edid = self.__edids.get_edid_for_port(port)
+ if not ctx.state.compare_edid(ch, edid):
+ get_logger().info("Changing EDID on port %d on [%d:%d]: %d/%d -> %d/%d (%s) ...",
+ port, unit, ch,
+ ctx.state.video_crc[ch], ctx.state.video_edid[ch],
+ edid.crc, edid.valid, edid.name)
+ ctx.changing_rid = self.__device.request_set_edid(unit, ch, edid)
+ break # Busy globally
+
+ def __ensure_config_colors(self, unit: int, ctx: _UnitContext) -> None:
+ assert self.__actual
+ assert ctx.state is not None
+ for np in range(6):
+ if self.__colors.crc != ctx.state.np_crc[np]:
+ # get_logger().info("Changing colors on NP [%d:%d]: %d -> %d ...",
+ # unit, np, ctx.state.np_crc[np], self.__colors.crc)
+ self.__device.request_set_colors(unit, np, self.__colors)
+
+ # =====
+
+ @classmethod
+ def get_real_unit_channel(cls, port: int) -> tuple[int, int]:
+ return (port // 4, port % 4)
+
+ @classmethod
+ def get_unit_target_channel(cls, unit: int, port: int) -> int:
+ (t_unit, t_ch) = cls.get_real_unit_channel(port)
+ if unit != t_unit:
+ t_ch = 4
+ return t_ch
+
+ @classmethod
+ def get_virtual_port(cls, unit: int, ch: int) -> int:
+ return (unit * 4) + ch
diff --git a/kvmd/apps/kvmd/switch/device.py b/kvmd/apps/kvmd/switch/device.py
new file mode 100644
index 00000000..b56cc406
--- /dev/null
+++ b/kvmd/apps/kvmd/switch/device.py
@@ -0,0 +1,196 @@
+# ========================================================================== #
+# #
+# KVMD - The main PiKVM daemon. #
+# #
+# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <https://www.gnu.org/licenses/>. #
+# #
+# ========================================================================== #
+
+
+import os
+import random
+import types
+
+import serial
+
+from .lib import tools
+
+from .types import Edid
+from .types import Colors
+
+from .proto import Packable
+from .proto import Request
+from .proto import Response
+from .proto import Header
+
+from .proto import BodySwitch
+from .proto import BodySetBeacon
+from .proto import BodyAtxClick
+from .proto import BodySetEdid
+from .proto import BodyClearEdid
+from .proto import BodySetColors
+
+
+# =====
+class DeviceError(Exception):
+ def __init__(self, ex: Exception):
+ super().__init__(tools.efmt(ex))
+
+
+class Device:
+ __SPEED = 115200
+ __TIMEOUT = 5.0
+
+ def __init__(self, device_path: str) -> None:
+ self.__device_path = device_path
+ self.__rid = random.randint(1, 0xFFFF)
+ self.__tty: (serial.Serial | None) = None
+ self.__buf: bytes = b""
+
+ def __enter__(self) -> "Device":
+ try:
+ self.__tty = serial.Serial(
+ self.__device_path,
+ baudrate=self.__SPEED,
+ timeout=self.__TIMEOUT,
+ )
+ except Exception as ex:
+ raise DeviceError(ex)
+ return self
+
+ def __exit__(
+ self,
+ _exc_type: type[BaseException],
+ _exc: BaseException,
+ _tb: types.TracebackType,
+ ) -> None:
+
+ if self.__tty is not None:
+ try:
+ self.__tty.close()
+ except Exception:
+ pass
+ self.__tty = None
+
+ def has_device(self) -> bool:
+ return os.path.exists(self.__device_path)
+
+ def get_fd(self) -> int:
+ assert self.__tty is not None
+ return self.__tty.fd
+
+ def read_all(self) -> list[Response]:
+ assert self.__tty is not None
+ try:
+ if not self.__tty.in_waiting:
+ return []
+ self.__buf += self.__tty.read_all()
+ except Exception as ex:
+ raise DeviceError(ex)
+
+ results: list[Response] = []
+ while True:
+ try:
+ begin = self.__buf.index(0xF1)
+ except ValueError:
+ break
+ try:
+ end = self.__buf.index(0xF2, begin)
+ except ValueError:
+ break
+ msg = self.__buf[begin + 1:end]
+ if 0xF1 in msg:
+ # raise RuntimeError(f"Found 0xF1 inside the message: {msg!r}")
+ break
+ self.__buf = self.__buf[end + 1:]
+ msg = self.__unescape(msg)
+ resp = Response.unpack(msg)
+ if resp is not None:
+ results.append(resp)
+ return results
+
+ def __unescape(self, msg: bytes) -> bytes:
+ if 0xF0 not in msg:
+ return msg
+ unesc: list[int] = []
+ esc = False
+ for ch in msg:
+ if ch == 0xF0:
+ esc = True
+ else:
+ if esc:
+ ch ^= 0xFF
+ esc = False
+ unesc.append(ch)
+ return bytes(unesc)
+
+ def request_reboot(self, unit: int, bootloader: bool) -> int:
+ return self.__send_request((Header.BOOTLOADER if bootloader else Header.REBOOT), unit, None)
+
+ def request_state(self) -> int:
+ return self.__send_request(Header.STATE, 0xFF, None)
+
+ def request_switch(self, unit: int, ch: int) -> int:
+ return self.__send_request(Header.SWITCH, unit, BodySwitch(ch))
+
+ def request_beacon(self, unit: int, ch: int, on: bool) -> int:
+ return self.__send_request(Header.BEACON, unit, BodySetBeacon(ch, on))
+
+ def request_atx_leds(self) -> int:
+ return self.__send_request(Header.ATX_LEDS, 0xFF, None)
+
+ def request_atx_cp(self, unit: int, ch: int, delay_ms: int) -> int:
+ return self.__send_request(Header.ATX_CLICK, unit, BodyAtxClick(ch, BodyAtxClick.POWER, delay_ms))
+
+ def request_atx_cr(self, unit: int, ch: int, delay_ms: int) -> int:
+ return self.__send_request(Header.ATX_CLICK, unit, BodyAtxClick(ch, BodyAtxClick.RESET, delay_ms))
+
+ def request_set_edid(self, unit: int, ch: int, edid: Edid) -> int:
+ if edid.valid:
+ return self.__send_request(Header.SET_EDID, unit, BodySetEdid(ch, edid))
+ return self.__send_request(Header.CLEAR_EDID, unit, BodyClearEdid(ch))
+
+ def request_set_colors(self, unit: int, ch: int, colors: Colors) -> int:
+ return self.__send_request(Header.SET_COLORS, unit, BodySetColors(ch, colors))
+
+ def __send_request(self, op: int, unit: int, body: (Packable | None)) -> int:
+ assert self.__tty is not None
+ req = Request(Header(
+ proto=1,
+ rid=self.__get_next_rid(),
+ op=op,
+ unit=unit,
+ ), body)
+ data: list[int] = [0xF1]
+ for ch in req.pack():
+ if 0xF0 <= ch <= 0xF2:
+ data.append(0xF0)
+ ch ^= 0xFF
+ data.append(ch)
+ data.append(0xF2)
+ try:
+ self.__tty.write(bytes(data))
+ self.__tty.flush()
+ except Exception as ex:
+ raise DeviceError(ex)
+ return req.header.rid
+
+ def __get_next_rid(self) -> int:
+ rid = self.__rid
+ self.__rid += 1
+ if self.__rid > 0xFFFF:
+ self.__rid = 1
+ return rid
diff --git a/kvmd/apps/kvmd/switch/lib.py b/kvmd/apps/kvmd/switch/lib.py
new file mode 100644
index 00000000..4ef2647e
--- /dev/null
+++ b/kvmd/apps/kvmd/switch/lib.py
@@ -0,0 +1,35 @@
+# ========================================================================== #
+# #
+# KVMD - The main PiKVM daemon. #
+# #
+# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <https://www.gnu.org/licenses/>. #
+# #
+# ========================================================================== #
+
+
+# pylint: disable=unused-import
+
+from ....logging import get_logger # noqa: F401
+
+from .... import tools # noqa: F401
+from .... import aiotools # noqa: F401
+from .... import aioproc # noqa: F401
+from .... import bitbang # noqa: F401
+from .... import htclient # noqa: F401
+from ....inotify import Inotify # noqa: F401
+from ....errors import OperationError # noqa: F401
+from ....edid import EdidNoBlockError as ParsedEdidNoBlockError # noqa: F401
+from ....edid import Edid as ParsedEdid # noqa: F401
diff --git a/kvmd/apps/kvmd/switch/proto.py b/kvmd/apps/kvmd/switch/proto.py
new file mode 100644
index 00000000..d4f43f84
--- /dev/null
+++ b/kvmd/apps/kvmd/switch/proto.py
@@ -0,0 +1,295 @@
+# ========================================================================== #
+# #
+# KVMD - The main PiKVM daemon. #
+# #
+# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <https://www.gnu.org/licenses/>. #
+# #
+# ========================================================================== #
+
+
+import struct
+import dataclasses
+
+from typing import Optional
+
+from .types import Edid
+from .types import Colors
+
+
+# =====
+class Packable:
+ def pack(self) -> bytes:
+ raise NotImplementedError()
+
+
+class Unpackable:
+ @classmethod
+ def unpack(cls, data: bytes, offset: int=0) -> "Unpackable":
+ raise NotImplementedError()
+
+
+# =====
[email protected](frozen=True)
+class Header(Packable, Unpackable):
+ proto: int
+ rid: int
+ op: int
+ unit: int
+
+ NAK = 0
+ BOOTLOADER = 2
+ REBOOT = 3
+ STATE = 4
+ SWITCH = 5
+ BEACON = 6
+ ATX_LEDS = 7
+ ATX_CLICK = 8
+ SET_EDID = 9
+ CLEAR_EDID = 10
+ SET_COLORS = 12
+
+ __struct = struct.Struct("<BHBB")
+
+ SIZE = __struct.size
+
+ def pack(self) -> bytes:
+ return self.__struct.pack(self.proto, self.rid, self.op, self.unit)
+
+ @classmethod
+ def unpack(cls, data: bytes, offset: int=0) -> "Header":
+ return Header(*cls.__struct.unpack_from(data, offset=offset))
+
+
[email protected](frozen=True)
+class Nak(Unpackable):
+ reason: int
+
+ INVALID_COMMAND = 0
+ BUSY = 1
+ NO_DOWNLINK = 2
+ DOWNLINK_OVERFLOW = 3
+
+ __struct = struct.Struct("<B")
+
+ @classmethod
+ def unpack(cls, data: bytes, offset: int=0) -> "Nak":
+ return Nak(*cls.__struct.unpack_from(data, offset=offset))
+
+
[email protected](frozen=True)
+class UnitFlags:
+ changing_busy: bool
+ flashing_busy: bool
+ has_downlink: bool
+
+
[email protected](frozen=True)
+class UnitState(Unpackable): # pylint: disable=too-many-instance-attributes
+ sw_version: int
+ hw_version: int
+ flags: UnitFlags
+ ch: int
+ beacons: tuple[bool, bool, bool, bool, bool, bool]
+ np_crc: tuple[int, int, int, int, int, int]
+ video_5v_sens: tuple[bool, bool, bool, bool, bool]
+ video_hpd: tuple[bool, bool, bool, bool, bool]
+ video_edid: tuple[bool, bool, bool, bool]
+ video_crc: tuple[int, int, int, int]
+ usb_5v_sens: tuple[bool, bool, bool, bool]
+ atx_busy: tuple[bool, bool, bool, bool]
+
+ __struct = struct.Struct("<HHHBBHHHHHHBBBHHHHBxB30x")
+
+ def compare_edid(self, ch: int, edid: Optional["Edid"]) -> bool:
+ if edid is None:
+ # Сойдет любой невалидный EDID
+ return (not self.video_edid[ch])
+ return (
+ self.video_edid[ch] == edid.valid
+ and self.video_crc[ch] == edid.crc
+ )
+
+ @classmethod
+ def unpack(cls, data: bytes, offset: int=0) -> "UnitState": # pylint: disable=too-many-locals
+ (
+ sw_version, hw_version, flags, ch,
+ beacons, nc0, nc1, nc2, nc3, nc4, nc5,
+ video_5v_sens, video_hpd, video_edid, vc0, vc1, vc2, vc3,
+ usb_5v_sens, atx_busy,
+ ) = cls.__struct.unpack_from(data, offset=offset)
+ return UnitState(
+ sw_version,
+ hw_version,
+ flags=UnitFlags(
+ changing_busy=bool(flags & 0x80),
+ flashing_busy=bool(flags & 0x40),
+ has_downlink=bool(flags & 0x02),
+ ),
+ ch=ch,
+ beacons=cls.__make_flags6(beacons),
+ np_crc=(nc0, nc1, nc2, nc3, nc4, nc5),
+ video_5v_sens=cls.__make_flags5(video_5v_sens),
+ video_hpd=cls.__make_flags5(video_hpd),
+ video_edid=cls.__make_flags4(video_edid),
+ video_crc=(vc0, vc1, vc2, vc3),
+ usb_5v_sens=cls.__make_flags4(usb_5v_sens),
+ atx_busy=cls.__make_flags4(atx_busy),
+ )
+
+ @classmethod
+ def __make_flags6(cls, mask: int) -> tuple[bool, bool, bool, bool, bool, bool]:
+ return (
+ bool(mask & 0x01), bool(mask & 0x02), bool(mask & 0x04),
+ bool(mask & 0x08), bool(mask & 0x10), bool(mask & 0x20),
+ )
+
+ @classmethod
+ def __make_flags5(cls, mask: int) -> tuple[bool, bool, bool, bool, bool]:
+ return (
+ bool(mask & 0x01), bool(mask & 0x02), bool(mask & 0x04),
+ bool(mask & 0x08), bool(mask & 0x10),
+ )
+
+ @classmethod
+ def __make_flags4(cls, mask: int) -> tuple[bool, bool, bool, bool]:
+ return (bool(mask & 0x01), bool(mask & 0x02), bool(mask & 0x04), bool(mask & 0x08))
+
+
[email protected](frozen=True)
+class UnitAtxLeds(Unpackable):
+ power: tuple[bool, bool, bool, bool]
+ hdd: tuple[bool, bool, bool, bool]
+
+ __struct = struct.Struct("<B")
+
+ @classmethod
+ def unpack(cls, data: bytes, offset: int=0) -> "UnitAtxLeds":
+ (mask,) = cls.__struct.unpack_from(data, offset=offset)
+ return UnitAtxLeds(
+ power=(bool(mask & 0x01), bool(mask & 0x02), bool(mask & 0x04), bool(mask & 0x08)),
+ hdd=(bool(mask & 0x10), bool(mask & 0x20), bool(mask & 0x40), bool(mask & 0x80)),
+ )
+
+
+# =====
[email protected](frozen=True)
+class BodySwitch(Packable):
+ ch: int
+
+ def __post_init__(self) -> None:
+ assert 0 <= self.ch <= 4
+
+ def pack(self) -> bytes:
+ return self.ch.to_bytes()
+
+
[email protected](frozen=True)
+class BodySetBeacon(Packable):
+ ch: int
+ on: bool
+
+ def __post_init__(self) -> None:
+ assert 0 <= self.ch <= 5
+
+ def pack(self) -> bytes:
+ return self.ch.to_bytes() + self.on.to_bytes()
+
+
[email protected](frozen=True)
+class BodyAtxClick(Packable):
+ ch: int
+ action: int
+ delay_ms: int
+
+ POWER = 0
+ RESET = 1
+
+ __struct = struct.Struct("<BBH")
+
+ def __post_init__(self) -> None:
+ assert 0 <= self.ch <= 3
+ assert self.action in [self.POWER, self.RESET]
+ assert 1 <= self.delay_ms <= 0xFFFF
+
+ def pack(self) -> bytes:
+ return self.__struct.pack(self.ch, self.action, self.delay_ms)
+
+
[email protected](frozen=True)
+class BodySetEdid(Packable):
+ ch: int
+ edid: Edid
+
+ def __post_init__(self) -> None:
+ assert 0 <= self.ch <= 3
+
+ def pack(self) -> bytes:
+ return self.ch.to_bytes() + self.edid.pack()
+
+
[email protected](frozen=True)
+class BodyClearEdid(Packable):
+ ch: int
+
+ def __post_init__(self) -> None:
+ assert 0 <= self.ch <= 3
+
+ def pack(self) -> bytes:
+ return self.ch.to_bytes()
+
+
[email protected](frozen=True)
+class BodySetColors(Packable):
+ ch: int
+ colors: Colors
+
+ def __post_init__(self) -> None:
+ assert 0 <= self.ch <= 5
+
+ def pack(self) -> bytes:
+ return self.ch.to_bytes() + self.colors.pack()
+
+
+# =====
[email protected](frozen=True)
+class Request:
+ header: Header
+ body: (Packable | None) = dataclasses.field(default=None)
+
+ def pack(self) -> bytes:
+ msg = self.header.pack()
+ if self.body is not None:
+ msg += self.body.pack()
+ return msg
+
+
[email protected](frozen=True)
+class Response:
+ header: Header
+ body: Unpackable
+
+ @classmethod
+ def unpack(cls, msg: bytes) -> Optional["Response"]:
+ header = Header.unpack(msg)
+ match header.op:
+ case Header.NAK:
+ return Response(header, Nak.unpack(msg, Header.SIZE))
+ case Header.STATE:
+ return Response(header, UnitState.unpack(msg, Header.SIZE))
+ case Header.ATX_LEDS:
+ return Response(header, UnitAtxLeds.unpack(msg, Header.SIZE))
+ # raise RuntimeError(f"Unknown OP in the header: {header!r}")
+ return None
diff --git a/kvmd/apps/kvmd/switch/state.py b/kvmd/apps/kvmd/switch/state.py
new file mode 100644
index 00000000..626cdfe1
--- /dev/null
+++ b/kvmd/apps/kvmd/switch/state.py
@@ -0,0 +1,355 @@
+# ========================================================================== #
+# #
+# KVMD - The main PiKVM daemon. #
+# #
+# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <https://www.gnu.org/licenses/>. #
+# #
+# ========================================================================== #
+
+
+import asyncio
+import dataclasses
+import time
+
+from typing import AsyncGenerator
+
+from .types import Edids
+from .types import Color
+from .types import Colors
+from .types import PortNames
+from .types import AtxClickPowerDelays
+from .types import AtxClickPowerLongDelays
+from .types import AtxClickResetDelays
+
+from .proto import UnitState
+from .proto import UnitAtxLeds
+
+from .chain import Chain
+
+
+# =====
+class _UnitInfo:
+ state: (UnitState | None) = dataclasses.field(default=None)
+ atx_leds: (UnitAtxLeds | None) = dataclasses.field(default=None)
+
+
+# =====
+class StateCache: # pylint: disable=too-many-instance-attributes
+ __FULL = 0xFFFF
+ __SUMMARY = 0x01
+ __EDIDS = 0x02
+ __COLORS = 0x04
+ __VIDEO = 0x08
+ __USB = 0x10
+ __BEACONS = 0x20
+ __ATX = 0x40
+
+ def __init__(self) -> None:
+ self.__edids = Edids()
+ self.__colors = Colors()
+ self.__port_names = PortNames({})
+ self.__atx_cp_delays = AtxClickPowerDelays({})
+ self.__atx_cpl_delays = AtxClickPowerLongDelays({})
+ self.__atx_cr_delays = AtxClickResetDelays({})
+
+ self.__units: list[_UnitInfo] = []
+ self.__active_port = -1
+ self.__synced = True
+
+ self.__queue: "asyncio.Queue[int]" = asyncio.Queue()
+
+ def get_edids(self) -> Edids:
+ return self.__edids.copy()
+
+ def get_colors(self) -> Colors:
+ return self.__colors
+
+ def get_port_names(self) -> PortNames:
+ return self.__port_names.copy()
+
+ def get_atx_cp_delays(self) -> AtxClickPowerDelays:
+ return self.__atx_cp_delays.copy()
+
+ def get_atx_cpl_delays(self) -> AtxClickPowerLongDelays:
+ return self.__atx_cpl_delays.copy()
+
+ def get_atx_cr_delays(self) -> AtxClickResetDelays:
+ return self.__atx_cr_delays.copy()
+
+ # =====
+
+ def get_state(self) -> dict:
+ return self.__inner_get_state(self.__FULL)
+
+ async def trigger_state(self) -> None:
+ self.__bump_state(self.__FULL)
+
+ async def poll_state(self) -> AsyncGenerator[dict, None]:
+ atx_ts: float = 0
+ while True:
+ try:
+ mask = await asyncio.wait_for(self.__queue.get(), timeout=0.1)
+ except TimeoutError:
+ mask = 0
+
+ if mask == self.__ATX:
+ # Откладываем единичное новое событие ATX, чтобы аккумулировать с нескольких свичей
+ if atx_ts == 0:
+ atx_ts = time.monotonic() + 0.2
+ continue
+ elif atx_ts >= time.monotonic():
+ continue
+ # ... Ну или разрешаем отправить, если оно уже достаточно мариновалось
+ elif mask == 0 and atx_ts > time.monotonic():
+ # Разрешаем отправить отложенное
+ mask = self.__ATX
+ atx_ts = 0
+ elif mask & self.__ATX:
+ # Комплексное событие всегда должно обрабатываться сразу
+ atx_ts = 0
+
+ if mask != 0:
+ yield self.__inner_get_state(mask)
+
+ def __inner_get_state(self, mask: int) -> dict: # pylint: disable=too-many-branches,too-many-statements,too-many-locals
+ assert mask != 0
+ x_model = (mask == self.__FULL)
+ x_summary = (mask & self.__SUMMARY)
+ x_edids = (mask & self.__EDIDS)
+ x_colors = (mask & self.__COLORS)
+ x_video = (mask & self.__VIDEO)
+ x_usb = (mask & self.__USB)
+ x_beacons = (mask & self.__BEACONS)
+ x_atx = (mask & self.__ATX)
+
+ state: dict = {}
+ if x_model:
+ state["model"] = {
+ "units": [],
+ "ports": [],
+ "limits": {
+ "atx": {
+ "click_delays": {
+ key: {"default": value, "min": 0, "max": 10}
+ for (key, value) in [
+ ("power", self.__atx_cp_delays.default),
+ ("power_long", self.__atx_cpl_delays.default),
+ ("reset", self.__atx_cr_delays.default),
+ ]
+ },
+ },
+ },
+ }
+ if x_summary:
+ state["summary"] = {"active_port": self.__active_port, "synced": self.__synced}
+ if x_edids:
+ state["edids"] = {
+ "all": {
+ edid_id: {
+ "name": edid.name,
+ "data": edid.as_text(),
+ "parsed": (dataclasses.asdict(edid.info) if edid.info is not None else None),
+ }
+ for (edid_id, edid) in self.__edids.all.items()
+ },
+ "used": [],
+ }
+ if x_colors:
+ state["colors"] = {
+ role: {
+ comp: getattr(getattr(self.__colors, role), comp)
+ for comp in Color.COMPONENTS
+ }
+ for role in Colors.ROLES
+ }
+ if x_video:
+ state["video"] = {"links": []}
+ if x_usb:
+ state["usb"] = {"links": []}
+ if x_beacons:
+ state["beacons"] = {"uplinks": [], "downlinks": [], "ports": []}
+ if x_atx:
+ state["atx"] = {"busy": [], "leds": {"power": [], "hdd": []}}
+
+ if not self.__is_units_ready():
+ return state
+
+ for (unit, ui) in enumerate(self.__units):
+ assert ui.state is not None
+ assert ui.atx_leds is not None
+ if x_model:
+ state["model"]["units"].append({"firmware": {"version": ui.state.sw_version}})
+ if x_video:
+ state["video"]["links"].extend(ui.state.video_5v_sens[:4])
+ if x_usb:
+ state["usb"]["links"].extend(ui.state.usb_5v_sens)
+ if x_beacons:
+ state["beacons"]["uplinks"].append(ui.state.beacons[5])
+ state["beacons"]["downlinks"].append(ui.state.beacons[4])
+ state["beacons"]["ports"].extend(ui.state.beacons[:4])
+ if x_atx:
+ state["atx"]["busy"].extend(ui.state.atx_busy)
+ state["atx"]["leds"]["power"].extend(ui.atx_leds.power)
+ state["atx"]["leds"]["hdd"].extend(ui.atx_leds.hdd)
+ if x_model or x_edids:
+ for ch in range(4):
+ port = Chain.get_virtual_port(unit, ch)
+ if x_model:
+ state["model"]["ports"].append({
+ "unit": unit,
+ "channel": ch,
+ "name": self.__port_names[port],
+ "atx": {
+ "click_delays": {
+ "power": self.__atx_cp_delays[port],
+ "power_long": self.__atx_cpl_delays[port],
+ "reset": self.__atx_cr_delays[port],
+ },
+ },
+ })
+ if x_edids:
+ state["edids"]["used"].append(self.__edids.get_id_for_port(port))
+ return state
+
+ def __inner_check_synced(self) -> bool:
+ for (unit, ui) in enumerate(self.__units):
+ if ui.state is None or ui.state.flags.changing_busy:
+ return False
+ if (
+ self.__active_port >= 0
+ and ui.state.ch != Chain.get_unit_target_channel(unit, self.__active_port)
+ ):
+ return False
+ for ch in range(4):
+ port = Chain.get_virtual_port(unit, ch)
+ edid = self.__edids.get_edid_for_port(port)
+ if not ui.state.compare_edid(ch, edid):
+ return False
+ for ch in range(6):
+ if ui.state.np_crc[ch] != self.__colors.crc:
+ return False
+ return True
+
+ def __recache_synced(self) -> bool:
+ synced = self.__inner_check_synced()
+ if self.__synced != synced:
+ self.__synced = synced
+ return True
+ return False
+
+ def truncate(self, units: int) -> None:
+ if len(self.__units) > units:
+ del self.__units[units:]
+ self.__bump_state(self.__FULL)
+
+ def update_active_port(self, port: int) -> None:
+ changed = (self.__active_port != port)
+ self.__active_port = port
+ changed = (self.__recache_synced() or changed)
+ if changed:
+ self.__bump_state(self.__SUMMARY)
+
+ def update_unit_state(self, unit: int, new: UnitState) -> None:
+ ui = self.__ensure_unit(unit)
+ (prev, ui.state) = (ui.state, new)
+ if not self.__is_units_ready():
+ return
+ mask = 0
+ if prev is None:
+ mask = self.__FULL
+ else:
+ if self.__recache_synced():
+ mask |= self.__SUMMARY
+ if prev.video_5v_sens != new.video_5v_sens:
+ mask |= self.__VIDEO
+ if prev.usb_5v_sens != new.usb_5v_sens:
+ mask |= self.__USB
+ if prev.beacons != new.beacons:
+ mask |= self.__BEACONS
+ if prev.atx_busy != new.atx_busy:
+ mask |= self.__ATX
+ if mask:
+ self.__bump_state(mask)
+
+ def update_unit_atx_leds(self, unit: int, new: UnitAtxLeds) -> None:
+ ui = self.__ensure_unit(unit)
+ (prev, ui.atx_leds) = (ui.atx_leds, new)
+ if not self.__is_units_ready():
+ return
+ if prev is None:
+ self.__bump_state(self.__FULL)
+ elif prev != new:
+ self.__bump_state(self.__ATX)
+
+ def __is_units_ready(self) -> bool:
+ for ui in self.__units:
+ if ui.state is None or ui.atx_leds is None:
+ return False
+ return True
+
+ def __ensure_unit(self, unit: int) -> _UnitInfo:
+ while len(self.__units) < unit + 1:
+ self.__units.append(_UnitInfo())
+ return self.__units[unit]
+
+ def __bump_state(self, mask: int) -> None:
+ assert mask != 0
+ self.__queue.put_nowait(mask)
+
+ # =====
+
+ def set_edids(self, edids: Edids) -> None:
+ changed = (
+ self.__edids.all != edids.all
+ or not self.__edids.compare_on_ports(edids, self.__get_ports())
+ )
+ self.__edids = edids.copy()
+ if changed:
+ self.__bump_state(self.__EDIDS)
+
+ def set_colors(self, colors: Colors) -> None:
+ changed = (self.__colors != colors)
+ self.__colors = colors
+ if changed:
+ self.__bump_state(self.__COLORS)
+
+ def set_port_names(self, port_names: PortNames) -> None:
+ changed = (not self.__port_names.compare_on_ports(port_names, self.__get_ports()))
+ self.__port_names = port_names.copy()
+ if changed:
+ self.__bump_state(self.__FULL)
+
+ def set_atx_cp_delays(self, delays: AtxClickPowerDelays) -> None:
+ changed = (not self.__atx_cp_delays.compare_on_ports(delays, self.__get_ports()))
+ self.__atx_cp_delays = delays.copy()
+ if changed:
+ self.__bump_state(self.__FULL)
+
+ def set_atx_cpl_delays(self, delays: AtxClickPowerLongDelays) -> None:
+ changed = (not self.__atx_cpl_delays.compare_on_ports(delays, self.__get_ports()))
+ self.__atx_cpl_delays = delays.copy()
+ if changed:
+ self.__bump_state(self.__FULL)
+
+ def set_atx_cr_delays(self, delays: AtxClickResetDelays) -> None:
+ changed = (not self.__atx_cr_delays.compare_on_ports(delays, self.__get_ports()))
+ self.__atx_cr_delays = delays.copy()
+ if changed:
+ self.__bump_state(self.__FULL)
+
+ def __get_ports(self) -> int:
+ return (len(self.__units) * 4)
diff --git a/kvmd/apps/kvmd/switch/storage.py b/kvmd/apps/kvmd/switch/storage.py
new file mode 100644
index 00000000..6e3a0a76
--- /dev/null
+++ b/kvmd/apps/kvmd/switch/storage.py
@@ -0,0 +1,186 @@
+# ========================================================================== #
+# #
+# KVMD - The main PiKVM daemon. #
+# #
+# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <https://www.gnu.org/licenses/>. #
+# #
+# ========================================================================== #
+
+
+import os
+import asyncio
+import json
+import contextlib
+
+from typing import AsyncGenerator
+
+try:
+ from ....clients.pst import PstClient
+except ImportError:
+ PstClient = None # type: ignore
+
+# from .lib import get_logger
+from .lib import aiotools
+from .lib import htclient
+from .lib import get_logger
+
+from .types import Edid
+from .types import Edids
+from .types import Color
+from .types import Colors
+from .types import PortNames
+from .types import AtxClickPowerDelays
+from .types import AtxClickPowerLongDelays
+from .types import AtxClickResetDelays
+
+
+# =====
+class StorageContext:
+ __F_EDIDS_ALL = "edids_all.json"
+ __F_EDIDS_PORT = "edids_port.json"
+
+ __F_COLORS = "colors.json"
+
+ __F_PORT_NAMES = "port_names.json"
+
+ __F_ATX_CP_DELAYS = "atx_click_power_delays.json"
+ __F_ATX_CPL_DELAYS = "atx_click_power_long_delays.json"
+ __F_ATX_CR_DELAYS = "atx_click_reset_delays.json"
+
+ def __init__(self, path: str, rw: bool) -> None:
+ self.__path = path
+ self.__rw = rw
+
+ # =====
+
+ async def write_edids(self, edids: Edids) -> None:
+ await self.__write_json_keyvals(self.__F_EDIDS_ALL, {
+ edid_id.lower(): {"name": edid.name, "data": edid.as_text()}
+ for (edid_id, edid) in edids.all.items()
+ if edid_id != Edids.DEFAULT_ID
+ })
+ await self.__write_json_keyvals(self.__F_EDIDS_PORT, edids.port)
+
+ async def write_colors(self, colors: Colors) -> None:
+ await self.__write_json_keyvals(self.__F_COLORS, {
+ role: {
+ comp: getattr(getattr(colors, role), comp)
+ for comp in Color.COMPONENTS
+ }
+ for role in Colors.ROLES
+ })
+
+ async def write_port_names(self, port_names: PortNames) -> None:
+ await self.__write_json_keyvals(self.__F_PORT_NAMES, port_names.kvs)
+
+ async def write_atx_cp_delays(self, delays: AtxClickPowerDelays) -> None:
+ await self.__write_json_keyvals(self.__F_ATX_CP_DELAYS, delays.kvs)
+
+ async def write_atx_cpl_delays(self, delays: AtxClickPowerLongDelays) -> None:
+ await self.__write_json_keyvals(self.__F_ATX_CPL_DELAYS, delays.kvs)
+
+ async def write_atx_cr_delays(self, delays: AtxClickResetDelays) -> None:
+ await self.__write_json_keyvals(self.__F_ATX_CR_DELAYS, delays.kvs)
+
+ async def __write_json_keyvals(self, name: str, kvs: dict) -> None:
+ if len(self.__path) == 0:
+ return
+ assert self.__rw
+ kvs = {str(key): value for (key, value) in kvs.items()}
+ if (await self.__read_json_keyvals(name)) == kvs:
+ return # Don't write the same data
+ path = os.path.join(self.__path, name)
+ get_logger(0).info("Writing '%s' ...", name)
+ await aiotools.write_file(path, json.dumps(kvs))
+
+ # =====
+
+ async def read_edids(self) -> Edids:
+ all_edids = {
+ edid_id.lower(): Edid.from_data(edid["name"], edid["data"])
+ for (edid_id, edid) in (await self.__read_json_keyvals(self.__F_EDIDS_ALL)).items()
+ }
+ port_edids = await self.__read_json_keyvals_int(self.__F_EDIDS_PORT)
+ return Edids(all_edids, port_edids)
+
+ async def read_colors(self) -> Colors:
+ raw = await self.__read_json_keyvals(self.__F_COLORS)
+ return Colors(**{ # type: ignore
+ role: Color(**{comp: raw[role][comp] for comp in Color.COMPONENTS})
+ for role in Colors.ROLES
+ if role in raw
+ })
+
+ async def read_port_names(self) -> PortNames:
+ return PortNames(await self.__read_json_keyvals_int(self.__F_PORT_NAMES))
+
+ async def read_atx_cp_delays(self) -> AtxClickPowerDelays:
+ return AtxClickPowerDelays(await self.__read_json_keyvals_int(self.__F_ATX_CP_DELAYS))
+
+ async def read_atx_cpl_delays(self) -> AtxClickPowerLongDelays:
+ return AtxClickPowerLongDelays(await self.__read_json_keyvals_int(self.__F_ATX_CPL_DELAYS))
+
+ async def read_atx_cr_delays(self) -> AtxClickResetDelays:
+ return AtxClickResetDelays(await self.__read_json_keyvals_int(self.__F_ATX_CR_DELAYS))
+
+ async def __read_json_keyvals_int(self, name: str) -> dict:
+ return (await self.__read_json_keyvals(name, int_keys=True))
+
+ async def __read_json_keyvals(self, name: str, int_keys: bool=False) -> dict:
+ if len(self.__path) == 0:
+ return {}
+ path = os.path.join(self.__path, name)
+ try:
+ kvs: dict = json.loads(await aiotools.read_file(path))
+ except FileNotFoundError:
+ kvs = {}
+ if int_keys:
+ kvs = {int(key): value for (key, value) in kvs.items()}
+ return kvs
+
+
+class Storage:
+ __SUBDIR = "__switch__"
+ __TIMEOUT = 5.0
+
+ def __init__(self, unix_path: str) -> None:
+ self.__pst: (PstClient | None) = None
+ if len(unix_path) > 0 and PstClient is not None:
+ self.__pst = PstClient(
+ subdir=self.__SUBDIR,
+ unix_path=unix_path,
+ timeout=self.__TIMEOUT,
+ user_agent=htclient.make_user_agent("KVMD"),
+ )
+ self.__lock = asyncio.Lock()
+
+ @contextlib.asynccontextmanager
+ async def readable(self) -> AsyncGenerator[StorageContext, None]:
+ async with self.__lock:
+ if self.__pst is None:
+ yield StorageContext("", False)
+ else:
+ path = await self.__pst.get_path()
+ yield StorageContext(path, False)
+
+ @contextlib.asynccontextmanager
+ async def writable(self) -> AsyncGenerator[StorageContext, None]:
+ async with self.__lock:
+ if self.__pst is None:
+ yield StorageContext("", True)
+ else:
+ async with self.__pst.writable() as path:
+ yield StorageContext(path, True)
diff --git a/kvmd/apps/kvmd/switch/types.py b/kvmd/apps/kvmd/switch/types.py
new file mode 100644
index 00000000..32225f06
--- /dev/null
+++ b/kvmd/apps/kvmd/switch/types.py
@@ -0,0 +1,308 @@
+# ========================================================================== #
+# #
+# KVMD - The main PiKVM daemon. #
+# #
+# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <https://www.gnu.org/licenses/>. #
+# #
+# ========================================================================== #
+
+
+import re
+import struct
+import uuid
+import dataclasses
+
+from typing import TypeVar
+from typing import Generic
+
+from .lib import bitbang
+from .lib import ParsedEdidNoBlockError
+from .lib import ParsedEdid
+
+
+# =====
[email protected](frozen=True)
+class EdidInfo:
+ mfc_id: str
+ product_id: int
+ serial: int
+ monitor_name: (str | None)
+ monitor_serial: (str | None)
+ audio: bool
+
+ @classmethod
+ def from_data(cls, data: bytes) -> "EdidInfo":
+ parsed = ParsedEdid(data)
+
+ monitor_name: (str | None) = None
+ try:
+ monitor_name = parsed.get_monitor_name()
+ except ParsedEdidNoBlockError:
+ pass
+
+ monitor_serial: (str | None) = None
+ try:
+ monitor_serial = parsed.get_monitor_serial()
+ except ParsedEdidNoBlockError:
+ pass
+
+ return EdidInfo(
+ mfc_id=parsed.get_mfc_id(),
+ product_id=parsed.get_product_id(),
+ serial=parsed.get_serial(),
+ monitor_name=monitor_name,
+ monitor_serial=monitor_serial,
+ audio=parsed.get_audio(),
+ )
+
+
[email protected](frozen=True)
+class Edid:
+ name: str
+ data: bytes
+ crc: int = dataclasses.field(default=0)
+ valid: bool = dataclasses.field(default=False)
+ info: (EdidInfo | None) = dataclasses.field(default=None)
+
+ __HEADER = b"\x00\xFF\xFF\xFF\xFF\xFF\xFF\x00"
+
+ def __post_init__(self) -> None:
+ assert len(self.name) > 0
+ assert len(self.data) == 256
+ object.__setattr__(self, "crc", bitbang.make_crc16(self.data))
+ object.__setattr__(self, "valid", self.data.startswith(self.__HEADER))
+ try:
+ object.__setattr__(self, "info", EdidInfo.from_data(self.data))
+ except Exception:
+ pass
+
+ def as_text(self) -> str:
+ return "".join(f"{item:0{2}X}" for item in self.data)
+
+ def pack(self) -> bytes:
+ return self.data
+
+ @classmethod
+ def from_data(cls, name: str, data: (str | bytes | None)) -> "Edid":
+ if data is None: # Пустой едид
+ return Edid(name, b"\x00" * 256)
+
+ if isinstance(data, bytes):
+ if data.startswith(cls.__HEADER):
+ return Edid(name, data) # Бинарный едид
+ data_hex = data.decode() # Текстовый едид, прочитанный как бинарный из файла
+ else: # isinstance(data, str)
+ data_hex = str(data) # Текстовый едид
+
+ data_hex = re.sub(r"\s", "", data_hex)
+ assert len(data_hex) == 512
+ data = bytes([
+ int(data_hex[index:index + 2], 16)
+ for index in range(0, len(data_hex), 2)
+ ])
+ return Edid(name, data)
+
+
+class Edids:
+ DEFAULT_NAME = "Default"
+ DEFAULT_ID = "default"
+
+ all: dict[str, Edid] = dataclasses.field(default_factory=dict)
+ port: dict[int, str] = dataclasses.field(default_factory=dict)
+
+ def __post_init__(self) -> None:
+ if self.DEFAULT_ID not in self.all:
+ self.set_default(None)
+
+ def set_default(self, data: (str | bytes | None)) -> None:
+ self.all[self.DEFAULT_ID] = Edid.from_data(self.DEFAULT_NAME, data)
+
+ def copy(self) -> "Edids":
+ return Edids(dict(self.all), dict(self.port))
+
+ def compare_on_ports(self, other: "Edids", ports: int) -> bool:
+ for port in range(ports):
+ if self.get_id_for_port(port) != other.get_id_for_port(port):
+ return False
+ return True
+
+ def add(self, edid: Edid) -> str:
+ edid_id = str(uuid.uuid4()).lower()
+ self.all[edid_id] = edid
+ return edid_id
+
+ def set(self, edid_id: str, edid: Edid) -> None:
+ assert edid_id in self.all
+ self.all[edid_id] = edid
+
+ def get(self, edid_id: str) -> Edid:
+ return self.all[edid_id]
+
+ def remove(self, edid_id: str) -> None:
+ assert edid_id in self.all
+ self.all.pop(edid_id)
+ for port in list(self.port):
+ if self.port[port] == edid_id:
+ self.port.pop(port)
+
+ def has_id(self, edid_id: str) -> bool:
+ return (edid_id in self.all)
+
+ def assign(self, port: int, edid_id: str) -> None:
+ assert edid_id in self.all
+ if edid_id == Edids.DEFAULT_ID:
+ self.port.pop(port, None)
+ else:
+ self.port[port] = edid_id
+
+ def get_id_for_port(self, port: int) -> str:
+ return self.port.get(port, self.DEFAULT_ID)
+
+ def get_edid_for_port(self, port: int) -> Edid:
+ return self.all[self.get_id_for_port(port)]
+
+
+# =====
[email protected](frozen=True)
+class Color:
+ COMPONENTS = frozenset(["red", "green", "blue", "brightness", "blink_ms"])
+
+ red: int
+ green: int
+ blue: int
+ brightness: int
+ blink_ms: int
+ crc: int = dataclasses.field(default=0)
+ _packed: bytes = dataclasses.field(default=b"")
+
+ __struct = struct.Struct("<BBBBH")
+ __rx = re.compile(r"^([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2}):([0-9a-fA-F]{2}):([0-9a-fA-F]{4})$")
+
+ def __post_init__(self) -> None:
+ assert 0 <= self.red <= 0xFF
+ assert 0 <= self.green <= 0xFF
+ assert 0 <= self.blue <= 0xFF
+ assert 0 <= self.brightness <= 0xFF
+ assert 0 <= self.blink_ms <= 0xFFFF
+ data = self.__struct.pack(self.red, self.green, self.blue, self.brightness, self.blink_ms)
+ object.__setattr__(self, "crc", bitbang.make_crc16(data))
+ object.__setattr__(self, "_packed", data)
+
+ def pack(self) -> bytes:
+ return self._packed
+
+ @classmethod
+ def from_text(cls, text: str) -> "Color":
+ match = cls.__rx.match(text)
+ assert match is not None, text
+ return Color(
+ red=int(match.group(1), 16),
+ green=int(match.group(2), 16),
+ blue=int(match.group(3), 16),
+ brightness=int(match.group(4), 16),
+ blink_ms=int(match.group(5), 16),
+ )
+
+
[email protected](frozen=True)
+class Colors:
+ ROLES = frozenset(["inactive", "active", "flashing", "beacon", "bootloader"])
+
+ inactive: Color = dataclasses.field(default=Color(255, 0, 0, 64, 0))
+ active: Color = dataclasses.field(default=Color(0, 255, 0, 128, 0))
+ flashing: Color = dataclasses.field(default=Color(0, 170, 255, 128, 0))
+ beacon: Color = dataclasses.field(default=Color(228, 44, 156, 255, 250))
+ bootloader: Color = dataclasses.field(default=Color(255, 170, 0, 128, 0))
+ crc: int = dataclasses.field(default=0)
+ _packed: bytes = dataclasses.field(default=b"")
+
+ __crc_struct = struct.Struct("<HHHHH")
+
+ def __post_init__(self) -> None:
+ crcs: list[int] = []
+ packed: bytes = b""
+ for color in [self.inactive, self.active, self.flashing, self.beacon, self.bootloader]:
+ crcs.append(color.crc)
+ packed += color.pack()
+ object.__setattr__(self, "crc", bitbang.make_crc16(self.__crc_struct.pack(*crcs)))
+ object.__setattr__(self, "_packed", packed)
+
+ def pack(self) -> bytes:
+ return self._packed
+
+
+# =====
+_T = TypeVar("_T")
+
+
+class _PortsDict(Generic[_T]):
+ def __init__(self, default: _T, kvs: dict[int, _T]) -> None:
+ self.default = default
+ self.kvs = {
+ port: value
+ for (port, value) in kvs.items()
+ if value != default
+ }
+
+ def compare_on_ports(self, other: "_PortsDict[_T]", ports: int) -> bool:
+ for port in range(ports):
+ if self[port] != other[port]:
+ return False
+ return True
+
+ def __getitem__(self, port: int) -> _T:
+ return self.kvs.get(port, self.default)
+
+ def __setitem__(self, port: int, value: (_T | None)) -> None:
+ if value is None:
+ value = self.default
+ if value == self.default:
+ self.kvs.pop(port, None)
+ else:
+ self.kvs[port] = value
+
+
+class PortNames(_PortsDict[str]):
+ def __init__(self, kvs: dict[int, str]) -> None:
+ super().__init__("", kvs)
+
+ def copy(self) -> "PortNames":
+ return PortNames(self.kvs)
+
+
+class AtxClickPowerDelays(_PortsDict[float]):
+ def __init__(self, kvs: dict[int, float]) -> None:
+ super().__init__(0.5, kvs)
+
+ def copy(self) -> "AtxClickPowerDelays":
+ return AtxClickPowerDelays(self.kvs)
+
+
+class AtxClickPowerLongDelays(_PortsDict[float]):
+ def __init__(self, kvs: dict[int, float]) -> None:
+ super().__init__(5.5, kvs)
+
+ def copy(self) -> "AtxClickPowerLongDelays":
+ return AtxClickPowerLongDelays(self.kvs)
+
+
+class AtxClickResetDelays(_PortsDict[float]):
+ def __init__(self, kvs: dict[int, float]) -> None:
+ super().__init__(0.5, kvs)
+
+ def copy(self) -> "AtxClickResetDelays":
+ return AtxClickResetDelays(self.kvs)
diff --git a/kvmd/apps/pst/server.py b/kvmd/apps/pst/server.py
index 79bbf7c8..8d8bf9d4 100644
--- a/kvmd/apps/pst/server.py
+++ b/kvmd/apps/pst/server.py
@@ -24,6 +24,7 @@ import os
import asyncio
from aiohttp.web import Request
+from aiohttp.web import Response
from aiohttp.web import WebSocketResponse
from ...logging import get_logger
@@ -35,6 +36,7 @@ from ... import fstab
from ...htserver import exposed_http
from ...htserver import exposed_ws
+from ...htserver import make_json_response
from ...htserver import WsSession
from ...htserver import HttpServer
@@ -65,6 +67,16 @@ class PstServer(HttpServer): # pylint: disable=too-many-arguments,too-many-inst
await ws.send_event("loop", {})
return (await self._ws_loop(ws))
+ @exposed_http("GET", "/state")
+ async def __state_handler(self, _: Request) -> Response:
+ return make_json_response({
+ "clients": len(self._get_wss()),
+ "data": {
+ "path": self.__data_path,
+ "write_allowed": self.__is_write_available(),
+ },
+ })
+
@exposed_ws("ping")
async def __ws_ping_handler(self, ws: WsSession, _: dict) -> None:
await ws.send_event("pong", {})
diff --git a/kvmd/clients/pst.py b/kvmd/clients/pst.py
new file mode 100644
index 00000000..6b9f5234
--- /dev/null
+++ b/kvmd/clients/pst.py
@@ -0,0 +1,93 @@
+# ========================================================================== #
+# #
+# KVMD - The main PiKVM daemon. #
+# #
+# Copyright (C) 2020 Maxim Devaev <[email protected]> #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <https://www.gnu.org/licenses/>. #
+# #
+# ========================================================================== #
+
+
+import os
+import contextlib
+
+from typing import AsyncGenerator
+
+import aiohttp
+
+from .. import htclient
+from .. import htserver
+
+
+# =====
+class PstError(Exception):
+ pass
+
+
+# =====
+class PstClient:
+ def __init__(
+ self,
+ subdir: str,
+ unix_path: str,
+ timeout: float,
+ user_agent: str,
+ ) -> None:
+
+ self.__subdir = subdir
+ self.__unix_path = unix_path
+ self.__timeout = timeout
+ self.__user_agent = user_agent
+
+ async def get_path(self) -> str:
+ async with self.__make_http_session() as session:
+ async with session.get("http://localhost:0/state") as resp:
+ htclient.raise_not_200(resp)
+ path = (await resp.json())["result"]["data"]["path"]
+ return os.path.join(path, self.__subdir)
+
+ @contextlib.asynccontextmanager
+ async def writable(self) -> AsyncGenerator[str, None]:
+ async with self.__inner_writable() as path:
+ path = os.path.join(path, self.__subdir)
+ if not os.path.exists(path):
+ os.mkdir(path)
+ yield path
+
+ @contextlib.asynccontextmanager
+ async def __inner_writable(self) -> AsyncGenerator[str, None]:
+ async with self.__make_http_session() as session:
+ async with session.ws_connect("http://localhost:0/ws") as ws:
+ path = ""
+ async for msg in ws:
+ if msg.type != aiohttp.WSMsgType.TEXT:
+ raise PstError(f"Unexpected message type: {msg!r}")
+ (event_type, event) = htserver.parse_ws_event(msg.data)
+ if event_type == "storage_state":
+ if not event["data"]["write_allowed"]:
+ raise PstError("Write is not allowed")
+ path = event["data"]["path"]
+ break
+ if not path:
+ raise PstError("WS loop broken without write_allowed=True flag")
+ # TODO: Actually we should follow ws events, but for fast writing we can safely ignore them
+ yield path
+
+ def __make_http_session(self) -> aiohttp.ClientSession:
+ return aiohttp.ClientSession(
+ headers={"User-Agent": self.__user_agent},
+ connector=aiohttp.UnixConnector(path=self.__unix_path),
+ timeout=aiohttp.ClientTimeout(total=self.__timeout),
+ )
diff --git a/kvmd/validators/__init__.py b/kvmd/validators/__init__.py
index 39ff60aa..aa997ab9 100644
--- a/kvmd/validators/__init__.py
+++ b/kvmd/validators/__init__.py
@@ -99,3 +99,11 @@ def check_any(arg: Any, name: str, validators: list[Callable[[Any], Any]]) -> An
except Exception:
pass
raise_error(arg, name)
+
+
+# =====
+def filter_printable(arg: str, replace: str, limit: int) -> str:
+ return "".join(
+ (ch if ch.isprintable() else replace)
+ for ch in arg[:limit]
+ )
diff --git a/kvmd/validators/os.py b/kvmd/validators/os.py
index 94d3a40f..b2381d0b 100644
--- a/kvmd/validators/os.py
+++ b/kvmd/validators/os.py
@@ -26,6 +26,7 @@ import stat
from typing import Any
from . import raise_error
+from . import filter_printable
from .basic import valid_number
from .basic import valid_string_list
@@ -75,9 +76,7 @@ def valid_abs_dir(arg: Any, name: str="") -> str:
def valid_printable_filename(arg: Any, name: str="") -> str:
if not name:
name = "printable filename"
-
arg = valid_stripped_string_not_empty(arg, name)
-
if (
"/" in arg
or "\0" in arg
@@ -85,12 +84,7 @@ def valid_printable_filename(arg: Any, name: str="") -> str:
or arg == "lost+found"
):
raise_error(arg, name)
-
- arg = "".join(
- (ch if ch.isprintable() else "_")
- for ch in arg[:255]
- )
- return arg
+ return filter_printable(arg, "_", 255)
# =====
diff --git a/kvmd/validators/switch.py b/kvmd/validators/switch.py
new file mode 100644
index 00000000..d4f3ab2f
--- /dev/null
+++ b/kvmd/validators/switch.py
@@ -0,0 +1,67 @@
+# ========================================================================== #
+# #
+# KVMD - The main PiKVM daemon. #
+# #
+# Copyright (C) 2018-2024 Maxim Devaev <[email protected]> #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <https://www.gnu.org/licenses/>. #
+# #
+# ========================================================================== #
+
+
+import re
+
+from typing import Any
+
+from . import filter_printable
+from . import check_re_match
+
+from .basic import valid_stripped_string
+from .basic import valid_number
+
+
+# =====
+def valid_switch_port_name(arg: Any) -> str:
+ arg = valid_stripped_string(arg, name="switch port name")
+ arg = filter_printable(arg, " ", 255)
+ arg = re.sub(r"\s+", " ", arg)
+ return arg.strip()
+
+
+def valid_switch_edid_id(arg: Any, allow_default: bool) -> str:
+ pattern = "(?i)^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
+ if allow_default:
+ pattern += "|^default$"
+ return check_re_match(arg, "switch EDID ID", pattern).lower()
+
+
+def valid_switch_edid_data(arg: Any) -> str:
+ name = "switch EDID data"
+ arg = valid_stripped_string(arg, name=name)
+ arg = re.sub(r"\s", "", arg)
+ return check_re_match(arg, name, "(?i)^[0-9a-f]{512}$").upper()
+
+
+def valid_switch_color(arg: Any, allow_default: bool) -> str:
+ pattern = "(?i)^[0-9a-f]{6}:[0-9a-f]{2}:[0-9a-f]{4}$"
+ if allow_default:
+ pattern += "|^default$"
+ arg = check_re_match(arg, "switch color", pattern).upper()
+ if arg == "DEFAULT":
+ arg = "default"
+ return arg
+
+
+def valid_switch_atx_click_delay(arg: Any) -> float:
+ return valid_number(arg, min=0, max=10, type=float, name="ATX delay")