Optimize playlist tracks listings (#580)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Mon, 27 Mar 2023 10:58:17 +0000 (12:58 +0200)
committerGitHub <noreply@github.com>
Mon, 27 Mar 2023 10:58:17 +0000 (12:58 +0200)
* change playlist tracks to async generators

* add support for ChunkedResultMessage

* fix some typos

* small improvement for playlist metadata (genres)

* adjust genre filter

* position count start at 1

music_assistant/common/models/api.py
music_assistant/server/controllers/media/playlists.py
music_assistant/server/controllers/metadata.py
music_assistant/server/models/music_provider.py
music_assistant/server/providers/filesystem_local/base.py
music_assistant/server/providers/qobuz/__init__.py
music_assistant/server/providers/soundcloud/__init__.py
music_assistant/server/providers/spotify/__init__.py
music_assistant/server/providers/websocket_api/__init__.py
music_assistant/server/providers/ytmusic/__init__.py

index e20a077f9b6be852fb1bd8f6b3cbfc170267aa9c..490652f08ae7f6049975f140d1a0b2026abf544e 100644 (file)
@@ -34,6 +34,14 @@ class SuccessResultMessage(ResultMessageBase):
     result: Any = field(default=None, metadata={"serialize": lambda v: get_serializable_value(v)})
 
 
+@dataclass
+class ChunkedResultMessage(ResultMessageBase):
+    """Message sent when the result of a command is sent in multiple chunks."""
+
+    result: Any = field(default=None, metadata={"serialize": lambda v: get_serializable_value(v)})
+    is_last_chunk: bool = False
+
+
 @dataclass
 class ErrorResultMessage(ResultMessageBase):
     """Message sent when a command did not execute successfully."""
index b2aa474a3b262d9aa5919fa27b561c8fb4b63b33..c85eb881aedb4623fc1c6865b77982ce22fe12ae 100644 (file)
@@ -2,6 +2,7 @@
 from __future__ import annotations
 
 import random
+from collections.abc import AsyncGenerator
 from time import time
 from typing import Any
 
@@ -46,16 +47,17 @@ class PlaylistController(MediaControllerBase[Playlist]):
         item_id: str,
         provider_domain: str | None = None,
         provider_instance: str | None = None,
-    ) -> list[Track]:
+    ) -> AsyncGenerator[Track, None]:
         """Return playlist tracks for the given provider playlist id."""
         playlist = await self.get(item_id, provider_domain, provider_instance)
         prov = next(x for x in playlist.provider_mappings)
-        return await self._get_provider_playlist_tracks(
+        async for track in self._get_provider_playlist_tracks(
             prov.item_id,
             provider_domain=prov.provider_domain,
             provider_instance=prov.provider_instance,
             cache_checksum=playlist.metadata.checksum,
-        )
+        ):
+            yield track
 
     async def add(self, item: Playlist) -> Playlist:
         """Add playlist to local db and return the new database item."""
@@ -239,25 +241,30 @@ class PlaylistController(MediaControllerBase[Playlist]):
         provider_domain: str | None = None,
         provider_instance: str | None = None,
         cache_checksum: Any = None,
-    ) -> list[Track]:
+    ) -> AsyncGenerator[Track, None]:
         """Return album tracks for the given provider album id."""
         provider = self.mass.get_provider(provider_instance or provider_domain)
         if not provider:
-            return []
+            return
         # prefer cache items (if any)
         cache_key = f"{provider.instance_id}.playlist.{item_id}.tracks"
         if cache := await self.mass.cache.get(cache_key, checksum=cache_checksum):
-            return [Track.from_dict(x) for x in cache]
+            for track_dict in cache:
+                yield Track.from_dict(track_dict)
+            return
         # no items in cache - get listing from provider
-        items = await provider.get_playlist_tracks(item_id)
-        # double check if position set
-        if items:
-            assert items[0].position is not None, "Playlist items require position to be set"
+        all_items = []
+        async for item in provider.get_playlist_tracks(item_id):
+            # double check if position set
+            assert item.position is not None, "Playlist items require position to be set"
+            yield item
+            all_items.append(item)
         # store (serializable items) in cache
         self.mass.create_task(
-            self.mass.cache.set(cache_key, [x.to_dict() for x in items], checksum=cache_checksum)
+            self.mass.cache.set(
+                cache_key, [x.to_dict() for x in all_items], checksum=cache_checksum
+            )
         )
