Use webserver for auth helper (#2170)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Sun, 4 May 2025 22:20:56 +0000 (00:20 +0200)
committerGitHub <noreply@github.com>
Sun, 4 May 2025 22:20:56 +0000 (00:20 +0200)
Always prefer webserver for auth helper and try to dynamically detect/handle reverse proxy/ingress in front of the webserver.

music_assistant/controllers/webserver.py
music_assistant/helpers/auth.py
music_assistant/providers/apple_music/__init__.py

index 78724ee24dda9bb83ea9fc44a650b8e8e67fea7f..2a6a4109fee0d15e4947259fcba0d3498a08f462 100644 (file)
@@ -13,6 +13,7 @@ import os
 import urllib.parse
 from concurrent import futures
 from contextlib import suppress
+from contextvars import ContextVar
 from functools import partial
 from typing import TYPE_CHECKING, Any, Final
 
@@ -47,6 +48,7 @@ CONF_BASE_URL = "base_url"
 CONF_EXPOSE_SERVER = "expose_server"
 MAX_PENDING_MSG = 512
 CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError)
+_BASE_URL: ContextVar[str] = ContextVar("_BASE_URL", default="")
 
 
 class WebserverController(CoreController):
@@ -70,7 +72,7 @@ class WebserverController(CoreController):
     @property
     def base_url(self) -> str:
         """Return the base_url for the streamserver."""
