improve cleanup at close/stop (#246)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Wed, 6 Apr 2022 13:13:31 +0000 (15:13 +0200)
committerGitHub <noreply@github.com>
Wed, 6 Apr 2022 13:13:31 +0000 (15:13 +0200)
examples/full.py
examples/simple.py
music_assistant/controllers/music/__init__.py
music_assistant/controllers/stream.py
music_assistant/helpers/cache.py
music_assistant/helpers/util.py
music_assistant/mass.py
music_assistant/models/player.py
music_assistant/models/player_queue.py
requirements.txt

index 8df915261b6234943aad3d00b5e7589f79be3fa2..dba2ca64ef35a5fd89e8c708b03a71073e659fc7 100644 (file)
@@ -4,8 +4,6 @@ import asyncio
 import logging
 import os
 
-from aiorun import run
-
 from music_assistant.mass import MusicAssistant
 from music_assistant.models.player import Player, PlayerState
 from music_assistant.providers.filesystem import FileSystemProvider
@@ -76,8 +74,6 @@ if not os.path.isdir(data_dir):
     os.makedirs(data_dir)
 db_file = os.path.join(data_dir, "music_assistant.db")
 
-mass = MusicAssistant(f"sqlite:///{db_file}")
-
 
 providers = []
 if args.spotify_username and args.spotify_password:
@@ -145,14 +141,13 @@ class TestPlayer(Player):
         self.update_state()
 
 
-def main():
+async def main():
     """Handle main execution."""
 
-    async def async_main():
-        """Async main routine."""
-        asyncio.get_event_loop().set_debug(args.debug)
+    asyncio.get_event_loop().set_debug(args.debug)
+
+    async with MusicAssistant(f"sqlite:///{db_file}") as mass:
 
-        await mass.setup()
         # register music provider(s)
         for prov in providers:
             await mass.music.register_provider(prov)
@@ -176,16 +171,11 @@ def main():
         if len(playlists) > 0:
             await test_player.active_queue.play_media(playlists[0].uri)
 
-    def on_shutdown(loop):
-        loop.run_until_complete(mass.stop())
-
-    run(
-        async_main(),
-        use_uvloop=True,
-        shutdown_callback=on_shutdown,
-        executor_workers=64,
-    )
+        await asyncio.sleep(3600)
 
 
 if __name__ == "__main__":
-    main()
+    try:
+        asyncio.run(main())
+    except KeyboardInterrupt:
+        pass
index 0e00d6b3192bf22d33a854677f6db7e1c1266a47..54f3adfd8a4041fd0d9a4a97a8c95fcd945e77fa 100644 (file)
@@ -4,7 +4,6 @@ import asyncio
 import logging
 import os
 
-from aiorun import run
 
 from music_assistant.mass import MusicAssistant
 from music_assistant.providers.spotify import SpotifyProvider
@@ -52,30 +51,29 @@ mass = MusicAssistant(f"sqlite:///{db_file}")
 spotify = SpotifyProvider(args.username, args.password)
 
 
-def main():
+async def main():
     """Handle main execution."""
 
-    async def async_main():
-        """Async main routine."""
-        asyncio.get_event_loop().set_debug(args.debug)
-        await mass.setup()
-        # register music provider(s)
-        await mass.music.register_provider(spotify)
-        # get some data
-        await mass.music.artists.library()
-        await mass.music.tracks.library()
-        await mass.music.radio.library()
-
-    def on_shutdown(loop):
-        loop.run_until_complete(mass.stop())
-
-    run(
-        async_main(),
-        use_uvloop=True,
-        shutdown_callback=on_shutdown,
-        executor_workers=64,
-    )
+    asyncio.get_event_loop().set_debug(args.debug)
+
+    # without contextmanager we need to call the async setup
+    await mass.setup()
+    # register music provider(s)
+    await mass.music.register_provider(spotify)
+    # get some data
+    await mass.music.artists.library()
+    await mass.music.tracks.library()
+    await mass.music.radio.library()
+
+    # run for an hour until someone hits CTRL+C
+    await asyncio.sleep(3600)
+
+    # without contextmanager we need to call the stop
+    await mass.stop()
 
 
 if __name__ == "__main__":
-    main()
+    try:
+        asyncio.run(main())
+    except KeyboardInterrupt:
+        pass
index 6012f9ad5589b36ed46f153a48ce3b7840b2adfb..475e0df4e966520a8c2aed1f33b2197810d51f35 100755 (executable)
@@ -14,7 +14,7 @@ from music_assistant.controllers.music.tracks import TracksController
 from music_assistant.helpers.cache import cached
 from music_assistant.helpers.datetime import utc_timestamp
 from music_assistant.helpers.typing import MusicAssistant
-from music_assistant.helpers.util import create_task, run_periodic
+from music_assistant.helpers.util import run_periodic
 from music_assistant.models.errors import (
     AlreadyRegisteredError,
     MusicAssistantError,
@@ -59,7 +59,7 @@ class MusicController:
         await self.tracks.setup()
         await self.radio.setup()
         await self.playlists.setup()
-        create_task(self.__periodic_sync)
+        self.mass.create_task(self.__periodic_sync)
 
     @property
     def provider_count(self) -> int:
@@ -95,7 +95,7 @@ class MusicController:
         else:
             self._providers[provider.id] = provider
             self.mass.signal_event(EventType.PROVIDER_REGISTERED, provider)
-            create_task(self.run_provider_sync(provider.id))
+            self.mass.create_task(self.run_provider_sync(provider.id))
 
     async def search(
         self, search_query, media_types: List[MediaType], limit: int = 10
index 29b6715122b4db1f8bd8bf618955b87bb6914a55..c94445af3a8fad23b3636d3f7e38b51ffd8d8f2a 100644 (file)
@@ -20,7 +20,7 @@ from music_assistant.helpers.audio import (
 )
 from music_assistant.helpers.process import AsyncProcess
 from music_assistant.helpers.typing import MusicAssistant
-from music_assistant.helpers.util import create_task, get_ip
+from music_assistant.helpers.util import get_ip
 from music_assistant.models.errors import MediaNotFoundError
 from music_assistant.models.media_items import ContentType
 from music_assistant.models.player_queue import PlayerQueue
@@ -131,7 +131,7 @@ class StreamController:
                 # write eof when last packet is received
                 sox_proc.write_eof()
 
-            create_task(writer)
+            self.mass.create_task(writer)
 
             # read bytes from final output
             chunksize = 32000 if output_fmt == ContentType.MP3 else 90000
index a6444708f1a476bb776de96d1ae398bb2ba347a8..ae430e2ccdb20feacf995bbaa3fd9b306d3c576b 100644 (file)
@@ -8,7 +8,6 @@ import time
 from typing import Awaitable
 
 from music_assistant.helpers.typing import MusicAssistant
-from music_assistant.helpers.util import create_task
 
 DB_TABLE = "cache"
 
@@ -123,7 +122,7 @@ class Cache:
 
 
 async def cached(
-    cache,
+    cache: Cache,
     cache_key: str,
     coro_func: Awaitable,
     *args,
@@ -138,5 +137,5 @@ async def cached(
         result = await coro_func
     else:
         result = await coro_func(*args)
-    create_task(cache.set(cache_key, result, checksum, expires))
+    cache.mass.create_task(cache.set(cache_key, result, checksum, expires))
     return result
index f314e660eeb0195bafc339befd89b016ce29afad..d49a8f3dd5807d2234cbf5edeb2f48377f2f3095 100755 (executable)
@@ -2,14 +2,11 @@
 from __future__ import annotations
 
 import asyncio
-import functools
 import os
 import platform
 import socket
 import tempfile
-import threading
-from asyncio.events import AbstractEventLoop
-from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union
+from typing import Any, Callable, Dict, List, Optional, Set, TypeVar
 
 import memory_tempfile
 
@@ -20,51 +17,6 @@ CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable)
 CALLBACK_TYPE = Callable[[], None]
 # pylint: enable=invalid-name
 
-DEFAULT_LOOP = None
-
-
-def create_task(
-    target: Callable[..., Any],
-    *args: Any,
-    loop: AbstractEventLoop = None,
-    **kwargs: Any,
-) -> Union[asyncio.Task, asyncio.Future]:
-    """Create Task on (main) event loop from Callable or awaitable.
-
-    target: target to call.
-    loop: Running (main) event loop, defaults to loop in current thread
-    args/kwargs: parameters for method to call.
-    """
-    try:
-        loop = loop or asyncio.get_running_loop()
-    except RuntimeError:
-        # try to fetch the default loop from global variable
-        loop = DEFAULT_LOOP
-
-    # Check for partials to properly determine if coroutine function
-    check_target = target
-    while isinstance(check_target, functools.partial):
-        check_target = check_target.func
-
-    async def executor_wrapper(_target: Callable, *_args, **_kwargs):
-        return await loop.run_in_executor(None, _target, *_args, **_kwargs)
-
-    # called from other thread
-    if threading.current_thread() is not threading.main_thread():
-        if asyncio.iscoroutine(check_target):
-            return asyncio.run_coroutine_threadsafe(target, loop)
-        if asyncio.iscoroutinefunction(check_target):
-            return asyncio.run_coroutine_threadsafe(target(*args), loop)
-        return asyncio.run_coroutine_threadsafe(
-            executor_wrapper(target, *args, **kwargs), loop
-        )
-
-    if asyncio.iscoroutine(check_target):
-        return loop.create_task(target)
-    if asyncio.iscoroutinefunction(check_target):
-        return loop.create_task(target(*args))
-    return loop.create_task(executor_wrapper(target, *args, **kwargs))
-
 
 def run_periodic(delay: float, later: bool = False):
     """Run a coroutine at interval."""
index d36dcdf5d0a0689d5643d8c416a7920b50343d82..b3aa6343e4dae8e95094a896b4e113131385a0da 100644 (file)
@@ -2,9 +2,12 @@
 from __future__ import annotations
 
 import asyncio
+import functools
 import logging
+import threading
 from time import time
-from typing import Any, Callable, Coroutine, Optional, Tuple, Union
+from types import TracebackType
+from typing import Any, Callable, Coroutine, List, Optional, Tuple, Type, Union
 
 import aiohttp
 from databases import DatabaseURL
@@ -12,10 +15,8 @@ from music_assistant.constants import EventType
 from music_assistant.controllers.metadata import MetaDataController
 from music_assistant.controllers.music import MusicController
 from music_assistant.controllers.players import PlayerController
-from music_assistant.helpers import util
 from music_assistant.helpers.cache import Cache
 from music_assistant.helpers.database import Database
-from music_assistant.helpers.util import create_task
 
 EventCallBackType = Callable[[EventType, Any], None]
 EventSubscriptionType = Tuple[EventCallBackType, Optional[Tuple[EventType]]]
@@ -52,13 +53,12 @@ class MusicAssistant:
         self.metadata = MetaDataController(self)
         self.music = MusicController(self)
         self.players = PlayerController(self, stream_port)
-        self._jobs_task: asyncio.Task = None
+        self._tracked_tasks: List[asyncio.Task] = []
 
     async def setup(self) -> None:
         """Async setup of music assistant."""
         # initialize loop
         self.loop = asyncio.get_event_loop()
-        util.DEFAULT_LOOP = self.loop
         # create shared aiohttp ClientSession
         if not self.http_session:
             self.http_session = aiohttp.ClientSession(
@@ -70,17 +70,20 @@ class MusicAssistant:
         await self.music.setup()
         await self.metadata.setup()
         await self.players.setup()
-        self._jobs_task = create_task(self.__process_jobs())
+        self.create_task(self.__process_jobs())
 
     async def stop(self) -> None:
         """Stop running the music assistant server."""
-        self.logger.info("Application shutdown")
+        self.logger.info("Stop called, cleaning up...")
+        # cancel any running tasks
+        for task in self._tracked_tasks:
+            task.cancel()
         self.signal_event(EventType.SHUTDOWN)
-        if self._jobs_task is not None:
-            self._jobs_task.cancel()
+        # wait for any remaining tasks launched by the shutdown event
+        await asyncio.wait_for(asyncio.wait(self._tracked_tasks), 2)
         if self.http_session and not self.http_session_provided:
             await self.http_session.connector.close()
-        self.http_session.detach()
+            self.http_session.detach()
 
     def signal_event(self, event_type: EventType, event_details: Any = None) -> None:
         """
@@ -90,8 +93,8 @@ class MusicAssistant:
             :param event_details: optional details to send with the event.
         """
         for cb_func, event_filter in self._listeners:
-            if not event_filter or event_type in event_filter:
-                create_task(cb_func, event_type, event_details)
+            if event_filter is None or event_type in event_filter:
+                self.create_task(cb_func, event_type, event_details)
 
     def subscribe(
         self,
@@ -107,8 +110,6 @@ class MusicAssistant:
         """
         if isinstance(event_filter, EventType):
             event_filter = (event_filter,)
-        elif event_filter is None:
-            event_filter = tuple()
         listener = (cb_func, event_filter)
         self._listeners.append(listener)
 
@@ -123,6 +124,53 @@ class MusicAssistant:
             name = job.__qualname__ or job.__name__
         self._jobs.put_nowait((name, job))
 
+    def create_task(
+        self,
+        target: Callable[..., Any],
+        *args: Any,
+        **kwargs: Any,
+    ) -> Union[asyncio.Task, asyncio.Future]:
+        """
+        Create Task on (main) event loop from Callable or awaitable.
+
+        Tasks create dby this helper will be properly cancelled on stop.
+        """
+
+        # Check for partials to properly determine if coroutine function
+        check_target = target
+        while isinstance(check_target, functools.partial):
+            check_target = check_target.func
+
+        async def executor_wrapper(_target: Callable, *_args, **_kwargs):
+            return await self.loop.run_in_executor(None, _target, *_args, **_kwargs)
+
+        # called from other thread
+        if threading.current_thread() is not threading.main_thread():
+            if asyncio.iscoroutine(check_target):
+                task = asyncio.run_coroutine_threadsafe(target, self.loop)
+            elif asyncio.iscoroutinefunction(check_target):
+                task = asyncio.run_coroutine_threadsafe(target(*args), self.loop)
+            else:
+                task = asyncio.run_coroutine_threadsafe(
+                    executor_wrapper(target, *args, **kwargs), self.loop
+                )
+        else:
+            if asyncio.iscoroutine(check_target):
+                task = self.loop.create_task(target)
+            elif asyncio.iscoroutinefunction(check_target):
+                task = self.loop.create_task(target(*args))
+            else:
+                task = self.loop.create_task(executor_wrapper(target, *args, **kwargs))
+
+        def task_done_callback(*args, **kwargs):
+            self.logger.debug("task finished %s", task.get_name())
+            self._tracked_tasks.remove(task)
+
+        self._tracked_tasks.append(task)
+        task.add_done_callback(task_done_callback)
+        self.logger.debug("spawned task %s", task.get_name())
+        return task
+
     async def __process_jobs(self):
         """Process jobs in the background."""
         while True:
@@ -131,7 +179,7 @@ class MusicAssistant:
             self.logger.debug("Start processing job [%s].", name)
             try:
                 # await job
-                task = asyncio.create_task(job, name=name)
+                task = self.create_task(job, name=name)
                 await task
             except Exception as err:  # pylint: disable=broad-except
                 self.logger.error(
@@ -140,3 +188,20 @@ class MusicAssistant:
             else:
                 duration = round(time() - time_start, 2)
                 self.logger.info("Finished job [%s] in %s seconds.", name, duration)
+
+    async def __aenter__(self) -> "MusicAssistant":
+        """Return Context manager."""
+        await self.setup()
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Type[BaseException],
+        exc_val: BaseException,
+        exc_tb: TracebackType,
+    ) -> Optional[bool]:
+        """Exit context manager."""
+        await self.stop()
+        if exc_val:
+            raise exc_val
+        return exc_type
index 653a41a0cea6d88e15e8d0ea2e528ddc57d724cd..cb74b91fcea7a1526605b132b1999cc5c4bea84a 100755 (executable)
@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Dict, List
 from mashumaro import DataClassDictMixin
 from music_assistant.constants import EventType
 from music_assistant.helpers.typing import MusicAssistant
-from music_assistant.helpers.util import create_task
 
 if TYPE_CHECKING:
     from .player_queue import PlayerQueue
@@ -210,12 +209,12 @@ class Player(ABC):
             # update group player childs when parent updates
             for child_player_id in self.group_childs:
                 if player := self.mass.players.get_player(child_player_id):
-                    create_task(player.update_state)
+                    self.mass.create_task(player.update_state)
         else:
             # update group player when child updates
             for group_player_id in self.get_group_parents():
                 if player := self.mass.players.get_player(group_player_id):
-                    create_task(player.update_state)
+                    self.mass.create_task(player.update_state)
 
     def get_group_parents(self) -> List[str]:
         """Get any/all group player id's this player belongs to."""
index c818fbdfe31efdf53936299bc63d379803201b6a..5f45117f0881de59f14b0ebf8b7b4805e7147141 100644 (file)
@@ -14,7 +14,6 @@ from mashumaro import DataClassDictMixin
 from music_assistant.constants import EventType
 from music_assistant.helpers.audio import get_stream_details
 from music_assistant.helpers.typing import MusicAssistant
-from music_assistant.helpers.util import create_task
 from music_assistant.models.errors import MediaNotFoundError, QueueEmpty
 from music_assistant.models.media_items import MediaType, StreamDetails
 
@@ -479,11 +478,11 @@ class PlayerQueue:
             # handle case where stream stopped on purpose and we need to restart it
             if self.player.state != PlayerState.PLAYING and self._signal_next:
                 self._signal_next = False
-                create_task(self.play())
+                self.mass.create_task(self.play())
             # start updater task if needed
             if self.player.state == PlayerState.PLAYING:
                 if not self._update_task:
-                    self._update_task = create_task(self.__update_task())
+                    self._update_task = self.mass.create_task(self.__update_task())
             else:
                 if self._update_task:
                     self._update_task.cancel()
@@ -655,4 +654,6 @@ class PlayerQueue:
 
         if self._save_task and not self._save_task.cancelled():
             return
-        self._save_task = self.mass.loop.call_later(60, create_task, cache_items)
+        self._save_task = self.mass.loop.call_later(
+            60, self.mass.create_task, cache_items
+        )
index 19d3a43127ed2e21ab11add9e41c87dec7e250d8..a085f69443b069414e7095b046f3bdf9e7a32861 100644 (file)
@@ -7,7 +7,6 @@ databases>=0.5,<=0.5.5
 aiosqlite>=0.13,<=0.17
 python-slugify>=4.0,<=6.1.1
 memory-tempfile<=2.2.3
-aiorun>=2021.10,<=2021.10.1
 pillow>=8.0,<9.1.1
 unidecode>=1.0,<=1.3.4
 ujson>=4.0,<=5.1.0