Typing fixes for the podcasts controller (#2627)
authorOzGav <gavnosp@hotmail.com>
Tue, 25 Nov 2025 10:34:03 +0000 (20:34 +1000)
committerGitHub <noreply@github.com>
Tue, 25 Nov 2025 10:34:03 +0000 (11:34 +0100)
* Typing fixes for the podcasts controller

* Fix overlooked typing issue

* Remove db asserts

* Remove ignore

music_assistant/controllers/media/podcasts.py
pyproject.toml

index e81bf0d9fb43c1b3de0b4da32373af4a40e3b1bc..eaba273e61bc7355ef47c7ba5909be9122ae2515 100644 (file)
@@ -6,8 +6,8 @@ from collections.abc import AsyncGenerator
 from typing import TYPE_CHECKING, Any
 
 from music_assistant_models.enums import MediaType, ProviderFeature
-from music_assistant_models.errors import InvalidDataError, MediaNotFoundError
-from music_assistant_models.media_items import Artist, Podcast, PodcastEpisode, UniqueList
+from music_assistant_models.errors import MediaNotFoundError, ProviderUnavailableError
+from music_assistant_models.media_items import Podcast, PodcastEpisode, UniqueList
 
 from music_assistant.constants import DB_TABLE_PLAYLOG, DB_TABLE_PODCASTS
 from music_assistant.controllers.media.base import MediaControllerBase
@@ -18,11 +18,12 @@ from music_assistant.helpers.compare import (
     loose_compare_strings,
 )
 from music_assistant.helpers.json import serialize_to_json
+from music_assistant.models.music_provider import MusicProvider
 
 if TYPE_CHECKING:
     from music_assistant_models.media_items import Track
 
-    from music_assistant.models.music_provider import MusicProvider
+    from music_assistant import MusicAssistant
 
 
 class PodcastsController(MediaControllerBase[Podcast]):
@@ -32,9 +33,9 @@ class PodcastsController(MediaControllerBase[Podcast]):
     media_type = MediaType.PODCAST
     item_cls = Podcast
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(self, mass: MusicAssistant) -> None:
         """Initialize class."""
-        super().__init__(*args, **kwargs)
+        super().__init__(mass)
         # register (extra) api handlers
         api_base = self.api_base
         self.mass.register_api_command(f"music/{api_base}/podcast_episodes", self.episodes)
@@ -51,9 +52,9 @@ class PodcastsController(MediaControllerBase[Podcast]):
         provider: str | None = None,
         extra_query: str | None = None,
         extra_query_params: dict[str, Any] | None = None,
-    ) -> list[Artist]:
+    ) -> list[Podcast]:
         """Get in-database podcasts."""
-        extra_query_params: dict[str, Any] = extra_query_params or {}
+        extra_query_params = extra_query_params or {}
         extra_query_parts: list[str] = [extra_query] if extra_query else []
         result = await self._get_library_items_by_query(
             favorite=favorite,
@@ -108,11 +109,11 @@ class PodcastsController(MediaControllerBase[Podcast]):
         self,
         item_id: str,
         provider_instance_id_or_domain: str,
-    ) -> UniqueList[PodcastEpisode]:
+    ) -> PodcastEpisode:
         """Return single podcast episode by the given provider podcast id."""
-        prov: MusicProvider = self.mass.get_provider(provider_instance_id_or_domain)
-        if not prov:
-            raise InvalidDataError("Provider not found")
+        prov = self.mass.get_provider(provider_instance_id_or_domain)
+        if not isinstance(prov, MusicProvider):
+            raise ProviderUnavailableError("Provider not found")
         return await prov.get_podcast_episode(item_id)
 
     async def versions(
@@ -126,7 +127,7 @@ class PodcastsController(MediaControllerBase[Podcast]):
         result: UniqueList[Podcast] = UniqueList()
         for provider_id in self.mass.music.get_unique_providers():
             provider = self.mass.get_provider(provider_id)
-            if not provider:
+            if not isinstance(provider, MusicProvider):
                 continue
             if not provider.library_supported(MediaType.PODCAST):
                 continue
@@ -139,11 +140,8 @@ class PodcastsController(MediaControllerBase[Podcast]):
             )
         return result
 
-    async def _add_library_item(self, item: Podcast) -> int:
+    async def _add_library_item(self, item: Podcast, overwrite_existing: bool = False) -> int:
         """Add a new record to the database."""
-        if not isinstance(item, Podcast):
-            msg = "Not a valid Podcast object (ItemMapping can not be added to db)"
-            raise InvalidDataError(msg)
         db_id = await self.mass.music.database.insert(
             self.db_table,
             {
@@ -156,7 +154,7 @@ class PodcastsController(MediaControllerBase[Podcast]):
                 "publisher": item.publisher,
                 "total_episodes": item.total_episodes or 0,
                 "search_name": create_safe_string(item.name, True, True),
-                "search_sort_name": create_safe_string(item.sort_name, True, True),
+                "search_sort_name": create_safe_string(item.sort_name or "", True, True),
             },
         )
         # update/set provider_mappings table
@@ -188,7 +186,7 @@ class PodcastsController(MediaControllerBase[Podcast]):
                 "publisher": cur_item.publisher or update.publisher,
                 "total_episodes": cur_item.total_episodes or update.total_episodes or 0,
                 "search_name": create_safe_string(name, True, True),
-                "search_sort_name": create_safe_string(sort_name, True, True),
+                "search_sort_name": create_safe_string(sort_name or "", True, True),
             },
         )
         # update/set provider_mappings table
@@ -204,8 +202,8 @@ class PodcastsController(MediaControllerBase[Podcast]):
         self, item_id: str, provider_instance_id_or_domain: str
     ) -> AsyncGenerator[PodcastEpisode, None]:
         """Return podcast episodes for the given provider podcast id."""
-        prov: MusicProvider = self.mass.get_provider(provider_instance_id_or_domain)
-        if prov is None:
+        prov = self.mass.get_provider(provider_instance_id_or_domain)
+        if not isinstance(prov, MusicProvider):
             return
 
         async def set_resume_position(episode: PodcastEpisode) -> None:
@@ -254,7 +252,7 @@ class PodcastsController(MediaControllerBase[Podcast]):
         if db_podcast.provider != "library":
             return  # Matching only supported for database items
 
-        async def find_prov_match(provider: MusicProvider):
+        async def find_prov_match(provider: MusicProvider) -> bool:
             self.logger.debug(
                 "Trying to match podcast %s on provider %s",
                 db_podcast.name,
index 50fb0edcf326f4608126a6a68aab14dde3fc3512..f3b43eb59350ed71d38d660d606b3b3b060651f8 100644 (file)
@@ -132,7 +132,6 @@ enable_error_code = [
 exclude = [
   '^music_assistant/controllers/media/base.py*$',
   '^music_assistant/controllers/media/playlists.py*$',
-  '^music_assistant/controllers/media/podcasts.py*$',
   '^music_assistant/controllers/media/tracks.py*$',
   '^music_assistant/controllers/music.py$',
   '^music_assistant/helpers/app_vars.py',