Add type hints to config controller functions (#2639)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Mon, 17 Nov 2025 11:20:30 +0000 (12:20 +0100)
committerGitHub <noreply@github.com>
Mon, 17 Nov 2025 11:20:30 +0000 (12:20 +0100)
music_assistant/controllers/config.py
music_assistant/controllers/music.py
music_assistant/providers/airplay/protocols/raop.py
music_assistant/providers/squeezelite/provider.py
music_assistant/providers/universal_group/player.py

index b42b8ff9a352a9f9a79d65e1e2661388de89f783..765d4f625b83802f4e0b2d3a0bc58ccf97ad5597 100644 (file)
@@ -5,7 +5,7 @@ from __future__ import annotations
 import base64
 import logging
 import os
-from typing import TYPE_CHECKING, Any, cast
+from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload
 from uuid import uuid4
 
 import aiofiles
@@ -82,6 +82,9 @@ DEFAULT_SAVE_DELAY = 5
 
 BASE_KEYS = ("enabled", "name", "available", "default_name", "provider", "type")
 
+# TypeVar for config value type inference
+_ConfigValueT = TypeVar("_ConfigValueT", bound=ConfigValueType)
+
 isfile = wrap(os.path.isfile)
 remove = wrap(os.remove)
 rename = wrap(os.rename)
@@ -229,9 +232,34 @@ class ConfigController:
         msg = f"No config found for provider id {instance_id}"
         raise KeyError(msg)
 
+    @overload
+    async def get_provider_config_value(
+        self, instance_id: str, key: str, *, return_type: type[_ConfigValueT] = ...
+    ) -> _ConfigValueT: ...
+
+    @overload
+    async def get_provider_config_value(
+        self, instance_id: str, key: str, *, return_type: None = ...
+    ) -> ConfigValueType: ...
+
     @api_command("config/providers/get_value")
-    async def get_provider_config_value(self, instance_id: str, key: str) -> ConfigValueType:
-        """Return single configentry value for a provider."""
+    async def get_provider_config_value(
+        self,
+        instance_id: str,
+        key: str,
+        *,
+        return_type: type[_ConfigValueT | ConfigValueType] | None = None,
+    ) -> _ConfigValueT | ConfigValueType:
+        """
+        Return single configentry value for a provider.
+
+        :param instance_id: The provider instance ID.
+        :param key: The config key to retrieve.
+        :param return_type: Optional type hint for type inference (e.g., str, int, bool).
+            Note: This parameter is used purely for static type checking and does not
+            perform runtime type validation. Callers are responsible for ensuring the
+            specified type matches the actual config value type.
+        """
         cache_key = f"prov_conf_value_{instance_id}.{key}"
         if (cached_value := self._value_cache.get(cache_key)) is not None:
             return cached_value
@@ -481,14 +509,55 @@ class ConfigController:
 
         return await player.get_config_entries(action=action, values=values)
 
+    @overload
+    async def get_player_config_value(
+        self,
+        player_id: str,
+        key: str,
+        unpack_splitted_values: Literal[True],
+        *,
+        return_type: type[_ConfigValueT] | None = ...,
+    ) -> tuple[str, ...] | list[tuple[str, ...]]: ...
+
+    @overload
+    async def get_player_config_value(
+        self,
+        player_id: str,
+        key: str,
+        unpack_splitted_values: Literal[False] = False,
+        *,
+        return_type: type[_ConfigValueT] = ...,
+    ) -> _ConfigValueT: ...
+
+    @overload
+    async def get_player_config_value(
+        self,
+        player_id: str,
+        key: str,
+        unpack_splitted_values: Literal[False] = False,
+        *,
+        return_type: None = ...,
+    ) -> ConfigValueType: ...
+
     @api_command("config/players/get_value")
     async def get_player_config_value(
         self,
         player_id: str,
         key: str,
         unpack_splitted_values: bool = False,
-    ) -> ConfigValueType | tuple[str, ...] | list[tuple[str, ...]]:
-        """Return single configentry value for a player."""
+        return_type: type[_ConfigValueT | ConfigValueType] | None = None,
+    ) -> _ConfigValueT | ConfigValueType | tuple[str, ...] | list[tuple[str, ...]]:
+        """
+        Return single configentry value for a player.
+
+        :param player_id: The player ID.
+        :param key: The config key to retrieve.
+        :param unpack_splitted_values: Whether to unpack multi-value config entries.
+        :param return_type: Optional type hint for type inference (e.g., str, int, bool).
+            Note: This parameter is used purely for static type checking and does not
+            perform runtime type validation. Callers are responsible for ensuring the
+            specified type matches the actual config value type.
+        """
         conf = await self.get_player_config(player_id)
         if unpack_splitted_values:
             return conf.values[key].get_splitted_values()
