Several small (typing) fixes for the Deezer provider (#2413)
authorOzGav <gavnosp@hotmail.com>
Mon, 22 Sep 2025 18:55:19 +0000 (04:55 +1000)
committerGitHub <noreply@github.com>
Mon, 22 Sep 2025 18:55:19 +0000 (20:55 +0200)
music_assistant/providers/deezer/__init__.py
music_assistant/providers/deezer/gw_client.py

index a47a9d2c100e0588924befc9943257975d698259..1d2011e28630a62b8057ac2eecb9ce9d1b71e1c5 100644 (file)
@@ -6,6 +6,7 @@ from asyncio import TaskGroup
 from collections.abc import AsyncGenerator
 from dataclasses import dataclass
 from math import ceil
+from typing import Any, Literal, cast
 
 import deezer
 from aiohttp import ClientSession, ClientTimeout
@@ -22,7 +23,7 @@ from music_assistant_models.enums import (
     ProviderFeature,
     StreamType,
 )
-from music_assistant_models.errors import LoginFailed
+from music_assistant_models.errors import InvalidDataError, LoginFailed, MediaNotFoundError
 from music_assistant_models.media_items import (
     Album,
     Artist,
@@ -36,12 +37,13 @@ from music_assistant_models.media_items import (
     RecommendationFolder,
     SearchResults,
     Track,
+    UniqueList,
 )
 from music_assistant_models.provider import ProviderManifest
 from music_assistant_models.streamdetails import StreamDetails
 
 from music_assistant import MusicAssistant
-from music_assistant.helpers.app_vars import app_var
+from music_assistant.helpers.app_vars import app_var  # type: ignore[attr-defined]
 from music_assistant.helpers.auth import AuthenticationHelper
 from music_assistant.helpers.datetime import utc_timestamp
 from music_assistant.helpers.util import infer_album_type
@@ -130,11 +132,13 @@ async def get_config_entries(
     # Action is to launch oauth flow
     if action == CONF_ACTION_AUTH:
         # Use the AuthenticationHelper to authenticate
-        async with AuthenticationHelper(mass, values["session_id"]) as auth_helper:  # type: ignore
+        if not values or "session_id" not in values:
+            raise InvalidDataError("session_id not found in values")
+        async with AuthenticationHelper(mass, cast("str", values["session_id"])) as auth_helper:
             url = f"{DEEZER_AUTH_URL}?app_id={DEEZER_APP_ID}&redirect_uri={RELAY_URL}\
 &perms={DEEZER_PERMS}&state={auth_helper.callback_url}"
             code = (await auth_helper.authenticate(url))["code"]
-            values[CONF_ACCESS_TOKEN] = await get_access_token(  # type: ignore
+            values[CONF_ACCESS_TOKEN] = await get_access_token(
                 DEEZER_APP_ID, DEEZER_APP_SECRET, code, mass.http_session
             )
 
@@ -173,7 +177,7 @@ class DeezerProvider(MusicProvider):
         self.credentials = DeezerCredentials(
             app_id=DEEZER_APP_ID,
             app_secret=DEEZER_APP_SECRET,
-            access_token=self.config.get_value(CONF_ACCESS_TOKEN),  # type: ignore
+            access_token=cast("str", self.config.get_value(CONF_ACCESS_TOKEN)),
         )
 
         self.client = deezer.Client(
@@ -186,13 +190,13 @@ class DeezerProvider(MusicProvider):
 
         self.gw_client = GWClient(
             self.mass.http_session,
-            self.config.get_value(CONF_ACCESS_TOKEN),
-            self.config.get_value(CONF_ARL_TOKEN),
+            str(self.config.get_value(CONF_ACCESS_TOKEN)),
+            str(self.config.get_value(CONF_ARL_TOKEN)),
         )
         await self.gw_client.setup()
 
     async def search(
-        self, search_query: str, media_types=list[MediaType], limit: int = 5
+        self, search_query: str, media_typeslist[MediaType], limit: int = 5
     ) -> SearchResults:
         """Perform search on music provider.
 
@@ -200,7 +204,7 @@ class DeezerProvider(MusicProvider):
         :param media_types: A list of media_types to include. All types if None.
         """
         # Create a task for each media_type
-        tasks = {}
+        tasks: dict[MediaType, Any] = {}
 
         async with TaskGroup() as taskgroup:
             for media_type in media_types:
@@ -267,6 +271,7 @@ class DeezerProvider(MusicProvider):
             )
         except deezer_exceptions.DeezerErrorResponse as error:
             self.logger.warning("Failed getting artist: %s", error)
+            raise MediaNotFoundError(f"Artist {prov_artist_id} not found on Deezer") from error
 
     async def get_album(self, prov_album_id: str) -> Album:
         """Get full album details by id."""
@@ -274,6 +279,7 @@ class DeezerProvider(MusicProvider):
             return self.parse_album(album=await self.client.get_album(album_id=int(prov_album_id)))
         except deezer_exceptions.DeezerErrorResponse as error:
             self.logger.warning("Failed getting album: %s", error)
+            raise MediaNotFoundError(f"Album {prov_album_id} not found on Deezer") from error
 
     async def get_playlist(self, prov_playlist_id: str) -> Playlist:
         """Get full playlist details by id."""
@@ -283,6 +289,7 @@ class DeezerProvider(MusicProvider):
             )
         except deezer_exceptions.DeezerErrorResponse as error:
             self.logger.warning("Failed getting playlist: %s", error)
+            raise MediaNotFoundError(f"Album {prov_playlist_id} not found on Deezer") from error
 
     async def get_track(self, prov_track_id: str) -> Track:
         """Get full track details by id."""
@@ -293,6 +300,7 @@ class DeezerProvider(MusicProvider):
             )
         except deezer_exceptions.DeezerErrorResponse as error:
             self.logger.warning("Failed getting track: %s", error)
+            raise MediaNotFoundError(f"Album {prov_track_id} not found on Deezer") from error
 
     async def get_album_tracks(self, prov_album_id: str) -> list[Track]:
         """Get all tracks in an album."""
@@ -390,12 +398,15 @@ class DeezerProvider(MusicProvider):
         return [
             RecommendationFolder(
                 item_id="recommended_tracks",
+                provider=self.lookup_key,
                 name="Recommended tracks",
                 translation_key="recommended_tracks",
-                items=[
-                    self.parse_track(track=track, user_country=self.gw_client.user_country)
-                    for track in await self.client.get_user_recommended_tracks()
-                ],
+                items=UniqueList(
+                    [
+                        self.parse_track(track=track, user_country=self.gw_client.user_country)
+                        for track in await self.client.get_user_recommended_tracks()
+                    ]
+                ),
             )
         ]
 
@@ -423,7 +434,7 @@ class DeezerProvider(MusicProvider):
         playlist = await self.client.get_playlist(playlist_id)
         return self.parse_playlist(playlist=playlist)
 
-    async def get_similar_tracks(self, prov_track_id, limit=25) -> list[Track]:
+    async def get_similar_tracks(self, prov_track_id: str, limit: int = 25) -> list[Track]:
         """Retrieve a dynamic list of tracks based on the provided item."""
         endpoint = "song.getSearchTrackMix"
         tracks = (await self.gw_client._gw_api_call(endpoint, args={"SNG_ID": prov_track_id}))[
@@ -459,7 +470,7 @@ class DeezerProvider(MusicProvider):
         blowfish_key = self.get_blowfish_key(streamdetails.data["track_id"])
         chunk_index = 0
         timeout = ClientTimeout(total=0, connect=30, sock_read=600)
-        headers = {}
+        headers: dict[str, str] = {}
         # if seek_position and streamdetails.size:
         #     chunk_count = ceil(streamdetails.size / 2048)
         #     chunk_index = int(chunk_count / streamdetails.duration) * seek_position
@@ -469,7 +480,7 @@ class DeezerProvider(MusicProvider):
         # NOTE: Seek with using the Range header is not working properly
         # causing malformed audio so this is a temporary patch
         # by just skipping chunks
-        if seek_position and streamdetails.size:
+        if seek_position and streamdetails.size and streamdetails.duration:
             chunk_count = ceil(streamdetails.size / 2048)
             skip_chunks = int(chunk_count / streamdetails.duration) * seek_position
         else:
@@ -511,46 +522,48 @@ class DeezerProvider(MusicProvider):
             metadata.preview = track.preview
         if hasattr(track, "explicit_lyrics"):
             metadata.explicit = track.explicit_lyrics
-        if hasattr(track, "duration"):
-            metadata.duration = track.duration
         if hasattr(track, "rank"):
             metadata.popularity = track.rank
         if hasattr(track, "album") and hasattr(track.album, "cover_big"):
-            metadata.images = [
+            metadata.add_image(
                 MediaItemImage(
                     type=ImageType.THUMB,
                     path=track.album.cover_big,
                     provider=self.lookup_key,
                     remotely_accessible=True,
                 )
-            ]
+            )
         return metadata
 
     def parse_metadata_album(self, album: deezer.Album) -> MediaItemMetadata:
         """Parse the album metadata."""
         return MediaItemMetadata(
             explicit=album.explicit_lyrics,
-            images=[
-                MediaItemImage(
-                    type=ImageType.THUMB,
-                    path=album.cover_big,
-                    provider=self.lookup_key,
-                    remotely_accessible=True,
-                )
-            ],
+            images=UniqueList(
+                [
+                    MediaItemImage(
+                        type=ImageType.THUMB,
+                        path=album.cover_big,
+                        provider=self.lookup_key,
+                        remotely_accessible=True,
+                    )
+                ]
+            ),
         )
 
     def parse_metadata_artist(self, artist: deezer.Artist) -> MediaItemMetadata:
         """Parse the artist metadata."""
         return MediaItemMetadata(
-            images=[
-                MediaItemImage(
-                    type=ImageType.THUMB,
-                    path=artist.picture_big,
-                    provider=self.lookup_key,
-                    remotely_accessible=True,
-                )
-            ],
+            images=UniqueList(
+                [
+                    MediaItemImage(
+                        type=ImageType.THUMB,
+                        path=artist.picture_big,
+                        provider=self.lookup_key,
+                        remotely_accessible=True,
+                    )
+                ]
+            ),
         )
 
     ### PARSING FUNCTIONS ###
@@ -566,7 +579,7 @@ class DeezerProvider(MusicProvider):
                     item_id=str(artist.id),
                     provider_domain=self.domain,
                     provider_instance=self.instance_id,
-                    url=artist.link,
+                    url=getattr(artist, "link", None),  # Sometimes the API doesn't return a link
                 )
             },
             metadata=self.parse_metadata_artist(artist=artist),
@@ -579,21 +592,23 @@ class DeezerProvider(MusicProvider):
             item_id=str(album.id),
             provider=self.lookup_key,
             name=album.title,
-            artists=[
-                ItemMapping(
-                    media_type=MediaType.ARTIST,
-                    item_id=str(album.artist.id),
-                    provider=self.lookup_key,
-                    name=album.artist.name,
-                )
-            ],
+            artists=UniqueList(
+                [
+                    ItemMapping(
+                        media_type=MediaType.ARTIST,
+                        item_id=str(album.artist.id),
+                        provider=self.lookup_key,
+                        name=album.artist.name,
+                    )
+                ]
+            ),
             media_type=MediaType.ALBUM,
             provider_mappings={
                 ProviderMapping(
                     item_id=str(album.id),
                     provider_domain=self.domain,
                     provider_instance=self.instance_id,
-                    url=album.link,
+                    url=getattr(album, "link", None),
                 )
             },
             metadata=self.parse_metadata_album(album=album),
@@ -613,25 +628,27 @@ class DeezerProvider(MusicProvider):
                     item_id=str(playlist.id),
                     provider_domain=self.domain,
                     provider_instance=self.instance_id,
-                    url=playlist.link,
+                    url=getattr(playlist, "link", None),
                 )
             },
             metadata=MediaItemMetadata(
-                images=[
-                    MediaItemImage(
-                        type=ImageType.THUMB,
-                        path=playlist.picture_big,
-                        provider=self.lookup_key,
-                        remotely_accessible=True,
-                    )
-                ],
+                images=UniqueList(
+                    [
+                        MediaItemImage(
+                            type=ImageType.THUMB,
+                            path=playlist.picture_big,
+                            provider=self.lookup_key,
+                            remotely_accessible=True,
+                        )
+                    ]
+                ),
             ),
             is_editable=is_editable,
             owner=creator.name,
             cache_checksum=playlist.checksum,
         )
 
