From ee691d124ef2e0ae8c580f496dc837df1c6a2271 Mon Sep 17 00:00:00 2001 From: Marcel van der Veldt Date: Mon, 17 Nov 2025 12:20:30 +0100 Subject: [PATCH] Add type hints to config controller functions (#2639) --- music_assistant/controllers/config.py | 147 +++++++++++++++++- music_assistant/controllers/music.py | 5 +- .../providers/airplay/protocols/raop.py | 4 +- .../providers/squeezelite/provider.py | 7 +- .../providers/universal_group/player.py | 5 +- 5 files changed, 149 insertions(+), 19 deletions(-) diff --git a/music_assistant/controllers/config.py b/music_assistant/controllers/config.py index b42b8ff9..765d4f62 100644 --- a/music_assistant/controllers/config.py +++ b/music_assistant/controllers/config.py @@ -5,7 +5,7 @@ from __future__ import annotations import base64 import logging import os -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload from uuid import uuid4 import aiofiles @@ -82,6 +82,9 @@ DEFAULT_SAVE_DELAY = 5 BASE_KEYS = ("enabled", "name", "available", "default_name", "provider", "type") +# TypeVar for config value type inference +_ConfigValueT = TypeVar("_ConfigValueT", bound=ConfigValueType) + isfile = wrap(os.path.isfile) remove = wrap(os.remove) rename = wrap(os.rename) @@ -229,9 +232,34 @@ class ConfigController: msg = f"No config found for provider id {instance_id}" raise KeyError(msg) + @overload + async def get_provider_config_value( + self, instance_id: str, key: str, *, return_type: type[_ConfigValueT] = ... + ) -> _ConfigValueT: ... + + @overload + async def get_provider_config_value( + self, instance_id: str, key: str, *, return_type: None = ... + ) -> ConfigValueType: ... + @api_command("config/providers/get_value") - async def get_provider_config_value(self, instance_id: str, key: str) -> ConfigValueType: - """Return single configentry value for a provider.""" + async def get_provider_config_value( + self, + instance_id: str, + key: str, + *, + return_type: type[_ConfigValueT | ConfigValueType] | None = None, + ) -> _ConfigValueT | ConfigValueType: + """ + Return single configentry value for a provider. + + :param instance_id: The provider instance ID. + :param key: The config key to retrieve. + :param return_type: Optional type hint for type inference (e.g., str, int, bool). + Note: This parameter is used purely for static type checking and does not + perform runtime type validation. Callers are responsible for ensuring the + specified type matches the actual config value type. + """ cache_key = f"prov_conf_value_{instance_id}.{key}" if (cached_value := self._value_cache.get(cache_key)) is not None: return cached_value @@ -481,14 +509,55 @@ class ConfigController: return await player.get_config_entries(action=action, values=values) + @overload + async def get_player_config_value( + self, + player_id: str, + key: str, + unpack_splitted_values: Literal[True], + *, + return_type: type[_ConfigValueT] | None = ..., + ) -> tuple[str, ...] | list[tuple[str, ...]]: ... + + @overload + async def get_player_config_value( + self, + player_id: str, + key: str, + unpack_splitted_values: Literal[False] = False, + *, + return_type: type[_ConfigValueT] = ..., + ) -> _ConfigValueT: ... + + @overload + async def get_player_config_value( + self, + player_id: str, + key: str, + unpack_splitted_values: Literal[False] = False, + *, + return_type: None = ..., + ) -> ConfigValueType: ... + @api_command("config/players/get_value") async def get_player_config_value( self, player_id: str, key: str, unpack_splitted_values: bool = False, - ) -> ConfigValueType | tuple[str, ...] | list[tuple[str, ...]]: - """Return single configentry value for a player.""" + return_type: type[_ConfigValueT | ConfigValueType] | None = None, + ) -> _ConfigValueT | ConfigValueType | tuple[str, ...] | list[tuple[str, ...]]: + """ + Return single configentry value for a player. + + :param player_id: The player ID. + :param key: The config key to retrieve. + :param unpack_splitted_values: Whether to unpack multi-value config entries. + :param return_type: Optional type hint for type inference (e.g., str, int, bool). + Note: This parameter is used purely for static type checking and does not + perform runtime type validation. Callers are responsible for ensuring the + specified type matches the actual config value type. + """ conf = await self.get_player_config(player_id) if unpack_splitted_values: return conf.values[key].get_splitted_values() @@ -498,6 +567,19 @@ class ConfigController: else conf.values[key].default_value ) + if TYPE_CHECKING: + # Overload for when default is provided - return type matches default type + @overload + def get_raw_player_config_value( + self, player_id: str, key: str, default: _ConfigValueT + ) -> _ConfigValueT: ... + + # Overload for when no default is provided - return ConfigValueType | None + @overload + def get_raw_player_config_value( + self, player_id: str, key: str, default: None = None + ) -> ConfigValueType | None: ... + def get_raw_player_config_value( self, player_id: str, key: str, default: ConfigValueType = None ) -> ConfigValueType: @@ -806,9 +888,34 @@ class ConfigController: config_entries = await self.get_core_config_entries(domain) return cast("CoreConfig", CoreConfig.parse(config_entries, raw_conf)) + @overload + async def get_core_config_value( + self, domain: str, key: str, *, return_type: type[_ConfigValueT] = ... + ) -> _ConfigValueT: ... + + @overload + async def get_core_config_value( + self, domain: str, key: str, *, return_type: None = ... + ) -> ConfigValueType: ... + @api_command("config/core/get_value") - async def get_core_config_value(self, domain: str, key: str) -> ConfigValueType: - """Return single configentry value for a core controller.""" + async def get_core_config_value( + self, + domain: str, + key: str, + *, + return_type: type[_ConfigValueT | ConfigValueType] | None = None, + ) -> _ConfigValueT | ConfigValueType: + """ + Return single configentry value for a core controller. + + :param domain: The core controller domain. + :param key: The config key to retrieve. + :param return_type: Optional type hint for type inference (e.g., str, int, bool). + Note: This parameter is used purely for static type checking and does not + perform runtime type validation. Callers are responsible for ensuring the + specified type matches the actual config value type. + """ conf = await self.get_core_config(domain) return ( conf.values[key].value @@ -862,6 +969,19 @@ class ConfigController: # return full config, just in case return await self.get_core_config(domain) + if TYPE_CHECKING: + # Overload for when default is provided - return type matches default type + @overload + def get_raw_core_config_value( + self, core_module: str, key: str, default: _ConfigValueT + ) -> _ConfigValueT: ... + + # Overload for when no default is provided - return ConfigValueType | None + @overload + def get_raw_core_config_value( + self, core_module: str, key: str, default: None = None + ) -> ConfigValueType | None: ... + def get_raw_core_config_value( self, core_module: str, key: str, default: ConfigValueType = None ) -> ConfigValueType: @@ -878,6 +998,19 @@ class ConfigController: ), ) + if TYPE_CHECKING: + # Overload for when default is provided - return type matches default type + @overload + def get_raw_provider_config_value( + self, provider_instance: str, key: str, default: _ConfigValueT + ) -> _ConfigValueT: ... + + # Overload for when no default is provided - return ConfigValueType | None + @overload + def get_raw_provider_config_value( + self, provider_instance: str, key: str, default: None = None + ) -> ConfigValueType | None: ... + def get_raw_provider_config_value( self, provider_instance: str, key: str, default: ConfigValueType = None ) -> ConfigValueType: diff --git a/music_assistant/controllers/music.py b/music_assistant/controllers/music.py index 64722262..3119a619 100644 --- a/music_assistant/controllers/music.py +++ b/music_assistant/controllers/music.py @@ -1533,9 +1533,8 @@ class MusicController(CoreController): if not sync_conf: return conf_key = f"provider_sync_interval_{media_type.value}s" - sync_interval = cast( - "int", - await self.mass.config.get_provider_config_value(provider.instance_id, conf_key), + sync_interval = await self.mass.config.get_provider_config_value( + provider.instance_id, conf_key, return_type=int ) if sync_interval <= 0: # sync disabled for this media type diff --git a/music_assistant/providers/airplay/protocols/raop.py b/music_assistant/providers/airplay/protocols/raop.py index 9ece6daa..45227df3 100644 --- a/music_assistant/providers/airplay/protocols/raop.py +++ b/music_assistant/providers/airplay/protocols/raop.py @@ -53,7 +53,7 @@ class RaopStream(AirPlayProtocol): if prop_value := self.player.raop_discovery_info.decoded_properties.get(prop): extra_args += [f"-{prop}", prop_value] if device_password := self.mass.config.get_raw_player_config_value( - player_id, CONF_PASSWORD, None + player_id, CONF_PASSWORD ): extra_args += ["-password", str(device_password)] # Add AirPlay credentials from pairing if available (for Apple devices) @@ -64,7 +64,7 @@ class RaopStream(AirPlayProtocol): elif self.prov.logger.isEnabledFor(VERBOSE_LOG_LEVEL): extra_args += ["-debug", "10"] read_ahead = await self.mass.config.get_player_config_value( - player_id, CONF_READ_AHEAD_BUFFER + player_id, CONF_READ_AHEAD_BUFFER, return_type=int ) # cliraop is the binary that handles the actual raop streaming to the player diff --git a/music_assistant/providers/squeezelite/provider.py b/music_assistant/providers/squeezelite/provider.py index 62cb0b07..3e2de090 100644 --- a/music_assistant/providers/squeezelite/provider.py +++ b/music_assistant/providers/squeezelite/provider.py @@ -121,11 +121,10 @@ class SqueezelitePlayerProvider(PlayerProvider): def get_corrected_elapsed_milliseconds(self, slimplayer: SlimClient) -> int: """Return corrected elapsed milliseconds for a slimplayer.""" - sync_delay = cast( - "int", - self.mass.config.get_raw_player_config_value(slimplayer.player_id, CONF_SYNC_ADJUST, 0), + sync_delay = self.mass.config.get_raw_player_config_value( + slimplayer.player_id, CONF_SYNC_ADJUST, 0 ) - return cast("int", slimplayer.elapsed_milliseconds - sync_delay) + return int(slimplayer.elapsed_milliseconds - sync_delay) def _handle_slimproto_event( self, diff --git a/music_assistant/providers/universal_group/player.py b/music_assistant/providers/universal_group/player.py index 23a4fc57..b94a0479 100644 --- a/music_assistant/providers/universal_group/player.py +++ b/music_assistant/providers/universal_group/player.py @@ -352,9 +352,8 @@ class UniversalGroupPlayer(GroupPlayer): content_sample_rate=UGP_FORMAT.sample_rate, content_bit_depth=UGP_FORMAT.bit_depth, ) - http_profile = cast( - "str", - await self.mass.config.get_player_config_value(child_player_id, CONF_HTTP_PROFILE), + http_profile = await self.mass.config.get_player_config_value( + child_player_id, CONF_HTTP_PROFILE, return_type=str ) elif output_format_str == "flac": output_format = AudioFormat(content_type=ContentType.FLAC) -- 2.34.1