Reload Snapcast provider when connection to the server gets lost (#1447)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Fri, 5 Jul 2024 10:18:45 +0000 (12:18 +0200)
committerGitHub <noreply@github.com>
Fri, 5 Jul 2024 10:18:45 +0000 (12:18 +0200)
music_assistant/server/controllers/config.py
music_assistant/server/providers/snapcast/__init__.py

index e504ee303c385d7f5669fb0f8a51ee8dcd912a2f..b378a0915f9bffa33a830897873c4b27823c9a30 100644 (file)
@@ -313,7 +313,11 @@ class ConfigController:
     @api_command("config/providers/reload")
     async def reload_provider(self, instance_id: str) -> None:
         """Reload provider."""
-        config = await self.get_provider_config(instance_id)
+        try:
+            config = await self.get_provider_config(instance_id)
+        except KeyError:
+            # Edge case: Provider was removed before we could reload it
+            return
         await self._load_provider_config(config)
 
     @api_command("config/players")
index 73ec3d3e567cb7a11531725265f6054ec7ab0b6d..95d3b23fd5e5d6e7a22c719a178162aa5ac77227 100644 (file)
@@ -229,7 +229,6 @@ class SnapCastProvider(PlayerProvider):
     _snapserver_runner: asyncio.Task | None
     _snapserver_started: asyncio.Event | None
     _ids_map: bidict  # ma_id / snapclient_id
-    _builtin_server_retry: int
 
     def _get_snapclient_id(self, player_id: str) -> str:
         search_dict = self._ids_map
@@ -283,7 +282,6 @@ class SnapCastProvider(PlayerProvider):
         self._snapcast_stream_dryout_ms = self.config.get_value(CONF_SERVER_DRYOUT_MS)
         self._stream_tasks = {}
         self._ids_map = bidict({})
-        self._builtin_server_retry = 0
 
         if self._use_builtin_server:
             await self._start_builtin_server()
@@ -302,8 +300,8 @@ class SnapCastProvider(PlayerProvider):
                 "Started connection to Snapserver %s",
                 f"{self._snapcast_server_host}:{self._snapcast_server_control_port}",
             )
-            if self._use_builtin_server:
-                self._snapserver.set_on_disconnect_callback(self._restart_builtin_server)
+            # register callback for when the connection gets lost to the snapserver
+            self._snapserver.set_on_disconnect_callback(self._handle_disconnect)
 
         except OSError as err:
             msg = "Unable to start the Snapserver connection ?"
@@ -670,17 +668,6 @@ class SnapCastProvider(PlayerProvider):
                         # where we try to connect too soon
                         self.mass.loop.call_later(2, self._snapserver_started.set)
 
-    async def _restart_builtin_server(self) -> None:
-        """Restart the built-in Snapserver."""
-        if self._use_builtin_server and self._builtin_server_retry > 1:
-            self.logger.info("Restarting, built-in Snapserver.")
-            await self._stop_builtin_server()
-            await asyncio.sleep(10)  # prevent race conditions when reloading
-            await self._start_builtin_server()
-        else:
-            self._builtin_server_retry += 1
-            self.logger.debug("Increase snapcast retry count")
-
     async def _stop_builtin_server(self) -> None:
         """Stop the built-in Snapserver."""
         self.logger.info("Stopping, built-in Snapserver")
@@ -691,7 +678,14 @@ class SnapCastProvider(PlayerProvider):
     async def _start_builtin_server(self) -> None:
         """Start the built-in Snapserver."""
         if self._use_builtin_server:
-            self._builtin_server_retry = 0
             self._snapserver_started = asyncio.Event()
             self._snapserver_runner = asyncio.create_task(self._builtin_server_runner())
             await asyncio.wait_for(self._snapserver_started.wait(), 10)
+
+    def _handle_disconnect(self, exc: Exception) -> None:
+        """Handle disconnect callback from snapserver."""
+        self.logger.info(
+            "Connection to SnapServer lost, reason: %s. Reloading provider in 5 seconds.", str(exc)
+        )
+        # schedule a reload of the provider
+        self.mass.call_later(5, self.mass.config.reload_provider(self.instance_id))