Fix: apple music playlists (#1330)
authorMarvin Schenkel <marvinschenkel@gmail.com>
Wed, 5 Jun 2024 17:29:50 +0000 (19:29 +0200)
committerGitHub <noreply@github.com>
Wed, 5 Jun 2024 17:29:50 +0000 (19:29 +0200)
* Fix playlist bugs + add playlist track paging

* Fix playlist bugs + add playlist track paging

music_assistant/server/providers/apple_music/__init__.py

index 8f980ec3223b5f5658180b6abd34e9049815a0d6..ec269d97ff974e4f59276bd236b64d53d51ef231 100644 (file)
@@ -239,23 +239,21 @@ class AppleMusicProvider(MusicProvider):
         response = await self._get_data(endpoint, include="artists")
         return [self._parse_track(track) for track in response["data"] if track["id"]]
 
-    async def get_playlist_tracks(
-        self, prov_playlist_id, offset, limit
-    ) -> AsyncGenerator[Track, None]:
+    async def get_playlist_tracks(self, prov_playlist_id, offset, limit) -> list[Track]:
         """Get all playlist tracks for given playlist id."""
-        # TODO: Import paging
         if self._is_catalog_id(prov_playlist_id):
             endpoint = f"catalog/{self._storefront}/playlists/{prov_playlist_id}/tracks"
         else:
             endpoint = f"me/library/playlists/{prov_playlist_id}/tracks"
-        count = 1
         result = []
-        for track in await self._get_all_items(endpoint, include="artists,catalog"):
+        response = await self._get_data(
+            endpoint, include="artists,catalog", limit=limit, offset=offset
+        )
+        for index, track in enumerate(response["data"]):
             if track and track["id"]:
                 parsed_track = self._parse_track(track)
-                parsed_track.position = count
+                parsed_track.position = offset + index + 1
                 result.append(parsed_track)
-                count += 1
         return result
 
     async def get_artist_albums(self, prov_artist_id) -> list[Album]:
@@ -503,17 +501,18 @@ class AppleMusicProvider(MusicProvider):
             track.external_ids.add((ExternalID.ISRC, isrc))
         return track
 
-    def _parse_playlist(self, playlist_obj):
+    def _parse_playlist(self, playlist_obj) -> Playlist:
         """Parse Apple Music playlist object to generic layout."""
         attributes = playlist_obj["attributes"]
+        playlist_id = attributes["playParams"].get("globalId") or playlist_obj["id"]
         playlist = Playlist(
-            item_id=playlist_obj["id"],
+            item_id=playlist_id,
             provider=self.domain,
             name=attributes["name"],
             owner=attributes.get("curatorName", "me"),
             provider_mappings={
                 ProviderMapping(
-                    item_id=playlist_obj["id"],
+                    item_id=playlist_id,
                     provider_domain=self.domain,
                     provider_instance=self.instance_id,
                     url=attributes.get("url"),
@@ -521,10 +520,13 @@ class AppleMusicProvider(MusicProvider):
             },
         )
         if artwork := attributes.get("artwork"):
+            url = artwork["url"]
+            if artwork["width"] and artwork["height"]:
+                url = url.format(w=artwork["width"], h=artwork["height"])
             playlist.metadata.images = [
                 MediaItemImage(
                     type=ImageType.THUMB,
-                    path=artwork["url"].format(w=artwork["width"], h=artwork["height"]),
+                    path=url,
                     provider=self.instance_id,
                     remotely_accessible=True,
                 )
@@ -545,12 +547,10 @@ class AppleMusicProvider(MusicProvider):
             kwargs["limit"] = limit
             kwargs["offset"] = offset
             result = await self._get_data(endpoint, **kwargs)
-            offset += limit
-            if not result or key not in result or not result[key]:
-                break
             all_items += result[key]
-            if len(result[key]) < limit:
+            if not result.get("next"):
                 break
+            offset += limit
         return all_items
 
     @throttle_with_retries
@@ -614,7 +614,7 @@ class AppleMusicProvider(MusicProvider):
 
     def _is_catalog_id(self, catalog_id: str) -> bool:
         """Check if input is a catalog id, or a library id."""
-        return catalog_id.isnumeric()
+        return catalog_id.isnumeric() or catalog_id.startswith("pl.")
 
     async def _fetch_song_stream_metadata(self, song_id: str) -> str:
         """Get the stream URL for a song from Apple Music."""