DEFAULT_SNAPSERVER_PORT = 1705
DEFAULT_SNAPSTREAM_IDLE_THRESHOLD = 60000
+# Socket path template for control script communication
+# The {queue_id} placeholder will be replaced with the actual queue ID
+CONTROL_SOCKET_PATH_TEMPLATE = "/tmp/ma-snapcast-{queue_id}.sock" # noqa: S108
+
MASS_STREAM_PREFIX = "Music Assistant - "
MASS_ANNOUNCEMENT_POSTFIX = " (announcement)"
SNAPWEB_DIR = pathlib.Path(__file__).parent.resolve().joinpath("snapweb")
"""
Control Music Assistant Snapcast plugin.
-This script is a bridge between Music Assistant and Snapcast
-It listens to the MA websocket and sends metadata to Snapcast
-and listens for player commands
+This script is a bridge between Music Assistant and Snapcast.
+It connects to Music Assistant via a Unix socket and sends metadata to Snapcast
+and listens for player commands.
"""
import json
import logging
+import socket
import sys
import threading
import urllib.parse
from typing import Any
import shortuuid
-import websocket
LOOP_STATUS_MAP = {
"all": "playlist",
class MusicAssistantControl:
- """Music Assistant websocket remote control Snapcast plugin."""
+ """Music Assistant Unix socket remote control Snapcast plugin."""
def __init__(
- self, queue_id: str, streamserver_ip: str, streamserver_port: int, api_port: int
+ self,
+ queue_id: str,
+ socket_path: str,
+ streamserver_ip: str,
+ streamserver_port: int,
) -> None:
"""Initialize."""
self.queue_id = queue_id
- self.api_port = api_port
+ self.socket_path = socket_path
self.streamserver_ip = streamserver_ip
self.streamserver_port = streamserver_port
self._metadata: dict[str, Any] = {}
self._properties: dict[str, Any] = {}
self._request_callbacks: dict[str, MessageCallback] = {}
self._seek_offset = 0.0
- self.websocket = websocket.WebSocketApp(
- url=f"ws://localhost:{api_port}/ws",
- on_message=self._on_ws_message,
- on_error=self._on_ws_error,
- on_open=self._on_ws_open,
- on_close=self._on_ws_close,
- )
+ self._socket: socket.socket | None = None
self._stopped = False
- self.websocket_thread = threading.Thread(target=self._websocket_loop, args=())
- self.websocket_thread.name = "massControl"
- self.websocket_thread.start()
+ self._socket_thread = threading.Thread(target=self._socket_loop, args=())
+ self._socket_thread.name = "massControl"
+ self._socket_thread.start()
def stop(self) -> None:
- """Stop the websocket thread."""
+ """Stop the socket thread."""
self._stopped = True
- self.websocket.close()
- self.websocket_thread.join()
+ if self._socket:
+ self._socket.close()
+ self._socket_thread.join()
def handle_snapcast_request(self, request: dict[str, Any]) -> None:
"""Handle (JSON RPC) message from Snapcast."""
self.send_request("player_queues/skip", queue_id=queue_id, seconds=seek_offset)
elif cmd == "SetProperty":
properties = request["params"]
- logger.debug(f"SetProperty: {property}")
+ logger.debug(f"SetProperty: {properties}")
if "shuffle" in properties:
self.send_request(
"player_queues/shuffle",
"""Send stream ready notification to Snapcast."""
send({"jsonrpc": "2.0", "method": "Plugin.Stream.Ready"})
- def _websocket_loop(self) -> None:
- logger.info("Started websocket loop")
+ def _socket_loop(self) -> None:
+ logger.info("Started socket loop")
while not self._stopped:
try:
- self.websocket.run_forever()
- sleep(2)
+ self._connect_and_read()
except (Exception, KeyboardInterrupt) as e:
- logger.info(f"Exception: {e!s}")
+ logger.info(f"Exception in socket loop: {e!s}")
+ if not self._stopped:
+ sleep(2)
+
+ def _connect_and_read(self) -> None:
+ """Connect to the Unix socket and read messages."""
+ logger.info("Connecting to Unix socket: %s", self.socket_path)
+ self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ try:
+ self._socket.connect(self.socket_path)
+ logger.info("Connected to Unix socket")
+ self.send_snapcast_stream_ready_notification()
+
+ # Read messages from socket
+ buffer = ""
+ while not self._stopped:
+ try:
+ data = self._socket.recv(4096)
+ if not data:
+ logger.info("Socket closed by server")
+ break
+ buffer += data.decode()
+
+ # Process complete lines
+ while "\n" in buffer:
+ line, buffer = buffer.split("\n", 1)
+ if line.strip():
+ self._handle_socket_message(line)
+ except TimeoutError:
+ continue
+ except OSError as e:
+ logger.error(f"Socket error: {e}")
+ break
+ finally:
+ if self._socket:
+ self._socket.close()
+ self._socket = None
+
+ def _handle_socket_message(self, message: str) -> None:
+ """Handle a message from the Music Assistant socket."""
+ logger.debug("Socket message received: %s", message)
+ try:
+ data = json.loads(message)
+ except json.JSONDecodeError as e:
+ logger.error(f"Invalid JSON: {e}")
+ return
+
+ # Request response
+ if "message_id" in data:
+ message_id = data["message_id"]
+ if callback := self._request_callbacks.pop(message_id, None):
+ if result := data.get("result"):
+ callback(result)
+ # TODO: handle failed requests
+ return
+
+ # Event
+ if "event" in data and data.get("object_id") == self.queue_id:
+ event = data["event"]
+ if event == "queue_updated":
+ properties = self._create_properties(data["data"])
+ self.send_snapcast_properties_notification(properties)
+ return
def _create_properties(self, mass_queue_details: dict[str, Any]) -> dict[str, Any]:
"""Create snapcast properties from Music Assistant queue details."""
return properties
- def _on_ws_message(self, ws: websocket.WebSocket, message: str) -> None:
- # TODO: error handling
- logger.debug("websocket message received: %s", message)
- data = json.loads(message)
-
- # Request response
- if "message_id" in data:
- message_id = data["message_id"]
- if callback := self._request_callbacks.pop(message_id, None):
- if result := data.get("result"):
- callback(result)
- # TODO: handle failed requests
- return
-
- # Event
- if "event" in data and data["object_id"] == self.queue_id:
- event = data["event"]
- if event == "queue_updated":
- properties = self._create_properties(data["data"])
- self.send_snapcast_properties_notification(properties)
- return
-
- def _on_ws_error(self, ws: websocket.WebSocket, error: Exception | str) -> None:
- logger.error("Websocket error")
- logger.error(error)
-
- def _on_ws_open(self, ws: websocket.WebSocket) -> None:
- logger.info("Snapcast RPC websocket opened")
- self.send_snapcast_stream_ready_notification()
-
- def _on_ws_close(
- self, ws: websocket.WebSocket, close_status_code: int | None, close_msg: str | None
- ) -> None:
- logger.info("Snapcast RPC websocket closed")
-
def send_request(
- self, command: str, callback: MessageCallback | None = None, **args: str | float
+ self, command: str, callback: MessageCallback | None = None, **args: str | float | bool
) -> None:
- """Send request to Music Assistant."""
+ """Send request to Music Assistant via Unix socket."""
+ if not self._socket:
+ logger.warning("Cannot send request - socket not connected")
+ return
+
msg_id = shortuuid.random(10)
command_msg = {
"message_id": msg_id,
logger.debug("send_request: %s", command_msg)
if callback:
self._request_callbacks[msg_id] = callback
- self.websocket.send(json.dumps(command_msg))
+ try:
+ data = json.dumps(command_msg) + "\n"
+ self._socket.sendall(data.encode())
+ except OSError as e:
+ logger.error(f"Failed to send request: {e}")
+ self._request_callbacks.pop(msg_id, None)
if __name__ == "__main__":
# Parse command line
queue_id = None
- api_port = None
+ socket_path: str | None = None
streamserver_ip: str | None = None
streamserver_port: str | None = None
stream_id: str | None = None
stream_id = arg.split("=")[1]
if arg.startswith("--queueid="):
queue_id = arg.split("=")[1]
+ if arg.startswith("--socket="):
+ socket_path = arg.split("=")[1]
if arg.startswith("--streamserver-ip="):
streamserver_ip = arg.split("=")[1]
if arg.startswith("--streamserver-port="):
streamserver_port = arg.split("=")[1]
- if arg.startswith("--api-port="):
- api_port = arg.split("=")[1]
- if not queue_id or not api_port:
- print("Usage: --stream=<stream_id> --api_port=<api_port>") # noqa: T201
+ if not queue_id or not socket_path:
+ print("Usage: --stream=<stream_id> --socket=<socket_path>") # noqa: T201
sys.exit()
log_format_stderr = "%(asctime)s %(module)s %(levelname)s: %(message)s"
logger.addHandler(log_handler)
logger.debug(
- "Initializing for stream_id %s, queue_id %s and api_port %s", stream_id, queue_id, api_port
+ "Initializing for stream_id %s, queue_id %s and socket %s", stream_id, queue_id, socket_path
)
assert streamserver_ip is not None # for type checking
assert streamserver_port is not None
- ctrl = MusicAssistantControl(queue_id, streamserver_ip, int(streamserver_port), int(api_port))
+ ctrl = MusicAssistantControl(queue_id, socket_path, streamserver_ip, int(streamserver_port))
# keep listening for messages on stdin and forward them
try:
return stream
# The control script is used only for music streams in the builtin server
# (queue_id is None only for announcement streams).
+ extra_args = ""
if (
self.provider._use_builtin_server
and queue_id
and self.provider._controlscript_available
):
+ # Create socket server for control script communication
+ socket_path = await self.provider.get_or_create_socket_server(queue_id)
extra_args = (
f"&controlscript={urllib.parse.quote_plus('control.py')}"
f"&controlscriptparams=--queueid={urllib.parse.quote_plus(queue_id)}%20"
- f"--api-port={self.mass.webserver.publish_port}%20"
+ f"--socket={urllib.parse.quote_plus(socket_path)}%20"
f"--streamserver-ip={self.mass.streams.publish_ip}%20"
f"--streamserver-port={self.mass.streams.publish_port}"
)
- else:
- extra_args = ""
attempts = 50
while attempts:
CONF_STREAM_IDLE_THRESHOLD,
CONF_USE_EXTERNAL_SERVER,
CONTROL_SCRIPT,
+ CONTROL_SOCKET_PATH_TEMPLATE,
DEFAULT_SNAPSERVER_PORT,
SNAPWEB_DIR,
)
from music_assistant.providers.snapcast.player import SnapCastPlayer
+from music_assistant.providers.snapcast.socket_server import SnapcastSocketServer
class SnapCastProvider(PlayerProvider):
_use_builtin_server: bool
_stop_called: bool
_controlscript_available: bool
+ _socket_servers: dict[str, SnapcastSocketServer] # queue_id -> socket server
async def handle_async_init(self) -> None:
"""Handle async initialization of the provider."""
self._use_builtin_server = not self.config.get_value(CONF_USE_EXTERNAL_SERVER)
self._stop_called = False
self._controlscript_available = False
+ self._socket_servers = {}
if self._use_builtin_server:
self._snapcast_server_host = "127.0.0.1"
self._snapcast_server_control_port = DEFAULT_SNAPSERVER_PORT
async def unload(self, is_removed: bool = False) -> None:
"""Handle close/cleanup of the provider."""
self._stop_called = True
+ # Stop all socket servers
+ for socket_server in list(self._socket_servers.values()):
+ await socket_server.stop()
+ self._socket_servers.clear()
for snap_client in self._snapserver.clients:
player_id = self._get_ma_id(snap_client.identifier)
if not (player := self.mass.players.get(player_id, raise_unavailable=False)):
self.logger.debug("Snapclient removed %s", player_id)
else:
self.logger.warning("Unable to remove snapclient %s: %s", player_id, error_msg)
+
+ async def get_or_create_socket_server(self, queue_id: str) -> str:
+ """Get or create a socket server for the given queue.
+
+ :param queue_id: The queue ID to create a socket server for.
+ :return: The path to the Unix socket.
+ """
+ if queue_id in self._socket_servers:
+ return self._socket_servers[queue_id].socket_path
+
+ socket_path = CONTROL_SOCKET_PATH_TEMPLATE.format(queue_id=queue_id)
+ socket_server = SnapcastSocketServer(
+ mass=self.mass,
+ queue_id=queue_id,
+ socket_path=socket_path,
+ streamserver_ip=str(self.mass.streams.publish_ip),
+ streamserver_port=cast("int", self.mass.streams.publish_port),
+ )
+ await socket_server.start()
+ self._socket_servers[queue_id] = socket_server
+ self.logger.debug("Created socket server for queue %s at %s", queue_id, socket_path)
+ return socket_path
+
+ async def stop_socket_server(self, queue_id: str) -> None:
+ """Stop and remove the socket server for the given queue.
+
+ :param queue_id: The queue ID to stop the socket server for.
+ """
+ if queue_id in self._socket_servers:
+ await self._socket_servers[queue_id].stop()
+ del self._socket_servers[queue_id]
+ self.logger.debug("Stopped socket server for queue %s", queue_id)
--- /dev/null
+"""Unix socket server for Snapcast control script communication.
+
+This module provides a secure communication channel between the Snapcast control script
+and Music Assistant, avoiding the need to expose the WebSocket API to the control script.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+from contextlib import suppress
+from pathlib import Path
+from typing import TYPE_CHECKING, Any
+
+from music_assistant_models.enums import EventType
+
+if TYPE_CHECKING:
+ from music_assistant.mass import MusicAssistant
+
+LOGGER = logging.getLogger(__name__)
+
+LOOP_STATUS_MAP = {
+ "all": "playlist",
+ "one": "track",
+ "off": "none",
+}
+LOOP_STATUS_MAP_REVERSE = {v: k for k, v in LOOP_STATUS_MAP.items()}
+
+
+class SnapcastSocketServer:
+ """Unix socket server for a single Snapcast control script connection.
+
+ Each stream gets its own socket server instance to handle control script communication.
+ The socket provides a secure IPC channel that doesn't require authentication since
+ only local processes can connect.
+ """
+
+ def __init__(
+ self,
+ mass: MusicAssistant,
+ queue_id: str,
+ socket_path: str,
+ streamserver_ip: str,
+ streamserver_port: int,
+ ) -> None:
+ """Initialize the socket server.
+
+ :param mass: The MusicAssistant instance.
+ :param queue_id: The queue ID this socket serves.
+ :param socket_path: Path to the Unix socket file.
+ :param streamserver_ip: IP address of the stream server (for image proxy).
+ :param streamserver_port: Port of the stream server (for image proxy).
+ """
+ self.mass = mass
+ self.queue_id = queue_id
+ self.socket_path = socket_path
+ self.streamserver_ip = streamserver_ip
+ self.streamserver_port = streamserver_port
+ self._server: asyncio.AbstractServer | None = None
+ self._client_writer: asyncio.StreamWriter | None = None
+ self._unsub_callback: Any = None
+ self._logger = LOGGER.getChild(queue_id)
+
+ async def start(self) -> None:
+ """Start the Unix socket server."""
+ # Ensure the socket file doesn't exist
+ socket_path = Path(self.socket_path)
+ socket_path.unlink(missing_ok=True)
+
+ # Create the socket server
+ self._server = await asyncio.start_unix_server(
+ self._handle_client,
+ path=self.socket_path,
+ )
+ # Set permissions so only the current user can access
+ Path(self.socket_path).chmod(0o600)
+ self._logger.debug("Started Unix socket server at %s", self.socket_path)
+
+ # Subscribe to queue events
+ self._unsub_callback = self.mass.subscribe(
+ self._handle_mass_event,
+ (EventType.QUEUE_UPDATED,),
+ self.queue_id,
+ )
+
+ async def stop(self) -> None:
+ """Stop the Unix socket server."""
+ if self._unsub_callback:
+ self._unsub_callback()
+ self._unsub_callback = None
+
+ if self._client_writer:
+ self._client_writer.close()
+ with suppress(Exception):
+ await self._client_writer.wait_closed()
+ self._client_writer = None
+
+ if self._server:
+ self._server.close()
+ await self._server.wait_closed()
+ self._server = None
+
+ # Clean up socket file
+ Path(self.socket_path).unlink(missing_ok=True)
+ self._logger.debug("Stopped Unix socket server")
+
+ async def _handle_client(
+ self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
+ ) -> None:
+ """Handle a control script connection."""
+ self._logger.debug("Control script connected")
+ self._client_writer = writer
+
+ try:
+ while True:
+ line = await reader.readline()
+ if not line:
+ break
+
+ try:
+ message = json.loads(line.decode().strip())
+ await self._handle_message(message)
+ except json.JSONDecodeError as err:
+ self._logger.warning("Invalid JSON from control script: %s", err)
+ except Exception as err:
+ self._logger.exception("Error handling control script message: %s", err)
+ except asyncio.CancelledError:
+ pass
+ except ConnectionResetError:
+ self._logger.debug("Control script connection reset")
+ finally:
+ self._client_writer = None
+ writer.close()
+ with suppress(Exception):
+ await writer.wait_closed()
+ self._logger.debug("Control script disconnected")
+
+ async def _handle_message(self, message: dict[str, Any]) -> None:
+ """Handle a message from the control script.
+
+ :param message: The JSON message from the control script.
+ """
+ msg_id = message.get("message_id")
+ command = message.get("command")
+ args = message.get("args", {})
+
+ if not command:
+ await self._send_error(msg_id, "Missing command")
+ return
+
+ try:
+ result = await self._execute_command(command, args)
+ await self._send_result(msg_id, result)
+ except Exception as err:
+ self._logger.exception("Error executing command %s: %s", command, err)
+ await self._send_error(msg_id, str(err))
+
+ async def _execute_command(self, command: str, args: dict[str, Any]) -> Any:
+ """Execute a Music Assistant API command.
+
+ :param command: The API command to execute.
+ :param args: The arguments for the command.
+ :return: The result of the command.
+ """
+ handler = self.mass.command_handlers.get(command)
+ if handler is None:
+ raise ValueError(f"Unknown command: {command}")
+
+ # Execute the handler
+ result = handler.target(**args)
+ if asyncio.iscoroutine(result):
+ result = await result
+ return result
+
+ async def _send_result(self, msg_id: str | None, result: Any) -> None:
+ """Send a success result to the control script.
+
+ :param msg_id: The message ID from the request.
+ :param result: The result data.
+ """
+ if not self._client_writer:
+ return
+
+ response: dict[str, Any] = {"message_id": msg_id}
+ if result is not None:
+ # Convert result to dict if it has to_dict method
+ if hasattr(result, "to_dict"):
+ response["result"] = result.to_dict()
+ else:
+ response["result"] = result
+
+ await self._send_message(response)
+
+ async def _send_error(self, msg_id: str | None, error: str) -> None:
+ """Send an error result to the control script.
+
+ :param msg_id: The message ID from the request.
+ :param error: The error message.
+ """
+ if not self._client_writer:
+ return
+
+ response = {
+ "message_id": msg_id,
+ "error": error,
+ }
+ await self._send_message(response)
+
+ async def _send_message(self, message: dict[str, Any]) -> None:
+ """Send a message to the control script.
+
+ :param message: The message to send.
+ """
+ if not self._client_writer:
+ return
+
+ try:
+ data = json.dumps(message) + "\n"
+ self._client_writer.write(data.encode())
+ await self._client_writer.drain()
+ except (ConnectionResetError, BrokenPipeError):
+ self._logger.debug("Failed to send message - connection closed")
+ self._client_writer = None
+
+ def _handle_mass_event(self, event: Any) -> None:
+ """Handle Music Assistant events and forward to control script.
+
+ :param event: The Music Assistant event.
+ """
+ if not self._client_writer:
+ return
+
+ # Forward queue_updated events
+ if event.event == EventType.QUEUE_UPDATED and event.object_id == self.queue_id:
+ event_msg = {
+ "event": "queue_updated",
+ "object_id": event.object_id,
+ "data": event.data.to_dict() if hasattr(event.data, "to_dict") else event.data,
+ }
+ # Schedule the send in the event loop
+ asyncio.create_task(self._send_message(event_msg))