Use TaskManager instead of TaskGroup (#1456)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Sat, 6 Jul 2024 10:20:11 +0000 (12:20 +0200)
committerGitHub <noreply@github.com>
Sat, 6 Jul 2024 10:20:11 +0000 (12:20 +0200)
for operations where we don't want to stop other tasks when one of the tasks fails

music_assistant/server/controllers/players.py
music_assistant/server/helpers/multi_client_stream.py
music_assistant/server/helpers/util.py
music_assistant/server/providers/airplay/__init__.py
music_assistant/server/providers/dlna/__init__.py
music_assistant/server/providers/slimproto/__init__.py
music_assistant/server/providers/sonos/__init__.py
music_assistant/server/providers/ugp/__init__.py
music_assistant/server/server.py

index 1e856917db1efca1edde533e062418a5c36fd769..e119fee96b548858ad9c1da5221581b7bff0e024 100644 (file)
@@ -44,6 +44,7 @@ from music_assistant.constants import (
 )
 from music_assistant.server.helpers.api import api_command
 from music_assistant.server.helpers.tags import parse_tags
+from music_assistant.server.helpers.util import TaskManager
 from music_assistant.server.models.core_controller import CoreController
 from music_assistant.server.models.player_provider import PlayerProvider
 
@@ -399,8 +400,8 @@ class PlayerController(CoreController):
         if not power and group_player.state in (PlayerState.PLAYING, PlayerState.PAUSED):
             await self.cmd_stop(player_id)
 
-        async with asyncio.TaskGroup() as tg:
-            any_member_powered = False
+        any_member_powered = False
+        async with TaskManager(self.mass) as tg:
             for member in self.iter_group_members(group_player, only_powered=True):
                 any_member_powered = True
                 if power:
@@ -502,7 +503,7 @@ class PlayerController(CoreController):
                         "while one or more individual players are playing. "
                         "This announcement will be redirected to the individual players."
                     )
-                    async with asyncio.TaskGroup() as tg:
+                    async with TaskManager(self.mass) as tg:
                         for group_member in player.group_childs:
                             tg.create_task(
                                 self.play_announcement(
@@ -1147,7 +1148,7 @@ class PlayerController(CoreController):
         # adjust volume if needed
         # in case of a (sync) group, we need to do this for all child players
         prev_volumes: dict[str, int] = {}
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for volume_player_id in player.group_childs or (player.player_id,):
                 if not (volume_player := self.get(volume_player_id)):
                     continue
@@ -1194,7 +1195,7 @@ class PlayerController(CoreController):
             "Announcement to player %s - restore previous state...", player.display_name
         )
         # restore volume
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for volume_player_id, prev_volume in prev_volumes.items():
                 tg.create_task(self.cmd_volume_set(volume_player_id, prev_volume))
 
index 542b879ab03142773e920d0c4ee36dfbe1cd65ac..c0a9981bb639b705b3f791d70a7a5688f3a9c77a 100644 (file)
@@ -8,6 +8,7 @@ from contextlib import suppress
 from music_assistant.common.helpers.util import empty_queue
 from music_assistant.common.models.media_items import AudioFormat
 from music_assistant.server.helpers.audio import get_ffmpeg_stream
+from music_assistant.server.helpers.util import TaskManager
 
 LOGGER = logging.getLogger(__name__)
 
@@ -93,6 +94,6 @@ class MultiClientStream:
                 *[sub.put(chunk) for sub in self.subscribers], return_exceptions=True
             )
         # EOF: send empty chunk
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for sub in list(self.subscribers):
                 tg.create_task(sub.put(b""))
index 352af3783da079eda2c82c6334bd9b83a862d90e..08f6d79b5cecaa80b9c4ad5544e3e7b8e956d453 100644 (file)
@@ -10,10 +10,12 @@ import tempfile
 import urllib.error
 import urllib.parse
 import urllib.request
+from collections.abc import Coroutine
 from functools import lru_cache
 from importlib.metadata import PackageNotFoundError
 from importlib.metadata import version as pkg_version
-from typing import TYPE_CHECKING
+from types import TracebackType
+from typing import TYPE_CHECKING, Self
 
 import ifaddr
 import memory_tempfile
@@ -23,6 +25,7 @@ from music_assistant.server.helpers.process import check_output
 if TYPE_CHECKING:
     from collections.abc import Iterator
 
+    from music_assistant.server import MusicAssistant
     from music_assistant.server.models import ProviderModuleType
 
 LOGGER = logging.getLogger(__name__)
@@ -130,3 +133,37 @@ def divide_chunks(data: bytes, chunk_size: int) -> Iterator[bytes]:
     """Chunk bytes data into smaller chunks."""
     for i in range(0, len(data), chunk_size):
         yield data[i : i + chunk_size]
