Tidal typing (#1057)
authorJozef Kruszynski <60214390+jozefKruszynski@users.noreply.github.com>
Fri, 9 Feb 2024 10:25:42 +0000 (11:25 +0100)
committerGitHub <noreply@github.com>
Fri, 9 Feb 2024 10:25:42 +0000 (11:25 +0100)
music_assistant/server/providers/tidal/__init__.py
music_assistant/server/providers/tidal/helpers.py

index 470ec88e326588c657db48e59a4db396d6dc5da4..a476095bd538946424ee5a156cf946036f129dfc 100644 (file)
@@ -4,7 +4,7 @@ from __future__ import annotations
 
 import asyncio
 from datetime import datetime, timedelta
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, cast
 
 from asyncio_throttle import Throttler
 from tidalapi import Album as TidalAlbum
@@ -127,8 +127,11 @@ async def get_config_entries(
     """
     # config flow auth action/step (authenticate button clicked)
     if action == CONF_ACTION_AUTH:
-        async with AuthenticationHelper(mass, values["session_id"]) as auth_helper:
-            tidal_session = await tidal_code_login(auth_helper, values.get(CONF_QUALITY))
+        async with AuthenticationHelper(mass, cast(str, values["session_id"])) as auth_helper:
+            quality: str | int | float | list[str] | list[int] | None = (
+                values.get(CONF_QUALITY) if values else None
+            )
+            tidal_session = await tidal_code_login(auth_helper, cast(str, quality))
             if not tidal_session.check_login():
                 msg = "Authentication to Tidal failed"
                 raise LoginFailed(msg)
@@ -150,7 +153,7 @@ async def get_config_entries(
             label="Quality",
             required=True,
             description="The Tidal Quality you wish to use",
-            options=[
+            options=(
                 ConfigValueOption(
                     title=TidalQuality.low_96k.value, value=TidalQuality.low_96k.name
                 ),
@@ -162,7 +165,7 @@ async def get_config_entries(
                     value=TidalQuality.high_lossless.name,
                 ),
                 ConfigValueOption(title=TidalQuality.hi_res.value, value=TidalQuality.hi_res.name),
-            ],
+            ),
             default_value=TidalQuality.high_lossless.name,
             value=values.get(CONF_QUALITY) if values else None,
         ),
@@ -210,7 +213,7 @@ class TidalProvider(MusicProvider):
 
     async def handle_setup(self) -> None:
         """Handle async initialization of the provider."""
-        self._tidal_user_id = self.config.get_value(CONF_USER_ID)
+        self._tidal_user_id: str = self.config.get_value(CONF_USER_ID)
         self._tidal_session = await self._get_tidal_session()
         self._throttler = Throttler(rate_limit=1, period=0.1)
 
@@ -236,7 +239,10 @@ class TidalProvider(MusicProvider):
         )
 
     async def search(
-        self, search_query: str, media_types=list[MediaType] | None, limit: int = 5
+        self,
+        search_query: str,
+        media_types: list[MediaType] | None = None,
+        limit: int = 5,
     ) -> SearchResults:
         """Perform search on musicprovider.
 
@@ -298,20 +304,23 @@ class TidalProvider(MusicProvider):
         ):
             yield await self._parse_playlist(playlist_obj=playlist)
 
-    async def get_album_tracks(self, prov_album_id: str) -> list[Track]:
+    async def get_album_tracks(self, prov_album_id: str) -> list[AlbumTrack]:
         """Get album tracks for given album id."""
         tidal_session = await self._get_tidal_session()
         async with self._throttler:
-            return [
-                await self._parse_track(
-                    track_obj=track_obj,
-                    extra_init_kwargs={
-                        "disc_number": track_obj.volume_num,
-                        "track_number": track_obj.track_num,
-                    },
-                )
-                for track_obj in await get_album_tracks(tidal_session, prov_album_id)
-            ]
+            return cast(
+                list[AlbumTrack],
+                [
+                    await self._parse_track(
+                        track_obj=track_obj,
+                        extra_init_kwargs={
+                            "disc_number": track_obj.volume_num,
+                            "track_number": track_obj.track_num,
+                        },
+                    )
+                    for track_obj in await get_album_tracks(tidal_session, prov_album_id)
+                ],
+            )
 
     async def get_artist_albums(self, prov_artist_id: str) -> list[Album]:
         """Get a list of all albums for the given artist."""
@@ -357,29 +366,29 @@ class TidalProvider(MusicProvider):
                 for track in await get_similar_tracks(tidal_session, prov_track_id, limit)
             ]
 
