Typing fixes for the playlists controller (#2628)
authorOzGav <gavnosp@hotmail.com>
Tue, 25 Nov 2025 10:46:10 +0000 (20:46 +1000)
committerGitHub <noreply@github.com>
Tue, 25 Nov 2025 10:46:10 +0000 (10:46 +0000)
* Typing fixes for the playlists controller

* Re-add comment

* Fix typos

* Restore comments

* Remove db asserts

* Resolve conflicts

* adjust comment

---------

Co-authored-by: Marvin Schenkel <marvinschenkel@gmail.com>
music_assistant/controllers/media/playlists.py
pyproject.toml

index d08fa5044e4a3ac0fddb4e78ab6fd2080acac45b..541e5c436ecce6aef8cd3fa91ed1e289411a3049 100644 (file)
@@ -3,7 +3,7 @@
 from __future__ import annotations
 
 from collections.abc import AsyncGenerator
-from typing import cast
+from typing import TYPE_CHECKING, cast
 
 from music_assistant_models.enums import MediaType, ProviderFeature
 from music_assistant_models.errors import (
@@ -23,6 +23,9 @@ from music_assistant.models.music_provider import MusicProvider
 
 from .base import MediaControllerBase
 
+if TYPE_CHECKING:
+    from music_assistant import MusicAssistant
+
 
 class PlaylistController(MediaControllerBase[Playlist]):
     """Controller managing MediaItems of type Playlist."""
@@ -31,9 +34,9 @@ class PlaylistController(MediaControllerBase[Playlist]):
     media_type = MediaType.PLAYLIST
     item_cls = Playlist
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(self, mass: MusicAssistant) -> None:
         """Initialize class."""
-        super().__init__(*args, **kwargs)
+        super().__init__(mass)
         # register (extra) api handlers
         api_base = self.api_base
         self.mass.register_api_command(f"music/{api_base}/create_playlist", self.create_playlist)
@@ -112,6 +115,8 @@ class PlaylistController(MediaControllerBase[Playlist]):
                 raise ProviderUnavailableError
         else:
             provider = self.mass.get_provider("builtin")
+        # grab all existing track ids in the playlist so we can check for duplicates
+        provider = cast("MusicProvider", provider)
 
         if "/" in name or "\\" in name or ".." in name:
             msg = f"{name} is not a valid Playlist name"
@@ -143,15 +148,23 @@ class PlaylistController(MediaControllerBase[Playlist]):
         playlist_prov_map = next(iter(playlist.provider_mappings))
         playlist_prov = self.mass.get_provider(playlist_prov_map.provider_instance)
         if not playlist_prov or not playlist_prov.available:
-            msg = f"Provider {playlist_prov_map.provider_instance} is not available"
-            raise ProviderUnavailableError(msg)
-        cur_playlist_track_ids = set()
-        cur_playlist_track_uris = set()
+            raise ProviderUnavailableError(
+                f"Provider {playlist_prov_map.provider_instance} is not available"
+            )
+        playlist_prov = cast("MusicProvider", playlist_prov)
+
+        # sets to track existing tracks
+        cur_playlist_track_ids: set[str] = set()
+        cur_playlist_track_uris: set[str] = set()
+
+        # collect current track IDs and URIs
         async for item in self.tracks(playlist.item_id, playlist.provider):
-            cur_playlist_track_uris.add(item.item_id)
-            cur_playlist_track_uris.add(item.uri)
+            if item.item_id:
+                cur_playlist_track_ids.add(item.item_id)
+            if item.uri:
+                cur_playlist_track_uris.add(item.uri)
 
-        # unwrap all uri's to track uri's
+        # unwrap URIs to individual track URIs
         unwrapped_uris: list[str] = []
         for uri in uris:
             # URI could be a playlist or album uri, unwrap it
@@ -167,13 +180,16 @@ class PlaylistController(MediaControllerBase[Playlist]):
             media_type_str, item_id = rest.split("/", 1)
             media_type = MediaType(media_type_str)
             if media_type == MediaType.ALBUM:
-                for track in await self.mass.music.albums.tracks(
+                album_tracks = await self.mass.music.albums.tracks(
                     item_id, provider_instance_id_or_domain
-                ):
-                    unwrapped_uris.append(track.uri)
+                )
+                for track in album_tracks:
+                    if track.uri is not None:
+                        unwrapped_uris.append(track.uri)
             elif media_type == MediaType.PLAYLIST:
-                for track in await self.tracks(item_id, provider_instance_id_or_domain):
-                    unwrapped_uris.append(track.uri)
+                async for track in self.tracks(item_id, provider_instance_id_or_domain):
+                    if track.uri is not None:
+                        unwrapped_uris.append(track.uri)
             elif media_type == MediaType.TRACK:
                 unwrapped_uris.append(uri)
             else:
@@ -330,6 +346,12 @@ class PlaylistController(MediaControllerBase[Playlist]):
             raise InvalidDataError(msg)
         for prov_mapping in playlist.provider_mappings:
             provider = self.mass.get_provider(prov_mapping.provider_instance)
+            if not provider or not isinstance(provider, MusicProvider):
+                self.logger.warning(
+                    "Provider %s is not available or does not support playlist editing",
+                    prov_mapping.provider_domain,
+                )
+                continue
             if ProviderFeature.PLAYLIST_TRACKS_EDIT not in provider.supported_features:
                 self.logger.warning(
                     "Provider %s does not support editing playlists",
@@ -340,7 +362,7 @@ class PlaylistController(MediaControllerBase[Playlist]):
 
         await self.update_item_in_library(db_playlist_id, playlist)
 
-    async def _add_library_item(self, item: Playlist) -> int:
+    async def _add_library_item(self, item: Playlist, overwrite_existing: bool = False) -> int:
         """Add a new record to the database."""
         db_id = await self.mass.music.database.insert(
             self.db_table,
@@ -353,7 +375,7 @@ class PlaylistController(MediaControllerBase[Playlist]):
                 "metadata": serialize_to_json(item.metadata),
                 "external_ids": serialize_to_json(item.external_ids),
                 "search_name": create_safe_string(item.name, True, True),
-                "search_sort_name": create_safe_string(item.sort_name, True, True),
+                "search_sort_name": create_safe_string(item.sort_name or "", True, True),
             },
         )
         # update/set provider_mappings table
@@ -362,7 +384,7 @@ class PlaylistController(MediaControllerBase[Playlist]):
         return db_id
 
     async def _update_library_item(
-        self, item_id: int, update: Playlist, overwrite: bool = False
+        self, item_id: str | int, update: Playlist, overwrite: bool = False
     ) -> None:
         """Update existing record in the database."""
         db_id = int(item_id)  # ensure integer
@@ -386,7 +408,7 @@ class PlaylistController(MediaControllerBase[Playlist]):
                     update.external_ids if overwrite else cur_item.external_ids
                 ),
                 "search_name": create_safe_string(name, True, True),
-                "search_sort_name": create_safe_string(sort_name, True, True),
+                "search_sort_name": create_safe_string(sort_name or "", True, True),
             },
         )
         # update/set provider_mappings table