-        return items
 
     async def _get_provider_dynamic_tracks(
         self,
@@ -270,13 +277,16 @@ class PlaylistController(MediaControllerBase[Playlist]):
         provider = self.mass.get_provider(provider_instance or provider_domain)
         if not provider or ProviderFeature.SIMILAR_TRACKS not in provider.supported_features:
             return []
-        playlist_tracks = await self._get_provider_playlist_tracks(
-            item_id=item_id,
-            provider_domain=provider_domain,
-            provider_instance=provider_instance,
-        )
-        # filter out unavailable tracks
-        playlist_tracks = [x for x in playlist_tracks if x.available]
+        playlist_tracks = [
+            x
+            async for x in self._get_provider_playlist_tracks(
+                item_id=item_id,
+                provider_domain=provider_domain,
+                provider_instance=provider_instance,
+            )
+            # filter out unavailable tracks
+            if x.available
+        ]
         limit = min(limit, len(playlist_tracks))
         # use set to prevent duplicates
         final_items = set()
index 7527f04f82e35c7cf8af4ea9969977c9c800da3b..f9171475ed929e7c63ecd47907f815b5047ec43c 100755 (executable)
@@ -174,7 +174,8 @@ class MetaDataController:
         playlist.metadata.genres = set()
         image_urls = set()
         try:
-            for track in await self.mass.music.playlists.tracks(
+            playlist_genres: dict[str, int] = {}
+            async for track in self.mass.music.playlists.tracks(
                 playlist.item_id, playlist.provider
             ):
                 if not playlist.image and track.image:
@@ -182,12 +183,24 @@ class MetaDataController:
                 if track.media_type != MediaType.TRACK:
                     # filter out radio items
                     continue
-                assert isinstance(track, Track)
-                assert isinstance(track.album, Album)
+                if not isinstance(track, Track):
+                    continue
                 if track.metadata.genres:
-                    playlist.metadata.genres.update(track.metadata.genres)
-                elif track.album and track.album.metadata.genres:
-                    playlist.metadata.genres.update(track.album.metadata.genres)
+                    genres = track.metadata.genres
+                elif track.album and isinstance(track.album, Album) and track.album.metadata.genres:
+                    genres = track.album.metadata.genres
+                else:
+                    genres = set()
+                for genre in genres:
+                    if genre not in playlist_genres:
+                        playlist_genres[genre] = 0
+                    playlist_genres[genre] += 1
+
+            playlist_genres_filtered = {
+                genre for genre, count in playlist_genres.items() if count > 5
+            }
+            playlist.metadata.genres.update(playlist_genres_filtered)
+
             # create collage thumb/fanart from playlist tracks
             if image_urls:
                 if playlist.image and self.mass.storage_path in playlist.image:
index 169af0ebfe8cf2aa3834f46b1837ea17b3ca97ca..e042dff01caa467eda510eeeeda63c65cb0796f2 100644 (file)
@@ -116,10 +116,11 @@ class MusicProvider(Provider):
 
     async def get_playlist_tracks(  # type: ignore[return]
         self, prov_playlist_id: str
-    ) -> list[Track]:
+    ) -> AsyncGenerator[Track, None]:
         """Get all playlist tracks for given playlist id."""
         if ProviderFeature.LIBRARY_PLAYLISTS in self.supported_features:
             raise NotImplementedError
+        yield  # type: ignore
 
     async def library_add(self, prov_item_id: str, media_type: MediaType) -> bool:
         """Add item to provider's library. Return true on success."""
index 035ecc9456e9d5c000beefe767f54fbe27de5cac..355fe4cd86b64b547bb6da2dbc8317b427b53f51 100644 (file)
@@ -424,9 +424,8 @@ class FileSystemProviderBase(MusicProvider):
                 result.append(track)
         return sorted(result, key=lambda x: (x.disc_number or 0, x.track_number or 0))
 
-    async def get_playlist_tracks(self, prov_playlist_id: str) -> list[Track]:
+    async def get_playlist_tracks(self, prov_playlist_id: str) -> AsyncGenerator[Track, None]:
         """Get playlist tracks for given playlist id."""
-        result = []
         if not await self.exists(prov_playlist_id):
             raise MediaNotFoundError(f"Playlist path does not exist: {prov_playlist_id}")
 
@@ -448,12 +447,11 @@ class FileSystemProviderBase(MusicProvider):
                     playlist_line, os.path.dirname(prov_playlist_id)
                 ):
                     # use the linenumber as position for easier deletions
