From 0427f0fc0e79c6991956a322bcb49f30ac5df66e Mon Sep 17 00:00:00 2001 From: Jc2k Date: Mon, 13 Jan 2025 23:43:37 +0000 Subject: [PATCH] chore: mypy for mass.py (#1863) * chore: mypy for mass.py * fix: avoid stashing task_id on the task object * fix: force type --- music_assistant/mass.py | 60 ++++++++++++++++++++++++----------------- pyproject.toml | 1 - 2 files changed, 36 insertions(+), 25 deletions(-) diff --git a/music_assistant/mass.py b/music_assistant/mass.py index 7f921ae4..1e1b900e 100644 --- a/music_assistant/mass.py +++ b/music_assistant/mass.py @@ -6,7 +6,7 @@ import asyncio import logging import os from collections.abc import Awaitable, Callable, Coroutine -from typing import TYPE_CHECKING, Any, Self, TypeVar +from typing import TYPE_CHECKING, Any, Self, TypeGuard, TypeVar from uuid import uuid4 import aiofiles @@ -48,6 +48,8 @@ from music_assistant.helpers.util import ( load_provider_module, ) from music_assistant.models import ProviderInstanceType +from music_assistant.models.music_provider import MusicProvider +from music_assistant.models.player_provider import PlayerProvider if TYPE_CHECKING: from types import TracebackType @@ -77,6 +79,16 @@ PROVIDERS_PATH = os.path.join(BASE_DIR, "providers") _R = TypeVar("_R") +def is_music_provider(provider: ProviderInstanceType) -> TypeGuard[MusicProvider]: + """Type guard that returns true if a provider is a music provider.""" + return provider.type == ProviderType.MUSIC + + +def is_player_provider(provider: ProviderInstanceType) -> TypeGuard[PlayerProvider]: + """Type guard that returns true if a provider is a player provider.""" + return provider.type == ProviderType.PLAYER + + class MusicAssistant: """Main MusicAssistant (Server) object.""" @@ -103,7 +115,7 @@ class MusicAssistant: self._subscribers: set[EventSubscriptionType] = set() self._provider_manifests: dict[str, ProviderManifest] = {} self._providers: dict[str, ProviderInstanceType] = {} - self._tracked_tasks: dict[str, asyncio.Task] = {} + self._tracked_tasks: dict[str, asyncio.Task[Any]] = {} self._tracked_timers: dict[str, asyncio.TimerHandle] = {} self.closing = False self.running_as_hass_addon: bool = False @@ -240,7 +252,7 @@ class MusicAssistant: """Return the application log from file.""" logfile = os.path.join(self.storage_path, "musicassistant.log") async with aiofiles.open(logfile) as _file: - return await _file.read() + return str(await _file.read()) @property def providers(self) -> list[ProviderInstanceType]: @@ -297,7 +309,7 @@ class MusicAssistant: cb_func: EventCallBackType, event_filter: EventType | tuple[EventType, ...] | None = None, id_filter: str | tuple[str, ...] | None = None, - ) -> Callable: + ) -> Callable[[], None]: """Add callback to event listeners. Returns function to remove the listener. @@ -329,9 +341,6 @@ class MusicAssistant: Tasks created by this helper will be properly cancelled on stop. """ - if target is None: - msg = "Target is missing" - raise RuntimeError(msg) if task_id and (existing := self._tracked_tasks.get(task_id)) and not existing.done(): # prevent duplicate tasks if task_id is given and already present if abort_existing: @@ -344,12 +353,16 @@ class MusicAssistant: elif asyncio.iscoroutine(target): # coroutine task = self.loop.create_task(target) - else: + elif callable(target): task = self.loop.create_task(asyncio.to_thread(target, *args, **kwargs)) + else: + raise RuntimeError("Target is missing") - def task_done_callback(_task: asyncio.Task) -> None: - _task_id = task.task_id - self._tracked_tasks.pop(_task_id, None) + if task_id is None: + task_id = uuid4().hex + + def task_done_callback(_task: asyncio.Task[Any]) -> None: + self._tracked_tasks.pop(task_id, None) # log unhandled exceptions if ( LOGGER.isEnabledFor(logging.DEBUG) @@ -365,9 +378,6 @@ class MusicAssistant: exc_info=err if LOGGER.isEnabledFor(logging.DEBUG) else None, ) - if task_id is None: - task_id = uuid4().hex - task.task_id = task_id self._tracked_tasks[task_id] = task task.add_done_callback(task_done_callback) return task @@ -375,7 +385,7 @@ class MusicAssistant: def call_later( self, delay: float, - target: Coroutine | Awaitable | Callable, + target: Coroutine[Any, Any, _R] | Awaitable[_R] | Callable[..., _R], *args: Any, task_id: str | None = None, **kwargs: Any, @@ -399,7 +409,7 @@ class MusicAssistant: self._tracked_timers[task_id] = handle return handle - def get_task(self, task_id: str) -> asyncio.Task: + def get_task(self, task_id: str) -> asyncio.Task[Any]: """Get existing scheduled task.""" if existing := self._tracked_tasks.get(task_id): # prevent duplicate tasks if task_id is given and already present @@ -410,8 +420,8 @@ class MusicAssistant: def register_api_command( self, command: str, - handler: Callable, - ) -> None: + handler: Callable[..., Coroutine[Any, Any, Any]], + ) -> Callable[[], None]: """ Dynamically register a command on the API. @@ -516,12 +526,13 @@ class MusicAssistant: # make sure to stop any running sync tasks first for sync_task in self.music.in_progress_syncs: if sync_task.provider_instance == instance_id: - sync_task.task.cancel() + if sync_task.task: + sync_task.task.cancel() # check if there are no other providers dependent of this provider for dep_prov in self.providers: if dep_prov.manifest.depends_on == provider.domain: await self.unload_provider(dep_prov.instance_id) - if provider.type == ProviderType.PLAYER: + if is_player_provider(provider): # mark all players of this provider as unavailable for player in provider.players: player.available = False @@ -590,7 +601,7 @@ class MusicAssistant: prov_manifest = self._provider_manifests.get(domain) # check for other instances of this provider existing = next((x for x in self.providers if x.domain == domain), None) - if existing and not prov_manifest.multi_instance: + if existing and prov_manifest and not prov_manifest.multi_instance: msg = f"Provider {domain} already loaded and only one instance allowed." raise SetupFailedError(msg) # check valid manifest (just in case) @@ -719,7 +730,7 @@ class MusicAssistant: ) -> None: """Handle MDNS service state callback.""" - async def process_mdns_state_change(prov: ProviderInstanceType): + async def process_mdns_state_change(prov: ProviderInstanceType) -> None: if state_change == ServiceStateChange.Removed: info = None else: @@ -755,6 +766,7 @@ class MusicAssistant: ) -> bool | None: """Exit context manager.""" await self.stop() + return None async def _update_available_providers_cache(self) -> None: """Update the global cache variable of loaded/available providers.""" @@ -770,12 +782,12 @@ class MusicAssistant: "streaming_providers": { x.lookup_key for x in self.providers - if x.type == ProviderType.MUSIC and x.is_streaming_provider + if is_music_provider(x) and x.is_streaming_provider }, "non_streaming_providers": { x.lookup_key for x in self.providers - if not (x.type == ProviderType.MUSIC and x.is_streaming_provider) + if not (is_music_provider(x) and x.is_streaming_provider) }, } ) diff --git a/pyproject.toml b/pyproject.toml index 0e4f3389..84a04928 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,7 +122,6 @@ exclude = [ '^music_assistant/controllers/.*$', '^music_assistant/helpers/.*$', '^music_assistant/models/.*$', - '^music_assistant/mass\.py$', '^music_assistant/providers/_template_music_provider/.*$', '^music_assistant/providers/_template_player_provider/.*$', '^music_assistant/providers/apple_music/.*$', -- 2.34.1