From 3ea8a913b963d66452f1a94a209747b9631e59f4 Mon Sep 17 00:00:00 2001 From: OzGav Date: Sat, 22 Nov 2025 03:35:08 +1000 Subject: [PATCH] Types fixes for the Albums Controller (#2632) * Typing fixes for the albums controller * Revert instance check * Fix type annotations and mypy compliance * Remove db asserts * Switch uniquelist to list * typo --- music_assistant/controllers/media/albums.py | 64 +++++++++++-------- music_assistant/providers/builtin/__init__.py | 4 +- pyproject.toml | 1 - 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/music_assistant/controllers/media/albums.py b/music_assistant/controllers/media/albums.py index b18e4c35..9e248398 100644 --- a/music_assistant/controllers/media/albums.py +++ b/music_assistant/controllers/media/albums.py @@ -8,7 +8,14 @@ from typing import TYPE_CHECKING, Any, cast from music_assistant_models.enums import AlbumType, MediaType, ProviderFeature from music_assistant_models.errors import InvalidDataError, MediaNotFoundError, MusicAssistantError -from music_assistant_models.media_items import Album, Artist, ItemMapping, Track, UniqueList +from music_assistant_models.media_items import ( + Album, + Artist, + ItemMapping, + MediaItemImage, + Track, + UniqueList, +) from music_assistant.constants import DB_TABLE_ALBUM_ARTISTS, DB_TABLE_ALBUM_TRACKS, DB_TABLE_ALBUMS from music_assistant.controllers.media.base import MediaControllerBase @@ -20,9 +27,10 @@ from music_assistant.helpers.compare import ( loose_compare_strings, ) from music_assistant.helpers.json import serialize_to_json +from music_assistant.models.music_provider import MusicProvider if TYPE_CHECKING: - from music_assistant.models.music_provider import MusicProvider + from music_assistant import MusicAssistant class AlbumsController(MediaControllerBase[Album]): @@ -32,9 +40,9 @@ class AlbumsController(MediaControllerBase[Album]): media_type = MediaType.ALBUM item_cls = Album - def __init__(self, *args, **kwargs) -> None: + def __init__(self, mass: MusicAssistant) -> None: """Initialize class.""" - super().__init__(*args, **kwargs) + super().__init__(mass) self.base_query = """ SELECT albums.*, @@ -78,7 +86,7 @@ class AlbumsController(MediaControllerBase[Album]): return album # append artist details to full album item (resolve ItemMappings) - album_artists = UniqueList() + album_artists: UniqueList[Artist | ItemMapping] = UniqueList() for artist in album.artists: if not isinstance(artist, ItemMapping): album_artists.append(artist) @@ -106,7 +114,7 @@ class AlbumsController(MediaControllerBase[Album]): album_types: list[AlbumType] | None = None, ) -> list[Album]: """Get in-database albums.""" - extra_query_params: dict[str, Any] = extra_query_params or {} + extra_query_params = extra_query_params or {} extra_query_parts: list[str] = [extra_query] if extra_query else [] extra_join_parts: list[str] = [] artist_table_joined = False @@ -223,7 +231,7 @@ class AlbumsController(MediaControllerBase[Album]): item_id: str, provider_instance_id_or_domain: str, in_library_only: bool = False, - ) -> UniqueList[Track]: + ) -> list[Track]: """Return album tracks for the given provider album id.""" # always check if we have a library item for this album library_album = await self.get_library_item_by_prov_id( @@ -231,8 +239,9 @@ class AlbumsController(MediaControllerBase[Album]): ) if not library_album: return await self._get_provider_album_tracks(item_id, provider_instance_id_or_domain) + db_items = await self.get_library_album_tracks(library_album.item_id) - result: UniqueList[Track] = UniqueList(db_items) + result: list[Track] = list(db_items) if in_library_only: # return in-library items only return sorted(db_items, key=lambda x: (x.disc_number, x.track_number)) @@ -243,7 +252,7 @@ class AlbumsController(MediaControllerBase[Album]): unique_ids: set[str] = {f"{x.disc_number}.{x.track_number}" for x in db_items} unique_ids.update({f"{x.name.lower()}.{x.version.lower()}" for x in db_items}) for db_item in db_items: - unique_ids.add(x.item_id for x in db_item.provider_mappings) + unique_ids.update(x.item_id for x in db_item.provider_mappings) for provider_mapping in library_album.provider_mappings: provider_tracks = await self._get_provider_album_tracks( provider_mapping.item_id, provider_mapping.provider_instance @@ -267,8 +276,8 @@ class AlbumsController(MediaControllerBase[Album]): and db_track.track_number != provider_track.track_number ): await self._set_album_track( - db_id=library_album.item_id, - db_track_id=db_track.item_id, + db_id=int(library_album.item_id), + db_track_id=int(db_track.item_id), track=provider_track, ) if provider_track.item_id in unique_ids: @@ -283,8 +292,8 @@ class AlbumsController(MediaControllerBase[Album]): provider_track.album = library_album # always prefer album image album_images = [library_album.image] if library_album.image else [] - track_images = provider_track.metadata.images or [] - provider_track.metadata.images = album_images + track_images + track_images: list[MediaItemImage] = provider_track.metadata.images or [] + provider_track.metadata.images = UniqueList(album_images + track_images) result.append(provider_track) # NOTE: we need to return the results sorted on disc/track here # to ensure the correct order at playback @@ -301,7 +310,7 @@ class AlbumsController(MediaControllerBase[Album]): result: UniqueList[Album] = UniqueList() for provider_id in self.mass.music.get_unique_providers(): provider = self.mass.get_provider(provider_id) - if not provider: + if not provider or not isinstance(provider, MusicProvider): continue if not provider.library_supported(MediaType.ALBUM): continue @@ -334,10 +343,10 @@ class AlbumsController(MediaControllerBase[Album]): album = self.album_from_item_mapping(item) return await self.add_item_to_library(album) - async def _add_library_item(self, item: Album) -> int: + async def _add_library_item(self, item: Album, overwrite_existing: bool = False) -> int: """Add a new record to the database.""" - if not isinstance(item, Album): - msg = "Not a valid Album object (ItemMapping can not be added to db)" + if not isinstance(item, Album): # TODO: Remove this once the codebase is fully typed + msg = "Not a valid Album object (ItemMapping can not be added to db)" # type: ignore[unreachable] raise InvalidDataError(msg) db_id = await self.mass.music.database.insert( self.db_table, @@ -351,7 +360,7 @@ class AlbumsController(MediaControllerBase[Album]): "metadata": serialize_to_json(item.metadata), "external_ids": serialize_to_json(item.external_ids), "search_name": create_safe_string(item.name, True, True), - "search_sort_name": create_safe_string(item.sort_name, True, True), + "search_sort_name": create_safe_string(item.sort_name or "", True, True), }, ) # update/set provider_mappings table @@ -389,7 +398,7 @@ class AlbumsController(MediaControllerBase[Album]): update.external_ids if overwrite else cur_item.external_ids ), "search_name": create_safe_string(name, True, True), - "search_sort_name": create_safe_string(sort_name, True, True), + "search_sort_name": create_safe_string(sort_name or "", True, True), }, ) # update/set provider_mappings table @@ -417,7 +426,7 @@ class AlbumsController(MediaControllerBase[Album]): self, item_id: str, provider_instance_id_or_domain: str, - ): + ) -> list[Track]: """Get the list of base tracks from the controller used to calculate the dynamic radio.""" return await self.tracks(item_id, provider_instance_id_or_domain, in_library_only=False) @@ -443,7 +452,7 @@ class AlbumsController(MediaControllerBase[Album]): self, db_id: int, artist: Artist | ItemMapping, overwrite: bool = False ) -> ItemMapping: """Store Album Artist info.""" - db_artist: Artist | ItemMapping = None + db_artist: Artist | ItemMapping | None = None if artist.provider == "library": db_artist = artist elif existing := await self.mass.music.artists.get_library_item_by_prov_id( @@ -452,9 +461,14 @@ class AlbumsController(MediaControllerBase[Album]): db_artist = existing if not db_artist or overwrite: - db_artist = await self.mass.music.artists.add_item_to_library( - artist, overwrite_existing=overwrite - ) + # Type narrowing: if artist is an ItemMapping, convert it or handle it + if isinstance(artist, ItemMapping): + # ItemMapping can't be added directly, use the existing or skip + db_artist = artist + else: + db_artist = await self.mass.music.artists.add_item_to_library( + artist, overwrite_existing=overwrite + ) # write (or update) record in album_artists table await self.mass.music.database.insert_or_replace( DB_TABLE_ALBUM_ARTISTS, @@ -489,7 +503,7 @@ class AlbumsController(MediaControllerBase[Album]): return # guard artist_name = db_album.artists[0].name - async def find_prov_match(provider: MusicProvider): + async def find_prov_match(provider: MusicProvider) -> bool: self.logger.debug( "Trying to match album %s on provider %s", db_album.name, provider.name ) diff --git a/music_assistant/providers/builtin/__init__.py b/music_assistant/providers/builtin/__init__.py index 4250ccae..6e306103 100644 --- a/music_assistant/providers/builtin/__init__.py +++ b/music_assistant/providers/builtin/__init__.py @@ -561,7 +561,7 @@ class BuiltinProvider(MusicProvider): result.append(item) return result - async def _get_builtin_playlist_random_album(self) -> UniqueList[Track]: + async def _get_builtin_playlist_random_album(self) -> list[Track]: for in_library_only in (True, False): for min_tracks_required in (10, 5, 1): for random_album in await self.mass.music.albums.library_items( @@ -575,7 +575,7 @@ class BuiltinProvider(MusicProvider): for idx, track in enumerate(tracks, 1): track.position = idx return tracks - return UniqueList() + return [] async def _get_builtin_playlist_random_artist(self) -> list[Track]: for in_library_only in (True, False): diff --git a/pyproject.toml b/pyproject.toml index 341064e2..264417d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,7 +130,6 @@ enable_error_code = [ "truthy-iterable", ] exclude = [ - '^music_assistant/controllers/media/albums.py*$', '^music_assistant/controllers/media/base.py*$', '^music_assistant/controllers/media/playlists.py*$', '^music_assistant/controllers/media/podcasts.py*$', -- 2.34.1