result: Any = field(default=None, metadata={"serialize": lambda v: get_serializable_value(v)})
+@dataclass
+class ChunkedResultMessage(ResultMessageBase):
+ """Message sent when the result of a command is sent in multiple chunks."""
+
+ result: Any = field(default=None, metadata={"serialize": lambda v: get_serializable_value(v)})
+ is_last_chunk: bool = False
+
+
@dataclass
class ErrorResultMessage(ResultMessageBase):
"""Message sent when a command did not execute successfully."""
from __future__ import annotations
import random
+from collections.abc import AsyncGenerator
from time import time
from typing import Any
item_id: str,
provider_domain: str | None = None,
provider_instance: str | None = None,
- ) -> list[Track]:
+ ) -> AsyncGenerator[Track, None]:
"""Return playlist tracks for the given provider playlist id."""
playlist = await self.get(item_id, provider_domain, provider_instance)
prov = next(x for x in playlist.provider_mappings)
- return await self._get_provider_playlist_tracks(
+ async for track in self._get_provider_playlist_tracks(
prov.item_id,
provider_domain=prov.provider_domain,
provider_instance=prov.provider_instance,
cache_checksum=playlist.metadata.checksum,
- )
+ ):
+ yield track
async def add(self, item: Playlist) -> Playlist:
"""Add playlist to local db and return the new database item."""
provider_domain: str | None = None,
provider_instance: str | None = None,
cache_checksum: Any = None,
- ) -> list[Track]:
+ ) -> AsyncGenerator[Track, None]:
"""Return album tracks for the given provider album id."""
provider = self.mass.get_provider(provider_instance or provider_domain)
if not provider:
- return []
+ return
# prefer cache items (if any)
cache_key = f"{provider.instance_id}.playlist.{item_id}.tracks"
if cache := await self.mass.cache.get(cache_key, checksum=cache_checksum):
- return [Track.from_dict(x) for x in cache]
+ for track_dict in cache:
+ yield Track.from_dict(track_dict)
+ return
# no items in cache - get listing from provider
- items = await provider.get_playlist_tracks(item_id)
- # double check if position set
- if items:
- assert items[0].position is not None, "Playlist items require position to be set"
+ all_items = []
+ async for item in provider.get_playlist_tracks(item_id):
+ # double check if position set
+ assert item.position is not None, "Playlist items require position to be set"
+ yield item
+ all_items.append(item)
# store (serializable items) in cache
self.mass.create_task(
- self.mass.cache.set(cache_key, [x.to_dict() for x in items], checksum=cache_checksum)
+ self.mass.cache.set(
+ cache_key, [x.to_dict() for x in all_items], checksum=cache_checksum
+ )
)
- return items
async def _get_provider_dynamic_tracks(
self,
provider = self.mass.get_provider(provider_instance or provider_domain)
if not provider or ProviderFeature.SIMILAR_TRACKS not in provider.supported_features:
return []
- playlist_tracks = await self._get_provider_playlist_tracks(
- item_id=item_id,
- provider_domain=provider_domain,
- provider_instance=provider_instance,
- )
- # filter out unavailable tracks
- playlist_tracks = [x for x in playlist_tracks if x.available]
+ playlist_tracks = [
+ x
+ async for x in self._get_provider_playlist_tracks(
+ item_id=item_id,
+ provider_domain=provider_domain,
+ provider_instance=provider_instance,
+ )
+ # filter out unavailable tracks
+ if x.available
+ ]
limit = min(limit, len(playlist_tracks))
# use set to prevent duplicates
final_items = set()
playlist.metadata.genres = set()
image_urls = set()
try:
- for track in await self.mass.music.playlists.tracks(
+ playlist_genres: dict[str, int] = {}
+ async for track in self.mass.music.playlists.tracks(
playlist.item_id, playlist.provider
):
if not playlist.image and track.image:
if track.media_type != MediaType.TRACK:
# filter out radio items
continue
- assert isinstance(track, Track)
- assert isinstance(track.album, Album)
+ if not isinstance(track, Track):
+ continue
if track.metadata.genres:
- playlist.metadata.genres.update(track.metadata.genres)
- elif track.album and track.album.metadata.genres:
- playlist.metadata.genres.update(track.album.metadata.genres)
+ genres = track.metadata.genres
+ elif track.album and isinstance(track.album, Album) and track.album.metadata.genres:
+ genres = track.album.metadata.genres
+ else:
+ genres = set()
+ for genre in genres:
+ if genre not in playlist_genres:
+ playlist_genres[genre] = 0
+ playlist_genres[genre] += 1
+
+ playlist_genres_filtered = {
+ genre for genre, count in playlist_genres.items() if count > 5
+ }
+ playlist.metadata.genres.update(playlist_genres_filtered)
+
# create collage thumb/fanart from playlist tracks
if image_urls:
if playlist.image and self.mass.storage_path in playlist.image:
async def get_playlist_tracks( # type: ignore[return]
self, prov_playlist_id: str
- ) -> list[Track]:
+ ) -> AsyncGenerator[Track, None]:
"""Get all playlist tracks for given playlist id."""
if ProviderFeature.LIBRARY_PLAYLISTS in self.supported_features:
raise NotImplementedError
+ yield # type: ignore
async def library_add(self, prov_item_id: str, media_type: MediaType) -> bool:
"""Add item to provider's library. Return true on success."""
result.append(track)
return sorted(result, key=lambda x: (x.disc_number or 0, x.track_number or 0))
- async def get_playlist_tracks(self, prov_playlist_id: str) -> list[Track]:
+ async def get_playlist_tracks(self, prov_playlist_id: str) -> AsyncGenerator[Track, None]:
"""Get playlist tracks for given playlist id."""
- result = []
if not await self.exists(prov_playlist_id):
raise MediaNotFoundError(f"Playlist path does not exist: {prov_playlist_id}")
playlist_line, os.path.dirname(prov_playlist_id)
):
# use the linenumber as position for easier deletions
- media_item.position = line_no
- result.append(media_item)
+ media_item.position = line_no + 1
+ yield media_item
except Exception as err: # pylint: disable=broad-except
self.logger.warning("Error while parsing playlist %s", prov_playlist_id, exc_info=err)
- return result
async def _parse_playlist_line(self, line: str, playlist_path: str) -> Track | Radio | None:
"""Try to parse a track from a playlist line."""
if (item and item["id"])
]
- async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]:
+ async def get_playlist_tracks(self, prov_playlist_id) -> AsyncGenerator[Track, None]:
"""Get all playlist tracks for given playlist id."""
- count = 0
- result = []
+ count = 1
for item in await self._get_all_items(
"playlist/get",
key="tracks",
track = await self._parse_track(item)
# use count as position
track.position = count
- result.append(track)
+ yield track
count += 1
- return result
async def get_artist_albums(self, prov_artist_id) -> list[Album]:
"""Get a list of albums for the given artist."""
) -> None:
"""Remove track(s) from playlist."""
playlist_track_ids = set()
- for track in await self.get_playlist_tracks(prov_playlist_id):
+ async for track in self.get_playlist_tracks(prov_playlist_id):
if track.position in positions_to_remove:
playlist_track_ids.add(str(track["playlist_track_id"]))
if len(playlist_track_ids) == positions_to_remove:
self.logger.debug("Parse playlist failed: %s", playlist_obj, exc_info=error)
return playlist
- async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]:
+ async def get_playlist_tracks(self, prov_playlist_id) -> AsyncGenerator[Track, None]:
"""Get all playlist tracks for given playlist id."""
playlist_obj = await self._soundcloud.get_playlist_details(playlist_id=prov_playlist_id)
if "tracks" not in playlist_obj:
- return []
- tracks = []
+ return
for index, item in enumerate(playlist_obj["tracks"]):
song = await self._soundcloud.get_track_details(item["id"])
try:
track = await self._parse_track(song[0])
if track:
- track.position = index
- tracks.append(track)
+ track.position = index + 1
+ yield track
except (KeyError, TypeError, InvalidDataError, IndexError) as error:
self.logger.debug("Parse track failed: %s", song, exc_info=error)
continue
- return tracks
async def get_artist_toptracks(self, prov_artist_id) -> list[Track]:
"""Get a list of 25 most popular tracks for the given artist."""
if (item and item["id"])
]
- async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]:
+ async def get_playlist_tracks(self, prov_playlist_id) -> AsyncGenerator[Track, None]:
"""Get all playlist tracks for given playlist id."""
- count = 0
- result = []
+ count = 1
for item in await self._get_all_items(
f"playlists/{prov_playlist_id}/tracks",
):
track = await self._parse_track(item["track"])
# use count as position
track.position = count
- result.append(track)
+ yield track
count += 1
- return result
async def get_artist_albums(self, prov_artist_id) -> list[Album]:
"""Get a list of all albums for the given artist."""
) -> None:
"""Remove track(s) from playlist."""
track_uris = []
- for track in await self.get_playlist_tracks(prov_playlist_id):
+ async for track in self.get_playlist_tracks(prov_playlist_id):
if track.position in positions_to_remove:
track_uris.append({"uri": f"spotify:track:{track.item_id}"})
if len(track_uris) == positions_to_remove:
from __future__ import annotations
import asyncio
+import inspect
import logging
import weakref
from concurrent import futures
from aiohttp import WSMsgType, web
from music_assistant.common.models.api import (
+ ChunkedResultMessage,
CommandMessage,
ErrorResultMessage,
MessageType,
try:
args = parse_arguments(handler.signature, handler.type_hints, msg.args)
result = handler.target(**args)
+ if inspect.isasyncgen(result):
+ # async generator = send chunked response
+ chunk_size = 100
+ batch: list[Any] = []
+ async for item in result:
+ batch.append(item)
+ if len(batch) == chunk_size:
+ self._send_message(ChunkedResultMessage(msg.message_id, batch))
+ batch = []
+ # send last chunk
+ self._send_message(ChunkedResultMessage(msg.message_id, batch, True))
+ del batch
+ return
if asyncio.iscoroutine(result):
result = await result
self._send_message(SuccessResultMessage(msg.message_id, result))
)
return await self._parse_playlist(playlist_obj)
- async def get_playlist_tracks(self, prov_playlist_id) -> list[Track]:
+ async def get_playlist_tracks(self, prov_playlist_id) -> AsyncGenerator[Track, None]:
"""Get all playlist tracks for given playlist id."""
playlist_obj = await get_playlist(
prov_playlist_id=prov_playlist_id,
username=self.config.get_value(CONF_USERNAME),
)
if "tracks" not in playlist_obj:
- return []
- tracks = []
+ return
for index, track in enumerate(playlist_obj["tracks"]):
if track["isAvailable"]:
# Playlist tracks sometimes do not have a valid artist id
try:
track = await self._parse_track(track)
if track:
- track.position = index
- tracks.append(track)
+ track.position = index + 1
+ yield track
except InvalidDataError:
track = await self.get_track(track["videoId"])
if track:
- track.position = index
- tracks.append(track)
- return tracks
+ track.position = index + 1
+ yield track
async def get_artist_albums(self, prov_artist_id) -> list[Album]:
"""Get a list of albums for the given artist."""
artist_obj = await get_artist(prov_artist_id=prov_artist_id)
if artist_obj.get("songs") and artist_obj["songs"].get("browseId"):
prov_playlist_id = artist_obj["songs"]["browseId"]
- playlist_tracks = await self.get_playlist_tracks(prov_playlist_id=prov_playlist_id)
+ playlist_tracks = [
+ x async for x in self.get_playlist_tracks(prov_playlist_id=prov_playlist_id)
+ ]
return playlist_tracks[:25]
return []