From: Marcel van der Veldt Date: Tue, 2 Dec 2025 15:56:02 +0000 (+0100) Subject: Remote access changes X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=5ee70bb503db8434181fc97404e348d0debc304e;p=music-assistant-server.git Remote access changes --- diff --git a/music_assistant/controllers/webserver/controller.py b/music_assistant/controllers/webserver/controller.py index e46c00e3..34b562ab 100644 --- a/music_assistant/controllers/webserver/controller.py +++ b/music_assistant/controllers/webserver/controller.py @@ -56,6 +56,7 @@ from .helpers.auth_middleware import ( set_current_user, ) from .helpers.auth_providers import BuiltinLoginProvider +from .remote_access import RemoteAccessManager from .websocket_client import WebsocketClientHandler if TYPE_CHECKING: @@ -91,6 +92,7 @@ class WebserverController(CoreController): ) self.manifest.icon = "web-box" self.auth = AuthenticationManager(self) + self.remote_access = RemoteAccessManager(self) @property def base_url(self) -> str: @@ -372,8 +374,12 @@ class WebserverController(CoreController): # announce to HA supervisor await self._announce_to_homeassistant() + # Setup remote access after webserver is running + await self.remote_access.setup() + async def close(self) -> None: """Cleanup on exit.""" + await self.remote_access.close() for client in set(self.clients): await client.disconnect() await self._server.close() diff --git a/music_assistant/controllers/webserver/remote_access/__init__.py b/music_assistant/controllers/webserver/remote_access/__init__.py new file mode 100644 index 00000000..281bd07d --- /dev/null +++ b/music_assistant/controllers/webserver/remote_access/__init__.py @@ -0,0 +1,287 @@ +""" +Remote Access subcomponent for the Webserver Controller. + +This module manages WebRTC-based remote access to Music Assistant instances. +It connects to a signaling server and handles incoming WebRTC connections, +bridging them to the local WebSocket API. + +Requires an active Home Assistant Cloud subscription due to STUN/TURN/SIGNALING server usage. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, cast + +from mashumaro import DataClassDictMixin +from music_assistant_models.enums import EventType +from music_assistant_models.errors import UnsupportedFeaturedException + +from music_assistant.constants import CONF_CORE +from music_assistant.controllers.webserver.remote_access.gateway import ( + WebRTCGateway, + generate_remote_id, +) +from music_assistant.helpers.api import api_command + +if TYPE_CHECKING: + from music_assistant_models.event import MassEvent + + from music_assistant.controllers.webserver import WebserverController + from music_assistant.providers.hass import HomeAssistantProvider + +# Signaling server URL +SIGNALING_SERVER_URL = "wss://signaling.music-assistant.io/ws" + +# Storage keys +CONF_KEY_MAIN = "remote_access" +CONF_REMOTE_ID = "remote_id" +CONF_ENABLED = "enabled" + + +@dataclass +class RemoteAccessInfo(DataClassDictMixin): + """Remote Access information dataclass.""" + + enabled: bool + running: bool + connected: bool + remote_id: str + ha_cloud_available: bool + signaling_url: str + + +class RemoteAccessManager: + """Manages WebRTC-based remote access for the webserver.""" + + def __init__(self, webserver: WebserverController) -> None: + """Initialize the remote access manager.""" + self.webserver = webserver + self.mass = webserver.mass + self.logger = webserver.logger.getChild("remote_access") + self.gateway: WebRTCGateway | None = None + self._remote_id: str | None = None + self._enabled: bool = False + self._ha_cloud_available: bool = False + + async def setup(self) -> None: + """Initialize the remote access manager.""" + # Load config from storage + enabled_value = self.mass.config.get(f"{CONF_CORE}/{CONF_KEY_MAIN}/{CONF_ENABLED}", False) + self._enabled = bool(enabled_value) + remote_id_value = self.mass.config.get( + f"{CONF_CORE}/{CONF_KEY_MAIN}/{CONF_REMOTE_ID}", None + ) + if not remote_id_value: + remote_id_value = generate_remote_id() + self.mass.config.set(f"{CONF_CORE}/{CONF_KEY_MAIN}/{CONF_REMOTE_ID}", remote_id_value) + self.logger.debug("Generated new Remote ID: %s", remote_id_value) + + self._remote_id = str(remote_id_value) + self._register_api_commands() + # Subscribe to provider updates to check for Home Assistant Cloud + self._ha_cloud_available = await self._check_ha_cloud_status() + self.mass.subscribe( + self._on_providers_updated, EventType.PROVIDERS_UPDATED, id_filter="hass" + ) + if self._enabled and self._ha_cloud_available: + await self.start() + + async def close(self) -> None: + """Cleanup on exit.""" + await self.stop() + + async def start(self) -> None: + """Start the remote access gateway.""" + if self.is_running: + self.logger.debug("Remote access already running") + return + if not self._ha_cloud_available: + raise UnsupportedFeaturedException( + "Home Assistant Cloud subscription is required for remote access" + ) + if not self._enabled: + # should not happen, but guard anyway + self.logger.debug("Remote access is disabled in configuration") + return + + self.logger.info("Starting remote access with Remote ID: %s", self._remote_id) + + # Determine local WebSocket URL from webserver config + base_url = self.mass.webserver.base_url + local_ws_url = base_url.replace("http", "ws") + if not local_ws_url.endswith("/"): + local_ws_url += "/" + local_ws_url += "ws" + + # Get ICE servers from HA Cloud if available + ice_servers: list[dict[str, str]] | None = None + if await self._check_ha_cloud_status(): + self.logger.info( + "Home Assistant Cloud subscription detected, using HA cloud ICE servers" + ) + ice_servers = await self._get_ha_cloud_ice_servers() + else: + self.logger.info( + "Home Assistant Cloud subscription not detected, using default STUN servers" + ) + + # Initialize and start the WebRTC gateway + self.gateway = WebRTCGateway( + http_session=self.mass.http_session, + signaling_url=SIGNALING_SERVER_URL, + local_ws_url=local_ws_url, + remote_id=self._remote_id, + ice_servers=ice_servers, + ) + + await self.gateway.start() + self._enabled = True + + async def stop(self) -> None: + """Stop the remote access gateway.""" + if self.gateway: + await self.gateway.stop() + self.gateway = None + self.logger.debug("WebRTC Remote Access stopped") + + async def _on_providers_updated(self, event: MassEvent) -> None: + """ + Handle providers updated event. + + :param event: The providers updated event. + """ + last_ha_cloud_available = self._ha_cloud_available + self._ha_cloud_available = await self._check_ha_cloud_status() + if self._ha_cloud_available == last_ha_cloud_available: + return # No change in HA Cloud status + if self.is_running and not self._ha_cloud_available: + self.logger.warning( + "Home Assistant Cloud subscription is no longer active, stopping remote access" + ) + await self.stop() + return + allow_start = self._ha_cloud_available and self._enabled + if allow_start and self.is_running: + return # Already running + if allow_start: + self.mass.create_task(self.start()) + + async def _check_ha_cloud_status(self) -> bool: + """Check if Home Assistant Cloud subscription is active. + + :return: True if HA Cloud is logged in and has active subscription. + """ + # Find the Home Assistant provider + ha_provider = cast("HomeAssistantProvider | None", self.mass.get_provider("hass")) + if not ha_provider: + return False + + try: + # Access the hass client from the provider + hass_client = ha_provider.hass + if not hass_client or not hass_client.connected: + return False + + # Call cloud/status command to check subscription + result = await hass_client.send_command("cloud/status") + + # Check for logged_in and active_subscription + logged_in = result.get("logged_in", False) + active_subscription = result.get("active_subscription", False) + + return bool(logged_in and active_subscription) + + except Exception: + return False + + async def _get_ha_cloud_ice_servers(self) -> list[dict[str, str]] | None: + """Get ICE servers from Home Assistant Cloud. + + :return: List of ICE server configurations or None if unavailable. + """ + # Find the Home Assistant provider + ha_provider = cast("HomeAssistantProvider | None", self.mass.get_provider("hass")) + if not ha_provider: + return None + + try: + hass_client = ha_provider.hass + if not hass_client or not hass_client.connected: + return None + + # Try to get ICE servers from HA Cloud + # This might be available via a cloud API endpoint + # For now, return None and use default STUN servers + # TODO: Research if HA Cloud exposes ICE/TURN server endpoints + self.logger.debug( + "Using default STUN servers (HA Cloud ICE servers not yet implemented)" + ) + return None + + except Exception: + self.logger.exception("Error getting Home Assistant Cloud ICE servers") + return None + + @property + def is_enabled(self) -> bool: + """Return whether WebRTC remote access is enabled.""" + return self._enabled + + @property + def is_running(self) -> bool: + """Return whether the gateway is running.""" + return self.gateway is not None and self.gateway.is_running + + @property + def is_connected(self) -> bool: + """Return whether the gateway is connected to the signaling server.""" + return self.gateway is not None and self.gateway.is_connected + + @property + def remote_id(self) -> str | None: + """Return the current Remote ID.""" + return self._remote_id + + def _register_api_commands(self) -> None: + """Register API commands for remote access.""" + + @api_command("remote_access/info") + def get_remote_access_info() -> RemoteAccessInfo: + """Get remote access information. + + Returns information about the remote access configuration including + whether it's enabled, running status, connected status, and the Remote ID. + """ + return RemoteAccessInfo( + enabled=self.is_enabled, + running=self.is_running, + connected=self.is_connected, + remote_id=self._remote_id or "", + ha_cloud_available=self._ha_cloud_available, + signaling_url=SIGNALING_SERVER_URL, + ) + + @api_command("remote_access/configure", required_role="admin") + async def configure_remote_access( + enabled: bool, + ) -> RemoteAccessInfo: + """ + Configure remote access settings. + + :param enabled: Enable or disable remote access. + + Starts or stops the WebRTC gateway based on the enabled parameter. + Returns the updated remote access info. + """ + # Save configuration + self._enabled = enabled + self.mass.config.set(f"{CONF_CORE}/{CONF_KEY_MAIN}/{CONF_ENABLED}", enabled) + allow_start = self._ha_cloud_available and self._enabled + + # Start or stop the gateway based on enabled flag + if allow_start and not self.is_running: + await self.start() + elif not allow_start and self.is_running: + await self.stop() + return get_remote_access_info() diff --git a/music_assistant/controllers/webserver/remote_access/gateway.py b/music_assistant/controllers/webserver/remote_access/gateway.py new file mode 100644 index 00000000..80d82b8b --- /dev/null +++ b/music_assistant/controllers/webserver/remote_access/gateway.py @@ -0,0 +1,633 @@ +"""Music Assistant WebRTC Gateway. + +This module provides WebRTC-based remote access to Music Assistant instances. +It connects to a signaling server and handles incoming WebRTC connections, +bridging them to the local WebSocket API. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +import secrets +import string +from dataclasses import dataclass, field +from typing import Any + +import aiohttp +from aiortc import ( + RTCConfiguration, + RTCIceCandidate, + RTCIceServer, + RTCPeerConnection, + RTCSessionDescription, +) + +from music_assistant.constants import MASS_LOGGER_NAME + +LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.remote_access") + +# Reduce verbose logging from aiortc/aioice +logging.getLogger("aioice").setLevel(logging.WARNING) +logging.getLogger("aiortc").setLevel(logging.WARNING) + + +def generate_remote_id() -> str: + """Generate a unique Remote ID in the format MA-XXXX-XXXX.""" + chars = string.ascii_uppercase + string.digits + part1 = "".join(secrets.choice(chars) for _ in range(4)) + part2 = "".join(secrets.choice(chars) for _ in range(4)) + return f"MA-{part1}-{part2}" + + +@dataclass +class WebRTCSession: + """Represents an active WebRTC session with a remote client.""" + + session_id: str + peer_connection: RTCPeerConnection + data_channel: Any = None + local_ws: Any = None + message_queue: asyncio.Queue[str] = field(default_factory=asyncio.Queue) + forward_to_local_task: asyncio.Task[None] | None = None + forward_from_local_task: asyncio.Task[None] | None = None + + +class WebRTCGateway: + """WebRTC Gateway for Music Assistant Remote Access. + + This gateway: + 1. Connects to a signaling server + 2. Registers with a unique Remote ID + 3. Handles incoming WebRTC connections from remote PWA clients + 4. Bridges WebRTC DataChannel messages to the local WebSocket API + """ + + def __init__( + self, + http_session: aiohttp.ClientSession, + signaling_url: str = "wss://signaling.music-assistant.io/ws", + local_ws_url: str = "ws://localhost:8095/ws", + ice_servers: list[dict[str, Any]] | None = None, + remote_id: str | None = None, + ) -> None: + """Initialize the WebRTC Gateway. + + :param http_session: Shared aiohttp ClientSession to use for HTTP/WebSocket connections. + :param signaling_url: WebSocket URL of the signaling server. + :param local_ws_url: Local WebSocket URL to bridge to. + :param ice_servers: List of ICE server configurations. + :param remote_id: Optional Remote ID to use (generated if not provided). + """ + self.http_session = http_session + self.signaling_url = signaling_url + self.local_ws_url = local_ws_url + self.remote_id = remote_id or generate_remote_id() + self.logger = LOGGER + + self.ice_servers = ice_servers or [ + {"urls": "stun:stun.l.google.com:19302"}, + {"urls": "stun:stun1.l.google.com:19302"}, + {"urls": "stun:stun.cloudflare.com:3478"}, + ] + + self.sessions: dict[str, WebRTCSession] = {} + self._signaling_ws: aiohttp.ClientWebSocketResponse | None = None + self._running = False + self._reconnect_delay = 5 + self._max_reconnect_delay = 60 + self._current_reconnect_delay = 5 + self._run_task: asyncio.Task[None] | None = None + self._is_connected = False + + @property + def is_running(self) -> bool: + """Return whether the gateway is running.""" + return self._running + + @property + def is_connected(self) -> bool: + """Return whether the gateway is connected to the signaling server.""" + return self._is_connected + + async def start(self) -> None: + """Start the WebRTC Gateway.""" + self.logger.info("Starting WebRTC Gateway with Remote ID: %s", self.remote_id) + self.logger.debug("Signaling URL: %s", self.signaling_url) + self.logger.debug("Local WS URL: %s", self.local_ws_url) + self._running = True + self._run_task = asyncio.create_task(self._run()) + self.logger.debug("WebRTC Gateway start task created") + + async def stop(self) -> None: + """Stop the WebRTC Gateway.""" + self.logger.info("Stopping WebRTC Gateway") + self._running = False + + # Close all sessions + for session_id in list(self.sessions.keys()): + await self._close_session(session_id) + + # Close signaling connection gracefully + if self._signaling_ws and not self._signaling_ws.closed: + try: + await self._signaling_ws.close() + except Exception: + self.logger.debug("Error closing signaling WebSocket", exc_info=True) + + # Cancel run task + if self._run_task and not self._run_task.done(): + self._run_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._run_task + + self._signaling_ws = None + + async def _run(self) -> None: + """Run the main loop with reconnection logic.""" + self.logger.debug("WebRTC Gateway _run() loop starting") + while self._running: + try: + await self._connect_to_signaling() + # Connection closed gracefully or with error + self._is_connected = False + if self._running: + self.logger.warning( + "Signaling server connection lost. Reconnecting in %ss...", + self._current_reconnect_delay, + ) + except Exception: + self._is_connected = False + self.logger.exception("Signaling connection error") + if self._running: + self.logger.info( + "Reconnecting to signaling server in %ss", + self._current_reconnect_delay, + ) + + if self._running: + await asyncio.sleep(self._current_reconnect_delay) + # Exponential backoff with max limit + self._current_reconnect_delay = min( + self._current_reconnect_delay * 2, self._max_reconnect_delay + ) + + async def _connect_to_signaling(self) -> None: + """Connect to the signaling server.""" + self.logger.info("Connecting to signaling server: %s", self.signaling_url) + try: + self._signaling_ws = await self.http_session.ws_connect( + self.signaling_url, + heartbeat=30, # Send WebSocket ping every 30 seconds + autoping=True, # Automatically respond to pings + ) + self.logger.debug("WebSocket connection established") + # Small delay to let any previous connection fully close on the server side + # This helps prevent race conditions during reconnection + await asyncio.sleep(0.5) + self.logger.debug("Sending registration") + await self._register() + self._is_connected = True + # Reset reconnect delay on successful connection + self._current_reconnect_delay = self._reconnect_delay + self.logger.info("Connected and registered with signaling server") + + # Message loop + self.logger.debug("Entering message loop") + async for msg in self._signaling_ws: + self.logger.debug("Received WebSocket message type: %s", msg.type) + if msg.type == aiohttp.WSMsgType.TEXT: + try: + await self._handle_signaling_message(json.loads(msg.data)) + except Exception: + self.logger.exception("Error handling signaling message") + elif msg.type == aiohttp.WSMsgType.PING: + # WebSocket ping - autoping should handle this, just log + self.logger.debug("Received WebSocket PING") + elif msg.type == aiohttp.WSMsgType.PONG: + # WebSocket pong response - just log + self.logger.debug("Received WebSocket PONG") + elif msg.type == aiohttp.WSMsgType.CLOSE: + # Close frame received + self.logger.warning( + "Signaling server sent close frame: code=%s, reason=%s", + msg.data, + msg.extra, + ) + break + elif msg.type == aiohttp.WSMsgType.CLOSED: + self.logger.warning("Signaling server closed connection") + break + elif msg.type == aiohttp.WSMsgType.ERROR: + self.logger.error("WebSocket error: %s", self._signaling_ws.exception()) + break + else: + self.logger.warning("Unexpected WebSocket message type: %s", msg.type) + + self.logger.info( + "Message loop exited - WebSocket closed: %s", self._signaling_ws.closed + ) + except TimeoutError: + self.logger.error("Timeout connecting to signaling server") + except aiohttp.ClientError as err: + self.logger.error("Failed to connect to signaling server: %s", err) + except Exception: + self.logger.exception("Unexpected error in signaling connection") + finally: + self._is_connected = False + self._signaling_ws = None + + async def _register(self) -> None: + """Register with the signaling server.""" + if self._signaling_ws: + registration_msg = { + "type": "register-server", + "remoteId": self.remote_id, + } + self.logger.info( + "Sending registration to signaling server with Remote ID: %s", + self.remote_id, + ) + self.logger.debug("Registration message: %s", registration_msg) + await self._signaling_ws.send_json(registration_msg) + self.logger.debug("Registration message sent successfully") + else: + self.logger.warning("Cannot register: signaling websocket is not connected") + + async def _handle_signaling_message(self, message: dict[str, Any]) -> None: + """Handle incoming signaling messages. + + :param message: The signaling message. + """ + msg_type = message.get("type") + self.logger.debug("Received signaling message: %s - Full message: %s", msg_type, message) + + if msg_type == "ping": + # Respond to ping with pong + if self._signaling_ws: + await self._signaling_ws.send_json({"type": "pong"}) + elif msg_type == "pong": + # Server responded to our ping, connection is alive + pass + elif msg_type == "registered": + self.logger.info("Registered with signaling server as: %s", message.get("remoteId")) + elif msg_type == "error": + self.logger.error( + "Signaling server error: %s", + message.get("message", "Unknown error"), + ) + elif msg_type == "client-connected": + session_id = message.get("sessionId") + if session_id: + await self._create_session(session_id) + elif msg_type == "client-disconnected": + session_id = message.get("sessionId") + if session_id: + await self._close_session(session_id) + elif msg_type == "offer": + session_id = message.get("sessionId") + offer_data = message.get("data") + if session_id and offer_data: + await self._handle_offer(session_id, offer_data) + elif msg_type == "ice-candidate": + session_id = message.get("sessionId") + candidate_data = message.get("data") + if session_id and candidate_data: + await self._handle_ice_candidate(session_id, candidate_data) + + async def _create_session(self, session_id: str) -> None: + """Create a new WebRTC session. + + :param session_id: The session ID. + """ + config = RTCConfiguration( + iceServers=[RTCIceServer(**server) for server in self.ice_servers] + ) + pc = RTCPeerConnection(configuration=config) + session = WebRTCSession(session_id=session_id, peer_connection=pc) + self.sessions[session_id] = session + + @pc.on("datachannel") + def on_datachannel(channel: Any) -> None: + session.data_channel = channel + asyncio.create_task(self._setup_data_channel(session)) + + @pc.on("icecandidate") + async def on_icecandidate(candidate: Any) -> None: + if candidate and self._signaling_ws: + await self._signaling_ws.send_json( + { + "type": "ice-candidate", + "sessionId": session_id, + "data": { + "candidate": candidate.candidate, + "sdpMid": candidate.sdpMid, + "sdpMLineIndex": candidate.sdpMLineIndex, + }, + } + ) + + @pc.on("connectionstatechange") + async def on_connectionstatechange() -> None: + if pc.connectionState == "failed": + await self._close_session(session_id) + + async def _handle_offer(self, session_id: str, offer: dict[str, Any]) -> None: + """Handle incoming WebRTC offer. + + :param session_id: The session ID. + :param offer: The offer data. + """ + session = self.sessions.get(session_id) + if not session: + return + pc = session.peer_connection + + # Check if peer connection is already closed or closing + if pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Ignoring offer for session %s - connection state: %s", + session_id, + pc.connectionState, + ) + return + + sdp = offer.get("sdp") + sdp_type = offer.get("type") + if not sdp or not sdp_type: + self.logger.error("Invalid offer data: missing sdp or type") + return + + try: + await pc.setRemoteDescription( + RTCSessionDescription( + sdp=str(sdp), + type=str(sdp_type), + ) + ) + + # Check again if session was closed during setRemoteDescription + if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Session %s closed during setRemoteDescription, aborting offer handling", + session_id, + ) + return + + answer = await pc.createAnswer() + + # Check again before setLocalDescription + if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Session %s closed during createAnswer, aborting offer handling", + session_id, + ) + return + + await pc.setLocalDescription(answer) + + # Final check before sending answer + if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Session %s closed during setLocalDescription, skipping answer transmission", + session_id, + ) + return + + if self._signaling_ws: + await self._signaling_ws.send_json( + { + "type": "answer", + "sessionId": session_id, + "data": { + "sdp": pc.localDescription.sdp, + "type": pc.localDescription.type, + }, + } + ) + except Exception: + self.logger.exception("Error handling offer for session %s", session_id) + # Clean up the session on error + await self._close_session(session_id) + + async def _handle_ice_candidate(self, session_id: str, candidate: dict[str, Any]) -> None: + """Handle incoming ICE candidate. + + :param session_id: The session ID. + :param candidate: The ICE candidate data. + """ + session = self.sessions.get(session_id) + if not session or not candidate: + return + + # Check if peer connection is already closed or closing + pc = session.peer_connection + if pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Ignoring ICE candidate for session %s - connection state: %s", + session_id, + pc.connectionState, + ) + return + + candidate_str = candidate.get("candidate") + sdp_mid = candidate.get("sdpMid") + sdp_mline_index = candidate.get("sdpMLineIndex") + + if not candidate_str: + return + + # Create RTCIceCandidate from the SDP string + try: + ice_candidate = RTCIceCandidate( + component=1, + foundation="", + ip="", + port=0, + priority=0, + protocol="udp", + type="host", + sdpMid=str(sdp_mid) if sdp_mid else None, + sdpMLineIndex=int(sdp_mline_index) if sdp_mline_index is not None else None, + ) + # Parse the candidate string to populate the fields + ice_candidate.candidate = str(candidate_str) # type: ignore[attr-defined] + + # Check if session was closed before adding candidate + if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Session %s closed before adding ICE candidate, skipping", + session_id, + ) + return + + await session.peer_connection.addIceCandidate(ice_candidate) + except Exception: + self.logger.exception( + "Failed to add ICE candidate for session %s: %s", session_id, candidate + ) + + async def _setup_data_channel(self, session: WebRTCSession) -> None: + """Set up data channel and bridge to local WebSocket. + + :param session: The WebRTC session. + """ + channel = session.data_channel + if not channel: + return + try: + session.local_ws = await self.http_session.ws_connect(self.local_ws_url) + loop = asyncio.get_event_loop() + + # Store task references for proper cleanup + session.forward_to_local_task = asyncio.create_task(self._forward_to_local(session)) + session.forward_from_local_task = asyncio.create_task(self._forward_from_local(session)) + + @channel.on("message") # type: ignore[misc] + def on_message(message: str) -> None: + # Called from aiortc thread, use call_soon_threadsafe + # Only queue message if session is still active + if session.forward_to_local_task and not session.forward_to_local_task.done(): + loop.call_soon_threadsafe(session.message_queue.put_nowait, message) + + @channel.on("close") # type: ignore[misc] + def on_close() -> None: + # Called from aiortc thread, use call_soon_threadsafe to schedule task + asyncio.run_coroutine_threadsafe(self._close_session(session.session_id), loop) + + except Exception: + self.logger.exception("Failed to connect to local WebSocket") + + async def _forward_to_local(self, session: WebRTCSession) -> None: + """Forward messages from WebRTC DataChannel to local WebSocket. + + :param session: The WebRTC session. + """ + try: + while session.local_ws and not session.local_ws.closed: + message = await session.message_queue.get() + + # Check if this is an HTTP proxy request + try: + msg_data = json.loads(message) + if isinstance(msg_data, dict) and msg_data.get("type") == "http-proxy-request": + # Handle HTTP proxy request + await self._handle_http_proxy_request(session, msg_data) + continue + except (json.JSONDecodeError, ValueError): + pass + + # Regular WebSocket message + if session.local_ws and not session.local_ws.closed: + await session.local_ws.send_str(message) + except asyncio.CancelledError: + # Task was cancelled during cleanup, this is expected + self.logger.debug("Forward to local task cancelled for session %s", session.session_id) + raise + except Exception: + self.logger.exception("Error forwarding to local WebSocket") + + async def _forward_from_local(self, session: WebRTCSession) -> None: + """Forward messages from local WebSocket to WebRTC DataChannel. + + :param session: The WebRTC session. + """ + try: + async for msg in session.local_ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if session.data_channel and session.data_channel.readyState == "open": + session.data_channel.send(msg.data) + elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): + break + except asyncio.CancelledError: + # Task was cancelled during cleanup, this is expected + self.logger.debug( + "Forward from local task cancelled for session %s", session.session_id + ) + raise + except Exception: + self.logger.exception("Error forwarding from local WebSocket") + + async def _handle_http_proxy_request( + self, session: WebRTCSession, request_data: dict[str, Any] + ) -> None: + """Handle HTTP proxy request from remote client. + + :param session: The WebRTC session. + :param request_data: The HTTP proxy request data. + """ + request_id = request_data.get("id") + method = request_data.get("method", "GET") + path = request_data.get("path", "/") + headers = request_data.get("headers", {}) + + # Build local HTTP URL + # Extract host and port from local_ws_url (ws://localhost:8095/ws) + ws_url_parts = self.local_ws_url.replace("ws://", "").split("/") + host_port = ws_url_parts[0] # localhost:8095 + local_http_url = f"http://{host_port}{path}" + + self.logger.debug("HTTP proxy request: %s %s", method, local_http_url) + + try: + # Use shared HTTP session for this request + async with self.http_session.request( + method, local_http_url, headers=headers + ) as response: + # Read response body + body = await response.read() + + # Prepare response data + response_data = { + "type": "http-proxy-response", + "id": request_id, + "status": response.status, + "headers": dict(response.headers), + "body": body.hex(), # Send as hex string to avoid encoding issues + } + + # Send response back through data channel + if session.data_channel and session.data_channel.readyState == "open": + session.data_channel.send(json.dumps(response_data)) + + except Exception as err: + self.logger.exception("Error handling HTTP proxy request") + # Send error response + error_response = { + "type": "http-proxy-response", + "id": request_id, + "status": 500, + "headers": {"Content-Type": "text/plain"}, + "body": str(err).encode().hex(), + } + if session.data_channel and session.data_channel.readyState == "open": + session.data_channel.send(json.dumps(error_response)) + + async def _close_session(self, session_id: str) -> None: + """Close a WebRTC session. + + :param session_id: The session ID. + """ + session = self.sessions.pop(session_id, None) + if not session: + return + + # Cancel forwarding tasks first to prevent race conditions + if session.forward_to_local_task and not session.forward_to_local_task.done(): + session.forward_to_local_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await session.forward_to_local_task + + if session.forward_from_local_task and not session.forward_from_local_task.done(): + session.forward_from_local_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await session.forward_from_local_task + + # Close connections + if session.local_ws and not session.local_ws.closed: + await session.local_ws.close() + if session.data_channel: + session.data_channel.close() + await session.peer_connection.close() diff --git a/music_assistant/providers/remote_access/__init__.py b/music_assistant/providers/remote_access/__init__.py deleted file mode 100644 index 1ec87564..00000000 --- a/music_assistant/providers/remote_access/__init__.py +++ /dev/null @@ -1,290 +0,0 @@ -""" -Remote Access Plugin Provider for Music Assistant. - -This plugin manages WebRTC-based remote access to Music Assistant instances. -It connects to a signaling server and handles incoming WebRTC connections, -bridging them to the local WebSocket API. - -Requires an active Home Assistant Cloud subscription for the best experience. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, cast - -from music_assistant_models.config_entries import ConfigEntry -from music_assistant_models.enums import ConfigEntryType, EventType, ProviderFeature - -from music_assistant.constants import CONF_ENABLED -from music_assistant.helpers.api import api_command -from music_assistant.models.plugin import PluginProvider -from music_assistant.providers.remote_access.gateway import WebRTCGateway, generate_remote_id - -if TYPE_CHECKING: - from music_assistant_models.config_entries import ConfigValueType, ProviderConfig - from music_assistant_models.event import MassEvent - from music_assistant_models.provider import ProviderManifest - - from music_assistant import MusicAssistant - from music_assistant.models import ProviderInstanceType - from music_assistant.providers.hass import HomeAssistantProvider - -# Signaling server URL -SIGNALING_SERVER_URL = "wss://signaling.music-assistant.io/ws" - -# Config keys -CONF_REMOTE_ID = "remote_id" - -SUPPORTED_FEATURES: set[ProviderFeature] = set() - - -async def setup( - mass: MusicAssistant, manifest: ProviderManifest, config: ProviderConfig -) -> ProviderInstanceType: - """Initialize provider(instance) with given configuration.""" - return RemoteAccessProvider(mass, manifest, config) - - -async def get_config_entries( - mass: MusicAssistant, - instance_id: str | None = None, - action: str | None = None, # noqa: ARG001 - values: dict[str, ConfigValueType] | None = None, # noqa: ARG001 -) -> tuple[ConfigEntry, ...]: - """ - Return Config entries to setup this provider. - - :param mass: MusicAssistant instance. - :param instance_id: id of an existing provider instance (None if new instance setup). - :param action: [optional] action key called from config entries UI. - :param values: the (intermediate) raw values for config entries sent with the action. - """ - entries: list[ConfigEntry] = [ - ConfigEntry( - key=CONF_REMOTE_ID, - type=ConfigEntryType.STRING, - label="Remote ID", - description="Unique identifier for WebRTC remote access. " - "Generated automatically and should not be changed.", - required=False, - hidden=True, - ) - # TODO: Add a message that optimal experience requires Home Assistant Cloud subscription - ] - # Get the remote ID if instance exists - remote_id: str | None = None - if instance_id: - if remote_id_value := mass.config.get_raw_provider_config_value( - instance_id, CONF_REMOTE_ID - ): - remote_id = str(remote_id_value) - entries += [ - ConfigEntry( - key="remote_access_id_intro", - type=ConfigEntryType.LABEL, - label="Remote access is enabled. You can securely connect to your " - "Music Assistant instance from https://app.music-assistant.io or supported " - "(mobile) apps using the Remote ID below.", - hidden=False, - ), - ConfigEntry( - key="remote_access_id_label", - type=ConfigEntryType.LABEL, - label=f"Remote Access ID: {remote_id}", - hidden=False, - ), - ] - - return tuple(entries) - - -async def _check_ha_cloud_status(mass: MusicAssistant) -> bool: - """Check if Home Assistant Cloud subscription is active. - - :param mass: MusicAssistant instance. - :return: True if HA Cloud is logged in and has active subscription. - """ - # Find the Home Assistant provider - ha_provider = cast("HomeAssistantProvider | None", mass.get_provider("hass")) - if not ha_provider: - return False - - try: - # Access the hass client from the provider - hass_client = ha_provider.hass - if not hass_client or not hass_client.connected: - return False - - # Call cloud/status command to check subscription - result = await hass_client.send_command("cloud/status") - - # Check for logged_in and active_subscription - logged_in = result.get("logged_in", False) - active_subscription = result.get("active_subscription", False) - - return bool(logged_in and active_subscription) - - except Exception: - return False - - -class RemoteAccessProvider(PluginProvider): - """Plugin Provider for WebRTC-based remote access.""" - - gateway: WebRTCGateway | None = None - _remote_id: str | None = None - - async def loaded_in_mass(self) -> None: - """Call after the provider has been loaded.""" - remote_id_value = self.config.get_value(CONF_REMOTE_ID) - if not remote_id_value: - # First time setup, generate a new Remote ID - remote_id_value = generate_remote_id() - self._remote_id = remote_id_value - self.logger.debug("Generated new Remote ID: %s", remote_id_value) - await self.mass.config.save_provider_config( - self.domain, - { - CONF_ENABLED: True, - CONF_REMOTE_ID: remote_id_value, - }, - instance_id=self.instance_id, - ) - - else: - self._remote_id = str(remote_id_value) - - # Register API commands - self._register_api_commands() - - # Subscribe to provider updates to check for Home Assistant Cloud - self.mass.subscribe(self._on_providers_updated, EventType.PROVIDERS_UPDATED) - - # Try initial setup (providers might already be loaded) - await self._try_enable_remote_access() - - async def unload(self, is_removed: bool = False) -> None: - """Handle unload/close of the provider. - - :param is_removed: True when the provider is removed from the configuration. - """ - if self.gateway: - await self.gateway.stop() - self.gateway = None - self.logger.debug("WebRTC Remote Access stopped") - - def _on_remote_id_ready(self, remote_id: str) -> None: - """Handle Remote ID registration with signaling server. - - :param remote_id: The registered Remote ID. - """ - self.logger.debug("Remote ID registered with signaling server: %s", remote_id) - self._remote_id = remote_id - - def _on_providers_updated(self, event: MassEvent) -> None: - """Handle providers updated event. - - :param event: The providers updated event. - """ - if self.gateway is not None: - # Already set up, no need to check again - return - # Try to enable remote access when providers are updated - self.mass.create_task(self._try_enable_remote_access()) - - async def _try_enable_remote_access(self) -> None: - """Try to enable remote access if Home Assistant Cloud is available.""" - if self.gateway is not None: - # Already set up - return - - # Determine local WebSocket URL from webserver config - base_url = self.mass.webserver.base_url - local_ws_url = base_url.replace("http", "ws") - if not local_ws_url.endswith("/"): - local_ws_url += "/" - local_ws_url += "ws" - - # Get ICE servers from HA Cloud if available - ice_servers: list[dict[str, str]] | None = None - if await _check_ha_cloud_status(self.mass): - self.logger.info( - "Home Assistant Cloud subscription detected, using HA cloud ICE servers" - ) - ice_servers = await self._get_ha_cloud_ice_servers() - else: - self.logger.info( - "Home Assistant Cloud subscription not detected, using default STUN servers" - ) - - # Initialize and start the WebRTC gateway - self.gateway = WebRTCGateway( - http_session=self.mass.http_session, - signaling_url=SIGNALING_SERVER_URL, - local_ws_url=local_ws_url, - remote_id=self._remote_id, - on_remote_id_ready=self._on_remote_id_ready, - ice_servers=ice_servers, - ) - - await self.gateway.start() - self.logger.info("WebRTC Remote Access enabled - Remote ID: %s", self._remote_id) - - async def _get_ha_cloud_ice_servers(self) -> list[dict[str, str]] | None: - """Get ICE servers from Home Assistant Cloud. - - :return: List of ICE server configurations or None if unavailable. - """ - # Find the Home Assistant provider - ha_provider = cast("HomeAssistantProvider | None", self.mass.get_provider("hass")) - if not ha_provider: - return None - try: - hass_client = ha_provider.hass - if not hass_client or not hass_client.connected: - return None - - # Try to get ICE servers from HA Cloud - # This might be available via a cloud API endpoint - # For now, return None and use default STUN servers - # TODO: Research if HA Cloud exposes ICE/TURN server endpoints - self.logger.debug( - "Using default STUN servers (HA Cloud ICE servers not yet implemented)" - ) - return None - - except Exception: - self.logger.exception("Error getting Home Assistant Cloud ICE servers") - return None - - @property - def is_enabled(self) -> bool: - """Return whether WebRTC remote access is enabled.""" - return self.gateway is not None and self.gateway.is_running - - @property - def is_connected(self) -> bool: - """Return whether the gateway is connected to the signaling server.""" - return self.gateway is not None and self.gateway.is_connected - - @property - def remote_id(self) -> str | None: - """Return the current Remote ID.""" - return self._remote_id - - def _register_api_commands(self) -> None: - """Register API commands for remote access.""" - - @api_command("remote_access/info") - def get_remote_access_info() -> dict[str, str | bool]: - """Get remote access information. - - Returns information about the remote access configuration including - whether it's enabled, connected status, and the Remote ID for connecting. - """ - return { - "enabled": self.is_enabled, - "connected": self.is_connected, - "remote_id": self._remote_id or "", - "signaling_url": SIGNALING_SERVER_URL, - } diff --git a/music_assistant/providers/remote_access/gateway.py b/music_assistant/providers/remote_access/gateway.py deleted file mode 100644 index 8f9197ba..00000000 --- a/music_assistant/providers/remote_access/gateway.py +++ /dev/null @@ -1,641 +0,0 @@ -"""Music Assistant WebRTC Gateway. - -This module provides WebRTC-based remote access to Music Assistant instances. -It connects to a signaling server and handles incoming WebRTC connections, -bridging them to the local WebSocket API. -""" - -from __future__ import annotations - -import asyncio -import contextlib -import json -import logging -import secrets -import string -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any - -import aiohttp -from aiortc import ( - RTCConfiguration, - RTCIceCandidate, - RTCIceServer, - RTCPeerConnection, - RTCSessionDescription, -) - -from music_assistant.constants import MASS_LOGGER_NAME - -if TYPE_CHECKING: - from collections.abc import Callable - -LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.remote_access") - -# Reduce verbose logging from aiortc/aioice -logging.getLogger("aioice").setLevel(logging.WARNING) -logging.getLogger("aiortc").setLevel(logging.WARNING) - - -def generate_remote_id() -> str: - """Generate a unique Remote ID in the format MA-XXXX-XXXX.""" - chars = string.ascii_uppercase + string.digits - part1 = "".join(secrets.choice(chars) for _ in range(4)) - part2 = "".join(secrets.choice(chars) for _ in range(4)) - return f"MA-{part1}-{part2}" - - -@dataclass -class WebRTCSession: - """Represents an active WebRTC session with a remote client.""" - - session_id: str - peer_connection: RTCPeerConnection - data_channel: Any = None - local_ws: Any = None - message_queue: asyncio.Queue[str] = field(default_factory=asyncio.Queue) - forward_to_local_task: asyncio.Task[None] | None = None - forward_from_local_task: asyncio.Task[None] | None = None - - -class WebRTCGateway: - """WebRTC Gateway for Music Assistant Remote Access. - - This gateway: - 1. Connects to a signaling server - 2. Registers with a unique Remote ID - 3. Handles incoming WebRTC connections from remote PWA clients - 4. Bridges WebRTC DataChannel messages to the local WebSocket API - """ - - def __init__( - self, - http_session: aiohttp.ClientSession, - signaling_url: str = "wss://signaling.music-assistant.io/ws", - local_ws_url: str = "ws://localhost:8095/ws", - ice_servers: list[dict[str, Any]] | None = None, - remote_id: str | None = None, - on_remote_id_ready: Callable[[str], None] | None = None, - ) -> None: - """Initialize the WebRTC Gateway. - - :param http_session: Shared aiohttp ClientSession to use for HTTP/WebSocket connections. - :param signaling_url: WebSocket URL of the signaling server. - :param local_ws_url: Local WebSocket URL to bridge to. - :param ice_servers: List of ICE server configurations. - :param remote_id: Optional Remote ID to use (generated if not provided). - :param on_remote_id_ready: Callback when Remote ID is registered. - """ - self.http_session = http_session - self.signaling_url = signaling_url - self.local_ws_url = local_ws_url - self.remote_id = remote_id or generate_remote_id() - self.on_remote_id_ready = on_remote_id_ready - self.logger = LOGGER - - self.ice_servers = ice_servers or [ - {"urls": "stun:stun.l.google.com:19302"}, - {"urls": "stun:stun1.l.google.com:19302"}, - {"urls": "stun:stun.cloudflare.com:3478"}, - ] - - self.sessions: dict[str, WebRTCSession] = {} - self._signaling_ws: aiohttp.ClientWebSocketResponse | None = None - self._running = False - self._reconnect_delay = 5 - self._max_reconnect_delay = 60 - self._current_reconnect_delay = 5 - self._run_task: asyncio.Task[None] | None = None - self._is_connected = False - - @property - def is_running(self) -> bool: - """Return whether the gateway is running.""" - return self._running - - @property - def is_connected(self) -> bool: - """Return whether the gateway is connected to the signaling server.""" - return self._is_connected - - async def start(self) -> None: - """Start the WebRTC Gateway.""" - self.logger.info("Starting WebRTC Gateway with Remote ID: %s", self.remote_id) - self.logger.debug("Signaling URL: %s", self.signaling_url) - self.logger.debug("Local WS URL: %s", self.local_ws_url) - self._running = True - self._run_task = asyncio.create_task(self._run()) - self.logger.debug("WebRTC Gateway start task created") - - async def stop(self) -> None: - """Stop the WebRTC Gateway.""" - self.logger.info("Stopping WebRTC Gateway") - self._running = False - - # Close all sessions - for session_id in list(self.sessions.keys()): - await self._close_session(session_id) - - # Close signaling connection gracefully - if self._signaling_ws and not self._signaling_ws.closed: - try: - await self._signaling_ws.close() - except Exception: - self.logger.debug("Error closing signaling WebSocket", exc_info=True) - - # Cancel run task - if self._run_task and not self._run_task.done(): - self._run_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._run_task - - self._signaling_ws = None - - async def _run(self) -> None: - """Run the main loop with reconnection logic.""" - self.logger.debug("WebRTC Gateway _run() loop starting") - while self._running: - try: - await self._connect_to_signaling() - # Connection closed gracefully or with error - self._is_connected = False - if self._running: - self.logger.warning( - "Signaling server connection lost. Reconnecting in %ss...", - self._current_reconnect_delay, - ) - except Exception: - self._is_connected = False - self.logger.exception("Signaling connection error") - if self._running: - self.logger.info( - "Reconnecting to signaling server in %ss", - self._current_reconnect_delay, - ) - - if self._running: - await asyncio.sleep(self._current_reconnect_delay) - # Exponential backoff with max limit - self._current_reconnect_delay = min( - self._current_reconnect_delay * 2, self._max_reconnect_delay - ) - - async def _connect_to_signaling(self) -> None: - """Connect to the signaling server.""" - self.logger.info("Connecting to signaling server: %s", self.signaling_url) - try: - self._signaling_ws = await self.http_session.ws_connect( - self.signaling_url, - heartbeat=30, # Send WebSocket ping every 30 seconds - autoping=True, # Automatically respond to pings - ) - self.logger.debug("WebSocket connection established") - # Small delay to let any previous connection fully close on the server side - # This helps prevent race conditions during reconnection - await asyncio.sleep(0.5) - self.logger.debug("Sending registration") - await self._register() - self._is_connected = True - # Reset reconnect delay on successful connection - self._current_reconnect_delay = self._reconnect_delay - self.logger.info("Connected and registered with signaling server") - - # Message loop - self.logger.debug("Entering message loop") - async for msg in self._signaling_ws: - self.logger.debug("Received WebSocket message type: %s", msg.type) - if msg.type == aiohttp.WSMsgType.TEXT: - try: - await self._handle_signaling_message(json.loads(msg.data)) - except Exception: - self.logger.exception("Error handling signaling message") - elif msg.type == aiohttp.WSMsgType.PING: - # WebSocket ping - autoping should handle this, just log - self.logger.debug("Received WebSocket PING") - elif msg.type == aiohttp.WSMsgType.PONG: - # WebSocket pong response - just log - self.logger.debug("Received WebSocket PONG") - elif msg.type == aiohttp.WSMsgType.CLOSE: - # Close frame received - self.logger.warning( - "Signaling server sent close frame: code=%s, reason=%s", - msg.data, - msg.extra, - ) - break - elif msg.type == aiohttp.WSMsgType.CLOSED: - self.logger.warning("Signaling server closed connection") - break - elif msg.type == aiohttp.WSMsgType.ERROR: - self.logger.error("WebSocket error: %s", self._signaling_ws.exception()) - break - else: - self.logger.warning("Unexpected WebSocket message type: %s", msg.type) - - self.logger.info( - "Message loop exited - WebSocket closed: %s", self._signaling_ws.closed - ) - except TimeoutError: - self.logger.error("Timeout connecting to signaling server") - except aiohttp.ClientError as err: - self.logger.error("Failed to connect to signaling server: %s", err) - except Exception: - self.logger.exception("Unexpected error in signaling connection") - finally: - self._is_connected = False - self._signaling_ws = None - - async def _register(self) -> None: - """Register with the signaling server.""" - if self._signaling_ws: - registration_msg = { - "type": "register-server", - "remoteId": self.remote_id, - } - self.logger.info( - "Sending registration to signaling server with Remote ID: %s", - self.remote_id, - ) - self.logger.debug("Registration message: %s", registration_msg) - await self._signaling_ws.send_json(registration_msg) - self.logger.debug("Registration message sent successfully") - else: - self.logger.warning("Cannot register: signaling websocket is not connected") - - async def _handle_signaling_message(self, message: dict[str, Any]) -> None: - """Handle incoming signaling messages. - - :param message: The signaling message. - """ - msg_type = message.get("type") - self.logger.debug("Received signaling message: %s - Full message: %s", msg_type, message) - - if msg_type == "ping": - # Respond to ping with pong - if self._signaling_ws: - await self._signaling_ws.send_json({"type": "pong"}) - elif msg_type == "pong": - # Server responded to our ping, connection is alive - pass - elif msg_type == "registered": - self.logger.info("Registered with signaling server as: %s", message.get("remoteId")) - if self.on_remote_id_ready: - self.on_remote_id_ready(self.remote_id) - elif msg_type == "error": - self.logger.error( - "Signaling server error: %s", - message.get("message", "Unknown error"), - ) - elif msg_type == "client-connected": - session_id = message.get("sessionId") - if session_id: - await self._create_session(session_id) - elif msg_type == "client-disconnected": - session_id = message.get("sessionId") - if session_id: - await self._close_session(session_id) - elif msg_type == "offer": - session_id = message.get("sessionId") - offer_data = message.get("data") - if session_id and offer_data: - await self._handle_offer(session_id, offer_data) - elif msg_type == "ice-candidate": - session_id = message.get("sessionId") - candidate_data = message.get("data") - if session_id and candidate_data: - await self._handle_ice_candidate(session_id, candidate_data) - - async def _create_session(self, session_id: str) -> None: - """Create a new WebRTC session. - - :param session_id: The session ID. - """ - config = RTCConfiguration( - iceServers=[RTCIceServer(**server) for server in self.ice_servers] - ) - pc = RTCPeerConnection(configuration=config) - session = WebRTCSession(session_id=session_id, peer_connection=pc) - self.sessions[session_id] = session - - @pc.on("datachannel") - def on_datachannel(channel: Any) -> None: - session.data_channel = channel - asyncio.create_task(self._setup_data_channel(session)) - - @pc.on("icecandidate") - async def on_icecandidate(candidate: Any) -> None: - if candidate and self._signaling_ws: - await self._signaling_ws.send_json( - { - "type": "ice-candidate", - "sessionId": session_id, - "data": { - "candidate": candidate.candidate, - "sdpMid": candidate.sdpMid, - "sdpMLineIndex": candidate.sdpMLineIndex, - }, - } - ) - - @pc.on("connectionstatechange") - async def on_connectionstatechange() -> None: - if pc.connectionState == "failed": - await self._close_session(session_id) - - async def _handle_offer(self, session_id: str, offer: dict[str, Any]) -> None: - """Handle incoming WebRTC offer. - - :param session_id: The session ID. - :param offer: The offer data. - """ - session = self.sessions.get(session_id) - if not session: - return - pc = session.peer_connection - - # Check if peer connection is already closed or closing - if pc.connectionState in ("closed", "failed"): - self.logger.debug( - "Ignoring offer for session %s - connection state: %s", - session_id, - pc.connectionState, - ) - return - - sdp = offer.get("sdp") - sdp_type = offer.get("type") - if not sdp or not sdp_type: - self.logger.error("Invalid offer data: missing sdp or type") - return - - try: - await pc.setRemoteDescription( - RTCSessionDescription( - sdp=str(sdp), - type=str(sdp_type), - ) - ) - - # Check again if session was closed during setRemoteDescription - if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): - self.logger.debug( - "Session %s closed during setRemoteDescription, aborting offer handling", - session_id, - ) - return - - answer = await pc.createAnswer() - - # Check again before setLocalDescription - if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): - self.logger.debug( - "Session %s closed during createAnswer, aborting offer handling", - session_id, - ) - return - - await pc.setLocalDescription(answer) - - # Final check before sending answer - if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): - self.logger.debug( - "Session %s closed during setLocalDescription, skipping answer transmission", - session_id, - ) - return - - if self._signaling_ws: - await self._signaling_ws.send_json( - { - "type": "answer", - "sessionId": session_id, - "data": { - "sdp": pc.localDescription.sdp, - "type": pc.localDescription.type, - }, - } - ) - except Exception: - self.logger.exception("Error handling offer for session %s", session_id) - # Clean up the session on error - await self._close_session(session_id) - - async def _handle_ice_candidate(self, session_id: str, candidate: dict[str, Any]) -> None: - """Handle incoming ICE candidate. - - :param session_id: The session ID. - :param candidate: The ICE candidate data. - """ - session = self.sessions.get(session_id) - if not session or not candidate: - return - - # Check if peer connection is already closed or closing - pc = session.peer_connection - if pc.connectionState in ("closed", "failed"): - self.logger.debug( - "Ignoring ICE candidate for session %s - connection state: %s", - session_id, - pc.connectionState, - ) - return - - candidate_str = candidate.get("candidate") - sdp_mid = candidate.get("sdpMid") - sdp_mline_index = candidate.get("sdpMLineIndex") - - if not candidate_str: - return - - # Create RTCIceCandidate from the SDP string - try: - ice_candidate = RTCIceCandidate( - component=1, - foundation="", - ip="", - port=0, - priority=0, - protocol="udp", - type="host", - sdpMid=str(sdp_mid) if sdp_mid else None, - sdpMLineIndex=int(sdp_mline_index) if sdp_mline_index is not None else None, - ) - # Parse the candidate string to populate the fields - ice_candidate.candidate = str(candidate_str) # type: ignore[attr-defined] - - # Check if session was closed before adding candidate - if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): - self.logger.debug( - "Session %s closed before adding ICE candidate, skipping", - session_id, - ) - return - - await session.peer_connection.addIceCandidate(ice_candidate) - except Exception: - self.logger.exception( - "Failed to add ICE candidate for session %s: %s", session_id, candidate - ) - - async def _setup_data_channel(self, session: WebRTCSession) -> None: - """Set up data channel and bridge to local WebSocket. - - :param session: The WebRTC session. - """ - channel = session.data_channel - if not channel: - return - try: - session.local_ws = await self.http_session.ws_connect(self.local_ws_url) - loop = asyncio.get_event_loop() - - # Store task references for proper cleanup - session.forward_to_local_task = asyncio.create_task(self._forward_to_local(session)) - session.forward_from_local_task = asyncio.create_task(self._forward_from_local(session)) - - @channel.on("message") # type: ignore[misc] - def on_message(message: str) -> None: - # Called from aiortc thread, use call_soon_threadsafe - # Only queue message if session is still active - if session.forward_to_local_task and not session.forward_to_local_task.done(): - loop.call_soon_threadsafe(session.message_queue.put_nowait, message) - - @channel.on("close") # type: ignore[misc] - def on_close() -> None: - # Called from aiortc thread, use call_soon_threadsafe to schedule task - asyncio.run_coroutine_threadsafe(self._close_session(session.session_id), loop) - - except Exception: - self.logger.exception("Failed to connect to local WebSocket") - - async def _forward_to_local(self, session: WebRTCSession) -> None: - """Forward messages from WebRTC DataChannel to local WebSocket. - - :param session: The WebRTC session. - """ - try: - while session.local_ws and not session.local_ws.closed: - message = await session.message_queue.get() - - # Check if this is an HTTP proxy request - try: - msg_data = json.loads(message) - if isinstance(msg_data, dict) and msg_data.get("type") == "http-proxy-request": - # Handle HTTP proxy request - await self._handle_http_proxy_request(session, msg_data) - continue - except (json.JSONDecodeError, ValueError): - pass - - # Regular WebSocket message - if session.local_ws and not session.local_ws.closed: - await session.local_ws.send_str(message) - except asyncio.CancelledError: - # Task was cancelled during cleanup, this is expected - self.logger.debug("Forward to local task cancelled for session %s", session.session_id) - raise - except Exception: - self.logger.exception("Error forwarding to local WebSocket") - - async def _forward_from_local(self, session: WebRTCSession) -> None: - """Forward messages from local WebSocket to WebRTC DataChannel. - - :param session: The WebRTC session. - """ - try: - async for msg in session.local_ws: - if msg.type == aiohttp.WSMsgType.TEXT: - if session.data_channel and session.data_channel.readyState == "open": - session.data_channel.send(msg.data) - elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): - break - except asyncio.CancelledError: - # Task was cancelled during cleanup, this is expected - self.logger.debug( - "Forward from local task cancelled for session %s", session.session_id - ) - raise - except Exception: - self.logger.exception("Error forwarding from local WebSocket") - - async def _handle_http_proxy_request( - self, session: WebRTCSession, request_data: dict[str, Any] - ) -> None: - """Handle HTTP proxy request from remote client. - - :param session: The WebRTC session. - :param request_data: The HTTP proxy request data. - """ - request_id = request_data.get("id") - method = request_data.get("method", "GET") - path = request_data.get("path", "/") - headers = request_data.get("headers", {}) - - # Build local HTTP URL - # Extract host and port from local_ws_url (ws://localhost:8095/ws) - ws_url_parts = self.local_ws_url.replace("ws://", "").split("/") - host_port = ws_url_parts[0] # localhost:8095 - local_http_url = f"http://{host_port}{path}" - - self.logger.debug("HTTP proxy request: %s %s", method, local_http_url) - - try: - # Use shared HTTP session for this request - async with self.http_session.request( - method, local_http_url, headers=headers - ) as response: - # Read response body - body = await response.read() - - # Prepare response data - response_data = { - "type": "http-proxy-response", - "id": request_id, - "status": response.status, - "headers": dict(response.headers), - "body": body.hex(), # Send as hex string to avoid encoding issues - } - - # Send response back through data channel - if session.data_channel and session.data_channel.readyState == "open": - session.data_channel.send(json.dumps(response_data)) - - except Exception as err: - self.logger.exception("Error handling HTTP proxy request") - # Send error response - error_response = { - "type": "http-proxy-response", - "id": request_id, - "status": 500, - "headers": {"Content-Type": "text/plain"}, - "body": str(err).encode().hex(), - } - if session.data_channel and session.data_channel.readyState == "open": - session.data_channel.send(json.dumps(error_response)) - - async def _close_session(self, session_id: str) -> None: - """Close a WebRTC session. - - :param session_id: The session ID. - """ - session = self.sessions.pop(session_id, None) - if not session: - return - - # Cancel forwarding tasks first to prevent race conditions - if session.forward_to_local_task and not session.forward_to_local_task.done(): - session.forward_to_local_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await session.forward_to_local_task - - if session.forward_from_local_task and not session.forward_from_local_task.done(): - session.forward_from_local_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await session.forward_from_local_task - - # Close connections - if session.local_ws and not session.local_ws.closed: - await session.local_ws.close() - if session.data_channel: - session.data_channel.close() - await session.peer_connection.close() diff --git a/music_assistant/providers/remote_access/manifest.json b/music_assistant/providers/remote_access/manifest.json deleted file mode 100644 index a6600c87..00000000 --- a/music_assistant/providers/remote_access/manifest.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "type": "plugin", - "domain": "remote_access", - "stage": "alpha", - "name": "Remote Access", - "description": "WebRTC-based encrypted, secure remote access for connecting to Music Assistant from outside your local network (requires Home Assistant Cloud subscription for the best experience).", - "codeowners": ["@music-assistant"], - "icon": "cloud-lock", - "multi_instance": false, - "builtin": true, - "allow_disable": true, - "requirements": ["aiortc>=1.6.0"] -} diff --git a/pyproject.toml b/pyproject.toml index c26830cd..33c2c926 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "librosa==0.11.0", "gql[all]==4.0.0", "aiovban>=0.6.3", + "aiortc>=1.6.0", ] description = "Music Assistant" license = {text = "Apache-2.0"}