From: Marcel van der Veldt Date: Wed, 26 Nov 2025 15:36:19 +0000 (+0100) Subject: Add (mandatory) authentication to the webserver (#2684) X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=a4585c516c079d4495aa0331d05d27e818974cab;p=music-assistant-server.git Add (mandatory) authentication to the webserver (#2684) --- diff --git a/music_assistant/constants.py b/music_assistant/constants.py index 85f34f77..01b545fc 100644 --- a/music_assistant/constants.py +++ b/music_assistant/constants.py @@ -14,12 +14,15 @@ from music_assistant_models.media_items import AudioFormat APPLICATION_NAME: Final = "Music Assistant" -API_SCHEMA_VERSION: Final[int] = 27 -MIN_SCHEMA_VERSION: Final[int] = 24 +API_SCHEMA_VERSION: Final[int] = 28 +MIN_SCHEMA_VERSION: Final[int] = 28 MASS_LOGGER_NAME: Final[str] = "music_assistant" +# Home Assistant system user +HOMEASSISTANT_SYSTEM_USER: Final[str] = "homeassistant_system" + UNKNOWN_ARTIST: Final[str] = "[unknown]" UNKNOWN_ARTIST_ID_MBID: Final[str] = "125ec42a-7229-4250-afc5-e057484327fe" VARIOUS_ARTISTS_NAME: Final[str] = "Various Artists" @@ -97,6 +100,7 @@ CONF_SMART_FADES_MODE: Final[str] = "smart_fades_mode" CONF_USE_SSL: Final[str] = "use_ssl" CONF_VERIFY_SSL: Final[str] = "verify_ssl" CONF_SSL_FINGERPRINT: Final[str] = "ssl_fingerprint" +CONF_AUTH_ALLOW_SELF_REGISTRATION: Final[str] = "auth_allow_self_registration" # config default values diff --git a/music_assistant/controllers/config.py b/music_assistant/controllers/config.py index 95248658..23d205c2 100644 --- a/music_assistant/controllers/config.py +++ b/music_assistant/controllers/config.py @@ -407,7 +407,7 @@ class ConfigController: ), ] - @api_command("config/providers/save") + @api_command("config/providers/save", required_role="admin") async def save_provider_config( self, provider_domain: str, @@ -431,7 +431,7 @@ class ConfigController: # return full config, just in case return await self.get_provider_config(config.instance_id) - @api_command("config/providers/remove") + @api_command("config/providers/remove", required_role="admin") async def remove_provider_config(self, instance_id: str) -> None: """Remove ProviderConfig.""" conf_key = f"{CONF_PROVIDERS}/{instance_id}" @@ -659,7 +659,7 @@ class ConfigController: } return cast("PlayerConfig", PlayerConfig.parse([], raw_conf)) - @api_command("config/players/save") + @api_command("config/players/save", required_role="admin") async def save_player_config( self, player_id: str, values: dict[str, ConfigValueType] ) -> PlayerConfig: @@ -683,7 +683,7 @@ class ConfigController: # return full player config (just in case) return await self.get_player_config(player_id) - @api_command("config/players/remove") + @api_command("config/players/remove", required_role="admin") async def remove_player_config(self, player_id: str) -> None: """Remove PlayerConfig.""" conf_key = f"{CONF_PLAYERS}/{player_id}" @@ -771,7 +771,7 @@ class ConfigController: return dsp_config - @api_command("config/players/dsp/save") + @api_command("config/players/dsp/save", required_role="admin") async def save_dsp_config(self, player_id: str, config: DSPConfig) -> DSPConfig: """ Save/update DSPConfig for a player. @@ -798,7 +798,7 @@ class ConfigController: raw_presets = self.get(CONF_PLAYER_DSP_PRESETS, {}) return [DSPConfigPreset.from_dict(preset) for preset in raw_presets.values()] - @api_command("config/dsp_presets/save") + @api_command("config/dsp_presets/save", required_role="admin") async def save_dsp_presets(self, preset: DSPConfigPreset) -> DSPConfigPreset: """ Save/update a user-defined DSP presets. @@ -823,7 +823,7 @@ class ConfigController: return preset - @api_command("config/dsp_presets/remove") + @api_command("config/dsp_presets/remove", required_role="admin") async def remove_dsp_preset(self, preset_id: str) -> None: """Remove a user-defined DSP preset.""" self.mass.config.remove(f"{CONF_PLAYER_DSP_PRESETS}/preset_{preset_id}") @@ -914,7 +914,7 @@ class ConfigController: conf_key = f"{CONF_PROVIDERS}/{default_config.instance_id}" self.set(conf_key, default_config.to_raw()) - @api_command("config/core") + @api_command("config/core", required_role="admin") async def get_core_configs(self, include_values: bool = False) -> list[CoreConfig]: """Return all core controllers config options.""" return [ @@ -1005,7 +1005,7 @@ class ConfigController: domain: str, action: str | None = None, values: dict[str, ConfigValueType] | None = None, - ) -> tuple[ConfigEntry, ...]: + ) -> list[ConfigEntry]: """ Return Config entries to configure a core controller. @@ -1016,12 +1016,12 @@ class ConfigController: if values is None: values = self.get(f"{CONF_CORE}/{domain}/values", {}) controller: CoreController = getattr(self.mass, domain) - return ( + return list( await controller.get_config_entries(action=action, values=values) + DEFAULT_CORE_CONFIG_ENTRIES ) - @api_command("config/core/save") + @api_command("config/core/save", required_role="admin") async def save_core_config( self, domain: str, @@ -1249,15 +1249,6 @@ class ConfigController: ] changed = True - # set 'onboard_done' flag if we have any (non default) provider configs - if self._data.get(CONF_ONBOARD_DONE) is None: - default_providers = {x.domain for x in self.mass.get_provider_manifests() if x.builtin} - for provider_config in self._data.get(CONF_PROVIDERS, {}).values(): - if provider_config["domain"] not in default_providers: - self._data[CONF_ONBOARD_DONE] = True - changed = True - break - # migrate player_group entries ugp_found = False for player_config in self._data.get(CONF_PLAYERS, {}).values(): @@ -1343,7 +1334,7 @@ class ConfigController: await _file.write(await async_json_dumps(self._data, indent=True)) LOGGER.debug("Saved data to persistent storage") - @api_command("config/providers/reload") + @api_command("config/providers/reload", required_role="admin") async def _reload_provider(self, instance_id: str) -> None: """Reload provider.""" try: diff --git a/music_assistant/controllers/media/base.py b/music_assistant/controllers/media/base.py index a63e0875..ffe6fb8b 100644 --- a/music_assistant/controllers/media/base.py +++ b/music_assistant/controllers/media/base.py @@ -10,10 +10,7 @@ from contextlib import suppress from typing import TYPE_CHECKING, Any, TypeVar, cast from music_assistant_models.enums import EventType, ExternalID, MediaType, ProviderFeature -from music_assistant_models.errors import ( - MediaNotFoundError, - ProviderUnavailableError, -) +from music_assistant_models.errors import MediaNotFoundError, ProviderUnavailableError from music_assistant_models.media_items import ItemMapping, MediaItemType, ProviderMapping, Track from music_assistant.constants import DB_TABLE_PLAYLOG, DB_TABLE_PROVIDER_MAPPINGS, MASS_LOGGER_NAME @@ -102,10 +99,16 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta): self.mass.register_api_command(f"music/{api_base}/count", self.library_count) self.mass.register_api_command(f"music/{api_base}/library_items", self.library_items) self.mass.register_api_command(f"music/{api_base}/get", self.get) - self.mass.register_api_command(f"music/{api_base}/get_{self.media_type}", self.get) - self.mass.register_api_command(f"music/{api_base}/add", self.add_item_to_library) - self.mass.register_api_command(f"music/{api_base}/update", self.update_item_in_library) - self.mass.register_api_command(f"music/{api_base}/remove", self.remove_item_from_library) + # Backward compatibility alias - prefer the generic "get" endpoint + self.mass.register_api_command( + f"music/{api_base}/get_{self.media_type}", self.get, alias=True + ) + self.mass.register_api_command( + f"music/{api_base}/update", self.update_item_in_library, required_role="admin" + ) + self.mass.register_api_command( + f"music/{api_base}/remove", self.remove_item_from_library, required_role="admin" + ) self._db_add_lock = asyncio.Lock() async def add_item_to_library( diff --git a/music_assistant/controllers/webserver.py b/music_assistant/controllers/webserver.py deleted file mode 100644 index 1f09203c..00000000 --- a/music_assistant/controllers/webserver.py +++ /dev/null @@ -1,544 +0,0 @@ -""" -Controller that manages the builtin webserver that hosts the api and frontend. - -Unlike the streamserver (which is as simple and unprotected as possible), -this webserver allows for more fine grained configuration to better secure it. -""" - -from __future__ import annotations - -import asyncio -import html -import logging -import os -import urllib.parse -from collections.abc import Awaitable, Callable -from concurrent import futures -from contextlib import suppress -from functools import partial -from typing import TYPE_CHECKING, Any, Final, cast - -import aiofiles -from aiohttp import WSMsgType, web -from mashumaro.exceptions import MissingField -from music_assistant_frontend import where as locate_frontend -from music_assistant_models.api import ( - CommandMessage, - ErrorResultMessage, - MessageType, - SuccessResultMessage, -) -from music_assistant_models.config_entries import ConfigEntry, ConfigValueOption -from music_assistant_models.enums import ConfigEntryType -from music_assistant_models.errors import InvalidCommand - -from music_assistant.constants import CONF_BIND_IP, CONF_BIND_PORT, VERBOSE_LOG_LEVEL -from music_assistant.helpers.api import APICommandHandler, parse_arguments -from music_assistant.helpers.api_docs import ( - generate_commands_reference, - generate_openapi_spec, - generate_schemas_reference, -) -from music_assistant.helpers.audio import get_preview_stream -from music_assistant.helpers.json import json_dumps, json_loads -from music_assistant.helpers.util import get_ip_addresses -from music_assistant.helpers.webserver import Webserver -from music_assistant.models.core_controller import CoreController - -if TYPE_CHECKING: - from music_assistant_models.config_entries import ConfigValueType, CoreConfig - from music_assistant_models.event import MassEvent - - from music_assistant import MusicAssistant - -DEFAULT_SERVER_PORT = 8095 -INGRESS_SERVER_PORT = 8094 -CONF_BASE_URL = "base_url" -MAX_PENDING_MSG = 512 -CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError) - - -class WebserverController(CoreController): - """Core Controller that manages the builtin webserver that hosts the api and frontend.""" - - domain: str = "webserver" - - def __init__(self, mass: MusicAssistant) -> None: - """Initialize instance.""" - super().__init__(mass) - self._server = Webserver(self.logger, enable_dynamic_routes=True) - self.register_dynamic_route = self._server.register_dynamic_route - self.unregister_dynamic_route = self._server.unregister_dynamic_route - self.clients: set[WebsocketClientHandler] = set() - self.manifest.name = "Web Server (frontend and api)" - self.manifest.description = ( - "The built-in webserver that hosts the Music Assistant Websockets API and frontend" - ) - self.manifest.icon = "web-box" - - @property - def base_url(self) -> str: - """Return the base_url for the streamserver.""" - return self._server.base_url - - async def get_config_entries( - self, - action: str | None = None, - values: dict[str, ConfigValueType] | None = None, - ) -> tuple[ConfigEntry, ...]: - """Return all Config Entries for this core module (if any).""" - ip_addresses = await get_ip_addresses() - default_publish_ip = ip_addresses[0] - default_base_url = f"http://{default_publish_ip}:{DEFAULT_SERVER_PORT}" - return ( - ConfigEntry( - key="webserver_warn", - type=ConfigEntryType.ALERT, - label="Please note that the webserver is unprotected. " - "Never ever expose the webserver directly to the internet! \n\n" - "Use a reverse proxy or VPN to secure access.", - required=False, - ), - ConfigEntry( - key=CONF_BASE_URL, - type=ConfigEntryType.STRING, - default_value=default_base_url, - label="Base URL", - description="The (base) URL to reach this webserver in the network. \n" - "Override this in advanced scenarios where for example you're running " - "the webserver behind a reverse proxy.", - ), - ConfigEntry( - key=CONF_BIND_PORT, - type=ConfigEntryType.INTEGER, - default_value=DEFAULT_SERVER_PORT, - label="TCP Port", - description="The TCP port to run the webserver.", - ), - ConfigEntry( - key=CONF_BIND_IP, - type=ConfigEntryType.STRING, - default_value="0.0.0.0", - options=[ConfigValueOption(x, x) for x in {"0.0.0.0", *ip_addresses}], - label="Bind to IP/interface", - description="Bind the (web)server to this specific interface. \n" - "Use 0.0.0.0 to bind to all interfaces. \n" - "Set this address for example to a docker-internal network, " - "when you are running a reverse proxy to enhance security and " - "protect outside access to the webinterface and API. \n\n" - "This is an advanced setting that should normally " - "not be adjusted in regular setups.", - category="advanced", - ), - ) - - async def setup(self, config: CoreConfig) -> None: - """Async initialize of module.""" - # work out all routes - routes: list[tuple[str, str, Callable[[web.Request], Awaitable[web.StreamResponse]]]] = [] - # frontend routes - frontend_dir = locate_frontend() - for filename in next(os.walk(frontend_dir))[2]: - if filename.endswith(".py"): - continue - filepath = os.path.join(frontend_dir, filename) - handler = partial(self._server.serve_static, filepath) - routes.append(("GET", f"/{filename}", handler)) - # add index - index_path = os.path.join(frontend_dir, "index.html") - handler = partial(self._server.serve_static, index_path) - routes.append(("GET", "/", handler)) - # add info - routes.append(("GET", "/info", self._handle_server_info)) - # add logging - routes.append(("GET", "/music-assistant.log", self._handle_application_log)) - # add websocket api - routes.append(("GET", "/ws", self._handle_ws_client)) - # also host the image proxy on the webserver - routes.append(("GET", "/imageproxy", self.mass.metadata.handle_imageproxy)) - # also host the audio preview service - routes.append(("GET", "/preview", self.serve_preview_stream)) - # add jsonrpc api - routes.append(("POST", "/api", self._handle_jsonrpc_api_command)) - # add api documentation - routes.append(("GET", "/api-docs", self._handle_api_intro)) - routes.append(("GET", "/api-docs/", self._handle_api_intro)) - routes.append(("GET", "/api-docs/commands", self._handle_commands_reference)) - routes.append(("GET", "/api-docs/commands/", self._handle_commands_reference)) - routes.append(("GET", "/api-docs/schemas", self._handle_schemas_reference)) - routes.append(("GET", "/api-docs/schemas/", self._handle_schemas_reference)) - routes.append(("GET", "/api-docs/openapi.json", self._handle_openapi_spec)) - routes.append(("GET", "/api-docs/swagger", self._handle_swagger_ui)) - routes.append(("GET", "/api-docs/swagger/", self._handle_swagger_ui)) - # start the webserver - all_ip_addresses = await get_ip_addresses() - default_publish_ip = all_ip_addresses[0] - if self.mass.running_as_hass_addon: - # if we're running on the HA supervisor we start an additional TCP site - # on the internal ("172.30.32.) IP for the HA ingress proxy - ingress_host = next( - (x for x in all_ip_addresses if x.startswith("172.30.32.")), default_publish_ip - ) - ingress_tcp_site_params = (ingress_host, INGRESS_SERVER_PORT) - else: - ingress_tcp_site_params = None - base_url = str(config.get_value(CONF_BASE_URL)) - port_value = config.get_value(CONF_BIND_PORT) - assert isinstance(port_value, int) - self.publish_port = port_value - self.publish_ip = default_publish_ip - bind_ip = cast("str | None", config.get_value(CONF_BIND_IP)) - # print a big fat message in the log where the webserver is running - # because this is a common source of issues for people with more complex setups - if not self.mass.config.onboard_done: - self.logger.warning( - "\n\n################################################################################\n" - "Starting webserver on %s:%s - base url: %s\n" - "If this is incorrect, see the documentation how to configure the Webserver\n" - "in Settings --> Core modules --> Webserver\n" - "################################################################################\n", - bind_ip, - self.publish_port, - base_url, - ) - else: - self.logger.info( - "Starting webserver on %s:%s - base url: %s\n#\n", - bind_ip, - self.publish_port, - base_url, - ) - await self._server.setup( - bind_ip=bind_ip, - bind_port=self.publish_port, - base_url=base_url, - static_routes=routes, - # add assets subdir as static_content - static_content=("/assets", os.path.join(frontend_dir, "assets"), "assets"), - ingress_tcp_site_params=ingress_tcp_site_params, - ) - - async def close(self) -> None: - """Cleanup on exit.""" - for client in set(self.clients): - await client.disconnect() - await self._server.close() - - async def serve_preview_stream(self, request: web.Request) -> web.StreamResponse: - """Serve short preview sample.""" - provider_instance_id_or_domain = request.query["provider"] - item_id = urllib.parse.unquote(request.query["item_id"]) - resp = web.StreamResponse(status=200, reason="OK", headers={"Content-Type": "audio/aac"}) - await resp.prepare(request) - async for chunk in get_preview_stream(self.mass, provider_instance_id_or_domain, item_id): - await resp.write(chunk) - return resp - - async def _handle_server_info(self, request: web.Request) -> web.Response: - """Handle request for server info.""" - return web.json_response(self.mass.get_server_info().to_dict()) - - async def _handle_ws_client(self, request: web.Request) -> web.WebSocketResponse: - connection = WebsocketClientHandler(self, request) - if lang := request.headers.get("Accept-Language"): - self.mass.metadata.set_default_preferred_language(lang.split(",")[0]) - try: - self.clients.add(connection) - return await connection.handle_client() - finally: - self.clients.remove(connection) - - async def _handle_jsonrpc_api_command(self, request: web.Request) -> web.Response: - """Handle incoming JSON RPC API command.""" - if not request.can_read_body: - return web.Response(status=400, text="Body required") - cmd_data = await request.read() - self.logger.log(VERBOSE_LOG_LEVEL, "Received on JSONRPC API: %s", cmd_data) - try: - command_msg = CommandMessage.from_json(cmd_data) - except ValueError: - error = f"Invalid JSON: {cmd_data.decode()}" - self.logger.error("Unhandled JSONRPC API error: %s", error) - return web.Response(status=400, text=error) - except MissingField as e: - # be forgiving if message_id is missing - cmd_data_dict = json_loads(cmd_data) - if e.field_name == "message_id" and "command" in cmd_data_dict: - cmd_data_dict["message_id"] = "unknown" - command_msg = CommandMessage.from_dict(cmd_data_dict) - else: - error = f"Missing field in JSON: {e!s}" - self.logger.error("Unhandled JSONRPC API error: %s", error) - return web.Response(status=400, text=error) - - # work out handler for the given path/command - handler = self.mass.command_handlers.get(command_msg.command) - if handler is None: - error = f"Invalid Command: {command_msg.command}" - self.logger.error("Unhandled JSONRPC API error: %s", error) - return web.Response(status=400, text=error) - try: - args = parse_arguments(handler.signature, handler.type_hints, command_msg.args) - result: Any = handler.target(**args) - if hasattr(result, "__anext__"): - # handle async generator (for really large listings) - result = [item async for item in result] - elif asyncio.iscoroutine(result): - result = await result - return web.json_response(result, dumps=json_dumps) - except Exception as e: - # Return clean error message without stacktrace - error_type = type(e).__name__ - error_msg = str(e) - error = f"{error_type}: {error_msg}" - self.logger.error("Error executing command %s: %s", command_msg.command, error) - return web.Response(status=500, text=error) - - async def _handle_application_log(self, request: web.Request) -> web.Response: - """Handle request to get the application log.""" - log_data = await self.mass.get_application_log() - return web.Response(text=log_data, content_type="text/text") - - async def _handle_api_intro(self, request: web.Request) -> web.Response: - """Handle request for API introduction/documentation page.""" - intro_html_path = os.path.join( - os.path.dirname(__file__), "..", "helpers", "resources", "api_docs.html" - ) - # Read the template - async with aiofiles.open(intro_html_path) as f: - html_content = await f.read() - - # Replace placeholders (escape values to prevent XSS) - html_content = html_content.replace("{VERSION}", html.escape(self.mass.version)) - html_content = html_content.replace("{BASE_URL}", html.escape(self.base_url)) - html_content = html_content.replace("{SERVER_HOST}", html.escape(request.host)) - - return web.Response(text=html_content, content_type="text/html") - - async def _handle_openapi_spec(self, request: web.Request) -> web.Response: - """Handle request for OpenAPI specification (generated on-the-fly).""" - spec = generate_openapi_spec( - self.mass.command_handlers, server_url=self.base_url, version=self.mass.version - ) - return web.json_response(spec) - - async def _handle_commands_reference(self, request: web.Request) -> web.Response: - """Handle request for commands reference page (generated on-the-fly).""" - html = generate_commands_reference(self.mass.command_handlers, server_url=self.base_url) - return web.Response(text=html, content_type="text/html") - - async def _handle_schemas_reference(self, request: web.Request) -> web.Response: - """Handle request for schemas reference page (generated on-the-fly).""" - html = generate_schemas_reference(self.mass.command_handlers) - return web.Response(text=html, content_type="text/html") - - async def _handle_swagger_ui(self, request: web.Request) -> web.FileResponse: - """Handle request for Swagger UI.""" - swagger_html_path = os.path.join( - os.path.dirname(__file__), "..", "helpers", "resources", "swagger_ui.html" - ) - return await self._server.serve_static(swagger_html_path, request) - - -class WebsocketClientHandler: - """Handle an active websocket client connection.""" - - def __init__(self, webserver: WebserverController, request: web.Request) -> None: - """Initialize an active connection.""" - self.mass = webserver.mass - self.request = request - self.wsock = web.WebSocketResponse(heartbeat=55) - self._to_write: asyncio.Queue[str | None] = asyncio.Queue(maxsize=MAX_PENDING_MSG) - self._handle_task: asyncio.Task[Any] | None = None - self._writer_task: asyncio.Task[None] | None = None - self._logger = webserver.logger - # try to dynamically detect the base_url of a client if proxied or behind Ingress - self.base_url: str | None = None - if forward_host := request.headers.get("X-Forwarded-Host"): - ingress_path = request.headers.get("X-Ingress-Path", "") - forward_proto = request.headers.get("X-Forwarded-Proto", request.protocol) - self.base_url = f"{forward_proto}://{forward_host}{ingress_path}" - - async def disconnect(self) -> None: - """Disconnect client.""" - self._cancel() - if self._writer_task is not None: - await self._writer_task - - async def handle_client(self) -> web.WebSocketResponse: - """Handle a websocket response.""" - # ruff: noqa: PLR0915 - request = self.request - wsock = self.wsock - try: - async with asyncio.timeout(10): - await wsock.prepare(request) - except TimeoutError: - self._logger.warning("Timeout preparing request from %s", request.remote) - return wsock - - self._logger.log(VERBOSE_LOG_LEVEL, "Connection from %s", request.remote) - self._handle_task = asyncio.current_task() - self._writer_task = self.mass.create_task(self._writer()) - - # send server(version) info when client connects - await self._send_message(self.mass.get_server_info()) - - # forward all events to clients - def handle_event(event: MassEvent) -> None: - self._send_message_sync(event) - - unsub_callback = self.mass.subscribe(handle_event) - - disconnect_warn = None - - try: - while not wsock.closed: - msg = await wsock.receive() - - if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): - break - - if msg.type != WSMsgType.TEXT: - continue - - self._logger.log(VERBOSE_LOG_LEVEL, "Received: %s", msg.data) - - try: - command_msg = CommandMessage.from_json(msg.data) - except ValueError: - disconnect_warn = f"Received invalid JSON: {msg.data}" - break - - await self._handle_command(command_msg) - - except asyncio.CancelledError: - self._logger.debug("Connection closed by client") - - except Exception: - self._logger.exception("Unexpected error inside websocket API") - - finally: - # Handle connection shutting down. - unsub_callback() - self._logger.log(VERBOSE_LOG_LEVEL, "Unsubscribed from events") - - try: - self._to_write.put_nowait(None) - # Make sure all error messages are written before closing - await self._writer_task - await wsock.close() - except asyncio.QueueFull: # can be raised by put_nowait - self._writer_task.cancel() - - finally: - if disconnect_warn is None: - self._logger.log(VERBOSE_LOG_LEVEL, "Disconnected") - else: - self._logger.warning("Disconnected: %s", disconnect_warn) - - return wsock - - async def _handle_command(self, msg: CommandMessage) -> None: - """Handle an incoming command from the client.""" - self._logger.debug("Handling command %s", msg.command) - - # work out handler for the given path/command - handler = self.mass.command_handlers.get(msg.command) - - if handler is None: - await self._send_message( - ErrorResultMessage( - msg.message_id, - InvalidCommand.error_code, - f"Invalid command: {msg.command}", - ) - ) - self._logger.warning("Invalid command: %s", msg.command) - return - - # schedule task to handle the command - self.mass.create_task(self._run_handler(handler, msg)) - - async def _run_handler(self, handler: APICommandHandler, msg: CommandMessage) -> None: - try: - args = parse_arguments(handler.signature, handler.type_hints, msg.args) - result: Any = handler.target(**args) - if hasattr(result, "__anext__"): - # handle async generator (for really large listings) - items: list[Any] = [] - async for item in result: - items.append(item) - if len(items) >= 500: - await self._send_message( - SuccessResultMessage(msg.message_id, items, partial=True) - ) - items = [] - result = items - elif asyncio.iscoroutine(result): - result = await result - await self._send_message(SuccessResultMessage(msg.message_id, result)) - except Exception as err: - if self._logger.isEnabledFor(logging.DEBUG): - self._logger.exception("Error handling message: %s", msg) - else: - self._logger.error("Error handling message: %s: %s", msg.command, str(err)) - err_msg = str(err) or err.__class__.__name__ - await self._send_message( - ErrorResultMessage(msg.message_id, getattr(err, "error_code", 999), err_msg) - ) - - async def _writer(self) -> None: - """Write outgoing messages.""" - # Exceptions if Socket disconnected or cancelled by connection handler - with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS): - while not self.wsock.closed: - if (process := await self._to_write.get()) is None: - break - - if callable(process): - message: str = process() - else: - message = process - self._logger.log(VERBOSE_LOG_LEVEL, "Writing: %s", message) - await self.wsock.send_str(message) - - async def _send_message(self, message: MessageType) -> None: - """Send a message to the client (for large response messages). - - Runs JSON serialization in executor to avoid blocking for large messages. - Closes connection if the client is not reading the messages. - - Async friendly. - """ - # Run JSON serialization in executor to avoid blocking for large messages - loop = asyncio.get_running_loop() - _message = await loop.run_in_executor(None, message.to_json) - - try: - self._to_write.put_nowait(_message) - except asyncio.QueueFull: - self._logger.error("Client exceeded max pending messages: %s", MAX_PENDING_MSG) - - self._cancel() - - def _send_message_sync(self, message: MessageType) -> None: - """Send a message from a sync context (for small messages like events). - - Serializes inline without executor overhead since events are typically small. - """ - _message = message.to_json() - - try: - self._to_write.put_nowait(_message) - except asyncio.QueueFull: - self._logger.error("Client exceeded max pending messages: %s", MAX_PENDING_MSG) - - self._cancel() - - def _cancel(self) -> None: - """Cancel the connection.""" - if self._handle_task is not None: - self._handle_task.cancel() - if self._writer_task is not None: - self._writer_task.cancel() diff --git a/music_assistant/controllers/webserver/__init__.py b/music_assistant/controllers/webserver/__init__.py new file mode 100644 index 00000000..a8cd31ec --- /dev/null +++ b/music_assistant/controllers/webserver/__init__.py @@ -0,0 +1,10 @@ +"""Webserver Controller for Music Assistant. + +Handles the built-in webserver that hosts the API, frontend, and authentication. +""" + +from __future__ import annotations + +from .controller import WebserverController + +__all__ = ["WebserverController"] diff --git a/music_assistant/controllers/webserver/api_docs.py b/music_assistant/controllers/webserver/api_docs.py new file mode 100644 index 00000000..0c3793d7 --- /dev/null +++ b/music_assistant/controllers/webserver/api_docs.py @@ -0,0 +1,1208 @@ +"""Helpers for generating API documentation and OpenAPI specifications.""" + +from __future__ import annotations + +import collections.abc +import inspect +import re +from collections.abc import Callable +from dataclasses import MISSING +from datetime import datetime +from enum import Enum +from types import NoneType, UnionType +from typing import Any, Union, get_args, get_origin, get_type_hints + +from music_assistant_models.player import Player as PlayerState + +from music_assistant.helpers.api import APICommandHandler + + +def _format_type_name(type_hint: Any) -> str: + """Format a type hint as a user-friendly string, using JSON types instead of Python types.""" + if type_hint is NoneType or type_hint is type(None): + return "null" + + # Handle internal Player model - replace with PlayerState + if hasattr(type_hint, "__name__") and type_hint.__name__ == "Player": + if ( + hasattr(type_hint, "__module__") + and type_hint.__module__ == "music_assistant.models.player" + ): + return "PlayerState" + + # Handle PluginSource - replace with PlayerSource (parent type) + if hasattr(type_hint, "__name__") and type_hint.__name__ == "PluginSource": + if ( + hasattr(type_hint, "__module__") + and type_hint.__module__ == "music_assistant.models.plugin" + ): + return "PlayerSource" + + # Map Python types to JSON types + type_name_mapping = { + "str": "string", + "int": "integer", + "float": "number", + "bool": "boolean", + "dict": "object", + "list": "array", + "tuple": "array", + "set": "array", + "frozenset": "array", + "Sequence": "array", + "UniqueList": "array", + "None": "null", + } + + if hasattr(type_hint, "__name__"): + type_name = str(type_hint.__name__) + return type_name_mapping.get(type_name, type_name) + + type_str = str(type_hint).replace("NoneType", "null") + # Replace Python types with JSON types in complex type strings + for python_type, json_type in type_name_mapping.items(): + type_str = type_str.replace(python_type, json_type) + return type_str + + +def _generate_type_alias_description(type_alias: Any, alias_name: str) -> str: + """Generate a human-readable description of a type alias from its definition. + + :param type_alias: The type alias to describe (e.g., ConfigValueType) + :param alias_name: The name of the alias for display + :return: A human-readable description string + """ + # Get the union args + args = get_args(type_alias) + if not args: + return f"Type alias for {alias_name}." + + # Convert each type to a readable name + type_names = [] + for arg in args: + origin = get_origin(arg) + if origin in (list, tuple): + # Handle list types + inner_args = get_args(arg) + if inner_args: + inner_type = inner_args[0] + if inner_type is bool: + type_names.append("array of boolean") + elif inner_type is int: + type_names.append("array of integer") + elif inner_type is float: + type_names.append("array of number") + elif inner_type is str: + type_names.append("array of string") + else: + type_names.append( + f"array of {getattr(inner_type, '__name__', str(inner_type))}" + ) + else: + type_names.append("array") + elif arg is type(None) or arg is NoneType: + type_names.append("null") + elif arg is bool: + type_names.append("boolean") + elif arg is int: + type_names.append("integer") + elif arg is float: + type_names.append("number") + elif arg is str: + type_names.append("string") + elif hasattr(arg, "__name__"): + type_names.append(arg.__name__) + else: + type_names.append(str(arg)) + + # Format the list nicely + if len(type_names) == 1: + types_str = type_names[0] + elif len(type_names) == 2: + types_str = f"{type_names[0]} or {type_names[1]}" + else: + types_str = f"{', '.join(type_names[:-1])}, or {type_names[-1]}" + + return f"Type alias for {alias_name.lower()} types. Can be {types_str}." + + +def _get_type_schema( # noqa: PLR0911, PLR0915 + type_hint: Any, definitions: dict[str, Any] +) -> dict[str, Any]: + """Convert a Python type hint to an OpenAPI schema.""" + # Check if type_hint matches a type alias that was expanded by get_type_hints() + # Import type aliases to compare against + from music_assistant_models.config_entries import ( # noqa: PLC0415 + ConfigValueType as config_value_type, # noqa: N813 + ) + from music_assistant_models.media_items import ( # noqa: PLC0415 + MediaItemType as media_item_type, # noqa: N813 + ) + + if type_hint == config_value_type: + # This is the expanded ConfigValueType, treat it as the type alias + return _get_type_schema("ConfigValueType", definitions) + if type_hint == media_item_type: + # This is the expanded MediaItemType, treat it as the type alias + return _get_type_schema("MediaItemType", definitions) + + # Handle string type hints from __future__ annotations + if isinstance(type_hint, str): + # Handle simple primitive type names + if type_hint in ("str", "string"): + return {"type": "string"} + if type_hint in ("int", "integer"): + return {"type": "integer"} + if type_hint in ("float", "number"): + return {"type": "number"} + if type_hint in ("bool", "boolean"): + return {"type": "boolean"} + + # Special handling for type aliases - create proper schema definitions + if type_hint == "ConfigValueType": + if "ConfigValueType" not in definitions: + from music_assistant_models.config_entries import ( # noqa: PLC0415 + ConfigValueType as config_value_type, # noqa: N813 + ) + + # Dynamically create oneOf schema with description from the actual type + cvt_args = get_args(config_value_type) + definitions["ConfigValueType"] = { + "description": _generate_type_alias_description( + config_value_type, "configuration value" + ), + "oneOf": [_get_type_schema(arg, definitions) for arg in cvt_args], + } + return {"$ref": "#/components/schemas/ConfigValueType"} + + if type_hint == "MediaItemType": + if "MediaItemType" not in definitions: + from music_assistant_models.media_items import ( # noqa: PLC0415 + MediaItemType as media_item_type, # noqa: N813 + ) + + # Dynamically create oneOf schema with description from the actual type + mit_origin = get_origin(media_item_type) + if mit_origin in (Union, UnionType): + mit_args = get_args(media_item_type) + definitions["MediaItemType"] = { + "description": _generate_type_alias_description( + media_item_type, "media item" + ), + "oneOf": [_get_type_schema(arg, definitions) for arg in mit_args], + } + else: + definitions["MediaItemType"] = _get_type_schema(media_item_type, definitions) + return {"$ref": "#/components/schemas/MediaItemType"} + + # Handle PluginSource - replace with PlayerSource (parent type) + if type_hint == "PluginSource": + return _get_type_schema("PlayerSource", definitions) + + # Check if it looks like a simple class name (no special chars, starts with uppercase) + # Examples: "PlayerType", "DeviceInfo", "PlaybackState" + # Exclude generic types like "Any", "Union", "Optional", etc. + excluded_types = {"Any", "Union", "Optional", "List", "Dict", "Tuple", "Set"} + if type_hint.isidentifier() and type_hint[0].isupper() and type_hint not in excluded_types: + # Create a schema reference for this type + if type_hint not in definitions: + definitions[type_hint] = {"type": "object"} + return {"$ref": f"#/components/schemas/{type_hint}"} + + # If it's "Any", return generic object without creating a schema + if type_hint == "Any": + return {"type": "object"} + + # For complex type expressions like "str | None", "list[str]", return generic object + return {"type": "object"} + + # Handle None type + if type_hint is NoneType or type_hint is type(None): + return {"type": "null"} + + # Handle internal Player model - replace with external PlayerState + if hasattr(type_hint, "__name__") and type_hint.__name__ == "Player": + # Check if this is the internal Player (from music_assistant.models.player) + if ( + hasattr(type_hint, "__module__") + and type_hint.__module__ == "music_assistant.models.player" + ): + # Replace with PlayerState from music_assistant_models + return _get_type_schema(PlayerState, definitions) + + # Handle PluginSource - replace with PlayerSource (parent type) + if hasattr(type_hint, "__name__") and type_hint.__name__ == "PluginSource": + # Check if this is PluginSource from music_assistant.models.plugin + if ( + hasattr(type_hint, "__module__") + and type_hint.__module__ == "music_assistant.models.plugin" + ): + # Replace with PlayerSource from music_assistant.models.player + from music_assistant.models.player import PlayerSource # noqa: PLC0415 + + return _get_type_schema(PlayerSource, definitions) + + # Handle Union types (including Optional) + origin = get_origin(type_hint) + if origin is Union or origin is UnionType: + args = get_args(type_hint) + # Check if it's Optional (Union with None) + non_none_args = [arg for arg in args if arg not in (NoneType, type(None))] + if (len(non_none_args) == 1 and NoneType in args) or type(None) in args: + # It's Optional[T], make it nullable + schema = _get_type_schema(non_none_args[0], definitions) + schema["nullable"] = True + return schema + # It's a union of multiple types + return {"oneOf": [_get_type_schema(arg, definitions) for arg in args]} + + # Handle UniqueList (treat as array) + if hasattr(type_hint, "__name__") and type_hint.__name__ == "UniqueList": + args = get_args(type_hint) + if args: + return {"type": "array", "items": _get_type_schema(args[0], definitions)} + return {"type": "array", "items": {}} + + # Handle Sequence types (from collections.abc or typing) + if origin is collections.abc.Sequence or ( + hasattr(origin, "__name__") and origin.__name__ == "Sequence" + ): + args = get_args(type_hint) + if args: + return {"type": "array", "items": _get_type_schema(args[0], definitions)} + return {"type": "array", "items": {}} + + # Handle set/frozenset types + if origin in (set, frozenset): + args = get_args(type_hint) + if args: + return {"type": "array", "items": _get_type_schema(args[0], definitions)} + return {"type": "array", "items": {}} + + # Handle list/tuple types + if origin in (list, tuple): + args = get_args(type_hint) + if args: + return {"type": "array", "items": _get_type_schema(args[0], definitions)} + return {"type": "array", "items": {}} + + # Handle dict types + if origin is dict: + args = get_args(type_hint) + if len(args) == 2: + return { + "type": "object", + "additionalProperties": _get_type_schema(args[1], definitions), + } + return {"type": "object", "additionalProperties": True} + + # Handle Enum types - add them to definitions as explorable objects + if inspect.isclass(type_hint) and issubclass(type_hint, Enum): + enum_name = type_hint.__name__ + if enum_name not in definitions: + enum_values = [item.value for item in type_hint] + enum_type = type(enum_values[0]).__name__ if enum_values else "string" + openapi_type = { + "str": "string", + "int": "integer", + "float": "number", + "bool": "boolean", + }.get(enum_type, "string") + + # Create a detailed enum definition with descriptions + enum_values_str = ", ".join(str(v) for v in enum_values) + definitions[enum_name] = { + "type": openapi_type, + "enum": enum_values, + "description": f"Enum: {enum_name}. Possible values: {enum_values_str}", + } + return {"$ref": f"#/components/schemas/{enum_name}"} + + # Handle datetime + if type_hint is datetime: + return {"type": "string", "format": "date-time"} + + # Handle primitive types - check both exact type and type name + if type_hint is str or (hasattr(type_hint, "__name__") and type_hint.__name__ == "str"): + return {"type": "string"} + if type_hint is int or (hasattr(type_hint, "__name__") and type_hint.__name__ == "int"): + return {"type": "integer"} + if type_hint is float or (hasattr(type_hint, "__name__") and type_hint.__name__ == "float"): + return {"type": "number"} + if type_hint is bool or (hasattr(type_hint, "__name__") and type_hint.__name__ == "bool"): + return {"type": "boolean"} + + # Handle complex types (dataclasses, models) + # Check for __annotations__ or if it's a class (not already handled above) + if hasattr(type_hint, "__annotations__") or ( + inspect.isclass(type_hint) and not issubclass(type_hint, (str, int, float, bool, Enum)) + ): + type_name = getattr(type_hint, "__name__", str(type_hint)) + # Add to definitions if not already there + if type_name not in definitions: + properties = {} + required = [] + + # Check if this is a dataclass with fields + if hasattr(type_hint, "__dataclass_fields__"): + # Resolve type hints to handle forward references from __future__ annotations + try: + resolved_hints = get_type_hints(type_hint) + except Exception: + resolved_hints = {} + + # Use dataclass fields to get proper info including defaults and metadata + for field_name, field_info in type_hint.__dataclass_fields__.items(): + # Skip fields marked with serialize="omit" in metadata + if field_info.metadata: + # Check for mashumaro field_options + if "serialize" in field_info.metadata: + if field_info.metadata["serialize"] == "omit": + continue + + # Use resolved type hint if available, otherwise fall back to field type + field_type = resolved_hints.get(field_name, field_info.type) + field_schema = _get_type_schema(field_type, definitions) + + # Add default value if present + if field_info.default is not MISSING: + field_schema["default"] = field_info.default + elif ( + hasattr(field_info, "default_factory") + and field_info.default_factory is not MISSING + ): + # Has a default factory - don't add anything, just skip + pass + + properties[field_name] = field_schema + + # Check if field is required (not Optional and no default) + has_default = field_info.default is not MISSING or ( + hasattr(field_info, "default_factory") + and field_info.default_factory is not MISSING + ) + is_optional = get_origin(field_type) in ( + Union, + UnionType, + ) and NoneType in get_args(field_type) + if not has_default and not is_optional: + required.append(field_name) + elif hasattr(type_hint, "__annotations__"): + # Fallback for non-dataclass types with annotations + for field_name, field_type in type_hint.__annotations__.items(): + properties[field_name] = _get_type_schema(field_type, definitions) + # Check if field is required (not Optional) + if not ( + get_origin(field_type) in (Union, UnionType) + and NoneType in get_args(field_type) + ): + required.append(field_name) + else: + # Class without dataclass fields or annotations - treat as generic object + pass # Will create empty properties + + definitions[type_name] = { + "type": "object", + "properties": properties, + } + if required: + definitions[type_name]["required"] = required + + return {"$ref": f"#/components/schemas/{type_name}"} + + # Handle Any + if type_hint is Any: + return {"type": "object"} + + # Fallback - for types we don't recognize, at least return a generic object type + return {"type": "object"} + + +def _parse_docstring( # noqa: PLR0915 + func: Callable[..., Any], +) -> tuple[str, str, dict[str, str]]: + """Parse docstring to extract summary, description and parameter descriptions. + + Returns: + Tuple of (short_summary, full_description, param_descriptions) + + Handles multiple docstring formats: + - reStructuredText (:param name: description) + - Google style (Args: section) + - NumPy style (Parameters section) + """ + docstring = inspect.getdoc(func) + if not docstring: + return "", "", {} + + lines = docstring.split("\n") + description_lines = [] + param_descriptions = {} + current_section = "description" + current_param = None + + for line in lines: + stripped = line.strip() + + # Check for section headers + if stripped.lower() in ("args:", "arguments:", "parameters:", "params:"): + current_section = "params" + current_param = None + continue + if stripped.lower() in ( + "returns:", + "return:", + "yields:", + "raises:", + "raises", + "examples:", + "example:", + "note:", + "notes:", + "see also:", + "warning:", + "warnings:", + ): + current_section = "other" + current_param = None + continue + + # Parse :param style + if stripped.startswith(":param "): + current_section = "params" + parts = stripped[7:].split(":", 1) + if len(parts) == 2: + current_param = parts[0].strip() + desc = parts[1].strip() + if desc: + param_descriptions[current_param] = desc + continue + + if stripped.startswith((":type ", ":rtype", ":return")): + current_section = "other" + current_param = None + continue + + # Detect bullet-style params even without explicit section header + # Format: "- param_name: description" + if stripped.startswith("- ") and ":" in stripped: + # This is likely a bullet-style parameter + current_section = "params" + content = stripped[2:] # Remove "- " + parts = content.split(":", 1) + param_name = parts[0].strip() + desc_part = parts[1].strip() if len(parts) > 1 else "" + if param_name and not param_name.startswith(("return", "yield", "raise")): + current_param = param_name + if desc_part: + param_descriptions[current_param] = desc_part + continue + + # In params section, detect param lines (indented or starting with name) + if current_section == "params" and stripped: + # Google/NumPy style: "param_name: description" or "param_name (type): description" + if ":" in stripped and not stripped.startswith(" "): + # Likely a parameter definition + if "(" in stripped and ")" in stripped: + # Format: param_name (type): description + param_part = stripped.split(":")[0] + param_name = param_part.split("(")[0].strip() + desc_part = ":".join(stripped.split(":")[1:]).strip() + else: + # Format: param_name: description + parts = stripped.split(":", 1) + param_name = parts[0].strip() + desc_part = parts[1].strip() if len(parts) > 1 else "" + + if param_name and not param_name.startswith(("return", "yield", "raise")): + current_param = param_name + if desc_part: + param_descriptions[current_param] = desc_part + elif current_param and stripped: + # Continuation of previous parameter description + param_descriptions[current_param] = ( + param_descriptions.get(current_param, "") + " " + stripped + ).strip() + continue + + # Collect description lines (only before params/returns sections) + if current_section == "description" and stripped: + description_lines.append(stripped) + elif current_section == "description" and not stripped and description_lines: + # Empty line in description - keep it for paragraph breaks + description_lines.append("") + + # Join description lines, removing excessive empty lines + description = "\n".join(description_lines).strip() + # Collapse multiple empty lines into one + while "\n\n\n" in description: + description = description.replace("\n\n\n", "\n\n") + + # Extract first sentence/line as summary + summary = "" + if description: + # Get first line or first sentence (whichever is shorter) + first_line = description.split("\n")[0] + # Try to get first sentence (ending with .) + summary = first_line.split(".")[0] + "." if "." in first_line else first_line + + return summary, description, param_descriptions + + +def generate_openapi_spec( + command_handlers: dict[str, APICommandHandler], + server_url: str = "http://localhost:8095", + version: str = "1.0.0", +) -> dict[str, Any]: + """Generate simplified OpenAPI 3.0 specification focusing on data models. + + This spec documents the single /api endpoint and all data models/schemas. + For detailed command documentation, see the Commands Reference page. + """ + definitions: dict[str, Any] = {} + + # Build all schemas from command handlers (this populates definitions) + for handler in command_handlers.values(): + # Skip aliases - they are for backward compatibility only + if handler.alias: + continue + # Build parameter schemas + for param_name in handler.signature.parameters: + if param_name == "self": + continue + # Skip return_type parameter (used only for type hints) + if param_name == "return_type": + continue + param_type = handler.type_hints.get(param_name, Any) + # Skip Any types as they don't provide useful schema information + if param_type is not Any and str(param_type) != "typing.Any": + _get_type_schema(param_type, definitions) + + # Build return type schema + return_type = handler.type_hints.get("return", Any) + # Skip Any types as they don't provide useful schema information + if return_type is not Any and str(return_type) != "typing.Any": + _get_type_schema(return_type, definitions) + + # Build a single /api endpoint with generic request/response + paths = { + "/api": { + "post": { + "summary": "Execute API command", + "description": ( + "Execute any Music Assistant API command.\n\n" + "See the **Commands Reference** page for a complete list of available " + "commands with examples." + ), + "operationId": "execute_command", + "security": [{"bearerAuth": []}], + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["command"], + "properties": { + "command": { + "type": "string", + "description": ( + "The command to execute (e.g., 'players/all')" + ), + "example": "players/all", + }, + "args": { + "type": "object", + "description": "Command arguments (varies by command)", + "additionalProperties": True, + "example": {}, + }, + }, + }, + "examples": { + "get_players": { + "summary": "Get all players", + "value": {"command": "players/all", "args": {}}, + }, + "play_media": { + "summary": "Play media on a player", + "value": { + "command": "players/cmd/play", + "args": {"player_id": "player123"}, + }, + }, + }, + } + }, + }, + "responses": { + "200": { + "description": "Successful command execution", + "content": { + "application/json": { + "schema": {"description": "Command result (varies by command)"} + } + }, + }, + "400": {"description": "Bad request - invalid command or parameters"}, + "401": {"description": "Unauthorized - authentication required"}, + "403": {"description": "Forbidden - insufficient permissions"}, + "500": {"description": "Internal server error"}, + }, + } + }, + "/auth/login": { + "post": { + "summary": "Authenticate with credentials", + "description": "Login with username and password to obtain an access token.", + "operationId": "auth_login", + "tags": ["Authentication"], + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "provider_id": { + "type": "string", + "description": "Auth provider ID (defaults to 'builtin')", + "example": "builtin", + }, + "credentials": { + "type": "object", + "description": "Provider-specific credentials", + "properties": { + "username": {"type": "string"}, + "password": {"type": "string"}, + }, + }, + }, + } + } + }, + }, + "responses": { + "200": { + "description": "Login successful", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "token": {"type": "string"}, + "user": {"type": "object"}, + }, + } + } + }, + }, + "400": {"description": "Invalid credentials"}, + }, + } + }, + "/auth/providers": { + "get": { + "summary": "Get available auth providers", + "description": "Returns list of configured authentication providers.", + "operationId": "auth_providers", + "tags": ["Authentication"], + "responses": { + "200": { + "description": "List of auth providers", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "providers": { + "type": "array", + "items": {"type": "object"}, + } + }, + } + } + }, + } + }, + } + }, + "/setup": { + "post": { + "summary": "Initial server setup", + "description": ( + "Handle initial setup of the Music Assistant server including creating " + "the first admin user. Only accessible when no users exist " + "(onboard_done=false)." + ), + "operationId": "setup", + "tags": ["Server"], + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["username", "password"], + "properties": { + "username": {"type": "string"}, + "password": {"type": "string"}, + "display_name": {"type": "string"}, + }, + } + } + }, + }, + "responses": { + "200": { + "description": "Setup completed successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "success": {"type": "boolean"}, + "token": {"type": "string"}, + "user": {"type": "object"}, + }, + } + } + }, + }, + "400": {"description": "Setup already completed or invalid request"}, + }, + } + }, + "/info": { + "get": { + "summary": "Get server info", + "description": ( + "Returns server information including schema version and authentication status." + ), + "operationId": "get_info", + "tags": ["Server"], + "responses": { + "200": { + "description": "Server information", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "schema_version": {"type": "integer"}, + "server_version": {"type": "string"}, + "onboard_done": {"type": "boolean"}, + "homeassistant_addon": {"type": "boolean"}, + }, + } + } + }, + } + }, + } + }, + } + + # Build OpenAPI spec + return { + "openapi": "3.0.0", + "info": { + "title": "Music Assistant API", + "version": version, + "description": ( + "Music Assistant API provides control over your music library, " + "players, and playback.\n\n" + "This specification documents the API structure and data models. " + "For a complete list of available commands with examples, " + "see the Commands Reference page." + ), + "contact": { + "name": "Music Assistant", + "url": "https://music-assistant.io", + }, + }, + "servers": [{"url": server_url, "description": "Music Assistant Server"}], + "paths": paths, + "components": { + "schemas": definitions, + "securitySchemes": { + "bearerAuth": { + "type": "http", + "scheme": "bearer", + "description": "Access token obtained from /auth/login or /auth/setup", + } + }, + }, + } + + +def _split_union_type(type_str: str) -> list[str]: + """Split a union type on | but respect brackets and parentheses. + + This ensures that list[A | B] and (A | B) are not split at the inner |. + """ + parts = [] + current_part = "" + bracket_depth = 0 + paren_depth = 0 + i = 0 + while i < len(type_str): + char = type_str[i] + if char == "[": + bracket_depth += 1 + current_part += char + elif char == "]": + bracket_depth -= 1 + current_part += char + elif char == "(": + paren_depth += 1 + current_part += char + elif char == ")": + paren_depth -= 1 + current_part += char + elif char == "|" and bracket_depth == 0 and paren_depth == 0: + # Check if this is a union separator (has space before and after) + if ( + i > 0 + and i < len(type_str) - 1 + and type_str[i - 1] == " " + and type_str[i + 1] == " " + ): + parts.append(current_part.strip()) + current_part = "" + i += 1 # Skip the space after |, the loop will handle incrementing i + else: + current_part += char + else: + current_part += char + i += 1 + if current_part.strip(): + parts.append(current_part.strip()) + return parts + + +def _extract_generic_inner_type(type_str: str) -> str | None: + """Extract inner type from generic type like list[T] or dict[K, V]. + + :param type_str: Type string like "list[str]" or "dict[str, int]" + :return: Inner type string "str" or "str, int", or None if not a complete generic type + """ + # Find the matching closing bracket + bracket_count = 0 + start_idx = type_str.index("[") + 1 + end_idx = -1 + for i in range(start_idx, len(type_str)): + if type_str[i] == "[": + bracket_count += 1 + elif type_str[i] == "]": + if bracket_count == 0: + end_idx = i + break + bracket_count -= 1 + + # Check if this is a complete generic type (ends with the closing bracket) + if end_idx == len(type_str) - 1: + return type_str[start_idx:end_idx].strip() + return None + + +def _parse_dict_type_params(inner_type: str) -> tuple[str, str] | None: + """Parse key and value types from dict inner type string. + + :param inner_type: The content inside dict[...], e.g., "str, ConfigValueType" + :return: Tuple of (key_type, value_type) or None if parsing fails + """ + # Split on comma to get key and value types + # Need to be careful with nested types like dict[str, list[int]] + parts = [] + current_part = "" + bracket_depth = 0 + for char in inner_type: + if char == "[": + bracket_depth += 1 + current_part += char + elif char == "]": + bracket_depth -= 1 + current_part += char + elif char == "," and bracket_depth == 0: + parts.append(current_part.strip()) + current_part = "" + else: + current_part += char + if current_part: + parts.append(current_part.strip()) + + if len(parts) == 2: + return parts[0], parts[1] + return None + + +def _python_type_to_json_type(type_str: str, _depth: int = 0) -> str: + """Convert Python type string to JSON/JavaScript type string. + + Args: + type_str: The type string to convert + _depth: Internal recursion depth tracker (do not set manually) + """ + # Prevent infinite recursion + if _depth > 50: + return "any" + + # Remove typing module prefix and class markers + type_str = type_str.replace("typing.", "").replace("", "") + + # Remove module paths from type names (e.g., "music_assistant.models.Artist" -> "Artist") + type_str = re.sub(r"[\w.]+\.(\w+)", r"\1", type_str) + + # Check for type aliases that should be preserved as-is + # These will have schema definitions in the API docs + if type_str in ("ConfigValueType", "MediaItemType"): + return type_str + + # Map Python types to JSON types + type_mappings = { + "str": "string", + "int": "integer", + "float": "number", + "bool": "boolean", + "dict": "object", + "Dict": "object", + "list": "array", + "tuple": "array", + "Tuple": "array", + "None": "null", + "NoneType": "null", + } + + # Check for List/list/UniqueList/tuple with type parameter BEFORE checking for union types + # This is important because list[A | B] contains " | " but should be handled as a list first + # codespell:ignore + if type_str.startswith(("list[", "List[", "UniqueList[", "tuple[", "Tuple[")): + inner_type = _extract_generic_inner_type(type_str) + if inner_type: + # Handle variable-length tuple (e.g., tuple[str, ...]) + # The ellipsis means "variable length of this type" + if inner_type.endswith(", ..."): + # Remove the ellipsis and just use the type + inner_type = inner_type[:-5].strip() + # Recursively convert the inner type + inner_json_type = _python_type_to_json_type(inner_type, _depth + 1) + # For list[A | B], wrap in parentheses to keep it as one unit + # This prevents "Array of A | B" from being split into separate union parts + if " | " in inner_json_type: + return f"Array of ({inner_json_type})" + return f"Array of {inner_json_type}" + + # Check for dict/Dict with type parameters BEFORE checking for union types + # This is important because dict[str, A | B] contains " | " + # but should be handled as a dict first + # codespell:ignore + if type_str.startswith(("dict[", "Dict[")): + inner_type = _extract_generic_inner_type(type_str) + if inner_type: + parsed = _parse_dict_type_params(inner_type) + if parsed: + key_type_str, value_type_str = parsed + key_type = _python_type_to_json_type(key_type_str, _depth + 1) + value_type = _python_type_to_json_type(value_type_str, _depth + 1) + # Use more descriptive format: "object with {key_type} keys and {value_type} values" + return f"object with {key_type} keys and {value_type} values" + + # Handle Union types by splitting on | and recursively processing each part + if " | " in type_str: + # Use helper to split on | but respect brackets + parts = _split_union_type(type_str) + + # Filter out None/null types (None, NoneType, null all mean JSON null) + parts = [part for part in parts if part not in ("None", "NoneType", "null")] + + # If splitting didn't help (only one part or same as input), avoid infinite recursion + if not parts or (len(parts) == 1 and parts[0] == type_str): + # Can't split further, return as-is or "any" + return type_str if parts else "any" + + if parts: + converted_parts = [_python_type_to_json_type(part, _depth + 1) for part in parts] + # Remove duplicates while preserving order + seen = set() + unique_parts = [] + for part in converted_parts: + if part not in seen: + seen.add(part) + unique_parts.append(part) + return " | ".join(unique_parts) + return "any" + + # Check for Union/Optional types with brackets + if "Union[" in type_str or "Optional[" in type_str: + # Extract content from Union[...] or Optional[...] + union_match = re.search(r"(?:Union|Optional)\[([^\]]+)\]", type_str) + if union_match: + inner = union_match.group(1) + # Recursively process the union content + return _python_type_to_json_type(inner, _depth + 1) + + # Direct mapping for basic types + for py_type, json_type in type_mappings.items(): + if type_str == py_type: + return json_type + + # Check if it's a complex type (starts with capital letter) + complex_match = re.search(r"^([A-Z][a-zA-Z0-9_]*)$", type_str) + if complex_match: + return complex_match.group(1) + + # Default to the original string if no mapping found + return type_str + + +def _make_type_links(type_str: str, server_url: str, as_list: bool = False) -> str: + """Convert type string to HTML with links to schemas reference for complex types. + + Args: + type_str: The type string to convert + server_url: Base server URL for building links + as_list: If True and type contains |, format as "Any of:" bullet list + """ + + # Find all complex types (capitalized words that aren't basic types) + def replace_type(match: re.Match[str]) -> str: + type_name = match.group(0) + # Check if it's a complex type (starts with capital letter) + # Exclude basic types and "Array" (which is used in "Array of Type") + excluded = {"Union", "Optional", "List", "Dict", "Array", "None", "NoneType"} + if type_name[0].isupper() and type_name not in excluded: + # Create link to our schemas reference page + schema_url = f"{server_url}/api-docs/schemas#schema-{type_name}" + return f'{type_name}' + return type_name + + # If it's a union type with multiple options and as_list is True, format as bullet list + if as_list and " | " in type_str: + # Use the bracket/parenthesis-aware splitter + parts = _split_union_type(type_str) + # Only use list format if there are 3+ options + if len(parts) >= 3: + html = '
Any of:
    ' + for part in parts: + linked_part = re.sub(r"\b[A-Z][a-zA-Z0-9_]*\b", replace_type, part) + html += f"
  • {linked_part}
  • " + html += "
" + return html + + # Replace complex type names with links + result: str = re.sub(r"\b[A-Z][a-zA-Z0-9_]*\b", replace_type, type_str) + return result + + +def generate_commands_json(command_handlers: dict[str, APICommandHandler]) -> list[dict[str, Any]]: + """Generate JSON representation of all available API commands. + + This is used by client libraries to sync their methods with the server API. + + Returns a list of command objects with the following structure: + { + "command": str, # Command name (e.g., "music/tracks/library_items") + "category": str, # Category (e.g., "Music") + "summary": str, # Short description + "description": str, # Full description + "parameters": [ # List of parameters + { + "name": str, + "type": str, # JSON type (string, integer, boolean, etc.) + "required": bool, + "description": str + } + ], + "return_type": str, # Return type + "authenticated": bool, # Whether authentication is required + "required_role": str | None, # Required user role (if any) + } + """ + commands_data = [] + + for command, handler in sorted(command_handlers.items()): + # Skip aliases - they are for backward compatibility only + if handler.alias: + continue + # Parse docstring + summary, description, param_descriptions = _parse_docstring(handler.target) + + # Get return type + return_type = handler.type_hints.get("return", Any) + # If type is already a string (e.g., "ConfigValueType"), use it directly + return_type_str = _python_type_to_json_type( + return_type if isinstance(return_type, str) else str(return_type) + ) + + # Extract category from command name + category = command.split("/")[0] if "/" in command else "general" + category_display = category.replace("_", " ").title() + + # Build parameters list + parameters = [] + for param_name, param in handler.signature.parameters.items(): + if param_name in ("self", "return_type"): + continue + + is_required = param.default is inspect.Parameter.empty + param_type = handler.type_hints.get(param_name, Any) + # If type is already a string (e.g., "ConfigValueType"), use it directly + type_str = param_type if isinstance(param_type, str) else str(param_type) + json_type_str = _python_type_to_json_type(type_str) + param_desc = param_descriptions.get(param_name, "") + + parameters.append( + { + "name": param_name, + "type": json_type_str, + "required": is_required, + "description": param_desc, + } + ) + + commands_data.append( + { + "command": command, + "category": category_display, + "summary": summary or "", + "description": description or "", + "parameters": parameters, + "return_type": return_type_str, + "authenticated": handler.authenticated, + "required_role": handler.required_role, + } + ) + + return commands_data + + +def generate_schemas_json(command_handlers: dict[str, APICommandHandler]) -> dict[str, Any]: + """Generate JSON representation of all schemas/data models. + + Returns a dict mapping schema names to their OpenAPI schema definitions. + """ + schemas: dict[str, Any] = {} + + for handler in command_handlers.values(): + # Skip aliases - they are for backward compatibility only + if handler.alias: + continue + # Collect schemas from parameters + for param_name in handler.signature.parameters: + if param_name == "self": + continue + # Skip return_type parameter (used only for type hints) + if param_name == "return_type": + continue + param_type = handler.type_hints.get(param_name, Any) + if param_type is not Any and str(param_type) != "typing.Any": + _get_type_schema(param_type, schemas) + + # Collect schemas from return type + return_type = handler.type_hints.get("return", Any) + if return_type is not Any and str(return_type) != "typing.Any": + _get_type_schema(return_type, schemas) + + return schemas diff --git a/music_assistant/controllers/webserver/auth.py b/music_assistant/controllers/webserver/auth.py new file mode 100644 index 00000000..7019c686 --- /dev/null +++ b/music_assistant/controllers/webserver/auth.py @@ -0,0 +1,1216 @@ +"""Authentication manager for Music Assistant webserver.""" + +from __future__ import annotations + +import hashlib +import logging +import secrets +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any + +from music_assistant_models.auth import ( + AuthProviderType, + AuthToken, + User, + UserAuthProvider, + UserRole, +) +from music_assistant_models.errors import ( + AuthenticationFailed, + AuthenticationRequired, + InsufficientPermissions, + InvalidDataError, +) + +from music_assistant.constants import ( + CONF_AUTH_ALLOW_SELF_REGISTRATION, + CONF_ONBOARD_DONE, + HOMEASSISTANT_SYSTEM_USER, + MASS_LOGGER_NAME, +) +from music_assistant.controllers.webserver.helpers.auth_middleware import ( + get_current_token, + get_current_user, +) +from music_assistant.controllers.webserver.helpers.auth_providers import ( + AuthResult, + BuiltinLoginProvider, + HomeAssistantOAuthProvider, + HomeAssistantProviderConfig, + LoginProvider, + LoginProviderConfig, +) +from music_assistant.helpers.api import api_command +from music_assistant.helpers.database import DatabaseConnection +from music_assistant.helpers.datetime import utc +from music_assistant.helpers.json import json_dumps, json_loads + +if TYPE_CHECKING: + from music_assistant.controllers.webserver import WebserverController + +LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.auth") + +# Database schema version +DB_SCHEMA_VERSION = 2 + +# Token expiration constants (in days) +TOKEN_SHORT_LIVED_EXPIRATION = 30 # Short-lived tokens (auto-renewing on use) +TOKEN_LONG_LIVED_EXPIRATION = 3650 # Long-lived tokens (10 years, no auto-renewal) + + +class AuthenticationManager: + """Manager for authentication and user management (part of webserver controller).""" + + def __init__(self, webserver: WebserverController) -> None: + """ + Initialize the authentication manager. + + :param webserver: WebserverController instance. + """ + self.webserver = webserver + self.mass = webserver.mass + self.database: DatabaseConnection = None # type: ignore[assignment] + self.login_providers: dict[str, LoginProvider] = {} + self.logger = LOGGER + + async def setup(self) -> None: + """Initialize the authentication manager.""" + # Get auth settings from config + allow_self_registration = self.webserver.config.get_value(CONF_AUTH_ALLOW_SELF_REGISTRATION) + assert isinstance(allow_self_registration, bool) + + # Setup database + db_path = self.mass.storage_path + "/auth.db" + self.database = DatabaseConnection(db_path) + await self.database.setup() + + # Create database schema and handle migrations + await self._setup_database() + + # Setup login providers based on config + await self._setup_login_providers(allow_self_registration) + + # Migration: Reset onboard_done if no users exist + # This handles existing setups where authentication was optional + if self.mass.config.onboard_done and not await self.has_users(): + self.logger.warning( + "Authentication is mandatory but no users exist. " + "Resetting onboard_done to redirect to setup." + ) + self.mass.config.set(CONF_ONBOARD_DONE, False) + self.mass.config.save(immediate=True) + + self.logger.info( + "Authentication manager initialized (providers=%d)", len(self.login_providers) + ) + + async def close(self) -> None: + """Cleanup on exit.""" + if self.database: + await self.database.close() + + async def _setup_database(self) -> None: + """Set up database schema and handle migrations.""" + # Always create tables if they don't exist + await self._create_database_tables() + + # Check current schema version + try: + if db_row := await self.database.get_row("settings", {"key": "schema_version"}): + prev_version = int(db_row["value"]) + else: + prev_version = 0 + except (KeyError, ValueError, Exception): + # settings table doesn't exist yet or other error + prev_version = 0 + + # Perform migration if needed + if prev_version < DB_SCHEMA_VERSION: + self.logger.warning( + "Performing database migration from schema version %s to %s", + prev_version, + DB_SCHEMA_VERSION, + ) + await self._migrate_database(prev_version) + + # Store current schema version + await self.database.insert_or_replace( + "settings", + {"key": "schema_version", "value": str(DB_SCHEMA_VERSION), "type": "int"}, + ) + + # Create indexes + await self._create_database_indexes() + await self.database.commit() + + async def _create_database_tables(self) -> None: + """Create database tables.""" + # Settings table (for schema version and other settings) + await self.database.execute( + """ + CREATE TABLE IF NOT EXISTS settings ( + key TEXT PRIMARY KEY, + value TEXT, + type TEXT + ) + """ + ) + + # Users table (decoupled from auth providers) + await self.database.execute( + """ + CREATE TABLE IF NOT EXISTS users ( + user_id TEXT PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + role TEXT NOT NULL, + enabled INTEGER DEFAULT 1, + created_at TEXT NOT NULL, + display_name TEXT, + avatar_url TEXT, + preferences TEXT DEFAULT '{}' + ) + """ + ) + + # User auth provider links (many-to-many) + await self.database.execute( + """ + CREATE TABLE IF NOT EXISTS user_auth_providers ( + link_id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + provider_type TEXT NOT NULL, + provider_user_id TEXT NOT NULL, + created_at TEXT NOT NULL, + UNIQUE(provider_type, provider_user_id), + FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE + ) + """ + ) + + # Auth tokens table + await self.database.execute( + """ + CREATE TABLE IF NOT EXISTS auth_tokens ( + token_id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + token_hash TEXT NOT NULL UNIQUE, + name TEXT NOT NULL, + created_at TEXT NOT NULL, + expires_at TEXT, + last_used_at TEXT, + is_long_lived INTEGER DEFAULT 0, + FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE + ) + """ + ) + + await self.database.commit() + + async def _create_database_indexes(self) -> None: + """Create database indexes.""" + await self.database.execute( + "CREATE INDEX IF NOT EXISTS idx_user_auth_providers_user " + "ON user_auth_providers(user_id)" + ) + await self.database.execute( + "CREATE INDEX IF NOT EXISTS idx_user_auth_providers_provider " + "ON user_auth_providers(provider_type, provider_user_id)" + ) + await self.database.execute( + "CREATE INDEX IF NOT EXISTS idx_tokens_user ON auth_tokens(user_id)" + ) + await self.database.execute( + "CREATE INDEX IF NOT EXISTS idx_tokens_hash ON auth_tokens(token_hash)" + ) + + async def _migrate_database(self, from_version: int) -> None: + """Perform database migration. + + :param from_version: The schema version to migrate from. + """ + self.logger.info( + "Migrating auth database from version %s to %s", from_version, DB_SCHEMA_VERSION + ) + # Migration to version 2: Recreate tables due to password salt breaking change + if from_version < 2: + # Drop all auth-related tables + await self.database.execute("DROP TABLE IF EXISTS auth_tokens") + await self.database.execute("DROP TABLE IF EXISTS user_auth_providers") + await self.database.execute("DROP TABLE IF EXISTS users") + await self.database.commit() + + # Recreate tables with current schema + await self._create_database_tables() + + async def _setup_login_providers(self, allow_self_registration: bool) -> None: + """ + Set up available login providers based on configuration. + + :param allow_self_registration: Whether to allow self-registration via OAuth. + """ + # Always enable built-in provider + builtin_config: LoginProviderConfig = {"allow_self_registration": False} + self.login_providers["builtin"] = BuiltinLoginProvider(self.mass, "builtin", builtin_config) + + # Home Assistant OAuth provider + # Automatically enabled if HA provider (plugin) is configured + ha_provider = None + for provider in self.mass.providers: + if provider.domain == "hass" and provider.available: + ha_provider = provider + break + + if ha_provider: + # Get URL from the HA provider config + ha_url = ha_provider.config.get_value("url") + assert isinstance(ha_url, str) + ha_config: HomeAssistantProviderConfig = { + "ha_url": ha_url, + "allow_self_registration": allow_self_registration, + } + self.login_providers["homeassistant"] = HomeAssistantOAuthProvider( + self.mass, "homeassistant", ha_config + ) + self.logger.info( + "Home Assistant OAuth provider enabled (using URL from HA provider: %s)", + ha_url, + ) + + async def _sync_ha_oauth_provider(self) -> None: + """ + Sync HA OAuth provider with HA provider availability (dynamic check). + + Adds the provider if HA is available, removes it if HA is not available. + """ + # Find HA provider + ha_provider = None + for provider in self.mass.providers: + if provider.domain == "hass" and provider.available: + ha_provider = provider + break + + if ha_provider: + # HA provider exists and is available - ensure OAuth provider is registered + if "homeassistant" not in self.login_providers: + # Get allow_self_registration config + allow_self_registration = bool( + self.webserver.config.get_value(CONF_AUTH_ALLOW_SELF_REGISTRATION, True) + ) + + # Get URL from the HA provider config + ha_url = ha_provider.config.get_value("url") + assert isinstance(ha_url, str) + ha_config: HomeAssistantProviderConfig = { + "ha_url": ha_url, + "allow_self_registration": allow_self_registration, + } + self.login_providers["homeassistant"] = HomeAssistantOAuthProvider( + self.mass, "homeassistant", ha_config + ) + self.logger.info( + "Home Assistant OAuth provider dynamically enabled (using URL: %s)", + ha_url, + ) + # HA provider not available - remove OAuth provider if present + elif "homeassistant" in self.login_providers: + del self.login_providers["homeassistant"] + self.logger.info("Home Assistant OAuth provider removed (HA provider not available)") + + async def has_users(self) -> bool: + """Check if any users exist in the system.""" + count = await self.database.get_count("users") + return count > 0 + + async def authenticate_with_credentials( + self, provider_id: str, credentials: dict[str, Any] + ) -> AuthResult: + """ + Authenticate a user with credentials. + + :param provider_id: The login provider ID. + :param credentials: Provider-specific credentials. + """ + provider = self.login_providers.get(provider_id) + if not provider: + return AuthResult(success=False, error="Invalid provider") + + return await provider.authenticate(credentials) + + async def authenticate_with_token(self, token: str) -> User | None: + """ + Authenticate a user with an access token. + + :param token: The access token. + """ + # Hash the token to look it up + token_hash = hashlib.sha256(token.encode()).hexdigest() + + # Find token in database + token_row = await self.database.get_row("auth_tokens", {"token_hash": token_hash}) + if not token_row: + return None + + # Check if token is expired + if token_row["expires_at"]: + expires_at = datetime.fromisoformat(token_row["expires_at"]) + if utc() > expires_at: + # Token expired, delete it + await self.database.delete("auth_tokens", {"token_id": token_row["token_id"]}) + return None + + # Implement sliding expiration for short-lived tokens + is_long_lived = bool(token_row["is_long_lived"]) + now = utc() + updates = {"last_used_at": now.isoformat()} + + if not is_long_lived and token_row["expires_at"]: + # Short-lived token: extend expiration on each use (sliding window) + new_expires_at = now + timedelta(days=TOKEN_SHORT_LIVED_EXPIRATION) + updates["expires_at"] = new_expires_at.isoformat() + + # Update last used timestamp and potentially expiration + await self.database.update( + "auth_tokens", + {"token_id": token_row["token_id"]}, + updates, + ) + + # Get user + return await self.get_user(token_row["user_id"]) + + async def get_token_id_from_token(self, token: str) -> str | None: + """ + Get token_id from a token string (for tracking revocation). + + :param token: The access token. + :return: The token_id or None if token not found. + """ + # Hash the token to look it up + token_hash = hashlib.sha256(token.encode()).hexdigest() + + # Find token in database + token_row = await self.database.get_row("auth_tokens", {"token_hash": token_hash}) + if not token_row: + return None + + return str(token_row["token_id"]) + + @api_command("auth/user", required_role="admin") + async def get_user(self, user_id: str) -> User | None: + """ + Get user by ID (admin only). + + :param user_id: The user ID. + :return: User object or None if not found. + """ + user_row = await self.database.get_row("users", {"user_id": user_id}) + if not user_row or not user_row["enabled"]: + return None + + # Convert Row to dict for easier handling of optional fields + user_dict = dict(user_row) + + # Parse preferences from JSON + preferences = {} + if prefs_json := user_dict.get("preferences"): + try: + preferences = json_loads(prefs_json) + except Exception: + self.logger.warning("Failed to parse preferences for user %s", user_id) + + return User( + user_id=user_dict["user_id"], + username=user_dict["username"], + role=UserRole(user_dict["role"]), + enabled=bool(user_dict["enabled"]), + created_at=datetime.fromisoformat(user_dict["created_at"]), + display_name=user_dict.get("display_name"), + avatar_url=user_dict.get("avatar_url"), + preferences=preferences, + ) + + async def get_user_by_provider_link( + self, provider_type: AuthProviderType, provider_user_id: str + ) -> User | None: + """ + Get user by their provider link. + + :param provider_type: The auth provider type. + :param provider_user_id: The user ID from the provider. + """ + link_row = await self.database.get_row( + "user_auth_providers", + { + "provider_type": provider_type.value, + "provider_user_id": provider_user_id, + }, + ) + if not link_row: + return None + + return await self.get_user(link_row["user_id"]) + + async def create_user( + self, + username: str, + role: UserRole = UserRole.USER, + display_name: str | None = None, + avatar_url: str | None = None, + preferences: dict[str, Any] | None = None, + ) -> User: + """ + Create a new user. + + :param username: The username. + :param role: The user role (default: USER). + :param display_name: Optional display name. + :param avatar_url: Optional avatar URL. + :param preferences: Optional user preferences dict. + """ + user_id = secrets.token_urlsafe(32) + created_at = utc() + if preferences is None: + preferences = {} + + user_data = { + "user_id": user_id, + "username": username, + "role": role.value, + "enabled": True, + "created_at": created_at.isoformat(), + "display_name": display_name, + "avatar_url": avatar_url, + "preferences": json_dumps(preferences), + } + + await self.database.insert("users", user_data) + + return User( + user_id=user_id, + username=username, + role=role, + enabled=True, + created_at=created_at, + display_name=display_name, + avatar_url=avatar_url, + preferences=preferences, + ) + + async def get_homeassistant_system_user(self) -> User: + """ + Get or create the Home Assistant system user. + + This is a special system user created automatically for Home Assistant integration. + It bypasses normal authentication but is restricted to the ingress webserver. + + :return: The Home Assistant system user. + """ + username = HOMEASSISTANT_SYSTEM_USER + display_name = "Home Assistant Integration" + role = UserRole.USER + + # Try to find existing user by username + user_row = await self.database.get_row("users", {"username": username}) + if user_row: + # Use get_user to ensure preferences are parsed correctly + user = await self.get_user(user_row["user_id"]) + assert user is not None # User exists in DB, so get_user must return it + return user + + # Create new system user + user = await self.create_user( + username=username, + role=role, + display_name=display_name, + ) + self.logger.debug("Created Home Assistant system user: %s (role: %s)", username, role.value) + return user + + async def get_homeassistant_system_user_token(self) -> str: + """ + Get or create an auth token for the Home Assistant system user. + + This method ensures only one active token exists for the HA integration. + If an old token exists, it is deleted and a new one is created. + The token auto-renews on use (expires after 30 days of inactivity). + + :return: Authentication token for the Home Assistant system user. + """ + token_name = "Home Assistant Integration" + + # Get the system user + system_user = await self.get_homeassistant_system_user() + + # Delete any existing tokens with this name to avoid accumulation + # We can't retrieve the plain token from the hash, so we always create a new one + existing_tokens = await self.database.get_rows( + "auth_tokens", + {"user_id": system_user.user_id, "name": token_name}, + ) + for token_row in existing_tokens: + await self.database.delete("auth_tokens", {"token_id": token_row["token_id"]}) + + # Create a new token for the system user + return await self.create_token( + user=system_user, + name=token_name, + is_long_lived=False, + ) + + async def link_user_to_provider( + self, + user: User, + provider_type: AuthProviderType, + provider_user_id: str, + ) -> UserAuthProvider: + """ + Link a user to an authentication provider. + + :param user: The user to link. + :param provider_type: The provider type. + :param provider_user_id: The user ID from the provider (e.g., password hash, OAuth ID). + """ + link_id = secrets.token_urlsafe(32) + created_at = utc() + link_data = { + "link_id": link_id, + "user_id": user.user_id, + "provider_type": provider_type.value, + "provider_user_id": provider_user_id, + "created_at": created_at.isoformat(), + } + + await self.database.insert("user_auth_providers", link_data) + + return UserAuthProvider( + link_id=link_id, + user_id=user.user_id, + provider_type=provider_type, + provider_user_id=provider_user_id, + created_at=created_at, + ) + + async def update_user( + self, + user: User, + username: str | None = None, + display_name: str | None = None, + avatar_url: str | None = None, + ) -> User: + """ + Update a user's profile information. + + :param user: The user to update. + :param username: New username (optional). + :param display_name: New display name (optional). + :param avatar_url: New avatar URL (optional). + """ + updates = {} + if username is not None: + updates["username"] = username + if display_name is not None: + updates["display_name"] = display_name + if avatar_url is not None: + updates["avatar_url"] = avatar_url + + if updates: + await self.database.update("users", {"user_id": user.user_id}, updates) + + # Return updated user + updated_user = await self.get_user(user.user_id) + assert updated_user is not None # User exists, so get_user must return it + return updated_user + + async def update_user_preferences( + self, + user: User, + preferences: dict[str, Any], + ) -> User: + """ + Update a user's preferences. + + :param user: The user to update. + :param preferences: New preferences dict (completely replaces existing preferences). + """ + # Verify user exists + current_user = await self.get_user(user.user_id) + if not current_user: + raise ValueError(f"User {user.user_id} not found") + + # Update database with new preferences (complete replacement) + await self.database.update( + "users", + {"user_id": user.user_id}, + {"preferences": json_dumps(preferences)}, + ) + + # Return updated user + updated_user = await self.get_user(user.user_id) + assert updated_user is not None # User exists, so get_user must return it + return updated_user + + async def update_provider_link( + self, + user: User, + provider_type: AuthProviderType, + provider_user_id: str, + ) -> None: + """ + Update a user's provider link (e.g., change password). + + :param user: The user. + :param provider_type: The provider type. + :param provider_user_id: The new provider user ID (e.g., new password hash). + """ + # Find existing link + link_row = await self.database.get_row( + "user_auth_providers", + { + "user_id": user.user_id, + "provider_type": provider_type.value, + }, + ) + + if link_row: + # Update existing link + await self.database.update( + "user_auth_providers", + {"link_id": link_row["link_id"]}, + {"provider_user_id": provider_user_id}, + ) + else: + # Create new link + await self.link_user_to_provider(user, provider_type, provider_user_id) + + async def create_token(self, user: User, name: str, is_long_lived: bool = False) -> str: + """ + Create a new access token for a user. + + :param user: The user to create the token for. + :param name: A name/description for the token (e.g., device name). + :param is_long_lived: Whether this is a long-lived token (default: False). + Short-lived tokens (False): Auto-renewing on use, expire after 30 days of inactivity. + Long-lived tokens (True): No auto-renewal, expire after 10 years. + """ + # Generate token + token = secrets.token_urlsafe(48) + token_hash = hashlib.sha256(token.encode()).hexdigest() + + # Calculate expiration based on token type + created_at = utc() + if is_long_lived: + # Long-lived tokens expire after 10 years (no auto-renewal) + expires_at = created_at + timedelta(days=TOKEN_LONG_LIVED_EXPIRATION) + else: + # Short-lived tokens expire after 30 days (with auto-renewal on use) + expires_at = created_at + timedelta(days=TOKEN_SHORT_LIVED_EXPIRATION) + + # Store token + token_data = { + "token_id": secrets.token_urlsafe(32), + "user_id": user.user_id, + "token_hash": token_hash, + "name": name, + "created_at": created_at.isoformat(), + "expires_at": expires_at.isoformat(), + "is_long_lived": 1 if is_long_lived else 0, + } + await self.database.insert("auth_tokens", token_data) + + return token + + @api_command("auth/token/revoke") + async def revoke_token(self, token_id: str) -> None: + """ + Revoke an auth token. + + :param token_id: The token ID to revoke. + """ + user = get_current_user() + if not user: + raise AuthenticationRequired("Not authenticated") + + token_row = await self.database.get_row("auth_tokens", {"token_id": token_id}) + if not token_row: + raise InvalidDataError("Token not found") + + # Check permissions - users can only revoke their own tokens unless admin + if token_row["user_id"] != user.user_id and user.role != UserRole.ADMIN: + raise InsufficientPermissions("You can only revoke your own tokens") + + await self.database.delete("auth_tokens", {"token_id": token_id}) + + # Disconnect any WebSocket connections using this token + self.webserver.disconnect_websockets_for_token(token_id) + + @api_command("auth/tokens") + async def get_user_tokens(self, user_id: str | None = None) -> list[AuthToken]: + """ + Get current user's auth tokens or another user's tokens (admin only). + + :param user_id: Optional user ID to get tokens for (admin only). + :return: List of auth tokens. + """ + current_user = get_current_user() + if not current_user: + return [] + + # If user_id is provided and different from current user, require admin + if user_id and user_id != current_user.user_id: + if current_user.role != UserRole.ADMIN: + return [] + target_user = await self.get_user(user_id) + if not target_user: + return [] + else: + target_user = current_user + + token_rows = await self.database.get_rows( + "auth_tokens", {"user_id": target_user.user_id}, limit=100 + ) + return [AuthToken.from_dict(dict(row)) for row in token_rows] + + @api_command("auth/users", required_role="admin") + async def list_users(self) -> list[User]: + """ + Get all users (admin only). + + System users are excluded from the list. + + :return: List of user objects. + """ + user_rows = await self.database.get_rows("users", limit=1000) + users = [] + for row in user_rows: + row_dict = dict(row) + + # Skip system users + if row_dict["username"] == HOMEASSISTANT_SYSTEM_USER: + continue + + # Parse preferences + preferences = {} + if prefs_json := row_dict.get("preferences"): + try: + preferences = json_loads(prefs_json) + except Exception: + self.logger.warning( + "Failed to parse preferences for user %s", row_dict["user_id"] + ) + + users.append( + User( + user_id=row_dict["user_id"], + username=row_dict["username"], + role=UserRole(row_dict["role"]), + enabled=bool(row_dict["enabled"]), + created_at=datetime.fromisoformat(row_dict["created_at"]), + display_name=row_dict.get("display_name"), + avatar_url=row_dict.get("avatar_url"), + preferences=preferences, + ) + ) + return users + + async def update_user_role(self, user_id: str, new_role: UserRole, admin_user: User) -> bool: + """ + Update a user's role (admin only). + + :param user_id: The user ID to update. + :param new_role: The new role to assign. + :param admin_user: The admin user performing the action. + """ + if admin_user.role != UserRole.ADMIN: + return False + + user_row = await self.database.get_row("users", {"user_id": user_id}) + if not user_row: + return False + + await self.database.update( + "users", + {"user_id": user_id}, + {"role": new_role.value}, + ) + return True + + @api_command("auth/user/enable", required_role="admin") + async def enable_user(self, user_id: str) -> None: + """ + Enable user account (admin only). + + :param user_id: The user ID. + """ + await self.database.update( + "users", + {"user_id": user_id}, + {"enabled": 1}, + ) + + @api_command("auth/user/disable", required_role="admin") + async def disable_user(self, user_id: str) -> None: + """ + Disable user account (admin only). + + :param user_id: The user ID. + """ + admin_user = get_current_user() + if not admin_user: + raise AuthenticationRequired("Not authenticated") + + # Cannot disable yourself + if user_id == admin_user.user_id: + raise InvalidDataError("Cannot disable your own account") + + await self.database.update( + "users", + {"user_id": user_id}, + {"enabled": 0}, + ) + + # Disconnect all WebSocket connections for this user + self.webserver.disconnect_websockets_for_user(user_id) + + async def get_login_providers(self) -> list[dict[str, Any]]: + """Get list of available login providers (dynamically checks for HA provider).""" + # Sync HA OAuth provider with HA provider availability + await self._sync_ha_oauth_provider() + + providers = [] + for provider_id, provider in self.login_providers.items(): + providers.append( + { + "provider_id": provider_id, + "provider_type": provider.provider_type.value, + "requires_redirect": provider.requires_redirect, + } + ) + return providers + + async def get_authorization_url( + self, provider_id: str, return_url: str | None = None + ) -> str | None: + """ + Get OAuth authorization URL for a provider. + + :param provider_id: The provider ID. + :param return_url: Optional URL to redirect to after successful login. + """ + provider = self.login_providers.get(provider_id) + if not provider or not provider.requires_redirect: + return None + + # Build callback redirect_uri + redirect_uri = f"{self.webserver.base_url}/auth/callback?provider_id={provider_id}" + return await provider.get_authorization_url(redirect_uri, return_url) + + async def handle_oauth_callback( + self, provider_id: str, code: str, state: str, redirect_uri: str + ) -> AuthResult: + """ + Handle OAuth callback. + + :param provider_id: The provider ID. + :param code: OAuth authorization code. + :param state: OAuth state parameter. + :param redirect_uri: The callback URL. + """ + provider = self.login_providers.get(provider_id) + if not provider: + return AuthResult(success=False, error="Invalid provider") + + return await provider.handle_oauth_callback(code, state, redirect_uri) + + @api_command("auth/token/create") + async def create_long_lived_token(self, name: str, user_id: str | None = None) -> str: + """ + Create a new long-lived access token for current user or another user (admin only). + + Long-lived tokens are intended for external integrations and API access. + They expire after 10 years and do NOT auto-renew on use. + + Short-lived tokens (for regular user sessions) are only created during login + and auto-renew on each use (sliding 30-day expiration window). + + :param name: The name/description for the token (e.g., "Home Assistant", "Mobile App"). + :param user_id: Optional user ID to create token for (admin only). + :return: The created token string. + """ + current_user = get_current_user() + if not current_user: + raise AuthenticationRequired("Not authenticated") + + # If user_id is provided and different from current user, require admin + if user_id and user_id != current_user.user_id: + if current_user.role != UserRole.ADMIN: + raise InsufficientPermissions( + "Admin access required to create tokens for other users" + ) + target_user = await self.get_user(user_id) + if not target_user: + raise InvalidDataError("User not found") + else: + target_user = current_user + + # Create a long-lived token (only long-lived tokens can be created via this command) + token = await self.create_token(target_user, name, is_long_lived=True) + self.logger.info("Created long-lived token '%s' for user '%s'", name, target_user.username) + return token + + @api_command("auth/user/create", required_role="admin") + async def create_user_with_api( + self, + username: str, + password: str, + role: str = "user", + display_name: str | None = None, + avatar_url: str | None = None, + ) -> User: + """ + Create a new user with built-in authentication (admin only). + + :param username: The username (minimum 3 characters). + :param password: The password (minimum 8 characters). + :param role: User role - "admin" or "user" (default: "user"). + :param display_name: Optional display name. + :param avatar_url: Optional avatar URL. + :return: Created user object. + """ + # Validation + if not username or len(username) < 3: + raise InvalidDataError("Username must be at least 3 characters") + + if not password or len(password) < 8: + raise InvalidDataError("Password must be at least 8 characters") + + # Validate role + try: + user_role = UserRole(role) + except ValueError as err: + raise InvalidDataError("Invalid role. Must be 'admin' or 'user'") from err + + # Get built-in provider + builtin_provider = self.login_providers.get("builtin") + if not builtin_provider or not isinstance(builtin_provider, BuiltinLoginProvider): + raise InvalidDataError("Built-in auth provider not available") + + # Create user with password + user = await builtin_provider.create_user_with_password(username, password, role=user_role) + + # Update optional fields if provided + if display_name or avatar_url: + updated_user = await self.update_user( + user, display_name=display_name, avatar_url=avatar_url + ) + if updated_user: + user = updated_user + + self.logger.info("User created by admin: %s (role: %s)", username, role) + return user + + @api_command("auth/user/delete", required_role="admin") + async def delete_user(self, user_id: str) -> None: + """ + Delete user account (admin only). + + :param user_id: The user ID. + """ + admin_user = get_current_user() + if not admin_user: + raise AuthenticationRequired("Not authenticated") + + # Don't allow deleting yourself + if user_id == admin_user.user_id: + raise InvalidDataError("Cannot delete your own account") + + # Delete user from database + await self.database.delete("users", {"user_id": user_id}) + await self.database.commit() + + # Disconnect all WebSocket connections for this user + self.webserver.disconnect_websockets_for_user(user_id) + + @api_command("auth/me") + async def get_current_user_info(self) -> User: + """Get current authenticated user information.""" + current_user_obj = get_current_user() + if not current_user_obj: + raise AuthenticationRequired("Not authenticated") + return current_user_obj + + async def _update_profile_password( + self, + target_user: User, + password: str, + old_password: str | None, + is_admin_update: bool, + current_user: User, + ) -> None: + """Update user password (helper method).""" + if len(password) < 8: + raise InvalidDataError("Password must be at least 8 characters") + + builtin_provider = self.login_providers.get("builtin") + if not builtin_provider or not isinstance(builtin_provider, BuiltinLoginProvider): + raise InvalidDataError("Built-in auth not available") + + if is_admin_update: + # Admin can reset password without old password + await builtin_provider.reset_password(target_user, password) + self.logger.info( + "Password reset for user %s by admin %s", + target_user.username, + current_user.username, + ) + else: + # User updating own password - requires old password verification + if not old_password: + raise InvalidDataError("old_password is required to change your own password") + + # Verify old password and change to new one + success = await builtin_provider.change_password(target_user, old_password, password) + if not success: + raise AuthenticationFailed("Invalid current password") + + self.logger.info("Password changed for user %s", target_user.username) + + @api_command("auth/user/update") + async def update_user_profile( + self, + user_id: str | None = None, + username: str | None = None, + display_name: str | None = None, + avatar_url: str | None = None, + password: str | None = None, + old_password: str | None = None, + role: str | None = None, + preferences: dict[str, Any] | None = None, + ) -> User: + """ + Update user profile information. + + Users can update their own profile. Admins can update any user including role and password. + + :param user_id: User ID to update (optional, defaults to current user). + :param username: New username (optional). + :param display_name: New display name (optional). + :param avatar_url: New avatar URL (optional). + :param password: New password (optional, minimum 8 characters). + :param old_password: Current password (required when user updates own password). + :param role: New role - "admin" or "user" (optional, admin only). + :param preferences: User preferences dict (completely replaces existing, optional). + :return: Updated user object. + """ + current_user_obj = get_current_user() + if not current_user_obj: + raise AuthenticationRequired("Not authenticated") + + # Determine target user + if user_id and user_id != current_user_obj.user_id: + # Updating another user - requires admin + if current_user_obj.role != UserRole.ADMIN: + raise InsufficientPermissions("Admin access required") + target_user = await self.get_user(user_id) + if not target_user: + raise InvalidDataError("User not found") + is_admin_update = True + else: + # Updating own profile + target_user = current_user_obj + is_admin_update = False + + # Update role (admin only) + if role: + if not is_admin_update: + raise InsufficientPermissions("Only admins can update user roles") + + try: + new_role = UserRole(role) + except ValueError as err: + raise InvalidDataError("Invalid role. Must be 'admin' or 'user'") from err + + success = await self.update_user_role(target_user.user_id, new_role, current_user_obj) + if not success: + raise InvalidDataError("Failed to update role") + + # Refresh target user to get updated role + refreshed_user = await self.get_user(target_user.user_id) + if not refreshed_user: + raise InvalidDataError("Failed to refresh user after role update") + target_user = refreshed_user + + # Update basic profile fields + if username or display_name or avatar_url: + updated_user = await self.update_user( + target_user, + username=username, + display_name=display_name, + avatar_url=avatar_url, + ) + if not updated_user: + raise InvalidDataError("Failed to update user profile") + target_user = updated_user + + # Update preferences if provided + if preferences is not None: + target_user = await self.update_user_preferences(target_user, preferences) + + # Update password if provided + if password: + await self._update_profile_password( + target_user, password, old_password, is_admin_update, current_user_obj + ) + + return target_user + + @api_command("auth/logout") + async def logout(self) -> None: + """Logout current user by revoking the current token.""" + user = get_current_user() + if not user: + raise AuthenticationRequired("Not authenticated") + + # Get current token from context + token = get_current_token() + if not token: + raise InvalidDataError("No token in context") + + # Find and revoke the token + token_hash = hashlib.sha256(token.encode()).hexdigest() + token_row = await self.database.get_row("auth_tokens", {"token_hash": token_hash}) + if token_row: + await self.database.delete("auth_tokens", {"token_id": token_row["token_id"]}) + + # Disconnect any WebSocket connections using this token + self.webserver.disconnect_websockets_for_token(token_row["token_id"]) + + @api_command("auth/user/providers") + async def get_my_providers(self) -> list[dict[str, Any]]: + """ + Get current user's linked authentication providers. + + :return: List of provider links. + """ + user = get_current_user() + if not user: + return [] + + # Get provider links from database + rows = await self.database.get_rows("user_auth_providers", {"user_id": user.user_id}) + providers = [UserAuthProvider.from_dict(dict(row)) for row in rows] + return [p.to_dict() for p in providers] + + @api_command("auth/user/unlink_provider", required_role="admin") + async def unlink_provider(self, user_id: str, provider_type: str) -> bool: + """ + Unlink authentication provider from user (admin only). + + :param user_id: The user ID. + :param provider_type: Provider type to unlink. + :return: True if successful. + """ + await self.database.delete( + "user_auth_providers", {"user_id": user_id, "provider_type": provider_type} + ) + await self.database.commit() + return True diff --git a/music_assistant/controllers/webserver/controller.py b/music_assistant/controllers/webserver/controller.py new file mode 100644 index 00000000..03d8386d --- /dev/null +++ b/music_assistant/controllers/webserver/controller.py @@ -0,0 +1,1045 @@ +""" +Controller that manages the builtin webserver that hosts the api and frontend. + +Unlike the streamserver (which is as simple and unprotected as possible), +this webserver allows for more fine grained configuration to better secure it. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import html +import json +import os +import urllib.parse +from collections.abc import Awaitable, Callable +from concurrent import futures +from functools import partial +from typing import TYPE_CHECKING, Any, Final, cast +from urllib.parse import quote + +import aiofiles +from aiohttp import ClientTimeout, web +from mashumaro.exceptions import MissingField +from music_assistant_frontend import where as locate_frontend +from music_assistant_models.api import CommandMessage +from music_assistant_models.auth import AuthProviderType, User, UserRole +from music_assistant_models.config_entries import ConfigEntry, ConfigValueOption +from music_assistant_models.enums import ConfigEntryType + +from music_assistant.constants import ( + CONF_AUTH_ALLOW_SELF_REGISTRATION, + CONF_BIND_IP, + CONF_BIND_PORT, + CONF_ONBOARD_DONE, + RESOURCES_DIR, + VERBOSE_LOG_LEVEL, +) +from music_assistant.helpers.api import parse_arguments +from music_assistant.helpers.audio import get_preview_stream +from music_assistant.helpers.json import json_dumps, json_loads +from music_assistant.helpers.redirect_validation import is_allowed_redirect_url +from music_assistant.helpers.util import get_ip_addresses +from music_assistant.helpers.webserver import Webserver +from music_assistant.models.core_controller import CoreController + +from .api_docs import generate_commands_json, generate_openapi_spec, generate_schemas_json +from .auth import AuthenticationManager +from .helpers.auth_middleware import ( + get_authenticated_user, + is_request_from_ingress, + set_current_user, +) +from .helpers.auth_providers import BuiltinLoginProvider +from .websocket_client import WebsocketClientHandler + +if TYPE_CHECKING: + from music_assistant_models.config_entries import ConfigValueType, CoreConfig + + from music_assistant import MusicAssistant + +DEFAULT_SERVER_PORT = 8095 +INGRESS_SERVER_PORT = 8094 +CONF_BASE_URL = "base_url" +MAX_PENDING_MSG = 512 +CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError) + + +class WebserverController(CoreController): + """Core Controller that manages the builtin webserver that hosts the api and frontend.""" + + domain: str = "webserver" + + def __init__(self, mass: MusicAssistant) -> None: + """Initialize instance.""" + super().__init__(mass) + self._server = Webserver(self.logger, enable_dynamic_routes=True) + self.register_dynamic_route = self._server.register_dynamic_route + self.unregister_dynamic_route = self._server.unregister_dynamic_route + self.clients: set[WebsocketClientHandler] = set() + self.manifest.name = "Web Server (frontend and api)" + self.manifest.description = ( + "The built-in webserver that hosts the Music Assistant Websockets API and frontend" + ) + self.manifest.icon = "web-box" + self.auth = AuthenticationManager(self) + + @property + def base_url(self) -> str: + """Return the base_url for the streamserver.""" + return self._server.base_url + + async def get_config_entries( + self, + action: str | None = None, + values: dict[str, ConfigValueType] | None = None, + ) -> tuple[ConfigEntry, ...]: + """Return all Config Entries for this core module (if any).""" + ip_addresses = await get_ip_addresses() + default_publish_ip = ip_addresses[0] + default_base_url = f"http://{default_publish_ip}:{DEFAULT_SERVER_PORT}" + return ( + ConfigEntry( + key="webserver_warn", + type=ConfigEntryType.ALERT, + label="Please note that the webserver is unprotected. " + "Never ever expose the webserver directly to the internet! \n\n" + "Use a reverse proxy or VPN to secure access.", + required=False, + ), + ConfigEntry( + key=CONF_BASE_URL, + type=ConfigEntryType.STRING, + default_value=default_base_url, + label="Base URL", + description="The (base) URL to reach this webserver in the network. \n" + "Override this in advanced scenarios where for example you're running " + "the webserver behind a reverse proxy.", + ), + ConfigEntry( + key=CONF_BIND_PORT, + type=ConfigEntryType.INTEGER, + default_value=DEFAULT_SERVER_PORT, + label="TCP Port", + description="The TCP port to run the webserver.", + ), + ConfigEntry( + key=CONF_BIND_IP, + type=ConfigEntryType.STRING, + default_value="0.0.0.0", + options=[ConfigValueOption(x, x) for x in {"0.0.0.0", *ip_addresses}], + label="Bind to IP/interface", + description="Bind the (web)server to this specific interface. \n" + "Use 0.0.0.0 to bind to all interfaces. \n" + "Set this address for example to a docker-internal network, " + "when you are running a reverse proxy to enhance security and " + "protect outside access to the webinterface and API. \n\n" + "This is an advanced setting that should normally " + "not be adjusted in regular setups.", + category="advanced", + ), + ConfigEntry( + key=CONF_AUTH_ALLOW_SELF_REGISTRATION, + type=ConfigEntryType.BOOLEAN, + default_value=True, + label="Allow Self-Registration", + description="Allow users to create accounts via Home Assistant OAuth. \n" + "New users will have USER role by default.", + category="advanced", + hidden=not any(provider.domain == "hass" for provider in self.mass.providers), + ), + ) + + async def setup(self, config: CoreConfig) -> None: # noqa: PLR0915 + """Async initialize of module.""" + self.config = config + # work out all routes + routes: list[tuple[str, str, Callable[[web.Request], Awaitable[web.StreamResponse]]]] = [] + # frontend routes + frontend_dir = locate_frontend() + for filename in next(os.walk(frontend_dir))[2]: + if filename.endswith(".py"): + continue + filepath = os.path.join(frontend_dir, filename) + handler = partial(self._server.serve_static, filepath) + routes.append(("GET", f"/{filename}", handler)) + # add index + index_path = os.path.join(frontend_dir, "index.html") + handler = partial(self._server.serve_static, index_path) + routes.append(("GET", "/", handler)) + # add logo + logo_path = str(RESOURCES_DIR.joinpath("logo.png")) + handler = partial(self._server.serve_static, logo_path) + routes.append(("GET", "/logo.png", handler)) + # add common CSS for HTML resources + common_css_path = str(RESOURCES_DIR.joinpath("common.css")) + handler = partial(self._server.serve_static, common_css_path) + routes.append(("GET", "/resources/common.css", handler)) + # add info + routes.append(("GET", "/info", self._handle_server_info)) + routes.append(("OPTIONS", "/info", self._handle_cors_preflight)) + # add logging + routes.append(("GET", "/music-assistant.log", self._handle_application_log)) + # add websocket api + routes.append(("GET", "/ws", self._handle_ws_client)) + # also host the image proxy on the webserver + routes.append(("GET", "/imageproxy", self.mass.metadata.handle_imageproxy)) + # also host the audio preview service + routes.append(("GET", "/preview", self.serve_preview_stream)) + # add jsonrpc api + routes.append(("POST", "/api", self._handle_jsonrpc_api_command)) + # add api documentation + routes.append(("GET", "/api-docs", self._handle_api_intro)) + routes.append(("GET", "/api-docs/", self._handle_api_intro)) + routes.append(("GET", "/api-docs/commands", self._handle_commands_reference)) + routes.append(("GET", "/api-docs/commands/", self._handle_commands_reference)) + routes.append(("GET", "/api-docs/commands.json", self._handle_commands_json)) + routes.append(("GET", "/api-docs/schemas", self._handle_schemas_reference)) + routes.append(("GET", "/api-docs/schemas/", self._handle_schemas_reference)) + routes.append(("GET", "/api-docs/schemas.json", self._handle_schemas_json)) + routes.append(("GET", "/api-docs/openapi.json", self._handle_openapi_spec)) + routes.append(("GET", "/api-docs/swagger", self._handle_swagger_ui)) + routes.append(("GET", "/api-docs/swagger/", self._handle_swagger_ui)) + # add authentication routes + routes.append(("GET", "/login", self._handle_login_page)) + routes.append(("POST", "/auth/login", self._handle_auth_login)) + routes.append(("OPTIONS", "/auth/login", self._handle_cors_preflight)) + routes.append(("POST", "/auth/logout", self._handle_auth_logout)) + routes.append(("GET", "/auth/me", self._handle_auth_me)) + routes.append(("PATCH", "/auth/me", self._handle_auth_me_update)) + routes.append(("GET", "/auth/providers", self._handle_auth_providers)) + routes.append(("GET", "/auth/authorize", self._handle_auth_authorize)) + routes.append(("GET", "/auth/callback", self._handle_auth_callback)) + # add first-time setup routes + routes.append(("GET", "/setup", self._handle_setup_page)) + routes.append(("POST", "/setup", self._handle_setup)) + # Initialize authentication manager + await self.auth.setup() + # start the webserver + all_ip_addresses = await get_ip_addresses() + default_publish_ip = all_ip_addresses[0] + if self.mass.running_as_hass_addon: + # if we're running on the HA supervisor we start an additional TCP site + # on the internal ("172.30.32.) IP for the HA ingress proxy + ingress_host = next( + (x for x in all_ip_addresses if x.startswith("172.30.32.")), default_publish_ip + ) + ingress_tcp_site_params = (ingress_host, INGRESS_SERVER_PORT) + else: + ingress_tcp_site_params = None + base_url = str(config.get_value(CONF_BASE_URL)) + port_value = config.get_value(CONF_BIND_PORT) + assert isinstance(port_value, int) + self.publish_port = port_value + self.publish_ip = default_publish_ip + bind_ip = cast("str | None", config.get_value(CONF_BIND_IP)) + # print a big fat message in the log where the webserver is running + # because this is a common source of issues for people with more complex setups + if not self.mass.config.onboard_done: + self.logger.warning( + "\n\n################################################################################\n" + "Starting webserver on %s:%s - base url: %s\n" + "If this is incorrect, see the documentation how to configure the Webserver\n" + "in Settings --> Core modules --> Webserver\n" + "################################################################################\n", + bind_ip, + self.publish_port, + base_url, + ) + else: + self.logger.info( + "Starting webserver on %s:%s - base url: %s\n#\n", + bind_ip, + self.publish_port, + base_url, + ) + await self._server.setup( + bind_ip=bind_ip, + bind_port=self.publish_port, + base_url=base_url, + static_routes=routes, + # add assets subdir as static_content + static_content=("/assets", os.path.join(frontend_dir, "assets"), "assets"), + ingress_tcp_site_params=ingress_tcp_site_params, + # Add mass object to app for use in auth middleware + app_state={"mass": self.mass}, + ) + if self.mass.running_as_hass_addon: + # announce to HA supervisor + await self._announce_to_homeassistant() + + async def close(self) -> None: + """Cleanup on exit.""" + for client in set(self.clients): + await client.disconnect() + await self._server.close() + await self.auth.close() + + def register_websocket_client(self, client: WebsocketClientHandler) -> None: + """Register a WebSocket client for tracking.""" + self.clients.add(client) + + def unregister_websocket_client(self, client: WebsocketClientHandler) -> None: + """Unregister a WebSocket client.""" + self.clients.discard(client) + + def disconnect_websockets_for_token(self, token_id: str) -> None: + """Disconnect all WebSocket clients using a specific token.""" + for client in list(self.clients): + if hasattr(client, "_token_id") and client._token_id == token_id: + username = ( + client._authenticated_user.username if client._authenticated_user else "unknown" + ) + self.logger.warning( + "Disconnecting WebSocket client due to token revocation: %s", + username, + ) + client._cancel() + + def disconnect_websockets_for_user(self, user_id: str) -> None: + """Disconnect all WebSocket clients for a specific user.""" + for client in list(self.clients): + if ( + hasattr(client, "_authenticated_user") + and client._authenticated_user + and client._authenticated_user.user_id == user_id + ): + self.logger.warning( + "Disconnecting WebSocket client due to user action: %s", + client._authenticated_user.username, + ) + client._cancel() + + async def serve_preview_stream(self, request: web.Request) -> web.StreamResponse: + """Serve short preview sample.""" + provider_instance_id_or_domain = request.query["provider"] + item_id = urllib.parse.unquote(request.query["item_id"]) + resp = web.StreamResponse(status=200, reason="OK", headers={"Content-Type": "audio/aac"}) + await resp.prepare(request) + async for chunk in get_preview_stream(self.mass, provider_instance_id_or_domain, item_id): + await resp.write(chunk) + return resp + + async def _handle_cors_preflight(self, request: web.Request) -> web.Response: + """Handle CORS preflight OPTIONS request.""" + return web.Response( + status=200, + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, Authorization", + "Access-Control-Max-Age": "86400", # Cache preflight for 24 hours + }, + ) + + async def _handle_server_info(self, request: web.Request) -> web.Response: + """Handle request for server info.""" + server_info = self.mass.get_server_info() + # Add CORS headers to allow frontend to call from any origin + return web.json_response( + server_info.to_dict(), + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, Authorization", + }, + ) + + async def _handle_ws_client(self, request: web.Request) -> web.WebSocketResponse: + connection = WebsocketClientHandler(self, request) + if lang := request.headers.get("Accept-Language"): + self.mass.metadata.set_default_preferred_language(lang.split(",")[0]) + try: + self.clients.add(connection) + return await connection.handle_client() + finally: + self.clients.discard(connection) + + async def _handle_jsonrpc_api_command(self, request: web.Request) -> web.Response: + """Handle incoming JSON RPC API command.""" + if not request.can_read_body: + return web.Response(status=400, text="Body required") + cmd_data = await request.read() + self.logger.log(VERBOSE_LOG_LEVEL, "Received on JSONRPC API: %s", cmd_data) + try: + command_msg = CommandMessage.from_json(cmd_data) + except ValueError: + error = f"Invalid JSON: {cmd_data.decode()}" + self.logger.error("Unhandled JSONRPC API error: %s", error) + return web.Response(status=400, text=error) + except MissingField as e: + # be forgiving if message_id is missing + cmd_data_dict = json_loads(cmd_data) + if e.field_name == "message_id" and "command" in cmd_data_dict: + cmd_data_dict["message_id"] = "unknown" + command_msg = CommandMessage.from_dict(cmd_data_dict) + else: + error = f"Missing field in JSON: {e.field_name}" + self.logger.error("Unhandled JSONRPC API error: %s", error) + return web.Response(status=400, text="Invalid JSON: missing required field") + + # work out handler for the given path/command + handler = self.mass.command_handlers.get(command_msg.command) + if handler is None: + error = f"Invalid Command: {command_msg.command}" + self.logger.error("Unhandled JSONRPC API error: %s", error) + return web.Response(status=400, text=error) + + # Check authentication if required + if handler.authenticated or handler.required_role: + if is_request_from_ingress(request): + # Ingress authentication (Home Assistant) + user = await self._get_ingress_user(request) + if not user: + # This should not happen - ingress requests should have user headers + return web.Response( + status=401, + text="Ingress authentication failed - missing user information", + ) + else: + # Regular authentication (non-ingress) + try: + user = await get_authenticated_user(request) + except Exception as e: + self.logger.exception("Authentication error: %s", e) + return web.Response( + status=401, + text="Authentication failed", + headers={"WWW-Authenticate": 'Bearer realm="Music Assistant"'}, + ) + + if not user: + return web.Response( + status=401, + text="Authentication required", + headers={"WWW-Authenticate": 'Bearer realm="Music Assistant"'}, + ) + + # Set user in context and check role + set_current_user(user) + if handler.required_role == "admin" and user.role != UserRole.ADMIN: + return web.Response( + status=403, + text="Admin access required", + ) + + try: + args = parse_arguments(handler.signature, handler.type_hints, command_msg.args) + result: Any = handler.target(**args) + if hasattr(result, "__anext__"): + # handle async generator (for really large listings) + result = [item async for item in result] + elif asyncio.iscoroutine(result): + result = await result + return web.json_response(result, dumps=json_dumps) + except Exception as e: + # Return clean error message without stacktrace + error_type = type(e).__name__ + error_msg = str(e) + error = f"{error_type}: {error_msg}" + self.logger.exception("Error executing command %s: %s", command_msg.command, error) + return web.Response(status=500, text="Internal server error") + + async def _handle_application_log(self, request: web.Request) -> web.Response: + """Handle request to get the application log.""" + log_data = await self.mass.get_application_log() + return web.Response(text=log_data, content_type="text/text") + + async def _handle_api_intro(self, request: web.Request) -> web.Response: + """Handle request for API introduction/documentation page.""" + intro_html_path = str(RESOURCES_DIR.joinpath("api_docs.html")) + # Read the template + async with aiofiles.open(intro_html_path) as f: + html_content = await f.read() + + # Replace placeholders (escape values to prevent XSS) + html_content = html_content.replace("{VERSION}", html.escape(self.mass.version)) + html_content = html_content.replace("{BASE_URL}", html.escape(self.base_url)) + html_content = html_content.replace("{SERVER_HOST}", html.escape(request.host)) + + return web.Response(text=html_content, content_type="text/html") + + async def _handle_openapi_spec(self, request: web.Request) -> web.Response: + """Handle request for OpenAPI specification (generated on-the-fly).""" + spec = generate_openapi_spec( + self.mass.command_handlers, server_url=self.base_url, version=self.mass.version + ) + return web.json_response(spec) + + async def _handle_commands_reference(self, request: web.Request) -> web.FileResponse: + """Handle request for commands reference page.""" + commands_html_path = str(RESOURCES_DIR.joinpath("commands_reference.html")) + return await self._server.serve_static(commands_html_path, request) + + async def _handle_commands_json(self, request: web.Request) -> web.Response: + """Handle request for commands JSON data (generated on-the-fly).""" + commands_data = generate_commands_json(self.mass.command_handlers) + return web.json_response(commands_data) + + async def _handle_schemas_reference(self, request: web.Request) -> web.FileResponse: + """Handle request for schemas reference page.""" + schemas_html_path = str(RESOURCES_DIR.joinpath("schemas_reference.html")) + return await self._server.serve_static(schemas_html_path, request) + + async def _handle_schemas_json(self, request: web.Request) -> web.Response: + """Handle request for schemas JSON data (generated on-the-fly).""" + schemas_data = generate_schemas_json(self.mass.command_handlers) + return web.json_response(schemas_data) + + async def _handle_swagger_ui(self, request: web.Request) -> web.FileResponse: + """Handle request for Swagger UI.""" + swagger_html_path = str(RESOURCES_DIR.joinpath("swagger_ui.html")) + return await self._server.serve_static(swagger_html_path, request) + + async def _handle_login_page(self, request: web.Request) -> web.Response: + """Handle request for login page.""" + # If not yet onboarded, redirect to setup + if not self.mass.config.onboard_done or not await self.auth.has_users(): + return_url = request.query.get("return_url", "") + device_name = request.query.get("device_name", "") + setup_url = ( + f"/setup?return_url={return_url}&device_name={device_name}" + if return_url + else "/setup" + ) + return web.Response(status=302, headers={"Location": setup_url}) + + # Check if this is an ingress request - if so, auto-authenticate and redirect with token + if is_request_from_ingress(request): + ingress_user_id = request.headers.get("X-Remote-User-ID") + ingress_username = request.headers.get("X-Remote-User-Name") + + if ingress_user_id and ingress_username: + # Try to find existing user linked to this HA user ID + user = await self.auth.get_user_by_provider_link( + AuthProviderType.HOME_ASSISTANT, ingress_user_id + ) + + if user: + # User exists, create token and redirect + device_name = request.query.get( + "device_name", f"Home Assistant Ingress ({ingress_username})" + ) + token = await self.auth.create_token(user, device_name) + + return_url = request.query.get("return_url", "/") + + # Insert code parameter before any hash fragment + code_param = f"code={quote(token, safe='')}" + if "#" in return_url: + url_parts = return_url.split("#", 1) + base_part = url_parts[0] + hash_part = url_parts[1] + separator = "&" if "?" in base_part else "?" + redirect_url = f"{base_part}{separator}{code_param}#{hash_part}" + elif "?" in return_url: + redirect_url = f"{return_url}&{code_param}" + else: + redirect_url = f"{return_url}?{code_param}" + + return web.Response(status=302, headers={"Location": redirect_url}) + + # Not ingress or user doesn't exist - serve login page + login_html_path = str(RESOURCES_DIR.joinpath("login.html")) + async with aiofiles.open(login_html_path) as f: + html_content = await f.read() + return web.Response(text=html_content, content_type="text/html") + + async def _handle_auth_login(self, request: web.Request) -> web.Response: + """Handle login request.""" + try: + if not request.can_read_body: + return web.Response(status=400, text="Body required") + + body = await request.json() + provider_id = body.get("provider_id", "builtin") # Default to built-in provider + credentials = body.get("credentials", {}) + return_url = body.get("return_url") # Optional return URL for redirect after login + + # Authenticate with provider + auth_result = await self.auth.authenticate_with_credentials(provider_id, credentials) + + if not auth_result.success or not auth_result.user: + return web.json_response( + {"success": False, "error": auth_result.error}, + status=401, + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, Authorization", + }, + ) + + # Create token for user + device_name = body.get( + "device_name", f"{request.headers.get('User-Agent', 'Unknown')[:50]}" + ) + token = await self.auth.create_token(auth_result.user, device_name) + + # Prepare response data + response_data = { + "success": True, + "token": token, + "user": auth_result.user.to_dict(), + } + + # If return_url provided, append code parameter and return as redirect_to + if return_url: + # Insert code parameter before any hash fragment + code_param = f"code={quote(token, safe='')}" + if "#" in return_url: + url_parts = return_url.split("#", 1) + base_part = url_parts[0] + hash_part = url_parts[1] + separator = "&" if "?" in base_part else "?" + redirect_url = f"{base_part}{separator}{code_param}#{hash_part}" + elif "?" in return_url: + redirect_url = f"{return_url}&{code_param}" + else: + redirect_url = f"{return_url}?{code_param}" + + response_data["redirect_to"] = redirect_url + self.logger.debug( + "Login successful, returning redirect_to: %s", + redirect_url.replace(token, "***TOKEN***"), + ) + + # Add CORS headers to allow login from any origin + return web.json_response( + response_data, + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, Authorization", + }, + ) + except Exception: + self.logger.exception("Error during login") + return web.json_response( + {"success": False, "error": "Login failed"}, + status=500, + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, Authorization", + }, + ) + + async def _handle_auth_logout(self, request: web.Request) -> web.Response: + """Handle logout request.""" + user = await get_authenticated_user(request) + if not user: + return web.Response(status=401, text="Not authenticated") + + # Get token from request + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + # Find and revoke the token + token_hash = hashlib.sha256(token.encode()).hexdigest() + token_row = await self.auth.database.get_row("auth_tokens", {"token_hash": token_hash}) + if token_row: + await self.auth.database.delete("auth_tokens", {"token_id": token_row["token_id"]}) + + return web.json_response({"success": True}) + + async def _handle_auth_me(self, request: web.Request) -> web.Response: + """Handle request for current user information.""" + user = await get_authenticated_user(request) + if not user: + return web.Response(status=401, text="Not authenticated") + + return web.json_response(user.to_dict()) + + async def _handle_auth_me_update(self, request: web.Request) -> web.Response: + """Handle request to update current user's profile.""" + user = await get_authenticated_user(request) + if not user: + return web.Response(status=401, text="Not authenticated") + + try: + if not request.can_read_body: + return web.Response(status=400, text="Body required") + + body = await request.json() + username = body.get("username") + display_name = body.get("display_name") + avatar_url = body.get("avatar_url") + + # Update user + updated_user = await self.auth.update_user( + user, + username=username, + display_name=display_name, + avatar_url=avatar_url, + ) + + return web.json_response({"success": True, "user": updated_user.to_dict()}) + except Exception: + self.logger.exception("Error updating user profile") + return web.json_response( + {"success": False, "error": "Failed to update profile"}, status=500 + ) + + async def _handle_auth_providers(self, request: web.Request) -> web.Response: + """Handle request for available login providers.""" + try: + providers = await self.auth.get_login_providers() + return web.json_response(providers) + except Exception: + self.logger.exception("Error getting auth providers") + return web.json_response({"error": "Failed to get auth providers"}, status=500) + + async def _handle_auth_authorize(self, request: web.Request) -> web.Response: + """Handle OAuth authorization request.""" + try: + provider_id = request.query.get("provider_id") + return_url = request.query.get("return_url") + + self.logger.debug( + "OAuth authorize request: provider_id=%s, return_url=%s", provider_id, return_url + ) + + if not provider_id: + return web.Response(status=400, text="provider_id required") + + # Validate return_url if provided + if return_url: + is_valid, _ = is_allowed_redirect_url(return_url, request, self.base_url) + if not is_valid: + return web.Response(status=400, text="Invalid return_url") + + auth_url = await self.auth.get_authorization_url(provider_id, return_url) + if not auth_url: + return web.Response( + status=400, text="Provider does not support OAuth or is not configured" + ) + + return web.json_response({"authorization_url": auth_url}) + except Exception: + self.logger.exception("Error during OAuth authorization") + return web.json_response({"error": "Authorization failed"}, status=500) + + async def _handle_auth_callback(self, request: web.Request) -> web.Response: + """Handle OAuth callback.""" + try: + code = request.query.get("code") + state = request.query.get("state") + provider_id = request.query.get("provider_id") + + if not code or not state or not provider_id: + return web.Response(status=400, text="code, state, and provider_id required") + + redirect_uri = f"{self.base_url}/auth/callback?provider_id={provider_id}" + auth_result = await self.auth.handle_oauth_callback( + provider_id, code, state, redirect_uri + ) + + if not auth_result.success or not auth_result.user: + # Return error page + error_html = f""" + + +

Authentication Failed

+

{html.escape(auth_result.error or "Unknown error")}

+ Back to Login + + + """ + return web.Response(text=error_html, content_type="text/html", status=400) + + # Create token + device_name = f"OAuth ({provider_id})" + token = await self.auth.create_token(auth_result.user, device_name) + + # Determine redirect URL (use return_url from OAuth flow or default to root) + final_redirect_url = auth_result.return_url or "/" + requires_consent = False + + # Validate redirect URL for security + if auth_result.return_url: + is_valid, category = is_allowed_redirect_url( + auth_result.return_url, request, self.base_url + ) + if not is_valid: + self.logger.warning("Invalid return_url blocked: %s", auth_result.return_url) + final_redirect_url = "/" + elif category == "external": + # External domain - require user consent + requires_consent = True + # Add code parameter to redirect URL (the token URL-encoded) + # Important: Insert code BEFORE any hash fragment (e.g., #/) to ensure + # it's in query params, not inside the hash where Vue Router can't access it + code_param = f"code={quote(token, safe='')}" + + # Split URL by hash to insert code in the right place + if "#" in final_redirect_url: + # URL has a hash fragment (e.g., http://example.com/#/ or http://example.com/path#section) + url_parts = final_redirect_url.split("#", 1) + base_url = url_parts[0] + hash_part = url_parts[1] + + # Add code to base URL (before hash) + separator = "&" if "?" in base_url else "?" + final_redirect_url = f"{base_url}{separator}{code_param}#{hash_part}" + # No hash fragment, simple case + elif "?" in final_redirect_url: + final_redirect_url = f"{final_redirect_url}&{code_param}" + else: + final_redirect_url = f"{final_redirect_url}?{code_param}" + + # Load OAuth callback success page template and inject token and redirect URL + oauth_callback_html_path = str(RESOURCES_DIR.joinpath("oauth_callback.html")) + async with aiofiles.open(oauth_callback_html_path) as f: + success_html = await f.read() + + # Replace template placeholders + success_html = success_html.replace("{TOKEN}", token) + success_html = success_html.replace("{REDIRECT_URL}", final_redirect_url) + success_html = success_html.replace( + "{REQUIRES_CONSENT}", "true" if requires_consent else "false" + ) + + return web.Response(text=success_html, content_type="text/html") + except Exception: + self.logger.exception("Error during OAuth callback") + error_html = """ + + +

Authentication Failed

+

An error occurred during authentication

+ Back to Login + + + """ + return web.Response(text=error_html, content_type="text/html", status=500) + + async def _handle_setup_page(self, request: web.Request) -> web.Response: + """Handle request for first-time setup page.""" + # Check if setup is needed + # Allow setup if either: + # 1. No users exist yet (fresh install) + # 2. Users exist but onboarding not done (e.g., Ingress auto-created user) + if await self.auth.has_users() and self.mass.config.get(CONF_ONBOARD_DONE): + # Setup already completed, redirect to login + return web.Response(status=302, headers={"Location": "/login"}) + + # Validate return_url if provided + return_url = request.query.get("return_url") + if return_url: + is_valid, _ = is_allowed_redirect_url(return_url, request, self.base_url) + if not is_valid: + return web.Response(status=400, text="Invalid return_url") + + # Serve setup page + setup_html_path = str(RESOURCES_DIR.joinpath("setup.html")) + async with aiofiles.open(setup_html_path) as f: + html_content = await f.read() + + # Check if this is from Ingress - if so, pre-fill user info + if is_request_from_ingress(request): + ingress_username = request.headers.get("X-Remote-User-Name", "") + ingress_display_name = request.headers.get("X-Remote-User-Display-Name", "") + + # Inject ingress user info into the page (use json.dumps to escape properly) + html_content = html_content.replace( + "const deviceName = urlParams.get('device_name');", + f"const deviceName = urlParams.get('device_name');\n" + f" const ingressUsername = {json.dumps(ingress_username)};\n" + f" const ingressDisplayName = {json.dumps(ingress_display_name)};", + ) + + return web.Response(text=html_content, content_type="text/html") + + async def _handle_setup(self, request: web.Request) -> web.Response: + """Handle first-time setup request to create admin user.""" + # Check if setup is still needed (allow if onboard_done is false) + if await self.auth.has_users() and self.mass.config.get(CONF_ONBOARD_DONE): + return web.json_response( + {"success": False, "error": "Setup already completed"}, status=400 + ) + + if not request.can_read_body: + return web.Response(status=400, text="Body required") + + body = await request.json() + username = body.get("username", "").strip() + password = body.get("password", "") + from_ingress = body.get("from_ingress", False) + display_name = body.get("display_name") + + # Validation + if not username or len(username) < 3: + return web.json_response( + {"success": False, "error": "Username must be at least 3 characters"}, status=400 + ) + + if not password or len(password) < 8: + return web.json_response( + {"success": False, "error": "Password must be at least 8 characters"}, status=400 + ) + + try: + # Get built-in provider + builtin_provider = self.auth.login_providers.get("builtin") + if not builtin_provider: + return web.json_response( + {"success": False, "error": "Built-in auth provider not available"}, status=500 + ) + + if not isinstance(builtin_provider, BuiltinLoginProvider): + return web.json_response( + {"success": False, "error": "Built-in provider configuration error"}, status=500 + ) + + # Check if this is an Ingress setup where user already exists + user = None + if from_ingress and is_request_from_ingress(request): + ha_user_id = request.headers.get("X-Remote-User-ID") + if ha_user_id: + # Try to find existing auto-created Ingress user + user = await self.auth.get_user_by_provider_link( + AuthProviderType.HOME_ASSISTANT, ha_user_id + ) + + if user: + # User already exists (auto-created from Ingress), update and add password + updates = {} + if display_name and not user.display_name: + updates["display_name"] = display_name + user.display_name = display_name + + # Make user admin if not already + if user.role != UserRole.ADMIN: + updates["role"] = UserRole.ADMIN.value + user.role = UserRole.ADMIN + + # Apply updates if any + if updates: + await self.auth.database.update( + "users", + {"user_id": user.user_id}, + updates, + ) + + # Add password authentication to existing user + password_hash = builtin_provider._hash_password(password, username) + await self.auth.link_user_to_provider(user, AuthProviderType.BUILTIN, password_hash) + else: + # Create new admin user with password + user = await builtin_provider.create_user_with_password( + username, password, role=UserRole.ADMIN, display_name=display_name + ) + + # If from Ingress, also link to HA provider + if from_ingress and is_request_from_ingress(request): + ha_user_id = request.headers.get("X-Remote-User-ID") + if ha_user_id: + # Link user to Home Assistant provider + await self.auth.link_user_to_provider( + user, AuthProviderType.HOME_ASSISTANT, ha_user_id + ) + + # Create token for the new admin + device_name = body.get( + "device_name", f"Setup ({request.headers.get('User-Agent', 'Unknown')[:50]})" + ) + token = await self.auth.create_token(user, device_name) + + # Mark onboarding as complete + self.mass.config.set(CONF_ONBOARD_DONE, True) + self.mass.config.save(immediate=True) + + self.logger.info("First admin user created: %s", username) + + return web.json_response( + { + "success": True, + "token": token, + "user": user.to_dict(), + } + ) + + except Exception as e: + self.logger.exception("Error during setup") + return web.json_response( + {"success": False, "error": f"Setup failed: {e!s}"}, status=500 + ) + + async def _get_ingress_user(self, request: web.Request) -> User | None: + """ + Get or create user for ingress (Home Assistant) requests. + + Extracts user information from Home Assistant ingress headers and either + finds the existing linked user or creates a new one. + + :param request: The web request with HA ingress headers. + :return: User object or None if headers are missing. + """ + ingress_user_id = request.headers.get("X-Remote-User-ID") + ingress_username = request.headers.get("X-Remote-User-Name") + ingress_display_name = request.headers.get("X-Remote-User-Display-Name") + + if not ingress_user_id or not ingress_username: + # No user headers available + return None + + # Try to find existing user linked to this HA user ID + user = await self.auth.get_user_by_provider_link( + AuthProviderType.HOME_ASSISTANT, ingress_user_id + ) + + if not user: + # Security: Ensure at least one user exists (setup should have been completed) + if not await self.auth.has_users(): + self.logger.warning("Ingress request attempted before setup completed") + return None + + # Auto-create user for Ingress (they're already authenticated by HA) + # Always create with USER role (admin is created during setup) + user = await self.auth.create_user( + username=ingress_username, + role=UserRole.USER, + display_name=ingress_display_name, + ) + # Link to Home Assistant provider + await self.auth.link_user_to_provider( + user, AuthProviderType.HOME_ASSISTANT, ingress_user_id + ) + self.logger.info("Auto-created ingress user: %s", ingress_username) + + return user + + async def _announce_to_homeassistant(self) -> None: + """Announce Music Assistant Ingress server to Home Assistant via Supervisor API.""" + supervisor_token = os.environ["SUPERVISOR_TOKEN"] + addon_hostname = os.environ["HOSTNAME"] + + # Get or create auth token for the HA system user + ha_integration_token = await self.auth.get_homeassistant_system_user_token() + + discovery_payload = { + "service": "music_assistant", + "config": { + "host": addon_hostname, + "port": INGRESS_SERVER_PORT, + "auth_token": ha_integration_token, + }, + } + + try: + async with self.mass.http_session_no_ssl.post( + "http://supervisor/discovery", + headers={"Authorization": f"Bearer {supervisor_token}"}, + json=discovery_payload, + timeout=ClientTimeout(total=10), + ) as response: + response.raise_for_status() + result = await response.json() + self.logger.debug( + "Successfully announced to Home Assistant. Discovery UUID: %s", + result.get("uuid"), + ) + except Exception as err: + self.logger.warning("Failed to announce to Home Assistant: %s", err) diff --git a/music_assistant/controllers/webserver/helpers/__init__.py b/music_assistant/controllers/webserver/helpers/__init__.py new file mode 100644 index 00000000..b223a211 --- /dev/null +++ b/music_assistant/controllers/webserver/helpers/__init__.py @@ -0,0 +1 @@ +"""Helpers for the webserver controller.""" diff --git a/music_assistant/controllers/webserver/helpers/auth_middleware.py b/music_assistant/controllers/webserver/helpers/auth_middleware.py new file mode 100644 index 00000000..a23d4ce3 --- /dev/null +++ b/music_assistant/controllers/webserver/helpers/auth_middleware.py @@ -0,0 +1,229 @@ +"""Authentication middleware and helpers for HTTP requests and WebSocket connections.""" + +from __future__ import annotations + +from contextvars import ContextVar +from typing import TYPE_CHECKING, Any, cast + +from aiohttp import web +from music_assistant_models.auth import AuthProviderType, User, UserRole + +from music_assistant.constants import HOMEASSISTANT_SYSTEM_USER + +if TYPE_CHECKING: + from music_assistant import MusicAssistant + +# Context key for storing authenticated user in request +USER_CONTEXT_KEY = "authenticated_user" + +# ContextVar for tracking current user and token across async calls +current_user: ContextVar[User | None] = ContextVar("current_user", default=None) +current_token: ContextVar[str | None] = ContextVar("current_token", default=None) + + +async def get_authenticated_user(request: web.Request) -> User | None: + """Get authenticated user from request. + + :param request: The aiohttp request. + """ + # Check if user is already in context (from middleware) + if USER_CONTEXT_KEY in request: + return cast("User | None", request[USER_CONTEXT_KEY]) + + mass: MusicAssistant = request.app["mass"] + + # Check for Home Assistant Ingress connections + if is_request_from_ingress(request): + ingress_user_id = request.headers.get("X-Remote-User-ID") + ingress_username = request.headers.get("X-Remote-User-Name") + ingress_display_name = request.headers.get("X-Remote-User-Display-Name") + + # Require all Ingress headers to be present for security + if not (ingress_user_id and ingress_username): + return None + + # Try to find existing user linked to this HA user ID + user = await mass.webserver.auth.get_user_by_provider_link( + AuthProviderType.HOME_ASSISTANT, ingress_user_id + ) + + if not user: + # Security: Ensure at least one user exists (setup should have been completed) + if not await mass.webserver.auth.has_users(): + # No users exist - setup has not been completed + # This should not happen as the server redirects to /setup + return None + + # Auto-create user for Ingress (they're already authenticated by HA) + # Always create with USER role (admin is created during setup) + user = await mass.webserver.auth.create_user( + username=ingress_username, + role=UserRole.USER, + display_name=ingress_display_name, + ) + # Link to Home Assistant provider + await mass.webserver.auth.link_user_to_provider( + user, AuthProviderType.HOME_ASSISTANT, ingress_user_id + ) + + # Store in request context + request[USER_CONTEXT_KEY] = user + return user + + # Try to authenticate from Authorization header + auth_header = request.headers.get("Authorization") + if not auth_header: + return None + + # Expected format: "Bearer " + parts = auth_header.split(" ", 1) + if len(parts) != 2 or parts[0].lower() != "bearer": + return None + + token = parts[1] + + # Authenticate with token (works for both user tokens and API keys) + user = await mass.webserver.auth.authenticate_with_token(token) + if user: + # Security: Deny homeassistant system user on regular (non-Ingress) webserver + if not is_request_from_ingress(request) and user.username == HOMEASSISTANT_SYSTEM_USER: + # Reject system user on regular webserver (should only use Ingress server) + return None + + # Store in request context + request[USER_CONTEXT_KEY] = user + + return user + + +async def require_authentication(request: web.Request) -> User: + """Require authentication for a request, raise 401 if not authenticated. + + :param request: The aiohttp request. + """ + user = await get_authenticated_user(request) + if not user: + raise web.HTTPUnauthorized( + text="Authentication required", + headers={"WWW-Authenticate": 'Bearer realm="Music Assistant"'}, + ) + return user + + +async def require_admin(request: web.Request) -> User: + """Require admin role for a request, raise 403 if not admin. + + :param request: The aiohttp request. + """ + user = await require_authentication(request) + if user.role != UserRole.ADMIN: + raise web.HTTPForbidden(text="Admin access required") + return user + + +def get_current_user() -> User | None: + """ + Get the current authenticated user from context. + + :return: The current user or None if not authenticated. + """ + return current_user.get() + + +def set_current_user(user: User | None) -> None: + """ + Set the current authenticated user in context. + + :param user: The user to set as current. + """ + current_user.set(user) + + +def get_current_token() -> str | None: + """ + Get the current authentication token from context. + + :return: The current token or None if not authenticated. + """ + return current_token.get() + + +def set_current_token(token: str | None) -> None: + """ + Set the current authentication token in context. + + :param token: The token to set as current. + """ + current_token.set(token) + + +def is_request_from_ingress(request: web.Request) -> bool: + """Check if request is coming from Home Assistant Ingress (internal network). + + Security is enforced by socket-level verification (IP/port binding), not headers. + Only requests on the internal ingress TCP site (172.30.32.x:8094) are accepted. + + :param request: The aiohttp request. + """ + # Check if ingress site is configured in the app + ingress_site_params = request.app.get("ingress_site") + if not ingress_site_params: + # No ingress site configured, can't be an ingress request + return False + + try: + # Security: Verify the request came through the ingress site by checking socket + # to prevent bypassing authentication on the regular webserver + transport = request.transport + if transport: + sockname = transport.get_extra_info("sockname") + if sockname and len(sockname) >= 2: + server_ip, server_port = sockname[0], sockname[1] + expected_ip, expected_port = ingress_site_params + # Request must match the ingress site's bind address and port + return bool(server_ip == expected_ip and server_port == expected_port) + except Exception: # noqa: S110 + pass + + return False + + +@web.middleware +async def auth_middleware(request: web.Request, handler: Any) -> web.StreamResponse: + """Authenticate requests and store user in context. + + :param request: The aiohttp request. + :param handler: The request handler. + """ + # Skip authentication for ingress requests (HA handles auth) + if is_request_from_ingress(request): + return cast("web.StreamResponse", await handler(request)) + + # Unauthenticated routes (static files, info, login, setup, etc.) + unauthenticated_paths = [ + "/info", + "/login", + "/setup", + "/auth/", + "/api-docs/", + "/assets/", + "/favicon.ico", + "/manifest.json", + "/index.html", + "/", + ] + + # Check if path should bypass auth + for path_prefix in unauthenticated_paths: + if request.path.startswith(path_prefix): + return cast("web.StreamResponse", await handler(request)) + + # Try to authenticate + user = await get_authenticated_user(request) + + # Store user in context (might be None for unauthenticated requests) + request[USER_CONTEXT_KEY] = user + + # Let the handler decide if authentication is required + # The handler will call require_authentication() if needed + return cast("web.StreamResponse", await handler(request)) diff --git a/music_assistant/controllers/webserver/helpers/auth_providers.py b/music_assistant/controllers/webserver/helpers/auth_providers.py new file mode 100644 index 00000000..74fe1ed3 --- /dev/null +++ b/music_assistant/controllers/webserver/helpers/auth_providers.py @@ -0,0 +1,577 @@ +"""Authentication provider base classes and implementations.""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import logging +import secrets +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime +from typing import TYPE_CHECKING, Any, TypedDict, cast +from urllib.parse import urlparse + +from hass_client import HomeAssistantClient +from hass_client.exceptions import BaseHassClientError +from hass_client.utils import base_url, get_auth_url, get_token, get_websocket_url +from music_assistant_models.auth import AuthProviderType, User, UserRole + +from music_assistant.constants import MASS_LOGGER_NAME + +if TYPE_CHECKING: + from music_assistant import MusicAssistant + from music_assistant.controllers.webserver.auth import AuthenticationManager + from music_assistant.providers.hass import HomeAssistantProvider + +LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.auth") + + +class LoginProviderConfig(TypedDict, total=False): + """Base configuration for login providers.""" + + allow_self_registration: bool + + +class HomeAssistantProviderConfig(LoginProviderConfig): + """Configuration for Home Assistant OAuth provider.""" + + ha_url: str + + +@dataclass +class AuthResult: + """Result of an authentication attempt.""" + + success: bool + user: User | None = None + error: str | None = None + access_token: str | None = None + return_url: str | None = None + + +class LoginProvider(ABC): + """Base class for login providers.""" + + def __init__(self, mass: MusicAssistant, provider_id: str, config: LoginProviderConfig) -> None: + """ + Initialize login provider. + + :param mass: MusicAssistant instance. + :param provider_id: Unique identifier for this provider instance. + :param config: Provider-specific configuration. + """ + self.mass = mass + self.provider_id = provider_id + self.config = config + self.logger = LOGGER + self.allow_self_registration = config.get("allow_self_registration", False) + + @property + def auth_manager(self) -> AuthenticationManager: + """Get auth manager from webserver.""" + return self.mass.webserver.auth + + @property + @abstractmethod + def provider_type(self) -> AuthProviderType: + """Return the provider type.""" + + @property + @abstractmethod + def requires_redirect(self) -> bool: + """Return True if this provider requires OAuth redirect.""" + + @abstractmethod + async def authenticate(self, credentials: dict[str, Any]) -> AuthResult: + """ + Authenticate user with provided credentials. + + :param credentials: Provider-specific credentials (username/password, OAuth code, etc). + """ + + async def get_authorization_url( + self, redirect_uri: str, return_url: str | None = None + ) -> str | None: + """ + Get OAuth authorization URL if applicable. + + :param redirect_uri: The callback URL for OAuth flow. + :param return_url: Optional URL to redirect to after successful login. + """ + return None + + async def handle_oauth_callback(self, code: str, state: str, redirect_uri: str) -> AuthResult: + """ + Handle OAuth callback if applicable. + + :param code: OAuth authorization code. + :param state: OAuth state parameter for CSRF protection. + :param redirect_uri: The callback URL. + """ + return AuthResult(success=False, error="OAuth not supported by this provider") + + +class BuiltinLoginProvider(LoginProvider): + """Built-in username/password login provider.""" + + @property + def provider_type(self) -> AuthProviderType: + """Return the provider type.""" + return AuthProviderType.BUILTIN + + @property + def requires_redirect(self) -> bool: + """Return False - built-in provider doesn't need redirect.""" + return False + + async def authenticate(self, credentials: dict[str, Any]) -> AuthResult: + """ + Authenticate user with username and password. + + :param credentials: Dict containing 'username' and 'password'. + """ + username = credentials.get("username") + password = credentials.get("password") + + if not username or not password: + return AuthResult(success=False, error="Username and password required") + + # First, look up user by username to get user_id + # This is needed to create the password hash with user_id in the salt + user_row = await self.auth_manager.database.get_row("users", {"username": username}) + if not user_row: + return AuthResult(success=False, error="Invalid username or password") + + user_id = user_row["user_id"] + + # Hash the password using user_id for enhanced security + password_hash = self._hash_password(password, user_id) + + # Verify the password by checking if provider link exists + user = await self.auth_manager.get_user_by_provider_link( + AuthProviderType.BUILTIN, password_hash + ) + + if not user: + return AuthResult(success=False, error="Invalid username or password") + + # Check if user is enabled + if not user.enabled: + return AuthResult(success=False, error="User account is disabled") + + return AuthResult(success=True, user=user) + + async def create_user_with_password( + self, + username: str, + password: str, + role: UserRole = UserRole.USER, + display_name: str | None = None, + ) -> User: + """ + Create a new built-in user with password. + + :param username: The username. + :param password: The password (will be hashed). + :param role: The user role (default: USER). + :param display_name: Optional display name. + """ + # Create the user + user = await self.auth_manager.create_user( + username=username, + role=role, + display_name=display_name, + ) + + # Hash password using user_id for enhanced security + password_hash = self._hash_password(password, user.user_id) + await self.auth_manager.link_user_to_provider(user, AuthProviderType.BUILTIN, password_hash) + + return user + + async def change_password(self, user: User, old_password: str, new_password: str) -> bool: + """ + Change user password. + + :param user: The user. + :param old_password: Current password for verification. + :param new_password: The new password. + """ + # Verify old password first using user_id + old_password_hash = self._hash_password(old_password, user.user_id) + existing_user = await self.auth_manager.get_user_by_provider_link( + AuthProviderType.BUILTIN, old_password_hash + ) + + if not existing_user or existing_user.user_id != user.user_id: + return False + + # Update password link with new hash using user_id + new_password_hash = self._hash_password(new_password, user.user_id) + await self.auth_manager.update_provider_link( + user, AuthProviderType.BUILTIN, new_password_hash + ) + + return True + + async def reset_password(self, user: User, new_password: str) -> None: + """ + Reset user password (admin only - no old password verification). + + :param user: The user whose password to reset. + :param new_password: The new password. + """ + # Hash new password using user_id and update provider link + new_password_hash = self._hash_password(new_password, user.user_id) + await self.auth_manager.update_provider_link( + user, AuthProviderType.BUILTIN, new_password_hash + ) + + def _hash_password(self, password: str, user_id: str) -> str: + """ + Hash password with salt combining user ID and server ID. + + :param password: Plain text password. + :param user_id: User ID to include in salt (random token for high entropy). + """ + # Combine user_id (random) and server_id for maximum security + salt = f"{user_id}:{self.mass.server_id}" + return hashlib.pbkdf2_hmac( + "sha256", password.encode(), salt.encode(), iterations=100000 + ).hex() + + +class HomeAssistantOAuthProvider(LoginProvider): + """Home Assistant OAuth login provider.""" + + @property + def provider_type(self) -> AuthProviderType: + """Return the provider type.""" + return AuthProviderType.HOME_ASSISTANT + + @property + def requires_redirect(self) -> bool: + """Return True - Home Assistant OAuth requires redirect.""" + return True + + async def authenticate(self, credentials: dict[str, Any]) -> AuthResult: + """ + Not used for OAuth providers - use handle_oauth_callback instead. + + :param credentials: Not used. + """ + return AuthResult(success=False, error="Use OAuth flow for Home Assistant authentication") + + async def _get_external_ha_url(self) -> str | None: + """ + Get the external URL for Home Assistant from the config API. + + This is needed when MA runs as HA add-on and connects via internal docker network + (http://supervisor/api) but needs the external URL for OAuth redirects. + + :return: External URL if available, otherwise None. + """ + ha_url = cast("str", self.config.get("ha_url")) if self.config.get("ha_url") else None + if not ha_url: + return None + + # Check if we're using the internal supervisor URL + if "supervisor" not in ha_url.lower(): + # Not using internal URL, return as-is + return ha_url + + # We're using internal URL - try to get external URL from HA provider + ha_provider = self.mass.get_provider("hass") + if not ha_provider: + # No HA provider available, use configured URL + return ha_url + + ha_provider = cast("HomeAssistantProvider", ha_provider) + + try: + # Access the hass client from the provider + hass_client = ha_provider.hass + if not hass_client or not hass_client.connected: + return ha_url + + # Get network URLs from Home Assistant using WebSocket API + # This command returns internal, external, and cloud URLs + network_urls = await hass_client.send_command("network/url") + + if network_urls: + # Priority: external > cloud > internal + # External is the manually configured external URL + # Cloud is the Nabu Casa cloud URL + # Internal is the local network URL + external_url = network_urls.get("external") + cloud_url = network_urls.get("cloud") + internal_url = network_urls.get("internal") + + # Use external URL first, then cloud, then internal + final_url = cast("str", external_url or cloud_url or internal_url) + if final_url: + self.logger.debug( + "Using HA URL for OAuth: %s (from network/url, configured: %s)", + final_url, + ha_url, + ) + return final_url + except Exception as err: + self.logger.warning("Failed to fetch HA network URLs: %s", err, exc_info=True) + + # Fallback to configured URL + return ha_url + + async def get_authorization_url( + self, redirect_uri: str, return_url: str | None = None + ) -> str | None: + """ + Get Home Assistant OAuth authorization URL using hass_client. + + :param redirect_uri: The callback URL. + :param return_url: Optional URL to redirect to after successful login. + """ + # Get the correct HA URL (external URL if running as add-on) + ha_url = await self._get_external_ha_url() + if not ha_url: + return None + + # If HA URL is still the internal supervisor URL (no external_url in HA config), + # infer from redirect_uri (the URL user is accessing MA from) + if "supervisor" in ha_url.lower(): + # Extract scheme and host from redirect_uri to build external HA URL + parsed = urlparse(redirect_uri) + # HA typically runs on port 8123, but use default ports for HTTPS (443) or HTTP (80) + if parsed.scheme == "https": + # HTTPS - use default port 443 (no port in URL) + inferred_ha_url = f"{parsed.scheme}://{parsed.hostname}" + else: + # HTTP - assume HA runs on default port 8123 + inferred_ha_url = f"{parsed.scheme}://{parsed.hostname}:8123" + + self.logger.debug( + "HA external_url not configured, inferring from callback URL: %s", + inferred_ha_url, + ) + ha_url = inferred_ha_url + + state = secrets.token_urlsafe(32) + # Store state and return_url for verification and final redirect + self._oauth_state = state + self._oauth_return_url = return_url + + # Use base_url of callback as client_id (same as HA provider does) + client_id = base_url(redirect_uri) + + # Use hass_client's get_auth_url utility + return cast( + "str", + get_auth_url( + ha_url, + redirect_uri, + client_id=client_id, + state=state, + ), + ) + + def _decode_ha_jwt_token(self, access_token: str) -> tuple[str | None, str | None]: + """ + Decode Home Assistant JWT token to extract user ID and name. + + :param access_token: The JWT access token from Home Assistant. + :return: Tuple of (user_id, username) or (None, None) if decoding fails. + """ + try: + # JWT tokens have 3 parts separated by dots: header.payload.signature + parts = access_token.split(".") + if len(parts) >= 2: + # Decode the payload (second part) + # Add padding if needed (JWT base64 may not be padded) + payload = parts[1] + payload += "=" * (4 - len(payload) % 4) + decoded = base64.urlsafe_b64decode(payload) + token_data = json.loads(decoded) + + # Home Assistant JWT tokens use 'iss' as the user ID + ha_user_id: str | None = token_data.get("iss") + + if not ha_user_id: + # Fallback to 'sub' if 'iss' is not present + ha_user_id = token_data.get("sub") + + # Try to extract username from token (name, username, or other fields) + username = token_data.get("name") or token_data.get("username") + + if ha_user_id: + return str(ha_user_id), username + return None, None + except Exception as decode_error: + self.logger.error("Failed to decode HA JWT token: %s", decode_error) + + return None, None + + async def _fetch_ha_user_via_websocket( + self, ha_url: str, access_token: str + ) -> tuple[str | None, str | None]: + """ + Fetch user information from Home Assistant via WebSocket. + + :param ha_url: Home Assistant URL. + :param access_token: Access token for WebSocket authentication. + :return: Tuple of (username, display_name) or (None, None) if fetch fails. + """ + ws_url = get_websocket_url(ha_url) + + try: + # Use context manager to automatically handle connect/disconnect + async with HomeAssistantClient(ws_url, access_token, self.mass.http_session) as client: + # Use the auth/current_user command to get user details + result = await client.send_command("auth/current_user") + + if result: + # Extract username and display name from response + username = result.get("name") or result.get("username") + display_name = result.get("name") + if username: + return username, display_name + + self.logger.warning("auth/current_user returned no user data") + return None, None + + except BaseHassClientError as ws_error: + self.logger.error("Failed to fetch HA user via WebSocket: %s", ws_error) + return None, None + + async def _get_or_create_user( + self, username: str, display_name: str | None, ha_user_id: str + ) -> User | None: + """ + Get or create a user for Home Assistant OAuth authentication. + + :param username: Username from Home Assistant. + :param display_name: Display name from Home Assistant. + :param ha_user_id: Home Assistant user ID. + :return: User object or None if creation failed. + """ + # Check if user already linked to HA + user = await self.auth_manager.get_user_by_provider_link( + AuthProviderType.HOME_ASSISTANT, ha_user_id + ) + if user: + return user + + # Check if a user with this username already exists (from built-in provider) + user_row = await self.auth_manager.database.get_row("users", {"username": username}) + if user_row: + # User exists with this username - link them to HA provider + user_dict = dict(user_row) + existing_user = User( + user_id=user_dict["user_id"], + username=user_dict["username"], + role=UserRole(user_dict["role"]), + enabled=bool(user_dict["enabled"]), + created_at=datetime.fromisoformat(user_dict["created_at"]), + display_name=user_dict["display_name"], + avatar_url=user_dict["avatar_url"], + ) + + # Link existing user to Home Assistant + await self.auth_manager.link_user_to_provider( + existing_user, AuthProviderType.HOME_ASSISTANT, ha_user_id + ) + + self.logger.debug("Linked existing user '%s' to Home Assistant provider", username) + return existing_user + + # New HA user - check if self-registration allowed + if not self.allow_self_registration: + return None + + # Create new user with USER role + user = await self.auth_manager.create_user( + username=username, + role=UserRole.USER, + display_name=display_name or username, + ) + + # Link to Home Assistant + await self.auth_manager.link_user_to_provider( + user, AuthProviderType.HOME_ASSISTANT, ha_user_id + ) + + return user + + async def handle_oauth_callback(self, code: str, state: str, redirect_uri: str) -> AuthResult: + """ + Handle Home Assistant OAuth callback using hass_client. + + :param code: OAuth authorization code. + :param state: OAuth state parameter. + :param redirect_uri: The callback URL. + """ + # Verify state + if not hasattr(self, "_oauth_state") or state != self._oauth_state: + return AuthResult(success=False, error="Invalid state parameter") + + # Get the correct HA URL (external URL if running as add-on) + # This must be the same URL used in get_authorization_url + ha_url = await self._get_external_ha_url() + if not ha_url: + return AuthResult(success=False, error="Home Assistant URL not configured") + + try: + # Use base_url of callback as client_id (same as HA provider does) + client_id = base_url(redirect_uri) + + # Use hass_client's get_token utility - no client_secret needed! + try: + token_details = await get_token(ha_url, code, client_id=client_id) + except Exception as token_error: + self.logger.error( + "Failed to get token from HA: %s (client_id: %s, ha_url: %s)", + token_error, + client_id, + ha_url, + ) + return AuthResult( + success=False, error=f"Failed to exchange OAuth code: {token_error}" + ) + + access_token = token_details.get("access_token") + if not access_token: + return AuthResult(success=False, error="No access token received from HA") + + # Decode JWT token to get HA user ID + ha_user_id, _ = self._decode_ha_jwt_token(access_token) + if not ha_user_id: + return AuthResult(success=False, error="Failed to decode token") + + # Fetch user information from HA via WebSocket + username, display_name = await self._fetch_ha_user_via_websocket(ha_url, access_token) + + # If we couldn't get username from WebSocket, fail authentication + if not username: + return AuthResult( + success=False, + error="Failed to get username from Home Assistant", + ) + + # Get or create user + user = await self._get_or_create_user(username, display_name, ha_user_id) + + # Get stored return_url from OAuth state + return_url = getattr(self, "_oauth_return_url", None) + + if not user: + return AuthResult( + success=False, + error="Self-registration is disabled. Please contact an administrator.", + ) + + return AuthResult(success=True, user=user, return_url=return_url) + + except Exception as e: + self.logger.exception("Error during Home Assistant OAuth callback") + return AuthResult(success=False, error=str(e)) diff --git a/music_assistant/controllers/webserver/websocket_client.py b/music_assistant/controllers/webserver/websocket_client.py new file mode 100644 index 00000000..1e0417e1 --- /dev/null +++ b/music_assistant/controllers/webserver/websocket_client.py @@ -0,0 +1,404 @@ +"""WebSocket client handler for Music Assistant API.""" + +from __future__ import annotations + +import asyncio +import logging +from concurrent import futures +from contextlib import suppress +from typing import TYPE_CHECKING, Any, Final + +from aiohttp import WSMsgType, web +from music_assistant_models.api import ( + CommandMessage, + ErrorResultMessage, + MessageType, + SuccessResultMessage, +) +from music_assistant_models.auth import AuthProviderType, UserRole +from music_assistant_models.errors import ( + AuthenticationRequired, + InsufficientPermissions, + InvalidCommand, + InvalidToken, +) + +from music_assistant.constants import HOMEASSISTANT_SYSTEM_USER, VERBOSE_LOG_LEVEL +from music_assistant.helpers.api import APICommandHandler, parse_arguments + +from .helpers.auth_middleware import is_request_from_ingress, set_current_token, set_current_user + +if TYPE_CHECKING: + from music_assistant.controllers.webserver import WebserverController + +MAX_PENDING_MSG = 512 +CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError) + + +class WebsocketClientHandler: + """Handle an active websocket client connection.""" + + def __init__(self, webserver: WebserverController, request: web.Request) -> None: + """Initialize an active connection.""" + self.webserver = webserver + self.mass = webserver.mass + self.request = request + self.wsock = web.WebSocketResponse(heartbeat=55) + self._to_write: asyncio.Queue[str | None] = asyncio.Queue(maxsize=MAX_PENDING_MSG) + self._handle_task: asyncio.Task[Any] | None = None + self._writer_task: asyncio.Task[None] | None = None + self._logger = webserver.logger + self._authenticated_user: Any = None # Will be set after auth command or from Ingress + self._current_token: str | None = None # Will be set after auth command + self._token_id: str | None = None # Will be set after auth for tracking revocation + self._is_ingress = is_request_from_ingress(request) + self._events_unsub_callback: Any = None # Will be set after authentication + # try to dynamically detect the base_url of a client if proxied or behind Ingress + self.base_url: str | None = None + if forward_host := request.headers.get("X-Forwarded-Host"): + ingress_path = request.headers.get("X-Ingress-Path", "") + forward_proto = request.headers.get("X-Forwarded-Proto", request.protocol) + self.base_url = f"{forward_proto}://{forward_host}{ingress_path}" + + async def disconnect(self) -> None: + """Disconnect client.""" + self._cancel() + if self._writer_task is not None: + await self._writer_task + + async def handle_client(self) -> web.WebSocketResponse: + """Handle a websocket response.""" + # ruff: noqa: PLR0915 + request = self.request + wsock = self.wsock + try: + async with asyncio.timeout(10): + await wsock.prepare(request) + except TimeoutError: + self._logger.warning("Timeout preparing request from %s", request.remote) + return wsock + + self._logger.log(VERBOSE_LOG_LEVEL, "Connection from %s", request.remote) + self._handle_task = asyncio.current_task() + self._writer_task = self.mass.create_task(self._writer()) + + # send server(version) info when client connects + server_info = self.mass.get_server_info() + await self._send_message(server_info) + + # For Ingress connections, auto-create/link user and subscribe to events immediately + # For regular connections, events will be subscribed after successful authentication + if self._is_ingress: + await self._handle_ingress_auth() + self._subscribe_to_events() + + disconnect_warn = None + + try: + while not wsock.closed: + msg = await wsock.receive() + + if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): + break + + if msg.type != WSMsgType.TEXT: + continue + + self._logger.log(VERBOSE_LOG_LEVEL, "Received: %s", msg.data) + + try: + command_msg = CommandMessage.from_json(msg.data) + except ValueError: + disconnect_warn = f"Received invalid JSON: {msg.data}" + break + + await self._handle_command(command_msg) + + except asyncio.CancelledError: + self._logger.debug("Connection closed by client") + + except Exception: + self._logger.exception("Unexpected error inside websocket API") + + finally: + # Handle connection shutting down. + if self._events_unsub_callback: + self._events_unsub_callback() + self._logger.log(VERBOSE_LOG_LEVEL, "Unsubscribed from events") + + # Unregister from webserver tracking + self.webserver.unregister_websocket_client(self) + + try: + self._to_write.put_nowait(None) + # Make sure all error messages are written before closing + await self._writer_task + await wsock.close() + except asyncio.QueueFull: # can be raised by put_nowait + self._writer_task.cancel() + + finally: + if disconnect_warn is None: + self._logger.log(VERBOSE_LOG_LEVEL, "Disconnected") + else: + self._logger.warning("Disconnected: %s", disconnect_warn) + + return wsock + + async def _handle_command(self, msg: CommandMessage) -> None: + """Handle an incoming command from the client.""" + self._logger.debug("Handling command %s", msg.command) + + # Handle special "auth" command + if msg.command == "auth": + await self._handle_auth_command(msg) + return + + # work out handler for the given path/command + handler = self.mass.command_handlers.get(msg.command) + + if handler is None: + await self._send_message( + ErrorResultMessage( + msg.message_id, + InvalidCommand.error_code, + f"Invalid command: {msg.command}", + ) + ) + self._logger.warning("Invalid command: %s", msg.command) + return + + # Check authentication if required + if handler.authenticated or handler.required_role: + # For Ingress, user should already be set from _handle_ingress_auth + # For regular connections, user must be set via auth command + if self._authenticated_user is None: + await self._send_message( + ErrorResultMessage( + msg.message_id, + AuthenticationRequired.error_code, + "Authentication required. Please send auth command first.", + ) + ) + return + + # Set user and token in context for API methods + set_current_user(self._authenticated_user) + set_current_token(self._current_token) + + # Check role if required + if handler.required_role == "admin": + if self._authenticated_user.role != UserRole.ADMIN: + await self._send_message( + ErrorResultMessage( + msg.message_id, + InsufficientPermissions.error_code, + "Admin access required", + ) + ) + return + + # schedule task to handle the command + self.mass.create_task(self._run_handler(handler, msg)) + + async def _run_handler(self, handler: APICommandHandler, msg: CommandMessage) -> None: + """Run command handler and send response.""" + try: + args = parse_arguments(handler.signature, handler.type_hints, msg.args) + result: Any = handler.target(**args) + if hasattr(result, "__anext__"): + # handle async generator (for really large listings) + items: list[Any] = [] + async for item in result: + items.append(item) + if len(items) >= 500: + await self._send_message( + SuccessResultMessage(msg.message_id, items, partial=True) + ) + items = [] + result = items + elif asyncio.iscoroutine(result): + result = await result + await self._send_message(SuccessResultMessage(msg.message_id, result)) + except Exception as err: + if self._logger.isEnabledFor(logging.DEBUG): + self._logger.exception("Error handling message: %s", msg) + else: + self._logger.error("Error handling message: %s: %s", msg.command, str(err)) + err_msg = str(err) or err.__class__.__name__ + await self._send_message( + ErrorResultMessage(msg.message_id, getattr(err, "error_code", 999), err_msg) + ) + + async def _writer(self) -> None: + """Write outgoing messages.""" + # Exceptions if Socket disconnected or cancelled by connection handler + with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS): + while not self.wsock.closed: + if (process := await self._to_write.get()) is None: + break + + if callable(process): + message: str = process() + else: + message = process + self._logger.log(VERBOSE_LOG_LEVEL, "Writing: %s", message) + await self.wsock.send_str(message) + + async def _send_message(self, message: MessageType) -> None: + """Send a message to the client (for large response messages). + + Runs JSON serialization in executor to avoid blocking for large messages. + Closes connection if the client is not reading the messages. + + Async friendly. + """ + # Run JSON serialization in executor to avoid blocking for large messages + loop = asyncio.get_running_loop() + _message = await loop.run_in_executor(None, message.to_json) + + try: + self._to_write.put_nowait(_message) + except asyncio.QueueFull: + self._logger.error("Client exceeded max pending messages: %s", MAX_PENDING_MSG) + + self._cancel() + + def _send_message_sync(self, message: MessageType) -> None: + """Send a message from a sync context (for small messages like events). + + Serializes inline without executor overhead since events are typically small. + """ + _message = message.to_json() + + try: + self._to_write.put_nowait(_message) + except asyncio.QueueFull: + self._logger.error("Client exceeded max pending messages: %s", MAX_PENDING_MSG) + + self._cancel() + + async def _handle_auth_command(self, msg: CommandMessage) -> None: + """Handle WebSocket authentication command. + + :param msg: The auth command message with access token. + """ + # Extract token from args (support both 'token' and 'access_token' for backward compat) + token = msg.args.get("token") if msg.args else None + if not token: + token = msg.args.get("access_token") if msg.args else None + if not token: + await self._send_message( + ErrorResultMessage( + msg.message_id, + AuthenticationRequired.error_code, + "token required in args", + ) + ) + return + + # Authenticate with token + user = await self.webserver.auth.authenticate_with_token(token) + if not user: + await self._send_message( + ErrorResultMessage( + msg.message_id, + InvalidToken.error_code, + "Invalid or expired token", + ) + ) + return + + # Security: Deny homeassistant system user on regular (non-Ingress) webserver + if not self._is_ingress and user.username == HOMEASSISTANT_SYSTEM_USER: + await self._send_message( + ErrorResultMessage( + msg.message_id, + InvalidToken.error_code, + "Home Assistant system user not allowed on regular webserver", + ) + ) + return + + # Get token_id for tracking revocation events + token_id = await self.webserver.auth.get_token_id_from_token(token) + + # Store authenticated user, token, and token_id + self._authenticated_user = user + self._current_token = token + self._token_id = token_id + self._logger.info("WebSocket client authenticated as %s", user.username) + + # Send success response + await self._send_message( + SuccessResultMessage( + msg.message_id, + {"authenticated": True, "user": user.to_dict()}, + ) + ) + + # Subscribe to events after successful authentication + self._subscribe_to_events() + + # Register with webserver for tracking + self.webserver.register_websocket_client(self) + + async def _handle_ingress_auth(self) -> None: + """Handle authentication for Ingress connections (auto-create/link user).""" + ingress_user_id = self.request.headers.get("X-Remote-User-ID") + ingress_username = self.request.headers.get("X-Remote-User-Name") + ingress_display_name = self.request.headers.get("X-Remote-User-Display-Name") + + if ingress_user_id and ingress_username: + # Try to find existing user linked to this HA user ID + user = await self.webserver.auth.get_user_by_provider_link( + AuthProviderType.HOME_ASSISTANT, ingress_user_id + ) + + if not user: + # Security: Ensure at least one user exists (setup should have been completed) + if not await self.webserver.auth.has_users(): + # No users exist - setup has not been completed + # This should not happen as the server redirects to /setup + self._logger.warning("Ingress connection attempted before setup completed") + return + + # Auto-create user for Ingress (they're already authenticated by HA) + # Always create with USER role (admin is created during setup) + user = await self.webserver.auth.create_user( + username=ingress_username, + role=UserRole.USER, + display_name=ingress_display_name, + ) + # Link to Home Assistant provider + await self.webserver.auth.link_user_to_provider( + user, AuthProviderType.HOME_ASSISTANT, ingress_user_id + ) + + self._authenticated_user = user + self._logger.debug("Ingress user authenticated: %s", user.username) + else: + # No HA user headers - allow homeassistant system user to connect with token + # This allows the Home Assistant integration to connect via the internal network + # The token authentication happens in _handle_auth_message + self._logger.debug("Ingress connection without user headers, expecting token auth") + + def _subscribe_to_events(self) -> None: + """Subscribe to Mass events and forward them to the client.""" + if self._events_unsub_callback is not None: + # Already subscribed + return + + def handle_event(event: Any) -> None: + # event is MassEvent but we use Any to avoid runtime import + self._send_message_sync(event) + + self._events_unsub_callback = self.mass.subscribe(handle_event) + self._logger.debug("Subscribed to events") + + def _cancel(self) -> None: + """Cancel the connection.""" + if self._handle_task is not None: + self._handle_task.cancel() + if self._writer_task is not None: + self._writer_task.cancel() diff --git a/music_assistant/helpers/api.py b/music_assistant/helpers/api.py index 53403986..2c382958 100644 --- a/music_assistant/helpers/api.py +++ b/music_assistant/helpers/api.py @@ -2,6 +2,7 @@ from __future__ import annotations +import importlib import inspect import logging from collections.abc import AsyncGenerator, Callable, Coroutine @@ -19,6 +20,161 @@ LOGGER = logging.getLogger(__name__) _F = TypeVar("_F", bound=Callable[..., Any]) +# Cache for resolved type alias strings to avoid repeated imports +_TYPE_ALIAS_CACHE: dict[str, Any] = {} + + +def _resolve_string_type(type_str: str) -> Any: + """ + Resolve a string type reference back to the actual type. + + This is needed when type aliases like ConfigValueType are converted to strings + during type hint resolution to avoid isinstance() errors with complex unions. + + Uses a module-level cache to avoid repeated imports. + + :param type_str: String name of the type (e.g., "ConfigValueType"). + :return: The actual type object, or the string if resolution fails. + """ + # Check cache first + if type_str in _TYPE_ALIAS_CACHE: + return _TYPE_ALIAS_CACHE[type_str] + + type_alias_map = { + "ConfigValueType": ("music_assistant_models.config_entries", "ConfigValueType"), + "MediaItemType": ("music_assistant_models.media_items", "MediaItemType"), + } + + if type_str not in type_alias_map: + # Cache the string itself for unknown types + _TYPE_ALIAS_CACHE[type_str] = type_str + return type_str + + module_name, type_name = type_alias_map[type_str] + try: + module = importlib.import_module(module_name) + resolved_type = getattr(module, type_name) + # Cache the successfully resolved type + _TYPE_ALIAS_CACHE[type_str] = resolved_type + return resolved_type + except (ImportError, AttributeError) as err: + LOGGER.warning("Failed to resolve type alias %s: %s", type_str, err) + # Cache the string to avoid repeated failed attempts + _TYPE_ALIAS_CACHE[type_str] = type_str + return type_str + + +def _resolve_generic_type_args( + args: tuple[Any, ...], + func: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]], + config_value_type: Any, + media_item_type: Any, +) -> tuple[list[Any], bool]: + """Resolve TypeVars and type aliases in generic type arguments. + + :param args: Type arguments from a generic type (e.g., from list[T] or dict[K, V]) + :param func: The function being analyzed + :param config_value_type: The ConfigValueType type alias to compare against + :param media_item_type: The MediaItemType type alias to compare against + :return: Tuple of (resolved_args, changed) where changed is True if any args were modified + """ + new_args: list[Any] = [] + changed = False + + for arg in args: + # Check if arg matches ConfigValueType union (type alias that was expanded) + if arg == config_value_type: + # Replace with string reference to preserve type alias + new_args.append("ConfigValueType") + changed = True + # Check if arg matches MediaItemType union (type alias that was expanded) + elif arg == media_item_type: + # Replace with string reference to preserve type alias + new_args.append("MediaItemType") + changed = True + elif isinstance(arg, TypeVar): + # For ItemCls, resolve to concrete type + if arg.__name__ == "ItemCls" and hasattr(func, "__self__"): + if hasattr(func.__self__, "item_cls"): + new_args.append(func.__self__.item_cls) + changed = True + else: + new_args.append(arg) + # For ConfigValue TypeVars, resolve to string name + elif "ConfigValue" in arg.__name__: + new_args.append("ConfigValueType") + changed = True + else: + new_args.append(arg) + # Check if arg is a Union containing a TypeVar + elif get_origin(arg) in (Union, UnionType): + union_args = get_args(arg) + for union_arg in union_args: + if isinstance(union_arg, TypeVar) and union_arg.__bound__ is not None: + # Resolve the TypeVar in the union + union_arg_index = union_args.index(union_arg) + resolved = _resolve_typevar_in_union( + union_arg, func, union_args, union_arg_index + ) + new_args.append(resolved) + changed = True + break + else: + # No TypeVar found in union, keep as-is + new_args.append(arg) + else: + new_args.append(arg) + + return new_args, changed + + +def _resolve_typevar_in_union( + arg: TypeVar, + func: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]], + args: tuple[Any, ...], + i: int, +) -> Any: + """Resolve a TypeVar found in a Union to its concrete type. + + :param arg: The TypeVar to resolve. + :param func: The function being analyzed. + :param args: All args from the Union. + :param i: Index of the TypeVar in the args. + """ + bound_type = arg.__bound__ + if not bound_type or not hasattr(arg, "__name__"): + return bound_type + + type_var_name = arg.__name__ + + # Map TypeVar names to their type alias names + if "ConfigValue" in type_var_name: + return "ConfigValueType" + + if type_var_name == "ItemCls": + # Resolve ItemCls to the actual media item class (e.g., Artist, Album, Track) + if hasattr(func, "__self__") and hasattr(func.__self__, "item_cls"): + resolved_type = func.__self__.item_cls + # Preserve other types in the union (like None for Optional) + other_args = [a for j, a in enumerate(args) if j != i] + if other_args: + # Reconstruct union with resolved type + return Union[resolved_type, *other_args] + return resolved_type + # Fallback to bound if we can't get item_cls + return bound_type + + # Check if the bound is MediaItemType by comparing the union + from music_assistant_models.media_items import ( # noqa: PLC0415 + MediaItemType as media_item_type, # noqa: N813 + ) + + if bound_type == media_item_type: + return "MediaItemType" + + # Fallback to the bound type + return bound_type + @dataclass class APICommandHandler: @@ -28,34 +184,106 @@ class APICommandHandler: signature: inspect.Signature type_hints: dict[str, Any] target: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]] + authenticated: bool = True + required_role: str | None = None # "admin" or "user" or None + alias: bool = False # If True, this is an alias for backward compatibility @classmethod def parse( - cls, command: str, func: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]] + cls, + command: str, + func: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]], + authenticated: bool = True, + required_role: str | None = None, + alias: bool = False, ) -> APICommandHandler: - """Parse APICommandHandler by providing a function.""" + """Parse APICommandHandler by providing a function. + + :param command: The command name/path. + :param func: The function to handle the command. + :param authenticated: Whether authentication is required (default: True). + :param required_role: Required user role ("admin" or "user") + None for any authenticated user. + :param alias: Whether this is an alias for backward compatibility (default: False). + """ type_hints = get_type_hints(func) # workaround for generic typevar ItemCls that needs to be resolved # to the real media item type. TODO: find a better way to do this # without this hack + # Import type aliases to compare against + from music_assistant_models.config_entries import ( # noqa: PLC0415 + ConfigValueType as config_value_type, # noqa: N813 + ) + from music_assistant_models.media_items import ( # noqa: PLC0415 + MediaItemType as media_item_type, # noqa: N813 + ) + for key, value in type_hints.items(): + # Handle generic types (list, tuple, dict, etc.) that may contain TypeVars + # For example: list[ItemCls] should become list[Artist] + # For example: dict[str, ConfigValueType] should preserve ConfigValueType + origin = get_origin(value) + if origin in (list, tuple, set, frozenset, dict): + args = get_args(value) + if args: + new_args, changed = _resolve_generic_type_args( + args, func, config_value_type, media_item_type + ) + if changed: + # Reconstruct the generic type with resolved TypeVars + type_hints[key] = origin[tuple(new_args)] + continue + + # Handle Union types that may contain TypeVars + # For example: _ConfigValueT | ConfigValueType should become just "ConfigValueType" + # when _ConfigValueT is bound to ConfigValueType + if origin is Union or origin is UnionType: + args = get_args(value) + # Check if union contains a TypeVar + # If the TypeVar's bound is a union that was flattened into the current union, + # we can just use the bound type for documentation purposes + typevar_found = False + for i, arg in enumerate(args): + if isinstance(arg, TypeVar) and arg.__bound__ is not None: + typevar_found = True + type_hints[key] = _resolve_typevar_in_union(arg, func, args, i) + break + if typevar_found: + continue if not hasattr(value, "__name__"): continue if value.__name__ == "ItemCls": type_hints[key] = func.__self__.item_cls # type: ignore[attr-defined] + # Resolve TypeVars to their bound type for API documentation + # This handles cases like _ConfigValueT which should show as ConfigValueType + elif isinstance(value, TypeVar): + if value.__bound__ is not None: + type_hints[key] = value.__bound__ return APICommandHandler( command=command, signature=inspect.signature(func), type_hints=type_hints, target=func, + authenticated=authenticated, + required_role=required_role, + alias=alias, ) -def api_command(command: str) -> Callable[[_F], _F]: - """Decorate a function as API route/command.""" +def api_command( + command: str, authenticated: bool = True, required_role: str | None = None +) -> Callable[[_F], _F]: + """Decorate a function as API route/command. + + :param command: The command name/path. + :param authenticated: Whether authentication is required (default: True). + :param required_role: Required user role ("admin" or "user"), None means any authenticated user. + """ def decorate(func: _F) -> _F: func.api_cmd = command # type: ignore[attr-defined] + func.api_authenticated = authenticated # type: ignore[attr-defined] + func.api_required_role = required_role # type: ignore[attr-defined] return func return decorate @@ -103,6 +331,14 @@ def parse_value( # noqa: PLR0911 allow_value_convert: bool = False, ) -> Any: """Try to parse a value from raw (json) data and type annotations.""" + # Resolve string type hints early for proper handling + if isinstance(value_type, str): + value_type = _resolve_string_type(value_type) + # If still a string after resolution, return value as-is + if isinstance(value_type, str): + LOGGER.debug("Unknown string type hint: %s, returning value as-is", value_type) + return value + if isinstance(value, dict) and hasattr(value_type, "from_dict"): if ( "media_type" in value diff --git a/music_assistant/helpers/api_docs.py b/music_assistant/helpers/api_docs.py deleted file mode 100644 index ba5dc6b2..00000000 --- a/music_assistant/helpers/api_docs.py +++ /dev/null @@ -1,2424 +0,0 @@ -"""Helpers for generating API documentation and OpenAPI specifications.""" - -from __future__ import annotations - -import collections.abc -import inspect -from collections.abc import Callable -from dataclasses import MISSING -from datetime import datetime -from enum import Enum -from types import NoneType, UnionType -from typing import Any, Union, get_args, get_origin, get_type_hints - -from music_assistant_models.player import Player as PlayerState - -from music_assistant.helpers.api import APICommandHandler - - -def _format_type_name(type_hint: Any) -> str: - """Format a type hint as a user-friendly string, using JSON types instead of Python types.""" - if type_hint is NoneType or type_hint is type(None): - return "null" - - # Handle internal Player model - replace with PlayerState - if hasattr(type_hint, "__name__") and type_hint.__name__ == "Player": - if ( - hasattr(type_hint, "__module__") - and type_hint.__module__ == "music_assistant.models.player" - ): - return "PlayerState" - - # Map Python types to JSON types - type_name_mapping = { - "str": "string", - "int": "integer", - "float": "number", - "bool": "boolean", - "dict": "object", - "list": "array", - "tuple": "array", - "set": "array", - "frozenset": "array", - "Sequence": "array", - "UniqueList": "array", - "None": "null", - } - - if hasattr(type_hint, "__name__"): - type_name = str(type_hint.__name__) - return type_name_mapping.get(type_name, type_name) - - type_str = str(type_hint).replace("NoneType", "null") - # Replace Python types with JSON types in complex type strings - for python_type, json_type in type_name_mapping.items(): - type_str = type_str.replace(python_type, json_type) - return type_str - - -def _get_type_schema( # noqa: PLR0911, PLR0915 - type_hint: Any, definitions: dict[str, Any] -) -> dict[str, Any]: - """Convert a Python type hint to an OpenAPI schema.""" - # Handle string type hints from __future__ annotations - if isinstance(type_hint, str): - # Handle simple primitive type names - if type_hint in ("str", "string"): - return {"type": "string"} - if type_hint in ("int", "integer"): - return {"type": "integer"} - if type_hint in ("float", "number"): - return {"type": "number"} - if type_hint in ("bool", "boolean"): - return {"type": "boolean"} - - # Check if it looks like a simple class name (no special chars, starts with uppercase) - # Examples: "PlayerType", "DeviceInfo", "PlaybackState" - # Exclude generic types like "Any", "Union", "Optional", etc. - excluded_types = {"Any", "Union", "Optional", "List", "Dict", "Tuple", "Set"} - if type_hint.isidentifier() and type_hint[0].isupper() and type_hint not in excluded_types: - # Create a schema reference for this type - if type_hint not in definitions: - definitions[type_hint] = {"type": "object"} - return {"$ref": f"#/components/schemas/{type_hint}"} - - # If it's "Any", return generic object without creating a schema - if type_hint == "Any": - return {"type": "object"} - - # For complex type expressions like "str | None", "list[str]", return generic object - return {"type": "object"} - - # Handle None type - if type_hint is NoneType or type_hint is type(None): - return {"type": "null"} - - # Handle internal Player model - replace with external PlayerState - if hasattr(type_hint, "__name__") and type_hint.__name__ == "Player": - # Check if this is the internal Player (from music_assistant.models.player) - if ( - hasattr(type_hint, "__module__") - and type_hint.__module__ == "music_assistant.models.player" - ): - # Replace with PlayerState from music_assistant_models - return _get_type_schema(PlayerState, definitions) - - # Handle Union types (including Optional) - origin = get_origin(type_hint) - if origin is Union or origin is UnionType: - args = get_args(type_hint) - # Check if it's Optional (Union with None) - non_none_args = [arg for arg in args if arg not in (NoneType, type(None))] - if (len(non_none_args) == 1 and NoneType in args) or type(None) in args: - # It's Optional[T], make it nullable - schema = _get_type_schema(non_none_args[0], definitions) - schema["nullable"] = True - return schema - # It's a union of multiple types - return {"oneOf": [_get_type_schema(arg, definitions) for arg in args]} - - # Handle UniqueList (treat as array) - if hasattr(type_hint, "__name__") and type_hint.__name__ == "UniqueList": - args = get_args(type_hint) - if args: - return {"type": "array", "items": _get_type_schema(args[0], definitions)} - return {"type": "array", "items": {}} - - # Handle Sequence types (from collections.abc or typing) - if origin is collections.abc.Sequence or ( - hasattr(origin, "__name__") and origin.__name__ == "Sequence" - ): - args = get_args(type_hint) - if args: - return {"type": "array", "items": _get_type_schema(args[0], definitions)} - return {"type": "array", "items": {}} - - # Handle set/frozenset types - if origin in (set, frozenset): - args = get_args(type_hint) - if args: - return {"type": "array", "items": _get_type_schema(args[0], definitions)} - return {"type": "array", "items": {}} - - # Handle list/tuple types - if origin in (list, tuple): - args = get_args(type_hint) - if args: - return {"type": "array", "items": _get_type_schema(args[0], definitions)} - return {"type": "array", "items": {}} - - # Handle dict types - if origin is dict: - args = get_args(type_hint) - if len(args) == 2: - return { - "type": "object", - "additionalProperties": _get_type_schema(args[1], definitions), - } - return {"type": "object", "additionalProperties": True} - - # Handle Enum types - add them to definitions as explorable objects - if inspect.isclass(type_hint) and issubclass(type_hint, Enum): - enum_name = type_hint.__name__ - if enum_name not in definitions: - enum_values = [item.value for item in type_hint] - enum_type = type(enum_values[0]).__name__ if enum_values else "string" - openapi_type = { - "str": "string", - "int": "integer", - "float": "number", - "bool": "boolean", - }.get(enum_type, "string") - - # Create a detailed enum definition with descriptions - enum_values_str = ", ".join(str(v) for v in enum_values) - definitions[enum_name] = { - "type": openapi_type, - "enum": enum_values, - "description": f"Enum: {enum_name}. Possible values: {enum_values_str}", - } - return {"$ref": f"#/components/schemas/{enum_name}"} - - # Handle datetime - if type_hint is datetime: - return {"type": "string", "format": "date-time"} - - # Handle primitive types - check both exact type and type name - if type_hint is str or (hasattr(type_hint, "__name__") and type_hint.__name__ == "str"): - return {"type": "string"} - if type_hint is int or (hasattr(type_hint, "__name__") and type_hint.__name__ == "int"): - return {"type": "integer"} - if type_hint is float or (hasattr(type_hint, "__name__") and type_hint.__name__ == "float"): - return {"type": "number"} - if type_hint is bool or (hasattr(type_hint, "__name__") and type_hint.__name__ == "bool"): - return {"type": "boolean"} - - # Handle complex types (dataclasses, models) - # Check for __annotations__ or if it's a class (not already handled above) - if hasattr(type_hint, "__annotations__") or ( - inspect.isclass(type_hint) and not issubclass(type_hint, (str, int, float, bool, Enum)) - ): - type_name = getattr(type_hint, "__name__", str(type_hint)) - # Add to definitions if not already there - if type_name not in definitions: - properties = {} - required = [] - - # Check if this is a dataclass with fields - if hasattr(type_hint, "__dataclass_fields__"): - # Resolve type hints to handle forward references from __future__ annotations - try: - resolved_hints = get_type_hints(type_hint) - except Exception: - resolved_hints = {} - - # Use dataclass fields to get proper info including defaults and metadata - for field_name, field_info in type_hint.__dataclass_fields__.items(): - # Skip fields marked with serialize="omit" in metadata - if field_info.metadata: - # Check for mashumaro field_options - if "serialize" in field_info.metadata: - if field_info.metadata["serialize"] == "omit": - continue - - # Use resolved type hint if available, otherwise fall back to field type - field_type = resolved_hints.get(field_name, field_info.type) - field_schema = _get_type_schema(field_type, definitions) - - # Add default value if present - if field_info.default is not MISSING: - field_schema["default"] = field_info.default - elif ( - hasattr(field_info, "default_factory") - and field_info.default_factory is not MISSING - ): - # Has a default factory - don't add anything, just skip - pass - - properties[field_name] = field_schema - - # Check if field is required (not Optional and no default) - has_default = field_info.default is not MISSING or ( - hasattr(field_info, "default_factory") - and field_info.default_factory is not MISSING - ) - is_optional = get_origin(field_type) in ( - Union, - UnionType, - ) and NoneType in get_args(field_type) - if not has_default and not is_optional: - required.append(field_name) - elif hasattr(type_hint, "__annotations__"): - # Fallback for non-dataclass types with annotations - for field_name, field_type in type_hint.__annotations__.items(): - properties[field_name] = _get_type_schema(field_type, definitions) - # Check if field is required (not Optional) - if not ( - get_origin(field_type) in (Union, UnionType) - and NoneType in get_args(field_type) - ): - required.append(field_name) - else: - # Class without dataclass fields or annotations - treat as generic object - pass # Will create empty properties - - definitions[type_name] = { - "type": "object", - "properties": properties, - } - if required: - definitions[type_name]["required"] = required - - return {"$ref": f"#/components/schemas/{type_name}"} - - # Handle Any - if type_hint is Any: - return {"type": "object"} - - # Fallback - for types we don't recognize, at least return a generic object type - return {"type": "object"} - - -def _parse_docstring( # noqa: PLR0915 - func: Callable[..., Any], -) -> tuple[str, str, dict[str, str]]: - """Parse docstring to extract summary, description and parameter descriptions. - - Returns: - Tuple of (short_summary, full_description, param_descriptions) - - Handles multiple docstring formats: - - reStructuredText (:param name: description) - - Google style (Args: section) - - NumPy style (Parameters section) - """ - docstring = inspect.getdoc(func) - if not docstring: - return "", "", {} - - lines = docstring.split("\n") - description_lines = [] - param_descriptions = {} - current_section = "description" - current_param = None - - for line in lines: - stripped = line.strip() - - # Check for section headers - if stripped.lower() in ("args:", "arguments:", "parameters:", "params:"): - current_section = "params" - current_param = None - continue - if stripped.lower() in ( - "returns:", - "return:", - "yields:", - "raises:", - "raises", - "examples:", - "example:", - "note:", - "notes:", - "see also:", - "warning:", - "warnings:", - ): - current_section = "other" - current_param = None - continue - - # Parse :param style - if stripped.startswith(":param "): - current_section = "params" - parts = stripped[7:].split(":", 1) - if len(parts) == 2: - current_param = parts[0].strip() - desc = parts[1].strip() - if desc: - param_descriptions[current_param] = desc - continue - - if stripped.startswith((":type ", ":rtype", ":return")): - current_section = "other" - current_param = None - continue - - # Detect bullet-style params even without explicit section header - # Format: "- param_name: description" - if stripped.startswith("- ") and ":" in stripped: - # This is likely a bullet-style parameter - current_section = "params" - content = stripped[2:] # Remove "- " - parts = content.split(":", 1) - param_name = parts[0].strip() - desc_part = parts[1].strip() if len(parts) > 1 else "" - if param_name and not param_name.startswith(("return", "yield", "raise")): - current_param = param_name - if desc_part: - param_descriptions[current_param] = desc_part - continue - - # In params section, detect param lines (indented or starting with name) - if current_section == "params" and stripped: - # Google/NumPy style: "param_name: description" or "param_name (type): description" - if ":" in stripped and not stripped.startswith(" "): - # Likely a parameter definition - if "(" in stripped and ")" in stripped: - # Format: param_name (type): description - param_part = stripped.split(":")[0] - param_name = param_part.split("(")[0].strip() - desc_part = ":".join(stripped.split(":")[1:]).strip() - else: - # Format: param_name: description - parts = stripped.split(":", 1) - param_name = parts[0].strip() - desc_part = parts[1].strip() if len(parts) > 1 else "" - - if param_name and not param_name.startswith(("return", "yield", "raise")): - current_param = param_name - if desc_part: - param_descriptions[current_param] = desc_part - elif current_param and stripped: - # Continuation of previous parameter description - param_descriptions[current_param] = ( - param_descriptions.get(current_param, "") + " " + stripped - ).strip() - continue - - # Collect description lines (only before params/returns sections) - if current_section == "description" and stripped: - description_lines.append(stripped) - elif current_section == "description" and not stripped and description_lines: - # Empty line in description - keep it for paragraph breaks - description_lines.append("") - - # Join description lines, removing excessive empty lines - description = "\n".join(description_lines).strip() - # Collapse multiple empty lines into one - while "\n\n\n" in description: - description = description.replace("\n\n\n", "\n\n") - - # Extract first sentence/line as summary - summary = "" - if description: - # Get first line or first sentence (whichever is shorter) - first_line = description.split("\n")[0] - # Try to get first sentence (ending with .) - summary = first_line.split(".")[0] + "." if "." in first_line else first_line - - return summary, description, param_descriptions - - -def generate_openapi_spec( - command_handlers: dict[str, APICommandHandler], - server_url: str = "http://localhost:8095", - version: str = "1.0.0", -) -> dict[str, Any]: - """Generate simplified OpenAPI 3.0 specification focusing on data models. - - This spec documents the single /api endpoint and all data models/schemas. - For detailed command documentation, see the Commands Reference page. - """ - definitions: dict[str, Any] = {} - - # Build all schemas from command handlers (this populates definitions) - for handler in command_handlers.values(): - # Build parameter schemas - for param_name in handler.signature.parameters: - if param_name == "self": - continue - # Skip return_type parameter (used only for type hints) - if param_name == "return_type": - continue - param_type = handler.type_hints.get(param_name, Any) - # Skip Any types as they don't provide useful schema information - if param_type is not Any and str(param_type) != "typing.Any": - _get_type_schema(param_type, definitions) - - # Build return type schema - return_type = handler.type_hints.get("return", Any) - # Skip Any types as they don't provide useful schema information - if return_type is not Any and str(return_type) != "typing.Any": - _get_type_schema(return_type, definitions) - - # Build a single /api endpoint with generic request/response - paths = { - "/api": { - "post": { - "summary": "Execute API command", - "description": ( - "Execute any Music Assistant API command.\n\n" - "See the **Commands Reference** page for a complete list of available " - "commands with examples." - ), - "operationId": "execute_command", - "requestBody": { - "required": True, - "content": { - "application/json": { - "schema": { - "type": "object", - "required": ["command"], - "properties": { - "command": { - "type": "string", - "description": ( - "The command to execute (e.g., 'players/all')" - ), - "example": "players/all", - }, - "args": { - "type": "object", - "description": "Command arguments (varies by command)", - "additionalProperties": True, - "example": {}, - }, - }, - }, - "examples": { - "get_players": { - "summary": "Get all players", - "value": {"command": "players/all", "args": {}}, - }, - "play_media": { - "summary": "Play media on a player", - "value": { - "command": "players/cmd/play", - "args": {"player_id": "player123"}, - }, - }, - }, - } - }, - }, - "responses": { - "200": { - "description": "Successful command execution", - "content": { - "application/json": { - "schema": {"description": "Command result (varies by command)"} - } - }, - }, - "400": {"description": "Bad request - invalid command or parameters"}, - "500": {"description": "Internal server error"}, - }, - } - } - } - - # Build OpenAPI spec - return { - "openapi": "3.0.0", - "info": { - "title": "Music Assistant API", - "version": version, - "description": ( - "Music Assistant API provides control over your music library, " - "players, and playback.\n\n" - "This specification documents the API structure and data models. " - "For a complete list of available commands with examples, " - "see the Commands Reference page." - ), - "contact": { - "name": "Music Assistant", - "url": "https://music-assistant.io", - }, - }, - "servers": [{"url": server_url, "description": "Music Assistant Server"}], - "paths": paths, - "components": {"schemas": definitions}, - } - - -def _split_union_type(type_str: str) -> list[str]: - """Split a union type on | but respect brackets and parentheses. - - This ensures that list[A | B] and (A | B) are not split at the inner |. - """ - parts = [] - current_part = "" - bracket_depth = 0 - paren_depth = 0 - i = 0 - while i < len(type_str): - char = type_str[i] - if char == "[": - bracket_depth += 1 - current_part += char - elif char == "]": - bracket_depth -= 1 - current_part += char - elif char == "(": - paren_depth += 1 - current_part += char - elif char == ")": - paren_depth -= 1 - current_part += char - elif char == "|" and bracket_depth == 0 and paren_depth == 0: - # Check if this is a union separator (has space before and after) - if ( - i > 0 - and i < len(type_str) - 1 - and type_str[i - 1] == " " - and type_str[i + 1] == " " - ): - parts.append(current_part.strip()) - current_part = "" - i += 1 # Skip the space after |, the loop will handle incrementing i - else: - current_part += char - else: - current_part += char - i += 1 - if current_part.strip(): - parts.append(current_part.strip()) - return parts - - -def _python_type_to_json_type(type_str: str, _depth: int = 0) -> str: - """Convert Python type string to JSON/JavaScript type string. - - Args: - type_str: The type string to convert - _depth: Internal recursion depth tracker (do not set manually) - """ - import re # noqa: PLC0415 - - # Prevent infinite recursion - if _depth > 50: - return "any" - - # Remove typing module prefix and class markers - type_str = type_str.replace("typing.", "").replace("", "") - - # Remove module paths from type names (e.g., "music_assistant.models.Artist" -> "Artist") - type_str = re.sub(r"[\w.]+\.(\w+)", r"\1", type_str) - - # Map Python types to JSON types - type_mappings = { - "str": "string", - "int": "integer", - "float": "number", - "bool": "boolean", - "dict": "object", - "Dict": "object", - "None": "null", - "NoneType": "null", - } - - # Check for List/list/UniqueList with type parameter BEFORE checking for union types - # This is important because list[A | B] contains " | " but should be handled as a list first - # We need to match list[...] where the brackets are balanced - if type_str.startswith(("list[", "List[", "UniqueList[")): # codespell:ignore - # Find the matching closing bracket - bracket_count = 0 - start_idx = type_str.index("[") + 1 - end_idx = -1 - for i in range(start_idx, len(type_str)): - if type_str[i] == "[": - bracket_count += 1 - elif type_str[i] == "]": - if bracket_count == 0: - end_idx = i - break - bracket_count -= 1 - - # Check if this is a complete list type (ends with the closing bracket) - if end_idx == len(type_str) - 1: - inner_type = type_str[start_idx:end_idx].strip() - # Recursively convert the inner type - inner_json_type = _python_type_to_json_type(inner_type, _depth + 1) - # For list[A | B], wrap in parentheses to keep it as one unit - # This prevents "Array of A | B" from being split into separate union parts - if " | " in inner_json_type: - return f"Array of ({inner_json_type})" - return f"Array of {inner_json_type}" - - # Handle Union types by splitting on | and recursively processing each part - if " | " in type_str: - # Use helper to split on | but respect brackets - parts = _split_union_type(type_str) - - # Filter out None types - parts = [part for part in parts if part != "None"] - - # If splitting didn't help (only one part or same as input), avoid infinite recursion - if not parts or (len(parts) == 1 and parts[0] == type_str): - # Can't split further, return as-is or "any" - return type_str if parts else "any" - - if parts: - converted_parts = [_python_type_to_json_type(part, _depth + 1) for part in parts] - # Remove duplicates while preserving order - seen = set() - unique_parts = [] - for part in converted_parts: - if part not in seen: - seen.add(part) - unique_parts.append(part) - return " | ".join(unique_parts) - return "any" - - # Check for Union/Optional types with brackets - if "Union[" in type_str or "Optional[" in type_str: - # Extract content from Union[...] or Optional[...] - union_match = re.search(r"(?:Union|Optional)\[([^\]]+)\]", type_str) - if union_match: - inner = union_match.group(1) - # Recursively process the union content - return _python_type_to_json_type(inner, _depth + 1) - - # Direct mapping for basic types - for py_type, json_type in type_mappings.items(): - if type_str == py_type: - return json_type - - # Check if it's a complex type (starts with capital letter) - complex_match = re.search(r"^([A-Z][a-zA-Z0-9_]*)$", type_str) - if complex_match: - return complex_match.group(1) - - # Default to the original string if no mapping found - return type_str - - -def _make_type_links(type_str: str, server_url: str, as_list: bool = False) -> str: - """Convert type string to HTML with links to schemas reference for complex types. - - Args: - type_str: The type string to convert - server_url: Base server URL for building links - as_list: If True and type contains |, format as "Any of:" bullet list - """ - import re # noqa: PLC0415 - from re import Match # noqa: PLC0415 - - # Find all complex types (capitalized words that aren't basic types) - def replace_type(match: Match[str]) -> str: - type_name = match.group(0) - # Check if it's a complex type (starts with capital letter) - # Exclude basic types and "Array" (which is used in "Array of Type") - excluded = {"Union", "Optional", "List", "Dict", "Array"} - if type_name[0].isupper() and type_name not in excluded: - # Create link to our schemas reference page - schema_url = f"{server_url}/api-docs/schemas#schema-{type_name}" - return f'{type_name}' - return type_name - - # If it's a union type with multiple options and as_list is True, format as bullet list - if as_list and " | " in type_str: - # Use the bracket/parenthesis-aware splitter - parts = _split_union_type(type_str) - # Only use list format if there are 3+ options - if len(parts) >= 3: - html = '
Any of:
    ' - for part in parts: - linked_part = re.sub(r"\b[A-Z][a-zA-Z0-9_]*\b", replace_type, part) - html += f"
  • {linked_part}
  • " - html += "
" - return html - - # Replace complex type names with links - result: str = re.sub(r"\b[A-Z][a-zA-Z0-9_]*\b", replace_type, type_str) - return result - - -def generate_commands_reference( # noqa: PLR0915 - command_handlers: dict[str, APICommandHandler], - server_url: str = "http://localhost:8095", -) -> str: - """Generate HTML commands reference page with all available commands.""" - import json # noqa: PLC0415 - - # Group commands by category - categories: dict[str, list[tuple[str, APICommandHandler]]] = {} - for command, handler in sorted(command_handlers.items()): - category = command.split("/")[0] if "/" in command else "general" - if category not in categories: - categories[category] = [] - categories[category].append((command, handler)) - - html = """ - - - - - Music Assistant API - Commands Reference - - - -
-

Commands Reference

-

Complete list of Music Assistant API commands

-
- - - -
-""" - - for category, commands in sorted(categories.items()): - category_display = category.replace("_", " ").title() - html += f'
\n' - html += f'
{category_display}
\n' - html += '
\n' - - for command, handler in commands: - # Parse docstring - summary, description, param_descriptions = _parse_docstring(handler.target) - - # Get return type - return_type = handler.type_hints.get("return", Any) - return_type_str = _python_type_to_json_type(str(return_type)) - - html += f'
\n' - html += ( - '
\n' - ) - html += '
\n' - html += f'
{command}
\n' - if summary: - summary_escaped = summary.replace("<", "<").replace(">", ">") - html += ( - f'
' - f"{summary_escaped}
\n" - ) - html += "
\n" - html += '
▼
\n' - html += "
\n" - - # Command details (collapsed by default) - html += '
\n' - - if description and description != summary: - desc_escaped = description.replace("<", "<").replace(">", ">") - html += ( - f'
' - f"{desc_escaped}
\n" - ) - - # Return type with links - return_type_html = _make_type_links(return_type_str, server_url) - html += '
\n' - html += ' Returns:\n' - html += f' {return_type_html}\n' # noqa: E501 - html += "
\n" - - # Parameters - params = [] - for param_name, param in handler.signature.parameters.items(): - if param_name == "self": - continue - # Skip return_type parameter (used only for type hints) - if param_name == "return_type": - continue - is_required = param.default is inspect.Parameter.empty - param_type = handler.type_hints.get(param_name, Any) - type_str = str(param_type) - json_type_str = _python_type_to_json_type(type_str) - param_desc = param_descriptions.get(param_name, "") - params.append((param_name, is_required, json_type_str, param_desc)) - - if params: - html += '
\n' - html += '
Parameters:
\n' - for param_name, is_required, type_str, param_desc in params: - # Convert type to HTML with links (use list format for unions) - type_html = _make_type_links(type_str, server_url, as_list=True) - html += '
\n' - html += ( - f' ' - f"{param_name}\n" - ) - if is_required: - html += ( - ' ' - "REQUIRED\n" - ) - # If it's a list format, display it differently - if "
    " in type_html: - html += ( - '
    ' - f"{type_html}
    \n" - ) - else: - html += ( - f' ' - f"{type_html}\n" - ) - if param_desc: - html += ( - f'
    ' - f"{param_desc}
    \n" - ) - html += "
\n" - html += "
\n" - - # Build example curl command with JSON types - example_args: dict[str, Any] = {} - for param_name, is_required, type_str, _ in params: - # Include optional params if few params - if is_required or len(params) <= 2: - if type_str == "string": - example_args[param_name] = "example_value" - elif type_str == "integer": - example_args[param_name] = 0 - elif type_str == "number": - example_args[param_name] = 0.0 - elif type_str == "boolean": - example_args[param_name] = True - elif type_str == "object": - example_args[param_name] = {} - elif type_str == "null": - example_args[param_name] = None - elif type_str.startswith("Array of "): - # Array type with item type specified (e.g., "Array of Artist") - item_type = type_str[9:] # Remove "Array of " - if item_type in {"string", "integer", "number", "boolean"}: - example_args[param_name] = [] - else: - # Complex type array - example_args[param_name] = [ - {"_comment": f"See {item_type} schema in Swagger UI"} - ] - else: - # Complex type (Artist, Player, etc.) - use placeholder object - # Extract the primary type if it's a union (e.g., "Artist | string") - primary_type = type_str.split(" | ")[0] if " | " in type_str else type_str - example_args[param_name] = { - "_comment": f"See {primary_type} schema in Swagger UI" - } - - request_body: dict[str, Any] = {"command": command} - if example_args: - request_body["args"] = example_args - - curl_cmd = ( - f"curl -X POST {server_url}/api \\\n" - ' -H "Content-Type: application/json" \\\n' - f" -d '{json.dumps(request_body, indent=2)}'" - ) - - # Add tabs for curl example and try it - html += '
\n' - html += '
\n' - html += ( - ' \n" - ) - html += ( - ' \n" # noqa: E501 - ) - html += "
\n" - - # cURL tab - html += f'
\n' # noqa: E501 - html += '
\n' - html += ( - ' \n' - ) - html += f"
{curl_cmd}
\n" - html += "
\n" - html += "
\n" - - # Try It tab - html += f'
\n' # noqa: E501 - html += '
\n' - # HTML-escape the JSON for the textarea - json_str = json.dumps(request_body, indent=2) - # Escape HTML entities - json_str_escaped = ( - json_str.replace("&", "&") - .replace("<", "<") - .replace(">", ">") - .replace('"', """) - .replace("'", "'") - ) - html += f' \n' # noqa: E501 - html += ( - f' \n" - ) - html += '
\n' - html += "
\n" - html += "
\n" - - html += "
\n" - # Close command-details div - html += "
\n" - # Close command div - html += "
\n" - - html += "
\n" - html += "
\n" - - html += """
- - - - -""" - - return html - - -def generate_schemas_reference( # noqa: PLR0915 - command_handlers: dict[str, APICommandHandler], -) -> str: - """Generate HTML schemas reference page with all data models.""" - # Collect all unique schemas from commands - schemas: dict[str, Any] = {} - - for handler in command_handlers.values(): - # Collect schemas from parameters - for param_name in handler.signature.parameters: - if param_name == "self": - continue - # Skip return_type parameter (used only for type hints) - if param_name == "return_type": - continue - param_type = handler.type_hints.get(param_name, Any) - if param_type is not Any and str(param_type) != "typing.Any": - _get_type_schema(param_type, schemas) - - # Collect schemas from return type - return_type = handler.type_hints.get("return", Any) - if return_type is not Any and str(return_type) != "typing.Any": - _get_type_schema(return_type, schemas) - - # Build HTML - html = """ - - - - - Music Assistant API - Schemas Reference - - - -
-

