Add some guards to remote connection gateway
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Sun, 30 Nov 2025 21:34:57 +0000 (22:34 +0100)
committerMarcel van der Veldt <m.vanderveldt@outlook.com>
Sun, 30 Nov 2025 21:34:57 +0000 (22:34 +0100)
music_assistant/controllers/webserver/remote_access/gateway.py

index fe516fefb3fd89aaec7d784acd121ac0075d7004..47b8ecd7405076b79df2fdd7d943e67c816e8010 100644 (file)
@@ -54,6 +54,8 @@ class WebRTCSession:
     data_channel: Any = None
     local_ws: Any = None
     message_queue: asyncio.Queue[str] = field(default_factory=asyncio.Queue)
+    forward_to_local_task: asyncio.Task[None] | None = None
+    forward_from_local_task: asyncio.Task[None] | None = None
 
 
 class WebRTCGateway:
@@ -402,13 +404,19 @@ class WebRTCGateway:
         try:
             session.local_ws = await local_session.ws_connect(self.local_ws_url)
             loop = asyncio.get_event_loop()
-            asyncio.create_task(self._forward_to_local(session))
-            asyncio.create_task(self._forward_from_local(session, local_session))
+
+            # Store task references for proper cleanup
+            session.forward_to_local_task = asyncio.create_task(self._forward_to_local(session))
+            session.forward_from_local_task = asyncio.create_task(
+                self._forward_from_local(session, local_session)
+            )
 
             @channel.on("message")  # type: ignore[misc]
             def on_message(message: str) -> None:
                 # Called from aiortc thread, use call_soon_threadsafe
-                loop.call_soon_threadsafe(session.message_queue.put_nowait, message)
+                # Only queue message if session is still active
+                if session.forward_to_local_task and not session.forward_to_local_task.done():
+                    loop.call_soon_threadsafe(session.message_queue.put_nowait, message)
 
             @channel.on("close")  # type: ignore[misc]
             def on_close() -> None:
@@ -441,6 +449,10 @@ class WebRTCGateway:
                 # Regular WebSocket message
                 if session.local_ws and not session.local_ws.closed:
                     await session.local_ws.send_str(message)
+        except asyncio.CancelledError:
+            # Task was cancelled during cleanup, this is expected
+            self.logger.debug("Forward to local task cancelled for session %s", session.session_id)
+            raise
         except Exception:
             self.logger.exception("Error forwarding to local WebSocket")
 
@@ -459,6 +471,12 @@ class WebRTCGateway:
                         session.data_channel.send(msg.data)
                 elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
                     break
+        except asyncio.CancelledError:
+            # Task was cancelled during cleanup, this is expected
+            self.logger.debug(
+                "Forward from local task cancelled for session %s", session.session_id
+            )
+            raise
         except Exception:
             self.logger.exception("Error forwarding from local WebSocket")
         finally:
@@ -528,6 +546,19 @@ class WebRTCGateway:
         session = self.sessions.pop(session_id, None)
         if not session:
             return
+
+        # Cancel forwarding tasks first to prevent race conditions
+        if session.forward_to_local_task and not session.forward_to_local_task.done():
+            session.forward_to_local_task.cancel()
+            with contextlib.suppress(asyncio.CancelledError):
+                await session.forward_to_local_task
+
+        if session.forward_from_local_task and not session.forward_from_local_task.done():
+            session.forward_from_local_task.cancel()
+            with contextlib.suppress(asyncio.CancelledError):
+                await session.forward_from_local_task
+
+        # Close connections
         if session.local_ws and not session.local_ws.closed:
             await session.local_ws.close()
         if session.data_channel: