From: Marcel van der Veldt Date: Mon, 1 Dec 2025 08:50:33 +0000 (+0100) Subject: remote access fixes X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=4d292b24d34f8666d5155048fcb8c7d492708ce9;p=music-assistant-server.git remote access fixes --- diff --git a/music_assistant/constants.py b/music_assistant/constants.py index 09b24eab..bf20bb79 100644 --- a/music_assistant/constants.py +++ b/music_assistant/constants.py @@ -101,7 +101,7 @@ CONF_USE_SSL: Final[str] = "use_ssl" CONF_VERIFY_SSL: Final[str] = "verify_ssl" CONF_SSL_FINGERPRINT: Final[str] = "ssl_fingerprint" CONF_AUTH_ALLOW_SELF_REGISTRATION: Final[str] = "auth_allow_self_registration" - +CONF_ENABLED: Final[str] = "enabled" # config default values DEFAULT_HOST: Final[str] = "0.0.0.0" @@ -141,7 +141,6 @@ CONFIGURABLE_CORE_CONTROLLERS = ( "cache", "music", "player_queues", - "remote_access", ) VERBOSE_LOG_LEVEL: Final[int] = 5 PROVIDERS_WITH_SHAREABLE_URLS = ("spotify", "qobuz") diff --git a/music_assistant/controllers/config.py b/music_assistant/controllers/config.py index 2d617d64..3a8267a6 100644 --- a/music_assistant/controllers/config.py +++ b/music_assistant/controllers/config.py @@ -426,9 +426,6 @@ class ConfigController: config = await self._update_provider_config(instance_id, values) else: config = await self._add_provider_config(provider_domain, values) - # mark onboard done whenever the (first) provider is added - # this will be replaced later by a more sophisticated onboarding process - self.set(CONF_ONBOARD_DONE, True) # return full config, just in case return await self.get_provider_config(config.instance_id) diff --git a/music_assistant/controllers/remote_access/__init__.py b/music_assistant/controllers/remote_access/__init__.py deleted file mode 100644 index 16eb8066..00000000 --- a/music_assistant/controllers/remote_access/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Remote Access controller for Music Assistant.""" - -from music_assistant.controllers.remote_access.controller import RemoteAccessController - -__all__ = ["RemoteAccessController"] diff --git a/music_assistant/controllers/remote_access/controller.py b/music_assistant/controllers/remote_access/controller.py deleted file mode 100644 index d67ce6e4..00000000 --- a/music_assistant/controllers/remote_access/controller.py +++ /dev/null @@ -1,306 +0,0 @@ -""" -Remote Access Controller for Music Assistant. - -This controller manages WebRTC-based remote access to Music Assistant instances. -It connects to a signaling server and handles incoming WebRTC connections, -bridging them to the local WebSocket API. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, cast - -from music_assistant_models.config_entries import ConfigEntry -from music_assistant_models.enums import ConfigEntryType, EventType - -from music_assistant.controllers.remote_access.gateway import WebRTCGateway, generate_remote_id -from music_assistant.helpers.api import api_command -from music_assistant.models.core_controller import CoreController - -if TYPE_CHECKING: - from music_assistant_models.config_entries import ConfigValueType, CoreConfig - from music_assistant_models.event import MassEvent - - from music_assistant import MusicAssistant - from music_assistant.providers.hass import HomeAssistantProvider - -# Signaling server URL -SIGNALING_SERVER_URL = "wss://signaling.music-assistant.io/ws" - -# Config keys -CONF_REMOTE_ID = "remote_id" -CONF_ENABLED = "enable_remote_access" - - -class RemoteAccessController(CoreController): - """Core Controller for WebRTC-based remote access.""" - - domain: str = "remote_access" - - def __init__(self, mass: MusicAssistant) -> None: - """Initialize the remote access controller. - - :param mass: MusicAssistant instance. - """ - super().__init__(mass) - self.manifest.name = "Remote Access" - self.manifest.description = ( - "WebRTC-based remote access for connecting to Music Assistant " - "from outside your local network (requires Home Assistant Cloud subscription)" - ) - self.manifest.icon = "cloud-lock" - self.gateway: WebRTCGateway | None = None - self._remote_id: str | None = None - self._setup_done = False - - async def get_config_entries( - self, - action: str | None = None, - values: dict[str, ConfigValueType] | None = None, - ) -> tuple[ConfigEntry, ...]: - """Return all Config Entries for this core module (if any).""" - entries = [] - has_ha_cloud = await self._check_ha_cloud_status() - - # Info alert about HA Cloud requirement - entries.append( - ConfigEntry( - key="remote_access_info", - type=ConfigEntryType.ALERT, - label="Remote Access requires an active Home Assistant Cloud subscription.", - required=False, - hidden=has_ha_cloud, - ) - ) - - entries.append( - ConfigEntry( - key=CONF_ENABLED, - type=ConfigEntryType.BOOLEAN, - default_value=False, - label="Enable Remote Access", - description="Enable WebRTC-based (encrypted), secure remote access to your " - "Music Assistant instance via https://app.music-assistant.io " - "or supported mobile apps. ", - hidden=not has_ha_cloud, - ) - ) - entries.append( - ConfigEntry( - key="remote_access_id_label", - type=ConfigEntryType.LABEL, - label=f"Remote Access ID: {self._remote_id}", - hidden=self._remote_id is None or not has_ha_cloud, - depends_on=CONF_ENABLED, - ) - ) - - entries.append( - ConfigEntry( - key=CONF_REMOTE_ID, - type=ConfigEntryType.STRING, - label="Remote ID", - description="Unique identifier for WebRTC remote access. " - "Generated automatically and should not be changed.", - required=False, - hidden=True, - ) - ) - - return tuple(entries) - - async def setup(self, config: CoreConfig) -> None: - """Async initialize of module.""" - self.config = config - self.logger.debug("RemoteAccessController.setup() called") - - # Get or generate Remote ID - remote_id_value = self.config.get_value(CONF_REMOTE_ID) - if not remote_id_value: - # Generate new Remote ID and save it - remote_id_value = generate_remote_id() - self._remote_id = remote_id_value - # Save the Remote ID to config - self.mass.config.set_raw_core_config_value(self.domain, CONF_REMOTE_ID, remote_id_value) - self.mass.config.save(immediate=True) - self.logger.debug("Generated new Remote ID: %s", remote_id_value) - - # Register API commands immediately - self._register_api_commands() - - # Subscribe to provider updates to check for Home Assistant Cloud - self.mass.subscribe(self._on_providers_updated, EventType.PROVIDERS_UPDATED) - - # Try initial setup (providers might already be loaded) - await self._try_enable_remote_access() - - async def close(self) -> None: - """Cleanup on exit.""" - if self.gateway: - await self.gateway.stop() - self.gateway = None - self.logger.debug("WebRTC Remote Access stopped") - - def _on_remote_id_ready(self, remote_id: str) -> None: - """Handle Remote ID registration with signaling server. - - :param remote_id: The registered Remote ID. - """ - self.logger.debug("Remote ID registered with signaling server: %s", remote_id) - self._remote_id = remote_id - - def _on_providers_updated(self, event: MassEvent) -> None: - """Handle providers updated event. - - :param event: The providers updated event. - """ - if self._setup_done: - # Already set up, no need to check again - return - # Try to enable remote access when providers are updated - self.mass.create_task(self._try_enable_remote_access()) - - async def _try_enable_remote_access(self) -> None: - """Try to enable remote access if Home Assistant Cloud is available.""" - if self._setup_done: - # Already set up - return - - # Check if remote access is enabled in config - if not self.config.get_value(CONF_ENABLED, True): - return - - # Check if Home Assistant Cloud is available and active - cloud_status = await self._check_ha_cloud_status() - if not cloud_status: - self.logger.debug("Home Assistant Cloud not available") - return - - # Mark as done to prevent multiple attempts - self._setup_done = True - self.logger.info("Home Assistant Cloud subscription detected, enabling remote access") - - # Determine local WebSocket URL from webserver config - base_url = self.mass.webserver.base_url - local_ws_url = base_url.replace("http", "ws") - if not local_ws_url.endswith("/"): - local_ws_url += "/" - local_ws_url += "ws" - - # Get ICE servers from HA Cloud if available - ice_servers = await self._get_ha_cloud_ice_servers() - - # Initialize and start the WebRTC gateway - self.gateway = WebRTCGateway( - signaling_url=SIGNALING_SERVER_URL, - local_ws_url=local_ws_url, - remote_id=self._remote_id, - on_remote_id_ready=self._on_remote_id_ready, - ice_servers=ice_servers, - ) - - await self.gateway.start() - self.logger.info("WebRTC Remote Access enabled - Remote ID: %s", self._remote_id) - - async def _check_ha_cloud_status(self) -> bool: - """Check if Home Assistant Cloud subscription is active. - - :return: True if HA Cloud is logged in and has active subscription. - """ - # Find the Home Assistant provider - ha_provider = cast("HomeAssistantProvider | None", self.mass.get_provider("hass")) - if not ha_provider: - self.logger.debug("Home Assistant provider not found or not available") - return False - - try: - # Access the hass client from the provider - hass_client = ha_provider.hass - if not hass_client or not hass_client.connected: - self.logger.debug("Home Assistant client not connected") - return False - - # Call cloud/status command to check subscription - result = await hass_client.send_command("cloud/status") - - # Check for logged_in and active_subscription - logged_in = result.get("logged_in", False) - active_subscription = result.get("active_subscription", False) - - if logged_in and active_subscription: - self.logger.debug("Home Assistant Cloud subscription is active") - return True - - self.logger.debug( - "Home Assistant Cloud not active (logged_in=%s, active_subscription=%s)", - logged_in, - active_subscription, - ) - return False - - except Exception: - self.logger.exception("Error checking Home Assistant Cloud status") - return False - - async def _get_ha_cloud_ice_servers(self) -> list[dict[str, str]] | None: - """Get ICE servers from Home Assistant Cloud. - - :return: List of ICE server configurations or None if unavailable. - """ - # Find the Home Assistant provider - ha_provider = cast("HomeAssistantProvider | None", self.mass.get_provider("hass")) - if not ha_provider: - return None - - try: - # Access the hass client from the provider - if not hasattr(ha_provider, "hass"): - return None - hass_client = ha_provider.hass - if not hass_client or not hass_client.connected: - return None - - # Try to get ICE servers from HA Cloud - # This might be available via a cloud API endpoint - # For now, return None and use default STUN servers - # TODO: Research if HA Cloud exposes ICE/TURN server endpoints - self.logger.debug( - "Using default STUN servers (HA Cloud ICE servers not yet implemented)" - ) - return None - - except Exception: - self.logger.exception("Error getting Home Assistant Cloud ICE servers") - return None - - @property - def is_enabled(self) -> bool: - """Return whether WebRTC remote access is enabled.""" - return self.gateway is not None and self.gateway.is_running - - @property - def is_connected(self) -> bool: - """Return whether the gateway is connected to the signaling server.""" - return self.gateway is not None and self.gateway.is_connected - - @property - def remote_id(self) -> str | None: - """Return the current Remote ID.""" - return self._remote_id - - def _register_api_commands(self) -> None: - """Register API commands for remote access.""" - - @api_command("remote_access/info") - def get_remote_access_info() -> dict[str, str | bool]: - """Get remote access information. - - Returns information about the remote access configuration including - whether it's enabled, connected status, and the Remote ID for connecting. - """ - return { - "enabled": self.is_enabled, - "connected": self.is_connected, - "remote_id": self._remote_id or "", - "signaling_url": SIGNALING_SERVER_URL, - } diff --git a/music_assistant/controllers/remote_access/gateway.py b/music_assistant/controllers/remote_access/gateway.py deleted file mode 100644 index 5f8b8561..00000000 --- a/music_assistant/controllers/remote_access/gateway.py +++ /dev/null @@ -1,667 +0,0 @@ -"""Music Assistant WebRTC Gateway. - -This module provides WebRTC-based remote access to Music Assistant instances. -It connects to a signaling server and handles incoming WebRTC connections, -bridging them to the local WebSocket API. -""" - -from __future__ import annotations - -import asyncio -import contextlib -import json -import logging -import secrets -import string -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any - -import aiohttp -from aiortc import ( - RTCConfiguration, - RTCIceCandidate, - RTCIceServer, - RTCPeerConnection, - RTCSessionDescription, -) - -from music_assistant.constants import MASS_LOGGER_NAME - -if TYPE_CHECKING: - from collections.abc import Callable - -LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.remote_access") - -# Reduce verbose logging from aiortc/aioice -logging.getLogger("aioice").setLevel(logging.WARNING) -logging.getLogger("aiortc").setLevel(logging.WARNING) - - -def generate_remote_id() -> str: - """Generate a unique Remote ID in the format MA-XXXX-XXXX.""" - chars = string.ascii_uppercase + string.digits - part1 = "".join(secrets.choice(chars) for _ in range(4)) - part2 = "".join(secrets.choice(chars) for _ in range(4)) - return f"MA-{part1}-{part2}" - - -@dataclass -class WebRTCSession: - """Represents an active WebRTC session with a remote client.""" - - session_id: str - peer_connection: RTCPeerConnection - data_channel: Any = None - local_ws: Any = None - message_queue: asyncio.Queue[str] = field(default_factory=asyncio.Queue) - forward_to_local_task: asyncio.Task[None] | None = None - forward_from_local_task: asyncio.Task[None] | None = None - - -class WebRTCGateway: - """WebRTC Gateway for Music Assistant Remote Access. - - This gateway: - 1. Connects to a signaling server - 2. Registers with a unique Remote ID - 3. Handles incoming WebRTC connections from remote PWA clients - 4. Bridges WebRTC DataChannel messages to the local WebSocket API - """ - - def __init__( - self, - signaling_url: str = "wss://signaling.music-assistant.io/ws", - local_ws_url: str = "ws://localhost:8095/ws", - ice_servers: list[dict[str, Any]] | None = None, - remote_id: str | None = None, - on_remote_id_ready: Callable[[str], None] | None = None, - ) -> None: - """Initialize the WebRTC Gateway. - - :param signaling_url: WebSocket URL of the signaling server. - :param local_ws_url: Local WebSocket URL to bridge to. - :param ice_servers: List of ICE server configurations. - :param remote_id: Optional Remote ID to use (generated if not provided). - :param on_remote_id_ready: Callback when Remote ID is registered. - """ - self.signaling_url = signaling_url - self.local_ws_url = local_ws_url - self.remote_id = remote_id or generate_remote_id() - self.on_remote_id_ready = on_remote_id_ready - self.logger = LOGGER - - self.ice_servers = ice_servers or [ - {"urls": "stun:stun.l.google.com:19302"}, - {"urls": "stun:stun1.l.google.com:19302"}, - {"urls": "stun:stun.cloudflare.com:3478"}, - ] - - self.sessions: dict[str, WebRTCSession] = {} - self._signaling_ws: aiohttp.ClientWebSocketResponse | None = None - self._signaling_session: aiohttp.ClientSession | None = None - self._running = False - self._reconnect_delay = 5 - self._max_reconnect_delay = 60 - self._current_reconnect_delay = 5 - self._run_task: asyncio.Task[None] | None = None - self._is_connected = False - - @property - def is_running(self) -> bool: - """Return whether the gateway is running.""" - return self._running - - @property - def is_connected(self) -> bool: - """Return whether the gateway is connected to the signaling server.""" - return self._is_connected - - async def start(self) -> None: - """Start the WebRTC Gateway.""" - self.logger.info("Starting WebRTC Gateway with Remote ID: %s", self.remote_id) - self.logger.debug("Signaling URL: %s", self.signaling_url) - self.logger.debug("Local WS URL: %s", self.local_ws_url) - self._running = True - self._run_task = asyncio.create_task(self._run()) - self.logger.debug("WebRTC Gateway start task created") - - async def stop(self) -> None: - """Stop the WebRTC Gateway.""" - self.logger.info("Stopping WebRTC Gateway") - self._running = False - - # Close all sessions - for session_id in list(self.sessions.keys()): - await self._close_session(session_id) - - # Close signaling connection gracefully - if self._signaling_ws and not self._signaling_ws.closed: - try: - await self._signaling_ws.close() - except Exception: - self.logger.debug("Error closing signaling WebSocket", exc_info=True) - - if self._signaling_session and not self._signaling_session.closed: - try: - await self._signaling_session.close() - except Exception: - self.logger.debug("Error closing signaling session", exc_info=True) - - # Cancel run task - if self._run_task and not self._run_task.done(): - self._run_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._run_task - - self._signaling_ws = None - self._signaling_session = None - - async def _run(self) -> None: - """Run the main loop with reconnection logic.""" - self.logger.debug("WebRTC Gateway _run() loop starting") - while self._running: - try: - await self._connect_to_signaling() - # Connection closed gracefully or with error - self._is_connected = False - if self._running: - self.logger.warning( - "Signaling server connection lost. Reconnecting in %ss...", - self._current_reconnect_delay, - ) - except Exception: - self._is_connected = False - self.logger.exception("Signaling connection error") - if self._running: - self.logger.info( - "Reconnecting to signaling server in %ss", - self._current_reconnect_delay, - ) - - if self._running: - await asyncio.sleep(self._current_reconnect_delay) - # Exponential backoff with max limit - self._current_reconnect_delay = min( - self._current_reconnect_delay * 2, self._max_reconnect_delay - ) - - async def _connect_to_signaling(self) -> None: - """Connect to the signaling server.""" - self.logger.info("Connecting to signaling server: %s", self.signaling_url) - # Create session with increased timeout for WebSocket connection - timeout = aiohttp.ClientTimeout( - total=None, # No total timeout for WebSocket connections - connect=30, # 30 seconds to establish connection - sock_connect=30, # 30 seconds for socket connection - sock_read=None, # No timeout for reading (we handle ping/pong) - ) - self._signaling_session = aiohttp.ClientSession(timeout=timeout) - try: - self._signaling_ws = await self._signaling_session.ws_connect( - self.signaling_url, - heartbeat=30, # Send WebSocket ping every 30 seconds - autoping=True, # Automatically respond to pings - ) - self.logger.debug("WebSocket connection established") - # Small delay to let any previous connection fully close on the server side - # This helps prevent race conditions during reconnection - await asyncio.sleep(0.5) - self.logger.debug("Sending registration") - await self._register() - self._is_connected = True - # Reset reconnect delay on successful connection - self._current_reconnect_delay = self._reconnect_delay - self.logger.info("Connected and registered with signaling server") - - # Message loop - self.logger.debug("Entering message loop") - async for msg in self._signaling_ws: - self.logger.debug("Received WebSocket message type: %s", msg.type) - if msg.type == aiohttp.WSMsgType.TEXT: - try: - await self._handle_signaling_message(json.loads(msg.data)) - except Exception: - self.logger.exception("Error handling signaling message") - elif msg.type == aiohttp.WSMsgType.PING: - # WebSocket ping - autoping should handle this, just log - self.logger.debug("Received WebSocket PING") - elif msg.type == aiohttp.WSMsgType.PONG: - # WebSocket pong response - just log - self.logger.debug("Received WebSocket PONG") - elif msg.type == aiohttp.WSMsgType.CLOSE: - # Close frame received - self.logger.warning( - "Signaling server sent close frame: code=%s, reason=%s", - msg.data, - msg.extra, - ) - break - elif msg.type == aiohttp.WSMsgType.CLOSED: - self.logger.warning("Signaling server closed connection") - break - elif msg.type == aiohttp.WSMsgType.ERROR: - self.logger.error("WebSocket error: %s", self._signaling_ws.exception()) - break - else: - self.logger.warning("Unexpected WebSocket message type: %s", msg.type) - - self.logger.info( - "Message loop exited - WebSocket closed: %s", self._signaling_ws.closed - ) - except TimeoutError: - self.logger.error("Timeout connecting to signaling server") - except aiohttp.ClientError as err: - self.logger.error("Failed to connect to signaling server: %s", err) - except Exception: - self.logger.exception("Unexpected error in signaling connection") - finally: - self._is_connected = False - if self._signaling_session: - await self._signaling_session.close() - self._signaling_session = None - self._signaling_ws = None - - async def _register(self) -> None: - """Register with the signaling server.""" - if self._signaling_ws: - registration_msg = { - "type": "register-server", - "remoteId": self.remote_id, - } - self.logger.info( - "Sending registration to signaling server with Remote ID: %s", - self.remote_id, - ) - self.logger.debug("Registration message: %s", registration_msg) - await self._signaling_ws.send_json(registration_msg) - self.logger.debug("Registration message sent successfully") - else: - self.logger.warning("Cannot register: signaling websocket is not connected") - - async def _handle_signaling_message(self, message: dict[str, Any]) -> None: - """Handle incoming signaling messages. - - :param message: The signaling message. - """ - msg_type = message.get("type") - self.logger.debug("Received signaling message: %s - Full message: %s", msg_type, message) - - if msg_type == "ping": - # Respond to ping with pong - if self._signaling_ws: - await self._signaling_ws.send_json({"type": "pong"}) - elif msg_type == "pong": - # Server responded to our ping, connection is alive - pass - elif msg_type == "registered": - self.logger.info("Registered with signaling server as: %s", message.get("remoteId")) - if self.on_remote_id_ready: - self.on_remote_id_ready(self.remote_id) - elif msg_type == "error": - self.logger.error( - "Signaling server error: %s", - message.get("message", "Unknown error"), - ) - elif msg_type == "client-connected": - session_id = message.get("sessionId") - if session_id: - await self._create_session(session_id) - elif msg_type == "client-disconnected": - session_id = message.get("sessionId") - if session_id: - await self._close_session(session_id) - elif msg_type == "offer": - session_id = message.get("sessionId") - offer_data = message.get("data") - if session_id and offer_data: - await self._handle_offer(session_id, offer_data) - elif msg_type == "ice-candidate": - session_id = message.get("sessionId") - candidate_data = message.get("data") - if session_id and candidate_data: - await self._handle_ice_candidate(session_id, candidate_data) - - async def _create_session(self, session_id: str) -> None: - """Create a new WebRTC session. - - :param session_id: The session ID. - """ - config = RTCConfiguration( - iceServers=[RTCIceServer(**server) for server in self.ice_servers] - ) - pc = RTCPeerConnection(configuration=config) - session = WebRTCSession(session_id=session_id, peer_connection=pc) - self.sessions[session_id] = session - - @pc.on("datachannel") - def on_datachannel(channel: Any) -> None: - session.data_channel = channel - asyncio.create_task(self._setup_data_channel(session)) - - @pc.on("icecandidate") - async def on_icecandidate(candidate: Any) -> None: - if candidate and self._signaling_ws: - await self._signaling_ws.send_json( - { - "type": "ice-candidate", - "sessionId": session_id, - "data": { - "candidate": candidate.candidate, - "sdpMid": candidate.sdpMid, - "sdpMLineIndex": candidate.sdpMLineIndex, - }, - } - ) - - @pc.on("connectionstatechange") - async def on_connectionstatechange() -> None: - if pc.connectionState == "failed": - await self._close_session(session_id) - - async def _handle_offer(self, session_id: str, offer: dict[str, Any]) -> None: - """Handle incoming WebRTC offer. - - :param session_id: The session ID. - :param offer: The offer data. - """ - session = self.sessions.get(session_id) - if not session: - return - pc = session.peer_connection - - # Check if peer connection is already closed or closing - if pc.connectionState in ("closed", "failed"): - self.logger.debug( - "Ignoring offer for session %s - connection state: %s", - session_id, - pc.connectionState, - ) - return - - sdp = offer.get("sdp") - sdp_type = offer.get("type") - if not sdp or not sdp_type: - self.logger.error("Invalid offer data: missing sdp or type") - return - - try: - await pc.setRemoteDescription( - RTCSessionDescription( - sdp=str(sdp), - type=str(sdp_type), - ) - ) - - # Check again if session was closed during setRemoteDescription - if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): - self.logger.debug( - "Session %s closed during setRemoteDescription, aborting offer handling", - session_id, - ) - return - - answer = await pc.createAnswer() - - # Check again before setLocalDescription - if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): - self.logger.debug( - "Session %s closed during createAnswer, aborting offer handling", - session_id, - ) - return - - await pc.setLocalDescription(answer) - - # Final check before sending answer - if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): - self.logger.debug( - "Session %s closed during setLocalDescription, skipping answer transmission", - session_id, - ) - return - - if self._signaling_ws: - await self._signaling_ws.send_json( - { - "type": "answer", - "sessionId": session_id, - "data": { - "sdp": pc.localDescription.sdp, - "type": pc.localDescription.type, - }, - } - ) - except Exception: - self.logger.exception("Error handling offer for session %s", session_id) - # Clean up the session on error - await self._close_session(session_id) - - async def _handle_ice_candidate(self, session_id: str, candidate: dict[str, Any]) -> None: - """Handle incoming ICE candidate. - - :param session_id: The session ID. - :param candidate: The ICE candidate data. - """ - session = self.sessions.get(session_id) - if not session or not candidate: - return - - # Check if peer connection is already closed or closing - pc = session.peer_connection - if pc.connectionState in ("closed", "failed"): - self.logger.debug( - "Ignoring ICE candidate for session %s - connection state: %s", - session_id, - pc.connectionState, - ) - return - - candidate_str = candidate.get("candidate") - sdp_mid = candidate.get("sdpMid") - sdp_mline_index = candidate.get("sdpMLineIndex") - - if not candidate_str: - return - - # Create RTCIceCandidate from the SDP string - try: - ice_candidate = RTCIceCandidate( - component=1, - foundation="", - ip="", - port=0, - priority=0, - protocol="udp", - type="host", - sdpMid=str(sdp_mid) if sdp_mid else None, - sdpMLineIndex=int(sdp_mline_index) if sdp_mline_index is not None else None, - ) - # Parse the candidate string to populate the fields - ice_candidate.candidate = str(candidate_str) # type: ignore[attr-defined] - - # Check if session was closed before adding candidate - if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): - self.logger.debug( - "Session %s closed before adding ICE candidate, skipping", - session_id, - ) - return - - await session.peer_connection.addIceCandidate(ice_candidate) - except Exception: - self.logger.exception( - "Failed to add ICE candidate for session %s: %s", session_id, candidate - ) - - async def _setup_data_channel(self, session: WebRTCSession) -> None: - """Set up data channel and bridge to local WebSocket. - - :param session: The WebRTC session. - """ - channel = session.data_channel - if not channel: - return - local_session = aiohttp.ClientSession() - try: - session.local_ws = await local_session.ws_connect(self.local_ws_url) - loop = asyncio.get_event_loop() - - # Store task references for proper cleanup - session.forward_to_local_task = asyncio.create_task(self._forward_to_local(session)) - session.forward_from_local_task = asyncio.create_task( - self._forward_from_local(session, local_session) - ) - - @channel.on("message") # type: ignore[misc] - def on_message(message: str) -> None: - # Called from aiortc thread, use call_soon_threadsafe - # Only queue message if session is still active - if session.forward_to_local_task and not session.forward_to_local_task.done(): - loop.call_soon_threadsafe(session.message_queue.put_nowait, message) - - @channel.on("close") # type: ignore[misc] - def on_close() -> None: - # Called from aiortc thread, use call_soon_threadsafe to schedule task - asyncio.run_coroutine_threadsafe(self._close_session(session.session_id), loop) - - except Exception: - self.logger.exception("Failed to connect to local WebSocket") - await local_session.close() - - async def _forward_to_local(self, session: WebRTCSession) -> None: - """Forward messages from WebRTC DataChannel to local WebSocket. - - :param session: The WebRTC session. - """ - try: - while session.local_ws and not session.local_ws.closed: - message = await session.message_queue.get() - - # Check if this is an HTTP proxy request - try: - msg_data = json.loads(message) - if isinstance(msg_data, dict) and msg_data.get("type") == "http-proxy-request": - # Handle HTTP proxy request - await self._handle_http_proxy_request(session, msg_data) - continue - except (json.JSONDecodeError, ValueError): - pass - - # Regular WebSocket message - if session.local_ws and not session.local_ws.closed: - await session.local_ws.send_str(message) - except asyncio.CancelledError: - # Task was cancelled during cleanup, this is expected - self.logger.debug("Forward to local task cancelled for session %s", session.session_id) - raise - except Exception: - self.logger.exception("Error forwarding to local WebSocket") - - async def _forward_from_local( - self, session: WebRTCSession, local_session: aiohttp.ClientSession - ) -> None: - """Forward messages from local WebSocket to WebRTC DataChannel. - - :param session: The WebRTC session. - :param local_session: The aiohttp client session. - """ - try: - async for msg in session.local_ws: - if msg.type == aiohttp.WSMsgType.TEXT: - if session.data_channel and session.data_channel.readyState == "open": - session.data_channel.send(msg.data) - elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): - break - except asyncio.CancelledError: - # Task was cancelled during cleanup, this is expected - self.logger.debug( - "Forward from local task cancelled for session %s", session.session_id - ) - raise - except Exception: - self.logger.exception("Error forwarding from local WebSocket") - finally: - await local_session.close() - - async def _handle_http_proxy_request( - self, session: WebRTCSession, request_data: dict[str, Any] - ) -> None: - """Handle HTTP proxy request from remote client. - - :param session: The WebRTC session. - :param request_data: The HTTP proxy request data. - """ - request_id = request_data.get("id") - method = request_data.get("method", "GET") - path = request_data.get("path", "/") - headers = request_data.get("headers", {}) - - # Build local HTTP URL - # Extract host and port from local_ws_url (ws://localhost:8095/ws) - ws_url_parts = self.local_ws_url.replace("ws://", "").split("/") - host_port = ws_url_parts[0] # localhost:8095 - local_http_url = f"http://{host_port}{path}" - - self.logger.debug("HTTP proxy request: %s %s", method, local_http_url) - - try: - # Create a new HTTP client session for this request - async with ( - aiohttp.ClientSession() as http_session, - http_session.request(method, local_http_url, headers=headers) as response, - ): - # Read response body - body = await response.read() - - # Prepare response data - response_data = { - "type": "http-proxy-response", - "id": request_id, - "status": response.status, - "headers": dict(response.headers), - "body": body.hex(), # Send as hex string to avoid encoding issues - } - - # Send response back through data channel - if session.data_channel and session.data_channel.readyState == "open": - session.data_channel.send(json.dumps(response_data)) - - except Exception as err: - self.logger.exception("Error handling HTTP proxy request") - # Send error response - error_response = { - "type": "http-proxy-response", - "id": request_id, - "status": 500, - "headers": {"Content-Type": "text/plain"}, - "body": str(err).encode().hex(), - } - if session.data_channel and session.data_channel.readyState == "open": - session.data_channel.send(json.dumps(error_response)) - - async def _close_session(self, session_id: str) -> None: - """Close a WebRTC session. - - :param session_id: The session ID. - """ - session = self.sessions.pop(session_id, None) - if not session: - return - - # Cancel forwarding tasks first to prevent race conditions - if session.forward_to_local_task and not session.forward_to_local_task.done(): - session.forward_to_local_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await session.forward_to_local_task - - if session.forward_from_local_task and not session.forward_from_local_task.done(): - session.forward_from_local_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await session.forward_from_local_task - - # Close connections - if session.local_ws and not session.local_ws.closed: - await session.local_ws.close() - if session.data_channel: - session.data_channel.close() - await session.peer_connection.close() diff --git a/music_assistant/mass.py b/music_assistant/mass.py index 19d29bb6..3b04bc89 100644 --- a/music_assistant/mass.py +++ b/music_assistant/mass.py @@ -43,7 +43,6 @@ from music_assistant.controllers.metadata import MetaDataController from music_assistant.controllers.music import MusicController from music_assistant.controllers.player_queues import PlayerQueuesController from music_assistant.controllers.players.player_controller import PlayerController -from music_assistant.controllers.remote_access import RemoteAccessController from music_assistant.controllers.streams import StreamsController from music_assistant.controllers.webserver import WebserverController from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user @@ -112,7 +111,6 @@ class MusicAssistant: players: PlayerController player_queues: PlayerQueuesController streams: StreamsController - remote_access: RemoteAccessController _aiobrowser: AsyncServiceBrowser def __init__(self, storage_path: str, cache_path: str, safe_mode: bool = False) -> None: @@ -168,7 +166,6 @@ class MusicAssistant: self.players = PlayerController(self) self.player_queues = PlayerQueuesController(self) self.streams = StreamsController(self) - self.remote_access = RemoteAccessController(self) # add manifests for core controllers for controller_name in CONFIGURABLE_CORE_CONTROLLERS: controller: CoreController = getattr(self, controller_name) @@ -184,8 +181,6 @@ class MusicAssistant: # not yet available while we're starting (or performing migrations) self._register_api_commands() await self.webserver.setup(await self.config.get_core_config("webserver")) - # setup remote access after webserver (it needs webserver's port) - await self.remote_access.setup(await self.config.get_core_config("remote_access")) # setup discovery await self._setup_discovery() # load providers @@ -207,7 +202,6 @@ class MusicAssistant: ) # stop core controllers await self.streams.close() - await self.remote_access.close() await self.webserver.close() await self.metadata.close() await self.music.close() diff --git a/music_assistant/providers/remote_access/__init__.py b/music_assistant/providers/remote_access/__init__.py new file mode 100644 index 00000000..1ec87564 --- /dev/null +++ b/music_assistant/providers/remote_access/__init__.py @@ -0,0 +1,290 @@ +""" +Remote Access Plugin Provider for Music Assistant. + +This plugin manages WebRTC-based remote access to Music Assistant instances. +It connects to a signaling server and handles incoming WebRTC connections, +bridging them to the local WebSocket API. + +Requires an active Home Assistant Cloud subscription for the best experience. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from music_assistant_models.config_entries import ConfigEntry +from music_assistant_models.enums import ConfigEntryType, EventType, ProviderFeature + +from music_assistant.constants import CONF_ENABLED +from music_assistant.helpers.api import api_command +from music_assistant.models.plugin import PluginProvider +from music_assistant.providers.remote_access.gateway import WebRTCGateway, generate_remote_id + +if TYPE_CHECKING: + from music_assistant_models.config_entries import ConfigValueType, ProviderConfig + from music_assistant_models.event import MassEvent + from music_assistant_models.provider import ProviderManifest + + from music_assistant import MusicAssistant + from music_assistant.models import ProviderInstanceType + from music_assistant.providers.hass import HomeAssistantProvider + +# Signaling server URL +SIGNALING_SERVER_URL = "wss://signaling.music-assistant.io/ws" + +# Config keys +CONF_REMOTE_ID = "remote_id" + +SUPPORTED_FEATURES: set[ProviderFeature] = set() + + +async def setup( + mass: MusicAssistant, manifest: ProviderManifest, config: ProviderConfig +) -> ProviderInstanceType: + """Initialize provider(instance) with given configuration.""" + return RemoteAccessProvider(mass, manifest, config) + + +async def get_config_entries( + mass: MusicAssistant, + instance_id: str | None = None, + action: str | None = None, # noqa: ARG001 + values: dict[str, ConfigValueType] | None = None, # noqa: ARG001 +) -> tuple[ConfigEntry, ...]: + """ + Return Config entries to setup this provider. + + :param mass: MusicAssistant instance. + :param instance_id: id of an existing provider instance (None if new instance setup). + :param action: [optional] action key called from config entries UI. + :param values: the (intermediate) raw values for config entries sent with the action. + """ + entries: list[ConfigEntry] = [ + ConfigEntry( + key=CONF_REMOTE_ID, + type=ConfigEntryType.STRING, + label="Remote ID", + description="Unique identifier for WebRTC remote access. " + "Generated automatically and should not be changed.", + required=False, + hidden=True, + ) + # TODO: Add a message that optimal experience requires Home Assistant Cloud subscription + ] + # Get the remote ID if instance exists + remote_id: str | None = None + if instance_id: + if remote_id_value := mass.config.get_raw_provider_config_value( + instance_id, CONF_REMOTE_ID + ): + remote_id = str(remote_id_value) + entries += [ + ConfigEntry( + key="remote_access_id_intro", + type=ConfigEntryType.LABEL, + label="Remote access is enabled. You can securely connect to your " + "Music Assistant instance from https://app.music-assistant.io or supported " + "(mobile) apps using the Remote ID below.", + hidden=False, + ), + ConfigEntry( + key="remote_access_id_label", + type=ConfigEntryType.LABEL, + label=f"Remote Access ID: {remote_id}", + hidden=False, + ), + ] + + return tuple(entries) + + +async def _check_ha_cloud_status(mass: MusicAssistant) -> bool: + """Check if Home Assistant Cloud subscription is active. + + :param mass: MusicAssistant instance. + :return: True if HA Cloud is logged in and has active subscription. + """ + # Find the Home Assistant provider + ha_provider = cast("HomeAssistantProvider | None", mass.get_provider("hass")) + if not ha_provider: + return False + + try: + # Access the hass client from the provider + hass_client = ha_provider.hass + if not hass_client or not hass_client.connected: + return False + + # Call cloud/status command to check subscription + result = await hass_client.send_command("cloud/status") + + # Check for logged_in and active_subscription + logged_in = result.get("logged_in", False) + active_subscription = result.get("active_subscription", False) + + return bool(logged_in and active_subscription) + + except Exception: + return False + + +class RemoteAccessProvider(PluginProvider): + """Plugin Provider for WebRTC-based remote access.""" + + gateway: WebRTCGateway | None = None + _remote_id: str | None = None + + async def loaded_in_mass(self) -> None: + """Call after the provider has been loaded.""" + remote_id_value = self.config.get_value(CONF_REMOTE_ID) + if not remote_id_value: + # First time setup, generate a new Remote ID + remote_id_value = generate_remote_id() + self._remote_id = remote_id_value + self.logger.debug("Generated new Remote ID: %s", remote_id_value) + await self.mass.config.save_provider_config( + self.domain, + { + CONF_ENABLED: True, + CONF_REMOTE_ID: remote_id_value, + }, + instance_id=self.instance_id, + ) + + else: + self._remote_id = str(remote_id_value) + + # Register API commands + self._register_api_commands() + + # Subscribe to provider updates to check for Home Assistant Cloud + self.mass.subscribe(self._on_providers_updated, EventType.PROVIDERS_UPDATED) + + # Try initial setup (providers might already be loaded) + await self._try_enable_remote_access() + + async def unload(self, is_removed: bool = False) -> None: + """Handle unload/close of the provider. + + :param is_removed: True when the provider is removed from the configuration. + """ + if self.gateway: + await self.gateway.stop() + self.gateway = None + self.logger.debug("WebRTC Remote Access stopped") + + def _on_remote_id_ready(self, remote_id: str) -> None: + """Handle Remote ID registration with signaling server. + + :param remote_id: The registered Remote ID. + """ + self.logger.debug("Remote ID registered with signaling server: %s", remote_id) + self._remote_id = remote_id + + def _on_providers_updated(self, event: MassEvent) -> None: + """Handle providers updated event. + + :param event: The providers updated event. + """ + if self.gateway is not None: + # Already set up, no need to check again + return + # Try to enable remote access when providers are updated + self.mass.create_task(self._try_enable_remote_access()) + + async def _try_enable_remote_access(self) -> None: + """Try to enable remote access if Home Assistant Cloud is available.""" + if self.gateway is not None: + # Already set up + return + + # Determine local WebSocket URL from webserver config + base_url = self.mass.webserver.base_url + local_ws_url = base_url.replace("http", "ws") + if not local_ws_url.endswith("/"): + local_ws_url += "/" + local_ws_url += "ws" + + # Get ICE servers from HA Cloud if available + ice_servers: list[dict[str, str]] | None = None + if await _check_ha_cloud_status(self.mass): + self.logger.info( + "Home Assistant Cloud subscription detected, using HA cloud ICE servers" + ) + ice_servers = await self._get_ha_cloud_ice_servers() + else: + self.logger.info( + "Home Assistant Cloud subscription not detected, using default STUN servers" + ) + + # Initialize and start the WebRTC gateway + self.gateway = WebRTCGateway( + http_session=self.mass.http_session, + signaling_url=SIGNALING_SERVER_URL, + local_ws_url=local_ws_url, + remote_id=self._remote_id, + on_remote_id_ready=self._on_remote_id_ready, + ice_servers=ice_servers, + ) + + await self.gateway.start() + self.logger.info("WebRTC Remote Access enabled - Remote ID: %s", self._remote_id) + + async def _get_ha_cloud_ice_servers(self) -> list[dict[str, str]] | None: + """Get ICE servers from Home Assistant Cloud. + + :return: List of ICE server configurations or None if unavailable. + """ + # Find the Home Assistant provider + ha_provider = cast("HomeAssistantProvider | None", self.mass.get_provider("hass")) + if not ha_provider: + return None + try: + hass_client = ha_provider.hass + if not hass_client or not hass_client.connected: + return None + + # Try to get ICE servers from HA Cloud + # This might be available via a cloud API endpoint + # For now, return None and use default STUN servers + # TODO: Research if HA Cloud exposes ICE/TURN server endpoints + self.logger.debug( + "Using default STUN servers (HA Cloud ICE servers not yet implemented)" + ) + return None + + except Exception: + self.logger.exception("Error getting Home Assistant Cloud ICE servers") + return None + + @property + def is_enabled(self) -> bool: + """Return whether WebRTC remote access is enabled.""" + return self.gateway is not None and self.gateway.is_running + + @property + def is_connected(self) -> bool: + """Return whether the gateway is connected to the signaling server.""" + return self.gateway is not None and self.gateway.is_connected + + @property + def remote_id(self) -> str | None: + """Return the current Remote ID.""" + return self._remote_id + + def _register_api_commands(self) -> None: + """Register API commands for remote access.""" + + @api_command("remote_access/info") + def get_remote_access_info() -> dict[str, str | bool]: + """Get remote access information. + + Returns information about the remote access configuration including + whether it's enabled, connected status, and the Remote ID for connecting. + """ + return { + "enabled": self.is_enabled, + "connected": self.is_connected, + "remote_id": self._remote_id or "", + "signaling_url": SIGNALING_SERVER_URL, + } diff --git a/music_assistant/providers/remote_access/gateway.py b/music_assistant/providers/remote_access/gateway.py new file mode 100644 index 00000000..8f9197ba --- /dev/null +++ b/music_assistant/providers/remote_access/gateway.py @@ -0,0 +1,641 @@ +"""Music Assistant WebRTC Gateway. + +This module provides WebRTC-based remote access to Music Assistant instances. +It connects to a signaling server and handles incoming WebRTC connections, +bridging them to the local WebSocket API. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +import secrets +import string +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import aiohttp +from aiortc import ( + RTCConfiguration, + RTCIceCandidate, + RTCIceServer, + RTCPeerConnection, + RTCSessionDescription, +) + +from music_assistant.constants import MASS_LOGGER_NAME + +if TYPE_CHECKING: + from collections.abc import Callable + +LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.remote_access") + +# Reduce verbose logging from aiortc/aioice +logging.getLogger("aioice").setLevel(logging.WARNING) +logging.getLogger("aiortc").setLevel(logging.WARNING) + + +def generate_remote_id() -> str: + """Generate a unique Remote ID in the format MA-XXXX-XXXX.""" + chars = string.ascii_uppercase + string.digits + part1 = "".join(secrets.choice(chars) for _ in range(4)) + part2 = "".join(secrets.choice(chars) for _ in range(4)) + return f"MA-{part1}-{part2}" + + +@dataclass +class WebRTCSession: + """Represents an active WebRTC session with a remote client.""" + + session_id: str + peer_connection: RTCPeerConnection + data_channel: Any = None + local_ws: Any = None + message_queue: asyncio.Queue[str] = field(default_factory=asyncio.Queue) + forward_to_local_task: asyncio.Task[None] | None = None + forward_from_local_task: asyncio.Task[None] | None = None + + +class WebRTCGateway: + """WebRTC Gateway for Music Assistant Remote Access. + + This gateway: + 1. Connects to a signaling server + 2. Registers with a unique Remote ID + 3. Handles incoming WebRTC connections from remote PWA clients + 4. Bridges WebRTC DataChannel messages to the local WebSocket API + """ + + def __init__( + self, + http_session: aiohttp.ClientSession, + signaling_url: str = "wss://signaling.music-assistant.io/ws", + local_ws_url: str = "ws://localhost:8095/ws", + ice_servers: list[dict[str, Any]] | None = None, + remote_id: str | None = None, + on_remote_id_ready: Callable[[str], None] | None = None, + ) -> None: + """Initialize the WebRTC Gateway. + + :param http_session: Shared aiohttp ClientSession to use for HTTP/WebSocket connections. + :param signaling_url: WebSocket URL of the signaling server. + :param local_ws_url: Local WebSocket URL to bridge to. + :param ice_servers: List of ICE server configurations. + :param remote_id: Optional Remote ID to use (generated if not provided). + :param on_remote_id_ready: Callback when Remote ID is registered. + """ + self.http_session = http_session + self.signaling_url = signaling_url + self.local_ws_url = local_ws_url + self.remote_id = remote_id or generate_remote_id() + self.on_remote_id_ready = on_remote_id_ready + self.logger = LOGGER + + self.ice_servers = ice_servers or [ + {"urls": "stun:stun.l.google.com:19302"}, + {"urls": "stun:stun1.l.google.com:19302"}, + {"urls": "stun:stun.cloudflare.com:3478"}, + ] + + self.sessions: dict[str, WebRTCSession] = {} + self._signaling_ws: aiohttp.ClientWebSocketResponse | None = None + self._running = False + self._reconnect_delay = 5 + self._max_reconnect_delay = 60 + self._current_reconnect_delay = 5 + self._run_task: asyncio.Task[None] | None = None + self._is_connected = False + + @property + def is_running(self) -> bool: + """Return whether the gateway is running.""" + return self._running + + @property + def is_connected(self) -> bool: + """Return whether the gateway is connected to the signaling server.""" + return self._is_connected + + async def start(self) -> None: + """Start the WebRTC Gateway.""" + self.logger.info("Starting WebRTC Gateway with Remote ID: %s", self.remote_id) + self.logger.debug("Signaling URL: %s", self.signaling_url) + self.logger.debug("Local WS URL: %s", self.local_ws_url) + self._running = True + self._run_task = asyncio.create_task(self._run()) + self.logger.debug("WebRTC Gateway start task created") + + async def stop(self) -> None: + """Stop the WebRTC Gateway.""" + self.logger.info("Stopping WebRTC Gateway") + self._running = False + + # Close all sessions + for session_id in list(self.sessions.keys()): + await self._close_session(session_id) + + # Close signaling connection gracefully + if self._signaling_ws and not self._signaling_ws.closed: + try: + await self._signaling_ws.close() + except Exception: + self.logger.debug("Error closing signaling WebSocket", exc_info=True) + + # Cancel run task + if self._run_task and not self._run_task.done(): + self._run_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._run_task + + self._signaling_ws = None + + async def _run(self) -> None: + """Run the main loop with reconnection logic.""" + self.logger.debug("WebRTC Gateway _run() loop starting") + while self._running: + try: + await self._connect_to_signaling() + # Connection closed gracefully or with error + self._is_connected = False + if self._running: + self.logger.warning( + "Signaling server connection lost. Reconnecting in %ss...", + self._current_reconnect_delay, + ) + except Exception: + self._is_connected = False + self.logger.exception("Signaling connection error") + if self._running: + self.logger.info( + "Reconnecting to signaling server in %ss", + self._current_reconnect_delay, + ) + + if self._running: + await asyncio.sleep(self._current_reconnect_delay) + # Exponential backoff with max limit + self._current_reconnect_delay = min( + self._current_reconnect_delay * 2, self._max_reconnect_delay + ) + + async def _connect_to_signaling(self) -> None: + """Connect to the signaling server.""" + self.logger.info("Connecting to signaling server: %s", self.signaling_url) + try: + self._signaling_ws = await self.http_session.ws_connect( + self.signaling_url, + heartbeat=30, # Send WebSocket ping every 30 seconds + autoping=True, # Automatically respond to pings + ) + self.logger.debug("WebSocket connection established") + # Small delay to let any previous connection fully close on the server side + # This helps prevent race conditions during reconnection + await asyncio.sleep(0.5) + self.logger.debug("Sending registration") + await self._register() + self._is_connected = True + # Reset reconnect delay on successful connection + self._current_reconnect_delay = self._reconnect_delay + self.logger.info("Connected and registered with signaling server") + + # Message loop + self.logger.debug("Entering message loop") + async for msg in self._signaling_ws: + self.logger.debug("Received WebSocket message type: %s", msg.type) + if msg.type == aiohttp.WSMsgType.TEXT: + try: + await self._handle_signaling_message(json.loads(msg.data)) + except Exception: + self.logger.exception("Error handling signaling message") + elif msg.type == aiohttp.WSMsgType.PING: + # WebSocket ping - autoping should handle this, just log + self.logger.debug("Received WebSocket PING") + elif msg.type == aiohttp.WSMsgType.PONG: + # WebSocket pong response - just log + self.logger.debug("Received WebSocket PONG") + elif msg.type == aiohttp.WSMsgType.CLOSE: + # Close frame received + self.logger.warning( + "Signaling server sent close frame: code=%s, reason=%s", + msg.data, + msg.extra, + ) + break + elif msg.type == aiohttp.WSMsgType.CLOSED: + self.logger.warning("Signaling server closed connection") + break + elif msg.type == aiohttp.WSMsgType.ERROR: + self.logger.error("WebSocket error: %s", self._signaling_ws.exception()) + break + else: + self.logger.warning("Unexpected WebSocket message type: %s", msg.type) + + self.logger.info( + "Message loop exited - WebSocket closed: %s", self._signaling_ws.closed + ) + except TimeoutError: + self.logger.error("Timeout connecting to signaling server") + except aiohttp.ClientError as err: + self.logger.error("Failed to connect to signaling server: %s", err) + except Exception: + self.logger.exception("Unexpected error in signaling connection") + finally: + self._is_connected = False + self._signaling_ws = None + + async def _register(self) -> None: + """Register with the signaling server.""" + if self._signaling_ws: + registration_msg = { + "type": "register-server", + "remoteId": self.remote_id, + } + self.logger.info( + "Sending registration to signaling server with Remote ID: %s", + self.remote_id, + ) + self.logger.debug("Registration message: %s", registration_msg) + await self._signaling_ws.send_json(registration_msg) + self.logger.debug("Registration message sent successfully") + else: + self.logger.warning("Cannot register: signaling websocket is not connected") + + async def _handle_signaling_message(self, message: dict[str, Any]) -> None: + """Handle incoming signaling messages. + + :param message: The signaling message. + """ + msg_type = message.get("type") + self.logger.debug("Received signaling message: %s - Full message: %s", msg_type, message) + + if msg_type == "ping": + # Respond to ping with pong + if self._signaling_ws: + await self._signaling_ws.send_json({"type": "pong"}) + elif msg_type == "pong": + # Server responded to our ping, connection is alive + pass + elif msg_type == "registered": + self.logger.info("Registered with signaling server as: %s", message.get("remoteId")) + if self.on_remote_id_ready: + self.on_remote_id_ready(self.remote_id) + elif msg_type == "error": + self.logger.error( + "Signaling server error: %s", + message.get("message", "Unknown error"), + ) + elif msg_type == "client-connected": + session_id = message.get("sessionId") + if session_id: + await self._create_session(session_id) + elif msg_type == "client-disconnected": + session_id = message.get("sessionId") + if session_id: + await self._close_session(session_id) + elif msg_type == "offer": + session_id = message.get("sessionId") + offer_data = message.get("data") + if session_id and offer_data: + await self._handle_offer(session_id, offer_data) + elif msg_type == "ice-candidate": + session_id = message.get("sessionId") + candidate_data = message.get("data") + if session_id and candidate_data: + await self._handle_ice_candidate(session_id, candidate_data) + + async def _create_session(self, session_id: str) -> None: + """Create a new WebRTC session. + + :param session_id: The session ID. + """ + config = RTCConfiguration( + iceServers=[RTCIceServer(**server) for server in self.ice_servers] + ) + pc = RTCPeerConnection(configuration=config) + session = WebRTCSession(session_id=session_id, peer_connection=pc) + self.sessions[session_id] = session + + @pc.on("datachannel") + def on_datachannel(channel: Any) -> None: + session.data_channel = channel + asyncio.create_task(self._setup_data_channel(session)) + + @pc.on("icecandidate") + async def on_icecandidate(candidate: Any) -> None: + if candidate and self._signaling_ws: + await self._signaling_ws.send_json( + { + "type": "ice-candidate", + "sessionId": session_id, + "data": { + "candidate": candidate.candidate, + "sdpMid": candidate.sdpMid, + "sdpMLineIndex": candidate.sdpMLineIndex, + }, + } + ) + + @pc.on("connectionstatechange") + async def on_connectionstatechange() -> None: + if pc.connectionState == "failed": + await self._close_session(session_id) + + async def _handle_offer(self, session_id: str, offer: dict[str, Any]) -> None: + """Handle incoming WebRTC offer. + + :param session_id: The session ID. + :param offer: The offer data. + """ + session = self.sessions.get(session_id) + if not session: + return + pc = session.peer_connection + + # Check if peer connection is already closed or closing + if pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Ignoring offer for session %s - connection state: %s", + session_id, + pc.connectionState, + ) + return + + sdp = offer.get("sdp") + sdp_type = offer.get("type") + if not sdp or not sdp_type: + self.logger.error("Invalid offer data: missing sdp or type") + return + + try: + await pc.setRemoteDescription( + RTCSessionDescription( + sdp=str(sdp), + type=str(sdp_type), + ) + ) + + # Check again if session was closed during setRemoteDescription + if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Session %s closed during setRemoteDescription, aborting offer handling", + session_id, + ) + return + + answer = await pc.createAnswer() + + # Check again before setLocalDescription + if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Session %s closed during createAnswer, aborting offer handling", + session_id, + ) + return + + await pc.setLocalDescription(answer) + + # Final check before sending answer + if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Session %s closed during setLocalDescription, skipping answer transmission", + session_id, + ) + return + + if self._signaling_ws: + await self._signaling_ws.send_json( + { + "type": "answer", + "sessionId": session_id, + "data": { + "sdp": pc.localDescription.sdp, + "type": pc.localDescription.type, + }, + } + ) + except Exception: + self.logger.exception("Error handling offer for session %s", session_id) + # Clean up the session on error + await self._close_session(session_id) + + async def _handle_ice_candidate(self, session_id: str, candidate: dict[str, Any]) -> None: + """Handle incoming ICE candidate. + + :param session_id: The session ID. + :param candidate: The ICE candidate data. + """ + session = self.sessions.get(session_id) + if not session or not candidate: + return + + # Check if peer connection is already closed or closing + pc = session.peer_connection + if pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Ignoring ICE candidate for session %s - connection state: %s", + session_id, + pc.connectionState, + ) + return + + candidate_str = candidate.get("candidate") + sdp_mid = candidate.get("sdpMid") + sdp_mline_index = candidate.get("sdpMLineIndex") + + if not candidate_str: + return + + # Create RTCIceCandidate from the SDP string + try: + ice_candidate = RTCIceCandidate( + component=1, + foundation="", + ip="", + port=0, + priority=0, + protocol="udp", + type="host", + sdpMid=str(sdp_mid) if sdp_mid else None, + sdpMLineIndex=int(sdp_mline_index) if sdp_mline_index is not None else None, + ) + # Parse the candidate string to populate the fields + ice_candidate.candidate = str(candidate_str) # type: ignore[attr-defined] + + # Check if session was closed before adding candidate + if session_id not in self.sessions or pc.connectionState in ("closed", "failed"): + self.logger.debug( + "Session %s closed before adding ICE candidate, skipping", + session_id, + ) + return + + await session.peer_connection.addIceCandidate(ice_candidate) + except Exception: + self.logger.exception( + "Failed to add ICE candidate for session %s: %s", session_id, candidate + ) + + async def _setup_data_channel(self, session: WebRTCSession) -> None: + """Set up data channel and bridge to local WebSocket. + + :param session: The WebRTC session. + """ + channel = session.data_channel + if not channel: + return + try: + session.local_ws = await self.http_session.ws_connect(self.local_ws_url) + loop = asyncio.get_event_loop() + + # Store task references for proper cleanup + session.forward_to_local_task = asyncio.create_task(self._forward_to_local(session)) + session.forward_from_local_task = asyncio.create_task(self._forward_from_local(session)) + + @channel.on("message") # type: ignore[misc] + def on_message(message: str) -> None: + # Called from aiortc thread, use call_soon_threadsafe + # Only queue message if session is still active + if session.forward_to_local_task and not session.forward_to_local_task.done(): + loop.call_soon_threadsafe(session.message_queue.put_nowait, message) + + @channel.on("close") # type: ignore[misc] + def on_close() -> None: + # Called from aiortc thread, use call_soon_threadsafe to schedule task + asyncio.run_coroutine_threadsafe(self._close_session(session.session_id), loop) + + except Exception: + self.logger.exception("Failed to connect to local WebSocket") + + async def _forward_to_local(self, session: WebRTCSession) -> None: + """Forward messages from WebRTC DataChannel to local WebSocket. + + :param session: The WebRTC session. + """ + try: + while session.local_ws and not session.local_ws.closed: + message = await session.message_queue.get() + + # Check if this is an HTTP proxy request + try: + msg_data = json.loads(message) + if isinstance(msg_data, dict) and msg_data.get("type") == "http-proxy-request": + # Handle HTTP proxy request + await self._handle_http_proxy_request(session, msg_data) + continue + except (json.JSONDecodeError, ValueError): + pass + + # Regular WebSocket message + if session.local_ws and not session.local_ws.closed: + await session.local_ws.send_str(message) + except asyncio.CancelledError: + # Task was cancelled during cleanup, this is expected + self.logger.debug("Forward to local task cancelled for session %s", session.session_id) + raise + except Exception: + self.logger.exception("Error forwarding to local WebSocket") + + async def _forward_from_local(self, session: WebRTCSession) -> None: + """Forward messages from local WebSocket to WebRTC DataChannel. + + :param session: The WebRTC session. + """ + try: + async for msg in session.local_ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if session.data_channel and session.data_channel.readyState == "open": + session.data_channel.send(msg.data) + elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): + break + except asyncio.CancelledError: + # Task was cancelled during cleanup, this is expected + self.logger.debug( + "Forward from local task cancelled for session %s", session.session_id + ) + raise + except Exception: + self.logger.exception("Error forwarding from local WebSocket") + + async def _handle_http_proxy_request( + self, session: WebRTCSession, request_data: dict[str, Any] + ) -> None: + """Handle HTTP proxy request from remote client. + + :param session: The WebRTC session. + :param request_data: The HTTP proxy request data. + """ + request_id = request_data.get("id") + method = request_data.get("method", "GET") + path = request_data.get("path", "/") + headers = request_data.get("headers", {}) + + # Build local HTTP URL + # Extract host and port from local_ws_url (ws://localhost:8095/ws) + ws_url_parts = self.local_ws_url.replace("ws://", "").split("/") + host_port = ws_url_parts[0] # localhost:8095 + local_http_url = f"http://{host_port}{path}" + + self.logger.debug("HTTP proxy request: %s %s", method, local_http_url) + + try: + # Use shared HTTP session for this request + async with self.http_session.request( + method, local_http_url, headers=headers + ) as response: + # Read response body + body = await response.read() + + # Prepare response data + response_data = { + "type": "http-proxy-response", + "id": request_id, + "status": response.status, + "headers": dict(response.headers), + "body": body.hex(), # Send as hex string to avoid encoding issues + } + + # Send response back through data channel + if session.data_channel and session.data_channel.readyState == "open": + session.data_channel.send(json.dumps(response_data)) + + except Exception as err: + self.logger.exception("Error handling HTTP proxy request") + # Send error response + error_response = { + "type": "http-proxy-response", + "id": request_id, + "status": 500, + "headers": {"Content-Type": "text/plain"}, + "body": str(err).encode().hex(), + } + if session.data_channel and session.data_channel.readyState == "open": + session.data_channel.send(json.dumps(error_response)) + + async def _close_session(self, session_id: str) -> None: + """Close a WebRTC session. + + :param session_id: The session ID. + """ + session = self.sessions.pop(session_id, None) + if not session: + return + + # Cancel forwarding tasks first to prevent race conditions + if session.forward_to_local_task and not session.forward_to_local_task.done(): + session.forward_to_local_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await session.forward_to_local_task + + if session.forward_from_local_task and not session.forward_from_local_task.done(): + session.forward_from_local_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await session.forward_from_local_task + + # Close connections + if session.local_ws and not session.local_ws.closed: + await session.local_ws.close() + if session.data_channel: + session.data_channel.close() + await session.peer_connection.close() diff --git a/music_assistant/providers/remote_access/manifest.json b/music_assistant/providers/remote_access/manifest.json new file mode 100644 index 00000000..a6600c87 --- /dev/null +++ b/music_assistant/providers/remote_access/manifest.json @@ -0,0 +1,13 @@ +{ + "type": "plugin", + "domain": "remote_access", + "stage": "alpha", + "name": "Remote Access", + "description": "WebRTC-based encrypted, secure remote access for connecting to Music Assistant from outside your local network (requires Home Assistant Cloud subscription for the best experience).", + "codeowners": ["@music-assistant"], + "icon": "cloud-lock", + "multi_instance": false, + "builtin": true, + "allow_disable": true, + "requirements": ["aiortc>=1.6.0"] +} diff --git a/pyproject.toml b/pyproject.toml index 6cb01d94..36a557c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,6 @@ dependencies = [ "librosa==0.11.0", "gql[all]==4.0.0", "aiovban>=0.6.3", - "aiortc>=1.6.0", ] description = "Music Assistant" license = {text = "Apache-2.0"}