Fix player lifecycle (enabling/disabling and config updates) (#3024)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Mon, 26 Jan 2026 21:47:40 +0000 (22:47 +0100)
committerGitHub <noreply@github.com>
Mon, 26 Jan 2026 21:47:40 +0000 (22:47 +0100)
music_assistant/controllers/config.py
music_assistant/controllers/players/player_controller.py
music_assistant/mass.py
music_assistant/models/player_provider.py
music_assistant/providers/_demo_player_provider/provider.py
music_assistant/providers/chromecast/provider.py
music_assistant/providers/dlna/player.py
music_assistant/providers/heos/provider.py
music_assistant/providers/roku_media_assistant/provider.py

index ada5827d41a11ada60d8fe41b2149f48a2c18997..b12c8860da4e5932dd2da91c3349087ba07ebb19 100644 (file)
@@ -7,6 +7,7 @@ import base64
 import contextlib
 import logging
 import os
+from copy import deepcopy
 from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload
 from uuid import uuid4
 
@@ -749,15 +750,21 @@ class ConfigController:
     ) -> PlayerConfig:
         """Save/update PlayerConfig."""
         config = await self.get_player_config(player_id)
+        old_config = deepcopy(config)
         changed_keys = config.update(values)
         if not changed_keys:
             # no changes
             return config
-        # validate/handle the update in the player manager
-        await self.mass.players.on_player_config_change(config, changed_keys)
-        # actually store changes (if the above did not raise)
+        # store updated config first (to prevent issues with enabling/disabling players)
         conf_key = f"{CONF_PLAYERS}/{player_id}"
         self.set(conf_key, config.to_raw())
