From: OzGav Date: Fri, 7 Nov 2025 08:06:25 +0000 (+1000) Subject: Typing fixes for the Webserver controller (#2586) X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=2b6f6a3f2ef0c8ead526f97ec42c79265b11a2a2;p=music-assistant-server.git Typing fixes for the Webserver controller (#2586) * Fix merge conflicts * more mypy fixes * Drafting post upstream changes * fix run_handler * PR review comment * Merge conflict fixes * Fix WebserverController __init__ signature to match CoreController --- diff --git a/music_assistant/controllers/webserver.py b/music_assistant/controllers/webserver.py index 1b2570c6..6c3be81a 100644 --- a/music_assistant/controllers/webserver.py +++ b/music_assistant/controllers/webserver.py @@ -12,6 +12,7 @@ import html import logging import os import urllib.parse +from collections.abc import Awaitable, Callable from concurrent import futures from contextlib import suppress from functools import partial @@ -45,11 +46,11 @@ from music_assistant.helpers.webserver import Webserver from music_assistant.models.core_controller import CoreController if TYPE_CHECKING: - from collections.abc import Awaitable - from music_assistant_models.config_entries import ConfigValueType, CoreConfig from music_assistant_models.event import MassEvent + from music_assistant import MusicAssistant + DEFAULT_SERVER_PORT = 8095 INGRESS_SERVER_PORT = 8094 CONF_BASE_URL = "base_url" @@ -62,9 +63,9 @@ class WebserverController(CoreController): domain: str = "webserver" - def __init__(self, *args, **kwargs) -> None: + def __init__(self, mass: MusicAssistant) -> None: """Initialize instance.""" - super().__init__(*args, **kwargs) + super().__init__(mass) self._server = Webserver(self.logger, enable_dynamic_routes=True) self.register_dynamic_route = self._server.register_dynamic_route self.unregister_dynamic_route = self._server.unregister_dynamic_route @@ -134,7 +135,7 @@ class WebserverController(CoreController): async def setup(self, config: CoreConfig) -> None: """Async initialize of module.""" # work out all routes - routes: list[tuple[str, str, Awaitable]] = [] + routes: list[tuple[str, str, Callable[[web.Request], Awaitable[web.StreamResponse]]]] = [] # frontend routes frontend_dir = locate_frontend() for filename in next(os.walk(frontend_dir))[2]: @@ -182,9 +183,12 @@ class WebserverController(CoreController): else: ingress_tcp_site_params = None base_url = str(config.get_value(CONF_BASE_URL)) - self.publish_port = int(config.get_value(CONF_BIND_PORT)) + port_value = config.get_value(CONF_BIND_PORT) + assert isinstance(port_value, int) + self.publish_port = port_value self.publish_ip = default_publish_ip bind_ip = config.get_value(CONF_BIND_IP) + assert isinstance(bind_ip, str) # print a big fat message in the log where the webserver is running # because this is a common source of issues for people with more complex setups if not self.mass.config.onboard_done: @@ -221,7 +225,7 @@ class WebserverController(CoreController): await client.disconnect() await self._server.close() - async def serve_preview_stream(self, request: web.Request): + async def serve_preview_stream(self, request: web.Request) -> web.StreamResponse: """Serve short preview sample.""" provider_instance_id_or_domain = request.query["provider"] item_id = urllib.parse.unquote(request.query["item_id"]) @@ -254,7 +258,7 @@ class WebserverController(CoreController): try: command_msg = CommandMessage.from_json(cmd_data) except ValueError: - error = f"Invalid JSON: {cmd_data}" + error = f"Invalid JSON: {cmd_data.decode()}" self.logger.error("Unhandled JSONRPC API error: %s", error) return web.Response(status=400, text=error) except MissingField as e: @@ -274,10 +278,11 @@ class WebserverController(CoreController): error = f"Invalid Command: {command_msg.command}" self.logger.error("Unhandled JSONRPC API error: %s", error) return web.Response(status=400, text=error) - try: args = parse_arguments(handler.signature, handler.type_hints, command_msg.args) - result = handler.target(**args) + result: Any = handler.target(**args) + if asyncio.iscoroutine(result): + result = await result if hasattr(result, "__anext__"): # handle async generator (for really large listings) result = [item async for item in result] @@ -330,7 +335,7 @@ class WebserverController(CoreController): html = generate_schemas_reference(self.mass.command_handlers) return web.Response(text=html, content_type="text/html") - async def _handle_swagger_ui(self, request: web.Request) -> web.Response: + async def _handle_swagger_ui(self, request: web.Request) -> web.FileResponse: """Handle request for Swagger UI.""" swagger_html_path = os.path.join( os.path.dirname(__file__), "..", "helpers", "resources", "swagger_ui.html" @@ -346,9 +351,9 @@ class WebsocketClientHandler: self.mass = webserver.mass self.request = request self.wsock = web.WebSocketResponse(heartbeat=55) - self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG) - self._handle_task: asyncio.Task | None = None - self._writer_task: asyncio.Task | None = None + self._to_write: asyncio.Queue[str | None] = asyncio.Queue(maxsize=MAX_PENDING_MSG) + self._handle_task: asyncio.Task[Any] | None = None + self._writer_task: asyncio.Task[None] | None = None self._logger = webserver.logger # try to dynamically detect the base_url of a client if proxied or behind Ingress self.base_url: str | None = None @@ -461,18 +466,18 @@ class WebsocketClientHandler: async def _run_handler(self, handler: APICommandHandler, msg: CommandMessage) -> None: try: args = parse_arguments(handler.signature, handler.type_hints, msg.args) - result = handler.target(**args) + result: Any = handler.target(**args) if hasattr(result, "__anext__"): # handle async generator (for really large listings) - iterator = result - result: list[Any] = [] - async for item in iterator: - result.append(item) - if len(result) >= 500: + items: list[Any] = [] + async for item in result: + items.append(item) + if len(items) >= 500: await self._send_message( - SuccessResultMessage(msg.message_id, result, partial=True) + SuccessResultMessage(msg.message_id, items, partial=True) ) - result = [] + items = [] + result = items elif asyncio.iscoroutine(result): result = await result await self._send_message(SuccessResultMessage(msg.message_id, result)) @@ -493,8 +498,10 @@ class WebsocketClientHandler: while not self.wsock.closed: if (process := await self._to_write.get()) is None: break + self._logger.log(VERBOSE_LOG_LEVEL, "Writing: %s", process) + await self.wsock.send_str(process) - if not isinstance(process, str): + if callable(process): message: str = process() else: message = process diff --git a/pyproject.toml b/pyproject.toml index 4056ac55..39bb3bc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,7 +143,6 @@ exclude = [ '^music_assistant/controllers/music.py$', '^music_assistant/controllers/player_queues.py$', '^music_assistant/controllers/streams.py$', - '^music_assistant/controllers/webserver.py', '^music_assistant/helpers/app_vars.py', '^music_assistant/providers/apple_music/.*$', '^music_assistant/providers/bluesound/.*$',