-    def get_playlist_creator(self, playlist: deezer.Playlist):
+    def get_playlist_creator(self, playlist: deezer.Playlist) -> deezer.User:
         """On playlists, the creator is called creator, elsewhere it's called user."""
         if hasattr(playlist, "creator"):
             return playlist.creator
@@ -664,7 +681,7 @@ class DeezerProvider(MusicProvider):
             name=track.title,
             sort_name=self.get_short_title(track),
             duration=track.duration,
-            artists=[artist] if artist else [],
+            artists=UniqueList([artist]) if artist else UniqueList(),
             album=album,
             provider_mappings={
                 ProviderMapping(
@@ -672,7 +689,7 @@ class DeezerProvider(MusicProvider):
                     provider_domain=self.domain,
                     provider_instance=self.instance_id,
                     available=self.track_available(track=track, user_country=user_country),
-                    url=track.link,
+                    url=getattr(track, "link", None),
                 )
             },
             metadata=self.parse_metadata_track(track=track),
@@ -684,11 +701,11 @@ class DeezerProvider(MusicProvider):
             item.external_ids.add((ExternalID.ISRC, isrc))
         return item
 
-    def get_short_title(self, track: deezer.Track):
+    def get_short_title(self, track: deezer.Track) -> str:
         """Short names only returned, if available."""
         if hasattr(track, "title_short"):