+        try:
+            # validate/handle the update in the player manager
+            await self.mass.players.on_player_config_change(config, changed_keys)
+        except Exception:
+            # rollback on error
+            self.set(conf_key, old_config.to_raw())
+            raise
         # send config updated event
         self.mass.signal_event(
             EventType.PLAYER_CONFIG_UPDATED,
index 2a6b2bd5fae522c49e1244f502b239e5153cbe14..af8744f2b6703a6b87953a9aaedf1dc20d5566d0 100644 (file)
@@ -1761,31 +1761,40 @@ class PlayerController(CoreController):
 
     async def on_player_config_change(self, config: PlayerConfig, changed_keys: set[str]) -> None:
         """Call (by config manager) when the configuration of a player changes."""
+        player = self.get(config.player_id)
+        player_provider = self.mass.get_provider(config.provider)
         player_disabled = ATTR_ENABLED in changed_keys and not config.enabled
+        player_enabled = ATTR_ENABLED in changed_keys and config.enabled
+
+        if player_disabled and player and player.available:
+            # edge case: ensure that the player is powered off if the player gets disabled
+            if player.power_control != PLAYER_CONTROL_NONE:
+                await self._handle_cmd_power(config.player_id, False)
+            elif player.playback_state != PlaybackState.IDLE:
+                await self.cmd_stop(config.player_id)
+
         # signal player provider that the player got enabled/disabled
-        if player_provider := self.mass.get_provider(config.provider):
+        if (player_enabled or player_disabled) and player_provider:
             assert isinstance(player_provider, PlayerProvider)  # for type checking
-            if ATTR_ENABLED in changed_keys and not config.enabled:
+            if player_disabled:
                 player_provider.on_player_disabled(config.player_id)
-            elif ATTR_ENABLED in changed_keys and config.enabled:
+            elif player_enabled:
                 player_provider.on_player_enabled(config.player_id)
-        if not (player := self.get(config.player_id)):
+            return  # enabling/disabling a player will be handled by the provider
+
+        if not player:
             return  # guard against player not being registered (yet)
+
         resume_queue: PlayerQueue | None = (
             self.mass.player_queues.get(player.active_source) if player.active_source else None
         )
-        if player_disabled and player.available:
-            # edge case: ensure that the player is powered off if the player gets disabled
-            if player.power_control != PLAYER_CONTROL_NONE:
-                await self._handle_cmd_power(config.player_id, False)
-            elif player.playback_state != PlaybackState.IDLE:
-                await self.cmd_stop(config.player_id)
+
         # ensure player state gets updated with any updated config
         player.set_config(config)
         await player.on_config_updated()
         player.update_state()
         # if the PlayerQueue was playing, restart playback
-        if not player_disabled and resume_queue and resume_queue.state == PlaybackState.PLAYING:
+        if resume_queue and resume_queue.state == PlaybackState.PLAYING:
             requires_restart = any(
                 v for v in config.values.values() if v.key in changed_keys and v.requires_reload
             )
index a5b85e9c04d3f46d7d708e133b39fe6002db17a1..c39e098092e4d16affd231f502066fa2dd39943f 100644 (file)
@@ -704,6 +704,31 @@ class MusicAssistant:
         self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", error)
         await self.unload_provider(instance_id)
 
+    async def run_provider_discovery(self, instance_id: str) -> None:
+        """
+        Run mDNS discovery for a given provider.
+
+        In case of a PlayerProvider, will also call its own discovery method.
+        """
+        provider = self.get_provider(instance_id, return_unavailable=False)
+        if not provider:
+            raise KeyError(f"Provider with instance ID {instance_id} not found")
+        if provider.manifest.mdns_discovery:
+            if provider.instance_id not in self._mdns_locks:
+                self._mdns_locks[provider.instance_id] = asyncio.Lock()
+            async with self._mdns_locks[provider.instance_id]:
+                for mdns_type in provider.manifest.mdns_discovery or []:
+                    for mdns_name in set(self.aiozc.zeroconf.cache.cache):
+                        if mdns_type not in mdns_name or mdns_type == mdns_name:
+                            continue
+                        info = AsyncServiceInfo(mdns_type, mdns_name)
+                        if await info.async_request(self.aiozc.zeroconf, 3000):
+                            await provider.on_mdns_service_state_change(
+                                mdns_name, ServiceStateChange.Added, info
+                            )
+        if isinstance(provider, PlayerProvider):
+            await provider.discover_players()
+
     def verify_event_loop_thread(self, what: str) -> None:
         """Report and raise if we are not running in the event loop thread."""
         if self.loop_thread_id != threading.get_ident():
@@ -757,7 +782,7 @@ class MusicAssistant:
             # If a provider fails, that will not block the loading of other providers.
             self.create_task(self.load_provider(prov_conf.instance_id, allow_retry=True))
 
-    async def _load_provider(self, conf: ProviderConfig) -> None:  # noqa: PLR0915
+    async def _load_provider(self, conf: ProviderConfig) -> None:
         """Load (or reload) a provider."""
         # if provider is already loaded, stop and unload it first
         await self.unload_provider(conf.instance_id)
@@ -815,21 +840,7 @@ class MusicAssistant:
         # execute post load actions
         async def _on_provider_loaded() -> None:
             await provider.loaded_in_mass()
-            if provider.type != ProviderType.PLAYER:
-                return
-            # add mdns discovery if needed
-            if provider.instance_id not in self._mdns_locks:
-                self._mdns_locks[provider.instance_id] = asyncio.Lock()
-            async with self._mdns_locks[provider.instance_id]:
-                for mdns_type in provider.manifest.mdns_discovery or []:
-                    for mdns_name in set(self.aiozc.zeroconf.cache.cache):
-                        if mdns_type not in mdns_name or mdns_type == mdns_name:
-                            continue
-                        info = AsyncServiceInfo(mdns_type, mdns_name)
-                        if await info.async_request(self.aiozc.zeroconf, 3000):
-                            await provider.on_mdns_service_state_change(
-                                mdns_name, ServiceStateChange.Added, info
-                            )
+            await self.run_provider_discovery(provider.instance_id)
 
         self.create_task(_on_provider_loaded())
 
index 51877417f14ef536b6bb65425b840e97c654bc1c..b16d95f2d388c0c74b93c66bf65c4088cbc39845 100644 (file)
@@ -19,16 +19,19 @@ class PlayerProvider(Provider):
 
     async def loaded_in_mass(self) -> None:
         """Call after the provider has been loaded."""
-        await self.discover_players()
 
     def on_player_enabled(self, player_id: str) -> None:
         """Call (by config manager) when a player gets enabled."""
         # default implementation: trigger discovery - feel free to override
         task_id = f"discover_players_{self.instance_id}"
-        self.mass.call_later(5, self.discover_players, task_id=task_id)
+        self.mass.call_later(5, self.mass.run_provider_discovery, self.instance_id, task_id=task_id)
 
     def on_player_disabled(self, player_id: str) -> None:
         """Call (by config manager) when a player gets disabled."""
+        # default implementation: unregister player from player controller
+        # which will also trigger an unload on the player instance
+        # feel free to override with a better implementation
+        self.mass.create_task(self.mass.players.unregister(player_id))
 
     async def remove_player(self, player_id: str) -> None:
         """Remove a player from this provider."""
index be7643ae1cd9a9bee9c286cc8f674cc29ab34772..49617a44220a3923e77c532283117994240e0da0 100644 (file)
@@ -49,10 +49,7 @@ class DemoPlayerprovider(PlayerProvider):
         # this is an optional method that you can implement if
         # relevant or leave out completely if not needed.
         # it will be called after the provider has been fully loaded into Music Assistant.
-        # you can use this for instance to trigger custom (non-mdns) discovery of players
-        # or any other logic that needs to run after the provider is fully loaded.
         self.logger.info("DemoPlayerProvider loaded")
-        await self.discover_players()
 
     async def unload(self, is_removed: bool = False) -> None:
         """
@@ -77,6 +74,7 @@ class DemoPlayerprovider(PlayerProvider):
         # OPTIONAL
         # this is an optional method that you can implement if
         # you want to do something special when a player is enabled.
+        super().on_player_enabled(player_id)
 
     def on_player_disabled(self, player_id: str) -> None:
         """Call (by config manager) when a player gets disabled."""
@@ -84,6 +82,7 @@ class DemoPlayerprovider(PlayerProvider):
         # this is an optional method that you can implement if
         # you want to do something special when a player is disabled.
         # e.g. you can stop polling the player or disconnect from it.
+        super().on_player_disabled(player_id)
 
     async def remove_player(self, player_id: str) -> None:
         """Remove a player from this provider."""
index 7255b8fed69c853471e6d4b6bcd4eaea4057df39..be72ed72dfae3942ffb7d90dae0832b359c73528 100644 (file)
@@ -56,6 +56,7 @@ class ChromecastProvider(PlayerProvider):
             self.mass.aiozc.zeroconf,
             known_hosts=manual_ip_config,
         )
+        self._discovery_running = False
         # set-up pychromecast logging
         if self.logger.isEnabledFor(VERBOSE_LOG_LEVEL):
             logging.getLogger("pychromecast").setLevel(logging.DEBUG)
@@ -64,6 +65,9 @@ class ChromecastProvider(PlayerProvider):
 
     async def discover_players(self) -> None:
         """Discover Cast players on the network."""
+        if self._discovery_running:
+            return
+        self._discovery_running = True
         assert self.browser is not None  # for type checking
         await self.mass.loop.run_in_executor(None, self.browser.start_discovery)
 
@@ -82,6 +86,7 @@ class ChromecastProvider(PlayerProvider):
 
             self.browser.host_browser.stop.set()
             self.browser.host_browser.join()
+            self._discovery_running = False
 
         await self.mass.loop.run_in_executor(None, stop_discovery)
 
index e230ff08622d107ac332f269c58e79a19136b784..69948c37b753d2187fee6b0d4f862eec0d676b46 100644 (file)
@@ -270,19 +270,6 @@ class DLNAPlayer(Player):
         """Return all (provider/player specific) Config Entries for the given player (if any)."""
         return [*PLAYER_CONFIG_ENTRIES]
 
-    # async def on_player_config_change(
-    #     self,
-    #     config: PlayerConfig,
-    #     changed_keys: set[str],
-    # ) -> None:
-    #     """Call (by config manager) when the configuration of a player changes."""
-    #     if dlna_player := self.dlnaplayers.get(config.player_id):
-    #         # reset player features based on config values
-    #         self._set_player_features(dlna_player)
-    #     else:
-    #         # run discovery to catch any re-enabled players
-    #         self.mass.create_task(self.discover_players())
-
     # COMMANDS
     @catch_request_errors
     async def stop(self) -> None:
index c6dc950a3cb383775eed268c4fe6c5697d6cc78d..516169e5941589b9b8e47887c4f69ab3afcc5ab6 100644 (file)
@@ -8,10 +8,7 @@ from music_assistant_models.errors import SetupFailedError
 from music_assistant_models.player import PlayerSource
 from pyheos import Heos, HeosError, HeosOptions, MediaItem, PlayerUpdateResult, const
 
-from music_assistant.constants import (
-    CONF_IP_ADDRESS,
-    VERBOSE_LOG_LEVEL,
-)
+from music_assistant.constants import CONF_ENABLED, CONF_IP_ADDRESS, VERBOSE_LOG_LEVEL
 from music_assistant.models.player_provider import PlayerProvider
 from music_assistant.providers.heos.constants import HEOS_PASSIVE_SOURCES
 
@@ -24,6 +21,7 @@ class HeosPlayerProvider(PlayerProvider):
     _heos: Heos
     _music_source_list: list[PlayerSource] = []
     _input_source_list: list[MediaItem] = []
+    _discovery_running: bool = False
 
     async def handle_async_init(self) -> None:
         """Handle async initialization of the provider."""
@@ -51,13 +49,7 @@ class HeosPlayerProvider(PlayerProvider):
         try:
             # Populate source lists
             await self._populate_sources()
-
-            # Build player configs
-            devices = await self._heos.get_players()
-            for device in devices.values():
-                heos_player = HeosPlayer(self, device)
-
-                await heos_player.setup()
+            # NOTE: players are discovered via discovery method (called automatically by core)
         except HeosError as e:
             self.logger.error(f"Unexpected error setting up HEOS controller: {e}")
             raise SetupFailedError("Unexpected error setting up HEOS controller") from e
@@ -126,17 +118,31 @@ class HeosPlayerProvider(PlayerProvider):
             self.logger.debug("Unloading player %s", player.name)
             await self.mass.players.unregister(player.player_id)
 
-    def on_player_disabled(self, player_id: str) -> None:
-        """Unregister player when it is disabled, cleans up connections."""
-        # Clean up event handling connection
-        self.mass.create_task(self.mass.players.unregister(player_id))
-
-    # TODO: Re-enable when MA lifecycles get updated.
-    # Currently a race-condition prevents `register_or_update` to finish because Enabled is still false  # noqa: E501
-    # def on_player_enabled(self, player_id: str) -> None:
-    #     """Reregister player when it is enabled."""
-    #     self.logger.debug("Attempting player re-enabling")
-    #     if device := self._device_map.get(player_id):
-    #         # Reinstantiate the player
-    #         heos_player = HeosPlayer(self, device)
-    #         self.mass.create_task(heos_player.setup())
+    async def discover_players(self) -> None:
+        """Discover players for this provider."""
+        if self._discovery_running:
+            return  # discovery already running
+        try:
+            self._discovery_running = True
+            self.logger.debug("Discovering HEOS players")
+            devices = await self._heos.get_players()
+            already_registered = {p.player_id for p in self.players}
+            for device in devices.values():
+                player_id = str(device.player_id)
+                if player_id in already_registered:
+                    continue  # already registered
+                # ignore disabled players in discovery
+                player_enabled = self.mass.config.get_raw_player_config_value(
+                    player_id, CONF_ENABLED, default=True
+                )
+                if not player_enabled:
+                    continue
+                self.logger.info("Discovered new HEOS player: %s (%s)", device.name, player_id)
+
+                heos_player = HeosPlayer(self, device)
+                await heos_player.setup()
+        finally:
+            self._discovery_running = False
+        # reschedule discovery
+        task_id = f"discover_players_{self.instance_id}"
+        self.mass.call_later(600, self.discover_players, task_id=task_id)
index c18d7b5bf4eb425c2b3b2910c8024f2084b89ef8..9dbbe1edd2a8a7be0888e1518e9dfb07d8314d0d 100644 (file)
@@ -72,45 +72,46 @@ class MediaAssistantprovider(PlayerProvider):
 
     async def discover_players(self) -> None:
         """Discover Roku players on the network."""
-        if self.config.get_value(CONF_AUTO_DISCOVER):
-            if self._discovery_running:
-                return
-            try:
-                self._discovery_running = True
-                self.logger.debug("Roku discovery started...")
-                discovered_devices: set[str] = set()
-
-                async def on_response(discovery_info: CaseInsensitiveDict) -> None:
-                    """Process discovered device from ssdp search."""
-                    ssdp_st: str | None = discovery_info.get("st")
-                    if not ssdp_st:
-                        return
+        if not self.config.get_value(CONF_AUTO_DISCOVER):
+            return
+        if self._discovery_running:
+            return
+        try:
+            self._discovery_running = True
+            self.logger.debug("Roku discovery started...")
+            discovered_devices: set[str] = set()
+
+            async def on_response(discovery_info: CaseInsensitiveDict) -> None:
+                """Process discovered device from ssdp search."""
+                ssdp_st: str | None = discovery_info.get("st")
+                if not ssdp_st:
+                    return
 
-                    if "roku:ecp" not in ssdp_st:
-                        # we're only interested in Roku devices
-                        return
+                if "roku:ecp" not in ssdp_st:
+                    # we're only interested in Roku devices
+                    return
 
-                    ssdp_usn: str = discovery_info["usn"]
-                    ssdp_udn: str | None = discovery_info.get("_udn")
-                    if not ssdp_udn and ssdp_usn.startswith("uuid:"):
-                        ssdp_udn = "ROKU_" + ssdp_usn.split(":")[-1]
-                    elif ssdp_udn:
-                        ssdp_udn = "ROKU_" + ssdp_udn.split(":")[-1]
-                    else:
-                        return
+                ssdp_usn: str = discovery_info["usn"]
+                ssdp_udn: str | None = discovery_info.get("_udn")
+                if not ssdp_udn and ssdp_usn.startswith("uuid:"):
+                    ssdp_udn = "ROKU_" + ssdp_usn.split(":")[-1]
+                elif ssdp_udn:
+                    ssdp_udn = "ROKU_" + ssdp_udn.split(":")[-1]
+                else:
+                    return
 
-                    if ssdp_udn in discovered_devices:
-                        # already processed this device
-                        return
+                if ssdp_udn in discovered_devices:
+                    # already processed this device
+                    return
 
-                    discovered_devices.add(ssdp_udn)
+                discovered_devices.add(ssdp_udn)
 
-                    await self._device_discovered(discovery_info["_host"])
+                await self._device_discovered(discovery_info["_host"])
 
-                await async_search(on_response, search_target="roku:ecp")
+            await async_search(on_response, search_target="roku:ecp")
 
-            finally:
-                self._discovery_running = False
+        finally:
+            self._discovery_running = False
 
         def reschedule() -> None:
             self.mass.create_task(self.discover_players())