-    async def library_add(self, prov_item_id: str, media_type: MediaType):
+    async def library_add(self, prov_item_id: str, media_type: MediaType) -> bool:
         """Add item to library."""
         tidal_session = await self._get_tidal_session()
         return await library_items_add_remove(
             tidal_session,
-            self._tidal_user_id,
+            str(self._tidal_user_id),
             prov_item_id,
             media_type,
             add=True,
         )
 
-    async def library_remove(self, prov_item_id: str, media_type: MediaType):
+    async def library_remove(self, prov_item_id: str, media_type: MediaType) -> bool:
         """Remove item from library."""
         tidal_session = await self._get_tidal_session()
         return await library_items_add_remove(
             tidal_session,
-            self._tidal_user_id,
+            str(self._tidal_user_id),
             prov_item_id,
             media_type,
             add=False,
         )
 
-    async def add_playlist_tracks(self, prov_playlist_id: str, prov_track_ids: list[str]):
+    async def add_playlist_tracks(self, prov_playlist_id: str, prov_track_ids: list[str]) -> None:
         """Add track(s) to playlist."""
         tidal_session = await self._get_tidal_session()
         return await add_remove_playlist_tracks(
@@ -404,7 +413,12 @@ class TidalProvider(MusicProvider):
     async def create_playlist(self, name: str) -> Playlist:
         """Create a new playlist on provider with given name."""
         tidal_session = await self._get_tidal_session()
-        playlist_obj = await create_playlist(tidal_session, self._tidal_user_id, name)
+        playlist_obj = await create_playlist(
+            session=tidal_session,
+            user_id=str(self._tidal_user_id),
+            title=name,
+            description="",
+        )
         return await self._parse_playlist(playlist_obj=playlist_obj)
 
     async def get_stream_details(self, item_id: str) -> StreamDetails:
@@ -480,16 +494,16 @@ class TidalProvider(MusicProvider):
         if (
             self._tidal_session
             and self._tidal_session.access_token
-            and datetime.fromisoformat(self.config.get_value(CONF_EXPIRY_TIME))
+            and datetime.fromisoformat(str(self.config.get_value(CONF_EXPIRY_TIME)))
             > (datetime.now() + timedelta(days=1))
         ):
             return self._tidal_session
         self._tidal_session = await self._load_tidal_session(
             token_type="Bearer",
             quality=self.config.get_value(CONF_QUALITY),
-            access_token=self.config.get_value(CONF_AUTH_TOKEN),
-            refresh_token=self.config.get_value(CONF_REFRESH_TOKEN),
-            expiry_time=datetime.fromisoformat(self.config.get_value(CONF_EXPIRY_TIME)),
+            access_token=str(self.config.get_value(CONF_AUTH_TOKEN)),
+            refresh_token=str(self.config.get_value(CONF_REFRESH_TOKEN)),
+            expiry_time=datetime.fromisoformat(str(self.config.get_value(CONF_EXPIRY_TIME))),
         )
         await self.mass.config.set_provider_config_value(
             self.config.instance_id,
@@ -510,11 +524,11 @@ class TidalProvider(MusicProvider):
 
     async def _load_tidal_session(
         self,
-        token_type,
+        token_type: str,
         quality: TidalQuality,
-        access_token,
-        refresh_token=None,
-        expiry_time=None,
+        access_token: str,
+        refresh_token: str,
+        expiry_time: datetime | None = None,
     ) -> TidalSession:
         """Load the tidalapi Session."""
 
@@ -593,7 +607,7 @@ class TidalProvider(MusicProvider):
         elif album_obj.type == "SINGLE":
             album.album_type = AlbumType.SINGLE
 
-        album.upc = album_obj.universal_product_number
+        album.upc = album_obj.universal_product_number
         album.year = int(album_obj.year)
         # metadata
         album.metadata.copyright = album_obj.copyright
@@ -680,6 +694,7 @@ class TidalProvider(MusicProvider):
                     path=image_url,
                 )
             ]