+
+
+class TaskManager:
+    """
+    Helper class to run many tasks at once.
+
+    This is basically an alternative to asyncio.TaskGroup but this will not
+    cancel all operations when one of the tasks fails.
+    Logging of exceptions is done by the mass.create_task helper.
+    """
+
+    def __init__(self, mass: MusicAssistant):
+        """Initialize the TaskManager."""
+        self.mass = mass
+        self._tasks: list[asyncio.Task] = []
+
+    def create_task(self, coro: Coroutine) -> None:
+        """Create a new task and add it to the manager."""
+        task = self.mass.create_task(coro)
+        self._tasks.append(task)
+
+    async def __aenter__(self) -> Self:
+        """Enter context manager."""
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: type[BaseException] | None,
+        exc_val: BaseException | None,
+        exc_tb: TracebackType | None,
+    ) -> bool | None:
+        """Exit context manager."""
+        await asyncio.wait(self._tasks)
+        self._tasks.clear()
index f5de17fdc371da91470b583fa10b801bbbd54506..0cc5d43a446fd3efa0e300b5d52c2ec87d28e973 100644 (file)
@@ -46,6 +46,7 @@ from music_assistant.common.models.player_queue import PlayerQueue
 from music_assistant.constants import CONF_SYNC_ADJUST, VERBOSE_LOG_LEVEL
 from music_assistant.server.helpers.audio import FFMpeg, get_ffmpeg_stream, get_player_filter_params
 from music_assistant.server.helpers.process import AsyncProcess, check_output
+from music_assistant.server.helpers.util import TaskManager
 from music_assistant.server.models.player_provider import PlayerProvider
 
 if TYPE_CHECKING:
@@ -590,7 +591,7 @@ class AirplayProvider(PlayerProvider):
         - player_id: player_id of the player to handle the command.
         """
         # forward command to player and any connected sync members
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for airplay_player in self._get_sync_clients(player_id):
                 if airplay_player.active_stream:
                     tg.create_task(airplay_player.active_stream.stop())
@@ -604,7 +605,7 @@ class AirplayProvider(PlayerProvider):
         - player_id: player_id of the player to handle the command.
         """
         # forward command to player and any connected sync members
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for airplay_player in self._get_sync_clients(player_id):
                 if airplay_player.active_stream and airplay_player.active_stream.running:
                     # prefer interactive command to our streamer
@@ -639,7 +640,7 @@ class AirplayProvider(PlayerProvider):
             # should not happen, but just in case
             raise RuntimeError("Player is synced")
         # always stop existing stream first
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for airplay_player in self._get_sync_clients(player_id):
                 if airplay_player.active_stream and airplay_player.active_stream:
                     tg.create_task(airplay_player.active_stream.stop())
index 117bb7c640fd0a358b83508b6cb21b3fd5284962..aa9be58ceac3d14cdac975e2563c598e5f4d2811 100644 (file)
@@ -41,6 +41,7 @@ from music_assistant.common.models.errors import PlayerUnavailableError
 from music_assistant.common.models.player import DeviceInfo, Player, PlayerMedia
 from music_assistant.constants import CONF_ENFORCE_MP3, CONF_PLAYERS, VERBOSE_LOG_LEVEL
 from music_assistant.server.helpers.didl_lite import create_didl_metadata
+from music_assistant.server.helpers.util import TaskManager
 from music_assistant.server.models.player_provider import PlayerProvider
 
 from .helpers import DLNANotifyServer
@@ -286,7 +287,7 @@ class DLNAPlayerProvider(PlayerProvider):
         Called when provider is deregistered (e.g. MA exiting or config reloading).
         """
         self.mass.streams.unregister_dynamic_route("/notify", "NOTIFY")
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for dlna_player in self.dlnaplayers.values():
                 tg.create_task(self._device_disconnect(dlna_player))
 
index c086ee1790f989f2fece0c9577bc865b5b93b562..2f46538dd0c624b19623ea180d60744f2a323652 100644 (file)
@@ -59,6 +59,7 @@ from music_assistant.constants import (
 )
 from music_assistant.server.helpers.audio import get_ffmpeg_stream, get_player_filter_params
 from music_assistant.server.helpers.multi_client_stream import MultiClientStream
+from music_assistant.server.helpers.util import TaskManager
 from music_assistant.server.models.player_provider import PlayerProvider
 from music_assistant.server.providers.ugp import UniversalGroupProvider
 
@@ -333,14 +334,14 @@ class SlimprotoProvider(PlayerProvider):
     async def cmd_stop(self, player_id: str) -> None:
         """Send STOP command to given player."""
         # forward command to player and any connected sync members
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for slimplayer in self._get_sync_clients(player_id):
                 tg.create_task(slimplayer.stop())
 
     async def cmd_play(self, player_id: str) -> None:
         """Send PLAY command to given player."""
         # forward command to player and any connected sync members
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for slimplayer in self._get_sync_clients(player_id):
                 tg.create_task(slimplayer.play())
 
@@ -407,7 +408,7 @@ class SlimprotoProvider(PlayerProvider):
         base_url = f"{self.mass.streams.base_url}/slimproto/multi?player_id={player_id}&fmt=flac"
 
         # forward to downstream play_media commands
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for slimplayer in self._get_sync_clients(player_id):
                 url = f"{base_url}&child_player_id={slimplayer.player_id}"
                 if self.mass.config.get_raw_player_config_value(
@@ -508,7 +509,7 @@ class SlimprotoProvider(PlayerProvider):
     async def cmd_pause(self, player_id: str) -> None:
         """Send PAUSE command to given player."""
         # forward command to player and any connected sync members
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for slimplayer in self._get_sync_clients(player_id):
                 tg.create_task(slimplayer.pause())
 
@@ -818,7 +819,7 @@ class SlimprotoProvider(PlayerProvider):
                 break
 
         # all child's ready (or timeout) - start play
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for _client in self._get_sync_clients(player.player_id):
                 self._sync_playpoints.setdefault(
                     _client.player_id, deque(maxlen=MIN_REQ_PLAYPOINTS)
index f6811d51e77618b3593c2b178f61bb76807f1e68..cca3fc861d7bebf82335ddff3f1e354261ceead1 100644 (file)
@@ -36,6 +36,7 @@ from music_assistant.common.models.errors import PlayerCommandFailed, PlayerUnav
 from music_assistant.common.models.player import DeviceInfo, Player, PlayerMedia
 from music_assistant.constants import CONF_CROSSFADE, SYNCGROUP_PREFIX, VERBOSE_LOG_LEVEL
 from music_assistant.server.helpers.didl_lite import create_didl_metadata
+from music_assistant.server.helpers.util import TaskManager
 from music_assistant.server.models.player_provider import PlayerProvider
 
 from .player import SonosPlayer
@@ -417,7 +418,7 @@ class SonosPlayerProvider(PlayerProvider):
         """Handle (provider native) playback of an announcement on given player."""
         if player_id.startswith(SYNCGROUP_PREFIX):
             # handle syncgroup, unwrap to all underlying child's
-            async with asyncio.TaskGroup() as tg:
+            async with TaskManager(self.mass) as tg:
                 if group_player := self.mass.players.get(player_id):
                     # execute on all child players
                     for child_player_id in group_player.group_childs:
index 37a450c7e66b2111a6df015883ed49a7012fa4c9..85983c6530babe37136848edd9b2df5e88fdfae2 100644 (file)
@@ -7,7 +7,6 @@ allowing the user to create player groups from all players known in the system.
 
 from __future__ import annotations
 
-import asyncio
 from time import time
 from typing import TYPE_CHECKING
 
@@ -38,6 +37,7 @@ from music_assistant.constants import CONF_GROUP_MEMBERS, SYNCGROUP_PREFIX
 from music_assistant.server.controllers.streams import DEFAULT_STREAM_HEADERS
 from music_assistant.server.helpers.audio import get_ffmpeg_stream, get_player_filter_params
 from music_assistant.server.helpers.multi_client_stream import MultiClientStream
+from music_assistant.server.helpers.util import TaskManager
 from music_assistant.server.models.player_provider import PlayerProvider
 
 if TYPE_CHECKING:
@@ -150,7 +150,7 @@ class UniversalGroupProvider(PlayerProvider):
         group_player.state = PlayerState.IDLE
         self.mass.players.update(player_id)
         # forward command to player and any connected sync child's
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for member in self.mass.players.iter_group_members(group_player, only_powered=True):
                 if member.state == PlayerState.IDLE:
                     continue
@@ -215,7 +215,7 @@ class UniversalGroupProvider(PlayerProvider):
         base_url = f"{self.mass.streams.base_url}/ugp/{player_id}.flac"
 
         # forward to downstream play_media commands
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self.mass) as tg:
             for member in self.mass.players.iter_group_members(group_player, only_powered=True):
                 if member.player_id.startswith(SYNCGROUP_PREFIX):
                     member = self.mass.players.get_sync_leader(member)  # noqa: PLW2901
index c09a2a30b65b24e8e73097ba5acbf368f01d3a49..9d7223981c4735668b624918df7da48d1adeebc9 100644 (file)
@@ -42,6 +42,7 @@ from music_assistant.server.controllers.webserver import WebserverController
 from music_assistant.server.helpers.api import APICommandHandler, api_command
 from music_assistant.server.helpers.images import get_icon_string
 from music_assistant.server.helpers.util import (
+    TaskManager,
     get_package_version,
     is_hass_supervisor,
     load_provider_module,
@@ -168,9 +169,10 @@ class MusicAssistant:
         for task in self._tracked_tasks.values():
             task.cancel()
         # cleanup all providers
-        async with asyncio.TaskGroup() as tg:
-            for prov_id in list(self._providers.keys()):
-                tg.create_task(self.unload_provider(prov_id))
+        await asyncio.gather(
+            *[self.unload_provider(prov_id) for prov_id in list(self._providers.keys())],
+            return_exceptions=True,
+        )
         # stop core controllers
         await self.streams.close()
         await self.webserver.close()
@@ -653,7 +655,7 @@ class MusicAssistant:
                         exc_info=exc,
                     )
 
-        async with asyncio.TaskGroup() as tg:
+        async with TaskManager(self) as tg:
             for dir_str in os.listdir(PROVIDERS_PATH):
                 dir_path = os.path.join(PROVIDERS_PATH, dir_str)
                 if dir_str == "test" and not ENABLE_DEBUG: