From: OzGav Date: Tue, 30 Sep 2025 15:54:58 +0000 (+1000) Subject: mypy fixes for webserver.py (#2437) X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=793be87cb3b4b57e1f857dcc5b7e6076d5b61c32;p=music-assistant-server.git mypy fixes for webserver.py (#2437) --- diff --git a/music_assistant/controllers/webserver.py b/music_assistant/controllers/webserver.py index e551a4bd..54a52f0d 100644 --- a/music_assistant/controllers/webserver.py +++ b/music_assistant/controllers/webserver.py @@ -37,11 +37,12 @@ from music_assistant.helpers.webserver import Webserver from music_assistant.models.core_controller import CoreController if TYPE_CHECKING: - from collections.abc import Awaitable - + from aiohttp.typedefs import Handler from music_assistant_models.config_entries import ConfigValueType, CoreConfig from music_assistant_models.event import MassEvent + from music_assistant.mass import MusicAssistant + DEFAULT_SERVER_PORT = 8095 INGRESS_SERVER_PORT = 8094 CONF_BASE_URL = "base_url" @@ -53,10 +54,11 @@ class WebserverController(CoreController): """Core Controller that manages the builtin webserver that hosts the api and frontend.""" domain: str = "webserver" + _server: Webserver - def __init__(self, *args, **kwargs) -> None: + def __init__(self, mass: MusicAssistant) -> None: """Initialize instance.""" - super().__init__(*args, **kwargs) + super().__init__(mass) self._server = Webserver(self.logger, enable_dynamic_routes=True) self.register_dynamic_route = self._server.register_dynamic_route self.unregister_dynamic_route = self._server.unregister_dynamic_route @@ -126,7 +128,7 @@ class WebserverController(CoreController): async def setup(self, config: CoreConfig) -> None: """Async initialize of module.""" # work out all routes - routes: list[tuple[str, str, Awaitable]] = [] + routes: list[tuple[str, str, Handler]] = [] # frontend routes frontend_dir = locate_frontend() for filename in next(os.walk(frontend_dir))[2]: @@ -164,9 +166,12 @@ class WebserverController(CoreController): else: ingress_tcp_site_params = None base_url = str(config.get_value(CONF_BASE_URL)) - self.publish_port = int(config.get_value(CONF_BIND_PORT)) + port_value = config.get_value(CONF_BIND_PORT) + assert isinstance(port_value, int) + self.publish_port = port_value self.publish_ip = default_publish_ip bind_ip = config.get_value(CONF_BIND_IP) + assert isinstance(bind_ip, str) # print a big fat message in the log where the webserver is running # because this is a common source of issues for people with more complex setups if not self.mass.config.onboard_done: @@ -203,7 +208,7 @@ class WebserverController(CoreController): await client.disconnect() await self._server.close() - async def serve_preview_stream(self, request: web.Request): + async def serve_preview_stream(self, request: web.Request) -> web.StreamResponse: """Serve short preview sample.""" provider_instance_id_or_domain = request.query["provider"] item_id = urllib.parse.unquote(request.query["item_id"]) @@ -236,7 +241,7 @@ class WebserverController(CoreController): try: command_msg = CommandMessage.from_json(cmd_data) except ValueError: - error = f"Invalid JSON: {cmd_data}" + error = f"Invalid JSON: {cmd_data!r}" self.logger.error("Unhandled JSONRPC API error: %s", error) return web.Response(status=400, text=error) @@ -247,7 +252,7 @@ class WebserverController(CoreController): self.logger.error("Unhandled JSONRPC API error: %s", error) return web.Response(status=400, text=error) args = parse_arguments(handler.signature, handler.type_hints, command_msg.args) - result = handler.target(**args) + result: Any = handler.target(**args) if hasattr(result, "__anext__"): # handle async generator (for really large listings) result = [item async for item in result] @@ -269,9 +274,9 @@ class WebsocketClientHandler: self.mass = webserver.mass self.request = request self.wsock = web.WebSocketResponse(heartbeat=55) - self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG) - self._handle_task: asyncio.Task | None = None - self._writer_task: asyncio.Task | None = None + self._to_write: asyncio.Queue[str | None] = asyncio.Queue(maxsize=MAX_PENDING_MSG) + self._handle_task: asyncio.Task[Any] | None = None + self._writer_task: asyncio.Task[None] | None = None self._logger = webserver.logger # try to dynamically detect the base_url of a client if proxied or behind Ingress self.base_url: str | None = None @@ -370,7 +375,7 @@ class WebsocketClientHandler: if handler is None: self._send_message( ErrorResultMessage( - msg.message_id, + str(msg.message_id), InvalidCommand.error_code, f"Invalid command: {msg.command}", ) @@ -384,21 +389,22 @@ class WebsocketClientHandler: async def _run_handler(self, handler: APICommandHandler, msg: CommandMessage) -> None: try: args = parse_arguments(handler.signature, handler.type_hints, msg.args) - result = handler.target(**args) + result: Any = handler.target(**args) if hasattr(result, "__anext__"): # handle async generator (for really large listings) iterator = result - result: list[Any] = [] + items: list[Any] = [] async for item in iterator: - result.append(item) - if len(result) >= 500: + items.append(item) + if len(items) >= 500: self._send_message( - SuccessResultMessage(msg.message_id, result, partial=True) + SuccessResultMessage(str(msg.message_id), items, partial=True) ) - result = [] + items = [] + result = items elif asyncio.iscoroutine(result): result = await result - self._send_message(SuccessResultMessage(msg.message_id, result)) + self._send_message(SuccessResultMessage(str(msg.message_id), result)) except Exception as err: if self._logger.isEnabledFor(logging.DEBUG): self._logger.exception("Error handling message: %s", msg) @@ -406,7 +412,7 @@ class WebsocketClientHandler: self._logger.error("Error handling message: %s: %s", msg.command, str(err)) err_msg = str(err) or err.__class__.__name__ self._send_message( - ErrorResultMessage(msg.message_id, getattr(err, "error_code", 999), err_msg) + ErrorResultMessage(str(msg.message_id), getattr(err, "error_code", 999), err_msg) ) async def _writer(self) -> None: @@ -417,7 +423,7 @@ class WebsocketClientHandler: if (process := await self._to_write.get()) is None: break - if not isinstance(process, str): + if callable(process): message: str = process() else: message = process diff --git a/pyproject.toml b/pyproject.toml index 420327e8..f85cc4fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,7 +130,15 @@ enable_error_code = [ "truthy-iterable", ] exclude = [ - '^music_assistant/controllers/.*$', + '^music_assistant/controllers/__init__.py$', + '^music_assistant/controllers/cache.py$', + '^music_assistant/controllers/config.py$', + '^music_assistant/controllers/media/.*$', + '^music_assistant/controllers/metadata.py$', + '^music_assistant/controllers/music.py$', + '^music_assistant/controllers/player_queues.py$', + '^music_assistant/controllers/players.py$', + '^music_assistant/controllers/streams.py$', '^music_assistant/helpers/app_vars.py', '^music_assistant/models/player_provider.py', '^music_assistant/providers/apple_music/.*$',