better handling of multi client queue stream
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Wed, 6 Apr 2022 11:24:47 +0000 (13:24 +0200)
committerMarcel van der Veldt <m.vanderveldt@outlook.com>
Wed, 6 Apr 2022 11:24:47 +0000 (13:24 +0200)
music_assistant/controllers/stream.py

index 71428644465a3cb767f53ca789a5a555ab754988..29b6715122b4db1f8bd8bf618955b87bb6914a55 100644 (file)
@@ -65,9 +65,6 @@ class StreamController:
 
         async def on_shutdown_event(*args, **kwargs):
             """Handle shutdown event."""
-            for subscribers in self._subscribers.values():
-                for callback in subscribers.values():
-                    await callback(b"")
             for task in self._stream_tasks.values():
                 task.cancel()
             await http_site.stop()
@@ -180,13 +177,13 @@ class StreamController:
                 except BrokenPipeError:
                     pass  # race condition
 
-            await self.subscribe(queue_id, client_id, audio_callback)
+            await self.subscribe_client(queue_id, client_id, audio_callback)
             await last_chunk_received.wait()
         finally:
-            await self.unsubscribe(queue_id, client_id)
+            await self.unsubscribe_client(queue_id, client_id)
         return resp
 
-    async def subscribe(
+    async def subscribe_client(
         self, queue_id: str, client_id: str, callback: Awaitable
     ) -> None:
         """Subscribe client to queue stream."""
@@ -204,33 +201,19 @@ class StreamController:
         assert expected_clients, "No clients expected for this stream"
 
         stream_task = self._stream_tasks.get(queue_id)
-        if stream_task is not None:
-            # a new client connected while we're already streaming, tell the queue to restart
-            stream_task.cancel()
-            await queue.resume()
-            return
         # we start the stream as soon as we've reached the expected number of clients
         # TODO: add timeout guard just in case we don't reach the number of expected client
         if stream_task is None and len(self._subscribers[queue_id]) >= expected_clients:
             # start the stream task
-            self._stream_tasks[queue_id] = task = asyncio.create_task(
-                self.start_multi_queue_stream(queue_id)
-            )
-            self.logger.debug("Multi client queue stream %s started", queue.queue_id)
-
-            def task_done_callback(*args, **kwargs):
-                self._stream_tasks.pop(queue_id, None)
-
-            task.add_done_callback(task_done_callback)
+            await self.start_multi_queue_stream(queue_id)
 
         self.logger.debug(
             "Subscribed client %s to multi queue stream %s",
             client_id,
             queue.queue_id,
         )
-        return client_id
 
-    async def unsubscribe(self, queue_id: str, clientid: str):
+    async def unsubscribe_client(self, queue_id: str, clientid: str):
         """Unsubscribe client from queue stream."""
         self._subscribers[queue_id].pop(clientid, None)
         self.logger.debug(
@@ -247,17 +230,40 @@ class StreamController:
     async def start_multi_queue_stream(self, queue_id: str) -> None:
         """Start the Queue stream feeding callbacks of listeners.."""
         queue = self.mass.players.get_player_queue(queue_id)
-        async for chunk in self._get_queue_stream(queue, 44100, 16, 2, resample=True):
-            if len(self._subscribers[queue_id].values()) == 0:
-                # just in case of race conditions
-                return
-            await asyncio.gather(
-                *[cb(chunk) for cb in list(self._subscribers[queue_id].values())]
-            )
-        # send empty chunk to inform EOF
-        await asyncio.gather(
-            *[cb(b"") for cb in list(self._subscribers[queue_id].values())]
-        )
+        assert queue_id not in self._stream_tasks, "already running!"
+
+        async def queue_task():
+            self.logger.debug("Multi client queue stream %s started", queue.queue_id)
+            try:
+                async for chunk in self._get_queue_stream(
+                    queue, 44100, 16, 2, resample=True
+                ):
+                    if len(self._subscribers[queue_id].values()) == 0:
+                        # just in case of race conditions
+                        return
+                    await asyncio.gather(
+                        *[
+                            cb(chunk)
+                            for cb in list(self._subscribers[queue_id].values())
+                        ]
+                    )
+            finally:
+                self._stream_tasks.pop(queue_id, None)
+                # send empty chunk to inform EOF
+                await asyncio.gather(
+                    *[cb(b"") for cb in list(self._subscribers[queue_id].values())]
+                )
+                self.logger.debug("Multi client queue stream %s ended", queue.queue_id)
+
+        self._stream_tasks[queue_id] = asyncio.create_task(queue_task())
+
+    async def stop_multi_queue_stream(self, queue_id: str) -> None:
+        """Signal a running queue stream task and its listeners to stop."""
+        if queue_id not in self._stream_tasks:
+            return
+
+        if stream_task := self._stream_tasks.pop(queue_id, None):
+            stream_task.cancel()
 
     async def _get_queue_stream(
         self,