from __future__ import annotations
+import asyncio
import logging
from typing import TYPE_CHECKING
import shortuuid
from hass_client import HomeAssistantClient
+from hass_client.exceptions import BaseHassClientError
from hass_client.utils import (
async_is_supervisor,
base_url,
"""Home Assistant Plugin for Music Assistant."""
hass: HomeAssistantClient
+ _listen_task: asyncio.Task | None = None
async def handle_async_init(self) -> None:
"""Handle async initialization of the plugin."""
logging.getLogger("hass_client").setLevel(self.logger.level + 10)
self.hass = HomeAssistantClient(url, token, self.mass.http_session)
await self.hass.connect()
+ self._listen_task = self.mass.create_task(self._hass_listener())
async def unload(self) -> None:
"""
Called when provider is deregistered (e.g. MA exiting or config reloading).
"""
+ if self._listen_task and not self._listen_task.done():
+ self._listen_task.cancel()
await self.hass.disconnect()
+
+ async def _hass_listener(self) -> None:
+ """Start listening on the HA websockets."""
+ try:
+ # start listening will block until the connection is lost/closed
+ await self.hass.start_listening()
+ except BaseHassClientError as err:
+ self.logger.warning("Connection to HA lost due to error: %s", err)
+ self.logger.info("Connection to HA lost. Reloading provider in 5 seconds.")
+ # schedule a reload of the provider
+ self.mass.call_later(5, self.mass.config.reload_provider(self.instance_id))
StateMap = {
"playing": PlayerState.PLAYING,
- "paused": PlayerState.PLAYING,
+ "paused": PlayerState.PAUSED,
"buffering": PlayerState.PLAYING,
"idle": PlayerState.IDLE,
"off": PlayerState.IDLE,
device_registry: dict[str, HassDevice],
) -> None:
"""Handle setup of a Player from an hass entity."""
- # fetch the entity registry entry for this entity to obtain more details
hass_device: HassDevice | None = None
platform_players: list[str] = []
if entity_registry_entry := entity_registry.get(state["entity_id"]):
"""Handle updating MA player with updated info in a HA CompressedState."""
player = self.mass.players.get(entity_id)
if player is None:
- return # should not happen, but guard just in case
+ # edge case - one of our subscribed entities was not available at startup
+ # and now came available - we should still set it up
+ player_ids: list[str] = self.config.get_value(CONF_PLAYERS)
+ if entity_id not in player_ids:
+ return # should not happen, but guard just in case
+ self.mass.create_task(self._late_add_player(entity_id))
+ return
if "s" in state:
player.state = StateMap.get(state["s"], PlayerState.IDLE)
player.powered = state["s"] not in (
else:
player.group_childs = set()
player.synced_to = None
+
+ async def _late_add_player(self, entity_id: str) -> None:
+ """Handle setup of Player from HA entity that became available after startup."""
+ # prefetch the device- and entity registry
+ device_registry = {x["id"]: x for x in await self.hass_prov.hass.get_device_registry()}
+ entity_registry = {
+ x["entity_id"]: x for x in await self.hass_prov.hass.get_entity_registry()
+ }
+ async for state in _get_hass_media_players(self.hass_prov):
+ if state["entity_id"] != entity_id:
+ continue
+ await self._setup_player(state, entity_registry, device_registry)
raise RuntimeError(msg)
self.command_handlers[command] = APICommandHandler.parse(command, handler)
- async def load_provider(self, conf: ProviderConfig) -> None:
+ async def load_provider(
+ self,
+ prov_conf: ProviderConfig,
+ raise_on_error: bool = False,
+ schedule_retry: int | None = 10,
+ ) -> None:
+ """Try to load a provider and catch errors."""
+ try:
+ await self._load_provider(prov_conf)
+ # pylint: disable=broad-except
+ except Exception as exc:
+ LOGGER.exception(
+ "Error loading provider(instance) %s",
+ prov_conf.name or prov_conf.domain,
+ )
+ if raise_on_error:
+ raise
+ # if loading failed, we store the error in the config object
+ # so we can show something useful to the user
+ prov_conf.last_error = str(exc)
+ self.config.set(f"{CONF_PROVIDERS}/{prov_conf.instance_id}/last_error", str(exc))
+ # auto schedule a retry if the (re)load failed
+ if schedule_retry:
+ self.call_later(
+ schedule_retry,
+ self.load_provider,
+ prov_conf,
+ raise_on_error,
+ min(schedule_retry + 10, 600),
+ )
+
+ async def unload_provider(self, instance_id: str) -> None:
+ """Unload a provider."""
+ if provider := self._providers.get(instance_id):
+ # remove mdns discovery if needed
+ if provider.manifest.mdns_discovery:
+ for mdns_type in provider.manifest.mdns_discovery:
+ self._aiobrowser.types.discard(mdns_type)
+ # make sure to stop any running sync tasks first
+ for sync_task in self.music.in_progress_syncs:
+ if sync_task.provider_instance == instance_id:
+ sync_task.task.cancel()
+ # check if there are no other providers dependent of this provider
+ for dep_prov in self.providers:
+ if dep_prov.manifest.depends_on == provider.domain:
+ await self.unload_provider(dep_prov.instance_id)
+ try:
+ await provider.unload()
+ except Exception as err:
+ LOGGER.warning("Error while unload provider %s: %s", provider.name, str(err))
+ finally:
+ self._providers.pop(instance_id, None)
+ await self._update_available_providers_cache()
+ self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
+
+ def _register_api_commands(self) -> None:
+ """Register all methods decorated as api_command within a class(instance)."""
+ for cls in (
+ self,
+ self.config,
+ self.metadata,
+ self.music,
+ self.players,
+ self.player_queues,
+ ):
+ for attr_name in dir(cls):
+ if attr_name.startswith("__"):
+ continue
+ obj = getattr(cls, attr_name)
+ if hasattr(obj, "api_cmd"):
+ # method is decorated with our api decorator
+ self.register_api_command(obj.api_cmd, obj)
+
+ async def _load_providers(self) -> None:
+ """Load providers from config."""
+ # create default config for any 'builtin' providers (e.g. URL provider)
+ for prov_manifest in self._provider_manifests.values():
+ if not prov_manifest.builtin:
+ continue
+ await self.config.create_builtin_provider_config(prov_manifest.domain)
+
+ # load all configured (and enabled) providers
+ prov_configs = await self.config.get_provider_configs(include_values=True)
+ async with asyncio.TaskGroup() as tg:
+ for prov_conf in prov_configs:
+ if not prov_conf.enabled:
+ continue
+ tg.create_task(self.load_provider(prov_conf))
+
+ async def _load_provider(self, conf: ProviderConfig) -> None:
"""Load (or reload) a provider."""
# if provider is already loaded, stop and unload it first
await self.unload_provider(conf.instance_id)
if provider.type == ProviderType.MUSIC:
self.music.start_sync(providers=[provider.instance_id])
- async def unload_provider(self, instance_id: str) -> None:
- """Unload a provider."""
- if provider := self._providers.get(instance_id):
- # remove mdns discovery if needed
- if provider.manifest.mdns_discovery:
- for mdns_type in provider.manifest.mdns_discovery:
- self._aiobrowser.types.discard(mdns_type)
- # make sure to stop any running sync tasks first
- for sync_task in self.music.in_progress_syncs:
- if sync_task.provider_instance == instance_id:
- sync_task.task.cancel()
- # check if there are no other providers dependent of this provider
- for dep_prov in self.providers:
- if dep_prov.manifest.depends_on == provider.domain:
- await self.unload_provider(dep_prov.instance_id)
- try:
- await provider.unload()
- except Exception as err:
- LOGGER.warning("Error while unload provider %s: %s", provider.name, str(err))
- finally:
- self._providers.pop(instance_id, None)
- await self._update_available_providers_cache()
- self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
-
- def _register_api_commands(self) -> None:
- """Register all methods decorated as api_command within a class(instance)."""
- for cls in (
- self,
- self.config,
- self.metadata,
- self.music,
- self.players,
- self.player_queues,
- ):
- for attr_name in dir(cls):
- if attr_name.startswith("__"):
- continue
- obj = getattr(cls, attr_name)
- if hasattr(obj, "api_cmd"):
- # method is decorated with our api decorator
- self.register_api_command(obj.api_cmd, obj)
-
- async def _load_providers(self) -> None:
- """Load providers from config."""
- # create default config for any 'builtin' providers (e.g. URL provider)
- for prov_manifest in self._provider_manifests.values():
- if not prov_manifest.builtin:
- continue
- await self.config.create_builtin_provider_config(prov_manifest.domain)
-
- async def load_provider(prov_conf: ProviderConfig) -> None:
- """Try to load a provider and catch errors."""
- try:
- await self.load_provider(prov_conf)
- # pylint: disable=broad-except
- except Exception as exc:
- LOGGER.exception(
- "Error loading provider(instance) %s",
- prov_conf.name or prov_conf.domain,
- )
- # if loading failed, we store the error in the config object
- # so we can show something useful to the user
- prov_conf.last_error = str(exc)
- self.config.set(f"{CONF_PROVIDERS}/{prov_conf.instance_id}/last_error", str(exc))
-
- # load all configured (and enabled) providers
- prov_configs = await self.config.get_provider_configs(include_values=True)
- async with asyncio.TaskGroup() as tg:
- for prov_conf in prov_configs:
- if not prov_conf.enabled:
- continue
- tg.create_task(load_provider(prov_conf))
-
async def __load_provider_manifests(self) -> None:
"""Preload all available provider manifest files."""