From: Fabian Munkes <105975993+fmunkes@users.noreply.github.com> Date: Tue, 24 Jun 2025 11:12:44 +0000 (+0200) Subject: Chore: Mypy for models (#2195) X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=0f417a1da84aa7cdbdfe4b728fe0d9734e8cf98f;p=music-assistant-server.git Chore: Mypy for models (#2195) --- diff --git a/music_assistant/controllers/media/base.py b/music_assistant/controllers/media/base.py index 723b5cf6..f98b4a72 100644 --- a/music_assistant/controllers/media/base.py +++ b/music_assistant/controllers/media/base.py @@ -408,7 +408,7 @@ class MediaControllerBase(Generic[ItemCls], metaclass=ABCMeta): async def get_library_item_by_prov_mappings( self, - provider_mappings: list[ProviderMapping], + provider_mappings: Iterable[ProviderMapping], ) -> ItemCls | None: """Get the library item for the given provider_instance.""" # always prefer provider instance first diff --git a/music_assistant/controllers/music.py b/music_assistant/controllers/music.py index 31c8010e..e11b2f6f 100644 --- a/music_assistant/controllers/music.py +++ b/music_assistant/controllers/music.py @@ -1023,7 +1023,7 @@ class MusicController(CoreController): return self.podcasts if media_type == MediaType.PODCAST_EPISODE: return self.podcasts - return None + raise NotImplementedError def get_unique_providers(self) -> set[str]: """ diff --git a/music_assistant/models/__init__.py b/music_assistant/models/__init__.py index ce2bab18..468f8d71 100644 --- a/music_assistant/models/__init__.py +++ b/music_assistant/models/__init__.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from music_assistant_models.config_entries import ConfigEntry, ConfigValueType, ProviderConfig from music_assistant_models.provider import ProviderManifest - from music_assistant import MusicAssistant + from music_assistant.mass import MusicAssistant ProviderInstanceType = MetadataProvider | MusicProvider | PlayerProvider | PluginProvider @@ -27,6 +27,7 @@ class ProviderModuleType(Protocol): mass: MusicAssistant, manifest: ProviderManifest, config: ProviderConfig ) -> ProviderInstanceType: """Initialize provider(instance) with given configuration.""" + raise NotImplementedError @staticmethod async def get_config_entries( @@ -42,3 +43,4 @@ class ProviderModuleType(Protocol): action: [optional] action key called from config entries UI. values: the (intermediate) raw values for config entries sent with the action. """ + raise NotImplementedError diff --git a/music_assistant/models/core_controller.py b/music_assistant/models/core_controller.py index 7f2532c3..efff1497 100644 --- a/music_assistant/models/core_controller.py +++ b/music_assistant/models/core_controller.py @@ -13,7 +13,7 @@ from music_assistant.constants import CONF_LOG_LEVEL, MASS_LOGGER_NAME if TYPE_CHECKING: from music_assistant_models.config_entries import ConfigEntry, ConfigValueType, CoreConfig - from music_assistant import MusicAssistant + from music_assistant.mass import MusicAssistant class CoreController: @@ -54,7 +54,7 @@ class CoreController: await self.close() if config is None: config = await self.mass.config.get_core_config(self.domain) - log_level = config.get_value(CONF_LOG_LEVEL) + log_level = str(config.get_value(CONF_LOG_LEVEL)) self._set_logger(log_level) await self.setup(config) @@ -63,8 +63,8 @@ class CoreController: mass_logger = logging.getLogger(MASS_LOGGER_NAME) self.logger = mass_logger.getChild(self.domain) if log_level is None: - log_level = self.mass.config.get_raw_core_config_value( - self.domain, CONF_LOG_LEVEL, "GLOBAL" + log_level = str( + self.mass.config.get_raw_core_config_value(self.domain, CONF_LOG_LEVEL, "GLOBAL") ) if log_level == "GLOBAL": self.logger.setLevel(mass_logger.level) diff --git a/music_assistant/models/metadata_provider.py b/music_assistant/models/metadata_provider.py index bf674961..bb16c06b 100644 --- a/music_assistant/models/metadata_provider.py +++ b/music_assistant/models/metadata_provider.py @@ -35,16 +35,19 @@ class MetadataProvider(Provider): """Retrieve metadata for an artist on this Metadata provider.""" if ProviderFeature.ARTIST_METADATA in self.supported_features: raise NotImplementedError + return None async def get_album_metadata(self, album: Album) -> MediaItemMetadata | None: """Retrieve metadata for an album on this Metadata provider.""" if ProviderFeature.ALBUM_METADATA in self.supported_features: raise NotImplementedError + return None async def get_track_metadata(self, track: Track) -> MediaItemMetadata | None: """Retrieve metadata for a track on this Metadata provider.""" if ProviderFeature.TRACK_METADATA in self.supported_features: raise NotImplementedError + return None async def resolve_image(self, path: str) -> str | bytes: """ diff --git a/music_assistant/models/music_provider.py b/music_assistant/models/music_provider.py index b293c94b..1a0129a1 100644 --- a/music_assistant/models/music_provider.py +++ b/music_assistant/models/music_provider.py @@ -37,6 +37,14 @@ if TYPE_CHECKING: from music_assistant_models.streamdetails import StreamDetails + from music_assistant.controllers.media.albums import AlbumsController + from music_assistant.controllers.media.artists import ArtistsController + from music_assistant.controllers.media.audiobooks import AudiobooksController + from music_assistant.controllers.media.playlists import PlaylistController + from music_assistant.controllers.media.podcasts import PodcastsController + from music_assistant.controllers.media.radio import RadioController + from music_assistant.controllers.media.tracks import TracksController + # ruff: noqa: ARG001, ARG002 @@ -89,37 +97,37 @@ class MusicProvider(Provider): async def get_library_artists(self) -> AsyncGenerator[Artist, None]: """Retrieve library artists from the provider.""" - yield + yield # type: ignore[misc] raise NotImplementedError async def get_library_albums(self) -> AsyncGenerator[Album, None]: """Retrieve library albums from the provider.""" - yield + yield # type: ignore[misc] raise NotImplementedError async def get_library_tracks(self) -> AsyncGenerator[Track, None]: """Retrieve library tracks from the provider.""" - yield + yield # type: ignore[misc] raise NotImplementedError async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]: """Retrieve library/subscribed playlists from the provider.""" - yield + yield # type: ignore[misc] raise NotImplementedError async def get_library_radios(self) -> AsyncGenerator[Radio, None]: """Retrieve library/subscribed radio stations from the provider.""" - yield + yield # type: ignore[misc] raise NotImplementedError async def get_library_audiobooks(self) -> AsyncGenerator[Audiobook, None]: """Retrieve library/subscribed audiobooks from the provider.""" - yield + yield # type: ignore[misc] raise NotImplementedError async def get_library_podcasts(self) -> AsyncGenerator[Podcast, None]: """Retrieve library/subscribed podcasts from the provider.""" - yield + yield # type: ignore[misc] raise NotImplementedError async def get_artist(self, prov_artist_id: str) -> Artist: @@ -127,77 +135,99 @@ class MusicProvider(Provider): raise NotImplementedError async def get_artist_albums(self, prov_artist_id: str) -> list[Album]: - """Get a list of all albums for the given artist.""" - if ProviderFeature.ARTIST_ALBUMS in self.supported_features: - raise NotImplementedError - return [] + """Get a list of all albums for the given artist. + + Only called if provider supports ProviderFeature.ARTIST_ALBUMS. + """ + raise NotImplementedError async def get_artist_toptracks(self, prov_artist_id: str) -> list[Track]: - """Get a list of most popular tracks for the given artist.""" - if ProviderFeature.ARTIST_TOPTRACKS in self.supported_features: - raise NotImplementedError - return [] + """Get a list of most popular tracks for the given artist. - async def get_album(self, prov_album_id: str) -> Album: # type: ignore[return] - """Get full album details by id.""" - if ProviderFeature.LIBRARY_ALBUMS in self.supported_features: - raise NotImplementedError + Only called if provider supports ProviderFeature.ARTIST_TOPTRACKS. + """ + raise NotImplementedError - async def get_track(self, prov_track_id: str) -> Track: # type: ignore[return] - """Get full track details by id.""" - if ProviderFeature.LIBRARY_TRACKS in self.supported_features: - raise NotImplementedError + async def get_album(self, prov_album_id: str) -> Album: + """Get full album details by id. - async def get_playlist(self, prov_playlist_id: str) -> Playlist: # type: ignore[return] - """Get full playlist details by id.""" - if ProviderFeature.LIBRARY_PLAYLISTS in self.supported_features: - raise NotImplementedError + Only called if provider supports ProviderFeature.LIBRARY_ALBUMS. + """ + raise NotImplementedError - async def get_radio(self, prov_radio_id: str) -> Radio: # type: ignore[return] - """Get full radio details by id.""" - if ProviderFeature.LIBRARY_RADIOS in self.supported_features: - raise NotImplementedError + async def get_track(self, prov_track_id: str) -> Track: + """Get full track details by id. - async def get_audiobook(self, prov_audiobook_id: str) -> Audiobook: # type: ignore[return] - """Get full audiobook details by id.""" - if ProviderFeature.LIBRARY_AUDIOBOOKS in self.supported_features: - raise NotImplementedError + Only called if provider supports ProviderFeature.LIBRARY_TRACKS. + """ + raise NotImplementedError - async def get_podcast(self, prov_podcast_id: str) -> Podcast: # type: ignore[return] - """Get full audiobook details by id.""" - if ProviderFeature.LIBRARY_PODCASTS in self.supported_features: - raise NotImplementedError + async def get_playlist(self, prov_playlist_id: str) -> Playlist: + """Get full playlist details by id. + + Only called if provider supports ProviderFeature.LIBRARY_PLAYLISTS. + """ + raise NotImplementedError + + async def get_radio(self, prov_radio_id: str) -> Radio: + """Get full radio details by id. + + Only called if provider supports ProviderFeature.LIBRARY_RADIOS. + """ + raise NotImplementedError + + async def get_audiobook(self, prov_audiobook_id: str) -> Audiobook: + """Get full audiobook details by id. + + Only called if provider supports ProviderFeature.LIBRARY_AUDIOBOOKS. + """ + raise NotImplementedError + + async def get_podcast(self, prov_podcast_id: str) -> Podcast: + """Get full podcast details by id. + + Only called if provider supports ProviderFeature.LIBRARY_PODCASTS. + """ + raise NotImplementedError async def get_podcast_episode(self, prov_episode_id: str) -> PodcastEpisode: - """Get (full) podcast episode details by id.""" - if ProviderFeature.LIBRARY_PODCASTS in self.supported_features: - raise NotImplementedError + """Get (full) podcast episode details by id. + + Only called if provider supports ProviderFeature.LIBRARY_PODCASTS. + """ + raise NotImplementedError async def get_album_tracks( self, - prov_album_id: str, # type: ignore[return] + prov_album_id: str, ) -> list[Track]: - """Get album tracks for given album id.""" - if ProviderFeature.LIBRARY_ALBUMS in self.supported_features: - raise NotImplementedError + """Get album tracks for given album id. + + Only called if provider supports ProviderFeature.LIBRARY_ALBUMS. + """ + raise NotImplementedError async def get_playlist_tracks( self, prov_playlist_id: str, page: int = 0, ) -> list[Track]: - """Get all playlist tracks for given playlist id.""" - if ProviderFeature.LIBRARY_PLAYLISTS in self.supported_features: - raise NotImplementedError + """Get all playlist tracks for given playlist id. + + Only called if provider supports ProviderFeature.LIBRARY_PLAYLISTS. + """ + raise NotImplementedError async def get_podcast_episodes( self, prov_podcast_id: str, ) -> AsyncGenerator[PodcastEpisode, None]: - """Get all PodcastEpisodes for given podcast id.""" - yield - if ProviderFeature.LIBRARY_PODCASTS in self.supported_features: - raise NotImplementedError + """Get all PodcastEpisodes for given podcast id. + + Only called if provider supports ProviderFeature.LIBRARY_PODCASTS. + """ + yield # type: ignore[misc] + raise NotImplementedError async def library_add(self, item: MediaItemType) -> bool: """Add item to provider's library. Return true on success.""" @@ -288,28 +318,34 @@ class MusicProvider(Provider): return True async def add_playlist_tracks(self, prov_playlist_id: str, prov_track_ids: list[str]) -> None: - """Add track(s) to playlist.""" - if ProviderFeature.PLAYLIST_TRACKS_EDIT in self.supported_features: - raise NotImplementedError + """Add track(s) to playlist. + + Only called if provider supports ProviderFeature.PLAYLIST_TRACKS_EDIT. + """ + raise NotImplementedError async def remove_playlist_tracks( self, prov_playlist_id: str, positions_to_remove: tuple[int, ...] ) -> None: - """Remove track(s) from playlist.""" - if ProviderFeature.PLAYLIST_TRACKS_EDIT in self.supported_features: - raise NotImplementedError + """Remove track(s) from playlist. - async def create_playlist(self, name: str) -> Playlist: # type: ignore[return] - """Create a new playlist on provider with given name.""" - if ProviderFeature.PLAYLIST_CREATE in self.supported_features: - raise NotImplementedError + Only called if provider supports ProviderFeature.PLAYLIST_TRACKS_EDIT. + """ + raise NotImplementedError - async def get_similar_tracks( # type: ignore[return] - self, prov_track_id: str, limit: int = 25 - ) -> list[Track]: - """Retrieve a dynamic list of similar tracks based on the provided track.""" - if ProviderFeature.SIMILAR_TRACKS in self.supported_features: - raise NotImplementedError + async def create_playlist(self, name: str) -> Playlist: + """Create a new playlist on provider with given name. + + Only called if provider supports ProviderFeature.PLAYLIST_CREATE. + """ + raise NotImplementedError + + async def get_similar_tracks(self, prov_track_id: str, limit: int = 25) -> list[Track]: + """Retrieve a dynamic list of similar tracks based on the provided track. + + Only called if provider supports ProviderFeature.SIMILAR_TRACKS. + """ + raise NotImplementedError async def get_resume_position(self, item_id: str, media_type: MediaType) -> tuple[bool, int]: """ @@ -330,7 +366,7 @@ class MusicProvider(Provider): """Get streamdetails for a track/radio/chapter/episode.""" raise NotImplementedError - async def get_audio_stream( # type: ignore[return] + async def get_audio_stream( self, streamdetails: StreamDetails, seek_position: int = 0 ) -> AsyncGenerator[bytes, None]: """ @@ -338,8 +374,7 @@ class MusicProvider(Provider): Will only be called when the stream_type is set to CUSTOM. """ - if False: - yield + yield b"" raise NotImplementedError async def on_streamed( @@ -412,7 +447,7 @@ class MusicProvider(Provider): return await self.get_podcast_episode(prov_item_id) return await self.get_track(prov_item_id) - async def browse(self, path: str) -> Sequence[MediaItemType | ItemMapping | BrowseFolder]: # noqa: PLR0911, PLR0915 + async def browse(self, path: str) -> Sequence[MediaItemType | ItemMapping | BrowseFolder]: # noqa: PLR0911 """Browse this provider's items. :param path: The path to browse, (e.g. provider_id://artists). @@ -622,8 +657,56 @@ class MusicProvider(Provider): async def sync_library(self, media_type: MediaType) -> None: """Run library sync for this provider.""" + # ruff: noqa: PLR0915 # too many statements # this reference implementation can be overridden # with a provider specific approach if needed + + async def _controller_update_item_in_library( + controller: ArtistsController + | AlbumsController + | TracksController + | RadioController + | PlaylistController + | AudiobooksController + | PodcastsController, + prov_item: MediaItemType, + item_id: str | int, + ) -> Artist | Album | Track | Radio | Playlist | Audiobook | Podcast: + """Update media item in controller including type checking. + + all isinstance(...) for type checking. The statement + library_item = await controller.update_item_in_library(prov_item) + cannot be moved out of this scope. + """ + library_item: Artist | Album | Track | Radio | Playlist | Audiobook | Podcast + if TYPE_CHECKING: + if isinstance(prov_item, Artist): + assert isinstance(controller, ArtistsController) + library_item = await controller.update_item_in_library(item_id, prov_item) + elif isinstance(prov_item, Album): + assert isinstance(controller, AlbumsController) + library_item = await controller.update_item_in_library(item_id, prov_item) + elif isinstance(prov_item, Track): + assert isinstance(controller, TracksController) + library_item = await controller.update_item_in_library(item_id, prov_item) + elif isinstance(prov_item, Radio): + assert isinstance(controller, RadioController) + library_item = await controller.update_item_in_library(item_id, prov_item) + elif isinstance(prov_item, Playlist): + assert isinstance(controller, PlaylistController) + library_item = await controller.update_item_in_library(item_id, prov_item) + elif isinstance(prov_item, Audiobook): + assert isinstance(controller, AudiobooksController) + library_item = await controller.update_item_in_library(item_id, prov_item) + elif isinstance(prov_item, Podcast): + assert isinstance(controller, PodcastsController) + library_item = await controller.update_item_in_library(item_id, prov_item) + else: + raise TypeError("Prov item unknown in this context.") + return library_item + else: + return await controller.update_item_in_library(item_id, prov_item) + if not self.library_supported(media_type): raise UnsupportedFeaturedException("Library sync not supported for this media type") self.logger.debug("Start sync of %s items.", media_type.value) @@ -633,6 +716,7 @@ class MusicProvider(Provider): library_item = await controller.get_library_item_by_prov_mappings( prov_item.provider_mappings, ) + assert not isinstance(prov_item, PodcastEpisode) try: if not library_item and not prov_item.available: # skip unavailable tracks @@ -647,18 +731,50 @@ class MusicProvider(Provider): # the additional metadata is then lazy retrieved afterwards if self.is_streaming_provider: prov_item.favorite = True - library_item = await controller.add_item_to_library(prov_item) + + # all isinstance(...) for type checking. The statement + # library_item = await controller.add_item_to_library(prov_item) + # cannot be moved out of this scope. + if TYPE_CHECKING: + if isinstance(prov_item, Artist): + assert isinstance(controller, ArtistsController) + library_item = await controller.add_item_to_library(prov_item) + elif isinstance(prov_item, Album): + assert isinstance(controller, AlbumsController) + library_item = await controller.add_item_to_library(prov_item) + elif isinstance(prov_item, Track): + assert isinstance(controller, TracksController) + library_item = await controller.add_item_to_library(prov_item) + elif isinstance(prov_item, Radio): + assert isinstance(controller, RadioController) + library_item = await controller.add_item_to_library(prov_item) + elif isinstance(prov_item, Playlist): + assert isinstance(controller, PlaylistController) + library_item = await controller.add_item_to_library(prov_item) + elif isinstance(prov_item, Audiobook): + assert isinstance(controller, AudiobooksController) + library_item = await controller.add_item_to_library(prov_item) + elif isinstance(prov_item, Podcast): + assert isinstance(controller, PodcastsController) + library_item = await controller.add_item_to_library(prov_item) + else: + raise RuntimeError + else: + library_item = await controller.add_item_to_library(prov_item) elif getattr(library_item, "cache_checksum", None) != getattr( prov_item, "cache_checksum", None ): # existing dbitem checksum changed (playlists only) + if TYPE_CHECKING: + assert isinstance(prov_item, Playlist) + assert isinstance(controller, PlaylistController) library_item = await controller.update_item_in_library( library_item.item_id, prov_item ) if library_item.available != prov_item.available: # existing item availability changed - library_item = await controller.update_item_in_library( - library_item.item_id, prov_item + library_item = await _controller_update_item_in_library( + controller, prov_item, library_item.item_id ) # check if resume_position_ms or fully_played changed (audiobook only) resume_pos_prov = getattr(prov_item, "resume_position_ms", None) @@ -671,8 +787,8 @@ class MusicProvider(Provider): or getattr(library_item, "fully_played", None) != fully_played_prov ) ): - library_item = await controller.update_item_in_library( - library_item.item_id, prov_item + library_item = await _controller_update_item_in_library( + controller, prov_item, library_item.item_id ) cur_db_ids.add(int(library_item.item_id)) diff --git a/music_assistant/models/plugin.py b/music_assistant/models/plugin.py index e8d77461..26bae5b1 100644 --- a/music_assistant/models/plugin.py +++ b/music_assistant/models/plugin.py @@ -4,14 +4,17 @@ from __future__ import annotations from collections.abc import AsyncGenerator from dataclasses import dataclass, field +from typing import TYPE_CHECKING from mashumaro import field_options, pass_through from music_assistant_models.enums import StreamType from music_assistant_models.player import PlayerMedia, PlayerSource -from music_assistant_models.streamdetails import AudioFormat # noqa: TC002 from .provider import Provider +if TYPE_CHECKING: + from music_assistant_models.media_items.audio_format import AudioFormat + # ruff: noqa: ARG001, ARG002 @@ -72,7 +75,7 @@ class PluginProvider(Provider): Plugin Provider implementations should inherit from this base model. """ - def get_source(self) -> PluginSource: # type: ignore[return] + def get_source(self) -> PluginSource: """Get (audio)source details for this plugin.""" # Will only be called if ProviderFeature.AUDIO_SOURCE is declared raise NotImplementedError @@ -86,6 +89,5 @@ class PluginProvider(Provider): The player_id is the id of the player that is requesting the stream. """ - if False: - yield b"" + yield b"" raise NotImplementedError diff --git a/music_assistant/models/provider.py b/music_assistant/models/provider.py index 3d579dad..6dd596ef 100644 --- a/music_assistant/models/provider.py +++ b/music_assistant/models/provider.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from zeroconf import ServiceStateChange from zeroconf.asyncio import AsyncServiceInfo - from music_assistant import MusicAssistant + from music_assistant.mass import MusicAssistant class Provider: @@ -29,7 +29,7 @@ class Provider: self.config = config mass_logger = logging.getLogger(MASS_LOGGER_NAME) self.logger = mass_logger.getChild(self.domain) - log_level = config.get_value(CONF_LOG_LEVEL) + log_level = str(config.get_value(CONF_LOG_LEVEL)) if log_level == "GLOBAL": self.logger.setLevel(mass_logger.level) else: @@ -44,7 +44,7 @@ class Provider: @property def supported_features(self) -> set[ProviderFeature]: """Return the features supported by this Provider.""" - return () + return set() @property def lookup_key(self) -> str: @@ -129,7 +129,7 @@ class Provider: """Unload provider with error message.""" self.mass.call_later(1, self.mass.unload_provider, self.instance_id, error) - def to_dict(self, *args, **kwargs) -> dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Return Provider(instance) as serializable dict.""" return { "type": self.type.value, diff --git a/pyproject.toml b/pyproject.toml index f9e3e8e7..4618f7dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,7 +125,7 @@ exclude = [ '^music_assistant/controllers/.*$', '^music_assistant/helpers/app_vars.py', '^music_assistant/helpers/webserver.py', - '^music_assistant/models/.*$', + '^music_assistant/models/player_provider.py', '^music_assistant/providers/apple_music/.*$', '^music_assistant/providers/bluesound/.*$', '^music_assistant/providers/chromecast/.*$',