Fix snapweb control script
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Sun, 14 Dec 2025 20:16:19 +0000 (21:16 +0100)
committerMarcel van der Veldt <m.vanderveldt@outlook.com>
Sun, 14 Dec 2025 20:16:19 +0000 (21:16 +0100)
music_assistant/providers/snapcast/constants.py
music_assistant/providers/snapcast/control.py
music_assistant/providers/snapcast/player.py
music_assistant/providers/snapcast/provider.py
music_assistant/providers/snapcast/socket_server.py [new file with mode: 0644]

index ad58fd41dfc2c1eb4c6c42c9e3bfe6b91fac1182..eb5ecb1ab3f84841b97c816333d5d45fa38e471f 100644 (file)
@@ -36,6 +36,10 @@ DEFAULT_SNAPSERVER_IP = "127.0.0.1"
 DEFAULT_SNAPSERVER_PORT = 1705
 DEFAULT_SNAPSTREAM_IDLE_THRESHOLD = 60000
 
+# Socket path template for control script communication
+# The {queue_id} placeholder will be replaced with the actual queue ID
+CONTROL_SOCKET_PATH_TEMPLATE = "/tmp/ma-snapcast-{queue_id}.sock"  # noqa: S108
+
 MASS_STREAM_PREFIX = "Music Assistant - "
 MASS_ANNOUNCEMENT_POSTFIX = " (announcement)"
 SNAPWEB_DIR = pathlib.Path(__file__).parent.resolve().joinpath("snapweb")
index 2a7c6da6826932138a119fdc3c89745d51c7e860..125bdbb94ccb7d8632f414f886ddbfebb58dc5de 100755 (executable)
@@ -2,13 +2,14 @@
 """
 Control Music Assistant Snapcast plugin.
 
-This script is a bridge between Music Assistant and Snapcast
-It listens to the MA websocket and sends metadata to Snapcast
-and listens for player commands
+This script is a bridge between Music Assistant and Snapcast.
+It connects to Music Assistant via a Unix socket and sends metadata to Snapcast
+and listens for player commands.
 """
 
 import json
 import logging
+import socket
 import sys
 import threading
 import urllib.parse
@@ -17,7 +18,6 @@ from time import sleep
 from typing import Any
 
 import shortuuid
-import websocket
 
 LOOP_STATUS_MAP = {
     "all": "playlist",
@@ -37,37 +37,36 @@ def send(json_msg: dict[str, Any]) -> None:
 
 
 class MusicAssistantControl:
-    """Music Assistant websocket remote control Snapcast plugin."""
+    """Music Assistant Unix socket remote control Snapcast plugin."""
 
     def __init__(
-        self, queue_id: str, streamserver_ip: str, streamserver_port: int, api_port: int
+        self,
+        queue_id: str,
+        socket_path: str,
+        streamserver_ip: str,
+        streamserver_port: int,
     ) -> None:
         """Initialize."""
         self.queue_id = queue_id
-        self.api_port = api_port
+        self.socket_path = socket_path
         self.streamserver_ip = streamserver_ip
         self.streamserver_port = streamserver_port
         self._metadata: dict[str, Any] = {}
         self._properties: dict[str, Any] = {}
         self._request_callbacks: dict[str, MessageCallback] = {}
         self._seek_offset = 0.0
-        self.websocket = websocket.WebSocketApp(
-            url=f"ws://localhost:{api_port}/ws",
-            on_message=self._on_ws_message,
-            on_error=self._on_ws_error,
-            on_open=self._on_ws_open,
-            on_close=self._on_ws_close,
-        )
+        self._socket: socket.socket | None = None
         self._stopped = False
-        self.websocket_thread = threading.Thread(target=self._websocket_loop, args=())
-        self.websocket_thread.name = "massControl"
-        self.websocket_thread.start()
+        self._socket_thread = threading.Thread(target=self._socket_loop, args=())
+        self._socket_thread.name = "massControl"
+        self._socket_thread.start()
 
     def stop(self) -> None:
-        """Stop the websocket thread."""
+        """Stop the socket thread."""
         self._stopped = True
-        self.websocket.close()
-        self.websocket_thread.join()
+        if self._socket:
+            self._socket.close()
+        self._socket_thread.join()
 
     def handle_snapcast_request(self, request: dict[str, Any]) -> None:
         """Handle (JSON RPC) message from Snapcast."""
@@ -114,7 +113,7 @@ class MusicAssistantControl:
                 self.send_request("player_queues/skip", queue_id=queue_id, seconds=seek_offset)
         elif cmd == "SetProperty":
             properties = request["params"]
-            logger.debug(f"SetProperty: {property}")
+            logger.debug(f"SetProperty: {properties}")
             if "shuffle" in properties:
                 self.send_request(
                     "player_queues/shuffle",
@@ -173,14 +172,75 @@ class MusicAssistantControl:
         """Send stream ready notification to Snapcast."""
         send({"jsonrpc": "2.0", "method": "Plugin.Stream.Ready"})
 
-    def _websocket_loop(self) -> None:
-        logger.info("Started websocket loop")
+    def _socket_loop(self) -> None:
+        logger.info("Started socket loop")
         while not self._stopped:
             try:
-                self.websocket.run_forever()
-                sleep(2)
+                self._connect_and_read()
             except (Exception, KeyboardInterrupt) as e:
-                logger.info(f"Exception: {e!s}")
+                logger.info(f"Exception in socket loop: {e!s}")
+                if not self._stopped:
+                    sleep(2)
+
+    def _connect_and_read(self) -> None:
+        """Connect to the Unix socket and read messages."""
+        logger.info("Connecting to Unix socket: %s", self.socket_path)
+        self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+        try:
+            self._socket.connect(self.socket_path)
+            logger.info("Connected to Unix socket")
+            self.send_snapcast_stream_ready_notification()
+
+            # Read messages from socket
+            buffer = ""
+            while not self._stopped:
+                try:
+                    data = self._socket.recv(4096)
+                    if not data:
+                        logger.info("Socket closed by server")
+                        break
+                    buffer += data.decode()
+
+                    # Process complete lines
+                    while "\n" in buffer:
+                        line, buffer = buffer.split("\n", 1)
+                        if line.strip():
+                            self._handle_socket_message(line)
+                except TimeoutError:
+                    continue
+                except OSError as e:
+                    logger.error(f"Socket error: {e}")
+                    break
+        finally:
+            if self._socket:
+                self._socket.close()
+                self._socket = None
+
+    def _handle_socket_message(self, message: str) -> None:
+        """Handle a message from the Music Assistant socket."""
+        logger.debug("Socket message received: %s", message)
+        try:
+            data = json.loads(message)
+        except json.JSONDecodeError as e:
+            logger.error(f"Invalid JSON: {e}")
+            return
+
+        # Request response
+        if "message_id" in data:
+            message_id = data["message_id"]
+            if callback := self._request_callbacks.pop(message_id, None):
+                if result := data.get("result"):
+                    callback(result)
+                # TODO: handle failed requests
+            return
+
+        # Event
+        if "event" in data and data.get("object_id") == self.queue_id:
+            event = data["event"]
+            if event == "queue_updated":
+                properties = self._create_properties(data["data"])
+                self.send_snapcast_properties_notification(properties)
+                return
 
     def _create_properties(self, mass_queue_details: dict[str, Any]) -> dict[str, Any]:
         """Create snapcast properties from Music Assistant queue details."""
@@ -235,45 +295,14 @@ class MusicAssistantControl:
 
         return properties
 
-    def _on_ws_message(self, ws: websocket.WebSocket, message: str) -> None:
-        # TODO: error handling
-        logger.debug("websocket message received: %s", message)
-        data = json.loads(message)
-
-        # Request response
-        if "message_id" in data:
-            message_id = data["message_id"]
-            if callback := self._request_callbacks.pop(message_id, None):
-                if result := data.get("result"):
-                    callback(result)
-                # TODO: handle failed requests
-            return
-
-        # Event
-        if "event" in data and data["object_id"] == self.queue_id:
-            event = data["event"]
-            if event == "queue_updated":
-                properties = self._create_properties(data["data"])
-                self.send_snapcast_properties_notification(properties)
-                return
-
-    def _on_ws_error(self, ws: websocket.WebSocket, error: Exception | str) -> None:
-        logger.error("Websocket error")
-        logger.error(error)
-
-    def _on_ws_open(self, ws: websocket.WebSocket) -> None:
-        logger.info("Snapcast RPC websocket opened")
-        self.send_snapcast_stream_ready_notification()
-
-    def _on_ws_close(
-        self, ws: websocket.WebSocket, close_status_code: int | None, close_msg: str | None
-    ) -> None:
-        logger.info("Snapcast RPC websocket closed")
-
     def send_request(
-        self, command: str, callback: MessageCallback | None = None, **args: str | float
+        self, command: str, callback: MessageCallback | None = None, **args: str | float | bool
     ) -> None:
-        """Send request to Music Assistant."""
+        """Send request to Music Assistant via Unix socket."""
+        if not self._socket:
+            logger.warning("Cannot send request - socket not connected")
+            return
+
         msg_id = shortuuid.random(10)
         command_msg = {
             "message_id": msg_id,
@@ -283,13 +312,18 @@ class MusicAssistantControl:
         logger.debug("send_request: %s", command_msg)
         if callback:
             self._request_callbacks[msg_id] = callback
-        self.websocket.send(json.dumps(command_msg))
+        try:
+            data = json.dumps(command_msg) + "\n"
+            self._socket.sendall(data.encode())
+        except OSError as e:
+            logger.error(f"Failed to send request: {e}")
+            self._request_callbacks.pop(msg_id, None)
 
 
 if __name__ == "__main__":
     # Parse command line
     queue_id = None
-    api_port = None
+    socket_path: str | None = None
     streamserver_ip: str | None = None
     streamserver_port: str | None = None
     stream_id: str | None = None
@@ -298,15 +332,15 @@ if __name__ == "__main__":
             stream_id = arg.split("=")[1]
         if arg.startswith("--queueid="):
             queue_id = arg.split("=")[1]
+        if arg.startswith("--socket="):
+            socket_path = arg.split("=")[1]
         if arg.startswith("--streamserver-ip="):
             streamserver_ip = arg.split("=")[1]
         if arg.startswith("--streamserver-port="):
             streamserver_port = arg.split("=")[1]
-        if arg.startswith("--api-port="):
-            api_port = arg.split("=")[1]
 
-    if not queue_id or not api_port:
-        print("Usage: --stream=<stream_id> --api_port=<api_port>")  # noqa: T201
+    if not queue_id or not socket_path:
+        print("Usage: --stream=<stream_id> --socket=<socket_path>")  # noqa: T201
         sys.exit()
 
     log_format_stderr = "%(asctime)s %(module)s %(levelname)s: %(message)s"
@@ -321,12 +355,12 @@ if __name__ == "__main__":
     logger.addHandler(log_handler)
 
     logger.debug(
-        "Initializing for stream_id %s, queue_id %s and api_port %s", stream_id, queue_id, api_port
+        "Initializing for stream_id %s, queue_id %s and socket %s", stream_id, queue_id, socket_path
     )
 
     assert streamserver_ip is not None  # for type checking
     assert streamserver_port is not None
-    ctrl = MusicAssistantControl(queue_id, streamserver_ip, int(streamserver_port), int(api_port))
+    ctrl = MusicAssistantControl(queue_id, socket_path, streamserver_ip, int(streamserver_port))
 
     # keep listening for messages on stdin and forward them
     try:
index dd7e2974707117f9bfa0492f7336b3758d916fcf..011250e6373f2300390487f1519e4068476b7c5a 100644 (file)
@@ -349,20 +349,21 @@ class SnapCastPlayer(Player):
             return stream
         # The control script is used only for music streams in the builtin server
         # (queue_id is None only for announcement streams).
+        extra_args = ""
         if (
             self.provider._use_builtin_server
             and queue_id
             and self.provider._controlscript_available
         ):
+            # Create socket server for control script communication
+            socket_path = await self.provider.get_or_create_socket_server(queue_id)
             extra_args = (
                 f"&controlscript={urllib.parse.quote_plus('control.py')}"
                 f"&controlscriptparams=--queueid={urllib.parse.quote_plus(queue_id)}%20"
-                f"--api-port={self.mass.webserver.publish_port}%20"
+                f"--socket={urllib.parse.quote_plus(socket_path)}%20"
                 f"--streamserver-ip={self.mass.streams.publish_ip}%20"
                 f"--streamserver-port={self.mass.streams.publish_port}"
             )
-        else:
-            extra_args = ""
 
         attempts = 50
         while attempts:
index 9f44435b8bbe0d368d11f65fc71966f8ddc853d4..94c197fc7bcfd1b84e58cb102b81e56cdffefe3a 100644 (file)
@@ -32,10 +32,12 @@ from music_assistant.providers.snapcast.constants import (
     CONF_STREAM_IDLE_THRESHOLD,
     CONF_USE_EXTERNAL_SERVER,
     CONTROL_SCRIPT,
+    CONTROL_SOCKET_PATH_TEMPLATE,
     DEFAULT_SNAPSERVER_PORT,
     SNAPWEB_DIR,
 )
 from music_assistant.providers.snapcast.player import SnapCastPlayer
+from music_assistant.providers.snapcast.socket_server import SnapcastSocketServer
 
 
 class SnapCastProvider(PlayerProvider):
@@ -50,6 +52,7 @@ class SnapCastProvider(PlayerProvider):
     _use_builtin_server: bool
     _stop_called: bool
     _controlscript_available: bool
+    _socket_servers: dict[str, SnapcastSocketServer]  # queue_id -> socket server
 
     async def handle_async_init(self) -> None:
         """Handle async initialization of the provider."""
@@ -58,6 +61,7 @@ class SnapCastProvider(PlayerProvider):
         self._use_builtin_server = not self.config.get_value(CONF_USE_EXTERNAL_SERVER)
         self._stop_called = False
         self._controlscript_available = False
+        self._socket_servers = {}
         if self._use_builtin_server:
             self._snapcast_server_host = "127.0.0.1"
             self._snapcast_server_control_port = DEFAULT_SNAPSERVER_PORT
@@ -111,6 +115,10 @@ class SnapCastProvider(PlayerProvider):
     async def unload(self, is_removed: bool = False) -> None:
         """Handle close/cleanup of the provider."""
         self._stop_called = True
+        # Stop all socket servers
+        for socket_server in list(self._socket_servers.values()):
+            await socket_server.stop()
+        self._socket_servers.clear()
         for snap_client in self._snapserver.clients:
             player_id = self._get_ma_id(snap_client.identifier)
             if not (player := self.mass.players.get(player_id, raise_unavailable=False)):
@@ -327,3 +335,35 @@ class SnapCastProvider(PlayerProvider):
             self.logger.debug("Snapclient removed %s", player_id)
         else:
             self.logger.warning("Unable to remove snapclient %s: %s", player_id, error_msg)
+
+    async def get_or_create_socket_server(self, queue_id: str) -> str:
+        """Get or create a socket server for the given queue.
+
+        :param queue_id: The queue ID to create a socket server for.
+        :return: The path to the Unix socket.
+        """
+        if queue_id in self._socket_servers:
+            return self._socket_servers[queue_id].socket_path
+
+        socket_path = CONTROL_SOCKET_PATH_TEMPLATE.format(queue_id=queue_id)
+        socket_server = SnapcastSocketServer(
+            mass=self.mass,
+            queue_id=queue_id,
+            socket_path=socket_path,
+            streamserver_ip=str(self.mass.streams.publish_ip),
+            streamserver_port=cast("int", self.mass.streams.publish_port),
+        )
+        await socket_server.start()
+        self._socket_servers[queue_id] = socket_server
+        self.logger.debug("Created socket server for queue %s at %s", queue_id, socket_path)
+        return socket_path
+
+    async def stop_socket_server(self, queue_id: str) -> None:
+        """Stop and remove the socket server for the given queue.
+
+        :param queue_id: The queue ID to stop the socket server for.
+        """
+        if queue_id in self._socket_servers:
+            await self._socket_servers[queue_id].stop()
+            del self._socket_servers[queue_id]
+            self.logger.debug("Stopped socket server for queue %s", queue_id)
diff --git a/music_assistant/providers/snapcast/socket_server.py b/music_assistant/providers/snapcast/socket_server.py
new file mode 100644 (file)
index 0000000..04c5a2c
--- /dev/null
@@ -0,0 +1,242 @@
+"""Unix socket server for Snapcast control script communication.
+
+This module provides a secure communication channel between the Snapcast control script
+and Music Assistant, avoiding the need to expose the WebSocket API to the control script.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+from contextlib import suppress
+from pathlib import Path
+from typing import TYPE_CHECKING, Any
+
+from music_assistant_models.enums import EventType
+
+if TYPE_CHECKING:
+    from music_assistant.mass import MusicAssistant
+
+LOGGER = logging.getLogger(__name__)
+
+LOOP_STATUS_MAP = {
+    "all": "playlist",
+    "one": "track",
+    "off": "none",
+}
+LOOP_STATUS_MAP_REVERSE = {v: k for k, v in LOOP_STATUS_MAP.items()}
+
+
+class SnapcastSocketServer:
+    """Unix socket server for a single Snapcast control script connection.
+
+    Each stream gets its own socket server instance to handle control script communication.
+    The socket provides a secure IPC channel that doesn't require authentication since
+    only local processes can connect.
+    """
+
+    def __init__(
+        self,
+        mass: MusicAssistant,
+        queue_id: str,
+        socket_path: str,
+        streamserver_ip: str,
+        streamserver_port: int,
+    ) -> None:
+        """Initialize the socket server.
+
+        :param mass: The MusicAssistant instance.
+        :param queue_id: The queue ID this socket serves.
+        :param socket_path: Path to the Unix socket file.
+        :param streamserver_ip: IP address of the stream server (for image proxy).
+        :param streamserver_port: Port of the stream server (for image proxy).
+        """
+        self.mass = mass
+        self.queue_id = queue_id
+        self.socket_path = socket_path
+        self.streamserver_ip = streamserver_ip
+        self.streamserver_port = streamserver_port
+        self._server: asyncio.AbstractServer | None = None
+        self._client_writer: asyncio.StreamWriter | None = None
+        self._unsub_callback: Any = None
+        self._logger = LOGGER.getChild(queue_id)
+
+    async def start(self) -> None:
+        """Start the Unix socket server."""
+        # Ensure the socket file doesn't exist
+        socket_path = Path(self.socket_path)
+        socket_path.unlink(missing_ok=True)
+
+        # Create the socket server
+        self._server = await asyncio.start_unix_server(
+            self._handle_client,
+            path=self.socket_path,
+        )
+        # Set permissions so only the current user can access
+        Path(self.socket_path).chmod(0o600)
+        self._logger.debug("Started Unix socket server at %s", self.socket_path)
+
+        # Subscribe to queue events
+        self._unsub_callback = self.mass.subscribe(
+            self._handle_mass_event,
+            (EventType.QUEUE_UPDATED,),
+            self.queue_id,
+        )
+
+    async def stop(self) -> None:
+        """Stop the Unix socket server."""
+        if self._unsub_callback:
+            self._unsub_callback()
+            self._unsub_callback = None
+
+        if self._client_writer:
+            self._client_writer.close()
+            with suppress(Exception):
+                await self._client_writer.wait_closed()
+            self._client_writer = None
+
+        if self._server:
+            self._server.close()
+            await self._server.wait_closed()
+            self._server = None
+
+        # Clean up socket file
+        Path(self.socket_path).unlink(missing_ok=True)
+        self._logger.debug("Stopped Unix socket server")
+
+    async def _handle_client(
+        self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
+    ) -> None:
+        """Handle a control script connection."""
+        self._logger.debug("Control script connected")
+        self._client_writer = writer
+
+        try:
+            while True:
+                line = await reader.readline()
+                if not line:
+                    break
+
+                try:
+                    message = json.loads(line.decode().strip())
+                    await self._handle_message(message)
+                except json.JSONDecodeError as err:
+                    self._logger.warning("Invalid JSON from control script: %s", err)
+                except Exception as err:
+                    self._logger.exception("Error handling control script message: %s", err)
+        except asyncio.CancelledError:
+            pass
+        except ConnectionResetError:
+            self._logger.debug("Control script connection reset")
+        finally:
+            self._client_writer = None
+            writer.close()
+            with suppress(Exception):
+                await writer.wait_closed()
+            self._logger.debug("Control script disconnected")
+
+    async def _handle_message(self, message: dict[str, Any]) -> None:
+        """Handle a message from the control script.
+
+        :param message: The JSON message from the control script.
+        """
+        msg_id = message.get("message_id")
+        command = message.get("command")
+        args = message.get("args", {})
+
+        if not command:
+            await self._send_error(msg_id, "Missing command")
+            return
+
+        try:
+            result = await self._execute_command(command, args)
+            await self._send_result(msg_id, result)
+        except Exception as err:
+            self._logger.exception("Error executing command %s: %s", command, err)
+            await self._send_error(msg_id, str(err))
+
+    async def _execute_command(self, command: str, args: dict[str, Any]) -> Any:
+        """Execute a Music Assistant API command.
+
+        :param command: The API command to execute.
+        :param args: The arguments for the command.
+        :return: The result of the command.
+        """
+        handler = self.mass.command_handlers.get(command)
+        if handler is None:
+            raise ValueError(f"Unknown command: {command}")
+
+        # Execute the handler
+        result = handler.target(**args)
+        if asyncio.iscoroutine(result):
+            result = await result
+        return result
+
+    async def _send_result(self, msg_id: str | None, result: Any) -> None:
+        """Send a success result to the control script.
+
+        :param msg_id: The message ID from the request.
+        :param result: The result data.
+        """
+        if not self._client_writer:
+            return
+
+        response: dict[str, Any] = {"message_id": msg_id}
+        if result is not None:
+            # Convert result to dict if it has to_dict method
+            if hasattr(result, "to_dict"):
+                response["result"] = result.to_dict()
+            else:
+                response["result"] = result
+
+        await self._send_message(response)
+
+    async def _send_error(self, msg_id: str | None, error: str) -> None:
+        """Send an error result to the control script.
+
+        :param msg_id: The message ID from the request.
+        :param error: The error message.
+        """
+        if not self._client_writer:
+            return
+
+        response = {
+            "message_id": msg_id,
+            "error": error,
+        }
+        await self._send_message(response)
+
+    async def _send_message(self, message: dict[str, Any]) -> None:
+        """Send a message to the control script.
+
+        :param message: The message to send.
+        """
+        if not self._client_writer:
+            return
+
+        try:
+            data = json.dumps(message) + "\n"
+            self._client_writer.write(data.encode())
+            await self._client_writer.drain()
+        except (ConnectionResetError, BrokenPipeError):
+            self._logger.debug("Failed to send message - connection closed")
+            self._client_writer = None
+
+    def _handle_mass_event(self, event: Any) -> None:
+        """Handle Music Assistant events and forward to control script.
+
+        :param event: The Music Assistant event.
+        """
+        if not self._client_writer:
+            return
+
+        # Forward queue_updated events
+        if event.event == EventType.QUEUE_UPDATED and event.object_id == self.queue_id:
+            event_msg = {
+                "event": "queue_updated",
+                "object_id": event.object_id,
+                "data": event.data.to_dict() if hasattr(event.data, "to_dict") else event.data,
+            }
+            # Schedule the send in the event loop
+            asyncio.create_task(self._send_message(event_msg))