Typing fixes for the Webserver controller (#2586)
authorOzGav <gavnosp@hotmail.com>
Fri, 7 Nov 2025 08:06:25 +0000 (18:06 +1000)
committerGitHub <noreply@github.com>
Fri, 7 Nov 2025 08:06:25 +0000 (09:06 +0100)
* Fix merge conflicts

* more mypy fixes

* Drafting post upstream changes

* fix run_handler

* PR review comment

* Merge conflict fixes

* Fix WebserverController __init__ signature to match CoreController

music_assistant/controllers/webserver.py
pyproject.toml

index 1b2570c6ad36202af4c2cc59e11b4080dfa0e0ca..6c3be81af6c5320e70d66524bf5d4414cdd723e6 100644 (file)
@@ -12,6 +12,7 @@ import html
 import logging
 import os
 import urllib.parse
+from collections.abc import Awaitable, Callable
 from concurrent import futures
 from contextlib import suppress
 from functools import partial
@@ -45,11 +46,11 @@ from music_assistant.helpers.webserver import Webserver
 from music_assistant.models.core_controller import CoreController
 
 if TYPE_CHECKING:
-    from collections.abc import Awaitable
-
     from music_assistant_models.config_entries import ConfigValueType, CoreConfig
     from music_assistant_models.event import MassEvent
 
+    from music_assistant import MusicAssistant
+
 DEFAULT_SERVER_PORT = 8095
 INGRESS_SERVER_PORT = 8094
 CONF_BASE_URL = "base_url"
@@ -62,9 +63,9 @@ class WebserverController(CoreController):
 
     domain: str = "webserver"
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(self, mass: MusicAssistant) -> None:
         """Initialize instance."""
-        super().__init__(*args, **kwargs)
+        super().__init__(mass)
         self._server = Webserver(self.logger, enable_dynamic_routes=True)
         self.register_dynamic_route = self._server.register_dynamic_route
         self.unregister_dynamic_route = self._server.unregister_dynamic_route
@@ -134,7 +135,7 @@ class WebserverController(CoreController):
     async def setup(self, config: CoreConfig) -> None:
         """Async initialize of module."""
         # work out all routes
-        routes: list[tuple[str, str, Awaitable]] = []
+        routes: list[tuple[str, str, Callable[[web.Request], Awaitable[web.StreamResponse]]]] = []
         # frontend routes
         frontend_dir = locate_frontend()
         for filename in next(os.walk(frontend_dir))[2]:
@@ -182,9 +183,12 @@ class WebserverController(CoreController):
         else:
             ingress_tcp_site_params = None
         base_url = str(config.get_value(CONF_BASE_URL))
-        self.publish_port = int(config.get_value(CONF_BIND_PORT))
+        port_value = config.get_value(CONF_BIND_PORT)
+        assert isinstance(port_value, int)
+        self.publish_port = port_value
         self.publish_ip = default_publish_ip
         bind_ip = config.get_value(CONF_BIND_IP)
+        assert isinstance(bind_ip, str)
         # print a big fat message in the log where the webserver is running
         # because this is a common source of issues for people with more complex setups
         if not self.mass.config.onboard_done:
@@ -221,7 +225,7 @@ class WebserverController(CoreController):
             await client.disconnect()
         await self._server.close()
 
-    async def serve_preview_stream(self, request: web.Request):
+    async def serve_preview_stream(self, request: web.Request) -> web.StreamResponse:
         """Serve short preview sample."""
         provider_instance_id_or_domain = request.query["provider"]
         item_id = urllib.parse.unquote(request.query["item_id"])
@@ -254,7 +258,7 @@ class WebserverController(CoreController):
         try:
             command_msg = CommandMessage.from_json(cmd_data)
         except ValueError:
-            error = f"Invalid JSON: {cmd_data}"
+            error = f"Invalid JSON: {cmd_data.decode()}"
             self.logger.error("Unhandled JSONRPC API error: %s", error)
             return web.Response(status=400, text=error)
         except MissingField as e:
@@ -274,10 +278,11 @@ class WebserverController(CoreController):
             error = f"Invalid Command: {command_msg.command}"
             self.logger.error("Unhandled JSONRPC API error: %s", error)
             return web.Response(status=400, text=error)
-
         try:
             args = parse_arguments(handler.signature, handler.type_hints, command_msg.args)
-            result = handler.target(**args)
+            result: Any = handler.target(**args)
+            if asyncio.iscoroutine(result):
+                result = await result
             if hasattr(result, "__anext__"):
                 # handle async generator (for really large listings)
                 result = [item async for item in result]
@@ -330,7 +335,7 @@ class WebserverController(CoreController):
         html = generate_schemas_reference(self.mass.command_handlers)
         return web.Response(text=html, content_type="text/html")
 
-    async def _handle_swagger_ui(self, request: web.Request) -> web.Response:
+    async def _handle_swagger_ui(self, request: web.Request) -> web.FileResponse:
         """Handle request for Swagger UI."""
         swagger_html_path = os.path.join(
             os.path.dirname(__file__), "..", "helpers", "resources", "swagger_ui.html"
@@ -346,9 +351,9 @@ class WebsocketClientHandler:
         self.mass = webserver.mass
         self.request = request
         self.wsock = web.WebSocketResponse(heartbeat=55)
-        self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
-        self._handle_task: asyncio.Task | None = None
-        self._writer_task: asyncio.Task | None = None
+        self._to_write: asyncio.Queue[str | None] = asyncio.Queue(maxsize=MAX_PENDING_MSG)
+        self._handle_task: asyncio.Task[Any] | None = None
+        self._writer_task: asyncio.Task[None] | None = None
         self._logger = webserver.logger
         # try to dynamically detect the base_url of a client if proxied or behind Ingress
         self.base_url: str | None = None
@@ -461,18 +466,18 @@ class WebsocketClientHandler:
     async def _run_handler(self, handler: APICommandHandler, msg: CommandMessage) -> None:
         try:
             args = parse_arguments(handler.signature, handler.type_hints, msg.args)
-            result = handler.target(**args)
+            result: Any = handler.target(**args)
             if hasattr(result, "__anext__"):
                 # handle async generator (for really large listings)
-                iterator = result
-                result: list[Any] = []
-                async for item in iterator:
-                    result.append(item)
-                    if len(result) >= 500:
+                items: list[Any] = []
+                async for item in result:
+                    items.append(item)
+                    if len(items) >= 500:
                         await self._send_message(
-                            SuccessResultMessage(msg.message_id, result, partial=True)
+                            SuccessResultMessage(msg.message_id, items, partial=True)
                         )
-                        result = []
+                        items = []
+                result = items
             elif asyncio.iscoroutine(result):
                 result = await result
             await self._send_message(SuccessResultMessage(msg.message_id, result))
@@ -493,8 +498,10 @@ class WebsocketClientHandler:
             while not self.wsock.closed:
                 if (process := await self._to_write.get()) is None:
                     break
+                self._logger.log(VERBOSE_LOG_LEVEL, "Writing: %s", process)
+                await self.wsock.send_str(process)
 
-                if not isinstance(process, str):
+                if callable(process):
                     message: str = process()
                 else:
                     message = process
index 4056ac55dc601ac212b837ec231fdcdd173a24c9..39bb3bc283ced61a1026f77590920759f470f132 100644 (file)
@@ -143,7 +143,6 @@ exclude = [
   '^music_assistant/controllers/music.py$',
   '^music_assistant/controllers/player_queues.py$',
   '^music_assistant/controllers/streams.py$',
-  '^music_assistant/controllers/webserver.py',
   '^music_assistant/helpers/app_vars.py',
   '^music_assistant/providers/apple_music/.*$',
   '^music_assistant/providers/bluesound/.*$',