From 2cc7dcc87709cfe12e83e30ffa5455f390999954 Mon Sep 17 00:00:00 2001 From: OzGav Date: Tue, 25 Nov 2025 20:42:29 +1000 Subject: [PATCH] Typing fixes for the base media controller (#2633) * Typing fixes for the base media controller * Remove db asserts * PR review comments * remove unused function --------- Co-authored-by: Marvin Schenkel --- music_assistant/controllers/media/base.py | 99 ++++++++++------------- pyproject.toml | 1 - 2 files changed, 41 insertions(+), 59 deletions(-) diff --git a/music_assistant/controllers/media/base.py b/music_assistant/controllers/media/base.py index bd3a88ef..a63e0875 100644 --- a/music_assistant/controllers/media/base.py +++ b/music_assistant/controllers/media/base.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator, Mapping from music_assistant import MusicAssistant - from music_assistant.models import MusicProvider + from music_assistant.models.music_provider import MusicProvider ItemCls = TypeVar("ItemCls", bound="MediaItemType") @@ -133,20 +133,24 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta): ) return library_item - async def _get_library_item_by_match(self, item: Track | ItemMapping) -> int | None: + async def _get_library_item_by_match(self, item: ItemCls | ItemMapping) -> int | None: if item.provider == "library": return int(item.item_id) - # search by provider mappings + # search by provider mappings if item is ItemMapping if isinstance(item, ItemMapping): if cur_item := await self.get_library_item_by_prov_id(item.item_id, item.provider): - return cur_item.item_id - elif cur_item := await self.get_library_item_by_prov_mappings(item.provider_mappings): - return cur_item.item_id + return int(cur_item.item_id) + + # for all other items that are MediaItemType, check provider_mappings if it exists + provider_mappings = getattr(item, "provider_mappings", None) + if provider_mappings: + if cur_item := await self.get_library_item_by_prov_mappings(provider_mappings): + return int(cur_item.item_id) if cur_item := await self.get_library_item_by_external_ids(item.external_ids): # existing item match by external id # Double check external IDs - if MBID exists, regards that as overriding if compare_media_item(item, cur_item): - return cur_item.item_id + return int(cur_item.item_id) # search by (exact) name match query = f"{self.db_table}.name = :name OR {self.db_table}.sort_name = :sort_name" query_params = {"name": item.name, "sort_name": item.sort_name} @@ -154,7 +158,7 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta): extra_query=query, extra_query_params=query_params ): if compare_media_item(db_item, item, True): - return db_item.item_id + return int(db_item.item_id) return None async def update_item_in_library( @@ -281,6 +285,7 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta): ): # schedule a refresh of the metadata on access of the item # e.g. the item is being played or opened in the UI + assert library_item.uri is not None self.mass.metadata.schedule_update_metadata(library_item.uri) return library_item # grab full details from the provider @@ -315,46 +320,22 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta): ) match self.media_type: case MediaType.ARTIST: - return searchresult.artists + return cast("list[ItemCls]", searchresult.artists) case MediaType.ALBUM: - return searchresult.albums + return cast("list[ItemCls]", searchresult.albums) case MediaType.TRACK: - return searchresult.tracks + return cast("list[ItemCls]", searchresult.tracks) case MediaType.PLAYLIST: - return searchresult.playlists + return cast("list[ItemCls]", searchresult.playlists) case MediaType.AUDIOBOOK: - return searchresult.audiobooks + return cast("list[ItemCls]", searchresult.audiobooks) case MediaType.PODCAST: - return searchresult.podcasts + return cast("list[ItemCls]", searchresult.podcasts) case MediaType.RADIO: - return searchresult.radio + return cast("list[ItemCls]", searchresult.radio) case _: return [] - async def get_provider_mapping(self, item: ItemCls) -> tuple[str, str]: - """Return (first) provider and item id.""" - if not getattr(item, "provider_mappings", None): - if item.provider == "library": - item = await self.get_library_item(item.item_id) - return (item.provider, item.item_id) - for prefer_unique in (True, False): - for prov_mapping in item.provider_mappings: - if not prov_mapping.available: - continue - if provider := self.mass.get_provider( - prov_mapping.provider_instance - if prefer_unique - else prov_mapping.provider_domain - ): - if prefer_unique and provider.is_streaming_provider: - continue - return (prov_mapping.provider_instance, prov_mapping.item_id) - # last resort: return just the first entry - for prov_mapping in item.provider_mappings: - return (prov_mapping.provider_domain, prov_mapping.item_id) - - return (None, None) - async def get_library_item(self, item_id: int | str) -> ItemCls: """Get single library item by id.""" db_id = int(item_id) # ensure integer @@ -498,13 +479,13 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta): library_item = await self.get_library_item(db_id) self.mass.signal_event(EventType.MEDIA_ITEM_UPDATED, library_item.uri, library_item) - @guard_single_request + @guard_single_request # type: ignore[type-var] # TODO: fix typing for MediaControllerBase async def get_provider_item( self, item_id: str, provider_instance_id_or_domain: str, force_refresh: bool = False, - fallback: ItemMapping | ItemCls = None, + fallback: ItemMapping | ItemCls | None = None, ) -> ItemCls: """Return item details for the given provider item id.""" if provider_instance_id_or_domain == "library": @@ -515,7 +496,7 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta): provider = cast("MusicProvider", provider) with suppress(MediaNotFoundError): async with self.mass.cache.handle_refresh(force_refresh): - return await provider.get_item(self.media_type, item_id) + return cast("ItemCls", await provider.get_item(self.media_type, item_id)) # if we reach this point all possibilities failed and the item could not be found. # There is a possibility that the (streaming) provider changed the id of the item # so we return the previous details (if we have any) marked as unavailable, so @@ -530,22 +511,25 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta): ): # fallback is a ItemMapping, try to convert to full item with suppress(LookupError, TypeError, ValueError): - return self.item_cls.from_dict( - { - **fallback.to_dict(), - "provider_mappings": [ - { - "item_id": fallback.item_id, - "provider_domain": fallback_provider.domain, - "provider_instance": fallback_provider.instance_id, - "available": fallback.available, - } - ], - } + return cast( + "ItemCls", + self.item_cls.from_dict( + { + **fallback.to_dict(), + "provider_mappings": [ + { + "item_id": fallback.item_id, + "provider_domain": fallback_provider.domain, + "provider_instance": fallback_provider.instance_id, + "available": fallback.available, + } + ], + } + ), ) if fallback: # simply return the fallback item - return fallback + return cast("ItemCls", fallback) # all options exhausted, we really can not find this item msg = ( f"{self.media_type.value}://{item_id} not " @@ -740,9 +724,8 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta): # build and execute final query sql_query = self._build_final_query(query_parts, join_parts, order_by) - return [ - self.item_cls.from_dict(self._parse_db_row(db_row)) + cast("ItemCls", self.item_cls.from_dict(self._parse_db_row(db_row))) for db_row in await self.mass.music.database.get_rows_from_query( sql_query, query_params, limit=limit, offset=offset ) @@ -853,7 +836,7 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta): return sql_query @staticmethod - def _parse_db_row(db_row: Mapping) -> dict[str, Any]: + def _parse_db_row(db_row: Mapping[str, Any]) -> dict[str, Any]: """Parse raw db Mapping into a dict.""" db_row_dict = dict(db_row) db_row_dict["provider"] = "library" diff --git a/pyproject.toml b/pyproject.toml index f3b43eb5..a09cab87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,7 +130,6 @@ enable_error_code = [ "truthy-iterable", ] exclude = [ - '^music_assistant/controllers/media/base.py*$', '^music_assistant/controllers/media/playlists.py*$', '^music_assistant/controllers/media/tracks.py*$', '^music_assistant/controllers/music.py$', -- 2.34.1