Snapcast: Reload built-in server in case of connection loss (#1442)
authorSantiago Soto <santiago@soto.uy>
Thu, 4 Jul 2024 21:49:03 +0000 (18:49 -0300)
committerGitHub <noreply@github.com>
Thu, 4 Jul 2024 21:49:03 +0000 (23:49 +0200)
music_assistant/server/providers/snapcast/__init__.py

index 8d33af7db112ca8d015fae5eb0bd2c67df726648..73d829e91509a45f538d1dec63ddaafa834a564e 100644 (file)
@@ -229,6 +229,7 @@ class SnapCastProvider(PlayerProvider):
     _snapserver_runner: asyncio.Task | None
     _snapserver_started: asyncio.Event | None
     _ids_map: bidict  # ma_id / snapclient_id
+    _builtin_server_retry: int
 
     def _get_snapclient_id(self, player_id: str) -> str:
         search_dict = self._ids_map
@@ -282,12 +283,10 @@ class SnapCastProvider(PlayerProvider):
         self._snapcast_stream_dryout_ms = self.config.get_value(CONF_SERVER_DRYOUT_MS)
         self._stream_tasks = {}
         self._ids_map = bidict({})
+        self._builtin_server_retry = 0
 
         if self._use_builtin_server:
-            # start our own builtin snapserver
-            self._snapserver_started = asyncio.Event()
-            self._snapserver_runner = asyncio.create_task(self._builtin_server_runner())
-            await asyncio.wait_for(self._snapserver_started.wait(), 10)
+            await self._start_builtin_server()
         else:
             self._snapserver_runner = None
             self._snapserver_started = None
@@ -303,6 +302,9 @@ class SnapCastProvider(PlayerProvider):
                 "Started connection to Snapserver %s",
                 f"{self._snapcast_server_host}:{self._snapcast_server_control_port}",
             )
+            if self._use_builtin_server:
+                self._snapserver.set_on_disconnect_callback(self._restart_builtin_server)
+
         except OSError as err:
             msg = "Unable to start the Snapserver connection ?"
             raise SetupFailedError(msg) from err
@@ -318,10 +320,7 @@ class SnapCastProvider(PlayerProvider):
             player_id = self._get_ma_id(snap_client_id)
             await self.cmd_stop(player_id)
         self._snapserver.stop()
-        if self._snapserver_runner and not self._snapserver_runner.done():
-            self._snapserver_runner.cancel()
-        await asyncio.sleep(10)  # prevent race conditions when reloading
-        self._snapserver_started.clear()
+        await self._stop_builtin_server()
 
     def on_player_config_removed(self, player_id: str) -> None:
         """Call (by config manager) when the configuration of a player is removed."""
@@ -670,3 +669,29 @@ class SnapCastProvider(PlayerProvider):
                         # delay init a small bit to prevent race conditions
                         # where we try to connect too soon
                         self.mass.loop.call_later(2, self._snapserver_started.set)
+
+    async def _restart_builtin_server(self) -> None:
+        """Restart the built-in Snapserver."""
+        if self._use_builtin_server and self._builtin_server_retry > 1:
+            self.logger.info("Restarting, built-in Snapserver.")
+            await self._stop_builtin_server()
+            await asyncio.sleep(10)  # prevent race conditions when reloading
+            await self._start_builtin_server()
+        else:
+            self._builtin_server_retry += 1
+            self.logger.debug("Increase snapcast retry count")
+
+    async def _stop_builtin_server(self) -> None:
+        """Stop the built-in Snapserver."""
+        self.logger.info("Stopping, built-in Snapserver")
+        if self._snapserver_runner and not self._snapserver_runner.done():
+            self._snapserver_runner.cancel()
+            self._snapserver_started.clear()
+
+    async def _start_builtin_server(self) -> None:
+        """Start the built-in Snapserver."""
+        if self._use_builtin_server:
+            self._builtin_server_retry = 0
+            self._snapserver_started = asyncio.Event()
+            self._snapserver_runner = asyncio.create_task(self._builtin_server_runner())
+            await asyncio.wait_for(self._snapserver_started.wait(), 10)