+
         return track
 
     async def _parse_playlist(
@@ -722,13 +737,16 @@ class TidalProvider(MusicProvider):
 
         return playlist
 
-    async def _get_image_url(self, item, size: int):
+    async def _get_image_url(
+        self, item: TidalArtist | TidalAlbum | TidalPlaylist, size: int
+    ) -> str:
         def inner() -> str:
-            return item.image(size)
+            image_url: str = item.image(size)
+            return image_url
 
         return await asyncio.to_thread(inner)
 
-    async def _get_lyrics(self, item):
+    async def _get_lyrics(self, item: TidalTrack) -> TidalLyrics:
         def inner() -> TidalLyrics:
             return item.lyrics
 
@@ -768,4 +786,5 @@ class TidalProvider(MusicProvider):
 
     def _is_hi_res(self, track_obj: TidalTrack) -> bool:
         """Check if track is hi-res."""
-        return track_obj.audio_quality.value == "HI_RES"
+        hi_res: bool = track_obj.audio_quality.value == "HI_RES"
+        return hi_res
index 9c4c52152737dc0805fbfb10f621eecb40c30f57..d6fc506c19fc3787d88f48085a557764617c1f74 100644 (file)
@@ -36,7 +36,10 @@ async def get_library_artists(
     """Async wrapper around the tidalapi Favorites.artists function."""
 
     def inner() -> list[TidalArtist]:
-        return TidalFavorites(session, user_id).artists(limit=limit, offset=offset)
+        artists: list[TidalArtist] = TidalFavorites(session, user_id).artists(
+            limit=limit, offset=offset
+        )
+        return artists
 
     return await asyncio.to_thread(inner)
 
@@ -50,34 +53,32 @@ async def library_items_add_remove(
 ) -> None:
     """Async wrapper around the tidalapi Favorites.items add/remove function."""
 
-    def inner() -> None:
-        match media_type:
-            case MediaType.ARTIST:
-                (
-                    TidalFavorites(session, user_id).add_artist(item_id)
-                    if add
-                    else TidalFavorites(session, user_id).remove_artist(item_id)
-                )
-            case MediaType.ALBUM:
-                (
-                    TidalFavorites(session, user_id).add_album(item_id)
-                    if add
-                    else TidalFavorites(session, user_id).remove_album(item_id)
-                )
-            case MediaType.TRACK:
-                (
-                    TidalFavorites(session, user_id).add_track(item_id)
-                    if add
-                    else TidalFavorites(session, user_id).remove_track(item_id)
-                )
-            case MediaType.PLAYLIST:
-                (
-                    TidalFavorites(session, user_id).add_playlist(item_id)
-                    if add
-                    else TidalFavorites(session, user_id).remove_playlist(item_id)
-                )
-            case MediaType.UNKNOWN:
-                return
+    def inner() -> bool:
+        tidal_favorites = TidalFavorites(session, user_id)
+        if MediaType.UNKNOWN:
+            return False
+        response: bool = False
+        if add:
+            match media_type:
+                case MediaType.ARTIST:
+                    response = tidal_favorites.add_artist(item_id)
+                case MediaType.ALBUM:
+                    response = tidal_favorites.add_album(item_id)
+                case MediaType.TRACK:
+                    response = tidal_favorites.add_track(item_id)
+                case MediaType.PLAYLIST:
+                    response = tidal_favorites.add_playlist(item_id)
+        else:
+            match media_type:
+                case MediaType.ARTIST:
+                    response = tidal_favorites.remove_artist(item_id)
+                case MediaType.ALBUM:
+                    response = tidal_favorites.remove_album(item_id)
+                case MediaType.TRACK:
+                    response = tidal_favorites.remove_track(item_id)
+                case MediaType.PLAYLIST:
+                    response = tidal_favorites.remove_playlist(item_id)
+        return response
 
     return await asyncio.to_thread(inner)
 
@@ -127,7 +128,10 @@ async def get_artist_toptracks(
     """Async wrapper around the tidalapi Artist.get_top_tracks function."""
 
     def inner() -> list[TidalTrack]:
-        return TidalArtist(session, prov_artist_id).get_top_tracks(limit=limit, offset=offset)
+        top_tracks: list[TidalTrack] = TidalArtist(session, prov_artist_id).get_top_tracks(
+            limit=limit, offset=offset
+        )
+        return top_tracks
 
     return await asyncio.to_thread(inner)
 
@@ -138,7 +142,10 @@ async def get_library_albums(
     """Async wrapper around the tidalapi Favorites.albums function."""
 
     def inner() -> list[TidalAlbum]:
-        return TidalFavorites(session, user_id).albums(limit=limit, offset=offset)
+        albums: list[TidalAlbum] = TidalFavorites(session, user_id).albums(
+            limit=limit, offset=offset
+        )
+        return albums
 
     return await asyncio.to_thread(inner)
 
@@ -173,12 +180,13 @@ async def get_track(session: TidalSession, prov_track_id: str) -> TidalTrack:
     return await asyncio.to_thread(inner)
 
 
-async def get_track_url(session: TidalSession, prov_track_id: str) -> dict[str, str]:
+async def get_track_url(session: TidalSession, prov_track_id: str) -> str:
     """Async wrapper around the tidalapi Track.get_url function."""
 
-    def inner() -> dict[str, str]:
+    def inner() -> str:
         try:
-            return TidalTrack(session, prov_track_id).get_url()
+            track_url: str = TidalTrack(session, prov_track_id).get_url()
+            return track_url
         except HTTPError as err:
             if err.response.status_code == 404:
                 msg = f"Track {prov_track_id} not found"
@@ -193,7 +201,10 @@ async def get_album_tracks(session: TidalSession, prov_album_id: str) -> list[Ti
 
     def inner() -> list[TidalTrack]:
         try:
-            return TidalAlbum(session, prov_album_id).tracks(limit=DEFAULT_LIMIT)
+            tracks: list[TidalTrack] = TidalAlbum(session, prov_album_id).tracks(
+                limit=DEFAULT_LIMIT
+            )
+            return tracks
         except HTTPError as err:
             if err.response.status_code == 404:
                 msg = f"Album {prov_album_id} not found"
@@ -209,7 +220,10 @@ async def get_library_tracks(
     """Async wrapper around the tidalapi Favorites.tracks function."""
 
     def inner() -> list[TidalTrack]:
-        return TidalFavorites(session, user_id).tracks(limit=limit, offset=offset)
+        tracks: list[TidalTrack] = TidalFavorites(session, user_id).tracks(
+            limit=limit, offset=offset
+        )
+        return tracks
 
     return await asyncio.to_thread(inner)
 
@@ -220,7 +234,10 @@ async def get_library_playlists(
     """Async wrapper around the tidalapi LoggedInUser.playlist_and_favorite_playlists function."""
 
     def inner() -> list[TidalPlaylist]:
-        return LoggedInUser(session, user_id).playlist_and_favorite_playlists(offset=offset)
+        playlists: list[TidalPlaylist] = LoggedInUser(
+            session, user_id
+        ).playlist_and_favorite_playlists(offset=offset)
+        return playlists
 
     return await asyncio.to_thread(inner)
 
@@ -250,7 +267,10 @@ async def get_playlist_tracks(
 
     def inner() -> list[TidalTrack]:
         try:
-            return TidalPlaylist(session, prov_playlist_id).tracks(limit=limit, offset=offset)
+            tracks: list[TidalTrack] = TidalPlaylist(session, prov_playlist_id).tracks(
+                limit=limit, offset=offset
+            )
+            return tracks
         except HTTPError as err:
             if err.response.status_code == 404:
                 msg = f"Playlist {prov_playlist_id} not found"
@@ -267,10 +287,9 @@ async def add_remove_playlist_tracks(
 
     def inner() -> None:
         if add:
-            return TidalUserPlaylist(session, prov_playlist_id).add(track_ids)
+            TidalUserPlaylist(session, prov_playlist_id).add(track_ids)
         for item in track_ids:
             TidalUserPlaylist(session, prov_playlist_id).remove_by_id(int(item))
-        return None
 
     return await asyncio.to_thread(inner)
 
@@ -281,7 +300,8 @@ async def create_playlist(
     """Async wrapper around the tidal LoggedInUser.create_playlist function."""
 
     def inner() -> TidalPlaylist:
-        return LoggedInUser(session, user_id).create_playlist(title, description)
+        playlist: TidalPlaylist = LoggedInUser(session, user_id).create_playlist(title, description)
+        return playlist
 
     return await asyncio.to_thread(inner)
 
@@ -293,8 +313,10 @@ async def get_similar_tracks(
 
     def inner() -> list[TidalTrack]:
         try:
-            # Re-add limit here after tidalapi supports it
-            return TidalTrack(session, prov_track_id).get_track_radio(limit=limit)
+            tracks: list[TidalTrack] = TidalTrack(session, prov_track_id).get_track_radio(
+                limit=limit
+            )
+            return tracks
         except HTTPError as err:
             if err.response.status_code == 404:
                 msg = f"Track {prov_track_id} not found"
@@ -305,22 +327,27 @@ async def get_similar_tracks(
 
 
 async def search(
-    session: TidalSession, query: str, media_types=None, limit=50, offset=0
+    session: TidalSession,
+    query: str,
+    media_types: list[MediaType] | None = None,
+    limit: int = 50,
+    offset: int = 0,
 ) -> dict[str, str]:
     """Async wrapper around the tidalapi Search function."""
 
     def inner() -> dict[str, str]:
         search_types = []
-        if MediaType.ARTIST in media_types:
+        if media_types and MediaType.ARTIST in media_types:
             search_types.append(TidalArtist)
-        if MediaType.ALBUM in media_types:
+        if media_types and MediaType.ALBUM in media_types:
             search_types.append(TidalAlbum)
-        if MediaType.TRACK in media_types:
+        if media_types and MediaType.TRACK in media_types:
             search_types.append(TidalTrack)
-        if MediaType.PLAYLIST in media_types:
+        if media_types and MediaType.PLAYLIST in media_types:
             search_types.append(TidalPlaylist)
 
         models = search_types if search_types else None
-        return session.search(query, models, limit, offset)
+        results: dict[str, str] = session.search(query, models, limit, offset)
+        return results
 
     return await asyncio.to_thread(inner)