Allow frontend to send base url for auth redirects (#1593)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Thu, 22 Aug 2024 09:51:46 +0000 (11:51 +0200)
committerGitHub <noreply@github.com>
Thu, 22 Aug 2024 09:51:46 +0000 (11:51 +0200)
music_assistant/server/controllers/webserver.py
music_assistant/server/helpers/auth.py
music_assistant/server/providers/spotify/__init__.py

index c3faff21ea935a713323997a9f6a1193c67e440b..725d140773642904a7938d87a23d38a38b281507 100644 (file)
@@ -11,6 +11,7 @@ import asyncio
 import logging
 import os
 import urllib.parse
+from collections.abc import Callable
 from concurrent import futures
 from contextlib import suppress
 from functools import partial
@@ -64,6 +65,7 @@ class WebserverController(CoreController):
             "The built-in webserver that hosts the Music Assistant Websockets API and frontend"
         )
         self.manifest.icon = "web-box"
+        self._auth_callbacks: dict[str, Callable] | None = {}
 
     @property
     def base_url(self) -> str:
@@ -160,6 +162,8 @@ class WebserverController(CoreController):
         routes.append(("GET", "/imageproxy", self.mass.metadata.handle_imageproxy))
         # also host the audio preview service
         routes.append(("GET", "/preview", self.serve_preview_stream))
+        # also host the auth callback service
+        routes.append(("*", "/callback/{session_id}", self._handle_auth_callback))
         # start the webserver
         default_publish_ip = await get_ip()
         if self.mass.running_as_hass_addon:
@@ -208,6 +212,22 @@ class WebserverController(CoreController):
             await resp.write(chunk)
         return resp
 
+    def register_auth_callback(self, session_id: str, handler: Awaitable) -> Callable:
+        """Register a auth callback, returns handler to unregister."""
+        if session_id in self._auth_callbacks:
+            msg = f"Session {session_id} already registered."
+            raise RuntimeError(msg)
+        self._auth_callbacks[session_id] = handler
+
+        def _remove():
+            return self._auth_callbacks.pop(session_id, None)
+
+        return _remove
+
+    def unregister_auth_callback(self, session_id: str) -> None:
+        """Unregister a auth callback from the webserver."""
+        self._auth_callbacks.pop(session_id)
+
     async def _handle_server_info(self, request: web.Request) -> web.Response:
         """Handle request for server info."""
         return web.json_response(self.mass.get_server_info().to_dict())
@@ -227,6 +247,13 @@ class WebserverController(CoreController):
         log_data = await self.mass.get_application_log()
         return web.Response(text=log_data, content_type="text/text")
 
+    async def _handle_auth_callback(self, request: web.Request) -> web.Response:
+        """Handle request for the auth callback."""
+        session_id = request.match_info["session_id"]
+        if handler := self._auth_callbacks.get(session_id):
+            return await handler(request)
+        return web.Response(status=403)
+
 
 class WebsocketClientHandler:
     """Handle an active websocket client connection."""
index 76c47741dccdd01cd3b5e8739397c45045d3d7a7..ce63d9c7ef57e723a1fc6452bce81321e386a487 100644 (file)
@@ -18,28 +18,40 @@ if TYPE_CHECKING:
 class AuthenticationHelper:
     """Context manager helper class for authentication with a forward and redirect URL."""
 
-    def __init__(self, mass: MusicAssistant, session_id: str) -> None:
+    def __init__(
+        self, mass: MusicAssistant, session_id: str, frontend_base_url: str | None = None
+    ) -> None:
         """
         Initialize the Authentication Helper.
 
         Params:
         - url: The URL the user needs to open for authentication.
         - session_id: a unique id for this auth session.
+        - (optional) frontend_base_url: The base URL the frontend is using.
         """
         self.mass = mass
         self.session_id = session_id
+        self.frontend_base_url = frontend_base_url
         self._callback_response: asyncio.Queue[dict[str, str]] = asyncio.Queue(1)
 
     @property
     def callback_url(self) -> str:
         """Return the callback URL."""
+        if self.frontend_base_url:
+            return f"{self.frontend_base_url}/callback/{self.session_id}"
         return f"{self.mass.streams.base_url}/callback/{self.session_id}"
 
     async def __aenter__(self) -> AuthenticationHelper:
         """Enter context manager."""
-        self.mass.streams.register_dynamic_route(
-            f"/callback/{self.session_id}", self._handle_callback, "GET"
-        )
+        if self.frontend_base_url:
+            self.mass.webserver.register_auth_callback(
+                self.session_id,
+                self._handle_callback,
+            )
+        else:
+            self.mass.streams.register_dynamic_route(
+                f"/callback/{self.session_id}", self._handle_callback, "GET"
+            )
         return self
 
     async def __aexit__(
@@ -49,7 +61,12 @@ class AuthenticationHelper:
         exc_tb: TracebackType | None,
     ) -> bool | None:
         """Exit context manager."""
-        self.mass.streams.unregister_dynamic_route(f"/callback/{self.session_id}", "GET")
+        if self.frontend_base_url:
+            self.mass.webserver.unregister_auth_callback(
+                self.session_id,
+            )
+        else:
+            self.mass.streams.unregister_dynamic_route(f"/callback/{self.session_id}", "GET")
 
     async def authenticate(self, auth_url: str, timeout: int = 60) -> dict[str, str]:
         """Start the auth process and return any query params if received on the callback."""
@@ -77,6 +94,7 @@ class AuthenticationHelper:
         <html>
         <body onload="window.close();">
             Authentication completed, you may now close this window.
+            Don't forget to press save in the Music Assistant settings page.
         </body>
         </html>
         """
index 7535ea5dacf6f7f3766e4c4d5943ed5067df5596..905383492402d9ad6f924733244038830d7d1f9a 100644 (file)
@@ -141,7 +141,9 @@ async def get_config_entries(
         import pkce
 
         code_verifier, code_challenge = pkce.generate_pkce_pair()
-        async with AuthenticationHelper(mass, cast(str, values["session_id"])) as auth_helper:
+        async with AuthenticationHelper(
+            mass, cast(str, values["session_id"]), values.get("frontend_base_url")
+        ) as auth_helper:
             params = {
                 "response_type": "code",
                 "client_id": values.get(CONF_CLIENT_ID) or app_var(2),