From ae67a44a7f585d8d508c1630b13e15de585fa0f9 Mon Sep 17 00:00:00 2001 From: Mischa Siekmann <45062894+gnumpi@users.noreply.github.com> Date: Thu, 5 Feb 2026 14:52:29 +0100 Subject: [PATCH] Snapcast: Stop the control scripts gracefully before shutting down the built-in snapcast server (#3092) gracefully stop control scripts before shutting down the snapcast server --- music_assistant/providers/snapcast/control.py | 26 ++++++++-- .../providers/snapcast/provider.py | 48 +++++++++++-------- .../providers/snapcast/socket_server.py | 11 +++++ 3 files changed, 62 insertions(+), 23 deletions(-) diff --git a/music_assistant/providers/snapcast/control.py b/music_assistant/providers/snapcast/control.py index 648df1d8..bfeeba61 100755 --- a/music_assistant/providers/snapcast/control.py +++ b/music_assistant/providers/snapcast/control.py @@ -14,6 +14,7 @@ import sys import threading import urllib.parse from collections.abc import Callable +from contextlib import suppress from time import sleep from typing import Any @@ -57,6 +58,7 @@ class MusicAssistantControl: self._seek_offset = 0.0 self._socket: socket.socket | None = None self._stopped = False + self._shutdown_event = threading.Event() self._socket_thread = threading.Thread(target=self._socket_loop, args=()) self._socket_thread.name = "massControl" self._socket_thread.start() @@ -65,8 +67,16 @@ class MusicAssistantControl: """Stop the socket thread.""" self._stopped = True if self._socket: - self._socket.close() - self._socket_thread.join() + with suppress(OSError): + self._socket.close() + if threading.current_thread() is not self._socket_thread: + self._socket_thread.join() + + def shutdown(self) -> None: + """Exit the control script.""" + logger.info("Shutdown requested by server") + self.stop() + self._shutdown_event.set() def handle_snapcast_request(self, request: dict[str, Any]) -> None: """Handle (JSON RPC) message from Snapcast.""" @@ -225,6 +235,10 @@ class MusicAssistantControl: logger.error(f"Invalid JSON: {e}") return + if data.get("command") == "shutdown": + self.shutdown() + return + # Request response if "message_id" in data: message_id = data["message_id"] @@ -364,7 +378,10 @@ if __name__ == "__main__": # keep listening for messages on stdin and forward them try: - for line in sys.stdin: + while not ctrl._shutdown_event.is_set(): + line = sys.stdin.readline() + if not line: # EOF + break try: ctrl.handle_snapcast_request(json.loads(line)) except Exception as e: @@ -375,5 +392,6 @@ if __name__ == "__main__": "id": id, } ) - except (SystemExit, KeyboardInterrupt): + finally: + ctrl.stop() sys.exit(0) diff --git a/music_assistant/providers/snapcast/provider.py b/music_assistant/providers/snapcast/provider.py index 7d465d8c..6ef28602 100644 --- a/music_assistant/providers/snapcast/provider.py +++ b/music_assistant/providers/snapcast/provider.py @@ -115,10 +115,7 @@ class SnapCastProvider(PlayerProvider): async def unload(self, is_removed: bool = False) -> None: """Handle close/cleanup of the provider.""" self._stop_called = True - # Stop all socket servers - for socket_server in list(self._socket_servers.values()): - await socket_server.stop() - self._socket_servers.clear() + for snap_client in self._snapserver.clients: player_id = self._get_ma_id(snap_client.identifier) if not (player := self.mass.players.get(player_id, raise_unavailable=False)): @@ -231,21 +228,34 @@ class SnapCastProvider(PlayerProvider): f"--streaming_client.initial_volume={self._snapcast_server_initial_volume}", ] async with AsyncProcess(args, stdout=True, name="snapserver") as snapserver_proc: - # keep reading from stdout until exit - async for raw_data in snapserver_proc.iter_any(): - text = raw_data.decode().strip() - for line in text.split("\n"): - logger.debug(line) - if "(Snapserver) Version 0." in line: - # delay init a small bit to prevent race conditions - # where we try to connect too soon - self.mass.loop.call_later(2, self._snapserver_started.set) - # Copy control script after snapserver starts - # (run in executor to avoid blocking) - loop = asyncio.get_running_loop() - self._controlscript_available = await loop.run_in_executor( - None, self._setup_controlscript - ) + try: + # keep reading from stdout until exit + async for raw_data in snapserver_proc.iter_any(): + text = raw_data.decode().strip() + for line in text.split("\n"): + logger.debug(line) + if "(Snapserver) Version 0." in line: + # delay init a small bit to prevent race conditions + # where we try to connect too soon + self.mass.loop.call_later(2, self._snapserver_started.set) + # Copy control script after snapserver starts + # (run in executor to avoid blocking) + loop = asyncio.get_running_loop() + self._controlscript_available = await loop.run_in_executor( + None, self._setup_controlscript + ) + except asyncio.CancelledError: + # Currently, MA doesn't guarantee a defined shutdown order; + # Make sure to close socket servers before + # shutting down the snapcast server. + # + # The snapserver doesn't always cleanup the control script processes + # properly. We do it explicitly when closing a socket server. + # Should be fixed on the server side, though. + for socket_server in list(self._socket_servers.values()): + await socket_server.stop() + self._socket_servers.clear() + raise def _get_ma_id(self, snap_client_id: str) -> str: search_dict = self._ids_map.inverse diff --git a/music_assistant/providers/snapcast/socket_server.py b/music_assistant/providers/snapcast/socket_server.py index 3f744a52..67295a87 100644 --- a/music_assistant/providers/snapcast/socket_server.py +++ b/music_assistant/providers/snapcast/socket_server.py @@ -92,6 +92,8 @@ class SnapcastSocketServer: self._unsub_callback = None if self._client_writer: + with suppress(Exception): + await self.notify_shutdown() self._client_writer.close() with suppress(Exception): await self._client_writer.wait_closed() @@ -106,6 +108,15 @@ class SnapcastSocketServer: Path(self.socket_path).unlink(missing_ok=True) self._logger.debug("Stopped Unix socket server") + async def notify_shutdown(self) -> None: + """Tell the control script to exit.""" + await self._send_message( + { + "event": "shutdown", + "object_id": self.queue_id, + } + ) + async def _handle_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: -- 2.34.1