Typing fixes for the config controller (#2570)
authorOzGav <gavnosp@hotmail.com>
Fri, 7 Nov 2025 16:18:39 +0000 (02:18 +1000)
committerGitHub <noreply@github.com>
Fri, 7 Nov 2025 16:18:39 +0000 (17:18 +0100)
music_assistant/controllers/config.py
pyproject.toml

index f8dae8ff378d77eae3181a0e63287d0d0d5da732..b42b8ff9a352a9f9a79d65e1e2661388de89f783 100644 (file)
@@ -5,7 +5,7 @@ from __future__ import annotations
 import base64
 import logging
 import os
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, cast
 from uuid import uuid4
 
 import aiofiles
@@ -69,6 +69,7 @@ from music_assistant.helpers.api import api_command
 from music_assistant.helpers.json import JSON_DECODE_EXCEPTIONS, async_json_dumps, async_json_loads
 from music_assistant.helpers.util import load_provider_module
 from music_assistant.models import ProviderModuleType
+from music_assistant.models.music_provider import MusicProvider
 
 if TYPE_CHECKING:
     import asyncio
@@ -117,7 +118,7 @@ class ConfigController:
     @property
     def onboard_done(self) -> bool:
         """Return True if onboarding is done."""
-        return self.get(CONF_ONBOARD_DONE, False)
+        return bool(self.get(CONF_ONBOARD_DONE, False))
 
     async def close(self) -> None:
         """Handle logic on server stop."""
@@ -196,12 +197,12 @@ class ConfigController:
         include_values: bool = False,
     ) -> list[ProviderConfig]:
         """Return all known provider configurations, optionally filtered by ProviderType."""
-        raw_values: dict[str, dict] = self.get(CONF_PROVIDERS, {})
+        raw_values = self.get(CONF_PROVIDERS, {})
         prov_entries = {x.domain for x in self.mass.get_provider_manifests()}
         return [
             await self.get_provider_config(prov_conf["instance_id"])
             if include_values
-            else ProviderConfig.parse([], prov_conf)
+            else cast("ProviderConfig", ProviderConfig.parse([], prov_conf))
             for prov_conf in raw_values.values()
             if (provider_type is None or prov_conf["type"] == provider_type)
             and (provider_domain is None or prov_conf["domain"] == provider_domain)
@@ -224,7 +225,7 @@ class ConfigController:
             else:
                 msg = f"Unknown provider domain: {raw_conf['domain']}"
                 raise KeyError(msg)
-            return ProviderConfig.parse(config_entries, raw_conf)
+            return cast("ProviderConfig", ProviderConfig.parse(config_entries, raw_conf))
         msg = f"No config found for provider id {instance_id}"
         raise KeyError(msg)
 
@@ -284,9 +285,7 @@ class ConfigController:
             supported_features = provider.supported_features
         else:
             provider = None
-            supported_features: set[ProviderFeature] = getattr(
-                prov_mod, "SUPPORTED_FEATURES", set()
-            )
+            supported_features = getattr(prov_mod, "SUPPORTED_FEATURES", set())
         extra_entries: list[ConfigEntry] = []
         if manifest.type == ProviderType.MUSIC:
             # library sync settings
@@ -294,13 +293,21 @@ class ConfigController:
                 extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_ARTISTS)
             if ProviderFeature.LIBRARY_ALBUMS in supported_features:
                 extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_ALBUMS)
-                if provider and provider.is_streaming_provider:
+                if (
+                    provider
+                    and isinstance(provider, MusicProvider)
+                    and provider.is_streaming_provider
+                ):
                     extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_ALBUM_TRACKS)
             if ProviderFeature.LIBRARY_TRACKS in supported_features:
                 extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_TRACKS)
             if ProviderFeature.LIBRARY_PLAYLISTS in supported_features:
                 extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_PLAYLISTS)
-                if provider and provider.is_streaming_provider:
+                if (
+                    provider
+                    and isinstance(provider, MusicProvider)
+                    and provider.is_streaming_provider
+                ):
                     extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_PLAYLIST_TRACKS)
             if ProviderFeature.LIBRARY_AUDIOBOOKS in supported_features:
                 extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_AUDIOBOOKS)