@@ -498,6 +567,19 @@ class ConfigController:
             else conf.values[key].default_value
         )
 
+    if TYPE_CHECKING:
+        # Overload for when default is provided - return type matches default type
+        @overload
+        def get_raw_player_config_value(
+            self, player_id: str, key: str, default: _ConfigValueT
+        ) -> _ConfigValueT: ...
+
+        # Overload for when no default is provided - return ConfigValueType | None
+        @overload
+        def get_raw_player_config_value(
+            self, player_id: str, key: str, default: None = None
+        ) -> ConfigValueType | None: ...
+
     def get_raw_player_config_value(
         self, player_id: str, key: str, default: ConfigValueType = None
     ) -> ConfigValueType:
@@ -806,9 +888,34 @@ class ConfigController:
         config_entries = await self.get_core_config_entries(domain)
         return cast("CoreConfig", CoreConfig.parse(config_entries, raw_conf))
 
+    @overload
+    async def get_core_config_value(
+        self, domain: str, key: str, *, return_type: type[_ConfigValueT] = ...
+    ) -> _ConfigValueT: ...
+
+    @overload
+    async def get_core_config_value(
+        self, domain: str, key: str, *, return_type: None = ...
+    ) -> ConfigValueType: ...
+
     @api_command("config/core/get_value")
-    async def get_core_config_value(self, domain: str, key: str) -> ConfigValueType:
-        """Return single configentry value for a core controller."""
+    async def get_core_config_value(
+        self,
+        domain: str,
+        key: str,
+        *,
+        return_type: type[_ConfigValueT | ConfigValueType] | None = None,
+    ) -> _ConfigValueT | ConfigValueType:
+        """
+        Return single configentry value for a core controller.
+
+        :param domain: The core controller domain.
+        :param key: The config key to retrieve.
+        :param return_type: Optional type hint for type inference (e.g., str, int, bool).
+            Note: This parameter is used purely for static type checking and does not
+            perform runtime type validation. Callers are responsible for ensuring the
+            specified type matches the actual config value type.
+        """
         conf = await self.get_core_config(domain)
         return (
             conf.values[key].value
@@ -862,6 +969,19 @@ class ConfigController:
         # return full config, just in case
         return await self.get_core_config(domain)
 
+    if TYPE_CHECKING:
+        # Overload for when default is provided - return type matches default type
+        @overload
+        def get_raw_core_config_value(
+            self, core_module: str, key: str, default: _ConfigValueT
+        ) -> _ConfigValueT: ...
+
+        # Overload for when no default is provided - return ConfigValueType | None
+        @overload
+        def get_raw_core_config_value(
+            self, core_module: str, key: str, default: None = None
+        ) -> ConfigValueType | None: ...
+
     def get_raw_core_config_value(
         self, core_module: str, key: str, default: ConfigValueType = None
     ) -> ConfigValueType:
@@ -878,6 +998,19 @@ class ConfigController:
             ),
         )
 
