Types fixes for the Albums Controller (#2632)
authorOzGav <gavnosp@hotmail.com>
Fri, 21 Nov 2025 17:35:08 +0000 (03:35 +1000)
committerGitHub <noreply@github.com>
Fri, 21 Nov 2025 17:35:08 +0000 (18:35 +0100)
* Typing fixes for the albums controller

* Revert instance check

* Fix type annotations and mypy compliance

* Remove db asserts

* Switch uniquelist to list

* typo

music_assistant/controllers/media/albums.py
music_assistant/providers/builtin/__init__.py
pyproject.toml

index b18e4c35e92be91ed43dd90564e5feab66563352..9e248398301dd013ae27fdd74796cd34c5a79e48 100644 (file)
@@ -8,7 +8,14 @@ from typing import TYPE_CHECKING, Any, cast
 
 from music_assistant_models.enums import AlbumType, MediaType, ProviderFeature
 from music_assistant_models.errors import InvalidDataError, MediaNotFoundError, MusicAssistantError
-from music_assistant_models.media_items import Album, Artist, ItemMapping, Track, UniqueList
+from music_assistant_models.media_items import (
+    Album,
+    Artist,
+    ItemMapping,
+    MediaItemImage,
+    Track,
+    UniqueList,
+)
 
 from music_assistant.constants import DB_TABLE_ALBUM_ARTISTS, DB_TABLE_ALBUM_TRACKS, DB_TABLE_ALBUMS
 from music_assistant.controllers.media.base import MediaControllerBase
@@ -20,9 +27,10 @@ from music_assistant.helpers.compare import (
     loose_compare_strings,
 )
 from music_assistant.helpers.json import serialize_to_json
+from music_assistant.models.music_provider import MusicProvider
 
 if TYPE_CHECKING:
-    from music_assistant.models.music_provider import MusicProvider
+    from music_assistant import MusicAssistant
 
 
 class AlbumsController(MediaControllerBase[Album]):
@@ -32,9 +40,9 @@ class AlbumsController(MediaControllerBase[Album]):
     media_type = MediaType.ALBUM
     item_cls = Album
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(self, mass: MusicAssistant) -> None:
         """Initialize class."""
-        super().__init__(*args, **kwargs)
+        super().__init__(mass)
         self.base_query = """
         SELECT
             albums.*,
@@ -78,7 +86,7 @@ class AlbumsController(MediaControllerBase[Album]):
             return album
 
         # append artist details to full album item (resolve ItemMappings)
-        album_artists = UniqueList()
+        album_artists: UniqueList[Artist | ItemMapping] = UniqueList()
         for artist in album.artists:
             if not isinstance(artist, ItemMapping):
                 album_artists.append(artist)
@@ -106,7 +114,7 @@ class AlbumsController(MediaControllerBase[Album]):
         album_types: list[AlbumType] | None = None,
     ) -> list[Album]:
         """Get in-database albums."""
-        extra_query_params: dict[str, Any] = extra_query_params or {}
+        extra_query_params = extra_query_params or {}
         extra_query_parts: list[str] = [extra_query] if extra_query else []
         extra_join_parts: list[str] = []
         artist_table_joined = False
@@ -223,7 +231,7 @@ class AlbumsController(MediaControllerBase[Album]):
         item_id: str,
         provider_instance_id_or_domain: str,
         in_library_only: bool = False,
-    ) -> UniqueList[Track]:
+    ) -> list[Track]:
         """Return album tracks for the given provider album id."""
         # always check if we have a library item for this album
         library_album = await self.get_library_item_by_prov_id(
@@ -231,8 +239,9 @@ class AlbumsController(MediaControllerBase[Album]):
         )
         if not library_album:
             return await self._get_provider_album_tracks(item_id, provider_instance_id_or_domain)
+
         db_items = await self.get_library_album_tracks(library_album.item_id)
-        result: UniqueList[Track] = UniqueList(db_items)
+        result: list[Track] = list(db_items)
         if in_library_only:
             # return in-library items only
             return sorted(db_items, key=lambda x: (x.disc_number, x.track_number))
@@ -243,7 +252,7 @@ class AlbumsController(MediaControllerBase[Album]):
         unique_ids: set[str] = {f"{x.disc_number}.{x.track_number}" for x in db_items}
         unique_ids.update({f"{x.name.lower()}.{x.version.lower()}" for x in db_items})
         for db_item in db_items:
-            unique_ids.add(x.item_id for x in db_item.provider_mappings)
+            unique_ids.update(x.item_id for x in db_item.provider_mappings)
         for provider_mapping in library_album.provider_mappings:
             provider_tracks = await self._get_provider_album_tracks(
                 provider_mapping.item_id, provider_mapping.provider_instance
@@ -267,8 +276,8 @@ class AlbumsController(MediaControllerBase[Album]):
                     and db_track.track_number != provider_track.track_number
                 ):
                     await self._set_album_track(
-                        db_id=library_album.item_id,
-                        db_track_id=db_track.item_id,
+                        db_id=int(library_album.item_id),
+                        db_track_id=int(db_track.item_id),
                         track=provider_track,
                     )
                 if provider_track.item_id in unique_ids:
@@ -283,8 +292,8 @@ class AlbumsController(MediaControllerBase[Album]):
                 provider_track.album = library_album
                 # always prefer album image
                 album_images = [library_album.image] if library_album.image else []
-                track_images = provider_track.metadata.images or []
-                provider_track.metadata.images = album_images + track_images
+                track_images: list[MediaItemImage] = provider_track.metadata.images or []
+                provider_track.metadata.images = UniqueList(album_images + track_images)
                 result.append(provider_track)
         # NOTE: we need to return the results sorted on disc/track here
         # to ensure the correct order at playback
@@ -301,7 +310,7 @@ class AlbumsController(MediaControllerBase[Album]):
         result: UniqueList[Album] = UniqueList()
         for provider_id in self.mass.music.get_unique_providers():
             provider = self.mass.get_provider(provider_id)
-            if not provider:
+            if not provider or not isinstance(provider, MusicProvider):
                 continue
             if not provider.library_supported(MediaType.ALBUM):
                 continue
@@ -334,10 +343,10 @@ class AlbumsController(MediaControllerBase[Album]):
         album = self.album_from_item_mapping(item)
         return await self.add_item_to_library(album)
 
-    async def _add_library_item(self, item: Album) -> int:
+    async def _add_library_item(self, item: Album, overwrite_existing: bool = False) -> int:
         """Add a new record to the database."""
-        if not isinstance(item, Album):
-            msg = "Not a valid Album object (ItemMapping can not be added to db)"
+        if not isinstance(item, Album):  # TODO: Remove this once the codebase is fully typed
+            msg = "Not a valid Album object (ItemMapping can not be added to db)"  # type: ignore[unreachable]
             raise InvalidDataError(msg)
         db_id = await self.mass.music.database.insert(
             self.db_table,
@@ -351,7 +360,7 @@ class AlbumsController(MediaControllerBase[Album]):
                 "metadata": serialize_to_json(item.metadata),
                 "external_ids": serialize_to_json(item.external_ids),
                 "search_name": create_safe_string(item.name, True, True),
-                "search_sort_name": create_safe_string(item.sort_name, True, True),
+                "search_sort_name": create_safe_string(item.sort_name or "", True, True),
             },
         )
         # update/set provider_mappings table
@@ -389,7 +398,7 @@ class AlbumsController(MediaControllerBase[Album]):
                     update.external_ids if overwrite else cur_item.external_ids
                 ),
                 "search_name": create_safe_string(name, True, True),
-                "search_sort_name": create_safe_string(sort_name, True, True),
+                "search_sort_name": create_safe_string(sort_name or "", True, True),
             },
         )
         # update/set provider_mappings table
@@ -417,7 +426,7 @@ class AlbumsController(MediaControllerBase[Album]):
         self,
         item_id: str,
         provider_instance_id_or_domain: str,
-    ):
+    ) -> list[Track]:
         """Get the list of base tracks from the controller used to calculate the dynamic radio."""
         return await self.tracks(item_id, provider_instance_id_or_domain, in_library_only=False)
 
@@ -443,7 +452,7 @@ class AlbumsController(MediaControllerBase[Album]):
         self, db_id: int, artist: Artist | ItemMapping, overwrite: bool = False
     ) -> ItemMapping:
         """Store Album Artist info."""
-        db_artist: Artist | ItemMapping = None
+        db_artist: Artist | ItemMapping | None = None
         if artist.provider == "library":
             db_artist = artist
         elif existing := await self.mass.music.artists.get_library_item_by_prov_id(
@@ -452,9 +461,14 @@ class AlbumsController(MediaControllerBase[Album]):
             db_artist = existing
 
         if not db_artist or overwrite:
-            db_artist = await self.mass.music.artists.add_item_to_library(
-                artist, overwrite_existing=overwrite
-            )
+            # Type narrowing: if artist is an ItemMapping, convert it or handle it
+            if isinstance(artist, ItemMapping):
+                # ItemMapping can't be added directly, use the existing or skip
+                db_artist = artist
+            else:
+                db_artist = await self.mass.music.artists.add_item_to_library(
+                    artist, overwrite_existing=overwrite
+                )
         # write (or update) record in album_artists table
         await self.mass.music.database.insert_or_replace(
             DB_TABLE_ALBUM_ARTISTS,
@@ -489,7 +503,7 @@ class AlbumsController(MediaControllerBase[Album]):
             return  # guard
         artist_name = db_album.artists[0].name
 
-        async def find_prov_match(provider: MusicProvider):
+        async def find_prov_match(provider: MusicProvider) -> bool:
             self.logger.debug(
                 "Trying to match album %s on provider %s", db_album.name, provider.name
             )
index 4250ccae555b3afbe00d75d7156f082c0a3dd484..6e30610399727defbc4c12f825d371177b03c95e 100644 (file)
@@ -561,7 +561,7 @@ class BuiltinProvider(MusicProvider):
             result.append(item)
         return result
 
-    async def _get_builtin_playlist_random_album(self) -> UniqueList[Track]:
+    async def _get_builtin_playlist_random_album(self) -> list[Track]:
         for in_library_only in (True, False):
             for min_tracks_required in (10, 5, 1):
                 for random_album in await self.mass.music.albums.library_items(
@@ -575,7 +575,7 @@ class BuiltinProvider(MusicProvider):
                     for idx, track in enumerate(tracks, 1):
                         track.position = idx
                     return tracks
-        return UniqueList()
+        return []
 
     async def _get_builtin_playlist_random_artist(self) -> list[Track]:
         for in_library_only in (True, False):
index 341064e26ad0b491087a823da2098e7b16768fa1..264417d777e9d474e8556cbd3c1f07f8847784cf 100644 (file)
@@ -130,7 +130,6 @@ enable_error_code = [
   "truthy-iterable",
 ]
 exclude = [
-  '^music_assistant/controllers/media/albums.py*$',
   '^music_assistant/controllers/media/base.py*$',
   '^music_assistant/controllers/media/playlists.py*$',
   '^music_assistant/controllers/media/podcasts.py*$',