-            return track.title_short
-        return track.title
+            return str(track.title_short)
+        return str(track.title)
 
     def get_album_type(self, album: deezer.Album) -> AlbumType:
         """Read and convert the Deezer album type."""
@@ -758,7 +775,9 @@ class DeezerProvider(MusicProvider):
 
     ### OTHER FUNCTIONS ###
 
-    async def get_track_content_type(self, gw_client: GWClient, track_id: int):
+    async def get_track_content_type(
+        self, gw_client: GWClient, track_id: str
+    ) -> Literal[ContentType.FLAC, ContentType.MP3]:
         """Get a tracks contentType."""
         song_data = await gw_client.get_song_data(track_id)
         if song_data["results"]["FILESIZE_FLAC"]:
@@ -776,12 +795,12 @@ class DeezerProvider(MusicProvider):
             return user_country in track.available_countries
         return True
 
-    def _md5(self, data, data_type="ascii"):
+    def _md5(self, data: str, data_type: str = "ascii") -> str:
         md5sum = hashlib.md5()
         md5sum.update(data.encode(data_type))
         return md5sum.hexdigest()
 
-    def get_blowfish_key(self, track_id):
+    def get_blowfish_key(self, track_id: str) -> str:
         """Get blowfish key to decrypt a chunk of a track."""
         secret = app_var(5)
         id_md5 = self._md5(track_id)
