Add persistent DTLS certificate management for WebRTC, enabling client-side certificate pinning, which significantly improves security.
from music_assistant_models.enums import EventType
from music_assistant.constants import CONF_CORE
-from music_assistant.controllers.webserver.remote_access.gateway import (
- WebRTCGateway,
- generate_remote_id,
+from music_assistant.controllers.webserver.remote_access.gateway import WebRTCGateway
+from music_assistant.helpers.webrtc_certificate import (
+ get_or_create_webrtc_certificate,
+ get_remote_id_from_certificate,
)
if TYPE_CHECKING:
+ from aiortc.rtcdtlstransport import RTCCertificate
from music_assistant_models.event import MassEvent
from music_assistant.controllers.webserver import WebserverController
SIGNALING_SERVER_URL = "wss://signaling.music-assistant.io/ws"
CONF_KEY_MAIN = "remote_access"
-CONF_REMOTE_ID = "remote_id"
CONF_ENABLED = "enabled"
TASK_ID_START_GATEWAY = "remote_access_start_gateway"
self.mass = webserver.mass
self.logger = webserver.logger.getChild("remote_access")
self.gateway: WebRTCGateway | None = None
- self._remote_id: str | None = None
+ self._remote_id: str
+ self._certificate: RTCCertificate
self._enabled: bool = False
self._using_ha_cloud: bool = False
self._starting: bool = False # Prevents concurrent gateway starts
async def setup(self) -> None:
"""Initialize the remote access manager."""
+ self._certificate = get_or_create_webrtc_certificate(self.mass.storage_path)
+
+ self._remote_id = get_remote_id_from_certificate(self._certificate)
+ self.logger.info("WebRTC certificate remote_id: %s", self._remote_id)
+
enabled_value = self.mass.config.get(f"{CONF_CORE}/{CONF_KEY_MAIN}/{CONF_ENABLED}", False)
self._enabled = bool(enabled_value)
- remote_id_value = self.mass.config.get(
- f"{CONF_CORE}/{CONF_KEY_MAIN}/{CONF_REMOTE_ID}", None
- )
- if not remote_id_value:
- remote_id_value = generate_remote_id()
- self.mass.config.set(f"{CONF_CORE}/{CONF_KEY_MAIN}/{CONF_REMOTE_ID}", remote_id_value)
-
- self._remote_id = str(remote_id_value)
self._register_api_commands()
self.mass.subscribe(self._on_providers_updated, EventType.PROVIDERS_UPDATED)
if self._enabled:
self.gateway = WebRTCGateway(
http_session=self.mass.http_session,
+ remote_id=self._remote_id,
+ certificate=self._certificate,
signaling_url=SIGNALING_SERVER_URL,
local_ws_url=local_ws_url,
- remote_id=self._remote_id,
ice_servers=ice_servers,
# Pass callback to get fresh ICE servers for each client connection
# This ensures TURN credentials are always valid
return self.gateway is not None and self.gateway.is_connected
@property
- def remote_id(self) -> str | None:
+ def remote_id(self) -> str:
"""Return the current Remote ID."""
return self._remote_id
+ @property
+ def certificate(self) -> RTCCertificate:
+ """Return the persistent WebRTC DTLS certificate."""
+ return self._certificate
+
def _register_api_commands(self) -> None:
"""Register API commands for remote access."""
enabled=self.is_enabled,
running=self.is_running,
connected=self.is_connected,
- remote_id=self._remote_id or "",
+ remote_id=self._remote_id,
using_ha_cloud=self._using_ha_cloud,
signaling_url=SIGNALING_SERVER_URL,
)
import contextlib
import json
import logging
-import secrets
-import string
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
-from typing import Any
+from typing import TYPE_CHECKING, Any
import aiohttp
from aiortc import RTCConfiguration, RTCIceServer, RTCPeerConnection, RTCSessionDescription
from aiortc.sdp import candidate_from_sdp
from music_assistant.constants import MASS_LOGGER_NAME, VERBOSE_LOG_LEVEL
+from music_assistant.helpers.webrtc_certificate import create_peer_connection_with_certificate
+
+if TYPE_CHECKING:
+ from aiortc.rtcdtlstransport import RTCCertificate
LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.remote_access")
logging.getLogger("aiortc").setLevel(logging.WARNING)
-def generate_remote_id() -> str:
- """Generate a unique Remote ID in the format MA-XXXX-XXXX."""
- chars = string.ascii_uppercase + string.digits
- part1 = "".join(secrets.choice(chars) for _ in range(4))
- part2 = "".join(secrets.choice(chars) for _ in range(4))
- return f"MA-{part1}-{part2}"
-
-
@dataclass
class WebRTCSession:
"""Represents an active WebRTC session with a remote client."""
def __init__(
self,
http_session: aiohttp.ClientSession,
+ remote_id: str,
+ certificate: RTCCertificate,
signaling_url: str = "wss://signaling.music-assistant.io/ws",
local_ws_url: str = "ws://localhost:8095/ws",
ice_servers: list[dict[str, Any]] | None = None,
- remote_id: str | None = None,
ice_servers_callback: Callable[[], Awaitable[list[dict[str, Any]]]] | None = None,
) -> None:
"""Initialize the WebRTC Gateway.
:param http_session: Shared aiohttp ClientSession to use for HTTP/WebSocket connections.
+ :param remote_id: Remote ID for this server instance.
+ :param certificate: Persistent RTCCertificate for DTLS, enabling client-side pinning.
:param signaling_url: WebSocket URL of the signaling server.
:param local_ws_url: Local WebSocket URL to bridge to.
:param ice_servers: List of ICE server configurations (used at registration time).
- :param remote_id: Optional Remote ID to use (generated if not provided).
:param ice_servers_callback: Optional callback to fetch fresh ICE servers for each session.
If provided, this will be called for each client connection to get fresh TURN
credentials. If not provided, the static ice_servers will be used.
self.http_session = http_session
self.signaling_url = signaling_url
self.local_ws_url = local_ws_url
- self.remote_id = remote_id or generate_remote_id()
+ self._remote_id = remote_id
+ self._certificate = certificate
self.logger = LOGGER
self._ice_servers_callback = ice_servers_callback
await self._signaling_ws.send_json(
{
"type": "register-server",
- "remoteId": self.remote_id,
+ "remoteId": self._remote_id,
"iceServers": self.ice_servers,
}
)
config = RTCConfiguration(
iceServers=[RTCIceServer(**server) for server in session_ice_servers]
)
- pc = RTCPeerConnection(configuration=config)
+ pc = create_peer_connection_with_certificate(self._certificate, configuration=config)
session = WebRTCSession(session_id=session_id, peer_connection=pc)
self.sessions[session_id] = session
--- /dev/null
+"""WebRTC DTLS Certificate Management.
+
+This module provides persistent DTLS certificate management for WebRTC connections.
+The certificate is generated once and stored persistently, enabling client-side
+certificate pinning for authentication.
+"""
+
+from __future__ import annotations
+
+import base64
+import logging
+import stat
+from datetime import UTC, datetime, timedelta
+from pathlib import Path
+
+from aiortc import RTCConfiguration, RTCPeerConnection
+from aiortc.rtcdtlstransport import RTCCertificate
+from cryptography import x509
+from cryptography.hazmat.primitives import hashes, serialization
+from cryptography.hazmat.primitives.asymmetric import ec
+from cryptography.x509.oid import NameOID
+
+LOGGER = logging.getLogger(__name__)
+
+CERT_FILENAME = "webrtc_certificate.pem"
+KEY_FILENAME = "webrtc_private_key.pem"
+
+CERT_VALIDITY_DAYS = 3650 # 10 years
+
+CERT_RENEWAL_THRESHOLD_DAYS = 30
+
+
+def _generate_certificate() -> tuple[ec.EllipticCurvePrivateKey, x509.Certificate]:
+ """Generate a new ECDSA certificate for WebRTC DTLS.
+
+ :return: Tuple of (private_key, certificate).
+ """
+ # Generate ECDSA key (SECP256R1 - same as aiortc default)
+ private_key = ec.generate_private_key(ec.SECP256R1())
+
+ now = datetime.now(UTC)
+ not_before = now - timedelta(days=1)
+ not_after = now + timedelta(days=CERT_VALIDITY_DAYS)
+
+ subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Music Assistant WebRTC")])
+
+ cert = (
+ x509.CertificateBuilder()
+ .subject_name(subject)
+ .issuer_name(subject)
+ .public_key(private_key.public_key())
+ .serial_number(x509.random_serial_number())
+ .not_valid_before(not_before)
+ .not_valid_after(not_after)
+ .sign(private_key, hashes.SHA256())
+ )
+
+ return private_key, cert
+
+
+def _save_certificate(
+ storage_path: str,
+ private_key: ec.EllipticCurvePrivateKey,
+ cert: x509.Certificate,
+) -> None:
+ """Save certificate and private key to disk.
+
+ :param storage_path: Directory to store the files.
+ :param private_key: The EC private key.
+ :param cert: The X.509 certificate.
+ """
+ cert_path = Path(storage_path) / CERT_FILENAME
+ key_path = Path(storage_path) / KEY_FILENAME
+
+ cert_pem = cert.public_bytes(serialization.Encoding.PEM)
+ cert_path.write_bytes(cert_pem)
+
+ key_pem = private_key.private_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PrivateFormat.PKCS8,
+ encryption_algorithm=serialization.NoEncryption(),
+ )
+ key_path.write_bytes(key_pem)
+
+ # Set restrictive permissions on private key (owner read/write only)
+ key_path.chmod(stat.S_IRUSR | stat.S_IWUSR)
+
+ LOGGER.info("Saved WebRTC certificate to %s", cert_path)
+
+
+def _load_certificate(
+ storage_path: str,
+) -> tuple[ec.EllipticCurvePrivateKey, x509.Certificate] | None:
+ """Load certificate and private key from disk.
+
+ :param storage_path: Directory containing the files.
+ :return: Tuple of (private_key, certificate) or None if files don't exist.
+ """
+ cert_path = Path(storage_path) / CERT_FILENAME
+ key_path = Path(storage_path) / KEY_FILENAME
+
+ if not cert_path.exists() or not key_path.exists():
+ return None
+
+ try:
+ cert_pem = cert_path.read_bytes()
+ cert = x509.load_pem_x509_certificate(cert_pem)
+
+ key_pem = key_path.read_bytes()
+ private_key = serialization.load_pem_private_key(key_pem, password=None)
+
+ if not isinstance(private_key, ec.EllipticCurvePrivateKey):
+ LOGGER.warning("WebRTC private key is not an EC key, will regenerate")
+ return None
+
+ return private_key, cert
+ except Exception as err:
+ LOGGER.warning("Failed to load WebRTC certificate: %s", err)
+ return None
+
+
+def _is_certificate_valid(cert: x509.Certificate) -> bool:
+ """Check if certificate is still valid with enough time remaining.
+
+ :param cert: The X.509 certificate to check.
+ :return: True if certificate is valid and has sufficient time remaining.
+ """
+ now = datetime.now(UTC)
+ not_after = cert.not_valid_after_utc
+
+ if now >= not_after:
+ LOGGER.info("WebRTC certificate has expired")
+ return False
+
+ days_remaining = (not_after - now).days
+ if days_remaining < CERT_RENEWAL_THRESHOLD_DAYS:
+ LOGGER.info(
+ "WebRTC certificate expires in %d days, will regenerate",
+ days_remaining,
+ )
+ return False
+
+ return True
+
+
+def get_or_create_webrtc_certificate(storage_path: str) -> RTCCertificate:
+ """Get or create a persistent WebRTC DTLS certificate.
+
+ Loads an existing certificate from disk if available and valid.
+ Otherwise, generates a new certificate and saves it.
+
+ :param storage_path: Directory to store/load the certificate files.
+ :return: RTCCertificate instance for use with WebRTC.
+ """
+ loaded = _load_certificate(storage_path)
+
+ if loaded is not None:
+ private_key, cert = loaded
+
+ if _is_certificate_valid(cert):
+ return RTCCertificate(key=private_key, cert=cert)
+
+ LOGGER.info("Generating new WebRTC DTLS certificate (valid for %d days)", CERT_VALIDITY_DAYS)
+ private_key, cert = _generate_certificate()
+ _save_certificate(storage_path, private_key, cert)
+
+ return RTCCertificate(key=private_key, cert=cert)
+
+
+def _get_certificate_fingerprint(certificate: RTCCertificate) -> str:
+ """Get the SHA-256 fingerprint of a certificate.
+
+ :param certificate: The RTCCertificate to get the fingerprint for.
+ :return: SHA-256 fingerprint as colon-separated hex string (e.g., "A1:B2:C3:...").
+ """
+ fingerprints = certificate.getFingerprints()
+ for fp in fingerprints:
+ if fp.algorithm == "sha-256":
+ return fp.value
+ raise ValueError("SHA-256 fingerprint not found in certificate")
+
+
+def get_remote_id_from_certificate(certificate: RTCCertificate) -> str:
+ """Generate a remote ID from the certificate fingerprint.
+
+ Uses base32-encoded 128-bit truncation of the SHA-256 fingerprint.
+ This creates a deterministic remote ID tied to the certificate.
+
+ :param certificate: The RTCCertificate to derive the remote ID from.
+ :return: Custom base32-encoded (with 9s instead of 2s) remote ID string
+ (26 characters, uppercase, no-padding).
+ """
+ fingerprint = _get_certificate_fingerprint(certificate)
+
+ # Parse the colon-separated hex fingerprint to bytes
+ # Format: "A1:B2:C3:D4:..." -> bytes
+ fingerprint_bytes = bytes.fromhex(fingerprint.replace(":", ""))
+
+ # Take first 128 bits (16 bytes) of SHA-256
+ truncated = fingerprint_bytes[:16]
+
+ # Base32 encode (with 9s instead of 2s) and return (uppercase) without padding
+ return base64.b32encode(truncated).decode("ascii").rstrip("=").replace("2", "9")
+
+
+def create_peer_connection_with_certificate(
+ certificate: RTCCertificate,
+ configuration: RTCConfiguration | None = None,
+) -> RTCPeerConnection:
+ """Create an RTCPeerConnection with a custom persistent certificate.
+
+ :param certificate: The RTCCertificate to use for DTLS.
+ :param configuration: Optional RTCConfiguration with ICE servers.
+ :return: RTCPeerConnection configured with the provided certificate.
+ """
+ pc = RTCPeerConnection(configuration=configuration)
+ # Replace the auto-generated certificate with our persistent one
+ # Uses name-mangled private attribute access
+ pc._RTCPeerConnection__certificates = [certificate] # type: ignore[attr-defined]
+ return pc
from music_assistant_models.enums import ProviderFeature
from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user
+from music_assistant.helpers.webrtc_certificate import create_peer_connection_with_certificate
from music_assistant.mass import MusicAssistant
from music_assistant.models.player_provider import PlayerProvider
from music_assistant.providers.sendspin.player import SendspinPlayer
len(ice_servers),
)
- # Create peer connection with ICE servers
+ # Create peer connection with ICE servers and persistent certificate
config = RTCConfiguration(iceServers=[RTCIceServer(**server) for server in ice_servers])
- pc = RTCPeerConnection(configuration=config)
+ certificate = self.mass.webserver.remote_access.certificate
+ pc = create_peer_connection_with_certificate(certificate, configuration=config)
session = SendspinWebRTCSession(
session_id=session_id,
from unittest.mock import AsyncMock, Mock, patch
-from aiortc import RTCConfiguration, RTCPeerConnection
+import pytest
+from aiortc import RTCConfiguration, RTCIceServer, RTCPeerConnection
+from aiortc.rtcdtlstransport import RTCCertificate
from music_assistant.controllers.webserver.remote_access import RemoteAccessInfo
from music_assistant.controllers.webserver.remote_access.gateway import (
WebRTCGateway,
WebRTCSession,
- generate_remote_id,
)
+from music_assistant.helpers.webrtc_certificate import (
+ _generate_certificate,
+ create_peer_connection_with_certificate,
+ get_remote_id_from_certificate,
+)
+
+
+@pytest.fixture
+def mock_certificate() -> Mock:
+ """Create a mock RTCCertificate for testing."""
+ cert = Mock()
+ mock_fingerprint = Mock()
+ mock_fingerprint.algorithm = "sha-256"
+ mock_fingerprint.value = (
+ "AA:BB:CC:DD:EE:FF:00:11:22:33:44:55:66:77:88:99:"
+ "AA:BB:CC:DD:EE:FF:00:11:22:33:44:55:66:77:88:99"
+ )
+ cert.getFingerprints.return_value = [mock_fingerprint]
+ return cert
-async def test_generate_remote_id() -> None:
- """Test remote ID generation format."""
- remote_id = generate_remote_id()
- assert remote_id.startswith("MA-")
- parts = remote_id.split("-")
- assert len(parts) == 3
- assert parts[0] == "MA"
- assert len(parts[1]) == 4
- assert len(parts[2]) == 4
- # Ensure it's alphanumeric
- assert parts[1].isalnum()
- assert parts[2].isalnum()
+async def test_get_remote_id_from_certificate(mock_certificate: Mock) -> None:
+ """Test remote ID generation from certificate fingerprint."""
+ remote_id = get_remote_id_from_certificate(mock_certificate)
+
+ # Should be base32 encoded, uppercase, no padding
+ assert remote_id.isalnum()
+ assert remote_id == remote_id.upper()
+ # 128 bits = 16 bytes -> 26 base32 chars (without padding)
+ assert len(remote_id) == 26
async def test_remote_access_info_dataclass() -> None:
enabled=True,
running=True,
connected=False,
- remote_id="MA-TEST-1234",
+ remote_id="VVPN3TLP34YMGIZDINCEKQKSIR",
using_ha_cloud=False,
signaling_url="wss://signaling.music-assistant.io/ws",
)
assert info.enabled is True
assert info.running is True
assert info.connected is False
- assert info.remote_id == "MA-TEST-1234"
+ assert info.remote_id == "VVPN3TLP34YMGIZDINCEKQKSIR"
assert info.using_ha_cloud is False
assert info.signaling_url == "wss://signaling.music-assistant.io/ws"
-async def test_webrtc_gateway_initialization() -> None:
+async def test_webrtc_gateway_initialization(mock_certificate: Mock) -> None:
"""Test WebRTCGateway initializes correctly."""
mock_session = Mock()
gateway = WebRTCGateway(
http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
signaling_url="wss://test.example.com/ws",
local_ws_url="ws://localhost:8095/ws",
- remote_id="MA-TEST-1234",
)
- assert gateway.remote_id == "MA-TEST-1234"
+ assert gateway._remote_id == "TEST-REMOTE-ID"
assert gateway.signaling_url == "wss://test.example.com/ws"
assert gateway.local_ws_url == "ws://localhost:8095/ws"
assert gateway.is_running is False
assert len(gateway.ice_servers) > 0
-async def test_webrtc_gateway_custom_ice_servers() -> None:
+async def test_webrtc_gateway_custom_ice_servers(mock_certificate: Mock) -> None:
"""Test WebRTCGateway accepts custom ICE servers."""
mock_session = Mock()
custom_ice_servers = [
gateway = WebRTCGateway(
http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
ice_servers=custom_ice_servers,
)
assert gateway.ice_servers == custom_ice_servers
-async def test_webrtc_gateway_start_stop() -> None:
+async def test_webrtc_gateway_start_stop(mock_certificate: Mock) -> None:
"""Test WebRTCGateway start and stop."""
mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session)
+ gateway = WebRTCGateway(
+ http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
+ )
# Mock the _run method to avoid actual connection
with patch.object(gateway, "_run", new_callable=AsyncMock):
assert gateway.is_running is False
-async def test_webrtc_gateway_generate_remote_id() -> None:
- """Test that WebRTCGateway generates a remote ID if not provided."""
- mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session)
-
- assert gateway.remote_id is not None
- assert gateway.remote_id.startswith("MA-")
-
-
-async def test_webrtc_gateway_handle_registration_message() -> None:
+async def test_webrtc_gateway_handle_registration_message(mock_certificate: Mock) -> None:
"""Test WebRTCGateway handles registration confirmation."""
mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session, remote_id="MA-TEST-1234")
+ gateway = WebRTCGateway(
+ http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
+ )
# Mock signaling WebSocket
gateway._signaling_ws = Mock()
- message = {"type": "registered", "remoteId": "MA-TEST-1234"}
+ message = {"type": "registered", "remoteId": "TEST-REMOTE-ID"}
await gateway._handle_signaling_message(message)
# Should log but not crash
-async def test_webrtc_gateway_handle_ping_pong() -> None:
+async def test_webrtc_gateway_handle_ping_pong(mock_certificate: Mock) -> None:
"""Test WebRTCGateway handles ping/pong messages."""
mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session)
+ gateway = WebRTCGateway(
+ http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
+ )
# Mock signaling WebSocket
mock_ws = AsyncMock()
mock_ws.send_json.assert_not_called()
-async def test_webrtc_gateway_handle_error_message() -> None:
+async def test_webrtc_gateway_handle_error_message(mock_certificate: Mock) -> None:
"""Test WebRTCGateway handles error messages."""
mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session)
+ gateway = WebRTCGateway(
+ http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
+ )
message = {"type": "error", "message": "Test error"}
# Should log error but not crash
await gateway._handle_signaling_message(message)
-async def test_webrtc_gateway_create_session() -> None:
+async def test_webrtc_gateway_create_session(mock_certificate: Mock) -> None:
"""Test WebRTCGateway creates sessions for clients."""
mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session)
+ gateway = WebRTCGateway(
+ http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
+ )
session_id = "test-session-123"
await gateway._create_session(session_id)
await gateway._close_session(session_id)
-async def test_webrtc_gateway_close_session() -> None:
+async def test_webrtc_gateway_close_session(mock_certificate: Mock) -> None:
"""Test WebRTCGateway closes sessions properly."""
mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session)
+ gateway = WebRTCGateway(
+ http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
+ )
session_id = "test-session-456"
await gateway._create_session(session_id)
assert session_id not in gateway.sessions
-async def test_webrtc_gateway_close_nonexistent_session() -> None:
+async def test_webrtc_gateway_close_nonexistent_session(mock_certificate: Mock) -> None:
"""Test WebRTCGateway handles closing non-existent session gracefully."""
mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session)
+ gateway = WebRTCGateway(
+ http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
+ )
# Should not raise an error
await gateway._close_session("nonexistent-session")
-async def test_webrtc_gateway_default_ice_servers() -> None:
+async def test_webrtc_gateway_default_ice_servers(mock_certificate: Mock) -> None:
"""Test WebRTCGateway uses default ICE servers."""
mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session)
+ gateway = WebRTCGateway(
+ http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
+ )
assert len(gateway.ice_servers) > 0
# Should have at least one STUN server
assert any("stun:" in server["urls"] for server in gateway.ice_servers)
-async def test_webrtc_gateway_handle_client_connected() -> None:
+async def test_webrtc_gateway_handle_client_connected(mock_certificate: Mock) -> None:
"""Test WebRTCGateway handles client-connected message."""
mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session)
+ gateway = WebRTCGateway(
+ http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
+ )
message = {"type": "client-connected", "sessionId": "test-session"}
await gateway._handle_signaling_message(message)
await gateway._close_session("test-session")
-async def test_webrtc_gateway_handle_client_disconnected() -> None:
+async def test_webrtc_gateway_handle_client_disconnected(mock_certificate: Mock) -> None:
"""Test WebRTCGateway handles client-disconnected message."""
mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session)
+ gateway = WebRTCGateway(
+ http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
+ )
# Create a session first
session_id = "test-disconnect-session"
assert session_id not in gateway.sessions
-async def test_webrtc_gateway_reconnection_logic() -> None:
+async def test_webrtc_gateway_reconnection_logic(mock_certificate: Mock) -> None:
"""Test WebRTCGateway has proper reconnection backoff."""
mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session)
+ gateway = WebRTCGateway(
+ http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
+ )
# Check initial reconnect delay
assert gateway._current_reconnect_delay == 5
await pc.close()
-async def test_webrtc_gateway_handle_offer_without_session() -> None:
+async def test_webrtc_gateway_handle_offer_without_session(mock_certificate: Mock) -> None:
"""Test WebRTCGateway handles offer for non-existent session gracefully."""
mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session)
+ gateway = WebRTCGateway(
+ http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
+ )
# Try to handle offer for non-existent session
offer_data = {"sdp": "test-sdp", "type": "offer"}
# Should not crash
-async def test_webrtc_gateway_handle_ice_candidate_without_session() -> None:
+async def test_webrtc_gateway_handle_ice_candidate_without_session(mock_certificate: Mock) -> None:
"""Test WebRTCGateway handles ICE candidate for non-existent session gracefully."""
mock_session = Mock()
- gateway = WebRTCGateway(http_session=mock_session)
+ gateway = WebRTCGateway(
+ http_session=mock_session,
+ remote_id="TEST-REMOTE-ID",
+ certificate=mock_certificate,
+ )
# Try to handle ICE candidate for non-existent session
candidate_data = {
await gateway._handle_ice_candidate("nonexistent-session", candidate_data)
# Should not crash
+
+
+async def test_create_peer_connection_with_certificate() -> None:
+ """Test that create_peer_connection_with_certificate correctly sets the custom certificate.
+
+ This verifies the fragile name-mangled private attribute access works correctly
+ and that our custom certificate fully replaces the auto-generated one, which is
+ critical for DTLS pinning.
+ """
+ # First verify the name-mangled attribute exists on RTCPeerConnection.
+ # If aiortc changes its internals, this will fail and alert us to update our code.
+ pc = RTCPeerConnection()
+ try:
+ assert hasattr(pc, "_RTCPeerConnection__certificates")
+ finally:
+ await pc.close()
+
+ # Now test our function correctly sets the certificate
+ private_key, cert = _generate_certificate()
+ certificate = RTCCertificate(key=private_key, cert=cert)
+ config = RTCConfiguration(iceServers=[RTCIceServer(urls="stun:stun.example.com:3478")])
+
+ pc = create_peer_connection_with_certificate(certificate, configuration=config)
+
+ try:
+ certificates = pc._RTCPeerConnection__certificates # type: ignore[attr-defined]
+ assert len(certificates) == 1
+ assert certificates[0] is certificate
+ finally:
+ await pc.close()