-                    media_item.position = line_no
-                    result.append(media_item)
+                    media_item.position = line_no + 1
+                    yield media_item
 
         except Exception as err:  # pylint: disable=broad-except
             self.logger.warning("Error while parsing playlist %s", prov_playlist_id, exc_info=err)
-        return result
 
     async def _parse_playlist_line(self, line: str, playlist_path: str) -> Track | Radio | None:
         """Try to parse a track from a playlist line."""
index 04749e6a658d536adaf749d41c3480a99f507f39..91a0062811e0013ae2fd84d3ed6d696b60d27253 100644 (file)
@@ -215,10 +215,9 @@ class QobuzProvider(MusicProvider):
             if (item and item["id"])
         ]
 
-    async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]:
+    async def get_playlist_tracks(self, prov_playlist_id) -> AsyncGenerator[Track, None]:
         """Get all playlist tracks for given playlist id."""
-        count = 0
-        result = []
+        count = 1
         for item in await self._get_all_items(
             "playlist/get",
             key="tracks",
@@ -230,9 +229,8 @@ class QobuzProvider(MusicProvider):
             track = await self._parse_track(item)
             # use count as position
             track.position = count
-            result.append(track)
+            yield track
             count += 1
-        return result
 
     async def get_artist_albums(self, prov_artist_id) -> list[Album]:
         """Get a list of albums for the given artist."""
@@ -324,7 +322,7 @@ class QobuzProvider(MusicProvider):
     ) -> None:
         """Remove track(s) from playlist."""
         playlist_track_ids = set()
-        for track in await self.get_playlist_tracks(prov_playlist_id):
+        async for track in self.get_playlist_tracks(prov_playlist_id):
             if track.position in positions_to_remove:
                 playlist_track_ids.add(str(track["playlist_track_id"]))
             if len(playlist_track_ids) == positions_to_remove:
index ab070e35206515fae46d37e1dfd9b5b1b661ded6..c2e74ae26f8a8edee470ce6b2fcc32171b2e7f79 100644 (file)
@@ -196,23 +196,21 @@ class SoundcloudMusicProvider(MusicProvider):
             self.logger.debug("Parse playlist failed: %s", playlist_obj, exc_info=error)
         return playlist
 
-    async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]:
+    async def get_playlist_tracks(self, prov_playlist_id) -> AsyncGenerator[Track, None]:
         """Get all playlist tracks for given playlist id."""
         playlist_obj = await self._soundcloud.get_playlist_details(playlist_id=prov_playlist_id)
         if "tracks" not in playlist_obj:
-            return []
-        tracks = []
+            return
         for index, item in enumerate(playlist_obj["tracks"]):
             song = await self._soundcloud.get_track_details(item["id"])
             try:
                 track = await self._parse_track(song[0])
                 if track:
-                    track.position = index
-                    tracks.append(track)
+                    track.position = index + 1
+                    yield track
             except (KeyError, TypeError, InvalidDataError, IndexError) as error:
                 self.logger.debug("Parse track failed: %s", song, exc_info=error)
                 continue
-        return tracks
 
     async def get_artist_toptracks(self, prov_artist_id) -> list[Track]:
         """Get a list of 25 most popular tracks for the given artist."""
index 147e42d2624c3da3a10f585e782fd36859fe1d89..a89b5473d95e6afbc90b8e5a06d2862911e41205 100644 (file)
@@ -239,10 +239,9 @@ class SpotifyProvider(MusicProvider):
             if (item and item["id"])
         ]
 
-    async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]:
+    async def get_playlist_tracks(self, prov_playlist_id) -> AsyncGenerator[Track, None]:
         """Get all playlist tracks for given playlist id."""
-        count = 0
-        result = []
+        count = 1
         for item in await self._get_all_items(
             f"playlists/{prov_playlist_id}/tracks",
         ):
