From: Marcel van der Veldt Date: Mon, 1 Dec 2025 01:30:01 +0000 (+0100) Subject: Move remote access into its own controller X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=ca7ca36bebcaab3e45a5ea9328e365e5e45bf62a;p=music-assistant-server.git Move remote access into its own controller --- diff --git a/music_assistant/constants.py b/music_assistant/constants.py index 01b545fc..09b24eab 100644 --- a/music_assistant/constants.py +++ b/music_assistant/constants.py @@ -141,6 +141,7 @@ CONFIGURABLE_CORE_CONTROLLERS = ( "cache", "music", "player_queues", + "remote_access", ) VERBOSE_LOG_LEVEL: Final[int] = 5 PROVIDERS_WITH_SHAREABLE_URLS = ("spotify", "qobuz") diff --git a/music_assistant/controllers/remote_access/__init__.py b/music_assistant/controllers/remote_access/__init__.py new file mode 100644 index 00000000..16eb8066 --- /dev/null +++ b/music_assistant/controllers/remote_access/__init__.py @@ -0,0 +1,5 @@ +"""Remote Access controller for Music Assistant.""" + +from music_assistant.controllers.remote_access.controller import RemoteAccessController + +__all__ = ["RemoteAccessController"] diff --git a/music_assistant/controllers/remote_access/controller.py b/music_assistant/controllers/remote_access/controller.py new file mode 100644 index 00000000..c7f28913 --- /dev/null +++ b/music_assistant/controllers/remote_access/controller.py @@ -0,0 +1,313 @@ +""" +Remote Access Controller for Music Assistant. + +This controller manages WebRTC-based remote access to Music Assistant instances. +It connects to a signaling server and handles incoming WebRTC connections, +bridging them to the local WebSocket API. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from music_assistant_models.config_entries import ConfigEntry +from music_assistant_models.enums import ConfigEntryType, EventType + +from music_assistant.controllers.remote_access.gateway import WebRTCGateway, generate_remote_id +from music_assistant.helpers.api import api_command +from music_assistant.models.core_controller import CoreController + +if TYPE_CHECKING: + from music_assistant_models.config_entries import ConfigValueType, CoreConfig + from music_assistant_models.event import MassEvent + + from music_assistant import MusicAssistant + +# Signaling server URL +SIGNALING_SERVER_URL = "wss://signaling.music-assistant.io/ws" + +# Config keys +CONF_REMOTE_ID = "remote_id" +CONF_ENABLED = "enabled" + + +class RemoteAccessController(CoreController): + """Core Controller for WebRTC-based remote access.""" + + domain: str = "remote_access" + + def __init__(self, mass: MusicAssistant) -> None: + """Initialize the remote access controller. + + :param mass: MusicAssistant instance. + """ + super().__init__(mass) + self.manifest.name = "Remote Access" + self.manifest.description = ( + "WebRTC-based remote access for connecting to Music Assistant " + "from outside your local network (requires Home Assistant Cloud subscription)" + ) + self.manifest.icon = "cloud-lock" + self.gateway: WebRTCGateway | None = None + self._remote_id: str | None = None + self._setup_done = False + + async def get_config_entries( + self, + action: str | None = None, + values: dict[str, ConfigValueType] | None = None, + ) -> tuple[ConfigEntry, ...]: + """Return all Config Entries for this core module (if any).""" + entries = [] + + # Info alert about HA Cloud requirement + entries.append( + ConfigEntry( + key="remote_access_info", + type=ConfigEntryType.ALERT, + label="Remote Access requires an active Home Assistant Cloud subscription. " + "Once detected, remote access will be automatically enabled and you will " + "receive a unique Remote ID for connecting from outside your network.", + required=False, + ) + ) + + entries.append( + ConfigEntry( + key=CONF_ENABLED, + type=ConfigEntryType.BOOLEAN, + default_value=True, + label="Enable Remote Access", + description="Enable WebRTC-based remote access when Home Assistant Cloud " + "subscription is detected. Disable this if you don't want to use remote access.", + ) + ) + + entries.append( + ConfigEntry( + key=CONF_REMOTE_ID, + type=ConfigEntryType.STRING, + label="Remote ID", + description="Unique identifier for WebRTC remote access. " + "Generated automatically and should not be changed.", + required=False, + hidden=True, + ) + ) + + return tuple(entries) + + async def setup(self, config: CoreConfig) -> None: + """Async initialize of module.""" + self.config = config + self.logger.debug("RemoteAccessController.setup() called") + + # Register API commands immediately + self._register_api_commands() + + # Subscribe to provider updates to check for Home Assistant Cloud + self.mass.subscribe(self._on_providers_updated, EventType.PROVIDERS_UPDATED) + + # Try initial setup (providers might already be loaded) + await self._try_enable_remote_access() + + async def close(self) -> None: + """Cleanup on exit.""" + if self.gateway: + await self.gateway.stop() + self.gateway = None + self.logger.info("WebRTC Remote Access stopped") + + def _on_remote_id_ready(self, remote_id: str) -> None: + """Handle Remote ID registration with signaling server. + + :param remote_id: The registered Remote ID. + """ + self.logger.info("Remote ID registered with signaling server: %s", remote_id) + self._remote_id = remote_id + + def _on_providers_updated(self, event: MassEvent) -> None: + """Handle providers updated event. + + :param event: The providers updated event. + """ + if self._setup_done: + # Already set up, no need to check again + return + # Try to enable remote access when providers are updated + self.mass.create_task(self._try_enable_remote_access()) + + async def _try_enable_remote_access(self) -> None: + """Try to enable remote access if Home Assistant Cloud is available.""" + if self._setup_done: + # Already set up + return + + # Check if remote access is enabled in config + if not self.config.get_value(CONF_ENABLED, True): + self.logger.info("Remote access is disabled in configuration") + return + + # Check if Home Assistant Cloud is available and active + cloud_status = await self._check_ha_cloud_status() + if not cloud_status: + self.logger.debug("Home Assistant Cloud not available yet") + return + + # Mark as done to prevent multiple attempts + self._setup_done = True + self.logger.info("Home Assistant Cloud subscription detected, enabling remote access") + + # Get or generate Remote ID + remote_id_value = self.config.get_value(CONF_REMOTE_ID) + if not remote_id_value: + # Generate new Remote ID and save it + remote_id_value = generate_remote_id() + # Save the Remote ID to config + self.mass.config.set_raw_core_config_value(self.domain, CONF_REMOTE_ID, remote_id_value) + self.mass.config.save(immediate=True) + self.logger.info("Generated new Remote ID: %s", remote_id_value) + + # Ensure remote_id is a string + if isinstance(remote_id_value, str): + self._remote_id = remote_id_value + else: + self.logger.error("Invalid remote_id type: %s", type(remote_id_value)) + return + + # Determine local WebSocket URL from webserver config + webserver_config = await self.mass.config.get_core_config("webserver") + bind_port_value = webserver_config.get_value("bind_port", 8095) + bind_port = int(bind_port_value) if isinstance(bind_port_value, int) else 8095 + local_ws_url = f"ws://localhost:{bind_port}/ws" + + # Get ICE servers from HA Cloud if available + ice_servers = await self._get_ha_cloud_ice_servers() + + # Initialize and start the WebRTC gateway + self.gateway = WebRTCGateway( + signaling_url=SIGNALING_SERVER_URL, + local_ws_url=local_ws_url, + remote_id=self._remote_id, + on_remote_id_ready=self._on_remote_id_ready, + ice_servers=ice_servers, + ) + + await self.gateway.start() + self.logger.info("WebRTC Remote Access enabled - Remote ID: %s", self._remote_id) + + async def _check_ha_cloud_status(self) -> bool: + """Check if Home Assistant Cloud subscription is active. + + :return: True if HA Cloud is logged in and has active subscription. + """ + # Find the Home Assistant provider + ha_provider = None + for provider in self.mass.providers: + if provider.domain == "hass" and provider.available: + ha_provider = provider + break + + if not ha_provider: + self.logger.debug("Home Assistant provider not found or not available") + return False + + try: + # Access the hass client from the provider + if not hasattr(ha_provider, "hass"): + self.logger.debug("Provider does not have hass attribute") + return False + hass_client = ha_provider.hass # type: ignore[union-attr] + if not hass_client or not hass_client.connected: + self.logger.debug("Home Assistant client not connected") + return False + + # Call cloud/status command to check subscription + result = await hass_client.send_command("cloud/status") + + # Check for logged_in and active_subscription + logged_in = result.get("logged_in", False) + active_subscription = result.get("active_subscription", False) + + if logged_in and active_subscription: + self.logger.info("Home Assistant Cloud subscription is active") + return True + + self.logger.info( + "Home Assistant Cloud not active (logged_in=%s, active_subscription=%s)", + logged_in, + active_subscription, + ) + return False + + except Exception: + self.logger.exception("Error checking Home Assistant Cloud status") + return False + + async def _get_ha_cloud_ice_servers(self) -> list[dict[str, str]] | None: + """Get ICE servers from Home Assistant Cloud. + + :return: List of ICE server configurations or None if unavailable. + """ + # Find the Home Assistant provider + ha_provider = None + for provider in self.mass.providers: + if provider.domain == "hass" and provider.available: + ha_provider = provider + break + + if not ha_provider: + return None + + try: + # Access the hass client from the provider + if not hasattr(ha_provider, "hass"): + return None + hass_client = ha_provider.hass # type: ignore[union-attr] + if not hass_client or not hass_client.connected: + return None + + # Try to get ICE servers from HA Cloud + # This might be available via a cloud API endpoint + # For now, return None and use default STUN servers + # TODO: Research if HA Cloud exposes ICE/TURN server endpoints + self.logger.debug( + "Using default STUN servers (HA Cloud ICE servers not yet implemented)" + ) + return None + + except Exception: + self.logger.exception("Error getting Home Assistant Cloud ICE servers") + return None + + @property + def is_enabled(self) -> bool: + """Return whether WebRTC remote access is enabled.""" + return self.gateway is not None and self.gateway.is_running + + @property + def is_connected(self) -> bool: + """Return whether the gateway is connected to the signaling server.""" + return self.gateway is not None and self.gateway.is_connected + + @property + def remote_id(self) -> str | None: + """Return the current Remote ID.""" + return self._remote_id + + def _register_api_commands(self) -> None: + """Register API commands for remote access.""" + + @api_command("remote_access/info") + def get_remote_access_info() -> dict[str, str | bool]: + """Get remote access information. + + Returns information about the remote access configuration including + whether it's enabled, connected status, and the Remote ID for connecting. + """ + return { + "enabled": self.is_enabled, + "connected": self.is_connected, + "remote_id": self._remote_id or "", + "signaling_url": SIGNALING_SERVER_URL, + } diff --git a/music_assistant/controllers/remote_access/gateway.py b/music_assistant/controllers/remote_access/gateway.py new file mode 100644 index 00000000..5f8b8561 --- /dev/null +++ b/music_assistant/controllers/remote_access/gateway.py @@ -0,0 +1,667 @@ +"""Music Assistant WebRTC Gateway. + +This module provides WebRTC-based remote access to Music Assistant instances. +It connects to a signaling server and handles incoming WebRTC connections, +bridging them to the local WebSocket API. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +import secrets +import string +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import aiohttp +from aiortc import ( + RTCConfiguration, + RTCIceCandidate, + RTCIceServer, + RTCPeerConnection, + RTCSessionDescription, +) + +from music_assistant.constants import MASS_LOGGER_NAME + +if TYPE_CHECKING: + from collections.abc import Callable + +LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.remote_access") + +# Reduce verbose logging from aiortc/aioice +logging.getLogger("aioice").setLevel(logging.WARNING) +logging.getLogger("aiortc").setLevel(logging.WARNING) + + +def generate_remote_id() -> str: + """Generate a unique Remote ID in the format MA-XXXX-XXXX.""" + chars = string.ascii_uppercase + string.digits + part1 = "".join(secrets.choice(chars) for _ in range(4)) + part2 = "".join(secrets.choice(chars) for _ in range(4)) + return f"MA-{part1}-{part2}" + + +@dataclass +class WebRTCSession: + """Represents an active WebRTC session with a remote client.""" + + session_id: str + peer_connection: RTCPeerConnection + data_channel: Any = None + local_ws: Any = None + message_queue: asyncio.Queue[str] = field(default_factory=asyncio.Queue) + forward_to_local_task: asyncio.Task[None] | None = None + forward_from_local_task: asyncio.Task[None] | None = None + + +class WebRTCGateway: + """WebRTC Gateway for Music Assistant Remote Access. + + This gateway: + 1. Connects to a signaling server + 2. Registers with a unique Remote ID + 3. Handles incoming WebRTC connections from remote PWA clients + 4. Bridges WebRTC DataChannel messages to the local WebSocket API + """ + + def __init__( + self, + signaling_url: str = "wss://signaling.music-assistant.io/ws", + local_ws_url: str = "ws://localhost:8095/ws", + ice_servers: list[dict[str, Any]] | None = None, + remote_id: str | None = None, + on_remote_id_ready: Callable[[str], None] | None = None, + ) -> None: + """Initialize the WebRTC Gateway. + + :param signaling_url: WebSocket URL of the signaling server. + :param local_ws_url: Local WebSocket URL to bridge to. + :param ice_servers: List of ICE server configurations. + :param remote_id: Optional Remote ID to use (generated if not provided). + :param on_remote_id_ready: Callback when Remote ID is registered. + """ + self.signaling_url = signaling_url + self.local_ws_url = local_ws_url + self.remote_id = remote_id or generate_remote_id() + self.on_remote_id_ready = on_remote_id_ready + self.logger = LOGGER + + self.ice_servers = ice_servers or [ + {"urls": "stun:stun.l.google.com:19302"}, + {"urls": "stun:stun1.l.google.com:19302"}, + {"urls": "stun:stun.cloudflare.com:3478"}, + ] + + self.sessions: dict[str, WebRTCSession] = {} + self._signaling_ws: aiohttp.ClientWebSocketResponse | None = None + self._signaling_session: aiohttp.ClientSession | None = None + self._running = False + self._reconnect_delay = 5 + self._max_reconnect_delay = 60 + self._current_reconnect_delay = 5 + self._run_task: asyncio.Task[None] | None = None + self._is_connected = False + + @property + def is_running(self) -> bool: + """Return whether the gateway is running.""" + return self._running + + @property + def is_connected(self) -> bool: + """Return whether the gateway is connected to the signaling server.""" + return self._is_connected + + async def start(self) -> None: + """Start the WebRTC Gateway.""" + self.logger.info("Starting WebRTC Gateway with Remote ID: %s", self.remote_id) + self.logger.debug("Signaling URL: %s", self.signaling_url) + self.logger.debug("Local WS URL: %s", self.local_ws_url) + self._running = True + self._run_task = asyncio.create_task(self._run()) + self.logger.debug("WebRTC Gateway start task created") + + async def stop(self) -> None: + """Stop the WebRTC Gateway.""" + self.logger.info("Stopping WebRTC Gateway") + self._running = False + + # Close all sessions + for session_id in list(self.sessions.keys()): + await self._close_session(session_id) + + # Close signaling connection gracefully + if self._signaling_ws and not self._signaling_ws.closed: + try: + await self._signaling_ws.close() + except Exception: + self.logger.debug("Error closing signaling WebSocket", exc_info=True) + + if self._signaling_session and not self._signaling_session.closed: + try: + await self._signaling_session.close() + except Exception: + self.logger.debug("Error closing signaling session", exc_info=True) + + # Cancel run task + if self._run_task and not self._run_task.done(): + self._run_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._run_task + + self._signaling_ws = None + self._signaling_session = None + + async def _run(self) -> None: + """Run the main loop with reconnection logic.""" + self.logger.debug("WebRTC Gateway _run() loop starting") + while self._running: + try: + await self._connect_to_signaling() + # Connection closed gracefully or with error + self._is_connected = False + if self._running: + self.logger.warning( + "Signaling server connection lost. Reconnecting in %ss...", + self._current_reconnect_delay, + ) + except Exception: + self._is_connected = False + self.logger.exception("Signaling connection error") + if self._running: + self.logger.info( + "Reconnecting to signaling server in %ss", + self._current_reconnect_delay, + ) + + if self._running: + await asyncio.sleep(self._current_reconnect_delay) + # Exponential backoff with max limit + self._current_reconnect_delay = min( + self._current_reconnect_delay * 2, self._max_reconnect_delay + ) + + async def _connect_to_signaling(self) -> None: + """Connect to the signaling server.""" + self.logger.info("Connecting to signaling server: %s", self.signaling_url) + # Create session with increased timeout for WebSocket connection + timeout = aiohttp.ClientTimeout( + total=None, # No total timeout for WebSocket connections + connect=30, # 30 seconds to establish connection + sock_connect=30, # 30 seconds for socket connection + sock_read=None, # No timeout for reading (we handle ping/pong) + ) + self._signaling_session = aiohttp.ClientSession(timeout=timeout) + try: + self._signaling_ws = await self._signaling_session.ws_connect( + self.signaling_url, + heartbeat=30, # Send WebSocket ping every 30 seconds + autoping=True, # Automatically respond to pings + ) + self.logger.debug("WebSocket connection established") + # Small delay to let any previous connection fully close on the server side + # This helps prevent race conditions during reconnection + await asyncio.sleep(0.5) + self.logger.debug("Sending registration") + await self._register() + self._is_connected = True + # Reset reconnect delay on successful connection + self._current_reconnect_delay = self._reconnect_delay + self.logger.info("Connected and registered with signaling server") + + # Message loop + self.logger.debug("Entering message loop") + async for msg in self._signaling_ws: + self.logger.debug("Received WebSocket message type: %s", msg.type) + if msg.type == aiohttp.WSMsgType.TEXT: + try: + await self._handle_signaling_message(json.loads(msg.data)) + except Exception: + self.logger.exception("Error handling signaling message") + elif msg.type == aiohttp.WSMsgType.PING: + # WebSocket ping - autoping should handle this, just log + self.logger.debug("Received WebSocket PING") + elif msg.type == aiohttp.WSMsgType.PONG: + # WebSocket pong response - just log + self.logger.debug("Received WebSocket PONG") + elif msg.type == aiohttp.WSMsgType.CLOSE: + # Close frame received + self.logger.warning( + "Signaling server sent close frame: code=%s, reason=%s", + msg.data, + msg.extra, + ) + break + elif msg.type == aiohttp.WSMsgType.CLOSED: + self.logger.warning("Signaling server closed connection") + break + elif msg.type == aiohttp.WSMsgType.ERROR: + self.logger.error("WebSocket error: %s", self._signaling_ws.exception()) + break + else: + self.logger.warning("Unexpected WebSocket message type: %s", msg.type) + + self.logger.info( + "Message loop exited - WebSocket closed: %s", self._signaling_ws.closed + ) + except TimeoutError: + self.logger.error("Timeout connecting to signaling server") + except aiohttp.ClientError as err: + self.logger.error("Failed to connect to signaling server: %s", err) + except Exception: + self.logger.exception("Unexpected error in signaling connection") + finally: + self._is_connected = False + if self._signaling_session: + await self._signaling_session.close() + self._signaling_session = None + self._signaling_ws = None + + async def _register(self) -> None: + """Register with the signaling server.""" + if self._signaling_ws: + registration_msg = { + "type": "register-server", + "remoteId": self.remote_id, + } + self.logger.info( + "Sending registration to signaling server with Remote ID: %s", + self.remote_id, + ) + self.logger.debug("Registration message: %s", registration_msg) + await self._signaling_ws.send_json(registration_msg) + self.logger.debug("Registration message sent successfully") + else: + self.logger.warning("Cannot register: signaling websocket is not connected") + + async def _handle_signaling_message(self, message: dict[str, Any]) -> None: + """Handle incoming signaling messages. + + :param message: The signaling message. + """ + msg_type = message.get("type") + self.logger.debug("Received signaling message: %s - Full message: %s", msg_type, message) + + if msg_type == "ping": + # Respond to ping with pong + if self._signaling_ws: + await self._signaling_ws.send_json({"type": "pong"}) + elif msg_type == "pong": + # Server responded to our ping, connection is alive + pass + elif msg_type == "registered": + self.logger.info("Registered with signaling server as: %s", message.get("remoteId")) + if self.on_remote_id_ready: + self.on_remote_id_ready(self.remote_id) + elif msg_type == "error": + self.logger.error( + "Signaling server error: %s", + message.get("message", "Unknown error"), + ) + elif msg_type == "client-connected": + session_id = message.get("sessionId") + if session_id: + await self._create_session(session_id) + elif msg_type == "client-disconnected": + session_id = message.get("sessionId") + if session_id: + await self._close_session(session_id) + elif msg_type == "offer": + session_id = message.get("sessionId") + offer_data = message.get("data") + if session_id and offer_data: + await self._handle_offer(session_id, offer_data) + elif msg_type == "ice-candidate": + session_id = message.get("sessionId") + candidate_data = message.get("data") + if session_id and candidate_data: + await self._handle_ice_candidate(session_id, candidate_data) + + async def _create_session(self, session_id: str) -> None: + """Create a new WebRTC session. + + :param session_id: The session ID. + """ + config = RTCConfiguration( + iceServers=[RTCIceServer(**server) for server in self.ice_servers] + ) + pc = RTCPeerConnection(configuration=config) + session = WebRTCSession(session_id=session_id, peer_connection=pc) + self.sessions[session_id] = session + + @pc.on("datachannel") + def on_datachannel(channel: Any) -> None: + session.data_channel = channel + asyncio.create_task(self._setup_data_channel(session)) + + @pc.on("icecandidate") + async def on_icecandidate(candidate: Any) -> None: + if candidate and self._signaling_ws: + await self._signaling_ws.send_json( + { + "type": "ice-candidate", + "sessionId": session_id, + "data": { + "candidate": candidate.candidate, + "sdpMid": candidate.sdpMid, + "sdpMLineIndex": candidate.sdpMLineIndex, + }, + } + ) + + @pc.on("connectionstatechange") + async def on_connectionstatechange() -> None: + if pc.connectionState == "failed": + await self._close_session(session_id) + + async def _handle_offer(self, session_id: str, offer: dict[str, Any]) -> None: + """Handle incoming WebRTC offer. + + :param session_id: The session ID. + :param offer: The offer data. + """ + session = self.sessions.get(session_id) + if not session: + return + pc = session.peer_connection + + # Check if peer connection is already closed or closing + if pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Ignoring offer for session %s - connection state: %s", + session_id, + pc.connectionState, + ) + return + + sdp = offer.get("sdp") + sdp_type = offer.get("type") + if not sdp or not sdp_type: + self.logger.error("Invalid offer data: missing sdp or type") + return + + try: + await pc.setRemoteDescription( + RTCSessionDescription( + sdp=str(sdp), + type=str(sdp_type), + ) + ) + + # Check again if session was closed during setRemoteDescription + if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Session %s closed during setRemoteDescription, aborting offer handling", + session_id, + ) + return + + answer = await pc.createAnswer() + + # Check again before setLocalDescription + if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Session %s closed during createAnswer, aborting offer handling", + session_id, + ) + return + + await pc.setLocalDescription(answer) + + # Final check before sending answer + if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Session %s closed during setLocalDescription, skipping answer transmission", + session_id, + ) + return + + if self._signaling_ws: + await self._signaling_ws.send_json( + { + "type": "answer", + "sessionId": session_id, + "data": { + "sdp": pc.localDescription.sdp, + "type": pc.localDescription.type, + }, + } + ) + except Exception: + self.logger.exception("Error handling offer for session %s", session_id) + # Clean up the session on error + await self._close_session(session_id) + + async def _handle_ice_candidate(self, session_id: str, candidate: dict[str, Any]) -> None: + """Handle incoming ICE candidate. + + :param session_id: The session ID. + :param candidate: The ICE candidate data. + """ + session = self.sessions.get(session_id) + if not session or not candidate: + return + + # Check if peer connection is already closed or closing + pc = session.peer_connection + if pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Ignoring ICE candidate for session %s - connection state: %s", + session_id, + pc.connectionState, + ) + return + + candidate_str = candidate.get("candidate") + sdp_mid = candidate.get("sdpMid") + sdp_mline_index = candidate.get("sdpMLineIndex") + + if not candidate_str: + return + + # Create RTCIceCandidate from the SDP string + try: + ice_candidate = RTCIceCandidate( + component=1, + foundation="", + ip="", + port=0, + priority=0, + protocol="udp", + type="host", + sdpMid=str(sdp_mid) if sdp_mid else None, + sdpMLineIndex=int(sdp_mline_index) if sdp_mline_index is not None else None, + ) + # Parse the candidate string to populate the fields + ice_candidate.candidate = str(candidate_str) # type: ignore[attr-defined] + + # Check if session was closed before adding candidate + if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Session %s closed before adding ICE candidate, skipping", + session_id, + ) + return + + await session.peer_connection.addIceCandidate(ice_candidate) + except Exception: + self.logger.exception( + "Failed to add ICE candidate for session %s: %s", session_id, candidate + ) + + async def _setup_data_channel(self, session: WebRTCSession) -> None: + """Set up data channel and bridge to local WebSocket. + + :param session: The WebRTC session. + """ + channel = session.data_channel + if not channel: + return + local_session = aiohttp.ClientSession() + try: + session.local_ws = await local_session.ws_connect(self.local_ws_url) + loop = asyncio.get_event_loop() + + # Store task references for proper cleanup + session.forward_to_local_task = asyncio.create_task(self._forward_to_local(session)) + session.forward_from_local_task = asyncio.create_task( + self._forward_from_local(session, local_session) + ) + + @channel.on("message") # type: ignore[misc] + def on_message(message: str) -> None: + # Called from aiortc thread, use call_soon_threadsafe + # Only queue message if session is still active + if session.forward_to_local_task and not session.forward_to_local_task.done(): + loop.call_soon_threadsafe(session.message_queue.put_nowait, message) + + @channel.on("close") # type: ignore[misc] + def on_close() -> None: + # Called from aiortc thread, use call_soon_threadsafe to schedule task + asyncio.run_coroutine_threadsafe(self._close_session(session.session_id), loop) + + except Exception: + self.logger.exception("Failed to connect to local WebSocket") + await local_session.close() + + async def _forward_to_local(self, session: WebRTCSession) -> None: + """Forward messages from WebRTC DataChannel to local WebSocket. + + :param session: The WebRTC session. + """ + try: + while session.local_ws and not session.local_ws.closed: + message = await session.message_queue.get() + + # Check if this is an HTTP proxy request + try: + msg_data = json.loads(message) + if isinstance(msg_data, dict) and msg_data.get("type") == "http-proxy-request": + # Handle HTTP proxy request + await self._handle_http_proxy_request(session, msg_data) + continue + except (json.JSONDecodeError, ValueError): + pass + + # Regular WebSocket message + if session.local_ws and not session.local_ws.closed: + await session.local_ws.send_str(message) + except asyncio.CancelledError: + # Task was cancelled during cleanup, this is expected + self.logger.debug("Forward to local task cancelled for session %s", session.session_id) + raise + except Exception: + self.logger.exception("Error forwarding to local WebSocket") + + async def _forward_from_local( + self, session: WebRTCSession, local_session: aiohttp.ClientSession + ) -> None: + """Forward messages from local WebSocket to WebRTC DataChannel. + + :param session: The WebRTC session. + :param local_session: The aiohttp client session. + """ + try: + async for msg in session.local_ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if session.data_channel and session.data_channel.readyState == "open": + session.data_channel.send(msg.data) + elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): + break + except asyncio.CancelledError: + # Task was cancelled during cleanup, this is expected + self.logger.debug( + "Forward from local task cancelled for session %s", session.session_id + ) + raise + except Exception: + self.logger.exception("Error forwarding from local WebSocket") + finally: + await local_session.close() + + async def _handle_http_proxy_request( + self, session: WebRTCSession, request_data: dict[str, Any] + ) -> None: + """Handle HTTP proxy request from remote client. + + :param session: The WebRTC session. + :param request_data: The HTTP proxy request data. + """ + request_id = request_data.get("id") + method = request_data.get("method", "GET") + path = request_data.get("path", "/") + headers = request_data.get("headers", {}) + + # Build local HTTP URL + # Extract host and port from local_ws_url (ws://localhost:8095/ws) + ws_url_parts = self.local_ws_url.replace("ws://", "").split("/") + host_port = ws_url_parts[0] # localhost:8095 + local_http_url = f"http://{host_port}{path}" + + self.logger.debug("HTTP proxy request: %s %s", method, local_http_url) + + try: + # Create a new HTTP client session for this request + async with ( + aiohttp.ClientSession() as http_session, + http_session.request(method, local_http_url, headers=headers) as response, + ): + # Read response body + body = await response.read() + + # Prepare response data + response_data = { + "type": "http-proxy-response", + "id": request_id, + "status": response.status, + "headers": dict(response.headers), + "body": body.hex(), # Send as hex string to avoid encoding issues + } + + # Send response back through data channel + if session.data_channel and session.data_channel.readyState == "open": + session.data_channel.send(json.dumps(response_data)) + + except Exception as err: + self.logger.exception("Error handling HTTP proxy request") + # Send error response + error_response = { + "type": "http-proxy-response", + "id": request_id, + "status": 500, + "headers": {"Content-Type": "text/plain"}, + "body": str(err).encode().hex(), + } + if session.data_channel and session.data_channel.readyState == "open": + session.data_channel.send(json.dumps(error_response)) + + async def _close_session(self, session_id: str) -> None: + """Close a WebRTC session. + + :param session_id: The session ID. + """ + session = self.sessions.pop(session_id, None) + if not session: + return + + # Cancel forwarding tasks first to prevent race conditions + if session.forward_to_local_task and not session.forward_to_local_task.done(): + session.forward_to_local_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await session.forward_to_local_task + + if session.forward_from_local_task and not session.forward_from_local_task.done(): + session.forward_from_local_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await session.forward_from_local_task + + # Close connections + if session.local_ws and not session.local_ws.closed: + await session.local_ws.close() + if session.data_channel: + session.data_channel.close() + await session.peer_connection.close() diff --git a/music_assistant/controllers/webserver/controller.py b/music_assistant/controllers/webserver/controller.py index 0aa8aa8e..e46c00e3 100644 --- a/music_assistant/controllers/webserver/controller.py +++ b/music_assistant/controllers/webserver/controller.py @@ -56,7 +56,6 @@ from .helpers.auth_middleware import ( set_current_user, ) from .helpers.auth_providers import BuiltinLoginProvider -from .remote_access import RemoteAccessManager from .websocket_client import WebsocketClientHandler if TYPE_CHECKING: @@ -92,7 +91,6 @@ class WebserverController(CoreController): ) self.manifest.icon = "web-box" self.auth = AuthenticationManager(self) - self.remote_access = RemoteAccessManager(self) @property def base_url(self) -> str: @@ -199,15 +197,6 @@ class WebserverController(CoreController): category="advanced", hidden=not any(provider.domain == "hass" for provider in self.mass.providers), ), - ConfigEntry( - key="remote_id", - type=ConfigEntryType.STRING, - label="Remote ID", - description="Unique identifier for WebRTC remote access. " - "Generated automatically and should not be changed.", - required=False, - hidden=True, - ), ] ) @@ -242,6 +231,7 @@ class WebserverController(CoreController): routes.append(("OPTIONS", "/info", self._handle_cors_preflight)) # add logging routes.append(("GET", "/music-assistant.log", self._handle_application_log)) + routes.append(("OPTIONS", "/music-assistant.log", self._handle_cors_preflight)) # add websocket api routes.append(("GET", "/ws", self._handle_ws_client)) # also host the image proxy on the webserver @@ -275,10 +265,7 @@ class WebserverController(CoreController): # add first-time setup routes routes.append(("GET", "/setup", self._handle_setup_page)) routes.append(("POST", "/setup", self._handle_setup)) - # Initialize authentication manager await self.auth.setup() - # Initialize remote access manager - await self.remote_access.setup() # start the webserver all_ip_addresses = await get_ip_addresses() default_publish_ip = all_ip_addresses[0] @@ -391,7 +378,6 @@ class WebserverController(CoreController): await client.disconnect() await self._server.close() await self.auth.close() - await self.remote_access.close() def register_websocket_client(self, client: WebsocketClientHandler) -> None: """Register a WebSocket client for tracking.""" @@ -561,7 +547,15 @@ class WebserverController(CoreController): async def _handle_application_log(self, request: web.Request) -> web.Response: """Handle request to get the application log.""" log_data = await self.mass.get_application_log() - return web.Response(text=log_data, content_type="text/text") + return web.Response( + text=log_data, + content_type="text/text", + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, Authorization", + }, + ) async def _handle_api_intro(self, request: web.Request) -> web.Response: """Handle request for API introduction/documentation page.""" diff --git a/music_assistant/controllers/webserver/remote_access/__init__.py b/music_assistant/controllers/webserver/remote_access/__init__.py deleted file mode 100644 index 5948906e..00000000 --- a/music_assistant/controllers/webserver/remote_access/__init__.py +++ /dev/null @@ -1,254 +0,0 @@ -"""Remote Access manager for Music Assistant webserver.""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING - -from music_assistant_models.enums import EventType - -from music_assistant.constants import MASS_LOGGER_NAME -from music_assistant.controllers.webserver.remote_access.gateway import ( - WebRTCGateway, - generate_remote_id, -) -from music_assistant.helpers.api import api_command - -if TYPE_CHECKING: - from music_assistant_models.event import MassEvent - - from music_assistant.controllers.webserver import WebserverController - -LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.remote_access") - -# Signaling server URL -SIGNALING_SERVER_URL = "wss://signaling.music-assistant.io/ws" - -# Internal config key for storing the remote ID -_CONF_REMOTE_ID = "remote_id" - - -class RemoteAccessManager: - """Manager for WebRTC-based remote access (part of webserver controller).""" - - def __init__(self, webserver: WebserverController) -> None: - """Initialize the remote access manager. - - :param webserver: WebserverController instance. - """ - self.webserver = webserver - self.mass = webserver.mass - self.logger = LOGGER - self.gateway: WebRTCGateway | None = None - self._remote_id: str | None = None - self._setup_done = False - - async def setup(self) -> None: - """Initialize the remote access manager.""" - self.logger.debug("RemoteAccessManager.setup() called") - - # Register API commands immediately - self._register_api_commands() - - # Subscribe to provider updates to check for Home Assistant Cloud - self.mass.subscribe(self._on_providers_updated, EventType.PROVIDERS_UPDATED) - - # Try initial setup (providers might already be loaded) - await self._try_enable_remote_access() - - async def close(self) -> None: - """Cleanup on exit.""" - if self.gateway: - await self.gateway.stop() - self.gateway = None - self.logger.info("WebRTC Remote Access stopped") - - def _on_remote_id_ready(self, remote_id: str) -> None: - """Handle Remote ID registration with signaling server. - - :param remote_id: The registered Remote ID. - """ - self.logger.info("Remote ID registered with signaling server: %s", remote_id) - self._remote_id = remote_id - - def _on_providers_updated(self, event: MassEvent) -> None: - """Handle providers updated event. - - :param event: The providers updated event. - """ - if self._setup_done: - # Already set up, no need to check again - return - # Try to enable remote access when providers are updated - self.mass.create_task(self._try_enable_remote_access()) - - async def _try_enable_remote_access(self) -> None: - """Try to enable remote access if Home Assistant Cloud is available.""" - if self._setup_done: - # Already set up - return - - # Check if Home Assistant Cloud is available and active - cloud_status = await self._check_ha_cloud_status() - if not cloud_status: - self.logger.debug("Home Assistant Cloud not available yet") - return - - # Mark as done to prevent multiple attempts - self._setup_done = True - self.logger.info("Home Assistant Cloud subscription detected, enabling remote access") - - # Get or generate Remote ID - remote_id_value = self.webserver.config.get_value(_CONF_REMOTE_ID) - if not remote_id_value: - # Generate new Remote ID and save it - remote_id_value = generate_remote_id() - # Save the Remote ID to config - self.mass.config.set_raw_core_config_value( - "webserver", _CONF_REMOTE_ID, remote_id_value - ) - self.mass.config.save(immediate=True) - self.logger.info("Generated new Remote ID: %s", remote_id_value) - - # Ensure remote_id is a string - if isinstance(remote_id_value, str): - self._remote_id = remote_id_value - else: - self.logger.error("Invalid remote_id type: %s", type(remote_id_value)) - return - - # Determine local WebSocket URL - bind_port_value = self.webserver.config.get_value("bind_port", 8095) - bind_port = int(bind_port_value) if isinstance(bind_port_value, int) else 8095 - local_ws_url = f"ws://localhost:{bind_port}/ws" - - # Get ICE servers from HA Cloud if available - ice_servers = await self._get_ha_cloud_ice_servers() - - # Initialize and start the WebRTC gateway - self.gateway = WebRTCGateway( - signaling_url=SIGNALING_SERVER_URL, - local_ws_url=local_ws_url, - remote_id=self._remote_id, - on_remote_id_ready=self._on_remote_id_ready, - ice_servers=ice_servers, - ) - - await self.gateway.start() - self.logger.info("WebRTC Remote Access enabled - Remote ID: %s", self._remote_id) - - async def _check_ha_cloud_status(self) -> bool: - """Check if Home Assistant Cloud subscription is active. - - :return: True if HA Cloud is logged in and has active subscription. - """ - # Find the Home Assistant provider - ha_provider = None - for provider in self.mass.providers: - if provider.domain == "hass" and provider.available: - ha_provider = provider - break - - if not ha_provider: - self.logger.debug("Home Assistant provider not found or not available") - return False - - try: - # Access the hass client from the provider - if not hasattr(ha_provider, "hass"): - self.logger.debug("Provider does not have hass attribute") - return False - hass_client = ha_provider.hass # type: ignore[union-attr] - if not hass_client or not hass_client.connected: - self.logger.debug("Home Assistant client not connected") - return False - - # Call cloud/status command to check subscription - result = await hass_client.send_command("cloud/status") - - # Check for logged_in and active_subscription - logged_in = result.get("logged_in", False) - active_subscription = result.get("active_subscription", False) - - if logged_in and active_subscription: - self.logger.info("Home Assistant Cloud subscription is active") - return True - - self.logger.info( - "Home Assistant Cloud not active (logged_in=%s, active_subscription=%s)", - logged_in, - active_subscription, - ) - return False - - except Exception: - self.logger.exception("Error checking Home Assistant Cloud status") - return False - - async def _get_ha_cloud_ice_servers(self) -> list[dict[str, str]] | None: - """Get ICE servers from Home Assistant Cloud. - - :return: List of ICE server configurations or None if unavailable. - """ - # Find the Home Assistant provider - ha_provider = None - for provider in self.mass.providers: - if provider.domain == "hass" and provider.available: - ha_provider = provider - break - - if not ha_provider: - return None - - try: - # Access the hass client from the provider - if not hasattr(ha_provider, "hass"): - return None - hass_client = ha_provider.hass # type: ignore[union-attr] - if not hass_client or not hass_client.connected: - return None - - # Try to get ICE servers from HA Cloud - # This might be available via a cloud API endpoint - # For now, return None and use default STUN servers - # TODO: Research if HA Cloud exposes ICE/TURN server endpoints - self.logger.debug( - "Using default STUN servers (HA Cloud ICE servers not yet implemented)" - ) - return None - - except Exception: - self.logger.exception("Error getting Home Assistant Cloud ICE servers") - return None - - @property - def is_enabled(self) -> bool: - """Return whether WebRTC remote access is enabled.""" - return self.gateway is not None and self.gateway.is_running - - @property - def is_connected(self) -> bool: - """Return whether the gateway is connected to the signaling server.""" - return self.gateway is not None and self.gateway.is_connected - - @property - def remote_id(self) -> str | None: - """Return the current Remote ID.""" - return self._remote_id - - def _register_api_commands(self) -> None: - """Register API commands for remote access.""" - - @api_command("remote_access/info") - def get_remote_access_info() -> dict[str, str | bool]: - """Get remote access information. - - Returns information about the remote access configuration including - whether it's enabled, connected status, and the Remote ID for connecting. - """ - return { - "enabled": self.is_enabled, - "connected": self.is_connected, - "remote_id": self._remote_id or "", - "signaling_url": SIGNALING_SERVER_URL, - } diff --git a/music_assistant/controllers/webserver/remote_access/gateway.py b/music_assistant/controllers/webserver/remote_access/gateway.py deleted file mode 100644 index 47b8ecd7..00000000 --- a/music_assistant/controllers/webserver/remote_access/gateway.py +++ /dev/null @@ -1,566 +0,0 @@ -"""Music Assistant WebRTC Gateway. - -This module provides WebRTC-based remote access to Music Assistant instances. -It connects to a signaling server and handles incoming WebRTC connections, -bridging them to the local WebSocket API. -""" - -from __future__ import annotations - -import asyncio -import contextlib -import json -import logging -import secrets -import string -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any - -import aiohttp -from aiortc import ( - RTCConfiguration, - RTCIceCandidate, - RTCIceServer, - RTCPeerConnection, - RTCSessionDescription, -) - -from music_assistant.constants import MASS_LOGGER_NAME - -if TYPE_CHECKING: - from collections.abc import Callable - -LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.remote_access") - -# Reduce verbose logging from aiortc/aioice -logging.getLogger("aioice").setLevel(logging.WARNING) -logging.getLogger("aiortc").setLevel(logging.WARNING) - - -def generate_remote_id() -> str: - """Generate a unique Remote ID in the format MA-XXXX-XXXX.""" - chars = string.ascii_uppercase + string.digits - part1 = "".join(secrets.choice(chars) for _ in range(4)) - part2 = "".join(secrets.choice(chars) for _ in range(4)) - return f"MA-{part1}-{part2}" - - -@dataclass -class WebRTCSession: - """Represents an active WebRTC session with a remote client.""" - - session_id: str - peer_connection: RTCPeerConnection - data_channel: Any = None - local_ws: Any = None - message_queue: asyncio.Queue[str] = field(default_factory=asyncio.Queue) - forward_to_local_task: asyncio.Task[None] | None = None - forward_from_local_task: asyncio.Task[None] | None = None - - -class WebRTCGateway: - """WebRTC Gateway for Music Assistant Remote Access. - - This gateway: - 1. Connects to a signaling server - 2. Registers with a unique Remote ID - 3. Handles incoming WebRTC connections from remote PWA clients - 4. Bridges WebRTC DataChannel messages to the local WebSocket API - """ - - def __init__( - self, - signaling_url: str = "wss://signaling.music-assistant.io/ws", - local_ws_url: str = "ws://localhost:8095/ws", - ice_servers: list[dict[str, Any]] | None = None, - remote_id: str | None = None, - on_remote_id_ready: Callable[[str], None] | None = None, - ) -> None: - """Initialize the WebRTC Gateway. - - :param signaling_url: WebSocket URL of the signaling server. - :param local_ws_url: Local WebSocket URL to bridge to. - :param ice_servers: List of ICE server configurations. - :param remote_id: Optional Remote ID to use (generated if not provided). - :param on_remote_id_ready: Callback when Remote ID is registered. - """ - self.signaling_url = signaling_url - self.local_ws_url = local_ws_url - self.remote_id = remote_id or generate_remote_id() - self.on_remote_id_ready = on_remote_id_ready - self.logger = LOGGER - - self.ice_servers = ice_servers or [ - {"urls": "stun:stun.l.google.com:19302"}, - {"urls": "stun:stun1.l.google.com:19302"}, - {"urls": "stun:stun.cloudflare.com:3478"}, - ] - - self.sessions: dict[str, WebRTCSession] = {} - self._signaling_ws: aiohttp.ClientWebSocketResponse | None = None - self._signaling_session: aiohttp.ClientSession | None = None - self._running = False - self._reconnect_delay = 5 - self._max_reconnect_delay = 60 - self._current_reconnect_delay = 5 - self._run_task: asyncio.Task[None] | None = None - self._is_connected = False - - @property - def is_running(self) -> bool: - """Return whether the gateway is running.""" - return self._running - - @property - def is_connected(self) -> bool: - """Return whether the gateway is connected to the signaling server.""" - return self._is_connected - - async def start(self) -> None: - """Start the WebRTC Gateway.""" - self.logger.info("Starting WebRTC Gateway with Remote ID: %s", self.remote_id) - self.logger.debug("Signaling URL: %s", self.signaling_url) - self.logger.debug("Local WS URL: %s", self.local_ws_url) - self._running = True - self._run_task = asyncio.create_task(self._run()) - self.logger.debug("WebRTC Gateway start task created") - - async def stop(self) -> None: - """Stop the WebRTC Gateway.""" - self.logger.info("Stopping WebRTC Gateway") - self._running = False - - # Close all sessions - for session_id in list(self.sessions.keys()): - await self._close_session(session_id) - - # Close signaling connection - if self._signaling_ws: - await self._signaling_ws.close() - if self._signaling_session: - await self._signaling_session.close() - - # Cancel run task - if self._run_task and not self._run_task.done(): - self._run_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._run_task - - async def _run(self) -> None: - """Run the main loop with reconnection logic.""" - self.logger.debug("WebRTC Gateway _run() loop starting") - while self._running: - try: - await self._connect_to_signaling() - # Connection closed gracefully or with error - self._is_connected = False - if self._running: - self.logger.warning( - "Signaling server connection lost. Reconnecting in %ss...", - self._current_reconnect_delay, - ) - except Exception: - self._is_connected = False - self.logger.exception("Signaling connection error") - if self._running: - self.logger.info( - "Reconnecting to signaling server in %ss", - self._current_reconnect_delay, - ) - - if self._running: - await asyncio.sleep(self._current_reconnect_delay) - # Exponential backoff with max limit - self._current_reconnect_delay = min( - self._current_reconnect_delay * 2, self._max_reconnect_delay - ) - - async def _connect_to_signaling(self) -> None: - """Connect to the signaling server.""" - self.logger.info("Connecting to signaling server: %s", self.signaling_url) - # Create session with increased timeout for WebSocket connection - timeout = aiohttp.ClientTimeout( - total=None, # No total timeout for WebSocket connections - connect=30, # 30 seconds to establish connection - sock_connect=30, # 30 seconds for socket connection - sock_read=None, # No timeout for reading (we handle ping/pong) - ) - self._signaling_session = aiohttp.ClientSession(timeout=timeout) - try: - self._signaling_ws = await self._signaling_session.ws_connect( - self.signaling_url, - heartbeat=45, # Send WebSocket ping every 45 seconds - autoping=True, # Automatically respond to pings - ) - await self._register() - self._is_connected = True - # Reset reconnect delay on successful connection - self._current_reconnect_delay = self._reconnect_delay - self.logger.info("Connected to signaling server") - - # Message loop - async for msg in self._signaling_ws: - if msg.type == aiohttp.WSMsgType.TEXT: - try: - await self._handle_signaling_message(json.loads(msg.data)) - except Exception: - self.logger.exception("Error handling signaling message") - elif msg.type == aiohttp.WSMsgType.CLOSED: - self.logger.warning("Signaling server closed connection: %s", msg.extra) - break - elif msg.type == aiohttp.WSMsgType.ERROR: - self.logger.error("WebSocket error: %s", msg.extra) - break - - self.logger.info("Message loop exited normally") - except TimeoutError: - self.logger.error("Timeout connecting to signaling server") - except aiohttp.ClientError as err: - self.logger.error("Failed to connect to signaling server: %s", err) - finally: - self._is_connected = False - if self._signaling_session: - await self._signaling_session.close() - self._signaling_session = None - self._signaling_ws = None - - async def _register(self) -> None: - """Register with the signaling server.""" - if self._signaling_ws: - registration_msg = { - "type": "register-server", - "remoteId": self.remote_id, - } - self.logger.info( - "Sending registration to signaling server with Remote ID: %s", - self.remote_id, - ) - self.logger.debug("Registration message: %s", registration_msg) - await self._signaling_ws.send_json(registration_msg) - self.logger.debug("Registration message sent successfully") - else: - self.logger.warning("Cannot register: signaling websocket is not connected") - - async def _handle_signaling_message(self, message: dict[str, Any]) -> None: - """Handle incoming signaling messages. - - :param message: The signaling message. - """ - msg_type = message.get("type") - self.logger.debug("Received signaling message: %s - Full message: %s", msg_type, message) - - if msg_type == "ping": - # Respond to ping with pong - if self._signaling_ws: - await self._signaling_ws.send_json({"type": "pong"}) - elif msg_type == "pong": - # Server responded to our ping, connection is alive - pass - elif msg_type == "registered": - self.logger.info("Registered with signaling server as: %s", message.get("remoteId")) - if self.on_remote_id_ready: - self.on_remote_id_ready(self.remote_id) - elif msg_type == "error": - self.logger.error( - "Signaling server error: %s", - message.get("message", "Unknown error"), - ) - elif msg_type == "client-connected": - session_id = message.get("sessionId") - if session_id: - await self._create_session(session_id) - elif msg_type == "client-disconnected": - session_id = message.get("sessionId") - if session_id: - await self._close_session(session_id) - elif msg_type == "offer": - session_id = message.get("sessionId") - offer_data = message.get("data") - if session_id and offer_data: - await self._handle_offer(session_id, offer_data) - elif msg_type == "ice-candidate": - session_id = message.get("sessionId") - candidate_data = message.get("data") - if session_id and candidate_data: - await self._handle_ice_candidate(session_id, candidate_data) - - async def _create_session(self, session_id: str) -> None: - """Create a new WebRTC session. - - :param session_id: The session ID. - """ - config = RTCConfiguration( - iceServers=[RTCIceServer(**server) for server in self.ice_servers] - ) - pc = RTCPeerConnection(configuration=config) - session = WebRTCSession(session_id=session_id, peer_connection=pc) - self.sessions[session_id] = session - - @pc.on("datachannel") - def on_datachannel(channel: Any) -> None: - session.data_channel = channel - asyncio.create_task(self._setup_data_channel(session)) - - @pc.on("icecandidate") - async def on_icecandidate(candidate: Any) -> None: - if candidate and self._signaling_ws: - await self._signaling_ws.send_json( - { - "type": "ice-candidate", - "sessionId": session_id, - "data": { - "candidate": candidate.candidate, - "sdpMid": candidate.sdpMid, - "sdpMLineIndex": candidate.sdpMLineIndex, - }, - } - ) - - @pc.on("connectionstatechange") - async def on_connectionstatechange() -> None: - if pc.connectionState == "failed": - await self._close_session(session_id) - - async def _handle_offer(self, session_id: str, offer: dict[str, Any]) -> None: - """Handle incoming WebRTC offer. - - :param session_id: The session ID. - :param offer: The offer data. - """ - session = self.sessions.get(session_id) - if not session: - return - pc = session.peer_connection - sdp = offer.get("sdp") - sdp_type = offer.get("type") - if not sdp or not sdp_type: - self.logger.error("Invalid offer data: missing sdp or type") - return - await pc.setRemoteDescription( - RTCSessionDescription( - sdp=str(sdp), - type=str(sdp_type), - ) - ) - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) - if self._signaling_ws: - await self._signaling_ws.send_json( - { - "type": "answer", - "sessionId": session_id, - "data": { - "sdp": pc.localDescription.sdp, - "type": pc.localDescription.type, - }, - } - ) - - async def _handle_ice_candidate(self, session_id: str, candidate: dict[str, Any]) -> None: - """Handle incoming ICE candidate. - - :param session_id: The session ID. - :param candidate: The ICE candidate data. - """ - session = self.sessions.get(session_id) - if not session or not candidate: - return - - candidate_str = candidate.get("candidate") - sdp_mid = candidate.get("sdpMid") - sdp_mline_index = candidate.get("sdpMLineIndex") - - if not candidate_str: - return - - # Create RTCIceCandidate from the SDP string - try: - ice_candidate = RTCIceCandidate( - component=1, - foundation="", - ip="", - port=0, - priority=0, - protocol="udp", - type="host", - sdpMid=str(sdp_mid) if sdp_mid else None, - sdpMLineIndex=int(sdp_mline_index) if sdp_mline_index is not None else None, - ) - # Parse the candidate string to populate the fields - ice_candidate.candidate = str(candidate_str) # type: ignore[attr-defined] - await session.peer_connection.addIceCandidate(ice_candidate) - except Exception: - self.logger.exception("Failed to add ICE candidate: %s", candidate) - - async def _setup_data_channel(self, session: WebRTCSession) -> None: - """Set up data channel and bridge to local WebSocket. - - :param session: The WebRTC session. - """ - channel = session.data_channel - if not channel: - return - local_session = aiohttp.ClientSession() - try: - session.local_ws = await local_session.ws_connect(self.local_ws_url) - loop = asyncio.get_event_loop() - - # Store task references for proper cleanup - session.forward_to_local_task = asyncio.create_task(self._forward_to_local(session)) - session.forward_from_local_task = asyncio.create_task( - self._forward_from_local(session, local_session) - ) - - @channel.on("message") # type: ignore[misc] - def on_message(message: str) -> None: - # Called from aiortc thread, use call_soon_threadsafe - # Only queue message if session is still active - if session.forward_to_local_task and not session.forward_to_local_task.done(): - loop.call_soon_threadsafe(session.message_queue.put_nowait, message) - - @channel.on("close") # type: ignore[misc] - def on_close() -> None: - # Called from aiortc thread, use call_soon_threadsafe to schedule task - asyncio.run_coroutine_threadsafe(self._close_session(session.session_id), loop) - - except Exception: - self.logger.exception("Failed to connect to local WebSocket") - await local_session.close() - - async def _forward_to_local(self, session: WebRTCSession) -> None: - """Forward messages from WebRTC DataChannel to local WebSocket. - - :param session: The WebRTC session. - """ - try: - while session.local_ws and not session.local_ws.closed: - message = await session.message_queue.get() - - # Check if this is an HTTP proxy request - try: - msg_data = json.loads(message) - if isinstance(msg_data, dict) and msg_data.get("type") == "http-proxy-request": - # Handle HTTP proxy request - await self._handle_http_proxy_request(session, msg_data) - continue - except (json.JSONDecodeError, ValueError): - pass - - # Regular WebSocket message - if session.local_ws and not session.local_ws.closed: - await session.local_ws.send_str(message) - except asyncio.CancelledError: - # Task was cancelled during cleanup, this is expected - self.logger.debug("Forward to local task cancelled for session %s", session.session_id) - raise - except Exception: - self.logger.exception("Error forwarding to local WebSocket") - - async def _forward_from_local( - self, session: WebRTCSession, local_session: aiohttp.ClientSession - ) -> None: - """Forward messages from local WebSocket to WebRTC DataChannel. - - :param session: The WebRTC session. - :param local_session: The aiohttp client session. - """ - try: - async for msg in session.local_ws: - if msg.type == aiohttp.WSMsgType.TEXT: - if session.data_channel and session.data_channel.readyState == "open": - session.data_channel.send(msg.data) - elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): - break - except asyncio.CancelledError: - # Task was cancelled during cleanup, this is expected - self.logger.debug( - "Forward from local task cancelled for session %s", session.session_id - ) - raise - except Exception: - self.logger.exception("Error forwarding from local WebSocket") - finally: - await local_session.close() - - async def _handle_http_proxy_request( - self, session: WebRTCSession, request_data: dict[str, Any] - ) -> None: - """Handle HTTP proxy request from remote client. - - :param session: The WebRTC session. - :param request_data: The HTTP proxy request data. - """ - request_id = request_data.get("id") - method = request_data.get("method", "GET") - path = request_data.get("path", "/") - headers = request_data.get("headers", {}) - - # Build local HTTP URL - # Extract host and port from local_ws_url (ws://localhost:8095/ws) - ws_url_parts = self.local_ws_url.replace("ws://", "").split("/") - host_port = ws_url_parts[0] # localhost:8095 - local_http_url = f"http://{host_port}{path}" - - self.logger.debug("HTTP proxy request: %s %s", method, local_http_url) - - try: - # Create a new HTTP client session for this request - async with ( - aiohttp.ClientSession() as http_session, - http_session.request(method, local_http_url, headers=headers) as response, - ): - # Read response body - body = await response.read() - - # Prepare response data - response_data = { - "type": "http-proxy-response", - "id": request_id, - "status": response.status, - "headers": dict(response.headers), - "body": body.hex(), # Send as hex string to avoid encoding issues - } - - # Send response back through data channel - if session.data_channel and session.data_channel.readyState == "open": - session.data_channel.send(json.dumps(response_data)) - - except Exception as err: - self.logger.exception("Error handling HTTP proxy request") - # Send error response - error_response = { - "type": "http-proxy-response", - "id": request_id, - "status": 500, - "headers": {"Content-Type": "text/plain"}, - "body": str(err).encode().hex(), - } - if session.data_channel and session.data_channel.readyState == "open": - session.data_channel.send(json.dumps(error_response)) - - async def _close_session(self, session_id: str) -> None: - """Close a WebRTC session. - - :param session_id: The session ID. - """ - session = self.sessions.pop(session_id, None) - if not session: - return - - # Cancel forwarding tasks first to prevent race conditions - if session.forward_to_local_task and not session.forward_to_local_task.done(): - session.forward_to_local_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await session.forward_to_local_task - - if session.forward_from_local_task and not session.forward_from_local_task.done(): - session.forward_from_local_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await session.forward_from_local_task - - # Close connections - if session.local_ws and not session.local_ws.closed: - await session.local_ws.close() - if session.data_channel: - session.data_channel.close() - await session.peer_connection.close() diff --git a/music_assistant/mass.py b/music_assistant/mass.py index 3b04bc89..19d29bb6 100644 --- a/music_assistant/mass.py +++ b/music_assistant/mass.py @@ -43,6 +43,7 @@ from music_assistant.controllers.metadata import MetaDataController from music_assistant.controllers.music import MusicController from music_assistant.controllers.player_queues import PlayerQueuesController from music_assistant.controllers.players.player_controller import PlayerController +from music_assistant.controllers.remote_access import RemoteAccessController from music_assistant.controllers.streams import StreamsController from music_assistant.controllers.webserver import WebserverController from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user @@ -111,6 +112,7 @@ class MusicAssistant: players: PlayerController player_queues: PlayerQueuesController streams: StreamsController + remote_access: RemoteAccessController _aiobrowser: AsyncServiceBrowser def __init__(self, storage_path: str, cache_path: str, safe_mode: bool = False) -> None: @@ -166,6 +168,7 @@ class MusicAssistant: self.players = PlayerController(self) self.player_queues = PlayerQueuesController(self) self.streams = StreamsController(self) + self.remote_access = RemoteAccessController(self) # add manifests for core controllers for controller_name in CONFIGURABLE_CORE_CONTROLLERS: controller: CoreController = getattr(self, controller_name) @@ -181,6 +184,8 @@ class MusicAssistant: # not yet available while we're starting (or performing migrations) self._register_api_commands() await self.webserver.setup(await self.config.get_core_config("webserver")) + # setup remote access after webserver (it needs webserver's port) + await self.remote_access.setup(await self.config.get_core_config("remote_access")) # setup discovery await self._setup_discovery() # load providers @@ -202,6 +207,7 @@ class MusicAssistant: ) # stop core controllers await self.streams.close() + await self.remote_access.close() await self.webserver.close() await self.metadata.close() await self.music.close()