data_channel: Any = None
local_ws: Any = None
message_queue: asyncio.Queue[str] = field(default_factory=asyncio.Queue)
+ forward_to_local_task: asyncio.Task[None] | None = None
+ forward_from_local_task: asyncio.Task[None] | None = None
class WebRTCGateway:
try:
session.local_ws = await local_session.ws_connect(self.local_ws_url)
loop = asyncio.get_event_loop()
- asyncio.create_task(self._forward_to_local(session))
- asyncio.create_task(self._forward_from_local(session, local_session))
+
+ # Store task references for proper cleanup
+ session.forward_to_local_task = asyncio.create_task(self._forward_to_local(session))
+ session.forward_from_local_task = asyncio.create_task(
+ self._forward_from_local(session, local_session)
+ )
@channel.on("message") # type: ignore[misc]
def on_message(message: str) -> None:
# Called from aiortc thread, use call_soon_threadsafe
- loop.call_soon_threadsafe(session.message_queue.put_nowait, message)
+ # Only queue message if session is still active
+ if session.forward_to_local_task and not session.forward_to_local_task.done():
+ loop.call_soon_threadsafe(session.message_queue.put_nowait, message)
@channel.on("close") # type: ignore[misc]
def on_close() -> None:
# Regular WebSocket message
if session.local_ws and not session.local_ws.closed:
await session.local_ws.send_str(message)
+ except asyncio.CancelledError:
+ # Task was cancelled during cleanup, this is expected
+ self.logger.debug("Forward to local task cancelled for session %s", session.session_id)
+ raise
except Exception:
self.logger.exception("Error forwarding to local WebSocket")
session.data_channel.send(msg.data)
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
break
+ except asyncio.CancelledError:
+ # Task was cancelled during cleanup, this is expected
+ self.logger.debug(
+ "Forward from local task cancelled for session %s", session.session_id
+ )
+ raise
except Exception:
self.logger.exception("Error forwarding from local WebSocket")
finally:
session = self.sessions.pop(session_id, None)
if not session:
return
+
+ # Cancel forwarding tasks first to prevent race conditions
+ if session.forward_to_local_task and not session.forward_to_local_task.done():
+ session.forward_to_local_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await session.forward_to_local_task
+
+ if session.forward_from_local_task and not session.forward_from_local_task.done():
+ session.forward_from_local_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await session.forward_from_local_task
+
+ # Close connections
if session.local_ws and not session.local_ws.closed:
await session.local_ws.close()
if session.data_channel: