From: Marcel van der Veldt Date: Sun, 30 Nov 2025 21:34:57 +0000 (+0100) Subject: Add some guards to remote connection gateway X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=7600256dba936f1a07bca51b9fa03c737cfcf3b8;p=music-assistant-server.git Add some guards to remote connection gateway --- diff --git a/music_assistant/controllers/webserver/remote_access/gateway.py b/music_assistant/controllers/webserver/remote_access/gateway.py index fe516fef..47b8ecd7 100644 --- a/music_assistant/controllers/webserver/remote_access/gateway.py +++ b/music_assistant/controllers/webserver/remote_access/gateway.py @@ -54,6 +54,8 @@ class WebRTCSession: data_channel: Any = None local_ws: Any = None message_queue: asyncio.Queue[str] = field(default_factory=asyncio.Queue) + forward_to_local_task: asyncio.Task[None] | None = None + forward_from_local_task: asyncio.Task[None] | None = None class WebRTCGateway: @@ -402,13 +404,19 @@ class WebRTCGateway: try: session.local_ws = await local_session.ws_connect(self.local_ws_url) loop = asyncio.get_event_loop() - asyncio.create_task(self._forward_to_local(session)) - asyncio.create_task(self._forward_from_local(session, local_session)) + + # Store task references for proper cleanup + session.forward_to_local_task = asyncio.create_task(self._forward_to_local(session)) + session.forward_from_local_task = asyncio.create_task( + self._forward_from_local(session, local_session) + ) @channel.on("message") # type: ignore[misc] def on_message(message: str) -> None: # Called from aiortc thread, use call_soon_threadsafe - loop.call_soon_threadsafe(session.message_queue.put_nowait, message) + # Only queue message if session is still active + if session.forward_to_local_task and not session.forward_to_local_task.done(): + loop.call_soon_threadsafe(session.message_queue.put_nowait, message) @channel.on("close") # type: ignore[misc] def on_close() -> None: @@ -441,6 +449,10 @@ class WebRTCGateway: # Regular WebSocket message if session.local_ws and not session.local_ws.closed: await session.local_ws.send_str(message) + except asyncio.CancelledError: + # Task was cancelled during cleanup, this is expected + self.logger.debug("Forward to local task cancelled for session %s", session.session_id) + raise except Exception: self.logger.exception("Error forwarding to local WebSocket") @@ -459,6 +471,12 @@ class WebRTCGateway: session.data_channel.send(msg.data) elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): break + except asyncio.CancelledError: + # Task was cancelled during cleanup, this is expected + self.logger.debug( + "Forward from local task cancelled for session %s", session.session_id + ) + raise except Exception: self.logger.exception("Error forwarding from local WebSocket") finally: @@ -528,6 +546,19 @@ class WebRTCGateway: session = self.sessions.pop(session_id, None) if not session: return + + # Cancel forwarding tasks first to prevent race conditions + if session.forward_to_local_task and not session.forward_to_local_task.done(): + session.forward_to_local_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await session.forward_to_local_task + + if session.forward_from_local_task and not session.forward_from_local_task.done(): + session.forward_from_local_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await session.forward_from_local_task + + # Close connections if session.local_ws and not session.local_ws.closed: await session.local_ws.close() if session.data_channel: