Apple music improvements (#2607)
authorMarvin Schenkel <marvinschenkel@gmail.com>
Fri, 7 Nov 2025 15:51:42 +0000 (16:51 +0100)
committerGitHub <noreply@github.com>
Fri, 7 Nov 2025 15:51:42 +0000 (16:51 +0100)
music_assistant/providers/apple_music/__init__.py

index f2f1b670b3ba7afdfaacf489a814068d8c6d299d..e64be99a2ea1b79b636019a967439de7cf8844ff 100644 (file)
@@ -90,6 +90,13 @@ SUPPORTED_FEATURES = {
     ProviderFeature.ARTIST_ALBUMS,
     ProviderFeature.ARTIST_TOPTRACKS,
     ProviderFeature.SIMILAR_TRACKS,
+    ProviderFeature.LIBRARY_ALBUMS_EDIT,
+    ProviderFeature.LIBRARY_ARTISTS_EDIT,
+    ProviderFeature.LIBRARY_PLAYLISTS_EDIT,
+    ProviderFeature.LIBRARY_TRACKS_EDIT,
+    ProviderFeature.FAVORITE_ALBUMS_EDIT,
+    ProviderFeature.FAVORITE_TRACKS_EDIT,
+    ProviderFeature.FAVORITE_PLAYLISTS_EDIT,
 }
 
 MUSIC_APP_TOKEN = app_var(8)
@@ -354,16 +361,14 @@ class AppleMusicProvider(MusicProvider):
         """Retrieve library tracks from the provider."""
         endpoint = "me/library/songs"
         song_catalog_ids = []
+        library_only_tracks = []
         for item in await self._get_all_items(endpoint):
             catalog_id = item.get("attributes", {}).get("playParams", {}).get("catalogId")
             if not catalog_id:
-                self.logger.debug(
-                    "Skipping track. No catalog version found for %s - %s",
-                    item["attributes"].get("artistName", ""),
-                    item["attributes"].get("name", ""),
-                )
-                continue
-            song_catalog_ids.append(catalog_id)
+                # Track is library-only (private/uploaded), use library ID instead
+                library_only_tracks.append(item)
+            else:
+                song_catalog_ids.append(catalog_id)
         # Obtain catalog info per 200 songs, the documented limit of 300 results in a 504 timeout
         max_limit = 200
         for i in range(0, len(song_catalog_ids), max_limit):
@@ -372,8 +377,15 @@ class AppleMusicProvider(MusicProvider):
             response = await self._get_data(
                 catalog_endpoint, ids=",".join(catalog_ids), include="artists,albums"
             )
+            # Fetch ratings for this batch
+            rating_response = await self._get_ratings(catalog_ids, MediaType.TRACK)
             for item in response["data"]:
-                yield self._parse_track(item)
+                is_favourite = rating_response.get(item["id"])
+                track = self._parse_track(item, is_favourite)
+                yield track
+        # Yield library-only tracks using their library metadata
+        for item in library_only_tracks:
+            yield self._parse_track(item)
 
     async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]:
         """Retrieve playlists from the provider."""
@@ -404,7 +416,9 @@ class AppleMusicProvider(MusicProvider):
         """Get full track details by id."""
         endpoint = f"catalog/{self._storefront}/songs/{prov_track_id}"
         response = await self._get_data(endpoint, include="artists,albums")
-        return self._parse_track(response["data"][0])
+        rating_response = await self._get_ratings([prov_track_id], MediaType.TRACK)
+        is_favourite = rating_response.get(prov_track_id)
+        return self._parse_track(response["data"][0], is_favourite)
 
     @use_cache()
     async def get_playlist(self, prov_playlist_id) -> Playlist:
@@ -424,11 +438,13 @@ class AppleMusicProvider(MusicProvider):
         response = await self._get_data(endpoint, include="artists")
         # Including albums results in a 504 error, so we need to fetch the album separately
         album = await self.get_album(prov_album_id)
+        track_ids = [track_obj["id"] for track_obj in response["data"] if "id" in track_obj]
+        rating_response = await self._get_ratings(track_ids, MediaType.TRACK)
         tracks = []
         for track_obj in response["data"]:
             if "id" not in track_obj:
                 continue
-            track = self._parse_track(track_obj)
+            track = self._parse_track(track_obj, rating_response.get(track_obj["id"]))
             track.album = album
             tracks.append(track)
         return tracks
@@ -479,23 +495,43 @@ class AppleMusicProvider(MusicProvider):
             return []
         return [self._parse_track(track) for track in response["data"] if track["id"]]
 
-    async def library_add(self, item: MediaItemType):
+    async def library_add(self, item: MediaItemType) -> None:
         """Add item to library."""
-        raise NotImplementedError("Not implemented!")
+        item_type = self._translate_media_type_to_apple_type(item.media_type)
+        kwargs = {
+            f"ids[{item_type}]": item.item_id,
+        }
+        await self._post_data("me/library/", **kwargs)
 
-    async def library_remove(self, prov_item_id, media_type: MediaType):
+    async def library_remove(self, prov_item_id, media_type: MediaType) -> None:
         """Remove item from library."""
