summaryrefslogtreecommitdiff
path: root/kvmd/apps/janus/stun.py
diff options
context:
space:
mode:
Diffstat (limited to 'kvmd/apps/janus/stun.py')
-rw-r--r--kvmd/apps/janus/stun.py268
1 files changed, 131 insertions, 137 deletions
diff --git a/kvmd/apps/janus/stun.py b/kvmd/apps/janus/stun.py
index 954f2d73..9a54e5df 100644
--- a/kvmd/apps/janus/stun.py
+++ b/kvmd/apps/janus/stun.py
@@ -1,3 +1,4 @@
+import asyncio
import socket
import struct
import secrets
@@ -40,144 +41,137 @@ class StunNatType:
# =====
-async def stun_get_info(
- stun_host: str,
- stun_port: int,
- src_ip: str,
- src_port: int,
- timeout: float,
-) -> Tuple[str, str]:
-
- return (await aiotools.run_async(_stun_get_info, stun_host, stun_port, src_ip, src_port, timeout))
-
-
-def _stun_get_info(
- stun_host: str,
- stun_port: int,
- src_ip: str,
- src_port: int,
- timeout: float,
-) -> Tuple[str, str]:
-
+class Stun:
# Partially based on https://github.com/JohnVillalovos/pystun
- (family, _, _, _, addr) = socket.getaddrinfo(src_ip, src_port, type=socket.SOCK_DGRAM)[0]
- with socket.socket(family, socket.SOCK_DGRAM) as sock:
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- sock.settimeout(timeout)
- sock.bind(addr)
- (nat_type, response) = _get_nat_type(
- stun_host=stun_host,
- stun_port=stun_port,
- src_ip=src_ip,
- sock=sock,
- )
- return (nat_type, (response.ext.ip if response.ext is not None else ""))
-
-
-def _get_nat_type( # pylint: disable=too-many-return-statements
- stun_host: str,
- stun_port: int,
- src_ip: str,
- sock: socket.socket,
-) -> Tuple[str, StunResponse]:
-
- first = _stun_request("First probe", stun_host, stun_port, b"", sock)
- if not first.ok:
- return (StunNatType.BLOCKED, first)
- if first.ext is None:
- raise RuntimeError(f"Ext addr is None: {first}")
-
- request = struct.pack(">HHI", 0x0003, 0x0004, 0x00000006) # Change-Request
- response = _stun_request("Change request [ext_ip == src_ip]", stun_host, stun_port, request, sock)
-
- if first.ext.ip == src_ip:
- if response.ok:
- return (StunNatType.OPEN_INTERNET, response)
- return (StunNatType.SYMMETRIC_UDP_FW, response)
-
- if response.ok:
- return (StunNatType.FULL_CONE_NAT, response)
-
- if first.changed is None:
- raise RuntimeError(f"Changed addr is None: {first}")
- response = _stun_request("Change request [ext_ip != src_ip]", first.changed.ip, first.changed.port, b"", sock)
- if not response.ok:
- return (StunNatType.CHANGED_ADDR_ERROR, response)
+ def __init__(
+ self,
+ host: str,
+ port: int,
+ timeout: float,
+ retries: int,
+ retries_delay: float,
+ ) -> None:
+
+ self.host = host
+ self.port = port
+ self.__timeout = timeout
+ self.__retries = retries
+ self.__retries_delay = retries_delay
+
+ self.__sock: Optional[socket.socket] = None
+
+ async def get_info(self, src_ip: str, src_port: int) -> Tuple[str, str]:
+
+ (family, _, _, _, addr) = socket.getaddrinfo(src_ip, src_port, type=socket.SOCK_DGRAM)[0]
+ try:
+ with socket.socket(family, socket.SOCK_DGRAM) as self.__sock:
+ self.__sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ self.__sock.settimeout(self.__timeout)
+ self.__sock.bind(addr)
+ (nat_type, response) = await self.__get_nat_type(src_ip)
+ return (nat_type, (response.ext.ip if response.ext is not None else ""))
+ finally:
+ self.__sock = None
+
+ async def __get_nat_type(self, src_ip: str) -> Tuple[str, StunResponse]: # pylint: disable=too-many-return-statements
+ first = await self.__make_request("First probe")
+ if not first.ok:
+ return (StunNatType.BLOCKED, first)
+
+ request = struct.pack(">HHI", 0x0003, 0x0004, 0x00000006) # Change-Request
+ response = await self.__make_request("Change request [ext_ip == src_ip]", request)
+
+ if first.ext is not None and first.ext.ip == src_ip:
+ if response.ok:
+ return (StunNatType.OPEN_INTERNET, response)
+ return (StunNatType.SYMMETRIC_UDP_FW, response)
- if response.ext == first.ext:
- request = struct.pack(">HHI", 0x0003, 0x0004, 0x00000002)
- response = _stun_request("Change port", first.changed.ip, stun_port, request, sock)
if response.ok:
- return (StunNatType.RESTRICTED_NAT, response)
- return (StunNatType.RESTRICTED_PORT_NAT, response)
-
- return (StunNatType.SYMMETRIC_NAT, response)
-
-
-def _stun_request( # pylint: disable=too-many-locals
- ctx: str,
- host: str,
- port: int,
- request: bytes,
- sock: socket.socket,
-) -> StunResponse:
-
- # TODO: Support IPv6 and RFC 5389
- # The first 4 bytes of the response are the Type (2) and Length (2)
- # The 5th byte is Reserved
- # The 6th byte is the Family: 0x01 = IPv4, 0x02 = IPv6
- # The remaining bytes are the IP address. 32 bits for IPv4 or 128 bits for
- # IPv6.
- # More info at: https://tools.ietf.org/html/rfc3489#section-11.2.1
- # And at: https://tools.ietf.org/html/rfc5389#section-15.1
-
- trans_id = secrets.token_bytes(16)
- request = struct.pack(">HH", 0x0001, len(request)) + trans_id + request # Bind Request
-
- try:
- sock.sendto(request, (host, port))
- except Exception as err:
- get_logger().error("%s: Can't send request: %s", ctx, tools.efmt(err))
- return StunResponse(ok=False)
- try:
- response = sock.recvfrom(2048)[0]
- except Exception as err:
- get_logger().error("%s: Can't recv response: %s", ctx, tools.efmt(err))
- return StunResponse(ok=False)
-
- (response_type, payload_len) = struct.unpack(">HH", response[:4])
- if response_type != 0x0101:
- get_logger().error("%s: Invalid response type: %#.4x", ctx, response_type)
- return StunResponse(ok=False)
- if trans_id != response[4:20]:
- get_logger().error("%s: Transaction ID mismatch")
- return StunResponse(ok=False)
-
- parsed: Dict[str, StunAddress] = {}
- base = 20
- remaining = payload_len
- while remaining > 0:
- (attr_type, attr_len) = struct.unpack(">HH", response[base:(base + 4)])
- base += 4
- field = {
- 0x0001: "ext", # MAPPED-ADDRESS
- 0x0004: "src", # SOURCE-ADDRESS
- 0x0005: "changed", # CHANGED-ADDRESS
- }.get(attr_type)
- if field is not None:
- parsed[field] = _parse_address(response[base:])
- base += attr_len
- remaining -= (4 + attr_len)
- return StunResponse(ok=True, **parsed)
-
-
-def _parse_address(data: bytes) -> StunAddress:
- family = data[1]
- if family == 1:
- parts = struct.unpack(">HBBBB", data[2:8])
- return StunAddress(
- ip=".".join(map(str, parts[1:])),
- port=parts[0],
- )
- raise RuntimeError(f"Only IPv4 supported; received: {family}")
+ return (StunNatType.FULL_CONE_NAT, response)
+
+ if first.changed is None:
+ raise RuntimeError(f"Changed addr is None: {first}")
+ response = await self.__make_request("Change request [ext_ip != src_ip]", b"", *first.changed.ip)
+ if not response.ok:
+ return (StunNatType.CHANGED_ADDR_ERROR, response)
+
+ if response.ext == first.ext:
+ request = struct.pack(">HHI", 0x0003, 0x0004, 0x00000002)
+ response = await self.__make_request("Change port", request, first.changed.ip)
+ if response.ok:
+ return (StunNatType.RESTRICTED_NAT, response)
+ return (StunNatType.RESTRICTED_PORT_NAT, response)
+
+ return (StunNatType.SYMMETRIC_NAT, response)
+
+ async def __make_request(self, ctx: str, request: bytes=b"", host: str="", port: int=0) -> StunResponse:
+ # TODO: Support IPv6 and RFC 5389
+ # The first 4 bytes of the response are the Type (2) and Length (2)
+ # The 5th byte is Reserved
+ # The 6th byte is the Family: 0x01 = IPv4, 0x02 = IPv6
+ # The remaining bytes are the IP address. 32 bits for IPv4 or 128 bits for
+ # IPv6.
+ # More info at: https://tools.ietf.org/html/rfc3489#section-11.2.1
+ # And at: https://tools.ietf.org/html/rfc5389#section-15.1
+
+ (response, error) = (b"", "")
+ for _ in range(self.__retries):
+ (response, error) = await self.__inner_make_request(request, host, port)
+ if not error:
+ break
+ await asyncio.sleep(self.__retries_delay)
+ if error:
+ get_logger(0).error("%s: Can't perform STUN request after %d retries; last error: %s",
+ ctx, self.__retries, error)
+ return StunResponse(ok=False)
+
+ parsed: Dict[str, StunAddress] = {}
+ offset = 0
+ remaining = len(response)
+ while remaining > 0:
+ (attr_type, attr_len) = struct.unpack(">HH", response[offset : offset + 4]) # noqa: E203
+ offset += 4
+ field = {
+ 0x0001: "ext", # MAPPED-ADDRESS
+ 0x0004: "src", # SOURCE-ADDRESS
+ 0x0005: "changed", # CHANGED-ADDRESS
+ }.get(attr_type)
+ if field is not None:
+ parsed[field] = self.__parse_address(response[offset:])
+ offset += attr_len
+ remaining -= (4 + attr_len)
+ return StunResponse(ok=True, **parsed)
+
+ async def __inner_make_request(self, request: bytes, host: str, port: int) -> Tuple[bytes, str]:
+ assert self.__sock is not None
+
+ trans_id = secrets.token_bytes(16)
+ request = struct.pack(">HH", 0x0001, len(request)) + trans_id + request # Bind Request
+
+ try:
+ await aiotools.run_async(self.__sock.sendto, request, ((host or self.host), (port or self.port)))
+ except Exception as err:
+ return (b"", f"Send error: {tools.efmt(err)}")
+ try:
+ response = (await aiotools.run_async(self.__sock.recvfrom, 2048))[0]
+ except Exception as err:
+ return (b"", f"Recv error: {tools.efmt(err)}")
+
+ (response_type, payload_len) = struct.unpack(">HH", response[:4])
+ if response_type != 0x0101:
+ return (b"", f"Invalid response type: {response_type:#06x}")
+ if trans_id != response[4:20]:
+ return (b"", "Transaction ID mismatch")
+
+ return (response[20 : 20 + payload_len], "") # noqa: E203
+
+ def __parse_address(self, data: bytes) -> StunAddress:
+ family = data[1]
+ if family == 1:
+ parts = struct.unpack(">HBBBB", data[2:8])
+ return StunAddress(
+ ip=".".join(map(str, parts[1:])),
+ port=parts[0],
+ )
+ raise RuntimeError(f"Only IPv4 supported; received: {family}")