From: Marcel van der Veldt Date: Mon, 27 Mar 2023 10:58:17 +0000 (+0200) Subject: Optimize playlist tracks listings (#580) X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=dd0f0ee4b5aedac7cbd5dbb713116cd19b6a8c14;p=music-assistant-server.git Optimize playlist tracks listings (#580) * change playlist tracks to async generators * add support for ChunkedResultMessage * fix some typos * small improvement for playlist metadata (genres) * adjust genre filter * position count start at 1 --- diff --git a/music_assistant/common/models/api.py b/music_assistant/common/models/api.py index e20a077f..490652f0 100644 --- a/music_assistant/common/models/api.py +++ b/music_assistant/common/models/api.py @@ -34,6 +34,14 @@ class SuccessResultMessage(ResultMessageBase): result: Any = field(default=None, metadata={"serialize": lambda v: get_serializable_value(v)}) +@dataclass +class ChunkedResultMessage(ResultMessageBase): + """Message sent when the result of a command is sent in multiple chunks.""" + + result: Any = field(default=None, metadata={"serialize": lambda v: get_serializable_value(v)}) + is_last_chunk: bool = False + + @dataclass class ErrorResultMessage(ResultMessageBase): """Message sent when a command did not execute successfully.""" diff --git a/music_assistant/server/controllers/media/playlists.py b/music_assistant/server/controllers/media/playlists.py index b2aa474a..c85eb881 100644 --- a/music_assistant/server/controllers/media/playlists.py +++ b/music_assistant/server/controllers/media/playlists.py @@ -2,6 +2,7 @@ from __future__ import annotations import random +from collections.abc import AsyncGenerator from time import time from typing import Any @@ -46,16 +47,17 @@ class PlaylistController(MediaControllerBase[Playlist]): item_id: str, provider_domain: str | None = None, provider_instance: str | None = None, - ) -> list[Track]: + ) -> AsyncGenerator[Track, None]: """Return playlist tracks for the given provider playlist id.""" playlist = await self.get(item_id, provider_domain, provider_instance) prov = next(x for x in playlist.provider_mappings) - return await self._get_provider_playlist_tracks( + async for track in self._get_provider_playlist_tracks( prov.item_id, provider_domain=prov.provider_domain, provider_instance=prov.provider_instance, cache_checksum=playlist.metadata.checksum, - ) + ): + yield track async def add(self, item: Playlist) -> Playlist: """Add playlist to local db and return the new database item.""" @@ -239,25 +241,30 @@ class PlaylistController(MediaControllerBase[Playlist]): provider_domain: str | None = None, provider_instance: str | None = None, cache_checksum: Any = None, - ) -> list[Track]: + ) -> AsyncGenerator[Track, None]: """Return album tracks for the given provider album id.""" provider = self.mass.get_provider(provider_instance or provider_domain) if not provider: - return [] + return # prefer cache items (if any) cache_key = f"{provider.instance_id}.playlist.{item_id}.tracks" if cache := await self.mass.cache.get(cache_key, checksum=cache_checksum): - return [Track.from_dict(x) for x in cache] + for track_dict in cache: + yield Track.from_dict(track_dict) + return # no items in cache - get listing from provider - items = await provider.get_playlist_tracks(item_id) - # double check if position set - if items: - assert items[0].position is not None, "Playlist items require position to be set" + all_items = [] + async for item in provider.get_playlist_tracks(item_id): + # double check if position set + assert item.position is not None, "Playlist items require position to be set" + yield item + all_items.append(item) # store (serializable items) in cache self.mass.create_task( - self.mass.cache.set(cache_key, [x.to_dict() for x in items], checksum=cache_checksum) + self.mass.cache.set( + cache_key, [x.to_dict() for x in all_items], checksum=cache_checksum + ) ) - return items async def _get_provider_dynamic_tracks( self, @@ -270,13 +277,16 @@ class PlaylistController(MediaControllerBase[Playlist]): provider = self.mass.get_provider(provider_instance or provider_domain) if not provider or ProviderFeature.SIMILAR_TRACKS not in provider.supported_features: return [] - playlist_tracks = await self._get_provider_playlist_tracks( - item_id=item_id, - provider_domain=provider_domain, - provider_instance=provider_instance, - ) - # filter out unavailable tracks - playlist_tracks = [x for x in playlist_tracks if x.available] + playlist_tracks = [ + x + async for x in self._get_provider_playlist_tracks( + item_id=item_id, + provider_domain=provider_domain, + provider_instance=provider_instance, + ) + # filter out unavailable tracks + if x.available + ] limit = min(limit, len(playlist_tracks)) # use set to prevent duplicates final_items = set() diff --git a/music_assistant/server/controllers/metadata.py b/music_assistant/server/controllers/metadata.py index 7527f04f..f9171475 100755 --- a/music_assistant/server/controllers/metadata.py +++ b/music_assistant/server/controllers/metadata.py @@ -174,7 +174,8 @@ class MetaDataController: playlist.metadata.genres = set() image_urls = set() try: - for track in await self.mass.music.playlists.tracks( + playlist_genres: dict[str, int] = {} + async for track in self.mass.music.playlists.tracks( playlist.item_id, playlist.provider ): if not playlist.image and track.image: @@ -182,12 +183,24 @@ class MetaDataController: if track.media_type != MediaType.TRACK: # filter out radio items continue - assert isinstance(track, Track) - assert isinstance(track.album, Album) + if not isinstance(track, Track): + continue if track.metadata.genres: - playlist.metadata.genres.update(track.metadata.genres) - elif track.album and track.album.metadata.genres: - playlist.metadata.genres.update(track.album.metadata.genres) + genres = track.metadata.genres + elif track.album and isinstance(track.album, Album) and track.album.metadata.genres: + genres = track.album.metadata.genres + else: + genres = set() + for genre in genres: + if genre not in playlist_genres: + playlist_genres[genre] = 0 + playlist_genres[genre] += 1 + + playlist_genres_filtered = { + genre for genre, count in playlist_genres.items() if count > 5 + } + playlist.metadata.genres.update(playlist_genres_filtered) + # create collage thumb/fanart from playlist tracks if image_urls: if playlist.image and self.mass.storage_path in playlist.image: diff --git a/music_assistant/server/models/music_provider.py b/music_assistant/server/models/music_provider.py index 169af0eb..e042dff0 100644 --- a/music_assistant/server/models/music_provider.py +++ b/music_assistant/server/models/music_provider.py @@ -116,10 +116,11 @@ class MusicProvider(Provider): async def get_playlist_tracks( # type: ignore[return] self, prov_playlist_id: str - ) -> list[Track]: + ) -> AsyncGenerator[Track, None]: """Get all playlist tracks for given playlist id.""" if ProviderFeature.LIBRARY_PLAYLISTS in self.supported_features: raise NotImplementedError + yield # type: ignore async def library_add(self, prov_item_id: str, media_type: MediaType) -> bool: """Add item to provider's library. Return true on success.""" diff --git a/music_assistant/server/providers/filesystem_local/base.py b/music_assistant/server/providers/filesystem_local/base.py index 035ecc94..355fe4cd 100644 --- a/music_assistant/server/providers/filesystem_local/base.py +++ b/music_assistant/server/providers/filesystem_local/base.py @@ -424,9 +424,8 @@ class FileSystemProviderBase(MusicProvider): result.append(track) return sorted(result, key=lambda x: (x.disc_number or 0, x.track_number or 0)) - async def get_playlist_tracks(self, prov_playlist_id: str) -> list[Track]: + async def get_playlist_tracks(self, prov_playlist_id: str) -> AsyncGenerator[Track, None]: """Get playlist tracks for given playlist id.""" - result = [] if not await self.exists(prov_playlist_id): raise MediaNotFoundError(f"Playlist path does not exist: {prov_playlist_id}") @@ -448,12 +447,11 @@ class FileSystemProviderBase(MusicProvider): playlist_line, os.path.dirname(prov_playlist_id) ): # use the linenumber as position for easier deletions - media_item.position = line_no - result.append(media_item) + media_item.position = line_no + 1 + yield media_item except Exception as err: # pylint: disable=broad-except self.logger.warning("Error while parsing playlist %s", prov_playlist_id, exc_info=err) - return result async def _parse_playlist_line(self, line: str, playlist_path: str) -> Track | Radio | None: """Try to parse a track from a playlist line.""" diff --git a/music_assistant/server/providers/qobuz/__init__.py b/music_assistant/server/providers/qobuz/__init__.py index 04749e6a..91a00628 100644 --- a/music_assistant/server/providers/qobuz/__init__.py +++ b/music_assistant/server/providers/qobuz/__init__.py @@ -215,10 +215,9 @@ class QobuzProvider(MusicProvider): if (item and item["id"]) ] - async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]: + async def get_playlist_tracks(self, prov_playlist_id) -> AsyncGenerator[Track, None]: """Get all playlist tracks for given playlist id.""" - count = 0 - result = [] + count = 1 for item in await self._get_all_items( "playlist/get", key="tracks", @@ -230,9 +229,8 @@ class QobuzProvider(MusicProvider): track = await self._parse_track(item) # use count as position track.position = count - result.append(track) + yield track count += 1 - return result async def get_artist_albums(self, prov_artist_id) -> list[Album]: """Get a list of albums for the given artist.""" @@ -324,7 +322,7 @@ class QobuzProvider(MusicProvider): ) -> None: """Remove track(s) from playlist.""" playlist_track_ids = set() - for track in await self.get_playlist_tracks(prov_playlist_id): + async for track in self.get_playlist_tracks(prov_playlist_id): if track.position in positions_to_remove: playlist_track_ids.add(str(track["playlist_track_id"])) if len(playlist_track_ids) == positions_to_remove: diff --git a/music_assistant/server/providers/soundcloud/__init__.py b/music_assistant/server/providers/soundcloud/__init__.py index ab070e35..c2e74ae2 100644 --- a/music_assistant/server/providers/soundcloud/__init__.py +++ b/music_assistant/server/providers/soundcloud/__init__.py @@ -196,23 +196,21 @@ class SoundcloudMusicProvider(MusicProvider): self.logger.debug("Parse playlist failed: %s", playlist_obj, exc_info=error) return playlist - async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]: + async def get_playlist_tracks(self, prov_playlist_id) -> AsyncGenerator[Track, None]: """Get all playlist tracks for given playlist id.""" playlist_obj = await self._soundcloud.get_playlist_details(playlist_id=prov_playlist_id) if "tracks" not in playlist_obj: - return [] - tracks = [] + return for index, item in enumerate(playlist_obj["tracks"]): song = await self._soundcloud.get_track_details(item["id"]) try: track = await self._parse_track(song[0]) if track: - track.position = index - tracks.append(track) + track.position = index + 1 + yield track except (KeyError, TypeError, InvalidDataError, IndexError) as error: self.logger.debug("Parse track failed: %s", song, exc_info=error) continue - return tracks async def get_artist_toptracks(self, prov_artist_id) -> list[Track]: """Get a list of 25 most popular tracks for the given artist.""" diff --git a/music_assistant/server/providers/spotify/__init__.py b/music_assistant/server/providers/spotify/__init__.py index 147e42d2..a89b5473 100644 --- a/music_assistant/server/providers/spotify/__init__.py +++ b/music_assistant/server/providers/spotify/__init__.py @@ -239,10 +239,9 @@ class SpotifyProvider(MusicProvider): if (item and item["id"]) ] - async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]: + async def get_playlist_tracks(self, prov_playlist_id) -> AsyncGenerator[Track, None]: """Get all playlist tracks for given playlist id.""" - count = 0 - result = [] + count = 1 for item in await self._get_all_items( f"playlists/{prov_playlist_id}/tracks", ): @@ -251,9 +250,8 @@ class SpotifyProvider(MusicProvider): track = await self._parse_track(item["track"]) # use count as position track.position = count - result.append(track) + yield track count += 1 - return result async def get_artist_albums(self, prov_artist_id) -> list[Album]: """Get a list of all albums for the given artist.""" @@ -319,7 +317,7 @@ class SpotifyProvider(MusicProvider): ) -> None: """Remove track(s) from playlist.""" track_uris = [] - for track in await self.get_playlist_tracks(prov_playlist_id): + async for track in self.get_playlist_tracks(prov_playlist_id): if track.position in positions_to_remove: track_uris.append({"uri": f"spotify:track:{track.item_id}"}) if len(track_uris) == positions_to_remove: diff --git a/music_assistant/server/providers/websocket_api/__init__.py b/music_assistant/server/providers/websocket_api/__init__.py index 6364bb1b..c6559db7 100644 --- a/music_assistant/server/providers/websocket_api/__init__.py +++ b/music_assistant/server/providers/websocket_api/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +import inspect import logging import weakref from concurrent import futures @@ -11,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Final from aiohttp import WSMsgType, web from music_assistant.common.models.api import ( + ChunkedResultMessage, CommandMessage, ErrorResultMessage, MessageType, @@ -215,6 +217,19 @@ class WebsocketClientHandler: try: args = parse_arguments(handler.signature, handler.type_hints, msg.args) result = handler.target(**args) + if inspect.isasyncgen(result): + # async generator = send chunked response + chunk_size = 100 + batch: list[Any] = [] + async for item in result: + batch.append(item) + if len(batch) == chunk_size: + self._send_message(ChunkedResultMessage(msg.message_id, batch)) + batch = [] + # send last chunk + self._send_message(ChunkedResultMessage(msg.message_id, batch, True)) + del batch + return if asyncio.iscoroutine(result): result = await result self._send_message(SuccessResultMessage(msg.message_id, result)) diff --git a/music_assistant/server/providers/ytmusic/__init__.py b/music_assistant/server/providers/ytmusic/__init__.py index 4b37d2d9..6642b4f3 100644 --- a/music_assistant/server/providers/ytmusic/__init__.py +++ b/music_assistant/server/providers/ytmusic/__init__.py @@ -248,7 +248,7 @@ class YoutubeMusicProvider(MusicProvider): ) return await self._parse_playlist(playlist_obj) - async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]: + async def get_playlist_tracks(self, prov_playlist_id) -> AsyncGenerator[Track, None]: """Get all playlist tracks for given playlist id.""" playlist_obj = await get_playlist( prov_playlist_id=prov_playlist_id, @@ -256,8 +256,7 @@ class YoutubeMusicProvider(MusicProvider): username=self.config.get_value(CONF_USERNAME), ) if "tracks" not in playlist_obj: - return [] - tracks = [] + return for index, track in enumerate(playlist_obj["tracks"]): if track["isAvailable"]: # Playlist tracks sometimes do not have a valid artist id @@ -265,14 +264,13 @@ class YoutubeMusicProvider(MusicProvider): try: track = await self._parse_track(track) if track: - track.position = index - tracks.append(track) + track.position = index + 1 + yield track except InvalidDataError: track = await self.get_track(track["videoId"]) if track: - track.position = index - tracks.append(track) - return tracks + track.position = index + 1 + yield track async def get_artist_albums(self, prov_artist_id) -> list[Album]: """Get a list of albums for the given artist.""" @@ -293,7 +291,9 @@ class YoutubeMusicProvider(MusicProvider): artist_obj = await get_artist(prov_artist_id=prov_artist_id) if artist_obj.get("songs") and artist_obj["songs"].get("browseId"): prov_playlist_id = artist_obj["songs"]["browseId"] - playlist_tracks = await self.get_playlist_tracks(prov_playlist_id=prov_playlist_id) + playlist_tracks = [ + x async for x in self.get_playlist_tracks(prov_playlist_id=prov_playlist_id) + ] return playlist_tracks[:25] return []