-        raise NotImplementedError("Not implemented!")
+        self.logger.warning(
+            "Deleting items from your library is not yet supported by the Apple Music API. "
+            f"Skipping deletion of {media_type} - {prov_item_id}."
+        )
 
     async def add_playlist_tracks(self, prov_playlist_id: str, prov_track_ids: list[str]):
         """Add track(s) to playlist."""
-        raise NotImplementedError("Not implemented!")
+        endpoint = f"me/library/playlists/{prov_playlist_id}/tracks"
+        data = {
+            "data": [
+                {
+                    "id": track_id,
+                    "type": "library-songs" if self.is_library_id(track_id) else "songs",
+                }
+                for track_id in prov_track_ids
+            ]
+        }
+        await self._post_data(endpoint, data=data)
 
     async def remove_playlist_tracks(
         self, prov_playlist_id: str, positions_to_remove: tuple[int, ...]
     ) -> None:
         """Remove track(s) from playlist."""
-        raise NotImplementedError("Not implemented!")
+        self.logger.warning(
+            "Removing tracks from playlists is not supported by the Apple Music "
+            "API. Make sure to delete them using the Apple Music app."
+        )
 
     @use_cache(3600 * 24)  # cache for 24 hours
     async def get_similar_tracks(self, prov_track_id, limit=25) -> list[Track]:
@@ -520,6 +556,24 @@ class AppleMusicProvider(MusicProvider):
     async def get_stream_details(self, item_id: str, media_type: MediaType) -> StreamDetails:
         """Return the content details for the given track when it will be streamed."""
         stream_metadata = await self._fetch_song_stream_metadata(item_id)
+        if self.is_library_id(item_id):
+            # Library items are not encrypted and do not need decryption keys
+            try:
+                stream_url = stream_metadata["assets"][0]["URL"]
+            except (KeyError, IndexError, TypeError) as exc:
+                raise MediaNotFoundError(
+                    f"Failed to extract stream URL for library track {item_id}: {exc}"
+                ) from exc
+            return StreamDetails(
+                item_id=item_id,
+                provider=self.lookup_key,
+                path=stream_url,
+                stream_type=StreamType.HTTP,
+                audio_format=AudioFormat(content_type=ContentType.UNKNOWN),
+                can_seek=True,
+                allow_seek=True,
+            )
+        # Continue to obtain decryption keys for catalog items
         license_url = stream_metadata["hls-key-server-url"]
         stream_url, uri = await self._parse_stream_url_and_uri(stream_metadata["assets"])
         if not stream_url or not uri:
@@ -536,6 +590,21 @@ class AppleMusicProvider(MusicProvider):
             allow_seek=True,
         )
 
+    async def set_favorite(self, prov_item_id: str, media_type: MediaType, favorite: bool) -> None:
+        """Set the favorite status of an item."""
+        data = {
+            "type": "ratings",
+            "attributes": {
+                "value": 1 if favorite else -1,
+            },
+        }
+        item_type = self._translate_media_type_to_apple_type(media_type)
+        if self._is_catalog_id(prov_item_id):
+            endpoint = f"me/ratings/{item_type}/{prov_item_id}"
+        else:
+            endpoint = f"me/ratings/library-{item_type}/{prov_item_id}"
+        await self._put_data(endpoint, data=data)
+
     def _parse_artist(self, artist_obj: dict[str, Any]) -> Artist:
         """Parse artist object to generic layout."""
         relationships = artist_obj.get("relationships", {})
@@ -684,13 +753,19 @@ class AppleMusicProvider(MusicProvider):
     def _parse_track(
         self,
         track_obj: dict[str, Any],
+        is_favourite: bool | None = None,
     ) -> Track:
         """Parse track object to generic layout."""
         relationships = track_obj.get("relationships", {})
-        if track_obj.get("type") == "library-songs" and relationships["catalog"]["data"] != []:
+        if (
+            track_obj.get("type") == "library-songs"
+            and relationships.get("catalog", {}).get("data", []) != []
+        ):
+            # Library track with catalog version available
             track_id = relationships.get("catalog", {})["data"][0]["id"]
             attributes = relationships.get("catalog", {})["data"][0]["attributes"]
         elif "attributes" in track_obj:
+            # Catalog track or library-only track
             track_id = track_obj["id"]
             attributes = track_obj["attributes"]
         else:
@@ -749,6 +824,7 @@ class AppleMusicProvider(MusicProvider):
             track.metadata.performers = set(composers.split(", "))
         if isrc := attributes.get("isrc"):
             track.external_ids.add((ExternalID.ISRC, isrc))
+        track.favorite = is_favourite or False
         return track
 
     def _parse_playlist(self, playlist_obj: dict[str, Any]) -> Playlist:
@@ -839,13 +915,49 @@ class AppleMusicProvider(MusicProvider):
             response.raise_for_status()
             return await response.json(loads=json_loads)
 
