From 4d1e41de0fc53e13f9c2f933d0a7b988fa09e402 Mon Sep 17 00:00:00 2001 From: Marcel van der Veldt Date: Wed, 6 Apr 2022 15:13:31 +0200 Subject: [PATCH] improve cleanup at close/stop (#246) --- examples/full.py | 28 ++---- examples/simple.py | 44 ++++----- music_assistant/controllers/music/__init__.py | 6 +- music_assistant/controllers/stream.py | 4 +- music_assistant/helpers/cache.py | 5 +- music_assistant/helpers/util.py | 50 +--------- music_assistant/mass.py | 95 ++++++++++++++++--- music_assistant/models/player.py | 5 +- music_assistant/models/player_queue.py | 9 +- requirements.txt | 1 - 10 files changed, 125 insertions(+), 122 deletions(-) diff --git a/examples/full.py b/examples/full.py index 8df91526..dba2ca64 100644 --- a/examples/full.py +++ b/examples/full.py @@ -4,8 +4,6 @@ import asyncio import logging import os -from aiorun import run - from music_assistant.mass import MusicAssistant from music_assistant.models.player import Player, PlayerState from music_assistant.providers.filesystem import FileSystemProvider @@ -76,8 +74,6 @@ if not os.path.isdir(data_dir): os.makedirs(data_dir) db_file = os.path.join(data_dir, "music_assistant.db") -mass = MusicAssistant(f"sqlite:///{db_file}") - providers = [] if args.spotify_username and args.spotify_password: @@ -145,14 +141,13 @@ class TestPlayer(Player): self.update_state() -def main(): +async def main(): """Handle main execution.""" - async def async_main(): - """Async main routine.""" - asyncio.get_event_loop().set_debug(args.debug) + asyncio.get_event_loop().set_debug(args.debug) + + async with MusicAssistant(f"sqlite:///{db_file}") as mass: - await mass.setup() # register music provider(s) for prov in providers: await mass.music.register_provider(prov) @@ -176,16 +171,11 @@ def main(): if len(playlists) > 0: await test_player.active_queue.play_media(playlists[0].uri) - def on_shutdown(loop): - loop.run_until_complete(mass.stop()) - - run( - async_main(), - use_uvloop=True, - shutdown_callback=on_shutdown, - executor_workers=64, - ) + await asyncio.sleep(3600) if __name__ == "__main__": - main() + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/examples/simple.py b/examples/simple.py index 0e00d6b3..54f3adfd 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -4,7 +4,6 @@ import asyncio import logging import os -from aiorun import run from music_assistant.mass import MusicAssistant from music_assistant.providers.spotify import SpotifyProvider @@ -52,30 +51,29 @@ mass = MusicAssistant(f"sqlite:///{db_file}") spotify = SpotifyProvider(args.username, args.password) -def main(): +async def main(): """Handle main execution.""" - async def async_main(): - """Async main routine.""" - asyncio.get_event_loop().set_debug(args.debug) - await mass.setup() - # register music provider(s) - await mass.music.register_provider(spotify) - # get some data - await mass.music.artists.library() - await mass.music.tracks.library() - await mass.music.radio.library() - - def on_shutdown(loop): - loop.run_until_complete(mass.stop()) - - run( - async_main(), - use_uvloop=True, - shutdown_callback=on_shutdown, - executor_workers=64, - ) + asyncio.get_event_loop().set_debug(args.debug) + + # without contextmanager we need to call the async setup + await mass.setup() + # register music provider(s) + await mass.music.register_provider(spotify) + # get some data + await mass.music.artists.library() + await mass.music.tracks.library() + await mass.music.radio.library() + + # run for an hour until someone hits CTRL+C + await asyncio.sleep(3600) + + # without contextmanager we need to call the stop + await mass.stop() if __name__ == "__main__": - main() + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/music_assistant/controllers/music/__init__.py b/music_assistant/controllers/music/__init__.py index 6012f9ad..475e0df4 100755 --- a/music_assistant/controllers/music/__init__.py +++ b/music_assistant/controllers/music/__init__.py @@ -14,7 +14,7 @@ from music_assistant.controllers.music.tracks import TracksController from music_assistant.helpers.cache import cached from music_assistant.helpers.datetime import utc_timestamp from music_assistant.helpers.typing import MusicAssistant -from music_assistant.helpers.util import create_task, run_periodic +from music_assistant.helpers.util import run_periodic from music_assistant.models.errors import ( AlreadyRegisteredError, MusicAssistantError, @@ -59,7 +59,7 @@ class MusicController: await self.tracks.setup() await self.radio.setup() await self.playlists.setup() - create_task(self.__periodic_sync) + self.mass.create_task(self.__periodic_sync) @property def provider_count(self) -> int: @@ -95,7 +95,7 @@ class MusicController: else: self._providers[provider.id] = provider self.mass.signal_event(EventType.PROVIDER_REGISTERED, provider) - create_task(self.run_provider_sync(provider.id)) + self.mass.create_task(self.run_provider_sync(provider.id)) async def search( self, search_query, media_types: List[MediaType], limit: int = 10 diff --git a/music_assistant/controllers/stream.py b/music_assistant/controllers/stream.py index 29b67151..c94445af 100644 --- a/music_assistant/controllers/stream.py +++ b/music_assistant/controllers/stream.py @@ -20,7 +20,7 @@ from music_assistant.helpers.audio import ( ) from music_assistant.helpers.process import AsyncProcess from music_assistant.helpers.typing import MusicAssistant -from music_assistant.helpers.util import create_task, get_ip +from music_assistant.helpers.util import get_ip from music_assistant.models.errors import MediaNotFoundError from music_assistant.models.media_items import ContentType from music_assistant.models.player_queue import PlayerQueue @@ -131,7 +131,7 @@ class StreamController: # write eof when last packet is received sox_proc.write_eof() - create_task(writer) + self.mass.create_task(writer) # read bytes from final output chunksize = 32000 if output_fmt == ContentType.MP3 else 90000 diff --git a/music_assistant/helpers/cache.py b/music_assistant/helpers/cache.py index a6444708..ae430e2c 100644 --- a/music_assistant/helpers/cache.py +++ b/music_assistant/helpers/cache.py @@ -8,7 +8,6 @@ import time from typing import Awaitable from music_assistant.helpers.typing import MusicAssistant -from music_assistant.helpers.util import create_task DB_TABLE = "cache" @@ -123,7 +122,7 @@ class Cache: async def cached( - cache, + cache: Cache, cache_key: str, coro_func: Awaitable, *args, @@ -138,5 +137,5 @@ async def cached( result = await coro_func else: result = await coro_func(*args) - create_task(cache.set(cache_key, result, checksum, expires)) + cache.mass.create_task(cache.set(cache_key, result, checksum, expires)) return result diff --git a/music_assistant/helpers/util.py b/music_assistant/helpers/util.py index f314e660..d49a8f3d 100755 --- a/music_assistant/helpers/util.py +++ b/music_assistant/helpers/util.py @@ -2,14 +2,11 @@ from __future__ import annotations import asyncio -import functools import os import platform import socket import tempfile -import threading -from asyncio.events import AbstractEventLoop -from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Set, TypeVar import memory_tempfile @@ -20,51 +17,6 @@ CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) CALLBACK_TYPE = Callable[[], None] # pylint: enable=invalid-name -DEFAULT_LOOP = None - - -def create_task( - target: Callable[..., Any], - *args: Any, - loop: AbstractEventLoop = None, - **kwargs: Any, -) -> Union[asyncio.Task, asyncio.Future]: - """Create Task on (main) event loop from Callable or awaitable. - - target: target to call. - loop: Running (main) event loop, defaults to loop in current thread - args/kwargs: parameters for method to call. - """ - try: - loop = loop or asyncio.get_running_loop() - except RuntimeError: - # try to fetch the default loop from global variable - loop = DEFAULT_LOOP - - # Check for partials to properly determine if coroutine function - check_target = target - while isinstance(check_target, functools.partial): - check_target = check_target.func - - async def executor_wrapper(_target: Callable, *_args, **_kwargs): - return await loop.run_in_executor(None, _target, *_args, **_kwargs) - - # called from other thread - if threading.current_thread() is not threading.main_thread(): - if asyncio.iscoroutine(check_target): - return asyncio.run_coroutine_threadsafe(target, loop) - if asyncio.iscoroutinefunction(check_target): - return asyncio.run_coroutine_threadsafe(target(*args), loop) - return asyncio.run_coroutine_threadsafe( - executor_wrapper(target, *args, **kwargs), loop - ) - - if asyncio.iscoroutine(check_target): - return loop.create_task(target) - if asyncio.iscoroutinefunction(check_target): - return loop.create_task(target(*args)) - return loop.create_task(executor_wrapper(target, *args, **kwargs)) - def run_periodic(delay: float, later: bool = False): """Run a coroutine at interval.""" diff --git a/music_assistant/mass.py b/music_assistant/mass.py index d36dcdf5..b3aa6343 100644 --- a/music_assistant/mass.py +++ b/music_assistant/mass.py @@ -2,9 +2,12 @@ from __future__ import annotations import asyncio +import functools import logging +import threading from time import time -from typing import Any, Callable, Coroutine, Optional, Tuple, Union +from types import TracebackType +from typing import Any, Callable, Coroutine, List, Optional, Tuple, Type, Union import aiohttp from databases import DatabaseURL @@ -12,10 +15,8 @@ from music_assistant.constants import EventType from music_assistant.controllers.metadata import MetaDataController from music_assistant.controllers.music import MusicController from music_assistant.controllers.players import PlayerController -from music_assistant.helpers import util from music_assistant.helpers.cache import Cache from music_assistant.helpers.database import Database -from music_assistant.helpers.util import create_task EventCallBackType = Callable[[EventType, Any], None] EventSubscriptionType = Tuple[EventCallBackType, Optional[Tuple[EventType]]] @@ -52,13 +53,12 @@ class MusicAssistant: self.metadata = MetaDataController(self) self.music = MusicController(self) self.players = PlayerController(self, stream_port) - self._jobs_task: asyncio.Task = None + self._tracked_tasks: List[asyncio.Task] = [] async def setup(self) -> None: """Async setup of music assistant.""" # initialize loop self.loop = asyncio.get_event_loop() - util.DEFAULT_LOOP = self.loop # create shared aiohttp ClientSession if not self.http_session: self.http_session = aiohttp.ClientSession( @@ -70,17 +70,20 @@ class MusicAssistant: await self.music.setup() await self.metadata.setup() await self.players.setup() - self._jobs_task = create_task(self.__process_jobs()) + self.create_task(self.__process_jobs()) async def stop(self) -> None: """Stop running the music assistant server.""" - self.logger.info("Application shutdown") + self.logger.info("Stop called, cleaning up...") + # cancel any running tasks + for task in self._tracked_tasks: + task.cancel() self.signal_event(EventType.SHUTDOWN) - if self._jobs_task is not None: - self._jobs_task.cancel() + # wait for any remaining tasks launched by the shutdown event + await asyncio.wait_for(asyncio.wait(self._tracked_tasks), 2) if self.http_session and not self.http_session_provided: await self.http_session.connector.close() - self.http_session.detach() + self.http_session.detach() def signal_event(self, event_type: EventType, event_details: Any = None) -> None: """ @@ -90,8 +93,8 @@ class MusicAssistant: :param event_details: optional details to send with the event. """ for cb_func, event_filter in self._listeners: - if not event_filter or event_type in event_filter: - create_task(cb_func, event_type, event_details) + if event_filter is None or event_type in event_filter: + self.create_task(cb_func, event_type, event_details) def subscribe( self, @@ -107,8 +110,6 @@ class MusicAssistant: """ if isinstance(event_filter, EventType): event_filter = (event_filter,) - elif event_filter is None: - event_filter = tuple() listener = (cb_func, event_filter) self._listeners.append(listener) @@ -123,6 +124,53 @@ class MusicAssistant: name = job.__qualname__ or job.__name__ self._jobs.put_nowait((name, job)) + def create_task( + self, + target: Callable[..., Any], + *args: Any, + **kwargs: Any, + ) -> Union[asyncio.Task, asyncio.Future]: + """ + Create Task on (main) event loop from Callable or awaitable. + + Tasks create dby this helper will be properly cancelled on stop. + """ + + # Check for partials to properly determine if coroutine function + check_target = target + while isinstance(check_target, functools.partial): + check_target = check_target.func + + async def executor_wrapper(_target: Callable, *_args, **_kwargs): + return await self.loop.run_in_executor(None, _target, *_args, **_kwargs) + + # called from other thread + if threading.current_thread() is not threading.main_thread(): + if asyncio.iscoroutine(check_target): + task = asyncio.run_coroutine_threadsafe(target, self.loop) + elif asyncio.iscoroutinefunction(check_target): + task = asyncio.run_coroutine_threadsafe(target(*args), self.loop) + else: + task = asyncio.run_coroutine_threadsafe( + executor_wrapper(target, *args, **kwargs), self.loop + ) + else: + if asyncio.iscoroutine(check_target): + task = self.loop.create_task(target) + elif asyncio.iscoroutinefunction(check_target): + task = self.loop.create_task(target(*args)) + else: + task = self.loop.create_task(executor_wrapper(target, *args, **kwargs)) + + def task_done_callback(*args, **kwargs): + self.logger.debug("task finished %s", task.get_name()) + self._tracked_tasks.remove(task) + + self._tracked_tasks.append(task) + task.add_done_callback(task_done_callback) + self.logger.debug("spawned task %s", task.get_name()) + return task + async def __process_jobs(self): """Process jobs in the background.""" while True: @@ -131,7 +179,7 @@ class MusicAssistant: self.logger.debug("Start processing job [%s].", name) try: # await job - task = asyncio.create_task(job, name=name) + task = self.create_task(job, name=name) await task except Exception as err: # pylint: disable=broad-except self.logger.error( @@ -140,3 +188,20 @@ class MusicAssistant: else: duration = round(time() - time_start, 2) self.logger.info("Finished job [%s] in %s seconds.", name, duration) + + async def __aenter__(self) -> "MusicAssistant": + """Return Context manager.""" + await self.setup() + return self + + async def __aexit__( + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> Optional[bool]: + """Exit context manager.""" + await self.stop() + if exc_val: + raise exc_val + return exc_type diff --git a/music_assistant/models/player.py b/music_assistant/models/player.py index 653a41a0..cb74b91f 100755 --- a/music_assistant/models/player.py +++ b/music_assistant/models/player.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Dict, List from mashumaro import DataClassDictMixin from music_assistant.constants import EventType from music_assistant.helpers.typing import MusicAssistant -from music_assistant.helpers.util import create_task if TYPE_CHECKING: from .player_queue import PlayerQueue @@ -210,12 +209,12 @@ class Player(ABC): # update group player childs when parent updates for child_player_id in self.group_childs: if player := self.mass.players.get_player(child_player_id): - create_task(player.update_state) + self.mass.create_task(player.update_state) else: # update group player when child updates for group_player_id in self.get_group_parents(): if player := self.mass.players.get_player(group_player_id): - create_task(player.update_state) + self.mass.create_task(player.update_state) def get_group_parents(self) -> List[str]: """Get any/all group player id's this player belongs to.""" diff --git a/music_assistant/models/player_queue.py b/music_assistant/models/player_queue.py index c818fbdf..5f45117f 100644 --- a/music_assistant/models/player_queue.py +++ b/music_assistant/models/player_queue.py @@ -14,7 +14,6 @@ from mashumaro import DataClassDictMixin from music_assistant.constants import EventType from music_assistant.helpers.audio import get_stream_details from music_assistant.helpers.typing import MusicAssistant -from music_assistant.helpers.util import create_task from music_assistant.models.errors import MediaNotFoundError, QueueEmpty from music_assistant.models.media_items import MediaType, StreamDetails @@ -479,11 +478,11 @@ class PlayerQueue: # handle case where stream stopped on purpose and we need to restart it if self.player.state != PlayerState.PLAYING and self._signal_next: self._signal_next = False - create_task(self.play()) + self.mass.create_task(self.play()) # start updater task if needed if self.player.state == PlayerState.PLAYING: if not self._update_task: - self._update_task = create_task(self.__update_task()) + self._update_task = self.mass.create_task(self.__update_task()) else: if self._update_task: self._update_task.cancel() @@ -655,4 +654,6 @@ class PlayerQueue: if self._save_task and not self._save_task.cancelled(): return - self._save_task = self.mass.loop.call_later(60, create_task, cache_items) + self._save_task = self.mass.loop.call_later( + 60, self.mass.create_task, cache_items + ) diff --git a/requirements.txt b/requirements.txt index 19d3a431..a085f694 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ databases>=0.5,<=0.5.5 aiosqlite>=0.13,<=0.17 python-slugify>=4.0,<=6.1.1 memory-tempfile<=2.2.3 -aiorun>=2021.10,<=2021.10.1 pillow>=8.0,<9.1.1 unidecode>=1.0,<=1.3.4 ujson>=4.0,<=5.1.0 -- 2.34.1