From: OzGav Date: Tue, 25 Nov 2025 10:46:10 +0000 (+1000) Subject: Typing fixes for the playlists controller (#2628) X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=6c3fded49aafb5d3cb0a3dd65ec5ea49201ed5ad;p=music-assistant-server.git Typing fixes for the playlists controller (#2628) * Typing fixes for the playlists controller * Re-add comment * Fix typos * Restore comments * Remove db asserts * Resolve conflicts * adjust comment --------- Co-authored-by: Marvin Schenkel --- diff --git a/music_assistant/controllers/media/playlists.py b/music_assistant/controllers/media/playlists.py index d08fa504..541e5c43 100644 --- a/music_assistant/controllers/media/playlists.py +++ b/music_assistant/controllers/media/playlists.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import AsyncGenerator -from typing import cast +from typing import TYPE_CHECKING, cast from music_assistant_models.enums import MediaType, ProviderFeature from music_assistant_models.errors import ( @@ -23,6 +23,9 @@ from music_assistant.models.music_provider import MusicProvider from .base import MediaControllerBase +if TYPE_CHECKING: + from music_assistant import MusicAssistant + class PlaylistController(MediaControllerBase[Playlist]): """Controller managing MediaItems of type Playlist.""" @@ -31,9 +34,9 @@ class PlaylistController(MediaControllerBase[Playlist]): media_type = MediaType.PLAYLIST item_cls = Playlist - def __init__(self, *args, **kwargs) -> None: + def __init__(self, mass: MusicAssistant) -> None: """Initialize class.""" - super().__init__(*args, **kwargs) + super().__init__(mass) # register (extra) api handlers api_base = self.api_base self.mass.register_api_command(f"music/{api_base}/create_playlist", self.create_playlist) @@ -112,6 +115,8 @@ class PlaylistController(MediaControllerBase[Playlist]): raise ProviderUnavailableError else: provider = self.mass.get_provider("builtin") + # grab all existing track ids in the playlist so we can check for duplicates + provider = cast("MusicProvider", provider) if "/" in name or "\\" in name or ".." in name: msg = f"{name} is not a valid Playlist name" @@ -143,15 +148,23 @@ class PlaylistController(MediaControllerBase[Playlist]): playlist_prov_map = next(iter(playlist.provider_mappings)) playlist_prov = self.mass.get_provider(playlist_prov_map.provider_instance) if not playlist_prov or not playlist_prov.available: - msg = f"Provider {playlist_prov_map.provider_instance} is not available" - raise ProviderUnavailableError(msg) - cur_playlist_track_ids = set() - cur_playlist_track_uris = set() + raise ProviderUnavailableError( + f"Provider {playlist_prov_map.provider_instance} is not available" + ) + playlist_prov = cast("MusicProvider", playlist_prov) + + # sets to track existing tracks + cur_playlist_track_ids: set[str] = set() + cur_playlist_track_uris: set[str] = set() + + # collect current track IDs and URIs async for item in self.tracks(playlist.item_id, playlist.provider): - cur_playlist_track_uris.add(item.item_id) - cur_playlist_track_uris.add(item.uri) + if item.item_id: + cur_playlist_track_ids.add(item.item_id) + if item.uri: + cur_playlist_track_uris.add(item.uri) - # unwrap all uri's to track uri's + # unwrap URIs to individual track URIs unwrapped_uris: list[str] = [] for uri in uris: # URI could be a playlist or album uri, unwrap it @@ -167,13 +180,16 @@ class PlaylistController(MediaControllerBase[Playlist]): media_type_str, item_id = rest.split("/", 1) media_type = MediaType(media_type_str) if media_type == MediaType.ALBUM: - for track in await self.mass.music.albums.tracks( + album_tracks = await self.mass.music.albums.tracks( item_id, provider_instance_id_or_domain - ): - unwrapped_uris.append(track.uri) + ) + for track in album_tracks: + if track.uri is not None: + unwrapped_uris.append(track.uri) elif media_type == MediaType.PLAYLIST: - for track in await self.tracks(item_id, provider_instance_id_or_domain): - unwrapped_uris.append(track.uri) + async for track in self.tracks(item_id, provider_instance_id_or_domain): + if track.uri is not None: + unwrapped_uris.append(track.uri) elif media_type == MediaType.TRACK: unwrapped_uris.append(uri) else: @@ -330,6 +346,12 @@ class PlaylistController(MediaControllerBase[Playlist]): raise InvalidDataError(msg) for prov_mapping in playlist.provider_mappings: provider = self.mass.get_provider(prov_mapping.provider_instance) + if not provider or not isinstance(provider, MusicProvider): + self.logger.warning( + "Provider %s is not available or does not support playlist editing", + prov_mapping.provider_domain, + ) + continue if ProviderFeature.PLAYLIST_TRACKS_EDIT not in provider.supported_features: self.logger.warning( "Provider %s does not support editing playlists", @@ -340,7 +362,7 @@ class PlaylistController(MediaControllerBase[Playlist]): await self.update_item_in_library(db_playlist_id, playlist) - async def _add_library_item(self, item: Playlist) -> int: + async def _add_library_item(self, item: Playlist, overwrite_existing: bool = False) -> int: """Add a new record to the database.""" db_id = await self.mass.music.database.insert( self.db_table, @@ -353,7 +375,7 @@ class PlaylistController(MediaControllerBase[Playlist]): "metadata": serialize_to_json(item.metadata), "external_ids": serialize_to_json(item.external_ids), "search_name": create_safe_string(item.name, True, True), - "search_sort_name": create_safe_string(item.sort_name, True, True), + "search_sort_name": create_safe_string(item.sort_name or "", True, True), }, ) # update/set provider_mappings table @@ -362,7 +384,7 @@ class PlaylistController(MediaControllerBase[Playlist]): return db_id async def _update_library_item( - self, item_id: int, update: Playlist, overwrite: bool = False + self, item_id: str | int, update: Playlist, overwrite: bool = False ) -> None: """Update existing record in the database.""" db_id = int(item_id) # ensure integer @@ -386,7 +408,7 @@ class PlaylistController(MediaControllerBase[Playlist]): update.external_ids if overwrite else cur_item.external_ids ), "search_name": create_safe_string(name, True, True), - "search_sort_name": create_safe_string(sort_name, True, True), + "search_sort_name": create_safe_string(sort_name or "", True, True), }, ) # update/set provider_mappings table @@ -398,7 +420,7 @@ class PlaylistController(MediaControllerBase[Playlist]): await self.set_provider_mappings(db_id, provider_mappings, overwrite) self.logger.debug("updated %s in database: (id %s)", update.name, db_id) - @guard_single_request + @guard_single_request # type: ignore[type-var] # TODO: fix typing in util.py async def _get_provider_playlist_tracks( self, item_id: str, @@ -418,7 +440,7 @@ class PlaylistController(MediaControllerBase[Playlist]): self, item_id: str, provider_instance_id_or_domain: str, - ): + ) -> list[Track]: """Get the list of base tracks from the controller used to calculate the dynamic radio.""" return [ x @@ -437,7 +459,7 @@ class PlaylistController(MediaControllerBase[Playlist]): def _refresh_playlist_tracks(self, playlist: Playlist) -> None: """Refresh playlist tracks by forcing a cache refresh.""" - async def _refresh(playlist: Playlist): + async def _refresh(playlist: Playlist) -> None: # simply iterate all tracks with force_refresh=True to refresh the cache async for _ in self.tracks(playlist.item_id, playlist.provider, force_refresh=True): pass diff --git a/pyproject.toml b/pyproject.toml index a09cab87..b36ebddb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,7 +130,6 @@ enable_error_code = [ "truthy-iterable", ] exclude = [ - '^music_assistant/controllers/media/playlists.py*$', '^music_assistant/controllers/media/tracks.py*$', '^music_assistant/controllers/music.py$', '^music_assistant/helpers/app_vars.py',