@@ -413,7 +420,7 @@ class ConfigController:
         return [
             await self.get_player_config(raw_conf["player_id"])
             if include_values
-            else PlayerConfig.parse([], raw_conf)
+            else cast("PlayerConfig", PlayerConfig.parse([], raw_conf))
             for raw_conf in list(self.get(CONF_PLAYERS, {}).values())
             # filter out unavailable providers (only if we requested the full info)
             if (
@@ -447,7 +454,7 @@ class ConfigController:
                 raw_conf["available"] = False
                 raw_conf["name"] = raw_conf.get("name")
                 raw_conf["default_name"] = raw_conf.get("default_name") or raw_conf["player_id"]
-            return PlayerConfig.parse(conf_entries, raw_conf)
+            return cast("PlayerConfig", PlayerConfig.parse(conf_entries, raw_conf))
         msg = f"No config found for player id {player_id}"
         raise KeyError(msg)
 
@@ -480,7 +487,7 @@ class ConfigController:
         player_id: str,
         key: str,
         unpack_splitted_values: bool = False,
-    ) -> ConfigValueType:
+    ) -> ConfigValueType | tuple[str, ...] | list[tuple[str, ...]]:
         """Return single configentry value for a player."""
         conf = await self.get_player_config(player_id)
         if unpack_splitted_values:
@@ -499,9 +506,12 @@ class ConfigController:
 
         Note that this only returns the stored value without any validation or default.
         """
-        return self.get(
-            f"{CONF_PLAYERS}/{player_id}/values/{key}",
-            self.get(f"{CONF_PLAYERS}/{player_id}/{key}", default),
+        return cast(
+            "ConfigValueType",
+            self.get(
+                f"{CONF_PLAYERS}/{player_id}/values/{key}",
+                self.get(f"{CONF_PLAYERS}/{player_id}/{key}", default),
+            ),
         )
 
     def get_base_player_config(self, player_id: str, provider: str) -> PlayerConfig:
@@ -516,7 +526,7 @@ class ConfigController:
                 "player_id": player_id,
                 "provider": provider,
             }
-        return PlayerConfig.parse([], raw_conf)
+        return cast("PlayerConfig", PlayerConfig.parse([], raw_conf))
 
     @api_command("config/players/save")
     async def save_player_config(
@@ -527,7 +537,7 @@ class ConfigController:
         changed_keys = config.update(values)
         if not changed_keys:
             # no changes
-            return None
+            return config
         # validate/handle the update in the player manager
         await self.mass.players.on_player_config_change(config, changed_keys)
         # actually store changes (if the above did not raise)
@@ -602,9 +612,15 @@ class ConfigController:
                 dsp_config.filters.append(
                     ToneControlFilter(
                         enabled=True,
-                        bass_level=deprecated_eq_bass,
-                        mid_level=deprecated_eq_mid,
-                        treble_level=deprecated_eq_treble,
+                        bass_level=float(deprecated_eq_bass)
+                        if isinstance(deprecated_eq_bass, (int, float, str))
+                        else 0.0,
+                        mid_level=float(deprecated_eq_mid)
+                        if isinstance(deprecated_eq_mid, (int, float, str))
+                        else 0.0,
+                        treble_level=float(deprecated_eq_treble)
+                        if isinstance(deprecated_eq_treble, (int, float, str))
+                        else 0.0,
                     )
                 )
 
@@ -748,17 +764,20 @@ class ConfigController:
             instance_id = f"{manifest.domain}--{shortuuid.random(8)}"
         else:
             instance_id = manifest.domain
-        default_config: ProviderConfig = ProviderConfig.parse(
-            config_entries,
-            {
-                "type": manifest.type.value,
-                "domain": manifest.domain,
-                "instance_id": instance_id,
-                "name": manifest.name,
-                # note: this will only work for providers that do
-                # not have any required config entries or provide defaults
-                "values": {},
-            },
+        default_config = cast(
+            "ProviderConfig",
+            ProviderConfig.parse(
+                config_entries,
+                {
+                    "type": manifest.type.value,
+                    "domain": manifest.domain,
+                    "instance_id": instance_id,
+                    "name": manifest.name,
+                    # note: this will only work for providers that do
+                    # not have any required config entries or provide defaults
+                    "values": {},
+                },
+            ),
         )
         default_config.validate()
         conf_key = f"{CONF_PROVIDERS}/{default_config.instance_id}"
@@ -770,9 +789,12 @@ class ConfigController:
         return [
             await self.get_core_config(core_controller)
             if include_values
-            else CoreConfig.parse(
-                [],
-                self.get(f"{CONF_CORE}/{core_controller}", {"domain": core_controller}),
+            else cast(
+                "CoreConfig",
+                CoreConfig.parse(
+                    [],
+                    self.get(f"{CONF_CORE}/{core_controller}", {"domain": core_controller}),
+                ),
             )
             for core_controller in CONFIGURABLE_CORE_CONTROLLERS
         ]
@@ -782,7 +804,7 @@ class ConfigController:
         """Return configuration for a single core controller."""
         raw_conf = self.get(f"{CONF_CORE}/{domain}", {"domain": domain})
         config_entries = await self.get_core_config_entries(domain)
-        return CoreConfig.parse(config_entries, raw_conf)
+        return cast("CoreConfig", CoreConfig.parse(config_entries, raw_conf))
 
     @api_command("config/core/get_value")
     async def get_core_config_value(self, domain: str, key: str) -> ConfigValueType:
@@ -848,9 +870,12 @@ class ConfigController:
 
         Note that this only returns the stored value without any validation or default.
         """
