From: OzGav Date: Fri, 7 Nov 2025 16:18:39 +0000 (+1000) Subject: Typing fixes for the config controller (#2570) X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=0c00c003c13a153eace46102acf3cfb8f58235f4;p=music-assistant-server.git Typing fixes for the config controller (#2570) --- diff --git a/music_assistant/controllers/config.py b/music_assistant/controllers/config.py index f8dae8ff..b42b8ff9 100644 --- a/music_assistant/controllers/config.py +++ b/music_assistant/controllers/config.py @@ -5,7 +5,7 @@ from __future__ import annotations import base64 import logging import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from uuid import uuid4 import aiofiles @@ -69,6 +69,7 @@ from music_assistant.helpers.api import api_command from music_assistant.helpers.json import JSON_DECODE_EXCEPTIONS, async_json_dumps, async_json_loads from music_assistant.helpers.util import load_provider_module from music_assistant.models import ProviderModuleType +from music_assistant.models.music_provider import MusicProvider if TYPE_CHECKING: import asyncio @@ -117,7 +118,7 @@ class ConfigController: @property def onboard_done(self) -> bool: """Return True if onboarding is done.""" - return self.get(CONF_ONBOARD_DONE, False) + return bool(self.get(CONF_ONBOARD_DONE, False)) async def close(self) -> None: """Handle logic on server stop.""" @@ -196,12 +197,12 @@ class ConfigController: include_values: bool = False, ) -> list[ProviderConfig]: """Return all known provider configurations, optionally filtered by ProviderType.""" - raw_values: dict[str, dict] = self.get(CONF_PROVIDERS, {}) + raw_values = self.get(CONF_PROVIDERS, {}) prov_entries = {x.domain for x in self.mass.get_provider_manifests()} return [ await self.get_provider_config(prov_conf["instance_id"]) if include_values - else ProviderConfig.parse([], prov_conf) + else cast("ProviderConfig", ProviderConfig.parse([], prov_conf)) for prov_conf in raw_values.values() if (provider_type is None or prov_conf["type"] == provider_type) and (provider_domain is None or prov_conf["domain"] == provider_domain) @@ -224,7 +225,7 @@ class ConfigController: else: msg = f"Unknown provider domain: {raw_conf['domain']}" raise KeyError(msg) - return ProviderConfig.parse(config_entries, raw_conf) + return cast("ProviderConfig", ProviderConfig.parse(config_entries, raw_conf)) msg = f"No config found for provider id {instance_id}" raise KeyError(msg) @@ -284,9 +285,7 @@ class ConfigController: supported_features = provider.supported_features else: provider = None - supported_features: set[ProviderFeature] = getattr( - prov_mod, "SUPPORTED_FEATURES", set() - ) + supported_features = getattr(prov_mod, "SUPPORTED_FEATURES", set()) extra_entries: list[ConfigEntry] = [] if manifest.type == ProviderType.MUSIC: # library sync settings @@ -294,13 +293,21 @@ class ConfigController: extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_ARTISTS) if ProviderFeature.LIBRARY_ALBUMS in supported_features: extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_ALBUMS) - if provider and provider.is_streaming_provider: + if ( + provider + and isinstance(provider, MusicProvider) + and provider.is_streaming_provider + ): extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_ALBUM_TRACKS) if ProviderFeature.LIBRARY_TRACKS in supported_features: extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_TRACKS) if ProviderFeature.LIBRARY_PLAYLISTS in supported_features: extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_PLAYLISTS) - if provider and provider.is_streaming_provider: + if ( + provider + and isinstance(provider, MusicProvider) + and provider.is_streaming_provider + ): extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_PLAYLIST_TRACKS) if ProviderFeature.LIBRARY_AUDIOBOOKS in supported_features: extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_AUDIOBOOKS) @@ -413,7 +420,7 @@ class ConfigController: return [ await self.get_player_config(raw_conf["player_id"]) if include_values - else PlayerConfig.parse([], raw_conf) + else cast("PlayerConfig", PlayerConfig.parse([], raw_conf)) for raw_conf in list(self.get(CONF_PLAYERS, {}).values()) # filter out unavailable providers (only if we requested the full info) if ( @@ -447,7 +454,7 @@ class ConfigController: raw_conf["available"] = False raw_conf["name"] = raw_conf.get("name") raw_conf["default_name"] = raw_conf.get("default_name") or raw_conf["player_id"] - return PlayerConfig.parse(conf_entries, raw_conf) + return cast("PlayerConfig", PlayerConfig.parse(conf_entries, raw_conf)) msg = f"No config found for player id {player_id}" raise KeyError(msg) @@ -480,7 +487,7 @@ class ConfigController: player_id: str, key: str, unpack_splitted_values: bool = False, - ) -> ConfigValueType: + ) -> ConfigValueType | tuple[str, ...] | list[tuple[str, ...]]: """Return single configentry value for a player.""" conf = await self.get_player_config(player_id) if unpack_splitted_values: @@ -499,9 +506,12 @@ class ConfigController: Note that this only returns the stored value without any validation or default. """ - return self.get( - f"{CONF_PLAYERS}/{player_id}/values/{key}", - self.get(f"{CONF_PLAYERS}/{player_id}/{key}", default), + return cast( + "ConfigValueType", + self.get( + f"{CONF_PLAYERS}/{player_id}/values/{key}", + self.get(f"{CONF_PLAYERS}/{player_id}/{key}", default), + ), ) def get_base_player_config(self, player_id: str, provider: str) -> PlayerConfig: @@ -516,7 +526,7 @@ class ConfigController: "player_id": player_id, "provider": provider, } - return PlayerConfig.parse([], raw_conf) + return cast("PlayerConfig", PlayerConfig.parse([], raw_conf)) @api_command("config/players/save") async def save_player_config( @@ -527,7 +537,7 @@ class ConfigController: changed_keys = config.update(values) if not changed_keys: # no changes - return None + return config # validate/handle the update in the player manager await self.mass.players.on_player_config_change(config, changed_keys) # actually store changes (if the above did not raise) @@ -602,9 +612,15 @@ class ConfigController: dsp_config.filters.append( ToneControlFilter( enabled=True, - bass_level=deprecated_eq_bass, - mid_level=deprecated_eq_mid, - treble_level=deprecated_eq_treble, + bass_level=float(deprecated_eq_bass) + if isinstance(deprecated_eq_bass, (int, float, str)) + else 0.0, + mid_level=float(deprecated_eq_mid) + if isinstance(deprecated_eq_mid, (int, float, str)) + else 0.0, + treble_level=float(deprecated_eq_treble) + if isinstance(deprecated_eq_treble, (int, float, str)) + else 0.0, ) ) @@ -748,17 +764,20 @@ class ConfigController: instance_id = f"{manifest.domain}--{shortuuid.random(8)}" else: instance_id = manifest.domain - default_config: ProviderConfig = ProviderConfig.parse( - config_entries, - { - "type": manifest.type.value, - "domain": manifest.domain, - "instance_id": instance_id, - "name": manifest.name, - # note: this will only work for providers that do - # not have any required config entries or provide defaults - "values": {}, - }, + default_config = cast( + "ProviderConfig", + ProviderConfig.parse( + config_entries, + { + "type": manifest.type.value, + "domain": manifest.domain, + "instance_id": instance_id, + "name": manifest.name, + # note: this will only work for providers that do + # not have any required config entries or provide defaults + "values": {}, + }, + ), ) default_config.validate() conf_key = f"{CONF_PROVIDERS}/{default_config.instance_id}" @@ -770,9 +789,12 @@ class ConfigController: return [ await self.get_core_config(core_controller) if include_values - else CoreConfig.parse( - [], - self.get(f"{CONF_CORE}/{core_controller}", {"domain": core_controller}), + else cast( + "CoreConfig", + CoreConfig.parse( + [], + self.get(f"{CONF_CORE}/{core_controller}", {"domain": core_controller}), + ), ) for core_controller in CONFIGURABLE_CORE_CONTROLLERS ] @@ -782,7 +804,7 @@ class ConfigController: """Return configuration for a single core controller.""" raw_conf = self.get(f"{CONF_CORE}/{domain}", {"domain": domain}) config_entries = await self.get_core_config_entries(domain) - return CoreConfig.parse(config_entries, raw_conf) + return cast("CoreConfig", CoreConfig.parse(config_entries, raw_conf)) @api_command("config/core/get_value") async def get_core_config_value(self, domain: str, key: str) -> ConfigValueType: @@ -848,9 +870,12 @@ class ConfigController: Note that this only returns the stored value without any validation or default. """ - return self.get( - f"{CONF_CORE}/{core_module}/values/{key}", - self.get(f"{CONF_CORE}/{core_module}/{key}", default), + return cast( + "ConfigValueType", + self.get( + f"{CONF_CORE}/{core_module}/values/{key}", + self.get(f"{CONF_CORE}/{core_module}/{key}", default), + ), ) def get_raw_provider_config_value( @@ -861,9 +886,12 @@ class ConfigController: Note that this only returns the stored value without any validation or default. """ - return self.get( - f"{CONF_PROVIDERS}/{provider_instance}/values/{key}", - self.get(f"{CONF_PROVIDERS}/{provider_instance}/{key}", default), + return cast( + "ConfigValueType", + self.get( + f"{CONF_PROVIDERS}/{provider_instance}/values/{key}", + self.get(f"{CONF_PROVIDERS}/{provider_instance}/{key}", default), + ), ) def set_raw_provider_config_value( @@ -883,6 +911,9 @@ class ConfigController: msg = f"Invalid provider_instance: {provider_instance}" raise KeyError(msg) if encrypted: + if not isinstance(value, str): + msg = f"Cannot encrypt non-string value for key {key}" + raise ValueError(msg) value = self.encrypt_string(value) if key in BASE_KEYS: self.set(f"{CONF_PROVIDERS}/{provider_instance}/{key}", value) @@ -934,6 +965,7 @@ class ConfigController: """Encrypt a (password)string with Fernet.""" if str_value.startswith(ENCRYPT_SUFFIX): return str_value + assert self._fernet is not None return ENCRYPT_SUFFIX + self._fernet.encrypt(str_value.encode()).decode() def decrypt_string(self, encrypted_str: str) -> str: @@ -942,6 +974,7 @@ class ConfigController: return encrypted_str if not encrypted_str.startswith(ENCRYPT_SUFFIX): return encrypted_str + assert self._fernet is not None try: return self._fernet.decrypt(encrypted_str.replace(ENCRYPT_SUFFIX, "").encode()).decode() except InvalidToken as err: @@ -972,7 +1005,6 @@ class ConfigController: instance_id: str provider_config: dict[str, Any] player_config: dict[str, Any] - values: dict[str, ConfigValueType] # Older versions of MA can create corrupt entries with no domain if retrying # logic runs after a provider has been removed. Remove those corrupt entries. @@ -1020,7 +1052,12 @@ class ConfigController: # migrate player_group entries ugp_found = False for player_config in self._data.get(CONF_PLAYERS, {}).values(): - if not player_config.get("provider").startswith("player_group"): + provider = player_config.get("provider") + if ( + not provider + or not isinstance(provider, str) + or not provider.startswith("player_group") + ): continue if not (values := player_config.get("values")): continue @@ -1144,7 +1181,7 @@ class ConfigController: self, provider_domain: str, values: dict[str, ConfigValueType], - ) -> list[ConfigEntry] | ProviderConfig: + ) -> ProviderConfig: """ Add new Provider (instance). @@ -1181,15 +1218,18 @@ class ConfigController: config_entries = await self.get_provider_config_entries( provider_domain=provider_domain, instance_id=instance_id, values=values ) - config: ProviderConfig = ProviderConfig.parse( - config_entries, - { - "type": manifest.type.value, - "domain": manifest.domain, - "instance_id": instance_id, - "default_name": manifest.name, - "values": values, - }, + config = cast( + "ProviderConfig", + ProviderConfig.parse( + config_entries, + { + "type": manifest.type.value, + "domain": manifest.domain, + "instance_id": instance_id, + "default_name": manifest.name, + "values": values, + }, + ), ) # validate the new config config.validate() diff --git a/pyproject.toml b/pyproject.toml index ae96ca3f..24d106fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,7 +132,6 @@ enable_error_code = [ ] exclude = [ '^music_assistant/controllers/cache.py$', - '^music_assistant/controllers/config.py$', '^music_assistant/controllers/media/albums.py*$', '^music_assistant/controllers/media/artists.py*$', '^music_assistant/controllers/media/audiobooks.py*$',