Schemas Reference

-

Data models and types used in the Music Assistant API

-
- - - -
- ← Back to API Documentation -""" - - # Add each schema - for schema_name in sorted(schemas.keys()): - schema_def = schemas[schema_name] - html += ( - f'
\n' - ) - html += '
\n' - html += f'
{schema_name}
\n' - html += '
▼
\n' - html += "
\n" - html += '
\n' - - # Add description if available - if "description" in schema_def: - desc = schema_def["description"] - html += f'
{desc}
\n' - - # Add properties if available - if "properties" in schema_def: - html += '
\n' - html += '
Properties:
\n' - - # Get required fields list - required_fields = schema_def.get("required", []) - - for prop_name, prop_def in schema_def["properties"].items(): - html += '
\n' - html += f' {prop_name}\n' - - # Check if field is required - is_required = prop_name in required_fields - - # Check if field is nullable (type is "null" or has null in anyOf/oneOf) - is_nullable = False - if "type" in prop_def and prop_def["type"] == "null": - is_nullable = True - elif "anyOf" in prop_def: - is_nullable = any(item.get("type") == "null" for item in prop_def["anyOf"]) - elif "oneOf" in prop_def: - is_nullable = any(item.get("type") == "null" for item in prop_def["oneOf"]) - - # Add required/optional badge - if is_required: - html += ( - ' REQUIRED\n' - ) - else: - html += ( - ' OPTIONAL\n' - ) - - # Add nullable badge if applicable - if is_nullable: - html += ( - ' NULLABLE\n' - ) - - # Add type - if "type" in prop_def: - prop_type = prop_def["type"] - html += ( - f' {prop_type}\n' - ) - elif "$ref" in prop_def: - # Extract type name from $ref - ref_type = prop_def["$ref"].split("/")[-1] - html += ( - f' ' - f'' - f"{ref_type}\n" - ) - - # Add description - if "description" in prop_def: - prop_desc = prop_def["description"] - html += ( - f'
' - f"{prop_desc}
\n" - ) - - # Add enum values if present - if "enum" in prop_def: - html += '
\n' - html += ( - '
' - "Possible values:
\n" - ) - for enum_val in prop_def["enum"]: - html += ( - f' ' - f"{enum_val}\n" - ) - html += "
\n" - - html += "
\n" - - html += "
\n" - - html += "
\n" - html += "
\n" - - html += """ - -
- - - - -""" - - return html - - -def generate_html_docs( # noqa: PLR0915 - command_handlers: dict[str, APICommandHandler], - server_url: str = "http://localhost:8095", - version: str = "1.0.0", -) -> str: - """Generate HTML documentation from API command handlers.""" - # Group commands by category - categories: dict[str, list[tuple[str, APICommandHandler]]] = {} - for command, handler in sorted(command_handlers.items()): - category = command.split("/")[0] if "/" in command else "general" - if category not in categories: - categories[category] = [] - categories[category].append((command, handler)) - - # Start building HTML - html_parts = [ - """ - - - - - Music Assistant API Documentation - - - -
-
-

