From: Artur Pragacz <49985303+arturpragacz@users.noreply.github.com> Date: Fri, 12 Dec 2025 17:02:04 +0000 (+0100) Subject: Add DTLS pinning (#2796) X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=13c939e6835e9e089102dc04b07f990c372e0aa7;p=music-assistant-server.git Add DTLS pinning (#2796) Add persistent DTLS certificate management for WebRTC, enabling client-side certificate pinning, which significantly improves security. --- diff --git a/music_assistant/controllers/webserver/remote_access/__init__.py b/music_assistant/controllers/webserver/remote_access/__init__.py index 3410b2b9..f295e701 100644 --- a/music_assistant/controllers/webserver/remote_access/__init__.py +++ b/music_assistant/controllers/webserver/remote_access/__init__.py @@ -16,12 +16,14 @@ from mashumaro import DataClassDictMixin from music_assistant_models.enums import EventType from music_assistant.constants import CONF_CORE -from music_assistant.controllers.webserver.remote_access.gateway import ( - WebRTCGateway, - generate_remote_id, +from music_assistant.controllers.webserver.remote_access.gateway import WebRTCGateway +from music_assistant.helpers.webrtc_certificate import ( + get_or_create_webrtc_certificate, + get_remote_id_from_certificate, ) if TYPE_CHECKING: + from aiortc.rtcdtlstransport import RTCCertificate from music_assistant_models.event import MassEvent from music_assistant.controllers.webserver import WebserverController @@ -31,7 +33,6 @@ if TYPE_CHECKING: SIGNALING_SERVER_URL = "wss://signaling.music-assistant.io/ws" CONF_KEY_MAIN = "remote_access" -CONF_REMOTE_ID = "remote_id" CONF_ENABLED = "enabled" TASK_ID_START_GATEWAY = "remote_access_start_gateway" @@ -59,7 +60,8 @@ class RemoteAccessManager: self.mass = webserver.mass self.logger = webserver.logger.getChild("remote_access") self.gateway: WebRTCGateway | None = None - self._remote_id: str | None = None + self._remote_id: str + self._certificate: RTCCertificate self._enabled: bool = False self._using_ha_cloud: bool = False self._starting: bool = False # Prevents concurrent gateway starts @@ -67,16 +69,13 @@ class RemoteAccessManager: async def setup(self) -> None: """Initialize the remote access manager.""" + self._certificate = get_or_create_webrtc_certificate(self.mass.storage_path) + + self._remote_id = get_remote_id_from_certificate(self._certificate) + self.logger.info("WebRTC certificate remote_id: %s", self._remote_id) + enabled_value = self.mass.config.get(f"{CONF_CORE}/{CONF_KEY_MAIN}/{CONF_ENABLED}", False) self._enabled = bool(enabled_value) - remote_id_value = self.mass.config.get( - f"{CONF_CORE}/{CONF_KEY_MAIN}/{CONF_REMOTE_ID}", None - ) - if not remote_id_value: - remote_id_value = generate_remote_id() - self.mass.config.set(f"{CONF_CORE}/{CONF_KEY_MAIN}/{CONF_REMOTE_ID}", remote_id_value) - - self._remote_id = str(remote_id_value) self._register_api_commands() self.mass.subscribe(self._on_providers_updated, EventType.PROVIDERS_UPDATED) if self._enabled: @@ -123,9 +122,10 @@ class RemoteAccessManager: self.gateway = WebRTCGateway( http_session=self.mass.http_session, + remote_id=self._remote_id, + certificate=self._certificate, signaling_url=SIGNALING_SERVER_URL, local_ws_url=local_ws_url, - remote_id=self._remote_id, ice_servers=ice_servers, # Pass callback to get fresh ICE servers for each client connection # This ensures TURN credentials are always valid @@ -236,10 +236,15 @@ class RemoteAccessManager: return self.gateway is not None and self.gateway.is_connected @property - def remote_id(self) -> str | None: + def remote_id(self) -> str: """Return the current Remote ID.""" return self._remote_id + @property + def certificate(self) -> RTCCertificate: + """Return the persistent WebRTC DTLS certificate.""" + return self._certificate + def _register_api_commands(self) -> None: """Register API commands for remote access.""" @@ -249,7 +254,7 @@ class RemoteAccessManager: enabled=self.is_enabled, running=self.is_running, connected=self.is_connected, - remote_id=self._remote_id or "", + remote_id=self._remote_id, using_ha_cloud=self._using_ha_cloud, signaling_url=SIGNALING_SERVER_URL, ) diff --git a/music_assistant/controllers/webserver/remote_access/gateway.py b/music_assistant/controllers/webserver/remote_access/gateway.py index 6e75a3ad..18f247ed 100644 --- a/music_assistant/controllers/webserver/remote_access/gateway.py +++ b/music_assistant/controllers/webserver/remote_access/gateway.py @@ -11,17 +11,19 @@ import asyncio import contextlib import json import logging -import secrets -import string from collections.abc import Awaitable, Callable from dataclasses import dataclass, field -from typing import Any +from typing import TYPE_CHECKING, Any import aiohttp from aiortc import RTCConfiguration, RTCIceServer, RTCPeerConnection, RTCSessionDescription from aiortc.sdp import candidate_from_sdp from music_assistant.constants import MASS_LOGGER_NAME, VERBOSE_LOG_LEVEL +from music_assistant.helpers.webrtc_certificate import create_peer_connection_with_certificate + +if TYPE_CHECKING: + from aiortc.rtcdtlstransport import RTCCertificate LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.remote_access") @@ -30,14 +32,6 @@ logging.getLogger("aioice").setLevel(logging.WARNING) logging.getLogger("aiortc").setLevel(logging.WARNING) -def generate_remote_id() -> str: - """Generate a unique Remote ID in the format MA-XXXX-XXXX.""" - chars = string.ascii_uppercase + string.digits - part1 = "".join(secrets.choice(chars) for _ in range(4)) - part2 = "".join(secrets.choice(chars) for _ in range(4)) - return f"MA-{part1}-{part2}" - - @dataclass class WebRTCSession: """Represents an active WebRTC session with a remote client.""" @@ -72,19 +66,21 @@ class WebRTCGateway: def __init__( self, http_session: aiohttp.ClientSession, + remote_id: str, + certificate: RTCCertificate, signaling_url: str = "wss://signaling.music-assistant.io/ws", local_ws_url: str = "ws://localhost:8095/ws", ice_servers: list[dict[str, Any]] | None = None, - remote_id: str | None = None, ice_servers_callback: Callable[[], Awaitable[list[dict[str, Any]]]] | None = None, ) -> None: """Initialize the WebRTC Gateway. :param http_session: Shared aiohttp ClientSession to use for HTTP/WebSocket connections. + :param remote_id: Remote ID for this server instance. + :param certificate: Persistent RTCCertificate for DTLS, enabling client-side pinning. :param signaling_url: WebSocket URL of the signaling server. :param local_ws_url: Local WebSocket URL to bridge to. :param ice_servers: List of ICE server configurations (used at registration time). - :param remote_id: Optional Remote ID to use (generated if not provided). :param ice_servers_callback: Optional callback to fetch fresh ICE servers for each session. If provided, this will be called for each client connection to get fresh TURN credentials. If not provided, the static ice_servers will be used. @@ -92,7 +88,8 @@ class WebRTCGateway: self.http_session = http_session self.signaling_url = signaling_url self.local_ws_url = local_ws_url - self.remote_id = remote_id or generate_remote_id() + self._remote_id = remote_id + self._certificate = certificate self.logger = LOGGER self._ice_servers_callback = ice_servers_callback @@ -271,7 +268,7 @@ class WebRTCGateway: await self._signaling_ws.send_json( { "type": "register-server", - "remoteId": self.remote_id, + "remoteId": self._remote_id, "iceServers": self.ice_servers, } ) @@ -332,7 +329,7 @@ class WebRTCGateway: config = RTCConfiguration( iceServers=[RTCIceServer(**server) for server in session_ice_servers] ) - pc = RTCPeerConnection(configuration=config) + pc = create_peer_connection_with_certificate(self._certificate, configuration=config) session = WebRTCSession(session_id=session_id, peer_connection=pc) self.sessions[session_id] = session diff --git a/music_assistant/helpers/webrtc_certificate.py b/music_assistant/helpers/webrtc_certificate.py new file mode 100644 index 00000000..4d188322 --- /dev/null +++ b/music_assistant/helpers/webrtc_certificate.py @@ -0,0 +1,220 @@ +"""WebRTC DTLS Certificate Management. + +This module provides persistent DTLS certificate management for WebRTC connections. +The certificate is generated once and stored persistently, enabling client-side +certificate pinning for authentication. +""" + +from __future__ import annotations + +import base64 +import logging +import stat +from datetime import UTC, datetime, timedelta +from pathlib import Path + +from aiortc import RTCConfiguration, RTCPeerConnection +from aiortc.rtcdtlstransport import RTCCertificate +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.x509.oid import NameOID + +LOGGER = logging.getLogger(__name__) + +CERT_FILENAME = "webrtc_certificate.pem" +KEY_FILENAME = "webrtc_private_key.pem" + +CERT_VALIDITY_DAYS = 3650 # 10 years + +CERT_RENEWAL_THRESHOLD_DAYS = 30 + + +def _generate_certificate() -> tuple[ec.EllipticCurvePrivateKey, x509.Certificate]: + """Generate a new ECDSA certificate for WebRTC DTLS. + + :return: Tuple of (private_key, certificate). + """ + # Generate ECDSA key (SECP256R1 - same as aiortc default) + private_key = ec.generate_private_key(ec.SECP256R1()) + + now = datetime.now(UTC) + not_before = now - timedelta(days=1) + not_after = now + timedelta(days=CERT_VALIDITY_DAYS) + + subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Music Assistant WebRTC")]) + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(subject) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(not_before) + .not_valid_after(not_after) + .sign(private_key, hashes.SHA256()) + ) + + return private_key, cert + + +def _save_certificate( + storage_path: str, + private_key: ec.EllipticCurvePrivateKey, + cert: x509.Certificate, +) -> None: + """Save certificate and private key to disk. + + :param storage_path: Directory to store the files. + :param private_key: The EC private key. + :param cert: The X.509 certificate. + """ + cert_path = Path(storage_path) / CERT_FILENAME + key_path = Path(storage_path) / KEY_FILENAME + + cert_pem = cert.public_bytes(serialization.Encoding.PEM) + cert_path.write_bytes(cert_pem) + + key_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + key_path.write_bytes(key_pem) + + # Set restrictive permissions on private key (owner read/write only) + key_path.chmod(stat.S_IRUSR | stat.S_IWUSR) + + LOGGER.info("Saved WebRTC certificate to %s", cert_path) + + +def _load_certificate( + storage_path: str, +) -> tuple[ec.EllipticCurvePrivateKey, x509.Certificate] | None: + """Load certificate and private key from disk. + + :param storage_path: Directory containing the files. + :return: Tuple of (private_key, certificate) or None if files don't exist. + """ + cert_path = Path(storage_path) / CERT_FILENAME + key_path = Path(storage_path) / KEY_FILENAME + + if not cert_path.exists() or not key_path.exists(): + return None + + try: + cert_pem = cert_path.read_bytes() + cert = x509.load_pem_x509_certificate(cert_pem) + + key_pem = key_path.read_bytes() + private_key = serialization.load_pem_private_key(key_pem, password=None) + + if not isinstance(private_key, ec.EllipticCurvePrivateKey): + LOGGER.warning("WebRTC private key is not an EC key, will regenerate") + return None + + return private_key, cert + except Exception as err: + LOGGER.warning("Failed to load WebRTC certificate: %s", err) + return None + + +def _is_certificate_valid(cert: x509.Certificate) -> bool: + """Check if certificate is still valid with enough time remaining. + + :param cert: The X.509 certificate to check. + :return: True if certificate is valid and has sufficient time remaining. + """ + now = datetime.now(UTC) + not_after = cert.not_valid_after_utc + + if now >= not_after: + LOGGER.info("WebRTC certificate has expired") + return False + + days_remaining = (not_after - now).days + if days_remaining < CERT_RENEWAL_THRESHOLD_DAYS: + LOGGER.info( + "WebRTC certificate expires in %d days, will regenerate", + days_remaining, + ) + return False + + return True + + +def get_or_create_webrtc_certificate(storage_path: str) -> RTCCertificate: + """Get or create a persistent WebRTC DTLS certificate. + + Loads an existing certificate from disk if available and valid. + Otherwise, generates a new certificate and saves it. + + :param storage_path: Directory to store/load the certificate files. + :return: RTCCertificate instance for use with WebRTC. + """ + loaded = _load_certificate(storage_path) + + if loaded is not None: + private_key, cert = loaded + + if _is_certificate_valid(cert): + return RTCCertificate(key=private_key, cert=cert) + + LOGGER.info("Generating new WebRTC DTLS certificate (valid for %d days)", CERT_VALIDITY_DAYS) + private_key, cert = _generate_certificate() + _save_certificate(storage_path, private_key, cert) + + return RTCCertificate(key=private_key, cert=cert) + + +def _get_certificate_fingerprint(certificate: RTCCertificate) -> str: + """Get the SHA-256 fingerprint of a certificate. + + :param certificate: The RTCCertificate to get the fingerprint for. + :return: SHA-256 fingerprint as colon-separated hex string (e.g., "A1:B2:C3:..."). + """ + fingerprints = certificate.getFingerprints() + for fp in fingerprints: + if fp.algorithm == "sha-256": + return fp.value + raise ValueError("SHA-256 fingerprint not found in certificate") + + +def get_remote_id_from_certificate(certificate: RTCCertificate) -> str: + """Generate a remote ID from the certificate fingerprint. + + Uses base32-encoded 128-bit truncation of the SHA-256 fingerprint. + This creates a deterministic remote ID tied to the certificate. + + :param certificate: The RTCCertificate to derive the remote ID from. + :return: Custom base32-encoded (with 9s instead of 2s) remote ID string + (26 characters, uppercase, no-padding). + """ + fingerprint = _get_certificate_fingerprint(certificate) + + # Parse the colon-separated hex fingerprint to bytes + # Format: "A1:B2:C3:D4:..." -> bytes + fingerprint_bytes = bytes.fromhex(fingerprint.replace(":", "")) + + # Take first 128 bits (16 bytes) of SHA-256 + truncated = fingerprint_bytes[:16] + + # Base32 encode (with 9s instead of 2s) and return (uppercase) without padding + return base64.b32encode(truncated).decode("ascii").rstrip("=").replace("2", "9") + + +def create_peer_connection_with_certificate( + certificate: RTCCertificate, + configuration: RTCConfiguration | None = None, +) -> RTCPeerConnection: + """Create an RTCPeerConnection with a custom persistent certificate. + + :param certificate: The RTCCertificate to use for DTLS. + :param configuration: Optional RTCConfiguration with ICE servers. + :return: RTCPeerConnection configured with the provided certificate. + """ + pc = RTCPeerConnection(configuration=configuration) + # Replace the auto-generated certificate with our persistent one + # Uses name-mangled private attribute access + pc._RTCPeerConnection__certificates = [certificate] # type: ignore[attr-defined] + return pc diff --git a/music_assistant/providers/sendspin/provider.py b/music_assistant/providers/sendspin/provider.py index 2045e990..bcfb16c4 100644 --- a/music_assistant/providers/sendspin/provider.py +++ b/music_assistant/providers/sendspin/provider.py @@ -16,6 +16,7 @@ from aiosendspin.server import ClientAddedEvent, ClientRemovedEvent, SendspinEve from music_assistant_models.enums import ProviderFeature from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user +from music_assistant.helpers.webrtc_certificate import create_peer_connection_with_certificate from music_assistant.mass import MusicAssistant from music_assistant.models.player_provider import PlayerProvider from music_assistant.providers.sendspin.player import SendspinPlayer @@ -156,9 +157,10 @@ class SendspinProvider(PlayerProvider): len(ice_servers), ) - # Create peer connection with ICE servers + # Create peer connection with ICE servers and persistent certificate config = RTCConfiguration(iceServers=[RTCIceServer(**server) for server in ice_servers]) - pc = RTCPeerConnection(configuration=config) + certificate = self.mass.webserver.remote_access.certificate + pc = create_peer_connection_with_certificate(certificate, configuration=config) session = SendspinWebRTCSession( session_id=session_id, diff --git a/tests/test_remote_access.py b/tests/test_remote_access.py index 9e6b514f..202ba148 100644 --- a/tests/test_remote_access.py +++ b/tests/test_remote_access.py @@ -2,28 +2,45 @@ from unittest.mock import AsyncMock, Mock, patch -from aiortc import RTCConfiguration, RTCPeerConnection +import pytest +from aiortc import RTCConfiguration, RTCIceServer, RTCPeerConnection +from aiortc.rtcdtlstransport import RTCCertificate from music_assistant.controllers.webserver.remote_access import RemoteAccessInfo from music_assistant.controllers.webserver.remote_access.gateway import ( WebRTCGateway, WebRTCSession, - generate_remote_id, ) +from music_assistant.helpers.webrtc_certificate import ( + _generate_certificate, + create_peer_connection_with_certificate, + get_remote_id_from_certificate, +) + + +@pytest.fixture +def mock_certificate() -> Mock: + """Create a mock RTCCertificate for testing.""" + cert = Mock() + mock_fingerprint = Mock() + mock_fingerprint.algorithm = "sha-256" + mock_fingerprint.value = ( + "AA:BB:CC:DD:EE:FF:00:11:22:33:44:55:66:77:88:99:" + "AA:BB:CC:DD:EE:FF:00:11:22:33:44:55:66:77:88:99" + ) + cert.getFingerprints.return_value = [mock_fingerprint] + return cert -async def test_generate_remote_id() -> None: - """Test remote ID generation format.""" - remote_id = generate_remote_id() - assert remote_id.startswith("MA-") - parts = remote_id.split("-") - assert len(parts) == 3 - assert parts[0] == "MA" - assert len(parts[1]) == 4 - assert len(parts[2]) == 4 - # Ensure it's alphanumeric - assert parts[1].isalnum() - assert parts[2].isalnum() +async def test_get_remote_id_from_certificate(mock_certificate: Mock) -> None: + """Test remote ID generation from certificate fingerprint.""" + remote_id = get_remote_id_from_certificate(mock_certificate) + + # Should be base32 encoded, uppercase, no padding + assert remote_id.isalnum() + assert remote_id == remote_id.upper() + # 128 bits = 16 bytes -> 26 base32 chars (without padding) + assert len(remote_id) == 26 async def test_remote_access_info_dataclass() -> None: @@ -32,7 +49,7 @@ async def test_remote_access_info_dataclass() -> None: enabled=True, running=True, connected=False, - remote_id="MA-TEST-1234", + remote_id="VVPN3TLP34YMGIZDINCEKQKSIR", using_ha_cloud=False, signaling_url="wss://signaling.music-assistant.io/ws", ) @@ -40,22 +57,23 @@ async def test_remote_access_info_dataclass() -> None: assert info.enabled is True assert info.running is True assert info.connected is False - assert info.remote_id == "MA-TEST-1234" + assert info.remote_id == "VVPN3TLP34YMGIZDINCEKQKSIR" assert info.using_ha_cloud is False assert info.signaling_url == "wss://signaling.music-assistant.io/ws" -async def test_webrtc_gateway_initialization() -> None: +async def test_webrtc_gateway_initialization(mock_certificate: Mock) -> None: """Test WebRTCGateway initializes correctly.""" mock_session = Mock() gateway = WebRTCGateway( http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, signaling_url="wss://test.example.com/ws", local_ws_url="ws://localhost:8095/ws", - remote_id="MA-TEST-1234", ) - assert gateway.remote_id == "MA-TEST-1234" + assert gateway._remote_id == "TEST-REMOTE-ID" assert gateway.signaling_url == "wss://test.example.com/ws" assert gateway.local_ws_url == "ws://localhost:8095/ws" assert gateway.is_running is False @@ -63,7 +81,7 @@ async def test_webrtc_gateway_initialization() -> None: assert len(gateway.ice_servers) > 0 -async def test_webrtc_gateway_custom_ice_servers() -> None: +async def test_webrtc_gateway_custom_ice_servers(mock_certificate: Mock) -> None: """Test WebRTCGateway accepts custom ICE servers.""" mock_session = Mock() custom_ice_servers = [ @@ -73,16 +91,22 @@ async def test_webrtc_gateway_custom_ice_servers() -> None: gateway = WebRTCGateway( http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, ice_servers=custom_ice_servers, ) assert gateway.ice_servers == custom_ice_servers -async def test_webrtc_gateway_start_stop() -> None: +async def test_webrtc_gateway_start_stop(mock_certificate: Mock) -> None: """Test WebRTCGateway start and stop.""" mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session) + gateway = WebRTCGateway( + http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, + ) # Mock the _run method to avoid actual connection with patch.object(gateway, "_run", new_callable=AsyncMock): @@ -94,33 +118,32 @@ async def test_webrtc_gateway_start_stop() -> None: assert gateway.is_running is False -async def test_webrtc_gateway_generate_remote_id() -> None: - """Test that WebRTCGateway generates a remote ID if not provided.""" - mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session) - - assert gateway.remote_id is not None - assert gateway.remote_id.startswith("MA-") - - -async def test_webrtc_gateway_handle_registration_message() -> None: +async def test_webrtc_gateway_handle_registration_message(mock_certificate: Mock) -> None: """Test WebRTCGateway handles registration confirmation.""" mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session, remote_id="MA-TEST-1234") + gateway = WebRTCGateway( + http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, + ) # Mock signaling WebSocket gateway._signaling_ws = Mock() - message = {"type": "registered", "remoteId": "MA-TEST-1234"} + message = {"type": "registered", "remoteId": "TEST-REMOTE-ID"} await gateway._handle_signaling_message(message) # Should log but not crash -async def test_webrtc_gateway_handle_ping_pong() -> None: +async def test_webrtc_gateway_handle_ping_pong(mock_certificate: Mock) -> None: """Test WebRTCGateway handles ping/pong messages.""" mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session) + gateway = WebRTCGateway( + http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, + ) # Mock signaling WebSocket mock_ws = AsyncMock() @@ -137,20 +160,28 @@ async def test_webrtc_gateway_handle_ping_pong() -> None: mock_ws.send_json.assert_not_called() -async def test_webrtc_gateway_handle_error_message() -> None: +async def test_webrtc_gateway_handle_error_message(mock_certificate: Mock) -> None: """Test WebRTCGateway handles error messages.""" mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session) + gateway = WebRTCGateway( + http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, + ) message = {"type": "error", "message": "Test error"} # Should log error but not crash await gateway._handle_signaling_message(message) -async def test_webrtc_gateway_create_session() -> None: +async def test_webrtc_gateway_create_session(mock_certificate: Mock) -> None: """Test WebRTCGateway creates sessions for clients.""" mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session) + gateway = WebRTCGateway( + http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, + ) session_id = "test-session-123" await gateway._create_session(session_id) @@ -163,10 +194,14 @@ async def test_webrtc_gateway_create_session() -> None: await gateway._close_session(session_id) -async def test_webrtc_gateway_close_session() -> None: +async def test_webrtc_gateway_close_session(mock_certificate: Mock) -> None: """Test WebRTCGateway closes sessions properly.""" mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session) + gateway = WebRTCGateway( + http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, + ) session_id = "test-session-456" await gateway._create_session(session_id) @@ -176,29 +211,41 @@ async def test_webrtc_gateway_close_session() -> None: assert session_id not in gateway.sessions -async def test_webrtc_gateway_close_nonexistent_session() -> None: +async def test_webrtc_gateway_close_nonexistent_session(mock_certificate: Mock) -> None: """Test WebRTCGateway handles closing non-existent session gracefully.""" mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session) + gateway = WebRTCGateway( + http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, + ) # Should not raise an error await gateway._close_session("nonexistent-session") -async def test_webrtc_gateway_default_ice_servers() -> None: +async def test_webrtc_gateway_default_ice_servers(mock_certificate: Mock) -> None: """Test WebRTCGateway uses default ICE servers.""" mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session) + gateway = WebRTCGateway( + http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, + ) assert len(gateway.ice_servers) > 0 # Should have at least one STUN server assert any("stun:" in server["urls"] for server in gateway.ice_servers) -async def test_webrtc_gateway_handle_client_connected() -> None: +async def test_webrtc_gateway_handle_client_connected(mock_certificate: Mock) -> None: """Test WebRTCGateway handles client-connected message.""" mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session) + gateway = WebRTCGateway( + http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, + ) message = {"type": "client-connected", "sessionId": "test-session"} await gateway._handle_signaling_message(message) @@ -210,10 +257,14 @@ async def test_webrtc_gateway_handle_client_connected() -> None: await gateway._close_session("test-session") -async def test_webrtc_gateway_handle_client_disconnected() -> None: +async def test_webrtc_gateway_handle_client_disconnected(mock_certificate: Mock) -> None: """Test WebRTCGateway handles client-disconnected message.""" mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session) + gateway = WebRTCGateway( + http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, + ) # Create a session first session_id = "test-disconnect-session" @@ -228,10 +279,14 @@ async def test_webrtc_gateway_handle_client_disconnected() -> None: assert session_id not in gateway.sessions -async def test_webrtc_gateway_reconnection_logic() -> None: +async def test_webrtc_gateway_reconnection_logic(mock_certificate: Mock) -> None: """Test WebRTCGateway has proper reconnection backoff.""" mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session) + gateway = WebRTCGateway( + http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, + ) # Check initial reconnect delay assert gateway._current_reconnect_delay == 5 @@ -272,10 +327,14 @@ async def test_webrtc_gateway_session_data_structures() -> None: await pc.close() -async def test_webrtc_gateway_handle_offer_without_session() -> None: +async def test_webrtc_gateway_handle_offer_without_session(mock_certificate: Mock) -> None: """Test WebRTCGateway handles offer for non-existent session gracefully.""" mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session) + gateway = WebRTCGateway( + http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, + ) # Try to handle offer for non-existent session offer_data = {"sdp": "test-sdp", "type": "offer"} @@ -284,10 +343,14 @@ async def test_webrtc_gateway_handle_offer_without_session() -> None: # Should not crash -async def test_webrtc_gateway_handle_ice_candidate_without_session() -> None: +async def test_webrtc_gateway_handle_ice_candidate_without_session(mock_certificate: Mock) -> None: """Test WebRTCGateway handles ICE candidate for non-existent session gracefully.""" mock_session = Mock() - gateway = WebRTCGateway(http_session=mock_session) + gateway = WebRTCGateway( + http_session=mock_session, + remote_id="TEST-REMOTE-ID", + certificate=mock_certificate, + ) # Try to handle ICE candidate for non-existent session candidate_data = { @@ -298,3 +361,33 @@ async def test_webrtc_gateway_handle_ice_candidate_without_session() -> None: await gateway._handle_ice_candidate("nonexistent-session", candidate_data) # Should not crash + + +async def test_create_peer_connection_with_certificate() -> None: + """Test that create_peer_connection_with_certificate correctly sets the custom certificate. + + This verifies the fragile name-mangled private attribute access works correctly + and that our custom certificate fully replaces the auto-generated one, which is + critical for DTLS pinning. + """ + # First verify the name-mangled attribute exists on RTCPeerConnection. + # If aiortc changes its internals, this will fail and alert us to update our code. + pc = RTCPeerConnection() + try: + assert hasattr(pc, "_RTCPeerConnection__certificates") + finally: + await pc.close() + + # Now test our function correctly sets the certificate + private_key, cert = _generate_certificate() + certificate = RTCCertificate(key=private_key, cert=cert) + config = RTCConfiguration(iceServers=[RTCIceServer(urls="stun:stun.example.com:3478")]) + + pc = create_peer_connection_with_certificate(certificate, configuration=config) + + try: + certificates = pc._RTCPeerConnection__certificates # type: ignore[attr-defined] + assert len(certificates) == 1 + assert certificates[0] is certificate + finally: + await pc.close()