Auto retry provider load if its unavailable or the connection gets lost (#1364)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Fri, 14 Jun 2024 20:08:46 +0000 (22:08 +0200)
committerGitHub <noreply@github.com>
Fri, 14 Jun 2024 20:08:46 +0000 (22:08 +0200)
music_assistant/server/controllers/config.py
music_assistant/server/providers/hass/__init__.py
music_assistant/server/providers/hass/manifest.json
music_assistant/server/providers/hass_players/__init__.py
music_assistant/server/server.py
requirements_all.txt

index 95ed6f46212476f077fce08fd7184e92edb05d8b..99a59999108c8c48ab73419ee75e49015a121118 100644 (file)
@@ -803,7 +803,7 @@ class ConfigController:
         # validate the new config
         config.validate()
         # try to load the provider first to catch errors before we save it.
-        await self.mass.load_provider(config)
+        await self.mass.load_provider(config, raise_on_error=True)
         # the load was a success, store this config
         conf_key = f"{CONF_PROVIDERS}/{config.instance_id}"
         self.set(conf_key, config.to_raw())
index 463b75d0993ac787feb5ce9919051f504df3f17d..0d41f3ef0f64b69593cd6d2fe82157ffeea6d7e0 100644 (file)
@@ -9,11 +9,13 @@ communication over the HA api for more flexibility as well as security.
 
 from __future__ import annotations
 
+import asyncio
 import logging
 from typing import TYPE_CHECKING
 
 import shortuuid
 from hass_client import HomeAssistantClient
+from hass_client.exceptions import BaseHassClientError
 from hass_client.utils import (
     async_is_supervisor,
     base_url,
@@ -152,6 +154,7 @@ class HomeAssistant(PluginProvider):
     """Home Assistant Plugin for Music Assistant."""
 
     hass: HomeAssistantClient
+    _listen_task: asyncio.Task | None = None
 
     async def handle_async_init(self) -> None:
         """Handle async initialization of the plugin."""
@@ -160,6 +163,7 @@ class HomeAssistant(PluginProvider):
         logging.getLogger("hass_client").setLevel(self.logger.level + 10)
         self.hass = HomeAssistantClient(url, token, self.mass.http_session)
         await self.hass.connect()
+        self._listen_task = self.mass.create_task(self._hass_listener())
 
     async def unload(self) -> None:
         """
@@ -167,4 +171,17 @@ class HomeAssistant(PluginProvider):
 
         Called when provider is deregistered (e.g. MA exiting or config reloading).
         """
+        if self._listen_task and not self._listen_task.done():
+            self._listen_task.cancel()
         await self.hass.disconnect()
+
+    async def _hass_listener(self) -> None:
+        """Start listening on the HA websockets."""
+        try:
+            # start listening will block until the connection is lost/closed
+            await self.hass.start_listening()
+        except BaseHassClientError as err:
+            self.logger.warning("Connection to HA lost due to error: %s", err)
+        self.logger.info("Connection to HA lost. Reloading provider in 5 seconds.")
+        # schedule a reload of the provider
+        self.mass.call_later(5, self.mass.config.reload_provider(self.instance_id))
index fb02b95466344bafcb5ac1ac7252cdcc09f4b2eb..7726c3246aa3f8730ddd13f9033df98986f4c2c8 100644 (file)
@@ -12,6 +12,6 @@
   "load_by_default": false,
   "icon": "md:webhook",
   "requirements": [
-    "hass-client==1.0.1"
+    "hass-client==1.1.0"
   ]
 }
index e3a66ffb4301f49ed101c3e5125cbe0324f72807..1e30337254c8ea3dfdba8c56c53619eedc156c63 100644 (file)
@@ -50,7 +50,7 @@ CONF_PLAYERS = "players"
 
 StateMap = {
     "playing": PlayerState.PLAYING,
-    "paused": PlayerState.PLAYING,
+    "paused": PlayerState.PAUSED,
     "buffering": PlayerState.PLAYING,
     "idle": PlayerState.IDLE,
     "off": PlayerState.IDLE,
@@ -350,7 +350,6 @@ class HomeAssistantPlayers(PlayerProvider):
         device_registry: dict[str, HassDevice],
     ) -> None:
         """Handle setup of a Player from an hass entity."""
-        # fetch the entity registry entry for this entity to obtain more details
         hass_device: HassDevice | None = None
         platform_players: list[str] = []
         if entity_registry_entry := entity_registry.get(state["entity_id"]):
@@ -408,7 +407,13 @@ class HomeAssistantPlayers(PlayerProvider):
             """Handle updating MA player with updated info in a HA CompressedState."""
             player = self.mass.players.get(entity_id)
             if player is None:
-                return  # should not happen, but guard just in case
+                # edge case - one of our subscribed entities was not available at startup
+                # and now came available - we should still set it up
+                player_ids: list[str] = self.config.get_value(CONF_PLAYERS)
+                if entity_id not in player_ids:
+                    return  # should not happen, but guard just in case
+                self.mass.create_task(self._late_add_player(entity_id))
+                return
             if "s" in state:
                 player.state = StateMap.get(state["s"], PlayerState.IDLE)
                 player.powered = state["s"] not in (
@@ -453,3 +458,15 @@ class HomeAssistantPlayers(PlayerProvider):
                 else:
                     player.group_childs = set()
                     player.synced_to = None
+
+    async def _late_add_player(self, entity_id: str) -> None:
+        """Handle setup of Player from HA entity that became available after startup."""
+        # prefetch the device- and entity registry
+        device_registry = {x["id"]: x for x in await self.hass_prov.hass.get_device_registry()}
+        entity_registry = {
+            x["entity_id"]: x for x in await self.hass_prov.hass.get_entity_registry()
+        }
+        async for state in _get_hass_media_players(self.hass_prov):
+            if state["entity_id"] != entity_id:
+                continue
+            await self._setup_player(state, entity_registry, device_registry)
index 9be9e59407c9e19fdd4496b7c7f4bbf88bdfcded..9443d4cd9034a2072ffc09e69a1c82e742fd8a14 100644 (file)
@@ -405,7 +405,96 @@ class MusicAssistant:
             raise RuntimeError(msg)
         self.command_handlers[command] = APICommandHandler.parse(command, handler)
 
-    async def load_provider(self, conf: ProviderConfig) -> None:
+    async def load_provider(
+        self,
+        prov_conf: ProviderConfig,
+        raise_on_error: bool = False,
+        schedule_retry: int | None = 10,
+    ) -> None:
+        """Try to load a provider and catch errors."""
+        try:
+            await self._load_provider(prov_conf)
+        # pylint: disable=broad-except
+        except Exception as exc:
+            LOGGER.exception(
+                "Error loading provider(instance) %s",
+                prov_conf.name or prov_conf.domain,
+            )
+            if raise_on_error:
+                raise
+            # if loading failed, we store the error in the config object
+            # so we can show something useful to the user
+            prov_conf.last_error = str(exc)
+            self.config.set(f"{CONF_PROVIDERS}/{prov_conf.instance_id}/last_error", str(exc))
+            # auto schedule a retry if the (re)load failed
+            if schedule_retry:
+                self.call_later(
+                    schedule_retry,
+                    self.load_provider,
+                    prov_conf,
+                    raise_on_error,
+                    min(schedule_retry + 10, 600),
+                )
+
+    async def unload_provider(self, instance_id: str) -> None:
+        """Unload a provider."""
+        if provider := self._providers.get(instance_id):
+            # remove mdns discovery if needed
+            if provider.manifest.mdns_discovery:
+                for mdns_type in provider.manifest.mdns_discovery:
+                    self._aiobrowser.types.discard(mdns_type)
+            # make sure to stop any running sync tasks first
+            for sync_task in self.music.in_progress_syncs:
+                if sync_task.provider_instance == instance_id:
+                    sync_task.task.cancel()
+            # check if there are no other providers dependent of this provider
+            for dep_prov in self.providers:
+                if dep_prov.manifest.depends_on == provider.domain:
+                    await self.unload_provider(dep_prov.instance_id)
+            try:
+                await provider.unload()
+            except Exception as err:
+                LOGGER.warning("Error while unload provider %s: %s", provider.name, str(err))
+            finally:
+                self._providers.pop(instance_id, None)
+                await self._update_available_providers_cache()
+                self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
+
+    def _register_api_commands(self) -> None:
+        """Register all methods decorated as api_command within a class(instance)."""
+        for cls in (
+            self,
+            self.config,
+            self.metadata,
+            self.music,
+            self.players,
+            self.player_queues,
+        ):
+            for attr_name in dir(cls):
+                if attr_name.startswith("__"):
+                    continue
+                obj = getattr(cls, attr_name)
+                if hasattr(obj, "api_cmd"):
+                    # method is decorated with our api decorator
+                    self.register_api_command(obj.api_cmd, obj)
+
+    async def _load_providers(self) -> None:
+        """Load providers from config."""
+        # create default config for any 'builtin' providers (e.g. URL provider)
+        for prov_manifest in self._provider_manifests.values():
+            if not prov_manifest.builtin:
+                continue
+            await self.config.create_builtin_provider_config(prov_manifest.domain)
+
+        # load all configured (and enabled) providers
+        prov_configs = await self.config.get_provider_configs(include_values=True)
+        async with asyncio.TaskGroup() as tg:
+            for prov_conf in prov_configs:
+                if not prov_conf.enabled:
+                    continue
+                tg.create_task(self.load_provider(prov_conf))
+
+    async def _load_provider(self, conf: ProviderConfig) -> None:
         """Load (or reload) a provider."""
         # if provider is already loaded, stop and unload it first
         await self.unload_provider(conf.instance_id)
@@ -480,79 +569,6 @@ class MusicAssistant:
         if provider.type == ProviderType.MUSIC:
             self.music.start_sync(providers=[provider.instance_id])
 
-    async def unload_provider(self, instance_id: str) -> None:
-        """Unload a provider."""
-        if provider := self._providers.get(instance_id):
-            # remove mdns discovery if needed
-            if provider.manifest.mdns_discovery:
-                for mdns_type in provider.manifest.mdns_discovery:
-                    self._aiobrowser.types.discard(mdns_type)
-            # make sure to stop any running sync tasks first
-            for sync_task in self.music.in_progress_syncs:
-                if sync_task.provider_instance == instance_id:
-                    sync_task.task.cancel()
-            # check if there are no other providers dependent of this provider
-            for dep_prov in self.providers:
-                if dep_prov.manifest.depends_on == provider.domain:
-                    await self.unload_provider(dep_prov.instance_id)
-            try:
-                await provider.unload()
-            except Exception as err:
-                LOGGER.warning("Error while unload provider %s: %s", provider.name, str(err))
-            finally:
-                self._providers.pop(instance_id, None)
-                await self._update_available_providers_cache()
-                self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
-
-    def _register_api_commands(self) -> None:
-        """Register all methods decorated as api_command within a class(instance)."""
-        for cls in (
-            self,
-            self.config,
-            self.metadata,
-            self.music,
-            self.players,
-            self.player_queues,
-        ):
-            for attr_name in dir(cls):
-                if attr_name.startswith("__"):
-                    continue
-                obj = getattr(cls, attr_name)
-                if hasattr(obj, "api_cmd"):
-                    # method is decorated with our api decorator
-                    self.register_api_command(obj.api_cmd, obj)
-
-    async def _load_providers(self) -> None:
-        """Load providers from config."""
-        # create default config for any 'builtin' providers (e.g. URL provider)
-        for prov_manifest in self._provider_manifests.values():
-            if not prov_manifest.builtin:
-                continue
-            await self.config.create_builtin_provider_config(prov_manifest.domain)
-
-        async def load_provider(prov_conf: ProviderConfig) -> None:
-            """Try to load a provider and catch errors."""
-            try:
-                await self.load_provider(prov_conf)
-            # pylint: disable=broad-except
-            except Exception as exc:
-                LOGGER.exception(
-                    "Error loading provider(instance) %s",
-                    prov_conf.name or prov_conf.domain,
-                )
-                # if loading failed, we store the error in the config object
-                # so we can show something useful to the user
-                prov_conf.last_error = str(exc)
-                self.config.set(f"{CONF_PROVIDERS}/{prov_conf.instance_id}/last_error", str(exc))
-
-        # load all configured (and enabled) providers
-        prov_configs = await self.config.get_provider_configs(include_values=True)
-        async with asyncio.TaskGroup() as tg:
-            for prov_conf in prov_configs:
-                if not prov_conf.enabled:
-                    continue
-                tg.create_task(load_provider(prov_conf))
-
     async def __load_provider_manifests(self) -> None:
         """Preload all available provider manifest files."""
 
index e27d7545dddd424218ab1839eda40f416f562106..434c5bd06a8290a87d54f12141f1ebdd405f81a6 100644 (file)
@@ -17,7 +17,7 @@ deezer-python-async==0.3.0
 defusedxml==0.7.1
 faust-cchardet>=2.1.18
 git+https://github.com/MarvinSchenkel/pytube.git
-hass-client==1.0.1
+hass-client==1.1.0
 ifaddr==0.2.0
 jellyfin_apiclient_python==1.9.2
 mashumaro==3.13