-    async def _delete_data(self, endpoint, data=None, **kwargs) -> str:
+    @throttle_with_retries
+    async def _delete_data(self, endpoint, data=None, **kwargs) -> None:
         """Delete data from api."""
-        raise NotImplementedError("Not implemented!")
+        url = f"https://api.music.apple.com/v1/{endpoint}"
+        headers = {"Authorization": f"Bearer {self._music_app_token}"}
+        headers["Music-User-Token"] = self._music_user_token
+        async with (
+            self.mass.http_session.delete(
+                url, headers=headers, params=kwargs, json=data, ssl=True, timeout=120
+            ) as response,
+        ):
+            # Convert HTTP errors to exceptions
+            if response.status == 404:
+                raise MediaNotFoundError(f"{endpoint} not found")
+            if response.status == 429:
+                # Debug this for now to see if the response headers give us info about the
+                # backoff time. There is no documentation on this.
+                self.logger.debug("Apple Music Rate Limiter. Headers: %s", response.headers)
+                raise ResourceTemporarilyUnavailable("Apple Music Rate Limiter")
+            response.raise_for_status()
 
     async def _put_data(self, endpoint, data=None, **kwargs) -> str:
         """Put data on api."""
-        raise NotImplementedError("Not implemented!")
+        url = f"https://api.music.apple.com/v1/{endpoint}"
+        headers = {"Authorization": f"Bearer {self._music_app_token}"}
+        headers["Music-User-Token"] = self._music_user_token
+        async with (
+            self.mass.http_session.put(
+                url, headers=headers, params=kwargs, json=data, ssl=True, timeout=120
+            ) as response,
+        ):
+            # Convert HTTP errors to exceptions
+            if response.status == 404:
+                raise MediaNotFoundError(f"{endpoint} not found")
+            if response.status == 429:
+                # Debug this for now to see if the response headers give us info about the
+                # backoff time. There is no documentation on this.
+                self.logger.debug("Apple Music Rate Limiter. Headers: %s", response.headers)
+                raise ResourceTemporarilyUnavailable("Apple Music Rate Limiter")
+            response.raise_for_status()
+            if response.content_length:
+                return await response.json(loads=json_loads)
+            return {}
 
     @throttle_with_retries
     async def _post_data(self, endpoint, data=None, **kwargs) -> str:
@@ -876,6 +988,50 @@ class AppleMusicProvider(MusicProvider):
         result = await self._get_data("me/storefront", l=language)
         return result["data"][0]["id"]
 
+    async def _get_ratings(self, item_ids: list[str], media_type: MediaType) -> dict[str, bool]:
+        """Get ratings (aka favorites) for a list of item ids."""
+        if media_type == MediaType.ARTIST:
+            raise NotImplementedError(
+                "Ratings are not available for artist in the Apple Music API."
+            )
+        endpoint = self._translate_media_type_to_apple_type(media_type)
+        # Apple Music limits to 200 ids per request
+        max_ids_per_request = 200
+        results = {}
+        for i in range(0, len(item_ids), max_ids_per_request):
+            batch_ids = item_ids[i : i + max_ids_per_request]
+            response = await self._get_data(
+                f"me/ratings/{endpoint}",
+                ids=",".join(batch_ids),
+            )
+            results.update(
+                {
+                    item["id"]: bool(item["attributes"].get("value", False) == 1)
+                    for item in response.get("data", [])
+                }
+            )
+        return results
+
+    def _translate_media_type_to_apple_type(self, media_type: MediaType) -> str:
+        """Translate MediaType to Apple Music endpoint string."""
+        match media_type:
+            case MediaType.ARTIST:
+                return "artists"
+            case MediaType.ALBUM:
+                return "albums"
+            case MediaType.TRACK:
+                return "songs"
+            case MediaType.PLAYLIST:
+                return "playlists"
+        raise MusicAssistantError(f"Unsupported media type: {media_type}")
+
+    def is_library_id(self, library_id) -> bool:
+        """Check a library ID matches known format."""
+        if not isinstance(library_id, str):
+            return False
+        valid = re.findall(r"^(?:[a|i|l|p]{1}\.|pl\.u\-)[a-zA-Z0-9]+$", library_id)
+        return bool(valid)
+
     def _is_catalog_id(self, catalog_id: str) -> bool:
         """Check if input is a catalog id, or a library id."""
         return catalog_id.isnumeric() or catalog_id.startswith("pl.")
@@ -883,9 +1039,13 @@ class AppleMusicProvider(MusicProvider):
     async def _fetch_song_stream_metadata(self, song_id: str) -> str:
         """Get the stream URL for a song from Apple Music."""
         playback_url = "https://play.music.apple.com/WebObjects/MZPlay.woa/wa/webPlayback"
-        data = {
-            "salableAdamId": song_id,
-        }
+        data = {}
+        self.logger.debug("_fetch_song_stream_metadata: Check if Library ID: %s", song_id)
+        if self.is_library_id(song_id):
+            data["universalLibraryId"] = song_id
+            data["isLibrary"] = True
+        else:
+            data["salableAdamId"] = song_id
         for retry in (True, False):
             try:
                 async with self.mass.http_session.post(