from collections.abc import AsyncGenerator, Mapping
from music_assistant import MusicAssistant
- from music_assistant.models import MusicProvider
+ from music_assistant.models.music_provider import MusicProvider
ItemCls = TypeVar("ItemCls", bound="MediaItemType")
)
return library_item
- async def _get_library_item_by_match(self, item: Track | ItemMapping) -> int | None:
+ async def _get_library_item_by_match(self, item: ItemCls | ItemMapping) -> int | None:
if item.provider == "library":
return int(item.item_id)
- # search by provider mappings
+ # search by provider mappings if item is ItemMapping
if isinstance(item, ItemMapping):
if cur_item := await self.get_library_item_by_prov_id(item.item_id, item.provider):
- return cur_item.item_id
- elif cur_item := await self.get_library_item_by_prov_mappings(item.provider_mappings):
- return cur_item.item_id
+ return int(cur_item.item_id)
+
+ # for all other items that are MediaItemType, check provider_mappings if it exists
+ provider_mappings = getattr(item, "provider_mappings", None)
+ if provider_mappings:
+ if cur_item := await self.get_library_item_by_prov_mappings(provider_mappings):
+ return int(cur_item.item_id)
if cur_item := await self.get_library_item_by_external_ids(item.external_ids):
# existing item match by external id
# Double check external IDs - if MBID exists, regards that as overriding
if compare_media_item(item, cur_item):
- return cur_item.item_id
+ return int(cur_item.item_id)
# search by (exact) name match
query = f"{self.db_table}.name = :name OR {self.db_table}.sort_name = :sort_name"
query_params = {"name": item.name, "sort_name": item.sort_name}
extra_query=query, extra_query_params=query_params
):
if compare_media_item(db_item, item, True):
- return db_item.item_id
+ return int(db_item.item_id)
return None
async def update_item_in_library(
):
# schedule a refresh of the metadata on access of the item
# e.g. the item is being played or opened in the UI
+ assert library_item.uri is not None
self.mass.metadata.schedule_update_metadata(library_item.uri)
return library_item
# grab full details from the provider
)
match self.media_type:
case MediaType.ARTIST:
- return searchresult.artists
+ return cast("list[ItemCls]", searchresult.artists)
case MediaType.ALBUM:
- return searchresult.albums
+ return cast("list[ItemCls]", searchresult.albums)
case MediaType.TRACK:
- return searchresult.tracks
+ return cast("list[ItemCls]", searchresult.tracks)
case MediaType.PLAYLIST:
- return searchresult.playlists
+ return cast("list[ItemCls]", searchresult.playlists)
case MediaType.AUDIOBOOK:
- return searchresult.audiobooks
+ return cast("list[ItemCls]", searchresult.audiobooks)
case MediaType.PODCAST:
- return searchresult.podcasts
+ return cast("list[ItemCls]", searchresult.podcasts)
case MediaType.RADIO:
- return searchresult.radio
+ return cast("list[ItemCls]", searchresult.radio)
case _:
return []
- async def get_provider_mapping(self, item: ItemCls) -> tuple[str, str]:
- """Return (first) provider and item id."""
- if not getattr(item, "provider_mappings", None):
- if item.provider == "library":
- item = await self.get_library_item(item.item_id)
- return (item.provider, item.item_id)
- for prefer_unique in (True, False):
- for prov_mapping in item.provider_mappings:
- if not prov_mapping.available:
- continue
- if provider := self.mass.get_provider(
- prov_mapping.provider_instance
- if prefer_unique
- else prov_mapping.provider_domain
- ):
- if prefer_unique and provider.is_streaming_provider:
- continue
- return (prov_mapping.provider_instance, prov_mapping.item_id)
- # last resort: return just the first entry
- for prov_mapping in item.provider_mappings:
- return (prov_mapping.provider_domain, prov_mapping.item_id)
-
- return (None, None)
-
async def get_library_item(self, item_id: int | str) -> ItemCls:
"""Get single library item by id."""
db_id = int(item_id) # ensure integer
library_item = await self.get_library_item(db_id)
self.mass.signal_event(EventType.MEDIA_ITEM_UPDATED, library_item.uri, library_item)
- @guard_single_request
+ @guard_single_request # type: ignore[type-var] # TODO: fix typing for MediaControllerBase
async def get_provider_item(
self,
item_id: str,
provider_instance_id_or_domain: str,
force_refresh: bool = False,
- fallback: ItemMapping | ItemCls = None,
+ fallback: ItemMapping | ItemCls | None = None,
) -> ItemCls:
"""Return item details for the given provider item id."""
if provider_instance_id_or_domain == "library":
provider = cast("MusicProvider", provider)
with suppress(MediaNotFoundError):
async with self.mass.cache.handle_refresh(force_refresh):
- return await provider.get_item(self.media_type, item_id)
+ return cast("ItemCls", await provider.get_item(self.media_type, item_id))
# if we reach this point all possibilities failed and the item could not be found.
# There is a possibility that the (streaming) provider changed the id of the item
# so we return the previous details (if we have any) marked as unavailable, so
):
# fallback is a ItemMapping, try to convert to full item
with suppress(LookupError, TypeError, ValueError):
- return self.item_cls.from_dict(
- {
- **fallback.to_dict(),
- "provider_mappings": [
- {
- "item_id": fallback.item_id,
- "provider_domain": fallback_provider.domain,
- "provider_instance": fallback_provider.instance_id,
- "available": fallback.available,
- }
- ],
- }
+ return cast(
+ "ItemCls",
+ self.item_cls.from_dict(
+ {
+ **fallback.to_dict(),
+ "provider_mappings": [
+ {
+ "item_id": fallback.item_id,
+ "provider_domain": fallback_provider.domain,
+ "provider_instance": fallback_provider.instance_id,
+ "available": fallback.available,
+ }
+ ],
+ }
+ ),
)
if fallback:
# simply return the fallback item
- return fallback
+ return cast("ItemCls", fallback)
# all options exhausted, we really can not find this item
msg = (
f"{self.media_type.value}://{item_id} not "
# build and execute final query
sql_query = self._build_final_query(query_parts, join_parts, order_by)
-
return [
- self.item_cls.from_dict(self._parse_db_row(db_row))
+ cast("ItemCls", self.item_cls.from_dict(self._parse_db_row(db_row)))
for db_row in await self.mass.music.database.get_rows_from_query(
sql_query, query_params, limit=limit, offset=offset
)
return sql_query
@staticmethod
- def _parse_db_row(db_row: Mapping) -> dict[str, Any]:
+ def _parse_db_row(db_row: Mapping[str, Any]) -> dict[str, Any]:
"""Parse raw db Mapping into a dict."""
db_row_dict = dict(db_row)
db_row_dict["provider"] = "library"