from __future__ import annotations
from collections.abc import AsyncGenerator
-from typing import cast
+from typing import TYPE_CHECKING, cast
from music_assistant_models.enums import MediaType, ProviderFeature
from music_assistant_models.errors import (
from .base import MediaControllerBase
+if TYPE_CHECKING:
+ from music_assistant import MusicAssistant
+
class PlaylistController(MediaControllerBase[Playlist]):
"""Controller managing MediaItems of type Playlist."""
media_type = MediaType.PLAYLIST
item_cls = Playlist
- def __init__(self, *args, **kwargs) -> None:
+ def __init__(self, mass: MusicAssistant) -> None:
"""Initialize class."""
- super().__init__(*args, **kwargs)
+ super().__init__(mass)
# register (extra) api handlers
api_base = self.api_base
self.mass.register_api_command(f"music/{api_base}/create_playlist", self.create_playlist)
raise ProviderUnavailableError
else:
provider = self.mass.get_provider("builtin")
+ # grab all existing track ids in the playlist so we can check for duplicates
+ provider = cast("MusicProvider", provider)
if "/" in name or "\\" in name or ".." in name:
msg = f"{name} is not a valid Playlist name"
playlist_prov_map = next(iter(playlist.provider_mappings))
playlist_prov = self.mass.get_provider(playlist_prov_map.provider_instance)
if not playlist_prov or not playlist_prov.available:
- msg = f"Provider {playlist_prov_map.provider_instance} is not available"
- raise ProviderUnavailableError(msg)
- cur_playlist_track_ids = set()
- cur_playlist_track_uris = set()
+ raise ProviderUnavailableError(
+ f"Provider {playlist_prov_map.provider_instance} is not available"
+ )
+ playlist_prov = cast("MusicProvider", playlist_prov)
+
+ # sets to track existing tracks
+ cur_playlist_track_ids: set[str] = set()
+ cur_playlist_track_uris: set[str] = set()
+
+ # collect current track IDs and URIs
async for item in self.tracks(playlist.item_id, playlist.provider):
- cur_playlist_track_uris.add(item.item_id)
- cur_playlist_track_uris.add(item.uri)
+ if item.item_id:
+ cur_playlist_track_ids.add(item.item_id)
+ if item.uri:
+ cur_playlist_track_uris.add(item.uri)
- # unwrap all uri's to track uri's
+ # unwrap URIs to individual track URIs
unwrapped_uris: list[str] = []
for uri in uris:
# URI could be a playlist or album uri, unwrap it
media_type_str, item_id = rest.split("/", 1)
media_type = MediaType(media_type_str)
if media_type == MediaType.ALBUM:
- for track in await self.mass.music.albums.tracks(
+ album_tracks = await self.mass.music.albums.tracks(
item_id, provider_instance_id_or_domain
- ):
- unwrapped_uris.append(track.uri)
+ )
+ for track in album_tracks:
+ if track.uri is not None:
+ unwrapped_uris.append(track.uri)
elif media_type == MediaType.PLAYLIST:
- for track in await self.tracks(item_id, provider_instance_id_or_domain):
- unwrapped_uris.append(track.uri)
+ async for track in self.tracks(item_id, provider_instance_id_or_domain):
+ if track.uri is not None:
+ unwrapped_uris.append(track.uri)
elif media_type == MediaType.TRACK:
unwrapped_uris.append(uri)
else:
raise InvalidDataError(msg)
for prov_mapping in playlist.provider_mappings:
provider = self.mass.get_provider(prov_mapping.provider_instance)
+ if not provider or not isinstance(provider, MusicProvider):
+ self.logger.warning(
+ "Provider %s is not available or does not support playlist editing",
+ prov_mapping.provider_domain,
+ )
+ continue
if ProviderFeature.PLAYLIST_TRACKS_EDIT not in provider.supported_features:
self.logger.warning(
"Provider %s does not support editing playlists",
await self.update_item_in_library(db_playlist_id, playlist)
- async def _add_library_item(self, item: Playlist) -> int:
+ async def _add_library_item(self, item: Playlist, overwrite_existing: bool = False) -> int:
"""Add a new record to the database."""
db_id = await self.mass.music.database.insert(
self.db_table,
"metadata": serialize_to_json(item.metadata),
"external_ids": serialize_to_json(item.external_ids),
"search_name": create_safe_string(item.name, True, True),
- "search_sort_name": create_safe_string(item.sort_name, True, True),
+ "search_sort_name": create_safe_string(item.sort_name or "", True, True),
},
)
# update/set provider_mappings table
return db_id
async def _update_library_item(
- self, item_id: int, update: Playlist, overwrite: bool = False
+ self, item_id: str | int, update: Playlist, overwrite: bool = False
) -> None:
"""Update existing record in the database."""
db_id = int(item_id) # ensure integer
update.external_ids if overwrite else cur_item.external_ids
),
"search_name": create_safe_string(name, True, True),
- "search_sort_name": create_safe_string(sort_name, True, True),
+ "search_sort_name": create_safe_string(sort_name or "", True, True),
},
)
# update/set provider_mappings table
await self.set_provider_mappings(db_id, provider_mappings, overwrite)
self.logger.debug("updated %s in database: (id %s)", update.name, db_id)
- @guard_single_request
+ @guard_single_request # type: ignore[type-var] # TODO: fix typing in util.py
async def _get_provider_playlist_tracks(
self,
item_id: str,
self,
item_id: str,
provider_instance_id_or_domain: str,
- ):
+ ) -> list[Track]:
"""Get the list of base tracks from the controller used to calculate the dynamic radio."""
return [
x
def _refresh_playlist_tracks(self, playlist: Playlist) -> None:
"""Refresh playlist tracks by forcing a cache refresh."""
- async def _refresh(playlist: Playlist):
+ async def _refresh(playlist: Playlist) -> None:
# simply iterate all tracks with force_refresh=True to refresh the cache
async for _ in self.tracks(playlist.item_id, playlist.provider, force_refresh=True):
pass