-        return self._server.base_url
+        return _BASE_URL.get(self._server.base_url)
 
     async def get_config_entries(
         self,
@@ -273,6 +275,12 @@ class WebsocketClientHandler:
         self._handle_task: asyncio.Task | None = None
         self._writer_task: asyncio.Task | None = None
         self._logger = webserver.logger
+        # try to dynamically detect the base_url of a client if proxied or behind Ingress
+        self.base_url: str | None = None
+        if forward_host := request.headers.get("X-Forwarded-Host"):
+            ingress_path = request.headers.get("X-Ingress-Path", "")
+            forward_proto = request.headers.get("X-Forwarded-Proto", request.protocol)
+            self.base_url = f"{forward_proto}://{forward_host}{ingress_path}"
 
     async def disconnect(self) -> None:
         """Disconnect client."""
@@ -357,6 +365,8 @@ class WebsocketClientHandler:
     def _handle_command(self, msg: CommandMessage) -> None:
         """Handle an incoming command from the client."""
         self._logger.debug("Handling command %s", msg.command)
+        if self.base_url:
+            _BASE_URL.set(self.base_url)
 
         # work out handler for the given path/command
         handler = self.mass.command_handlers.get(msg.command)
index 04ae1947aa7710439b0b9a1067c434f1e033a0eb..9f3465a5e1978f15cf5bd852d1c47573dee7f109 100644 (file)
@@ -3,6 +3,7 @@
 from __future__ import annotations
 
 import asyncio
+import logging
 from types import TracebackType
 from typing import TYPE_CHECKING
 
@@ -13,6 +14,8 @@ from music_assistant_models.errors import LoginFailed
 if TYPE_CHECKING:
     from music_assistant import MusicAssistant
 
+LOGGER = logging.getLogger(__name__)
+
 
 class AuthenticationHelper:
     """Context manager helper class for authentication with a forward and redirect URL."""
@@ -27,18 +30,17 @@ class AuthenticationHelper:
         """
         self.mass = mass
         self.session_id = session_id
+        self._cb_path = f"/callback/{self.session_id}"
         self._callback_response: asyncio.Queue[dict[str, str]] = asyncio.Queue(1)
 
     @property
     def callback_url(self) -> str:
         """Return the callback URL."""
-        return f"{self.mass.streams.base_url}/callback/{self.session_id}"
+        return f"{self.mass.webserver.base_url}{self._cb_path}"
 
     async def __aenter__(self) -> AuthenticationHelper:
         """Enter context manager."""
-        self.mass.streams.register_dynamic_route(
-            f"/callback/{self.session_id}", self._handle_callback, "GET"
-        )
+        self.mass.webserver.register_dynamic_route(self._cb_path, self._handle_callback, "GET")
         return self
 
     async def __aexit__(
@@ -48,11 +50,12 @@ class AuthenticationHelper:
         exc_tb: TracebackType | None,
     ) -> bool | None:
         """Exit context manager."""
-        self.mass.streams.unregister_dynamic_route(f"/callback/{self.session_id}", "GET")
+        self.mass.webserver.unregister_dynamic_route(self._cb_path, "GET")
 
     async def authenticate(self, auth_url: str, timeout: int = 60) -> dict[str, str]:
         """Start the auth process and return any query params if received on the callback."""
         self.send_url(auth_url)
+        LOGGER.debug("Waiting for authentication callback on %s", self.callback_url)
         return await self.wait_for_callback(timeout)
 
     def send_url(self, auth_url: str) -> None:
@@ -72,6 +75,7 @@ class AuthenticationHelper:
         """Handle callback response."""
         params = dict(request.query)
         await self._callback_response.put(params)
+        LOGGER.debug("Received callback with params: %s", params)
         return_html = """
         <html>
         <body onload="window.close();">
index f7ce19b8da430f6661583cd732cd9c1f18d34a5b..d57d9efd5173c665a2a31caad2d4eac24954dbdf 100644 (file)
@@ -129,7 +129,8 @@ async def get_config_entries(
     if action == "CONF_ACTION_AUTH":
         # TODO: check the developer token is valid otherwise user is going to have bad experience
         async with AuthenticationHelper(mass, values["session_id"]) as auth_helper:
-            flow_base_url = f"apple_music_auth/{values['session_id']}/"
+            callback_url = auth_helper.callback_url
+            flow_base_path = f"apple_music_auth/{values['session_id']}/"
             flow_timeout = 600
             parent_file_path = pathlib.Path(__file__).parent.resolve()
 
@@ -144,17 +145,20 @@ async def get_config_entries(
             async def serve_mk_glue(request: web.Request) -> web.Response:
                 return_html = f"const app_token='{values[CONF_MUSIC_APP_TOKEN]}';"
                 return_html += f"const user_token='{values[CONF_MUSIC_USER_TOKEN]}';"
-                return_html += f"const return_url='{auth_helper.callback_url}';"
+                return_html += f"const return_url='{callback_url}';"
                 return_html += f"const flow_timeout={flow_timeout - 10};"
                 return_html += f"const mass_buid='{mass.version}';"
                 return web.Response(body=return_html, headers={"content-type": "text/javascript"})
 
-            mass.webserver.register_dynamic_route(f"/{flow_base_url}index.html", serve_mk_auth_page)
-            mass.webserver.register_dynamic_route(f"/{flow_base_url}index.css", serve_mk_auth_css)
-            mass.webserver.register_dynamic_route(f"/{flow_base_url}index.js", serve_mk_glue)
+            mass.webserver.register_dynamic_route(
+                f"/{flow_base_path}index.html", serve_mk_auth_page
+            )
+            mass.webserver.register_dynamic_route(f"/{flow_base_path}index.css", serve_mk_auth_css)
+            mass.webserver.register_dynamic_route(f"/{flow_base_path}index.js", serve_mk_glue)
+            flow_base_url = f"{mass.webserver.base_url}/{flow_base_path}index.html"
             try:
                 values[CONF_MUSIC_USER_TOKEN] = (
-                    await auth_helper.authenticate(f"{flow_base_url}index.html", flow_timeout)
+                    await auth_helper.authenticate(flow_base_url, flow_timeout)
                 )["music-user-token"]
             except KeyError:
                 # no music-user-token URL param was found so user probably cancelled the auth
@@ -162,9 +166,9 @@ async def get_config_entries(
             except Exception as error:
                 raise LoginFailed(f"Failed to authenticate with Apple '{error}'.")
             finally:
-                mass.webserver.unregister_dynamic_route(f"/{flow_base_url}index.html")
-                mass.webserver.unregister_dynamic_route(f"/{flow_base_url}index.css")
-                mass.webserver.unregister_dynamic_route(f"/{flow_base_url}index.js")
+                mass.webserver.unregister_dynamic_route(f"/{flow_base_path}index.html")
+                mass.webserver.unregister_dynamic_route(f"/{flow_base_path}index.css")
+                mass.webserver.unregister_dynamic_route(f"/{flow_base_path}index.js")
 
     # ruff: noqa: ARG001
     return (