@@ -398,7 +420,7 @@ class PlaylistController(MediaControllerBase[Playlist]):
         await self.set_provider_mappings(db_id, provider_mappings, overwrite)
         self.logger.debug("updated %s in database: (id %s)", update.name, db_id)
 
-    @guard_single_request
+    @guard_single_request  # type: ignore[type-var]  # TODO: fix typing in util.py
     async def _get_provider_playlist_tracks(
         self,
         item_id: str,
@@ -418,7 +440,7 @@ class PlaylistController(MediaControllerBase[Playlist]):
         self,
         item_id: str,
         provider_instance_id_or_domain: str,
-    ):
+    ) -> list[Track]:
         """Get the list of base tracks from the controller used to calculate the dynamic radio."""
         return [
             x
@@ -437,7 +459,7 @@ class PlaylistController(MediaControllerBase[Playlist]):
     def _refresh_playlist_tracks(self, playlist: Playlist) -> None:
         """Refresh playlist tracks by forcing a cache refresh."""
 
-        async def _refresh(playlist: Playlist):
+        async def _refresh(playlist: Playlist) -> None:
             # simply iterate all tracks with force_refresh=True to refresh the cache
             async for _ in self.tracks(playlist.item_id, playlist.provider, force_refresh=True):
                 pass
index a09cab8700d7428f66cd697b0ba164aec1ca8857..b36ebddb268836791d77e15f85223754c28ca727 100644 (file)
@@ -130,7 +130,6 @@ enable_error_code = [
   "truthy-iterable",
 ]
 exclude = [
-  '^music_assistant/controllers/media/playlists.py*$',
   '^music_assistant/controllers/media/tracks.py*$',
   '^music_assistant/controllers/music.py$',
   '^music_assistant/helpers/app_vars.py',