+    if TYPE_CHECKING:
+        # Overload for when default is provided - return type matches default type
+        @overload
+        def get_raw_provider_config_value(
+            self, provider_instance: str, key: str, default: _ConfigValueT
+        ) -> _ConfigValueT: ...
+
+        # Overload for when no default is provided - return ConfigValueType | None
+        @overload
+        def get_raw_provider_config_value(
+            self, provider_instance: str, key: str, default: None = None
+        ) -> ConfigValueType | None: ...
+
     def get_raw_provider_config_value(
         self, provider_instance: str, key: str, default: ConfigValueType = None
     ) -> ConfigValueType:
index 6472226268be84d49ff6d94904cb85da1fca894f..3119a6191912ca0bf45319e03182cb229ec6fc4a 100644 (file)
@@ -1533,9 +1533,8 @@ class MusicController(CoreController):
         if not sync_conf:
             return
         conf_key = f"provider_sync_interval_{media_type.value}s"
-        sync_interval = cast(
-            "int",
-            await self.mass.config.get_provider_config_value(provider.instance_id, conf_key),
+        sync_interval = await self.mass.config.get_provider_config_value(
+            provider.instance_id, conf_key, return_type=int
         )
         if sync_interval <= 0:
             # sync disabled for this media type
index 9ece6daaaf699a3c96d2b39f866d41922ee42a35..45227df3d1d89580bf8f5b8df0098d51e33fc904 100644 (file)
@@ -53,7 +53,7 @@ class RaopStream(AirPlayProtocol):
             if prop_value := self.player.raop_discovery_info.decoded_properties.get(prop):
                 extra_args += [f"-{prop}", prop_value]
         if device_password := self.mass.config.get_raw_player_config_value(
-            player_id, CONF_PASSWORD, None
+            player_id, CONF_PASSWORD
         ):
             extra_args += ["-password", str(device_password)]
         # Add AirPlay credentials from pairing if available (for Apple devices)
@@ -64,7 +64,7 @@ class RaopStream(AirPlayProtocol):
         elif self.prov.logger.isEnabledFor(VERBOSE_LOG_LEVEL):
             extra_args += ["-debug", "10"]
         read_ahead = await self.mass.config.get_player_config_value(
-            player_id, CONF_READ_AHEAD_BUFFER
+            player_id, CONF_READ_AHEAD_BUFFER, return_type=int
         )
 
         # cliraop is the binary that handles the actual raop streaming to the player
index 62cb0b07e7782d5d0732d2d1041dd82758f1e9ab..3e2de090ef38cd4e60e4af7faf0b2db1969dac4d 100644 (file)
@@ -121,11 +121,10 @@ class SqueezelitePlayerProvider(PlayerProvider):
 
     def get_corrected_elapsed_milliseconds(self, slimplayer: SlimClient) -> int:
         """Return corrected elapsed milliseconds for a slimplayer."""
-        sync_delay = cast(
-            "int",
-            self.mass.config.get_raw_player_config_value(slimplayer.player_id, CONF_SYNC_ADJUST, 0),
+        sync_delay = self.mass.config.get_raw_player_config_value(
+            slimplayer.player_id, CONF_SYNC_ADJUST, 0
         )
-        return cast("int", slimplayer.elapsed_milliseconds - sync_delay)
+        return int(slimplayer.elapsed_milliseconds - sync_delay)
 
     def _handle_slimproto_event(
         self,
index 23a4fc5723998d8e234c2e43f0ce2b7d7cbaeb7c..b94a04791394c95e259b12282de9c6dfb0840b1b 100644 (file)
@@ -352,9 +352,8 @@ class UniversalGroupPlayer(GroupPlayer):
                 content_sample_rate=UGP_FORMAT.sample_rate,
                 content_bit_depth=UGP_FORMAT.bit_depth,
             )
-            http_profile = cast(
-                "str",
-                await self.mass.config.get_player_config_value(child_player_id, CONF_HTTP_PROFILE),
+            http_profile = await self.mass.config.get_player_config_value(
+                child_player_id, CONF_HTTP_PROFILE, return_type=str
             )
         elif output_format_str == "flac":
             output_format = AudioFormat(content_type=ContentType.FLAC)