mypy fixes for webserver.py (#2437)
authorOzGav <gavnosp@hotmail.com>
Tue, 30 Sep 2025 15:54:58 +0000 (01:54 +1000)
committerGitHub <noreply@github.com>
Tue, 30 Sep 2025 15:54:58 +0000 (17:54 +0200)
music_assistant/controllers/webserver.py
pyproject.toml

index e551a4bdda9ff2d1c4dea1c6c98c910b0c5aaf7c..54a52f0dcda59ffadfd0cead4d86451483d5a7fe 100644 (file)
@@ -37,11 +37,12 @@ from music_assistant.helpers.webserver import Webserver
 from music_assistant.models.core_controller import CoreController
 
 if TYPE_CHECKING:
-    from collections.abc import Awaitable
-
+    from aiohttp.typedefs import Handler
     from music_assistant_models.config_entries import ConfigValueType, CoreConfig
     from music_assistant_models.event import MassEvent
 
+    from music_assistant.mass import MusicAssistant
+
 DEFAULT_SERVER_PORT = 8095
 INGRESS_SERVER_PORT = 8094
 CONF_BASE_URL = "base_url"
@@ -53,10 +54,11 @@ class WebserverController(CoreController):
     """Core Controller that manages the builtin webserver that hosts the api and frontend."""
 
     domain: str = "webserver"
+    _server: Webserver
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(self, mass: MusicAssistant) -> None:
         """Initialize instance."""
-        super().__init__(*args, **kwargs)
+        super().__init__(mass)
         self._server = Webserver(self.logger, enable_dynamic_routes=True)
         self.register_dynamic_route = self._server.register_dynamic_route
         self.unregister_dynamic_route = self._server.unregister_dynamic_route
@@ -126,7 +128,7 @@ class WebserverController(CoreController):
     async def setup(self, config: CoreConfig) -> None:
         """Async initialize of module."""
         # work out all routes
-        routes: list[tuple[str, str, Awaitable]] = []
+        routes: list[tuple[str, str, Handler]] = []
         # frontend routes
         frontend_dir = locate_frontend()
         for filename in next(os.walk(frontend_dir))[2]:
@@ -164,9 +166,12 @@ class WebserverController(CoreController):
         else:
             ingress_tcp_site_params = None
         base_url = str(config.get_value(CONF_BASE_URL))
-        self.publish_port = int(config.get_value(CONF_BIND_PORT))
+        port_value = config.get_value(CONF_BIND_PORT)
+        assert isinstance(port_value, int)
+        self.publish_port = port_value
         self.publish_ip = default_publish_ip
         bind_ip = config.get_value(CONF_BIND_IP)
+        assert isinstance(bind_ip, str)
         # print a big fat message in the log where the webserver is running
         # because this is a common source of issues for people with more complex setups
         if not self.mass.config.onboard_done:
@@ -203,7 +208,7 @@ class WebserverController(CoreController):
             await client.disconnect()
         await self._server.close()
 
-    async def serve_preview_stream(self, request: web.Request):
+    async def serve_preview_stream(self, request: web.Request) -> web.StreamResponse:
         """Serve short preview sample."""
         provider_instance_id_or_domain = request.query["provider"]
         item_id = urllib.parse.unquote(request.query["item_id"])
@@ -236,7 +241,7 @@ class WebserverController(CoreController):
         try:
             command_msg = CommandMessage.from_json(cmd_data)
         except ValueError:
-            error = f"Invalid JSON: {cmd_data}"
+            error = f"Invalid JSON: {cmd_data!r}"
             self.logger.error("Unhandled JSONRPC API error: %s", error)
             return web.Response(status=400, text=error)
 
@@ -247,7 +252,7 @@ class WebserverController(CoreController):
             self.logger.error("Unhandled JSONRPC API error: %s", error)
             return web.Response(status=400, text=error)
         args = parse_arguments(handler.signature, handler.type_hints, command_msg.args)
-        result = handler.target(**args)
+        result: Any = handler.target(**args)
         if hasattr(result, "__anext__"):
             # handle async generator (for really large listings)
             result = [item async for item in result]
@@ -269,9 +274,9 @@ class WebsocketClientHandler:
         self.mass = webserver.mass
         self.request = request
         self.wsock = web.WebSocketResponse(heartbeat=55)
-        self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
-        self._handle_task: asyncio.Task | None = None
-        self._writer_task: asyncio.Task | None = None
+        self._to_write: asyncio.Queue[str | None] = asyncio.Queue(maxsize=MAX_PENDING_MSG)
+        self._handle_task: asyncio.Task[Any] | None = None
+        self._writer_task: asyncio.Task[None] | None = None
         self._logger = webserver.logger
         # try to dynamically detect the base_url of a client if proxied or behind Ingress
         self.base_url: str | None = None
@@ -370,7 +375,7 @@ class WebsocketClientHandler:
         if handler is None:
             self._send_message(
                 ErrorResultMessage(
-                    msg.message_id,
+                    str(msg.message_id),
                     InvalidCommand.error_code,
                     f"Invalid command: {msg.command}",
                 )
@@ -384,21 +389,22 @@ class WebsocketClientHandler:
     async def _run_handler(self, handler: APICommandHandler, msg: CommandMessage) -> None:
         try:
             args = parse_arguments(handler.signature, handler.type_hints, msg.args)
-            result = handler.target(**args)
+            result: Any = handler.target(**args)
             if hasattr(result, "__anext__"):
                 # handle async generator (for really large listings)
                 iterator = result
-                result: list[Any] = []
+                items: list[Any] = []
                 async for item in iterator:
-                    result.append(item)
-                    if len(result) >= 500:
+                    items.append(item)
+                    if len(items) >= 500:
                         self._send_message(
-                            SuccessResultMessage(msg.message_id, result, partial=True)
+                            SuccessResultMessage(str(msg.message_id), items, partial=True)
                         )
-                        result = []
+                        items = []
+                result = items
             elif asyncio.iscoroutine(result):
                 result = await result
-            self._send_message(SuccessResultMessage(msg.message_id, result))
+            self._send_message(SuccessResultMessage(str(msg.message_id), result))
         except Exception as err:
             if self._logger.isEnabledFor(logging.DEBUG):
                 self._logger.exception("Error handling message: %s", msg)
@@ -406,7 +412,7 @@ class WebsocketClientHandler:
                 self._logger.error("Error handling message: %s: %s", msg.command, str(err))
             err_msg = str(err) or err.__class__.__name__
             self._send_message(
-                ErrorResultMessage(msg.message_id, getattr(err, "error_code", 999), err_msg)
+                ErrorResultMessage(str(msg.message_id), getattr(err, "error_code", 999), err_msg)
             )
 
     async def _writer(self) -> None:
@@ -417,7 +423,7 @@ class WebsocketClientHandler:
                 if (process := await self._to_write.get()) is None:
                     break
 
-                if not isinstance(process, str):
+                if callable(process):
                     message: str = process()
                 else:
                     message = process
index 420327e801e2a5eda4d19a6ce320dbe0548d6912..f85cc4fc1ba5a2dab15903da5b8da5696e8a6f49 100644 (file)
@@ -130,7 +130,15 @@ enable_error_code = [
   "truthy-iterable",
 ]
 exclude = [
-  '^music_assistant/controllers/.*$',
+  '^music_assistant/controllers/__init__.py$',
+  '^music_assistant/controllers/cache.py$',
+  '^music_assistant/controllers/config.py$',
+  '^music_assistant/controllers/media/.*$',
+  '^music_assistant/controllers/metadata.py$',
+  '^music_assistant/controllers/music.py$',
+  '^music_assistant/controllers/player_queues.py$',
+  '^music_assistant/controllers/players.py$',
+  '^music_assistant/controllers/streams.py$',
   '^music_assistant/helpers/app_vars.py',
   '^music_assistant/models/player_provider.py',
   '^music_assistant/providers/apple_music/.*$',