From 498fb9db91593c36c613fecffaf7d96ba3f13fab Mon Sep 17 00:00:00 2001 From: Marcel van der Veldt Date: Fri, 14 Jun 2024 22:08:46 +0200 Subject: [PATCH] Auto retry provider load if its unavailable or the connection gets lost (#1364) --- music_assistant/server/controllers/config.py | 2 +- .../server/providers/hass/__init__.py | 17 ++ .../server/providers/hass/manifest.json | 2 +- .../server/providers/hass_players/__init__.py | 23 ++- music_assistant/server/server.py | 164 ++++++++++-------- requirements_all.txt | 2 +- 6 files changed, 130 insertions(+), 80 deletions(-) diff --git a/music_assistant/server/controllers/config.py b/music_assistant/server/controllers/config.py index 95ed6f46..99a59999 100644 --- a/music_assistant/server/controllers/config.py +++ b/music_assistant/server/controllers/config.py @@ -803,7 +803,7 @@ class ConfigController: # validate the new config config.validate() # try to load the provider first to catch errors before we save it. - await self.mass.load_provider(config) + await self.mass.load_provider(config, raise_on_error=True) # the load was a success, store this config conf_key = f"{CONF_PROVIDERS}/{config.instance_id}" self.set(conf_key, config.to_raw()) diff --git a/music_assistant/server/providers/hass/__init__.py b/music_assistant/server/providers/hass/__init__.py index 463b75d0..0d41f3ef 100644 --- a/music_assistant/server/providers/hass/__init__.py +++ b/music_assistant/server/providers/hass/__init__.py @@ -9,11 +9,13 @@ communication over the HA api for more flexibility as well as security. from __future__ import annotations +import asyncio import logging from typing import TYPE_CHECKING import shortuuid from hass_client import HomeAssistantClient +from hass_client.exceptions import BaseHassClientError from hass_client.utils import ( async_is_supervisor, base_url, @@ -152,6 +154,7 @@ class HomeAssistant(PluginProvider): """Home Assistant Plugin for Music Assistant.""" hass: HomeAssistantClient + _listen_task: asyncio.Task | None = None async def handle_async_init(self) -> None: """Handle async initialization of the plugin.""" @@ -160,6 +163,7 @@ class HomeAssistant(PluginProvider): logging.getLogger("hass_client").setLevel(self.logger.level + 10) self.hass = HomeAssistantClient(url, token, self.mass.http_session) await self.hass.connect() + self._listen_task = self.mass.create_task(self._hass_listener()) async def unload(self) -> None: """ @@ -167,4 +171,17 @@ class HomeAssistant(PluginProvider): Called when provider is deregistered (e.g. MA exiting or config reloading). """ + if self._listen_task and not self._listen_task.done(): + self._listen_task.cancel() await self.hass.disconnect() + + async def _hass_listener(self) -> None: + """Start listening on the HA websockets.""" + try: + # start listening will block until the connection is lost/closed + await self.hass.start_listening() + except BaseHassClientError as err: + self.logger.warning("Connection to HA lost due to error: %s", err) + self.logger.info("Connection to HA lost. Reloading provider in 5 seconds.") + # schedule a reload of the provider + self.mass.call_later(5, self.mass.config.reload_provider(self.instance_id)) diff --git a/music_assistant/server/providers/hass/manifest.json b/music_assistant/server/providers/hass/manifest.json index fb02b954..7726c324 100644 --- a/music_assistant/server/providers/hass/manifest.json +++ b/music_assistant/server/providers/hass/manifest.json @@ -12,6 +12,6 @@ "load_by_default": false, "icon": "md:webhook", "requirements": [ - "hass-client==1.0.1" + "hass-client==1.1.0" ] } diff --git a/music_assistant/server/providers/hass_players/__init__.py b/music_assistant/server/providers/hass_players/__init__.py index e3a66ffb..1e303372 100644 --- a/music_assistant/server/providers/hass_players/__init__.py +++ b/music_assistant/server/providers/hass_players/__init__.py @@ -50,7 +50,7 @@ CONF_PLAYERS = "players" StateMap = { "playing": PlayerState.PLAYING, - "paused": PlayerState.PLAYING, + "paused": PlayerState.PAUSED, "buffering": PlayerState.PLAYING, "idle": PlayerState.IDLE, "off": PlayerState.IDLE, @@ -350,7 +350,6 @@ class HomeAssistantPlayers(PlayerProvider): device_registry: dict[str, HassDevice], ) -> None: """Handle setup of a Player from an hass entity.""" - # fetch the entity registry entry for this entity to obtain more details hass_device: HassDevice | None = None platform_players: list[str] = [] if entity_registry_entry := entity_registry.get(state["entity_id"]): @@ -408,7 +407,13 @@ class HomeAssistantPlayers(PlayerProvider): """Handle updating MA player with updated info in a HA CompressedState.""" player = self.mass.players.get(entity_id) if player is None: - return # should not happen, but guard just in case + # edge case - one of our subscribed entities was not available at startup + # and now came available - we should still set it up + player_ids: list[str] = self.config.get_value(CONF_PLAYERS) + if entity_id not in player_ids: + return # should not happen, but guard just in case + self.mass.create_task(self._late_add_player(entity_id)) + return if "s" in state: player.state = StateMap.get(state["s"], PlayerState.IDLE) player.powered = state["s"] not in ( @@ -453,3 +458,15 @@ class HomeAssistantPlayers(PlayerProvider): else: player.group_childs = set() player.synced_to = None + + async def _late_add_player(self, entity_id: str) -> None: + """Handle setup of Player from HA entity that became available after startup.""" + # prefetch the device- and entity registry + device_registry = {x["id"]: x for x in await self.hass_prov.hass.get_device_registry()} + entity_registry = { + x["entity_id"]: x for x in await self.hass_prov.hass.get_entity_registry() + } + async for state in _get_hass_media_players(self.hass_prov): + if state["entity_id"] != entity_id: + continue + await self._setup_player(state, entity_registry, device_registry) diff --git a/music_assistant/server/server.py b/music_assistant/server/server.py index 9be9e594..9443d4cd 100644 --- a/music_assistant/server/server.py +++ b/music_assistant/server/server.py @@ -405,7 +405,96 @@ class MusicAssistant: raise RuntimeError(msg) self.command_handlers[command] = APICommandHandler.parse(command, handler) - async def load_provider(self, conf: ProviderConfig) -> None: + async def load_provider( + self, + prov_conf: ProviderConfig, + raise_on_error: bool = False, + schedule_retry: int | None = 10, + ) -> None: + """Try to load a provider and catch errors.""" + try: + await self._load_provider(prov_conf) + # pylint: disable=broad-except + except Exception as exc: + LOGGER.exception( + "Error loading provider(instance) %s", + prov_conf.name or prov_conf.domain, + ) + if raise_on_error: + raise + # if loading failed, we store the error in the config object + # so we can show something useful to the user + prov_conf.last_error = str(exc) + self.config.set(f"{CONF_PROVIDERS}/{prov_conf.instance_id}/last_error", str(exc)) + # auto schedule a retry if the (re)load failed + if schedule_retry: + self.call_later( + schedule_retry, + self.load_provider, + prov_conf, + raise_on_error, + min(schedule_retry + 10, 600), + ) + + async def unload_provider(self, instance_id: str) -> None: + """Unload a provider.""" + if provider := self._providers.get(instance_id): + # remove mdns discovery if needed + if provider.manifest.mdns_discovery: + for mdns_type in provider.manifest.mdns_discovery: + self._aiobrowser.types.discard(mdns_type) + # make sure to stop any running sync tasks first + for sync_task in self.music.in_progress_syncs: + if sync_task.provider_instance == instance_id: + sync_task.task.cancel() + # check if there are no other providers dependent of this provider + for dep_prov in self.providers: + if dep_prov.manifest.depends_on == provider.domain: + await self.unload_provider(dep_prov.instance_id) + try: + await provider.unload() + except Exception as err: + LOGGER.warning("Error while unload provider %s: %s", provider.name, str(err)) + finally: + self._providers.pop(instance_id, None) + await self._update_available_providers_cache() + self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers()) + + def _register_api_commands(self) -> None: + """Register all methods decorated as api_command within a class(instance).""" + for cls in ( + self, + self.config, + self.metadata, + self.music, + self.players, + self.player_queues, + ): + for attr_name in dir(cls): + if attr_name.startswith("__"): + continue + obj = getattr(cls, attr_name) + if hasattr(obj, "api_cmd"): + # method is decorated with our api decorator + self.register_api_command(obj.api_cmd, obj) + + async def _load_providers(self) -> None: + """Load providers from config.""" + # create default config for any 'builtin' providers (e.g. URL provider) + for prov_manifest in self._provider_manifests.values(): + if not prov_manifest.builtin: + continue + await self.config.create_builtin_provider_config(prov_manifest.domain) + + # load all configured (and enabled) providers + prov_configs = await self.config.get_provider_configs(include_values=True) + async with asyncio.TaskGroup() as tg: + for prov_conf in prov_configs: + if not prov_conf.enabled: + continue + tg.create_task(self.load_provider(prov_conf)) + + async def _load_provider(self, conf: ProviderConfig) -> None: """Load (or reload) a provider.""" # if provider is already loaded, stop and unload it first await self.unload_provider(conf.instance_id) @@ -480,79 +569,6 @@ class MusicAssistant: if provider.type == ProviderType.MUSIC: self.music.start_sync(providers=[provider.instance_id]) - async def unload_provider(self, instance_id: str) -> None: - """Unload a provider.""" - if provider := self._providers.get(instance_id): - # remove mdns discovery if needed - if provider.manifest.mdns_discovery: - for mdns_type in provider.manifest.mdns_discovery: - self._aiobrowser.types.discard(mdns_type) - # make sure to stop any running sync tasks first - for sync_task in self.music.in_progress_syncs: - if sync_task.provider_instance == instance_id: - sync_task.task.cancel() - # check if there are no other providers dependent of this provider - for dep_prov in self.providers: - if dep_prov.manifest.depends_on == provider.domain: - await self.unload_provider(dep_prov.instance_id) - try: - await provider.unload() - except Exception as err: - LOGGER.warning("Error while unload provider %s: %s", provider.name, str(err)) - finally: - self._providers.pop(instance_id, None) - await self._update_available_providers_cache() - self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers()) - - def _register_api_commands(self) -> None: - """Register all methods decorated as api_command within a class(instance).""" - for cls in ( - self, - self.config, - self.metadata, - self.music, - self.players, - self.player_queues, - ): - for attr_name in dir(cls): - if attr_name.startswith("__"): - continue - obj = getattr(cls, attr_name) - if hasattr(obj, "api_cmd"): - # method is decorated with our api decorator - self.register_api_command(obj.api_cmd, obj) - - async def _load_providers(self) -> None: - """Load providers from config.""" - # create default config for any 'builtin' providers (e.g. URL provider) - for prov_manifest in self._provider_manifests.values(): - if not prov_manifest.builtin: - continue - await self.config.create_builtin_provider_config(prov_manifest.domain) - - async def load_provider(prov_conf: ProviderConfig) -> None: - """Try to load a provider and catch errors.""" - try: - await self.load_provider(prov_conf) - # pylint: disable=broad-except - except Exception as exc: - LOGGER.exception( - "Error loading provider(instance) %s", - prov_conf.name or prov_conf.domain, - ) - # if loading failed, we store the error in the config object - # so we can show something useful to the user - prov_conf.last_error = str(exc) - self.config.set(f"{CONF_PROVIDERS}/{prov_conf.instance_id}/last_error", str(exc)) - - # load all configured (and enabled) providers - prov_configs = await self.config.get_provider_configs(include_values=True) - async with asyncio.TaskGroup() as tg: - for prov_conf in prov_configs: - if not prov_conf.enabled: - continue - tg.create_task(load_provider(prov_conf)) - async def __load_provider_manifests(self) -> None: """Preload all available provider manifest files.""" diff --git a/requirements_all.txt b/requirements_all.txt index e27d7545..434c5bd0 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -17,7 +17,7 @@ deezer-python-async==0.3.0 defusedxml==0.7.1 faust-cchardet>=2.1.18 git+https://github.com/MarvinSchenkel/pytube.git -hass-client==1.0.1 +hass-client==1.1.0 ifaddr==0.2.0 jellyfin_apiclient_python==1.9.2 mashumaro==3.13 -- 2.34.1