From: Marcel van der Veldt Date: Sun, 14 Dec 2025 20:16:19 +0000 (+0100) Subject: Fix snapweb control script X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=d140b69306f3aa8cd70fa6d8d49c3d043f3cc738;p=music-assistant-server.git Fix snapweb control script --- diff --git a/music_assistant/providers/snapcast/constants.py b/music_assistant/providers/snapcast/constants.py index ad58fd41..eb5ecb1a 100644 --- a/music_assistant/providers/snapcast/constants.py +++ b/music_assistant/providers/snapcast/constants.py @@ -36,6 +36,10 @@ DEFAULT_SNAPSERVER_IP = "127.0.0.1" DEFAULT_SNAPSERVER_PORT = 1705 DEFAULT_SNAPSTREAM_IDLE_THRESHOLD = 60000 +# Socket path template for control script communication +# The {queue_id} placeholder will be replaced with the actual queue ID +CONTROL_SOCKET_PATH_TEMPLATE = "/tmp/ma-snapcast-{queue_id}.sock" # noqa: S108 + MASS_STREAM_PREFIX = "Music Assistant - " MASS_ANNOUNCEMENT_POSTFIX = " (announcement)" SNAPWEB_DIR = pathlib.Path(__file__).parent.resolve().joinpath("snapweb") diff --git a/music_assistant/providers/snapcast/control.py b/music_assistant/providers/snapcast/control.py index 2a7c6da6..125bdbb9 100755 --- a/music_assistant/providers/snapcast/control.py +++ b/music_assistant/providers/snapcast/control.py @@ -2,13 +2,14 @@ """ Control Music Assistant Snapcast plugin. -This script is a bridge between Music Assistant and Snapcast -It listens to the MA websocket and sends metadata to Snapcast -and listens for player commands +This script is a bridge between Music Assistant and Snapcast. +It connects to Music Assistant via a Unix socket and sends metadata to Snapcast +and listens for player commands. """ import json import logging +import socket import sys import threading import urllib.parse @@ -17,7 +18,6 @@ from time import sleep from typing import Any import shortuuid -import websocket LOOP_STATUS_MAP = { "all": "playlist", @@ -37,37 +37,36 @@ def send(json_msg: dict[str, Any]) -> None: class MusicAssistantControl: - """Music Assistant websocket remote control Snapcast plugin.""" + """Music Assistant Unix socket remote control Snapcast plugin.""" def __init__( - self, queue_id: str, streamserver_ip: str, streamserver_port: int, api_port: int + self, + queue_id: str, + socket_path: str, + streamserver_ip: str, + streamserver_port: int, ) -> None: """Initialize.""" self.queue_id = queue_id - self.api_port = api_port + self.socket_path = socket_path self.streamserver_ip = streamserver_ip self.streamserver_port = streamserver_port self._metadata: dict[str, Any] = {} self._properties: dict[str, Any] = {} self._request_callbacks: dict[str, MessageCallback] = {} self._seek_offset = 0.0 - self.websocket = websocket.WebSocketApp( - url=f"ws://localhost:{api_port}/ws", - on_message=self._on_ws_message, - on_error=self._on_ws_error, - on_open=self._on_ws_open, - on_close=self._on_ws_close, - ) + self._socket: socket.socket | None = None self._stopped = False - self.websocket_thread = threading.Thread(target=self._websocket_loop, args=()) - self.websocket_thread.name = "massControl" - self.websocket_thread.start() + self._socket_thread = threading.Thread(target=self._socket_loop, args=()) + self._socket_thread.name = "massControl" + self._socket_thread.start() def stop(self) -> None: - """Stop the websocket thread.""" + """Stop the socket thread.""" self._stopped = True - self.websocket.close() - self.websocket_thread.join() + if self._socket: + self._socket.close() + self._socket_thread.join() def handle_snapcast_request(self, request: dict[str, Any]) -> None: """Handle (JSON RPC) message from Snapcast.""" @@ -114,7 +113,7 @@ class MusicAssistantControl: self.send_request("player_queues/skip", queue_id=queue_id, seconds=seek_offset) elif cmd == "SetProperty": properties = request["params"] - logger.debug(f"SetProperty: {property}") + logger.debug(f"SetProperty: {properties}") if "shuffle" in properties: self.send_request( "player_queues/shuffle", @@ -173,14 +172,75 @@ class MusicAssistantControl: """Send stream ready notification to Snapcast.""" send({"jsonrpc": "2.0", "method": "Plugin.Stream.Ready"}) - def _websocket_loop(self) -> None: - logger.info("Started websocket loop") + def _socket_loop(self) -> None: + logger.info("Started socket loop") while not self._stopped: try: - self.websocket.run_forever() - sleep(2) + self._connect_and_read() except (Exception, KeyboardInterrupt) as e: - logger.info(f"Exception: {e!s}") + logger.info(f"Exception in socket loop: {e!s}") + if not self._stopped: + sleep(2) + + def _connect_and_read(self) -> None: + """Connect to the Unix socket and read messages.""" + logger.info("Connecting to Unix socket: %s", self.socket_path) + self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + self._socket.connect(self.socket_path) + logger.info("Connected to Unix socket") + self.send_snapcast_stream_ready_notification() + + # Read messages from socket + buffer = "" + while not self._stopped: + try: + data = self._socket.recv(4096) + if not data: + logger.info("Socket closed by server") + break + buffer += data.decode() + + # Process complete lines + while "\n" in buffer: + line, buffer = buffer.split("\n", 1) + if line.strip(): + self._handle_socket_message(line) + except TimeoutError: + continue + except OSError as e: + logger.error(f"Socket error: {e}") + break + finally: + if self._socket: + self._socket.close() + self._socket = None + + def _handle_socket_message(self, message: str) -> None: + """Handle a message from the Music Assistant socket.""" + logger.debug("Socket message received: %s", message) + try: + data = json.loads(message) + except json.JSONDecodeError as e: + logger.error(f"Invalid JSON: {e}") + return + + # Request response + if "message_id" in data: + message_id = data["message_id"] + if callback := self._request_callbacks.pop(message_id, None): + if result := data.get("result"): + callback(result) + # TODO: handle failed requests + return + + # Event + if "event" in data and data.get("object_id") == self.queue_id: + event = data["event"] + if event == "queue_updated": + properties = self._create_properties(data["data"]) + self.send_snapcast_properties_notification(properties) + return def _create_properties(self, mass_queue_details: dict[str, Any]) -> dict[str, Any]: """Create snapcast properties from Music Assistant queue details.""" @@ -235,45 +295,14 @@ class MusicAssistantControl: return properties - def _on_ws_message(self, ws: websocket.WebSocket, message: str) -> None: - # TODO: error handling - logger.debug("websocket message received: %s", message) - data = json.loads(message) - - # Request response - if "message_id" in data: - message_id = data["message_id"] - if callback := self._request_callbacks.pop(message_id, None): - if result := data.get("result"): - callback(result) - # TODO: handle failed requests - return - - # Event - if "event" in data and data["object_id"] == self.queue_id: - event = data["event"] - if event == "queue_updated": - properties = self._create_properties(data["data"]) - self.send_snapcast_properties_notification(properties) - return - - def _on_ws_error(self, ws: websocket.WebSocket, error: Exception | str) -> None: - logger.error("Websocket error") - logger.error(error) - - def _on_ws_open(self, ws: websocket.WebSocket) -> None: - logger.info("Snapcast RPC websocket opened") - self.send_snapcast_stream_ready_notification() - - def _on_ws_close( - self, ws: websocket.WebSocket, close_status_code: int | None, close_msg: str | None - ) -> None: - logger.info("Snapcast RPC websocket closed") - def send_request( - self, command: str, callback: MessageCallback | None = None, **args: str | float + self, command: str, callback: MessageCallback | None = None, **args: str | float | bool ) -> None: - """Send request to Music Assistant.""" + """Send request to Music Assistant via Unix socket.""" + if not self._socket: + logger.warning("Cannot send request - socket not connected") + return + msg_id = shortuuid.random(10) command_msg = { "message_id": msg_id, @@ -283,13 +312,18 @@ class MusicAssistantControl: logger.debug("send_request: %s", command_msg) if callback: self._request_callbacks[msg_id] = callback - self.websocket.send(json.dumps(command_msg)) + try: + data = json.dumps(command_msg) + "\n" + self._socket.sendall(data.encode()) + except OSError as e: + logger.error(f"Failed to send request: {e}") + self._request_callbacks.pop(msg_id, None) if __name__ == "__main__": # Parse command line queue_id = None - api_port = None + socket_path: str | None = None streamserver_ip: str | None = None streamserver_port: str | None = None stream_id: str | None = None @@ -298,15 +332,15 @@ if __name__ == "__main__": stream_id = arg.split("=")[1] if arg.startswith("--queueid="): queue_id = arg.split("=")[1] + if arg.startswith("--socket="): + socket_path = arg.split("=")[1] if arg.startswith("--streamserver-ip="): streamserver_ip = arg.split("=")[1] if arg.startswith("--streamserver-port="): streamserver_port = arg.split("=")[1] - if arg.startswith("--api-port="): - api_port = arg.split("=")[1] - if not queue_id or not api_port: - print("Usage: --stream= --api_port=") # noqa: T201 + if not queue_id or not socket_path: + print("Usage: --stream= --socket=") # noqa: T201 sys.exit() log_format_stderr = "%(asctime)s %(module)s %(levelname)s: %(message)s" @@ -321,12 +355,12 @@ if __name__ == "__main__": logger.addHandler(log_handler) logger.debug( - "Initializing for stream_id %s, queue_id %s and api_port %s", stream_id, queue_id, api_port + "Initializing for stream_id %s, queue_id %s and socket %s", stream_id, queue_id, socket_path ) assert streamserver_ip is not None # for type checking assert streamserver_port is not None - ctrl = MusicAssistantControl(queue_id, streamserver_ip, int(streamserver_port), int(api_port)) + ctrl = MusicAssistantControl(queue_id, socket_path, streamserver_ip, int(streamserver_port)) # keep listening for messages on stdin and forward them try: diff --git a/music_assistant/providers/snapcast/player.py b/music_assistant/providers/snapcast/player.py index dd7e2974..011250e6 100644 --- a/music_assistant/providers/snapcast/player.py +++ b/music_assistant/providers/snapcast/player.py @@ -349,20 +349,21 @@ class SnapCastPlayer(Player): return stream # The control script is used only for music streams in the builtin server # (queue_id is None only for announcement streams). + extra_args = "" if ( self.provider._use_builtin_server and queue_id and self.provider._controlscript_available ): + # Create socket server for control script communication + socket_path = await self.provider.get_or_create_socket_server(queue_id) extra_args = ( f"&controlscript={urllib.parse.quote_plus('control.py')}" f"&controlscriptparams=--queueid={urllib.parse.quote_plus(queue_id)}%20" - f"--api-port={self.mass.webserver.publish_port}%20" + f"--socket={urllib.parse.quote_plus(socket_path)}%20" f"--streamserver-ip={self.mass.streams.publish_ip}%20" f"--streamserver-port={self.mass.streams.publish_port}" ) - else: - extra_args = "" attempts = 50 while attempts: diff --git a/music_assistant/providers/snapcast/provider.py b/music_assistant/providers/snapcast/provider.py index 9f44435b..94c197fc 100644 --- a/music_assistant/providers/snapcast/provider.py +++ b/music_assistant/providers/snapcast/provider.py @@ -32,10 +32,12 @@ from music_assistant.providers.snapcast.constants import ( CONF_STREAM_IDLE_THRESHOLD, CONF_USE_EXTERNAL_SERVER, CONTROL_SCRIPT, + CONTROL_SOCKET_PATH_TEMPLATE, DEFAULT_SNAPSERVER_PORT, SNAPWEB_DIR, ) from music_assistant.providers.snapcast.player import SnapCastPlayer +from music_assistant.providers.snapcast.socket_server import SnapcastSocketServer class SnapCastProvider(PlayerProvider): @@ -50,6 +52,7 @@ class SnapCastProvider(PlayerProvider): _use_builtin_server: bool _stop_called: bool _controlscript_available: bool + _socket_servers: dict[str, SnapcastSocketServer] # queue_id -> socket server async def handle_async_init(self) -> None: """Handle async initialization of the provider.""" @@ -58,6 +61,7 @@ class SnapCastProvider(PlayerProvider): self._use_builtin_server = not self.config.get_value(CONF_USE_EXTERNAL_SERVER) self._stop_called = False self._controlscript_available = False + self._socket_servers = {} if self._use_builtin_server: self._snapcast_server_host = "127.0.0.1" self._snapcast_server_control_port = DEFAULT_SNAPSERVER_PORT @@ -111,6 +115,10 @@ class SnapCastProvider(PlayerProvider): async def unload(self, is_removed: bool = False) -> None: """Handle close/cleanup of the provider.""" self._stop_called = True + # Stop all socket servers + for socket_server in list(self._socket_servers.values()): + await socket_server.stop() + self._socket_servers.clear() for snap_client in self._snapserver.clients: player_id = self._get_ma_id(snap_client.identifier) if not (player := self.mass.players.get(player_id, raise_unavailable=False)): @@ -327,3 +335,35 @@ class SnapCastProvider(PlayerProvider): self.logger.debug("Snapclient removed %s", player_id) else: self.logger.warning("Unable to remove snapclient %s: %s", player_id, error_msg) + + async def get_or_create_socket_server(self, queue_id: str) -> str: + """Get or create a socket server for the given queue. + + :param queue_id: The queue ID to create a socket server for. + :return: The path to the Unix socket. + """ + if queue_id in self._socket_servers: + return self._socket_servers[queue_id].socket_path + + socket_path = CONTROL_SOCKET_PATH_TEMPLATE.format(queue_id=queue_id) + socket_server = SnapcastSocketServer( + mass=self.mass, + queue_id=queue_id, + socket_path=socket_path, + streamserver_ip=str(self.mass.streams.publish_ip), + streamserver_port=cast("int", self.mass.streams.publish_port), + ) + await socket_server.start() + self._socket_servers[queue_id] = socket_server + self.logger.debug("Created socket server for queue %s at %s", queue_id, socket_path) + return socket_path + + async def stop_socket_server(self, queue_id: str) -> None: + """Stop and remove the socket server for the given queue. + + :param queue_id: The queue ID to stop the socket server for. + """ + if queue_id in self._socket_servers: + await self._socket_servers[queue_id].stop() + del self._socket_servers[queue_id] + self.logger.debug("Stopped socket server for queue %s", queue_id) diff --git a/music_assistant/providers/snapcast/socket_server.py b/music_assistant/providers/snapcast/socket_server.py new file mode 100644 index 00000000..04c5a2cf --- /dev/null +++ b/music_assistant/providers/snapcast/socket_server.py @@ -0,0 +1,242 @@ +"""Unix socket server for Snapcast control script communication. + +This module provides a secure communication channel between the Snapcast control script +and Music Assistant, avoiding the need to expose the WebSocket API to the control script. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from contextlib import suppress +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from music_assistant_models.enums import EventType + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + +LOGGER = logging.getLogger(__name__) + +LOOP_STATUS_MAP = { + "all": "playlist", + "one": "track", + "off": "none", +} +LOOP_STATUS_MAP_REVERSE = {v: k for k, v in LOOP_STATUS_MAP.items()} + + +class SnapcastSocketServer: + """Unix socket server for a single Snapcast control script connection. + + Each stream gets its own socket server instance to handle control script communication. + The socket provides a secure IPC channel that doesn't require authentication since + only local processes can connect. + """ + + def __init__( + self, + mass: MusicAssistant, + queue_id: str, + socket_path: str, + streamserver_ip: str, + streamserver_port: int, + ) -> None: + """Initialize the socket server. + + :param mass: The MusicAssistant instance. + :param queue_id: The queue ID this socket serves. + :param socket_path: Path to the Unix socket file. + :param streamserver_ip: IP address of the stream server (for image proxy). + :param streamserver_port: Port of the stream server (for image proxy). + """ + self.mass = mass + self.queue_id = queue_id + self.socket_path = socket_path + self.streamserver_ip = streamserver_ip + self.streamserver_port = streamserver_port + self._server: asyncio.AbstractServer | None = None + self._client_writer: asyncio.StreamWriter | None = None + self._unsub_callback: Any = None + self._logger = LOGGER.getChild(queue_id) + + async def start(self) -> None: + """Start the Unix socket server.""" + # Ensure the socket file doesn't exist + socket_path = Path(self.socket_path) + socket_path.unlink(missing_ok=True) + + # Create the socket server + self._server = await asyncio.start_unix_server( + self._handle_client, + path=self.socket_path, + ) + # Set permissions so only the current user can access + Path(self.socket_path).chmod(0o600) + self._logger.debug("Started Unix socket server at %s", self.socket_path) + + # Subscribe to queue events + self._unsub_callback = self.mass.subscribe( + self._handle_mass_event, + (EventType.QUEUE_UPDATED,), + self.queue_id, + ) + + async def stop(self) -> None: + """Stop the Unix socket server.""" + if self._unsub_callback: + self._unsub_callback() + self._unsub_callback = None + + if self._client_writer: + self._client_writer.close() + with suppress(Exception): + await self._client_writer.wait_closed() + self._client_writer = None + + if self._server: + self._server.close() + await self._server.wait_closed() + self._server = None + + # Clean up socket file + Path(self.socket_path).unlink(missing_ok=True) + self._logger.debug("Stopped Unix socket server") + + async def _handle_client( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + """Handle a control script connection.""" + self._logger.debug("Control script connected") + self._client_writer = writer + + try: + while True: + line = await reader.readline() + if not line: + break + + try: + message = json.loads(line.decode().strip()) + await self._handle_message(message) + except json.JSONDecodeError as err: + self._logger.warning("Invalid JSON from control script: %s", err) + except Exception as err: + self._logger.exception("Error handling control script message: %s", err) + except asyncio.CancelledError: + pass + except ConnectionResetError: + self._logger.debug("Control script connection reset") + finally: + self._client_writer = None + writer.close() + with suppress(Exception): + await writer.wait_closed() + self._logger.debug("Control script disconnected") + + async def _handle_message(self, message: dict[str, Any]) -> None: + """Handle a message from the control script. + + :param message: The JSON message from the control script. + """ + msg_id = message.get("message_id") + command = message.get("command") + args = message.get("args", {}) + + if not command: + await self._send_error(msg_id, "Missing command") + return + + try: + result = await self._execute_command(command, args) + await self._send_result(msg_id, result) + except Exception as err: + self._logger.exception("Error executing command %s: %s", command, err) + await self._send_error(msg_id, str(err)) + + async def _execute_command(self, command: str, args: dict[str, Any]) -> Any: + """Execute a Music Assistant API command. + + :param command: The API command to execute. + :param args: The arguments for the command. + :return: The result of the command. + """ + handler = self.mass.command_handlers.get(command) + if handler is None: + raise ValueError(f"Unknown command: {command}") + + # Execute the handler + result = handler.target(**args) + if asyncio.iscoroutine(result): + result = await result + return result + + async def _send_result(self, msg_id: str | None, result: Any) -> None: + """Send a success result to the control script. + + :param msg_id: The message ID from the request. + :param result: The result data. + """ + if not self._client_writer: + return + + response: dict[str, Any] = {"message_id": msg_id} + if result is not None: + # Convert result to dict if it has to_dict method + if hasattr(result, "to_dict"): + response["result"] = result.to_dict() + else: + response["result"] = result + + await self._send_message(response) + + async def _send_error(self, msg_id: str | None, error: str) -> None: + """Send an error result to the control script. + + :param msg_id: The message ID from the request. + :param error: The error message. + """ + if not self._client_writer: + return + + response = { + "message_id": msg_id, + "error": error, + } + await self._send_message(response) + + async def _send_message(self, message: dict[str, Any]) -> None: + """Send a message to the control script. + + :param message: The message to send. + """ + if not self._client_writer: + return + + try: + data = json.dumps(message) + "\n" + self._client_writer.write(data.encode()) + await self._client_writer.drain() + except (ConnectionResetError, BrokenPipeError): + self._logger.debug("Failed to send message - connection closed") + self._client_writer = None + + def _handle_mass_event(self, event: Any) -> None: + """Handle Music Assistant events and forward to control script. + + :param event: The Music Assistant event. + """ + if not self._client_writer: + return + + # Forward queue_updated events + if event.event == EventType.QUEUE_UPDATED and event.object_id == self.queue_id: + event_msg = { + "event": "queue_updated", + "object_id": event.object_id, + "data": event.data.to_dict() if hasattr(event.data, "to_dict") else event.data, + } + # Schedule the send in the event loop + asyncio.create_task(self._send_message(event_msg))