Music Assistant API Documentation

-

Version """, - version, - """

-
- -
-

Getting Started

-

Music Assistant provides two ways to interact with the API:

- -

🔌 WebSocket API (Recommended)

-

- The WebSocket API provides full access to all commands - and real-time event updates. -

-
    -
  • Endpoint: ws://""", - server_url.replace("http://", "").replace("https://", ""), - """/ws
  • -
  • - Best for: Applications that need live - updates and real-time communication -
  • -
  • - Bonus: When connected, you automatically - receive event messages for state changes -
  • -
-

Sending commands:

-
{
-  "message_id": "unique-id-123",
-  "command": "players/all",
-  "args": {}
-}
-

Receiving events:

-

- Once connected, you will automatically receive event messages - whenever something changes: -

-
{
-  "event": "player_updated",
-  "data": {
-    "player_id": "player_123",
-    ...player data...
-  }
-}
- -

🌐 REST API (Simple)

-

- The REST API provides a simple HTTP interface for - executing commands. -

-
    -
  • Endpoint: POST """, - server_url, - """/api
  • -
  • - Best for: Simple, incidental commands - without need for real-time updates -
  • -
-

Example request:

-
{
-  "command": "players/all",
-  "args": {}
-}
- -

📥 OpenAPI Specification

-

Download the OpenAPI 3.0 specification for automated client generation:

- Download openapi.json - -

🚀 Interactive API Explorers

-

- Try out the API interactively with our API explorers. - Test endpoints, see live responses, and explore the full API: -

- - -

📡 WebSocket Events

-

- When connected via WebSocket, you automatically receive - real-time event notifications: -

-
- Player Events: -
    -
  • player_added - New player discovered
  • -
  • player_updated - Player state changed
  • -
  • player_removed - Player disconnected
  • -
  • player_config_updated - Player settings changed
  • -
- - Queue Events: -
    -
  • queue_added - New queue created
  • -
  • queue_updated - Queue state changed
  • -
  • queue_items_updated - Queue content changed
  • -
  • queue_time_updated - Playback position updated
  • -
- - Library Events: -
    -
  • media_item_added - New media added to library
  • -
  • media_item_updated - Media metadata updated
  • -
  • media_item_deleted - Media removed from library
  • -
  • media_item_played - Media playback started
  • -
- - System Events: -
    -
  • providers_updated - Provider status changed
  • -
  • sync_tasks_updated - Sync progress updated
  • -
  • application_shutdown - Server shutting down
  • -
-
-
- - -""" - ) - - # Add commands by category - for category, commands in sorted(categories.items()): - html_parts.append(f'
\n') - html_parts.append(f'
{category}
\n') - - for command, handler in commands: - _, description, param_descriptions = _parse_docstring(handler.target) - - html_parts.append('
\n') - html_parts.append(f'
{command}
\n') - - if description: - html_parts.append( - f'
{description}
\n' - ) - - # Parameters - params_html = [] - for param_name, param in handler.signature.parameters.items(): - if param_name == "self": - continue - # Skip return_type parameter (used only for type hints) - if param_name == "return_type": - continue - - param_type = handler.type_hints.get(param_name, Any) - is_required = param.default is inspect.Parameter.empty - param_desc = param_descriptions.get(param_name, "") - - # Format type name - type_name = _format_type_name(param_type) - if get_origin(param_type): - origin = get_origin(param_type) - args = get_args(param_type) - if origin is Union or origin is UnionType: - type_name = " | ".join(_format_type_name(arg) for arg in args) - elif origin in (list, tuple): - if args: - inner_type = _format_type_name(args[0]) - type_name = f"{origin.__name__}[{inner_type}]" - elif origin is dict: - if len(args) == 2: - key_type = _format_type_name(args[0]) - val_type = _format_type_name(args[1]) - type_name = f"dict[{key_type}, {val_type}]" - - required_badge = ( - 'required' - if is_required - else 'optional' - ) - - # Format default value - default_str = "" - if not is_required and param.default is not None: - try: - if isinstance(param.default, str): - default_str = f' = "{param.default}"' - elif isinstance(param.default, Enum): - default_str = f" = {param.default.value}" - elif isinstance(param.default, (int, float, bool, list, dict)): - default_str = f" = {param.default}" - except Exception: # noqa: S110 - pass # Can't serialize, skip default - - params_html.append( - f'
\n' - f' {param_name}\n' - f' ' - f"({type_name}{default_str})\n" - f" {required_badge}\n" - ) - if param_desc: - params_html.append( - f'
' - f"{param_desc}
\n" - ) - params_html.append("
\n") - - if params_html: - html_parts.append('
\n') - html_parts.append("

Parameters

\n") - html_parts.extend(params_html) - html_parts.append("
\n") - - # Return type - return_type = handler.type_hints.get("return", Any) - if return_type and return_type is not NoneType: - type_name = _format_type_name(return_type) - if get_origin(return_type): - origin = get_origin(return_type) - args = get_args(return_type) - if origin in (list, tuple) and args: - inner_type = _format_type_name(args[0]) - type_name = f"{origin.__name__}[{inner_type}]" - elif origin is Union or origin is UnionType: - type_name = " | ".join(_format_type_name(arg) for arg in args) - - html_parts.append('
\n') - html_parts.append("

Returns

\n") - html_parts.append( - f'
{type_name}
\n' - ) - html_parts.append("
\n") - - html_parts.append("
\n") - - html_parts.append("
\n") - - html_parts.append( - """
- - -""" - ) - - return "".join(html_parts) diff --git a/music_assistant/helpers/redirect_validation.py b/music_assistant/helpers/redirect_validation.py new file mode 100644 index 00000000..3d396631 --- /dev/null +++ b/music_assistant/helpers/redirect_validation.py @@ -0,0 +1,116 @@ +"""Helpers for validating redirect URLs in OAuth/auth flows.""" + +from __future__ import annotations + +import ipaddress +import logging +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +from music_assistant.constants import MASS_LOGGER_NAME + +if TYPE_CHECKING: + from aiohttp import web + +LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.redirect_validation") + +# Allowed redirect URI patterns +# Add custom URL schemes for mobile apps here +ALLOWED_REDIRECT_PATTERNS = [ + # Custom URL schemes for mobile apps + "musicassistant://", # Music Assistant mobile app + # Home Assistant domains + "https://my.home-assistant.io/", + "http://homeassistant.local/", + "https://homeassistant.local/", +] + + +def is_allowed_redirect_url( + url: str, + request: web.Request | None = None, + base_url: str | None = None, +) -> tuple[bool, str]: + """ + Validate if a redirect URL is allowed for OAuth/auth flows. + + Security rules (in order of priority): + 1. Must use http, https, or registered custom scheme (e.g., musicassistant://) + 2. Same origin as the request - auto-allowed (trusted) + 3. Localhost (127.0.0.1, ::1, localhost) - auto-allowed (trusted) + 4. Private network IPs (RFC 1918) - auto-allowed (trusted) + 5. Configured base_url - auto-allowed (trusted) + 6. Matches allowed redirect patterns - auto-allowed (trusted) + 7. Everything else - requires user consent (external) + + :param url: The redirect URL to validate. + :param request: Optional aiohttp request to compare origin. + :param base_url: Optional configured base URL to allow. + :return: Tuple of (is_valid, category) where category is: + - "trusted": Auto-allowed, no consent needed + - "external": Valid but requires user consent + - "blocked": Invalid/dangerous URL + """ + if not url: + return False, "blocked" + + try: + parsed = urlparse(url) + + # Check for custom URL schemes (mobile apps) + for pattern in ALLOWED_REDIRECT_PATTERNS: + if url.startswith(pattern): + LOGGER.debug("Redirect URL trusted (pattern match): %s", url) + return True, "trusted" + + # Only http/https for web URLs + if parsed.scheme not in ("http", "https"): + LOGGER.warning("Redirect URL blocked (invalid scheme): %s", url) + return False, "blocked" + + hostname = parsed.hostname + if not hostname: + LOGGER.warning("Redirect URL blocked (no hostname): %s", url) + return False, "blocked" + + # 1. Same origin as request - always trusted + if request: + request_host = request.host + if parsed.netloc == request_host: + LOGGER.debug("Redirect URL trusted (same origin): %s", url) + return True, "trusted" + + # 2. Localhost - always trusted (for development and mobile app testing) + if hostname in ("localhost", "127.0.0.1", "::1"): + LOGGER.debug("Redirect URL trusted (localhost): %s", url) + return True, "trusted" + + # 3. Private network IPs - always trusted (for local network access) + if _is_private_ip(hostname): + LOGGER.debug("Redirect URL trusted (private IP): %s", url) + return True, "trusted" + + # 4. Configured base_url - always trusted + if base_url: + base_parsed = urlparse(base_url) + if parsed.netloc == base_parsed.netloc: + LOGGER.debug("Redirect URL trusted (base_url): %s", url) + return True, "trusted" + + # If we get here, URL is external and requires user consent + LOGGER.info("Redirect URL is external (requires consent): %s", url) + return True, "external" + + except Exception as e: + LOGGER.exception("Error validating redirect URL: %s", e) + return False, "blocked" + + +def _is_private_ip(hostname: str) -> bool: + """Check if hostname is a private IP address (RFC 1918).""" + try: + ip = ipaddress.ip_address(hostname) + return ip.is_private + except ValueError: + # Not a valid IP address + return False diff --git a/music_assistant/helpers/resources/api_docs.html b/music_assistant/helpers/resources/api_docs.html index 04bfc12e..af5dc3f4 100644 --- a/music_assistant/helpers/resources/api_docs.html +++ b/music_assistant/helpers/resources/api_docs.html @@ -4,230 +4,403 @@ Music Assistant API Documentation +
- +

Music Assistant API

Version {VERSION}
@@ -286,7 +459,25 @@ # Connect to WebSocket ws://{SERVER_HOST}/ws -# Send a command (message_id is REQUIRED) +# Step 1: Authenticate (REQUIRED as first command) +{ + "message_id": "auth-123", + "command": "auth", + "args": { + "token": "your_access_token" + } +} + +# Auth response +{ + "message_id": "auth-123", + "result": { + "authenticated": true, + "user": {...user info...} + } +} + +# Step 2: Send commands (message_id is REQUIRED) { "message_id": "unique-id-123", "command": "players/all", @@ -317,16 +508,18 @@ ws://{SERVER_HOST}/ws since each HTTP request is isolated. The response returns the command result directly.
-# Get all players +# Get all players (requires authentication) curl -X POST {BASE_URL}/api \ + -H "Authorization: Bearer your_access_token" \ -H "Content-Type: application/json" \ -d '{ "command": "players/all", "args": {} }' -# Play media on a player +# Play media on a player (requires authentication) curl -X POST {BASE_URL}/api \ + -H "Authorization: Bearer your_access_token" \ -H "Content-Type: application/json" \ -d '{ "command": "player_queues/play_media", @@ -336,7 +529,7 @@ curl -X POST {BASE_URL}/api \ } }' -# Get server info +# Get server info (no authentication required) curl {BASE_URL}/info
@@ -459,7 +652,7 @@ api.subscribe('player_updated', (event) => {

Best Practices

✓ Do:

-
    +
    • Use WebSocket API for real-time applications
    • Handle connection drops and reconnect automatically
    • Subscribe to relevant events instead of polling
    • @@ -467,7 +660,7 @@ api.subscribe('player_updated', (event) => {
    • Implement proper error handling

    ✗ Don't:

    -
      +
      • Poll the REST API frequently for updates (use WebSocket events instead)
      • Send commands without waiting for previous responses
      • Ignore error responses
      • @@ -477,19 +670,132 @@ api.subscribe('player_updated', (event) => {

        Authentication

        +

        + As of API Schema Version 28, authentication is now mandatory for all API access + (except when accessed through Home Assistant Ingress). +

        + +

        Authentication Overview

        +

        Music Assistant supports the following authentication methods:

        +
          +
        • Username/Password - Built-in authentication provider
        • +
        • Home Assistant OAuth - OAuth flow for HA users (optional)
        • +
        • Bearer Tokens - Token-based authentication for HTTP and WebSocket
        • +
        + +

        HTTP Authentication Endpoints

        +

        The following HTTP endpoints are available for authentication (no auth required):

        +
        +# Get server info (includes onboard_done status) +GET {BASE_URL}/info + +# Get available login providers +GET {BASE_URL}/auth/providers + +# Login with credentials (built-in provider is default) +POST {BASE_URL}/auth/login +{ + "credentials": { + "username": "your_username", + "password": "your_password" + } +} + +# Or specify a different provider (e.g., Home Assistant OAuth) +POST {BASE_URL}/auth/login +{ + "provider_id": "homeassistant", + "credentials": {...provider-specific...} +} + +# Response includes access token +{ + "success": true, + "token": "your_access_token", + "user": { ...user info... } +} + +# First-time setup (only if no users exist) +POST {BASE_URL}/setup +{ + "username": "admin", + "password": "secure_password" +} +
        + +

        Using Bearer Tokens

        +

        Once you have an access token, include it in all HTTP requests:

        +
        +curl -X POST {BASE_URL}/api \ + -H "Authorization: Bearer your_access_token" \ + -H "Content-Type: application/json" \ + -d '{ + "command": "players/all", + "args": {} + }' +
        + +

        WebSocket Authentication

        +

        + After establishing a WebSocket connection, you must send an + auth command as the first message: +

        +
        +# Send auth command immediately after connection +{ + "message_id": "auth-123", + "command": "auth", + "args": { + "token": "your_access_token" + } +} + +# Response on success +{ + "message_id": "auth-123", + "result": { + "authenticated": true, + "user": { + "user_id": "...", + "username": "your_username", + "role": "admin" + } + } +} +
        +
        - Note: Authentication is not yet implemented but will be added - in a future release. For now, ensure your Music Assistant server is not directly - exposed to the internet. Use a VPN or reverse proxy for secure access. + Token Types: +
          +
        • Short-lived tokens: Created automatically during login. Expire after 30 days of inactivity but auto-renew on each use (sliding expiration window). Perfect for user sessions.
        • +
        • Long-lived tokens: Created via auth/token/create command. Expire after 10 years with no auto-renewal. Intended for external integrations (Home Assistant, mobile apps, API access).
        • +
        + Use the auth/tokens and auth/token/create WebSocket commands to manage your tokens.
        + +

        User Management Commands

        +

        The following WebSocket commands are available for authentication management:

        +
          +
        • auth/users - List all users (admin only)
        • +
        • auth/user - Get user by ID (admin only)
        • +
        • auth/user/create - Create a new user (admin only)
        • +
        • auth/user/update - Update user profile, password, or role (admin for other users)
        • +
        • auth/user/enable - Enable user (admin only)
        • +
        • auth/user/disable - Disable user (admin only)
        • +
        • auth/user/delete - Delete user (admin only)
        • +
        • auth/tokens - List your tokens
        • +
        • auth/token/create - Create a new long-lived token
        • +
        • auth/token/revoke - Revoke a token
        • +
        +

        See the Commands Reference for detailed documentation of all auth commands.

diff --git a/music_assistant/helpers/resources/commands_reference.html b/music_assistant/helpers/resources/commands_reference.html new file mode 100644 index 00000000..99aa9b94 --- /dev/null +++ b/music_assistant/helpers/resources/commands_reference.html @@ -0,0 +1,1201 @@ + + + + + + Music Assistant API - Commands Reference + + + + +
+ +

Commands Reference

+

Complete list of Music Assistant API commands

+
+ + + +
+
Loading commands...
+
+ + + + diff --git a/music_assistant/helpers/resources/common.css b/music_assistant/helpers/resources/common.css new file mode 100644 index 00000000..10260b94 --- /dev/null +++ b/music_assistant/helpers/resources/common.css @@ -0,0 +1,350 @@ +/* Music Assistant - Common Styles + * Shared CSS variables and base styles used across all HTML pages + */ + +/* CSS Variables for theming */ +:root { + --fg: #000000; + --background: #f5f5f5; + --overlay: #e7e7e7; + --panel: #ffffff; + --default: #ffffff; + --primary: #03a9f4; + --text-secondary: rgba(0, 0, 0, 0.6); + --text-tertiary: rgba(0, 0, 0, 0.4); + --border: rgba(0, 0, 0, 0.1); + --input-bg: rgba(0, 0, 0, 0.03); + --input-focus-bg: rgba(3, 169, 244, 0.05); + --primary-glow: rgba(3, 169, 244, 0.15); + --error-bg: rgba(244, 67, 54, 0.08); + --error-border: rgba(244, 67, 54, 0.2); + --error-text: #d32f2f; + --success: #4caf50; + --success-bg: rgba(76, 175, 80, 0.1); + --success-border: rgba(76, 175, 80, 0.3); + --code-bg: #2d2d2d; + --code-fg: #f8f8f2; + --code-comment: #75715e; + --code-string: #e6db74; + --code-keyword: #66d9ef; + --info-bg: rgba(3, 169, 244, 0.08); + --info-border: rgba(3, 169, 244, 0.3); +} + +/* Dark mode color scheme */ +@media (prefers-color-scheme: dark) { + :root { + --fg: #ffffff; + --background: #181818; + --overlay: #181818; + --panel: #232323; + --default: #000000; + --text-secondary: rgba(255, 255, 255, 0.7); + --text-tertiary: rgba(255, 255, 255, 0.4); + --border: rgba(255, 255, 255, 0.08); + --input-bg: rgba(255, 255, 255, 0.05); + --input-focus-bg: rgba(3, 169, 244, 0.08); + --primary-glow: rgba(3, 169, 244, 0.25); + --error-bg: rgba(244, 67, 54, 0.1); + --error-border: rgba(244, 67, 54, 0.25); + --error-text: #ff6b6b; + --success: #66bb6a; + --success-bg: rgba(102, 187, 106, 0.15); + --success-border: rgba(102, 187, 106, 0.4); + --code-bg: #1a1a1a; + --info-bg: rgba(3, 169, 244, 0.12); + --info-border: rgba(3, 169, 244, 0.4); + } +} + +/* Base reset and body styles */ +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif; + background: var(--background); + color: var(--fg); + line-height: 1.6; +} + +/* Common header styles */ +.header { + background: var(--panel); + color: var(--fg); + padding: 1.5rem 2rem; + text-align: center; + box-shadow: 0 2px 10px rgba(0,0,0,0.1); + border-bottom: 1px solid var(--border); +} + +.header h1 { + font-size: 1.8em; + margin-bottom: 0.3rem; + font-weight: 600; +} + +.header p { + font-size: 0.95em; + opacity: 0.9; +} + +.header .logo { + margin-bottom: 1rem; +} + +.header .logo img { + width: 60px; + height: 60px; + object-fit: contain; +} + +/* Logo styles */ +.logo { + text-align: center; + margin-bottom: 24px; +} + +.logo img { + width: 72px; + height: 72px; + object-fit: contain; +} + +/* Form elements */ +.form-group { + margin-bottom: 22px; +} + +label { + display: block; + color: var(--fg); + font-size: 13px; + font-weight: 500; + margin-bottom: 8px; + letter-spacing: 0.2px; +} + +input[type="text"], +input[type="password"] { + width: 100%; + padding: 14px 16px; + background: var(--input-bg); + border: 1px solid var(--border); + border-radius: 10px; + font-size: 15px; + color: var(--fg); + transition: all 0.2s ease; +} + +input[type="text"]::placeholder, +input[type="password"]::placeholder { + color: var(--text-tertiary); +} + +input[type="text"]:focus, +input[type="password"]:focus { + outline: none; + border-color: var(--primary); + background: var(--input-focus-bg); + box-shadow: 0 0 0 3px var(--primary-glow); +} + +input[type="text"]:disabled { + background: var(--overlay); + color: var(--text-tertiary); + cursor: not-allowed; +} + +/* Button styles */ +.btn { + width: 100%; + padding: 15px; + border: none; + border-radius: 10px; + font-size: 15px; + font-weight: 600; + cursor: pointer; + transition: all 0.2s ease; + letter-spacing: 0.3px; +} + +.btn-primary { + background: var(--primary); + color: white; +} + +.btn-primary:hover { + filter: brightness(1.1); + box-shadow: 0 8px 24px var(--primary-glow); + transform: translateY(-1px); +} + +.btn-primary:active { + transform: translateY(0); + filter: brightness(0.95); +} + +.btn-primary:disabled { + opacity: 0.5; + cursor: not-allowed; + transform: none; + box-shadow: none; + filter: none; +} + +.btn-secondary { + background: var(--input-bg); + color: var(--fg); + border: 1px solid var(--border); + display: flex; + align-items: center; + justify-content: center; + gap: 10px; +} + +.btn-secondary:hover { + background: var(--input-focus-bg); + border-color: var(--primary); +} + +/* Error and success messages */ +.error { + background: var(--error-bg); + border: 1px solid var(--error-border); + color: var(--error-text); + padding: 14px 16px; + border-radius: 10px; + margin-bottom: 22px; + font-size: 13px; + display: none; +} + +.error.show { + display: block; +} + +.error-message { + background: var(--error-bg); + color: var(--error-text); + padding: 14px 16px; + border-radius: 10px; + margin-bottom: 22px; + font-size: 13px; + display: none; + border: 1px solid var(--error-border); +} + +.error-message.show { + display: block; +} + +/* Container and panel styles */ +.container { + max-width: 1200px; + margin: 0 auto; + background: var(--panel); + border-radius: 16px; + box-shadow: 0 4px 24px rgba(0, 0, 0, 0.12), 0 0 0 1px var(--border); +} + +.panel { + background: var(--panel); + border-radius: 16px; + box-shadow: 0 4px 24px rgba(0, 0, 0, 0.12), 0 0 0 1px var(--border); + padding: 48px 40px; +} + +/* Loading indicator */ +.loading { + display: inline-block; + width: 16px; + height: 16px; + border: 2px solid rgba(255, 255, 255, 0.3); + border-radius: 50%; + border-top-color: #fff; + animation: spin 0.6s linear infinite; +} + +@keyframes spin { + to { transform: rotate(360deg); } +} + +/* Code blocks */ +.code-block, +.example { + background: var(--code-bg); + color: var(--code-fg); + padding: 1rem; + border-radius: 8px; + overflow-x: auto; + margin: 1rem 0; + font-family: 'Monaco', 'Courier New', monospace; + font-size: 0.9em; + line-height: 1.6; +} + +.example pre { + margin: 0; +} + +/* Divider */ +.divider { + text-align: center; + margin: 28px 0; + position: relative; +} + +.divider::before { + content: ''; + position: absolute; + top: 50%; + left: 0; + right: 0; + height: 1px; + background: var(--border); +} + +.divider span { + background: var(--panel); + padding: 0 16px; + color: var(--text-tertiary); + font-size: 13px; + position: relative; +} + +/* Link styles */ +.type-link { + color: var(--primary); + text-decoration: none; + border-bottom: 1px dashed var(--primary); + transition: all 0.2s; +} + +.type-link:hover { + opacity: 0.8; + border-bottom-color: transparent; +} + +.back-link { + display: inline-block; + margin-bottom: 1rem; + padding: 0.5rem 1rem; + background: var(--primary); + color: #ffffff; + text-decoration: none; + border-radius: 6px; + transition: background 0.2s; +} + +.back-link:hover { + opacity: 0.9; +} + +/* Utility classes */ +.hidden { + display: none; +} diff --git a/music_assistant/helpers/resources/login.html b/music_assistant/helpers/resources/login.html new file mode 100644 index 00000000..eac11731 --- /dev/null +++ b/music_assistant/helpers/resources/login.html @@ -0,0 +1,291 @@ + + + + + + Login - Music Assistant + + + + + + + + + diff --git a/music_assistant/helpers/resources/logo.png b/music_assistant/helpers/resources/logo.png index d00d8ffd..ac742642 100644 Binary files a/music_assistant/helpers/resources/logo.png and b/music_assistant/helpers/resources/logo.png differ diff --git a/music_assistant/helpers/resources/oauth_callback.html b/music_assistant/helpers/resources/oauth_callback.html new file mode 100644 index 00000000..1bd898d4 --- /dev/null +++ b/music_assistant/helpers/resources/oauth_callback.html @@ -0,0 +1,216 @@ + + + + + + Login Successful + + + + + +
+

Login Successful!

+

Redirecting...

+
+ + + diff --git a/music_assistant/helpers/resources/schemas_reference.html b/music_assistant/helpers/resources/schemas_reference.html new file mode 100644 index 00000000..cfc394f6 --- /dev/null +++ b/music_assistant/helpers/resources/schemas_reference.html @@ -0,0 +1,547 @@ + + + + + + Music Assistant API - Schemas Reference + + + + +
+ +

Schemas Reference

+

Data models and types used in the Music Assistant API

+
+ + + + + + + + diff --git a/music_assistant/helpers/resources/setup.html b/music_assistant/helpers/resources/setup.html new file mode 100644 index 00000000..936808e7 --- /dev/null +++ b/music_assistant/helpers/resources/setup.html @@ -0,0 +1,614 @@ + + + + + + Music Assistant - Setup + + + + +
+ + + +
+
+
+
+
+ + +
+
+

Welcome to Music Assistant!

+

Let's get you started with your personal music server. This setup wizard will guide you through the initial configuration.

+
+ +
+

What you'll set up:

+

Step 1: Create your administrator account

+

Step 2: Complete the setup process

+
+ +
+ +
+
+ + +
+
+

Create Administrator Account

+

Your admin credentials will be used to access the Music Assistant web interface and mobile apps.

+
+ + + +
+ +
+
+ + +
+ +
+ + +
+ Minimum 8 characters recommended +
+
+ +
+ + +
+ +
+ + +
+
+ +
+
+

Creating your account...

+
+
+ + +
+
+ ✓ +
+
+

Setup Complete!

+

Your Music Assistant server has been successfully configured and is ready to use.

+

You can now start adding music providers and connecting your speakers to begin enjoying your music library.

+
+ +
+ +
+
+
+ + + + diff --git a/music_assistant/helpers/resources/swagger_ui.html b/music_assistant/helpers/resources/swagger_ui.html index 872a6a70..8a8d3d2a 100644 --- a/music_assistant/helpers/resources/swagger_ui.html +++ b/music_assistant/helpers/resources/swagger_ui.html @@ -10,19 +10,11 @@ margin: 0; padding: 0; } - .topbar { - display: none; - } - .swagger-ui .info { - margin: 30px 0; - } - .swagger-ui .info .title { - font-size: 2.5em; - }
+