-        return self.get(
-            f"{CONF_CORE}/{core_module}/values/{key}",
-            self.get(f"{CONF_CORE}/{core_module}/{key}", default),
+        return cast(
+            "ConfigValueType",
+            self.get(
+                f"{CONF_CORE}/{core_module}/values/{key}",
+                self.get(f"{CONF_CORE}/{core_module}/{key}", default),
+            ),
         )
 
     def get_raw_provider_config_value(
@@ -861,9 +886,12 @@ class ConfigController:
 
         Note that this only returns the stored value without any validation or default.
         """
-        return self.get(
-            f"{CONF_PROVIDERS}/{provider_instance}/values/{key}",
-            self.get(f"{CONF_PROVIDERS}/{provider_instance}/{key}", default),
+        return cast(
+            "ConfigValueType",
+            self.get(
+                f"{CONF_PROVIDERS}/{provider_instance}/values/{key}",
+                self.get(f"{CONF_PROVIDERS}/{provider_instance}/{key}", default),
+            ),
         )
 
     def set_raw_provider_config_value(
@@ -883,6 +911,9 @@ class ConfigController:
             msg = f"Invalid provider_instance: {provider_instance}"
             raise KeyError(msg)
         if encrypted:
+            if not isinstance(value, str):
+                msg = f"Cannot encrypt non-string value for key {key}"
+                raise ValueError(msg)
             value = self.encrypt_string(value)
         if key in BASE_KEYS:
             self.set(f"{CONF_PROVIDERS}/{provider_instance}/{key}", value)
@@ -934,6 +965,7 @@ class ConfigController:
         """Encrypt a (password)string with Fernet."""
         if str_value.startswith(ENCRYPT_SUFFIX):
             return str_value
+        assert self._fernet is not None
         return ENCRYPT_SUFFIX + self._fernet.encrypt(str_value.encode()).decode()
 
     def decrypt_string(self, encrypted_str: str) -> str:
@@ -942,6 +974,7 @@ class ConfigController:
             return encrypted_str
         if not encrypted_str.startswith(ENCRYPT_SUFFIX):
             return encrypted_str
+        assert self._fernet is not None
         try:
             return self._fernet.decrypt(encrypted_str.replace(ENCRYPT_SUFFIX, "").encode()).decode()
         except InvalidToken as err:
@@ -972,7 +1005,6 @@ class ConfigController:
         instance_id: str
         provider_config: dict[str, Any]
         player_config: dict[str, Any]
-        values: dict[str, ConfigValueType]
 
         # Older versions of MA can create corrupt entries with no domain if retrying
         # logic runs after a provider has been removed. Remove those corrupt entries.
@@ -1020,7 +1052,12 @@ class ConfigController:
         # migrate player_group entries
         ugp_found = False
         for player_config in self._data.get(CONF_PLAYERS, {}).values():
-            if not player_config.get("provider").startswith("player_group"):
+            provider = player_config.get("provider")
+            if (
+                not provider
+                or not isinstance(provider, str)
+                or not provider.startswith("player_group")
+            ):
                 continue
             if not (values := player_config.get("values")):
                 continue
@@ -1144,7 +1181,7 @@ class ConfigController:
         self,
         provider_domain: str,
         values: dict[str, ConfigValueType],
-    ) -> list[ConfigEntry] | ProviderConfig:
+    ) -> ProviderConfig:
         """
         Add new Provider (instance).
 
@@ -1181,15 +1218,18 @@ class ConfigController:
         config_entries = await self.get_provider_config_entries(
             provider_domain=provider_domain, instance_id=instance_id, values=values
         )
-        config: ProviderConfig = ProviderConfig.parse(
-            config_entries,
-            {
-                "type": manifest.type.value,
-                "domain": manifest.domain,
-                "instance_id": instance_id,
-                "default_name": manifest.name,
-                "values": values,
-            },
+        config = cast(
+            "ProviderConfig",
+            ProviderConfig.parse(
+                config_entries,
+                {
+                    "type": manifest.type.value,
+                    "domain": manifest.domain,
+                    "instance_id": instance_id,
+                    "default_name": manifest.name,
+                    "values": values,
+                },
+            ),
         )
         # validate the new config
         config.validate()
index ae96ca3ffd8ea33e6a0fe9eec24a730aa6bcb173..24d106fa3a3628621e155db0a2104f2981e76d04 100644 (file)
@@ -132,7 +132,6 @@ enable_error_code = [
 ]
 exclude = [
   '^music_assistant/controllers/cache.py$',
-  '^music_assistant/controllers/config.py$',
   '^music_assistant/controllers/media/albums.py*$',
   '^music_assistant/controllers/media/artists.py*$',
   '^music_assistant/controllers/media/audiobooks.py*$',