Tidal PKCE login (#1509)
authorJozef Kruszynski <60214390+jozefKruszynski@users.noreply.github.com>
Sun, 21 Jul 2024 23:58:56 +0000 (01:58 +0200)
committerGitHub <noreply@github.com>
Sun, 21 Jul 2024 23:58:56 +0000 (01:58 +0200)
music_assistant/server/providers/tidal/__init__.py
music_assistant/server/providers/tidal/helpers.py

index e9018d2aa31d47a7bd1bc5f376c56ea183852de2..63a18446885223126eaadde958493d90942bf591 100644 (file)
@@ -3,15 +3,17 @@
 from __future__ import annotations
 
 import asyncio
+import base64
+import pickle
 from contextlib import suppress
 from datetime import datetime, timedelta
+from enum import StrEnum
 from typing import TYPE_CHECKING, Any, cast
 
 from tidalapi import Album as TidalAlbum
 from tidalapi import Artist as TidalArtist
 from tidalapi import Config as TidalConfig
 from tidalapi import Playlist as TidalPlaylist
-from tidalapi import Quality as TidalQuality
 from tidalapi import Session as TidalSession
 from tidalapi import Track as TidalTrack
 from tidalapi import exceptions as tidal_exceptions
@@ -66,8 +68,8 @@ from .helpers import (
     get_playlist,
     get_playlist_tracks,
     get_similar_tracks,
+    get_stream,
     get_track,
-    get_track_url,
     library_items_add_remove,
     remove_playlist_tracks,
     search,
@@ -77,6 +79,7 @@ if TYPE_CHECKING:
     from collections.abc import AsyncGenerator, Awaitable, Callable
 
     from tidalapi.media import Lyrics as TidalLyrics
+    from tidalapi.media import Stream as TidalStream
 
     from music_assistant.common.models.config_entries import ProviderConfig
     from music_assistant.common.models.provider import ProviderManifest
@@ -84,17 +87,39 @@ if TYPE_CHECKING:
     from music_assistant.server.models import ProviderInstanceType
 
 TOKEN_TYPE = "Bearer"
-CONF_ACTION_AUTH = "auth"
+
+# Actions
+CONF_ACTION_START_PKCE_LOGIN = "start_pkce_login"
+CONF_ACTION_COMPLETE_PKCE_LOGIN = "auth"
+CONF_ACTION_CLEAR_AUTH = "clear_auth"
+
+# Intermediate steps
+CONF_TEMP_SESSION = "temp_session"
+CONF_OOPS_URL = "oops_url"
+
+# Config keys
 CONF_AUTH_TOKEN = "auth_token"
 CONF_REFRESH_TOKEN = "refresh_token"
 CONF_USER_ID = "user_id"
 CONF_EXPIRY_TIME = "expiry_time"
 CONF_QUALITY = "quality"
 
+# Labels
+LABEL_START_PKCE_LOGIN = "start_pkce_login_label"
+LABEL_OOPS_URL = "oops_url_label"
+LABEL_COMPLETE_PKCE_LOGIN = "complete_pkce_login_label"
+
 BROWSE_URL = "https://tidal.com/browse"
 RESOURCES_URL = "https://resources.tidal.com/images"
 
 
+class TidalQualityEnum(StrEnum):
+    """Enum for Tidal Quality."""
+
+    HIGH_LOSSLESS = "HIGH_LOSSLESS"  # "High - 16bit, 44.1kHz"
+    HI_RES = "HI_RES"  # "Max - Up to 24bit, 192kHz"
+
+
 async def setup(
     mass: MusicAssistant, manifest: ProviderManifest, config: ProviderConfig
 ) -> ProviderInstanceType:
@@ -104,15 +129,31 @@ async def setup(
     return prov
 
 
-async def tidal_code_login(auth_helper: AuthenticationHelper, quality: str) -> TidalSession:
-    """Async wrapper around the tidalapi Session function."""
+async def tidal_auth_url(auth_helper: AuthenticationHelper, quality: str) -> str:
+    """Generate the Tidal authentication URL."""
 
     def inner() -> TidalSession:
+        # global glob_temp_session
         config = TidalConfig(quality=quality, item_limit=10000, alac=False)
         session = TidalSession(config=config)
-        login, future = session.login_oauth()
-        auth_helper.send_url(f"https://{login.verification_uri_complete}")
-        future.result()
+        url = session.pkce_login_url()
+        auth_helper.send_url(url)
+        session_bytes = pickle.dumps(session)
+        base64_bytes = base64.b64encode(session_bytes)
+        return base64_bytes.decode("utf-8")
+
+    return await asyncio.to_thread(inner)
+
+
+async def tidal_pkce_login(base64_session: str, url: str) -> TidalSession:
+    """Async wrapper around the tidalapi Session function."""
+
+    def inner() -> TidalSession:
+        base64_bytes = base64_session.encode("utf-8")
+        message_bytes = base64.b64decode(base64_bytes)
+        session = pickle.loads(message_bytes)  # noqa: S301
+        token = session.pkce_get_auth_token(url_redirect=url)
+        session.process_auth_token(token)
         return session
 
     return await asyncio.to_thread(inner)
@@ -131,54 +172,133 @@ async def get_config_entries(
     action: [optional] action key called from config entries UI.
     values: the (intermediate) raw values for config entries sent with the action.
     """
-    # config flow auth action/step (authenticate button clicked)
-    if action == CONF_ACTION_AUTH:
+    if action == CONF_ACTION_START_PKCE_LOGIN:
         async with AuthenticationHelper(mass, cast(str, values["session_id"])) as auth_helper:
-            quality: str | int | float | list[str] | list[int] | None = (
-                values.get(CONF_QUALITY) if values else None
-            )
-            tidal_session = await tidal_code_login(auth_helper, cast(str, quality))
-            if not tidal_session.check_login():
-                msg = "Authentication to Tidal failed"
-                raise LoginFailed(msg)
-            # set the retrieved token on the values object to pass along
-            values[CONF_AUTH_TOKEN] = tidal_session.access_token
-            values[CONF_REFRESH_TOKEN] = tidal_session.refresh_token
-            values[CONF_EXPIRY_TIME] = tidal_session.expiry_time.isoformat()
-            values[CONF_USER_ID] = str(tidal_session.user.id)
-
-    # config flow auth action/step to pick the library to use
-    # because this call is very slow, we only show/calculate the dropdown if we do
-    # not yet have this info or we/user invalidated it.
+            quality: str = values.get(CONF_QUALITY) if values else None
+            base64_session = await tidal_auth_url(auth_helper, cast(str, quality))
+            values[CONF_TEMP_SESSION] = base64_session
+
+    if action == CONF_ACTION_COMPLETE_PKCE_LOGIN:
+        quality: str = values.get(CONF_QUALITY) if values else None
+        pkce_url: str = values.get(CONF_OOPS_URL) if values else None
+        base64_session = values.get(CONF_TEMP_SESSION) if values else None
+        tidal_session = await tidal_pkce_login(base64_session, pkce_url)
+        if not tidal_session.check_login():
+            msg = "Authentication to Tidal failed"
+            raise LoginFailed(msg)
+        # set the retrieved token on the values object to pass along
+        values[CONF_AUTH_TOKEN] = tidal_session.access_token
+        values[CONF_REFRESH_TOKEN] = tidal_session.refresh_token
+        values[CONF_EXPIRY_TIME] = tidal_session.expiry_time.isoformat()
+        values[CONF_USER_ID] = str(tidal_session.user.id)
+        values[CONF_TEMP_SESSION] = ""
+
+    if action == CONF_ACTION_CLEAR_AUTH:
+        values[CONF_AUTH_TOKEN] = None
+
+    if values.get(CONF_AUTH_TOKEN):
+        auth_entries = (
+            ConfigEntry(
+                key=CONF_ACTION_CLEAR_AUTH,
+                type=ConfigEntryType.ACTION,
+                label="Reset authentication",
+                description="Reset the authentication for Tidal",
+                action=CONF_ACTION_CLEAR_AUTH,
+                value=None,
+            ),
+            ConfigEntry(
+                key=CONF_QUALITY,
+                type=ConfigEntryType.STRING,
+                label=CONF_QUALITY,
+                required=True,
+                hidden=True,
+                default_value=values.get(CONF_QUALITY, TidalQualityEnum.HI_RES.value),
+                value=values.get(CONF_QUALITY),
+            ),
+        )
+    else:
+        auth_entries = (
+            ConfigEntry(
+                key=CONF_QUALITY,
+                type=ConfigEntryType.STRING,
+                label="Quality setting for Tidal:",
+                required=True,
+                description="HIGH_LOSSLESS = 16bit 44.1kHz, HI_RES = Up to 24bit 192kHz",
+                options=tuple(ConfigValueOption(x.value, x.name) for x in TidalQualityEnum),
+                default_value=TidalQualityEnum.HI_RES.value,
+                value=values.get(CONF_QUALITY) if values else None,
+            ),
+            ConfigEntry(
+                key=LABEL_START_PKCE_LOGIN,
+                type=ConfigEntryType.LABEL,
+                label="The button below will redirect you to Tidal.com to authenticate."
+                " After authenticating, you will be redirected to a page that prominently displays"
+                " 'Oops' at the top.",
+            ),
+            ConfigEntry(
+                key=CONF_ACTION_START_PKCE_LOGIN,
+                type=ConfigEntryType.ACTION,
+                label="Starts the auth process via PKCE on Tidal.com",
+                description="This button will redirect you to Tidal.com to authenticate."
+                " After authenticating, you will be redirected to a page that prominently displays"
+                " 'Oops' at the top.",
+                action=CONF_ACTION_START_PKCE_LOGIN,
+                depends_on=CONF_QUALITY,
+                action_label="Starts the auth process via PKCE on Tidal.com",
+                value=values.get(CONF_TEMP_SESSION) if values else None,
+            ),
+            ConfigEntry(
+                key=CONF_TEMP_SESSION,
+                type=ConfigEntryType.STRING,
+                label="Temporary session for Tidal",
+                hidden=True,
+                required=False,
+                value=values.get(CONF_TEMP_SESSION) if values else None,
+            ),
+            ConfigEntry(
+                key=LABEL_OOPS_URL,
+                type=ConfigEntryType.LABEL,
+                label="Copy the URL from the 'Oops' page that you were previously redirected to"
+                " and paste it in the field below",
+            ),
+            ConfigEntry(
+                key=CONF_OOPS_URL,
+                type=ConfigEntryType.STRING,
+                label="Oops URL from Tidal redirect",
+                description="This field should be filled manually by you after authenticating on"
+                " Tidal.com and being redirected to a page that prominently displays"
+                " 'Oops' at the top.",
+                depends_on=CONF_ACTION_START_PKCE_LOGIN,
+                value=values.get(CONF_OOPS_URL) if values else None,
+            ),
+            ConfigEntry(
+                key=LABEL_COMPLETE_PKCE_LOGIN,
+                type=ConfigEntryType.LABEL,
+                label="After pasting the URL in the field above, click the button below to complete"
+                " the process.",
+            ),
+            ConfigEntry(
+                key=CONF_ACTION_COMPLETE_PKCE_LOGIN,
+                type=ConfigEntryType.ACTION,
+                label="Complete the auth process via PKCE on Tidal.com",
+                description="Click this after adding the 'Oops' URL above, this will complete the"
+                " authentication process.",
+                action=CONF_ACTION_COMPLETE_PKCE_LOGIN,
+                depends_on=CONF_OOPS_URL,
+                action_label="Complete the auth process via PKCE on Tidal.com",
+                value=None,
+            ),
+        )
 
     # return the collected config entries
     return (
-        ConfigEntry(
-            key=CONF_QUALITY,
-            type=ConfigEntryType.STRING,
-            label="Quality",
-            required=True,
-            description="The Tidal Quality you wish to use",
-            options=(
-                ConfigValueOption(title=TidalQuality.low_96k, value=TidalQuality.low_96k),
-                ConfigValueOption(title=TidalQuality.low_320k, value=TidalQuality.low_320k),
-                ConfigValueOption(
-                    title=TidalQuality.high_lossless,
-                    value=TidalQuality.high_lossless,
-                ),
-                ConfigValueOption(title=TidalQuality.hi_res, value=TidalQuality.hi_res),
-            ),
-            default_value=TidalQuality.high_lossless,
-            value=values.get(CONF_QUALITY) if values else None,
-        ),
+        *auth_entries,
         ConfigEntry(
             key=CONF_AUTH_TOKEN,
             type=ConfigEntryType.SECURE_STRING,
             label="Authentication token for Tidal",
             description="You need to link Music Assistant to your Tidal account.",
-            action=CONF_ACTION_AUTH,
-            depends_on=CONF_QUALITY,
-            action_label="Authenticate on Tidal.com",
+            hidden=True,
             value=values.get(CONF_AUTH_TOKEN) if values else None,
         ),
         ConfigEntry(
@@ -429,22 +549,28 @@ class TidalProvider(MusicProvider):
 
     async def get_stream_details(self, item_id: str) -> StreamDetails:
         """Return the content details for the given track when it will be streamed."""
-        # make sure a valid track is requested.
         tidal_session = await self._get_tidal_session()
-        track = await get_track(tidal_session, item_id)
-        url = await get_track_url(tidal_session, item_id)
-        media_info = await self._get_media_info(item_id=item_id, url=url)
-        if not track:
+        # make sure a valid track is requested.
+        if not (track := await get_track(tidal_session, item_id)):
             msg = f"track {item_id} not found"
             raise MediaNotFoundError(msg)
+        stream: TidalStream = await get_stream(track)
+        manifest = stream.get_stream_manifest()
+        if manifest.is_MPD:
+            # for mpeg-dash streams we just pass the complete base64 manifest
+            url = f"data:application/dash+xml;base64,{manifest.manifest}"
+        else:
+            # as far as I can oversee a BTS stream is just a single URL
+            url = manifest.urls[0]
+
         return StreamDetails(
             item_id=track.id,
             provider=self.instance_id,
             audio_format=AudioFormat(
-                content_type=ContentType.try_parse(media_info.format),
-                sample_rate=media_info.sample_rate,
-                bit_depth=media_info.bits_per_sample,
-                channels=media_info.channels,
+                content_type=ContentType.try_parse(manifest.codecs),
+                sample_rate=manifest.sample_rate,
+                bit_depth=stream.bit_depth,
+                channels=2,
             ),
             stream_type=StreamType.HTTP,
             duration=track.duration,
@@ -548,7 +674,13 @@ class TidalProvider(MusicProvider):
         def inner() -> TidalSession:
             config = TidalConfig(quality=quality, item_limit=10000, alac=False)
             session = TidalSession(config=config)
-            session.load_oauth_session(token_type, access_token, refresh_token, expiry_time)
+            session.load_oauth_session(
+                token_type=token_type,
+                access_token=access_token,
+                refresh_token=refresh_token,
+                expiry_time=expiry_time,
+                is_pkce=True,
+            )
             return session
 
         return await asyncio.to_thread(inner)
index 035afd23be9ec562cefa6cb82c19109148ea71d2..e5b21cc92b53ae72b708f3df3f5cbcf8b4b064e7 100644 (file)
@@ -21,6 +21,7 @@ from tidalapi import Session as TidalSession
 from tidalapi import Track as TidalTrack
 from tidalapi import UserPlaylist as TidalUserPlaylist
 from tidalapi.exceptions import MetadataNotAvailable, ObjectNotFound, TooManyRequests
+from tidalapi.media import Stream as TidalStream
 
 from music_assistant.common.models.enums import MediaType
 from music_assistant.common.models.errors import (
@@ -186,6 +187,22 @@ async def get_track(session: TidalSession, prov_track_id: str) -> TidalTrack:
     return await asyncio.to_thread(inner)
 
 
+async def get_stream(track: TidalTrack) -> TidalStream:
+    """Async wrapper around the tidalapi Track.get_stream_url function."""
+
+    def inner() -> str:
+        try:
+            return track.get_stream()
+        except ObjectNotFound as err:
+            msg = f"Track {track.id} has no available stream"
+            raise MediaNotFoundError(msg) from err
+        except TooManyRequests:
+            msg = "Tidal API rate limit reached"
+            raise ResourceTemporarilyUnavailable(msg)
+
+    return await asyncio.to_thread(inner)
+
+
 async def get_track_url(session: TidalSession, prov_track_id: str) -> str:
     """Async wrapper around the tidalapi Track.get_url function."""