@@ -251,9 +250,8 @@ class SpotifyProvider(MusicProvider):
             track = await self._parse_track(item["track"])
             # use count as position
             track.position = count
-            result.append(track)
+            yield track
             count += 1
-        return result
 
     async def get_artist_albums(self, prov_artist_id) -> list[Album]:
         """Get a list of all albums for the given artist."""
@@ -319,7 +317,7 @@ class SpotifyProvider(MusicProvider):
     ) -> None:
         """Remove track(s) from playlist."""
         track_uris = []
-        for track in await self.get_playlist_tracks(prov_playlist_id):
+        async for track in self.get_playlist_tracks(prov_playlist_id):
             if track.position in positions_to_remove:
                 track_uris.append({"uri": f"spotify:track:{track.item_id}"})
             if len(track_uris) == positions_to_remove:
index 6364bb1b1eef138ccb1bd10242eb733e7cc85d83..c6559db739d8fab2b3405193826e88488ee7a929 100644 (file)
@@ -2,6 +2,7 @@
 from __future__ import annotations
 
 import asyncio
+import inspect
 import logging
 import weakref
 from concurrent import futures
@@ -11,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Final
 from aiohttp import WSMsgType, web
 
 from music_assistant.common.models.api import (
+    ChunkedResultMessage,
     CommandMessage,
     ErrorResultMessage,
     MessageType,
@@ -215,6 +217,19 @@ class WebsocketClientHandler:
         try:
             args = parse_arguments(handler.signature, handler.type_hints, msg.args)
             result = handler.target(**args)
+            if inspect.isasyncgen(result):
+                # async generator = send chunked response
+                chunk_size = 100
+                batch: list[Any] = []
+                async for item in result:
+                    batch.append(item)
+                    if len(batch) == chunk_size:
+                        self._send_message(ChunkedResultMessage(msg.message_id, batch))
+                        batch = []
+                # send last chunk
+                self._send_message(ChunkedResultMessage(msg.message_id, batch, True))
+                del batch
+                return
             if asyncio.iscoroutine(result):
                 result = await result
             self._send_message(SuccessResultMessage(msg.message_id, result))
index 4b37d2d9002b4c65a0d1d871d9efba89574c9104..6642b4f3554cbbb670a5c0584e381db3f9c2c46b 100644 (file)
@@ -248,7 +248,7 @@ class YoutubeMusicProvider(MusicProvider):
         )
         return await self._parse_playlist(playlist_obj)
 
-    async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]:
+    async def get_playlist_tracks(self, prov_playlist_id) -> AsyncGenerator[Track, None]:
         """Get all playlist tracks for given playlist id."""
         playlist_obj = await get_playlist(
             prov_playlist_id=prov_playlist_id,
@@ -256,8 +256,7 @@ class YoutubeMusicProvider(MusicProvider):
             username=self.config.get_value(CONF_USERNAME),
         )
         if "tracks" not in playlist_obj:
-            return []
-        tracks = []
+            return
         for index, track in enumerate(playlist_obj["tracks"]):
             if track["isAvailable"]:
                 # Playlist tracks sometimes do not have a valid artist id
@@ -265,14 +264,13 @@ class YoutubeMusicProvider(MusicProvider):
                 try:
                     track = await self._parse_track(track)
                     if track:
-                        track.position = index
-                        tracks.append(track)
+                        track.position = index + 1
+                        yield track
                 except InvalidDataError:
                     track = await self.get_track(track["videoId"])
                     if track:
-                        track.position = index
-                        tracks.append(track)
-        return tracks
+                        track.position = index + 1
+                        yield track
 
     async def get_artist_albums(self, prov_artist_id) -> list[Album]:
         """Get a list of albums for the given artist."""
@@ -293,7 +291,9 @@ class YoutubeMusicProvider(MusicProvider):
         artist_obj = await get_artist(prov_artist_id=prov_artist_id)
         if artist_obj.get("songs") and artist_obj["songs"].get("browseId"):
             prov_playlist_id = artist_obj["songs"]["browseId"]
-            playlist_tracks = await self.get_playlist_tracks(prov_playlist_id=prov_playlist_id)
+            playlist_tracks = [
+                x async for x in self.get_playlist_tracks(prov_playlist_id=prov_playlist_id)
+            ]
             return playlist_tracks[:25]
         return []