chore: mypy for mass.py (#1863)
authorJc2k <john.carr@unrouted.co.uk>
Mon, 13 Jan 2025 23:43:37 +0000 (23:43 +0000)
committerGitHub <noreply@github.com>
Mon, 13 Jan 2025 23:43:37 +0000 (00:43 +0100)
* chore: mypy for mass.py

* fix: avoid stashing task_id on the task object

* fix: force type

music_assistant/mass.py
pyproject.toml

index 7f921ae42a789e22140e3334fa5540765a405525..1e1b900eaaf7727c7666f5e3f88e3757dbfd5de9 100644 (file)
@@ -6,7 +6,7 @@ import asyncio
 import logging
 import os
 from collections.abc import Awaitable, Callable, Coroutine
-from typing import TYPE_CHECKING, Any, Self, TypeVar
+from typing import TYPE_CHECKING, Any, Self, TypeGuard, TypeVar
 from uuid import uuid4
 
 import aiofiles
@@ -48,6 +48,8 @@ from music_assistant.helpers.util import (
     load_provider_module,
 )
 from music_assistant.models import ProviderInstanceType
+from music_assistant.models.music_provider import MusicProvider
+from music_assistant.models.player_provider import PlayerProvider
 
 if TYPE_CHECKING:
     from types import TracebackType
@@ -77,6 +79,16 @@ PROVIDERS_PATH = os.path.join(BASE_DIR, "providers")
 _R = TypeVar("_R")
 
 
+def is_music_provider(provider: ProviderInstanceType) -> TypeGuard[MusicProvider]:
+    """Type guard that returns true if a provider is a music provider."""
+    return provider.type == ProviderType.MUSIC
+
+
+def is_player_provider(provider: ProviderInstanceType) -> TypeGuard[PlayerProvider]:
+    """Type guard that returns true if a provider is a player provider."""
+    return provider.type == ProviderType.PLAYER
+
+
 class MusicAssistant:
     """Main MusicAssistant (Server) object."""
 
@@ -103,7 +115,7 @@ class MusicAssistant:
         self._subscribers: set[EventSubscriptionType] = set()
         self._provider_manifests: dict[str, ProviderManifest] = {}
         self._providers: dict[str, ProviderInstanceType] = {}
-        self._tracked_tasks: dict[str, asyncio.Task] = {}
+        self._tracked_tasks: dict[str, asyncio.Task[Any]] = {}
         self._tracked_timers: dict[str, asyncio.TimerHandle] = {}
         self.closing = False
         self.running_as_hass_addon: bool = False
@@ -240,7 +252,7 @@ class MusicAssistant:
         """Return the application log from file."""
         logfile = os.path.join(self.storage_path, "musicassistant.log")
         async with aiofiles.open(logfile) as _file:
-            return await _file.read()
+            return str(await _file.read())
 
     @property
     def providers(self) -> list[ProviderInstanceType]:
@@ -297,7 +309,7 @@ class MusicAssistant:
         cb_func: EventCallBackType,
         event_filter: EventType | tuple[EventType, ...] | None = None,
         id_filter: str | tuple[str, ...] | None = None,
-    ) -> Callable:
+    ) -> Callable[[], None]:
         """Add callback to event listeners.
 
         Returns function to remove the listener.
@@ -329,9 +341,6 @@ class MusicAssistant:
 
         Tasks created by this helper will be properly cancelled on stop.
         """
-        if target is None:
-            msg = "Target is missing"
-            raise RuntimeError(msg)
         if task_id and (existing := self._tracked_tasks.get(task_id)) and not existing.done():
             # prevent duplicate tasks if task_id is given and already present
             if abort_existing:
@@ -344,12 +353,16 @@ class MusicAssistant:
         elif asyncio.iscoroutine(target):
             # coroutine
             task = self.loop.create_task(target)
-        else:
+        elif callable(target):
             task = self.loop.create_task(asyncio.to_thread(target, *args, **kwargs))
+        else:
+            raise RuntimeError("Target is missing")
 
-        def task_done_callback(_task: asyncio.Task) -> None:
-            _task_id = task.task_id
-            self._tracked_tasks.pop(_task_id, None)
+        if task_id is None:
+            task_id = uuid4().hex
+
+        def task_done_callback(_task: asyncio.Task[Any]) -> None:
+            self._tracked_tasks.pop(task_id, None)
             # log unhandled exceptions
             if (
                 LOGGER.isEnabledFor(logging.DEBUG)
@@ -365,9 +378,6 @@ class MusicAssistant:
                     exc_info=err if LOGGER.isEnabledFor(logging.DEBUG) else None,
                 )
 
