Typing fixes for the base media controller (#2633)
authorOzGav <gavnosp@hotmail.com>
Tue, 25 Nov 2025 10:42:29 +0000 (20:42 +1000)
committerGitHub <noreply@github.com>
Tue, 25 Nov 2025 10:42:29 +0000 (11:42 +0100)
* Typing fixes for the base media controller

* Remove db asserts

* PR review comments

* remove unused function

---------

Co-authored-by: Marvin Schenkel <marvinschenkel@gmail.com>
music_assistant/controllers/media/base.py
pyproject.toml

index bd3a88ef54a92b3763592427468be4132fe300ac..a63e0875d8c031610181bdfa06848e7d71374972 100644 (file)
@@ -25,7 +25,7 @@ if TYPE_CHECKING:
     from collections.abc import AsyncGenerator, Mapping
 
     from music_assistant import MusicAssistant
-    from music_assistant.models import MusicProvider
+    from music_assistant.models.music_provider import MusicProvider
 
 
 ItemCls = TypeVar("ItemCls", bound="MediaItemType")
@@ -133,20 +133,24 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta):
         )
         return library_item
 
-    async def _get_library_item_by_match(self, item: Track | ItemMapping) -> int | None:
+    async def _get_library_item_by_match(self, item: ItemCls | ItemMapping) -> int | None:
         if item.provider == "library":
             return int(item.item_id)
-        # search by provider mappings
+        # search by provider mappings if item is ItemMapping
         if isinstance(item, ItemMapping):
             if cur_item := await self.get_library_item_by_prov_id(item.item_id, item.provider):
-                return cur_item.item_id
-        elif cur_item := await self.get_library_item_by_prov_mappings(item.provider_mappings):
-            return cur_item.item_id
+                return int(cur_item.item_id)
+
+        # for all other items that are MediaItemType, check provider_mappings if it exists
+        provider_mappings = getattr(item, "provider_mappings", None)
+        if provider_mappings:
+            if cur_item := await self.get_library_item_by_prov_mappings(provider_mappings):
+                return int(cur_item.item_id)
         if cur_item := await self.get_library_item_by_external_ids(item.external_ids):
             # existing item match by external id
             # Double check external IDs - if MBID exists, regards that as overriding
             if compare_media_item(item, cur_item):
-                return cur_item.item_id
+                return int(cur_item.item_id)
         # search by (exact) name match
         query = f"{self.db_table}.name = :name OR {self.db_table}.sort_name = :sort_name"
         query_params = {"name": item.name, "sort_name": item.sort_name}
@@ -154,7 +158,7 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta):
             extra_query=query, extra_query_params=query_params
         ):
             if compare_media_item(db_item, item, True):
-                return db_item.item_id
+                return int(db_item.item_id)
         return None
 
     async def update_item_in_library(
@@ -281,6 +285,7 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta):
         ):
             # schedule a refresh of the metadata on access of the item
             # e.g. the item is being played or opened in the UI
+            assert library_item.uri is not None
             self.mass.metadata.schedule_update_metadata(library_item.uri)
             return library_item
         # grab full details from the provider
@@ -315,46 +320,22 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta):
         )
         match self.media_type:
             case MediaType.ARTIST:
-                return searchresult.artists
+                return cast("list[ItemCls]", searchresult.artists)
             case MediaType.ALBUM:
-                return searchresult.albums
+                return cast("list[ItemCls]", searchresult.albums)
             case MediaType.TRACK:
-                return searchresult.tracks
+                return cast("list[ItemCls]", searchresult.tracks)
             case MediaType.PLAYLIST:
-                return searchresult.playlists
+                return cast("list[ItemCls]", searchresult.playlists)
             case MediaType.AUDIOBOOK:
-                return searchresult.audiobooks
+                return cast("list[ItemCls]", searchresult.audiobooks)
             case MediaType.PODCAST:
-                return searchresult.podcasts
+                return cast("list[ItemCls]", searchresult.podcasts)
             case MediaType.RADIO:
-                return searchresult.radio
+                return cast("list[ItemCls]", searchresult.radio)
             case _:
                 return []
 
-    async def get_provider_mapping(self, item: ItemCls) -> tuple[str, str]:
-        """Return (first) provider and item id."""
-        if not getattr(item, "provider_mappings", None):
-            if item.provider == "library":
-                item = await self.get_library_item(item.item_id)
-            return (item.provider, item.item_id)
-        for prefer_unique in (True, False):
-            for prov_mapping in item.provider_mappings:
-                if not prov_mapping.available:
-                    continue
-                if provider := self.mass.get_provider(
-                    prov_mapping.provider_instance
-                    if prefer_unique
-                    else prov_mapping.provider_domain
-                ):
-                    if prefer_unique and provider.is_streaming_provider:
-                        continue
-                    return (prov_mapping.provider_instance, prov_mapping.item_id)
-        # last resort: return just the first entry
-        for prov_mapping in item.provider_mappings:
-            return (prov_mapping.provider_domain, prov_mapping.item_id)
-
-        return (None, None)
-
     async def get_library_item(self, item_id: int | str) -> ItemCls:
         """Get single library item by id."""
         db_id = int(item_id)  # ensure integer
@@ -498,13 +479,13 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta):
         library_item = await self.get_library_item(db_id)
         self.mass.signal_event(EventType.MEDIA_ITEM_UPDATED, library_item.uri, library_item)
 
-    @guard_single_request
+    @guard_single_request  # type: ignore[type-var]  # TODO: fix typing for MediaControllerBase
     async def get_provider_item(
         self,
         item_id: str,
         provider_instance_id_or_domain: str,
         force_refresh: bool = False,
-        fallback: ItemMapping | ItemCls = None,
+        fallback: ItemMapping | ItemCls | None = None,
     ) -> ItemCls:
         """Return item details for the given provider item id."""
         if provider_instance_id_or_domain == "library":
@@ -515,7 +496,7 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta):
             provider = cast("MusicProvider", provider)
             with suppress(MediaNotFoundError):
                 async with self.mass.cache.handle_refresh(force_refresh):
-                    return await provider.get_item(self.media_type, item_id)
+                    return cast("ItemCls", await provider.get_item(self.media_type, item_id))
         # if we reach this point all possibilities failed and the item could not be found.
         # There is a possibility that the (streaming) provider changed the id of the item
         # so we return the previous details (if we have any) marked as unavailable, so
@@ -530,22 +511,25 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta):
         ):
             # fallback is a ItemMapping, try to convert to full item
             with suppress(LookupError, TypeError, ValueError):
-                return self.item_cls.from_dict(
-                    {
-                        **fallback.to_dict(),
-                        "provider_mappings": [
-                            {
-                                "item_id": fallback.item_id,
-                                "provider_domain": fallback_provider.domain,
-                                "provider_instance": fallback_provider.instance_id,
-                                "available": fallback.available,
-                            }
-                        ],
-                    }
+                return cast(
+                    "ItemCls",
+                    self.item_cls.from_dict(
+                        {
+                            **fallback.to_dict(),
+                            "provider_mappings": [
+                                {
+                                    "item_id": fallback.item_id,
+                                    "provider_domain": fallback_provider.domain,
+                                    "provider_instance": fallback_provider.instance_id,
+                                    "available": fallback.available,
+                                }
+                            ],
+                        }
+                    ),
                 )
         if fallback:
             # simply return the fallback item
-            return fallback
+            return cast("ItemCls", fallback)
         # all options exhausted, we really can not find this item
         msg = (
             f"{self.media_type.value}://{item_id} not "
@@ -740,9 +724,8 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta):
 
         # build and execute final query
         sql_query = self._build_final_query(query_parts, join_parts, order_by)
-
         return [
-            self.item_cls.from_dict(self._parse_db_row(db_row))
+            cast("ItemCls", self.item_cls.from_dict(self._parse_db_row(db_row)))
             for db_row in await self.mass.music.database.get_rows_from_query(
                 sql_query, query_params, limit=limit, offset=offset
             )
@@ -853,7 +836,7 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta):
         return sql_query
 
     @staticmethod
-    def _parse_db_row(db_row: Mapping) -> dict[str, Any]:
+    def _parse_db_row(db_row: Mapping[str, Any]) -> dict[str, Any]:
         """Parse raw db Mapping into a dict."""
         db_row_dict = dict(db_row)
         db_row_dict["provider"] = "library"
index f3b43eb59350ed71d38d660d606b3b3b060651f8..a09cab8700d7428f66cd697b0ba164aec1ca8857 100644 (file)
@@ -130,7 +130,6 @@ enable_error_code = [
   "truthy-iterable",
 ]
 exclude = [
-  '^music_assistant/controllers/media/base.py*$',
   '^music_assistant/controllers/media/playlists.py*$',
   '^music_assistant/controllers/media/tracks.py*$',
   '^music_assistant/controllers/music.py$',