@@ -789,7 +808,7 @@ class DeezerProvider(MusicProvider):
             chr(ord(id_md5[i]) ^ ord(id_md5[i + 16]) ^ ord(secret[i])) for i in range(16)
         )
 
-    def decrypt_chunk(self, chunk, blowfish_key):
+    def decrypt_chunk(self, chunk: bytes, blowfish_key: str) -> bytes:
         """Decrypt a given chunk using the blow fish key."""
         cipher = Blowfish.new(
             blowfish_key.encode("ascii"),
index f5dda1eba097901f69253f76fc84b846959faa0c..8633135341c4352d9759fa9ae3c408419cf0d17e 100644 (file)
@@ -5,9 +5,11 @@ cookie based on the api_token.
 """
 
 import datetime
+from collections.abc import Mapping
 from http.cookies import BaseCookie, Morsel
+from typing import Any, cast
 
-from aiohttp import ClientSession
+from aiohttp import ClientSession, ClientTimeout
 from music_assistant_models.streamdetails import StreamDetails
 from yarl import URL
 
@@ -46,12 +48,10 @@ class GWClient:
         self.session = session
 
     async def _set_cookie(self) -> None:
-        cookie = Morsel()
+        cookie: Morsel[str] = Morsel()
 
         cookie.set("arl", self._arl_token, self._arl_token)
-        cookie.domain = ".deezer.com"
-        cookie.path = "/"
-        cookie.httponly = {"HttpOnly": True}
+        cookie.update({"domain": ".deezer.com", "path": "/", "httponly": "True"})
 
         self.session.cookie_jar.update_cookies(BaseCookie({"arl": cookie}), URL(GW_LIGHT_URL))
 
@@ -84,7 +84,7 @@ class GWClient:
         await self._set_cookie()
         await self._update_user_data()
 
-    async def _get_license(self):
+    async def _get_license(self) -> str | None:
         if (
             self._license_expiration_timestamp
             < (datetime.datetime.now() + datetime.timedelta(days=1)).timestamp()
@@ -93,8 +93,14 @@ class GWClient:
         return self._license
 
     async def _gw_api_call(
-        self, method, use_csrf_token=True, args=None, params=None, http_method="POST", retry=True
-    ):
+        self,
+        method: str,
+        use_csrf_token: bool = True,
+        args: dict[str, Any] | None = None,
+        params: dict[str, Any] | None = None,
+        http_method: str = "POST",
+        retry: bool = True,
+    ) -> dict[str, Any]:
         csrf_token = self._gw_csrf_token if use_csrf_token else "null"
         if params is None:
             params = {}
@@ -103,8 +109,8 @@ class GWClient:
         result = await self.session.request(
             http_method,
             GW_LIGHT_URL,
-            params=parameters,
-            timeout=30,
+            params=cast("Mapping[str, str]", parameters),
+            timeout=ClientTimeout(total=30),
             json=args,
             headers={"User-Agent": USER_AGENT_HEADER},
         )
@@ -119,13 +125,13 @@ class GWClient:
             else:
                 msg = "Failed to call GW-API"
                 raise DeezerGWError(msg, result_json["error"])
-        return result_json
+        return cast("dict[str, Any]", result_json)
 
-    async def get_song_data(self, track_id):
+    async def get_song_data(self, track_id: str) -> dict[str, Any]:
         """Get data such as the track token for a given track."""
         return await self._gw_api_call("song.getData", args={"SNG_ID": track_id})
 
-    async def get_deezer_track_urls(self, track_id):
+    async def get_deezer_track_urls(self, track_id: str) -> tuple[dict[str, Any], dict[str, Any]]:
         """Get the URL for a given track id."""
         dz_license = await self._get_license()
 
@@ -171,7 +177,7 @@ class GWClient:
             msg = "last or current track information must be provided."
             raise DeezerGWError(msg)
 
-        payload = {}
+        payload: dict[str, Any] = {}
 
         if next_track:
             payload["next_media"] = {"media": {"id": next_track, "type": "song"}}