-        if task_id is None:
-            task_id = uuid4().hex
-        task.task_id = task_id
         self._tracked_tasks[task_id] = task
         task.add_done_callback(task_done_callback)
         return task
@@ -375,7 +385,7 @@ class MusicAssistant:
     def call_later(
         self,
         delay: float,
-        target: Coroutine | Awaitable | Callable,
+        target: Coroutine[Any, Any, _R] | Awaitable[_R] | Callable[..., _R],
         *args: Any,
         task_id: str | None = None,
         **kwargs: Any,
@@ -399,7 +409,7 @@ class MusicAssistant:
         self._tracked_timers[task_id] = handle
         return handle
 
-    def get_task(self, task_id: str) -> asyncio.Task:
+    def get_task(self, task_id: str) -> asyncio.Task[Any]:
         """Get existing scheduled task."""
         if existing := self._tracked_tasks.get(task_id):
             # prevent duplicate tasks if task_id is given and already present
@@ -410,8 +420,8 @@ class MusicAssistant:
     def register_api_command(
         self,
         command: str,
-        handler: Callable,
-    ) -> None:
+        handler: Callable[..., Coroutine[Any, Any, Any]],
+    ) -> Callable[[], None]:
         """
         Dynamically register a command on the API.
 
@@ -516,12 +526,13 @@ class MusicAssistant:
             # make sure to stop any running sync tasks first
             for sync_task in self.music.in_progress_syncs:
                 if sync_task.provider_instance == instance_id:
-                    sync_task.task.cancel()
+                    if sync_task.task:
+                        sync_task.task.cancel()
             # check if there are no other providers dependent of this provider
             for dep_prov in self.providers:
                 if dep_prov.manifest.depends_on == provider.domain:
                     await self.unload_provider(dep_prov.instance_id)
-            if provider.type == ProviderType.PLAYER:
+            if is_player_provider(provider):
                 # mark all players of this provider as unavailable
                 for player in provider.players:
                     player.available = False
@@ -590,7 +601,7 @@ class MusicAssistant:
         prov_manifest = self._provider_manifests.get(domain)
         # check for other instances of this provider
         existing = next((x for x in self.providers if x.domain == domain), None)
-        if existing and not prov_manifest.multi_instance:
+        if existing and prov_manifest and not prov_manifest.multi_instance:
             msg = f"Provider {domain} already loaded and only one instance allowed."
             raise SetupFailedError(msg)
         # check valid manifest (just in case)
@@ -719,7 +730,7 @@ class MusicAssistant:
     ) -> None:
         """Handle MDNS service state callback."""
 
-        async def process_mdns_state_change(prov: ProviderInstanceType):
+        async def process_mdns_state_change(prov: ProviderInstanceType) -> None:
             if state_change == ServiceStateChange.Removed:
                 info = None
             else:
@@ -755,6 +766,7 @@ class MusicAssistant:
     ) -> bool | None:
         """Exit context manager."""
         await self.stop()
+        return None
 
     async def _update_available_providers_cache(self) -> None:
         """Update the global cache variable of loaded/available providers."""
@@ -770,12 +782,12 @@ class MusicAssistant:
                 "streaming_providers": {
                     x.lookup_key
                     for x in self.providers
-                    if x.type == ProviderType.MUSIC and x.is_streaming_provider
+                    if is_music_provider(x) and x.is_streaming_provider
                 },
                 "non_streaming_providers": {
                     x.lookup_key
                     for x in self.providers
-                    if not (x.type == ProviderType.MUSIC and x.is_streaming_provider)
+                    if not (is_music_provider(x) and x.is_streaming_provider)
                 },
             }
         )
index 0e4f3389705d282424d7b0876a0384f63c25634d..84a04928ad7352212af99da26c9ac315894b08bc 100644 (file)
@@ -122,7 +122,6 @@ exclude = [
   '^music_assistant/controllers/.*$',
   '^music_assistant/helpers/.*$',
   '^music_assistant/models/.*$',
-  '^music_assistant/mass\.py$',
   '^music_assistant/providers/_template_music_provider/.*$',
   '^music_assistant/providers/_template_player_provider/.*$',
   '^music_assistant/providers/apple_music/.*$',