From d5be10a3d72e39458c74a3ce3bb71c8aeaa03643 Mon Sep 17 00:00:00 2001 From: Jozef Kruszynski <60214390+jozefKruszynski@users.noreply.github.com> Date: Fri, 9 Feb 2024 11:25:42 +0100 Subject: [PATCH] Tidal typing (#1057) --- .../server/providers/tidal/__init__.py | 93 +++++++------ .../server/providers/tidal/helpers.py | 125 +++++++++++------- 2 files changed, 132 insertions(+), 86 deletions(-) diff --git a/music_assistant/server/providers/tidal/__init__.py b/music_assistant/server/providers/tidal/__init__.py index 470ec88e..a476095b 100644 --- a/music_assistant/server/providers/tidal/__init__.py +++ b/music_assistant/server/providers/tidal/__init__.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from asyncio_throttle import Throttler from tidalapi import Album as TidalAlbum @@ -127,8 +127,11 @@ async def get_config_entries( """ # config flow auth action/step (authenticate button clicked) if action == CONF_ACTION_AUTH: - async with AuthenticationHelper(mass, values["session_id"]) as auth_helper: - tidal_session = await tidal_code_login(auth_helper, values.get(CONF_QUALITY)) + async with AuthenticationHelper(mass, cast(str, values["session_id"])) as auth_helper: + quality: str | int | float | list[str] | list[int] | None = ( + values.get(CONF_QUALITY) if values else None + ) + tidal_session = await tidal_code_login(auth_helper, cast(str, quality)) if not tidal_session.check_login(): msg = "Authentication to Tidal failed" raise LoginFailed(msg) @@ -150,7 +153,7 @@ async def get_config_entries( label="Quality", required=True, description="The Tidal Quality you wish to use", - options=[ + options=( ConfigValueOption( title=TidalQuality.low_96k.value, value=TidalQuality.low_96k.name ), @@ -162,7 +165,7 @@ async def get_config_entries( value=TidalQuality.high_lossless.name, ), ConfigValueOption(title=TidalQuality.hi_res.value, value=TidalQuality.hi_res.name), - ], + ), default_value=TidalQuality.high_lossless.name, value=values.get(CONF_QUALITY) if values else None, ), @@ -210,7 +213,7 @@ class TidalProvider(MusicProvider): async def handle_setup(self) -> None: """Handle async initialization of the provider.""" - self._tidal_user_id = self.config.get_value(CONF_USER_ID) + self._tidal_user_id: str = self.config.get_value(CONF_USER_ID) self._tidal_session = await self._get_tidal_session() self._throttler = Throttler(rate_limit=1, period=0.1) @@ -236,7 +239,10 @@ class TidalProvider(MusicProvider): ) async def search( - self, search_query: str, media_types=list[MediaType] | None, limit: int = 5 + self, + search_query: str, + media_types: list[MediaType] | None = None, + limit: int = 5, ) -> SearchResults: """Perform search on musicprovider. @@ -298,20 +304,23 @@ class TidalProvider(MusicProvider): ): yield await self._parse_playlist(playlist_obj=playlist) - async def get_album_tracks(self, prov_album_id: str) -> list[Track]: + async def get_album_tracks(self, prov_album_id: str) -> list[AlbumTrack]: """Get album tracks for given album id.""" tidal_session = await self._get_tidal_session() async with self._throttler: - return [ - await self._parse_track( - track_obj=track_obj, - extra_init_kwargs={ - "disc_number": track_obj.volume_num, - "track_number": track_obj.track_num, - }, - ) - for track_obj in await get_album_tracks(tidal_session, prov_album_id) - ] + return cast( + list[AlbumTrack], + [ + await self._parse_track( + track_obj=track_obj, + extra_init_kwargs={ + "disc_number": track_obj.volume_num, + "track_number": track_obj.track_num, + }, + ) + for track_obj in await get_album_tracks(tidal_session, prov_album_id) + ], + ) async def get_artist_albums(self, prov_artist_id: str) -> list[Album]: """Get a list of all albums for the given artist.""" @@ -357,29 +366,29 @@ class TidalProvider(MusicProvider): for track in await get_similar_tracks(tidal_session, prov_track_id, limit) ] - async def library_add(self, prov_item_id: str, media_type: MediaType): + async def library_add(self, prov_item_id: str, media_type: MediaType) -> bool: """Add item to library.""" tidal_session = await self._get_tidal_session() return await library_items_add_remove( tidal_session, - self._tidal_user_id, + str(self._tidal_user_id), prov_item_id, media_type, add=True, ) - async def library_remove(self, prov_item_id: str, media_type: MediaType): + async def library_remove(self, prov_item_id: str, media_type: MediaType) -> bool: """Remove item from library.""" tidal_session = await self._get_tidal_session() return await library_items_add_remove( tidal_session, - self._tidal_user_id, + str(self._tidal_user_id), prov_item_id, media_type, add=False, ) - async def add_playlist_tracks(self, prov_playlist_id: str, prov_track_ids: list[str]): + async def add_playlist_tracks(self, prov_playlist_id: str, prov_track_ids: list[str]) -> None: """Add track(s) to playlist.""" tidal_session = await self._get_tidal_session() return await add_remove_playlist_tracks( @@ -404,7 +413,12 @@ class TidalProvider(MusicProvider): async def create_playlist(self, name: str) -> Playlist: """Create a new playlist on provider with given name.""" tidal_session = await self._get_tidal_session() - playlist_obj = await create_playlist(tidal_session, self._tidal_user_id, name) + playlist_obj = await create_playlist( + session=tidal_session, + user_id=str(self._tidal_user_id), + title=name, + description="", + ) return await self._parse_playlist(playlist_obj=playlist_obj) async def get_stream_details(self, item_id: str) -> StreamDetails: @@ -480,16 +494,16 @@ class TidalProvider(MusicProvider): if ( self._tidal_session and self._tidal_session.access_token - and datetime.fromisoformat(self.config.get_value(CONF_EXPIRY_TIME)) + and datetime.fromisoformat(str(self.config.get_value(CONF_EXPIRY_TIME))) > (datetime.now() + timedelta(days=1)) ): return self._tidal_session self._tidal_session = await self._load_tidal_session( token_type="Bearer", quality=self.config.get_value(CONF_QUALITY), - access_token=self.config.get_value(CONF_AUTH_TOKEN), - refresh_token=self.config.get_value(CONF_REFRESH_TOKEN), - expiry_time=datetime.fromisoformat(self.config.get_value(CONF_EXPIRY_TIME)), + access_token=str(self.config.get_value(CONF_AUTH_TOKEN)), + refresh_token=str(self.config.get_value(CONF_REFRESH_TOKEN)), + expiry_time=datetime.fromisoformat(str(self.config.get_value(CONF_EXPIRY_TIME))), ) await self.mass.config.set_provider_config_value( self.config.instance_id, @@ -510,11 +524,11 @@ class TidalProvider(MusicProvider): async def _load_tidal_session( self, - token_type, + token_type: str, quality: TidalQuality, - access_token, - refresh_token=None, - expiry_time=None, + access_token: str, + refresh_token: str, + expiry_time: datetime | None = None, ) -> TidalSession: """Load the tidalapi Session.""" @@ -593,7 +607,7 @@ class TidalProvider(MusicProvider): elif album_obj.type == "SINGLE": album.album_type = AlbumType.SINGLE - album.upc = album_obj.universal_product_number + # album.upc = album_obj.universal_product_number album.year = int(album_obj.year) # metadata album.metadata.copyright = album_obj.copyright @@ -680,6 +694,7 @@ class TidalProvider(MusicProvider): path=image_url, ) ] + return track async def _parse_playlist( @@ -722,13 +737,16 @@ class TidalProvider(MusicProvider): return playlist - async def _get_image_url(self, item, size: int): + async def _get_image_url( + self, item: TidalArtist | TidalAlbum | TidalPlaylist, size: int + ) -> str: def inner() -> str: - return item.image(size) + image_url: str = item.image(size) + return image_url return await asyncio.to_thread(inner) - async def _get_lyrics(self, item): + async def _get_lyrics(self, item: TidalTrack) -> TidalLyrics: def inner() -> TidalLyrics: return item.lyrics @@ -768,4 +786,5 @@ class TidalProvider(MusicProvider): def _is_hi_res(self, track_obj: TidalTrack) -> bool: """Check if track is hi-res.""" - return track_obj.audio_quality.value == "HI_RES" + hi_res: bool = track_obj.audio_quality.value == "HI_RES" + return hi_res diff --git a/music_assistant/server/providers/tidal/helpers.py b/music_assistant/server/providers/tidal/helpers.py index 9c4c5215..d6fc506c 100644 --- a/music_assistant/server/providers/tidal/helpers.py +++ b/music_assistant/server/providers/tidal/helpers.py @@ -36,7 +36,10 @@ async def get_library_artists( """Async wrapper around the tidalapi Favorites.artists function.""" def inner() -> list[TidalArtist]: - return TidalFavorites(session, user_id).artists(limit=limit, offset=offset) + artists: list[TidalArtist] = TidalFavorites(session, user_id).artists( + limit=limit, offset=offset + ) + return artists return await asyncio.to_thread(inner) @@ -50,34 +53,32 @@ async def library_items_add_remove( ) -> None: """Async wrapper around the tidalapi Favorites.items add/remove function.""" - def inner() -> None: - match media_type: - case MediaType.ARTIST: - ( - TidalFavorites(session, user_id).add_artist(item_id) - if add - else TidalFavorites(session, user_id).remove_artist(item_id) - ) - case MediaType.ALBUM: - ( - TidalFavorites(session, user_id).add_album(item_id) - if add - else TidalFavorites(session, user_id).remove_album(item_id) - ) - case MediaType.TRACK: - ( - TidalFavorites(session, user_id).add_track(item_id) - if add - else TidalFavorites(session, user_id).remove_track(item_id) - ) - case MediaType.PLAYLIST: - ( - TidalFavorites(session, user_id).add_playlist(item_id) - if add - else TidalFavorites(session, user_id).remove_playlist(item_id) - ) - case MediaType.UNKNOWN: - return + def inner() -> bool: + tidal_favorites = TidalFavorites(session, user_id) + if MediaType.UNKNOWN: + return False + response: bool = False + if add: + match media_type: + case MediaType.ARTIST: + response = tidal_favorites.add_artist(item_id) + case MediaType.ALBUM: + response = tidal_favorites.add_album(item_id) + case MediaType.TRACK: + response = tidal_favorites.add_track(item_id) + case MediaType.PLAYLIST: + response = tidal_favorites.add_playlist(item_id) + else: + match media_type: + case MediaType.ARTIST: + response = tidal_favorites.remove_artist(item_id) + case MediaType.ALBUM: + response = tidal_favorites.remove_album(item_id) + case MediaType.TRACK: + response = tidal_favorites.remove_track(item_id) + case MediaType.PLAYLIST: + response = tidal_favorites.remove_playlist(item_id) + return response return await asyncio.to_thread(inner) @@ -127,7 +128,10 @@ async def get_artist_toptracks( """Async wrapper around the tidalapi Artist.get_top_tracks function.""" def inner() -> list[TidalTrack]: - return TidalArtist(session, prov_artist_id).get_top_tracks(limit=limit, offset=offset) + top_tracks: list[TidalTrack] = TidalArtist(session, prov_artist_id).get_top_tracks( + limit=limit, offset=offset + ) + return top_tracks return await asyncio.to_thread(inner) @@ -138,7 +142,10 @@ async def get_library_albums( """Async wrapper around the tidalapi Favorites.albums function.""" def inner() -> list[TidalAlbum]: - return TidalFavorites(session, user_id).albums(limit=limit, offset=offset) + albums: list[TidalAlbum] = TidalFavorites(session, user_id).albums( + limit=limit, offset=offset + ) + return albums return await asyncio.to_thread(inner) @@ -173,12 +180,13 @@ async def get_track(session: TidalSession, prov_track_id: str) -> TidalTrack: return await asyncio.to_thread(inner) -async def get_track_url(session: TidalSession, prov_track_id: str) -> dict[str, str]: +async def get_track_url(session: TidalSession, prov_track_id: str) -> str: """Async wrapper around the tidalapi Track.get_url function.""" - def inner() -> dict[str, str]: + def inner() -> str: try: - return TidalTrack(session, prov_track_id).get_url() + track_url: str = TidalTrack(session, prov_track_id).get_url() + return track_url except HTTPError as err: if err.response.status_code == 404: msg = f"Track {prov_track_id} not found" @@ -193,7 +201,10 @@ async def get_album_tracks(session: TidalSession, prov_album_id: str) -> list[Ti def inner() -> list[TidalTrack]: try: - return TidalAlbum(session, prov_album_id).tracks(limit=DEFAULT_LIMIT) + tracks: list[TidalTrack] = TidalAlbum(session, prov_album_id).tracks( + limit=DEFAULT_LIMIT + ) + return tracks except HTTPError as err: if err.response.status_code == 404: msg = f"Album {prov_album_id} not found" @@ -209,7 +220,10 @@ async def get_library_tracks( """Async wrapper around the tidalapi Favorites.tracks function.""" def inner() -> list[TidalTrack]: - return TidalFavorites(session, user_id).tracks(limit=limit, offset=offset) + tracks: list[TidalTrack] = TidalFavorites(session, user_id).tracks( + limit=limit, offset=offset + ) + return tracks return await asyncio.to_thread(inner) @@ -220,7 +234,10 @@ async def get_library_playlists( """Async wrapper around the tidalapi LoggedInUser.playlist_and_favorite_playlists function.""" def inner() -> list[TidalPlaylist]: - return LoggedInUser(session, user_id).playlist_and_favorite_playlists(offset=offset) + playlists: list[TidalPlaylist] = LoggedInUser( + session, user_id + ).playlist_and_favorite_playlists(offset=offset) + return playlists return await asyncio.to_thread(inner) @@ -250,7 +267,10 @@ async def get_playlist_tracks( def inner() -> list[TidalTrack]: try: - return TidalPlaylist(session, prov_playlist_id).tracks(limit=limit, offset=offset) + tracks: list[TidalTrack] = TidalPlaylist(session, prov_playlist_id).tracks( + limit=limit, offset=offset + ) + return tracks except HTTPError as err: if err.response.status_code == 404: msg = f"Playlist {prov_playlist_id} not found" @@ -267,10 +287,9 @@ async def add_remove_playlist_tracks( def inner() -> None: if add: - return TidalUserPlaylist(session, prov_playlist_id).add(track_ids) + TidalUserPlaylist(session, prov_playlist_id).add(track_ids) for item in track_ids: TidalUserPlaylist(session, prov_playlist_id).remove_by_id(int(item)) - return None return await asyncio.to_thread(inner) @@ -281,7 +300,8 @@ async def create_playlist( """Async wrapper around the tidal LoggedInUser.create_playlist function.""" def inner() -> TidalPlaylist: - return LoggedInUser(session, user_id).create_playlist(title, description) + playlist: TidalPlaylist = LoggedInUser(session, user_id).create_playlist(title, description) + return playlist return await asyncio.to_thread(inner) @@ -293,8 +313,10 @@ async def get_similar_tracks( def inner() -> list[TidalTrack]: try: - # Re-add limit here after tidalapi supports it - return TidalTrack(session, prov_track_id).get_track_radio(limit=limit) + tracks: list[TidalTrack] = TidalTrack(session, prov_track_id).get_track_radio( + limit=limit + ) + return tracks except HTTPError as err: if err.response.status_code == 404: msg = f"Track {prov_track_id} not found" @@ -305,22 +327,27 @@ async def get_similar_tracks( async def search( - session: TidalSession, query: str, media_types=None, limit=50, offset=0 + session: TidalSession, + query: str, + media_types: list[MediaType] | None = None, + limit: int = 50, + offset: int = 0, ) -> dict[str, str]: """Async wrapper around the tidalapi Search function.""" def inner() -> dict[str, str]: search_types = [] - if MediaType.ARTIST in media_types: + if media_types and MediaType.ARTIST in media_types: search_types.append(TidalArtist) - if MediaType.ALBUM in media_types: + if media_types and MediaType.ALBUM in media_types: search_types.append(TidalAlbum) - if MediaType.TRACK in media_types: + if media_types and MediaType.TRACK in media_types: search_types.append(TidalTrack) - if MediaType.PLAYLIST in media_types: + if media_types and MediaType.PLAYLIST in media_types: search_types.append(TidalPlaylist) models = search_types if search_types else None - return session.search(query, models, limit, offset) + results: dict[str, str] = session.search(query, models, limit, offset) + return results return await asyncio.to_thread(inner) -- 2.34.1