From 9702155da07bdf1d3bb528753a67fd1b44e7a0b6 Mon Sep 17 00:00:00 2001 From: Jc2k Date: Mon, 8 Jul 2024 23:39:01 +0100 Subject: [PATCH] Add typing for Plex (#1479) --- .../server/providers/plex/__init__.py | 307 +++++++++++------- .../server/providers/plex/helpers.py | 20 +- mypy.ini | 2 +- 3 files changed, 195 insertions(+), 134 deletions(-) diff --git a/music_assistant/server/providers/plex/__init__.py b/music_assistant/server/providers/plex/__init__.py index d82e9f57..b9a0ad51 100644 --- a/music_assistant/server/providers/plex/__init__.py +++ b/music_assistant/server/providers/plex/__init__.py @@ -4,9 +4,10 @@ from __future__ import annotations import asyncio import logging -from asyncio import TaskGroup +from asyncio import Task, TaskGroup +from collections.abc import Awaitable from contextlib import suppress -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast import plexapi.exceptions import requests @@ -14,6 +15,7 @@ from plexapi.audio import Album as PlexAlbum from plexapi.audio import Artist as PlexArtist from plexapi.audio import Playlist as PlexPlaylist from plexapi.audio import Track as PlexTrack +from plexapi.base import PlexObject from plexapi.myplex import MyPlexAccount, MyPlexPinLogin from plexapi.server import PlexServer @@ -50,6 +52,7 @@ from music_assistant.common.models.media_items import ( ProviderMapping, SearchResults, Track, + UniqueList, ) from music_assistant.common.models.streamdetails import StreamDetails from music_assistant.constants import UNKNOWN_ARTIST @@ -118,11 +121,13 @@ async def get_config_entries( # noqa: PLR0915 if action == CONF_ACTION_GDM: server_details = await discover_local_servers() if server_details and server_details[0] and server_details[1]: + assert values values[CONF_LOCAL_SERVER_IP] = server_details[0] values[CONF_LOCAL_SERVER_PORT] = server_details[1] values[CONF_LOCAL_SERVER_SSL] = False values[CONF_LOCAL_SERVER_VERIFY_CERT] = False else: + assert values values[CONF_LOCAL_SERVER_IP] = "Discovery failed, please add IP manually" values[CONF_LOCAL_SERVER_PORT] = 32400 values[CONF_LOCAL_SERVER_SSL] = False @@ -130,6 +135,7 @@ async def get_config_entries( # noqa: PLR0915 # handle action clear authentication if action == CONF_ACTION_CLEAR_AUTH: + assert values values[CONF_AUTH_TOKEN] = None values[CONF_LOCAL_SERVER_IP] = None values[CONF_LOCAL_SERVER_PORT] = 32400 @@ -138,8 +144,9 @@ async def get_config_entries( # noqa: PLR0915 # handle action MyPlex auth if action == CONF_ACTION_AUTH_MYPLEX: + assert values values[CONF_AUTH_TOKEN] = None - async with AuthenticationHelper(mass, values["session_id"]) as auth_helper: + async with AuthenticationHelper(mass, str(values["session_id"])) as auth_helper: plex_auth = MyPlexPinLogin(headers={"X-Plex-Product": "Music Assistant"}, oauth=True) auth_url = plex_auth.oauthUrl(auth_helper.callback_url) await auth_helper.authenticate(auth_url) @@ -151,6 +158,7 @@ async def get_config_entries( # noqa: PLR0915 # handle action Local auth (no MyPlex) if action == CONF_ACTION_AUTH_LOCAL: + assert values values[CONF_AUTH_TOKEN] = AUTH_TOKEN_UNAUTH # collect all config entries to show @@ -231,11 +239,11 @@ async def get_config_entries( # noqa: PLR0915 action_label="Select Plex Music Library", ) if action in (CONF_ACTION_LIBRARY, CONF_ACTION_AUTH_MYPLEX, CONF_ACTION_AUTH_LOCAL): - token = mass.config.decrypt_string(values.get(CONF_AUTH_TOKEN)) - server_http_ip = values.get(CONF_LOCAL_SERVER_IP) - server_http_port = values.get(CONF_LOCAL_SERVER_PORT) - server_http_ssl = values.get(CONF_LOCAL_SERVER_SSL) - server_http_verify_cert = values.get(CONF_LOCAL_SERVER_VERIFY_CERT) + token = mass.config.decrypt_string(str(values.get(CONF_AUTH_TOKEN))) + server_http_ip = str(values.get(CONF_LOCAL_SERVER_IP)) + server_http_port = str(values.get(CONF_LOCAL_SERVER_PORT)) + server_http_ssl = bool(values.get(CONF_LOCAL_SERVER_SSL)) + server_http_verify_cert = bool(values.get(CONF_LOCAL_SERVER_VERIFY_CERT)) if not ( libraries := await get_libraries( mass, @@ -301,19 +309,25 @@ async def get_config_entries( # noqa: PLR0915 return tuple(entries) +Param = ParamSpec("Param") +RetType = TypeVar("RetType") +PlexObjectT = TypeVar("PlexObjectT", bound=PlexObject) +MediaItemT = TypeVar("MediaItemT", bound=MediaItem) + + class PlexProvider(MusicProvider): """Provider for a plex music library.""" _plex_server: PlexServer = None _plex_library: PlexMusicSection = None _myplex_account: MyPlexAccount = None - _baseurl: str = None + _baseurl: str async def handle_async_init(self) -> None: """Set up the music provider by connecting to the server.""" # silence loggers logging.getLogger("plexapi").setLevel(self.logger.level + 10) - _, library_name = self.config.get_value(CONF_LIBRARY_ID).split(" / ", 1) + _, library_name = str(self.config.get_value(CONF_LIBRARY_ID)).split(" / ", 1) def connect() -> PlexServer: try: @@ -346,14 +360,19 @@ class PlexProvider(MusicProvider): except plexapi.exceptions.BadRequest as err: if "Invalid token" in str(err): # token invalid, invalidate the config - self.mass.config.remove_provider_config_value(self.instance_id, CONF_AUTH_TOKEN) + self.mass.call_later( + 0, + self.mass.config.remove_provider_config_value( + self.instance_id, CONF_AUTH_TOKEN + ), + ) msg = "Authentication failed" raise LoginFailed(msg) raise LoginFailed from err return plex_server self._myplex_account = await self.get_myplex_account_and_refresh_token( - self.config.get_value(CONF_AUTH_TOKEN) + str(self.config.get_value(CONF_AUTH_TOKEN)) ) try: self._plex_server = await self._run_async(connect) @@ -393,14 +412,17 @@ class PlexProvider(MusicProvider): async def resolve_image(self, path: str) -> str | bytes: """Return the full image URL including the auth token.""" - return self._plex_server.url(path, True) + return str(self._plex_server.url(path, True)) - async def _run_async(self, call: Callable, *args, **kwargs): - await self.get_myplex_account_and_refresh_token(self.config.get_value(CONF_AUTH_TOKEN)) + async def _run_async( + self, call: Callable[Param, RetType], *args: Param.args, **kwargs: Param.kwargs + ) -> RetType: + await self.get_myplex_account_and_refresh_token(str(self.config.get_value(CONF_AUTH_TOKEN))) return await asyncio.to_thread(call, *args, **kwargs) - async def _get_data(self, key, cls=None): - return await self._run_async(self._plex_library.fetchItem, key, cls) + async def _get_data(self, key: str, cls: type[PlexObjectT]) -> PlexObjectT: + results = await self._run_async(self._plex_library.fetchItem, key, cls) + return cast(PlexObjectT, results) def _get_item_mapping(self, media_type: MediaType, key: str, name: str) -> ItemMapping: name, version = parse_title_and_version(name) @@ -416,7 +438,7 @@ class PlexProvider(MusicProvider): version=version, ) - async def _get_or_create_artist_by_name(self, artist_name) -> Artist: + async def _get_or_create_artist_by_name(self, artist_name: str) -> Artist | ItemMapping: subquery = ( "WHERE provider_mappings.media_type = 'artist' " "AND provider_mappings.provider_instance = :provider_instance" @@ -445,7 +467,7 @@ class PlexProvider(MusicProvider): }, ) - async def _parse(self, plex_media) -> MediaItem | None: + async def _parse(self, plex_media: PlexObject) -> MediaItem | None: if plex_media.type == "artist": return await self._parse_artist(plex_media) elif plex_media.type == "album": @@ -456,45 +478,67 @@ class PlexProvider(MusicProvider): return await self._parse_playlist(plex_media) return None - async def _search_track(self, search_query, limit) -> list[PlexTrack]: - return await self._run_async( - self._plex_library.searchTracks, title=search_query, limit=limit + async def _search_track(self, search_query: str | None, limit: int) -> list[PlexTrack]: + return cast( + list[PlexTrack], + await self._run_async(self._plex_library.searchTracks, title=search_query, limit=limit), ) - async def _search_album(self, search_query, limit) -> list[PlexAlbum]: - return await self._run_async( - self._plex_library.searchAlbums, title=search_query, limit=limit + async def _search_album(self, search_query: str, limit: int) -> list[PlexAlbum]: + return cast( + list[PlexAlbum], + await self._run_async(self._plex_library.searchAlbums, title=search_query, limit=limit), ) - async def _search_artist(self, search_query, limit) -> list[PlexArtist]: - return await self._run_async( - self._plex_library.searchArtists, title=search_query, limit=limit + async def _search_artist(self, search_query: str, limit: int) -> list[PlexArtist]: + return cast( + list[PlexArtist], + await self._run_async( + self._plex_library.searchArtists, title=search_query, limit=limit + ), ) - async def _search_playlist(self, search_query, limit) -> list[PlexPlaylist]: - return await self._run_async(self._plex_library.playlists, title=search_query, limit=limit) + async def _search_playlist(self, search_query: str, limit: int) -> list[PlexPlaylist]: + return cast( + list[PlexPlaylist], + await self._run_async(self._plex_library.playlists, title=search_query, limit=limit), + ) - async def _search_track_advanced(self, limit, **kwargs) -> list[PlexTrack]: - return await self._run_async(self._plex_library.searchTracks, filters=kwargs, limit=limit) + async def _search_track_advanced(self, limit: int, **kwargs: Any) -> list[PlexTrack]: + return cast( + list[PlexPlaylist], + await self._run_async(self._plex_library.searchTracks, filters=kwargs, limit=limit), + ) - async def _search_album_advanced(self, limit, **kwargs) -> list[PlexAlbum]: - return await self._run_async(self._plex_library.searchAlbums, filters=kwargs, limit=limit) + async def _search_album_advanced(self, limit: int, **kwargs: Any) -> list[PlexAlbum]: + return cast( + list[PlexPlaylist], + await self._run_async(self._plex_library.searchAlbums, filters=kwargs, limit=limit), + ) - async def _search_artist_advanced(self, limit, **kwargs) -> list[PlexArtist]: - return await self._run_async(self._plex_library.searchArtists, filters=kwargs, limit=limit) + async def _search_artist_advanced(self, limit: int, **kwargs: Any) -> list[PlexArtist]: + return cast( + list[PlexPlaylist], + await self._run_async(self._plex_library.searchArtists, filters=kwargs, limit=limit), + ) - async def _search_playlist_advanced(self, limit, **kwargs) -> list[PlexPlaylist]: - return await self._run_async(self._plex_library.playlists, filters=kwargs, limit=limit) + async def _search_playlist_advanced(self, limit: int, **kwargs: Any) -> list[PlexPlaylist]: + return cast( + list[PlexPlaylist], + await self._run_async(self._plex_library.playlists, filters=kwargs, limit=limit), + ) async def _search_and_parse( - self, search_coro: Coroutine, parse_coro: Callable - ) -> list[MediaItem]: - task_results = [] + self, + search_coro: Awaitable[list[PlexObjectT]], + parse_coro: Callable[[PlexObjectT], Coroutine[Any, Any, MediaItemT]], + ) -> list[MediaItemT]: + task_results: list[Task[MediaItemT]] = [] async with TaskGroup() as tg: for item in await search_coro: task_results.append(tg.create_task(parse_coro(item))) - results = [] + results: list[MediaItemT] = [] for task in task_results: results.append(task.result()) @@ -526,14 +570,16 @@ class PlexProvider(MusicProvider): if plex_album.year: album.year = plex_album.year if thumb := plex_album.firstAttr("thumb", "parentThumb", "grandparentThumb"): - album.metadata.images = [ - MediaItemImage( - type=ImageType.THUMB, - path=thumb, - provider=self.instance_id, - remotely_accessible=False, - ) - ] + album.metadata.images = UniqueList( + [ + MediaItemImage( + type=ImageType.THUMB, + path=thumb, + provider=self.instance_id, + remotely_accessible=False, + ) + ] + ) if plex_album.summary: album.metadata.description = plex_album.summary @@ -568,14 +614,16 @@ class PlexProvider(MusicProvider): if plex_artist.summary: artist.metadata.description = plex_artist.summary if thumb := plex_artist.firstAttr("thumb", "parentThumb", "grandparentThumb"): - artist.metadata.images = [ - MediaItemImage( - type=ImageType.THUMB, - path=thumb, - provider=self.instance_id, - remotely_accessible=False, - ) - ] + artist.metadata.images = UniqueList( + [ + MediaItemImage( + type=ImageType.THUMB, + path=thumb, + provider=self.instance_id, + remotely_accessible=False, + ) + ] + ) return artist async def _parse_playlist(self, plex_playlist: PlexPlaylist) -> Playlist: @@ -596,14 +644,16 @@ class PlexProvider(MusicProvider): if plex_playlist.summary: playlist.metadata.description = plex_playlist.summary if thumb := plex_playlist.firstAttr("thumb", "parentThumb", "grandparentThumb"): - playlist.metadata.images = [ - MediaItemImage( - type=ImageType.THUMB, - path=thumb, - provider=self.instance_id, - remotely_accessible=False, - ) - ] + playlist.metadata.images = UniqueList( + [ + MediaItemImage( + type=ImageType.THUMB, + path=thumb, + provider=self.instance_id, + remotely_accessible=False, + ) + ] + ) playlist.is_editable = not plex_playlist.smart playlist.cache_checksum = str(plex_playlist.updatedAt.timestamp()) @@ -665,14 +715,16 @@ class PlexProvider(MusicProvider): raise InvalidDataError(msg) if thumb := plex_track.firstAttr("thumb", "parentThumb", "grandparentThumb"): - track.metadata.images = [ - MediaItemImage( - type=ImageType.THUMB, - path=thumb, - provider=self.instance_id, - remotely_accessible=False, - ) - ] + track.metadata.images = UniqueList( + [ + MediaItemImage( + type=ImageType.THUMB, + path=thumb, + provider=self.instance_id, + remotely_accessible=False, + ) + ] + ) if plex_track.parentKey: track.album = self._get_item_mapping( MediaType.ALBUM, plex_track.parentKey, plex_track.parentTitle @@ -680,15 +732,17 @@ class PlexProvider(MusicProvider): if plex_track.duration: track.duration = int(plex_track.duration / 1000) if plex_track.chapters: - track.metadata.chapters = [ - MediaItemChapter( - chapter_id=plex_chapter.id, - position_start=plex_chapter.start, - position_end=plex_chapter.end, - title=plex_chapter.title, - ) - for plex_chapter in plex_track.chapters - ] + track.metadata.chapters = UniqueList( + [ + MediaItemChapter( + chapter_id=plex_chapter.id, + position_start=plex_chapter.start, + position_end=plex_chapter.end, + title=plex_chapter.title, + ) + for plex_chapter in plex_track.chapters + ] + ) return track @@ -704,47 +758,54 @@ class PlexProvider(MusicProvider): :param media_types: A list of media_types to include. :param limit: Number of items to return in the search (per type). """ - tasks = {} + artists = None + albums = None + tracks = None + playlists = None async with TaskGroup() as tg: - for media_type in media_types: - if media_type == MediaType.ARTIST: - tasks[MediaType.ARTIST] = tg.create_task( - self._search_and_parse( - self._search_artist(search_query, limit), self._parse_artist - ) + if MediaType.ARTIST in media_types: + artists = tg.create_task( + self._search_and_parse( + self._search_artist(search_query, limit), self._parse_artist ) - elif media_type == MediaType.ALBUM: - tasks[MediaType.ARTIST] = tg.create_task( - self._search_and_parse( - self._search_album(search_query, limit), self._parse_album - ) + ) + + if MediaType.ALBUM in media_types: + albums = tg.create_task( + self._search_and_parse( + self._search_album(search_query, limit), self._parse_album ) - elif media_type == MediaType.TRACK: - tasks[MediaType.ARTIST] = tg.create_task( - self._search_and_parse( - self._search_track(search_query, limit), self._parse_track - ) + ) + + if MediaType.TRACK in media_types: + tracks = tg.create_task( + self._search_and_parse( + self._search_track(search_query, limit), self._parse_track ) - elif media_type == MediaType.PLAYLIST: - tasks[MediaType.ARTIST] = tg.create_task( - self._search_and_parse( - self._search_playlist(search_query, limit), - self._parse_playlist, - ) + ) + + if MediaType.PLAYLIST in media_types: + playlists = tg.create_task( + self._search_and_parse( + self._search_playlist(search_query, limit), + self._parse_playlist, ) + ) search_results = SearchResults() - for media_type, task in tasks.items(): - if media_type == MediaType.ARTIST: - search_results.artists = task.result() - elif media_type == MediaType.ALBUM: - search_results.albums = task.result() - elif media_type == MediaType.TRACK: - search_results.tracks = task.result() - elif media_type == MediaType.PLAYLIST: - search_results.playlists = task.result() + if artists: + search_results.artists = artists.result() + + if albums: + search_results.albums = albums.result() + + if tracks: + search_results.tracks = tracks.result() + + if playlists: + search_results.playlists = playlists.result() return search_results @@ -772,7 +833,7 @@ class PlexProvider(MusicProvider): for track in tracks_obj: yield await self._parse_track(track) - async def get_album(self, prov_album_id) -> Album: + async def get_album(self, prov_album_id: str) -> Album: """Get full album details by id.""" if plex_album := await self._get_data(prov_album_id, PlexAlbum): return await self._parse_album(plex_album) @@ -790,7 +851,7 @@ class PlexProvider(MusicProvider): tracks.append(track) return tracks - async def get_artist(self, prov_artist_id) -> Artist: + async def get_artist(self, prov_artist_id: str) -> Artist: """Get full artist details by id.""" if prov_artist_id.startswith(FAKE_ARTIST_PREFIX): # This artist does not exist in plex, so we can just load it from DB. @@ -807,14 +868,14 @@ class PlexProvider(MusicProvider): msg = f"Item {prov_artist_id} not found" raise MediaNotFoundError(msg) - async def get_track(self, prov_track_id) -> Track: + async def get_track(self, prov_track_id: str) -> Track: """Get full track details by id.""" if plex_track := await self._get_data(prov_track_id, PlexTrack): return await self._parse_track(plex_track) msg = f"Item {prov_track_id} not found" raise MediaNotFoundError(msg) - async def get_playlist(self, prov_playlist_id) -> Playlist: + async def get_playlist(self, prov_playlist_id: str) -> Playlist: """Get full playlist details by id.""" if plex_playlist := await self._get_data(prov_playlist_id, PlexPlaylist): return await self._parse_playlist(plex_playlist) @@ -838,11 +899,11 @@ class PlexProvider(MusicProvider): result.append(track) return result - async def get_artist_albums(self, prov_artist_id) -> list[Album]: + async def get_artist_albums(self, prov_artist_id: str) -> list[Album]: """Get a list of albums for the given artist.""" if not prov_artist_id.startswith(FAKE_ARTIST_PREFIX): plex_artist = await self._get_data(prov_artist_id, PlexArtist) - plex_albums = await self._run_async(plex_artist.albums) + plex_albums = cast(list[PlexAlbum], await self._run_async(plex_artist.albums)) if plex_albums: albums = [] for album_obj in plex_albums: @@ -898,7 +959,7 @@ class PlexProvider(MusicProvider): async def on_streamed(self, streamdetails: StreamDetails, seconds_streamed: int) -> None: """Handle callback when an item completed streaming.""" - def mark_played(): + def mark_played() -> None: item = streamdetails.data params = {"key": str(item.ratingKey), "identifier": "com.plexapp.plugins.library"} self._plex_server.query("/:/scrobble", params=params) @@ -910,7 +971,7 @@ class PlexProvider(MusicProvider): if auth_token == AUTH_TOKEN_UNAUTH: return self._myplex_account - def _refresh_plex_token(): + def _refresh_plex_token() -> MyPlexAccount: if self._myplex_account is None: myplex_account = MyPlexAccount(token=auth_token) self._myplex_account = myplex_account diff --git a/music_assistant/server/providers/plex/helpers.py b/music_assistant/server/providers/plex/helpers.py index 5afbe3e1..af33e53d 100644 --- a/music_assistant/server/providers/plex/helpers.py +++ b/music_assistant/server/providers/plex/helpers.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import requests from plexapi.gdm import GDM @@ -17,7 +17,7 @@ if TYPE_CHECKING: async def get_libraries( mass: MusicAssistant, - auth_token: str, + auth_token: str | None, local_server_ssl: bool, local_server_ip: str, local_server_port: str, @@ -30,24 +30,24 @@ async def get_libraries( """ cache_key = "plex_libraries" - def _get_libraries(): + def _get_libraries() -> list[str]: # create a listing of available music libraries on all servers all_libraries: list[str] = [] session = requests.Session() session.verify = local_server_verify_cert local_server_protocol = "https" if local_server_ssl else "http" + plex_server: PlexServer if auth_token is None: - plex_server: PlexServer = PlexServer( + plex_server = PlexServer( f"{local_server_protocol}://{local_server_ip}:{local_server_port}" ) else: - plex_server: PlexServer = PlexServer( + plex_server = PlexServer( f"{local_server_protocol}://{local_server_ip}:{local_server_port}", auth_token, session=session, ) - for media_section in plex_server.library.sections(): - media_section: PlexLibrarySection + for media_section in cast(list[PlexLibrarySection], plex_server.library.sections()): if media_section.type != PlexMusicSection.TYPE: continue # TODO: figure out what plex uses as stable id and use that instead of names @@ -55,7 +55,7 @@ async def get_libraries( return all_libraries if cache := await mass.cache.get(cache_key, checksum=auth_token): - return cache + return cast(list[str], cache) result = await asyncio.to_thread(_get_libraries) # use short expiration for in-memory cache @@ -63,10 +63,10 @@ async def get_libraries( return result -async def discover_local_servers(): +async def discover_local_servers() -> tuple[str, int] | tuple[None, None]: """Discover all local plex servers on the network.""" - def _discover_local_servers(): + def _discover_local_servers() -> tuple[str, int] | tuple[None, None]: gdm = GDM() gdm.scan() if len(gdm.entries) > 0: diff --git a/mypy.ini b/mypy.ini index c99b9645..796e73f6 100644 --- a/mypy.ini +++ b/mypy.ini @@ -21,4 +21,4 @@ disallow_untyped_decorators = true disallow_untyped_defs = true warn_return_any = true warn_unreachable = true -packages=tests,music_assistant.client,music_assistant.common,music_assistant.server.providers.builtin,music_assistant.server.providers.filesystem_local,music_assistant.server.providers.filesystem_smb,music_assistant.server.providers.jellyfin,music_assistant.server.providers.radiobrowser,music_assistant.server.providers.test +packages=tests,music_assistant.client,music_assistant.common,music_assistant.server.providers.builtin,music_assistant.server.providers.filesystem_local,music_assistant.server.providers.filesystem_smb,music_assistant.server.providers.jellyfin,music_assistant.server.providers.plex,music_assistant.server.providers.radiobrowser,music_assistant.server.providers.test -- 2.34.1