Get mypy passing on music_assistant.client
authorJohn Carr <john.carr@unrouted.co.uk>
Wed, 26 Jun 2024 09:26:52 +0000 (10:26 +0100)
committerJc2k <john.carr@unrouted.co.uk>
Sat, 29 Jun 2024 12:21:39 +0000 (13:21 +0100)
music_assistant/client/client.py
music_assistant/client/config.py
music_assistant/client/connection.py
music_assistant/client/music.py
music_assistant/client/player_queues.py
music_assistant/client/players.py
music_assistant/common/models/api.py

index 6e8a84f7683bc1fdffa688dcb1008af734594df8..cc0ae4ce6a4e6d67d9cf7f11eb43e0fc05037944 100644 (file)
@@ -58,7 +58,7 @@ class MusicAssistantClient:
         self.server_url = server_url
         self.connection = WebsocketsConnection(server_url, aiohttp_session)
         self.logger = logging.getLogger(__package__)
-        self._result_futures: dict[str, asyncio.Future] = {}
+        self._result_futures: dict[str | int, asyncio.Future[Any]] = {}
         self._subscribers: list[EventSubscriptionType] = []
         self._stop_called: bool = False
         self._loop: asyncio.AbstractEventLoop | None = None
@@ -133,6 +133,7 @@ class MusicAssistantClient:
 
     def get_image_url(self, image: MediaItemImage, size: int = 0) -> str:
         """Get (proxied) URL for MediaItemImage."""
+        assert self.server_info
         if image.remotely_accessible and not size:
             return image.path
         if image.remotely_accessible and size:
@@ -167,9 +168,9 @@ class MusicAssistantClient:
     def subscribe(
         self,
         cb_func: EventCallBackType,
-        event_filter: EventType | tuple[EventType] | None = None,
-        id_filter: str | tuple[str] | None = None,
-    ) -> Callable:
+        event_filter: EventType | tuple[EventType, ...] | None = None,
+        id_filter: str | tuple[str, ...] | None = None,
+    ) -> Callable[[], None]:
         """Add callback to event listeners.
 
         Returns function to remove the listener.
@@ -360,6 +361,8 @@ class MusicAssistantClient:
         if self._stop_called:
             return
 
+        assert self._loop
+
         if event.event == EventType.PROVIDERS_UPDATED:
             self._providers = {x["instance_id"]: ProviderInstance.from_dict(x) for x in event.data}
 
@@ -386,6 +389,7 @@ class MusicAssistantClient:
     ) -> bool | None:
         """Exit context manager."""
         await self.disconnect()
+        return None
 
     def __repr__(self) -> str:
         """Return the representation."""
index f884a41c66e701197b8555a71d075a4ab971c8cd..5a901934952eecfe450f7836be273dc613eac442 100644 (file)
@@ -2,7 +2,7 @@
 
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, cast
 
 from music_assistant.common.models.config_entries import (
     ConfigEntry,
@@ -51,8 +51,11 @@ class Config:
 
     async def get_provider_config_value(self, instance_id: str, key: str) -> ConfigValueType:
         """Return single configentry value for a provider."""
-        return await self.client.send_command(
-            "config/providers/get_value", instance_id=instance_id, key=key
+        return cast(
+            ConfigValueType,
+            await self.client.send_command(
+                "config/providers/get_value", instance_id=instance_id, key=key
+            ),
         )
 
     async def get_provider_config_entries(
@@ -70,7 +73,7 @@ class Config:
         action: [optional] action key called from config entries UI.
         values: the (intermediate) raw values for config entries sent with the action.
         """
-        return (
+        return tuple(
             ConfigEntry.from_dict(x)
             for x in await self.client.send_command(
                 "config/providers/get_entries",
@@ -144,8 +147,11 @@ class Config:
         key: str,
     ) -> ConfigValueType:
         """Return single configentry value for a player."""
-        return await self.client.send_command(
-            "config/players/get_value", player_id=player_id, key=key
+        return cast(
+            ConfigValueType,
+            await self.client.send_command(
+                "config/players/get_value", player_id=player_id, key=key
+            ),
         )
 
     async def save_player_config(
@@ -185,7 +191,10 @@ class Config:
 
     async def get_core_config_value(self, domain: str, key: str) -> ConfigValueType:
         """Return single configentry value for a core controller."""
-        return await self.client.send_command("config/core/get_value", domain=domain, key=key)
+        return cast(
+            ConfigValueType,
+            await self.client.send_command("config/core/get_value", domain=domain, key=key),
+        )
 
     async def get_core_config_entries(
         self,
@@ -200,7 +209,7 @@ class Config:
         action: [optional] action key called from config entries UI.
         values: the (intermediate) raw values for config entries sent with the action.
         """
-        return (
+        return tuple(
             ConfigEntry.from_dict(x)
             for x in await self.client.send_command(
                 "config/core/get_entries",
index f3afb6aae14fc80267d455a592b1609dc9c872a2..9d7597a247e9715ee8da619f4793a1d4e204cf98 100644 (file)
@@ -4,7 +4,7 @@ from __future__ import annotations
 
 import logging
 import pprint
-from typing import Any
+from typing import Any, cast
 
 from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType, client_exceptions
 
@@ -39,7 +39,7 @@ class WebsocketsConnection:
         """Initialize."""
         self.ws_server_url = get_websocket_url(server_url)
         self._aiohttp_session_provided = aiohttp_session is not None
-        self._aiohttp_session = aiohttp_session or ClientSession()
+        self._aiohttp_session: ClientSession | None = aiohttp_session or ClientSession()
         self._ws_client: ClientWebSocketResponse | None = None
 
     @property
@@ -87,24 +87,20 @@ class WebsocketsConnection:
         ws_msg = await self._ws_client.receive()
 
         if ws_msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
-            msg = "Connection was closed."
-            raise ConnectionClosed(msg)
+            raise ConnectionClosed("Connection was closed.")
 
         if ws_msg.type == WSMsgType.ERROR:
             raise ConnectionFailed
 
         if ws_msg.type != WSMsgType.TEXT:
-            msg = f"Received non-Text message: {ws_msg.type}"
-            raise InvalidMessage(msg)
+            raise InvalidMessage(f"Received non-Text message: {ws_msg.type}")
 
         try:
-            msg = json_loads(ws_msg.data)
+            msg = cast(dict[str, Any], json_loads(ws_msg.data))
         except TypeError as err:
-            msg = f"Received unsupported JSON: {err}"
-            raise InvalidMessage(msg) from err
+            raise InvalidMessage(f"Received unsupported JSON: {err}") from err
         except ValueError as err:
-            msg = "Received invalid JSON."
-            raise InvalidMessage(msg) from err
+            raise InvalidMessage("Received invalid JSON.") from err
 
         if LOGGER.isEnabledFor(logging.DEBUG):
             LOGGER.debug("Received message:\n%s\n", pprint.pformat(ws_msg))
index 886e030a90b66a0a306e352d4441e2d41a28c165..628290bb264d734fc3c856a3a95d3c3f9e6e4e3c 100644 (file)
@@ -3,7 +3,7 @@
 from __future__ import annotations
 
 import urllib.parse
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, cast
 
 from music_assistant.common.models.enums import ImageType, MediaType
 from music_assistant.common.models.media_items import (
@@ -22,6 +22,7 @@ from music_assistant.common.models.media_items import (
     media_from_dict,
 )
 from music_assistant.common.models.provider import SyncTask
+from music_assistant.common.models.queue_item import QueueItem
 
 if TYPE_CHECKING:
     from .client import MusicAssistantClient
@@ -111,6 +112,7 @@ class Music:
         provider_instance_id_or_domain: str,
     ) -> str:
         """Get URL to preview clip of given track."""
+        assert self.client.server_info
         encoded_url = urllib.parse.quote(urllib.parse.quote(item_id))
         return f"{self.client.server_info.base_url}/preview?path={encoded_url}&provider={provider_instance_id_or_domain}"  # noqa: E501
 
@@ -230,7 +232,7 @@ class Music:
     ) -> list[Track]:
         """Get (top)tracks for given artist."""
         return [
-            Artist.from_dict(item)
+            Track.from_dict(item)
             for item in await self.client.send_command(
                 "music/artists/artist_tracks",
                 item_id=item_id,
@@ -525,7 +527,9 @@ class Music:
 
     async def add_item_to_library(self, item: str | MediaItemType) -> MediaItemType:
         """Add item (uri or mediaitem) to the library."""
-        return await self.client.send_command("music/library/add_item", item=item)
+        return cast(
+            MediaItemType, await self.client.send_command("music/library/add_item", item=item)
+        )
 
     async def refresh_item(
         self,
@@ -540,7 +544,7 @@ class Music:
 
     def get_media_item_image(
         self,
-        item: MediaItemType | ItemMapping,
+        item: MediaItemType | ItemMapping | QueueItem,
         type: ImageType = ImageType.THUMB,  # noqa: A002
     ) -> MediaItemImage | None:
         """Get MediaItemImage for MediaItem, ItemMapping."""
@@ -556,11 +560,11 @@ class Music:
             if album_image := self.get_media_item_image(album, type):
                 return album_image
         # handle regular image within mediaitem
-        metadata: MediaItemMetadata
+        metadata: MediaItemMetadata | None
         if metadata := getattr(item, "metadata", None):
             for img in metadata.images or []:
                 if img.type == type:
-                    return img
+                    return cast(MediaItemImage, img)
         # retry with album/track artist(s)
         artists: list[Artist | ItemMapping] | None
         if artists := getattr(item, "artists", None):
index 940e75eae86ef554b0c230d0b4edd2262a51b1dc..2a2d4d94b788e1b15a05e4188345b51facbd1fbc 100644 (file)
@@ -161,7 +161,7 @@ class PlayerQueues:
         """
         await self.client.send_command("player_queues/skip", queue_id=queue_id, seconds=seconds)
 
-    async def queue_command_shuffle(self, queue_id: str, shuffle_enabled=bool) -> None:
+    async def queue_command_shuffle(self, queue_id: str, shuffle_enabledbool) -> None:
         """Configure shuffle mode on the the queue."""
         await self.client.send_command(
             "player_queues/shuffle", queue_id=queue_id, shuffle_enabled=shuffle_enabled
@@ -231,4 +231,6 @@ class PlayerQueues:
     def _handle_event(self, event: MassEvent) -> None:
         """Handle incoming player(queue) event."""
         if event.event in (EventType.QUEUE_ADDED, EventType.QUEUE_UPDATED):
+            # Queue events always have an object_id
+            assert event.object_id
             self._queues[event.object_id] = PlayerQueue.from_dict(event.data)
index 8c0c3a6a0c1861007e2b935f34a15c7b2c8157af..1378ce5136f95285eef1a3ea364193b21d90c78a 100644 (file)
@@ -130,7 +130,7 @@ class Players:
 
     async def cmd_unsync_many(self, player_ids: list[str]) -> None:
         """Create temporary sync group by joining given players to target player."""
-        await self.client.send_command("players/cmd/unsync_many", player_ids)
+        await self.client.send_command("players/cmd/unsync_many", player_ids=player_ids)
 
     async def play_announcement(
         self,
@@ -208,7 +208,11 @@ class Players:
     def _handle_event(self, event: MassEvent) -> None:
         """Handle incoming player event."""
         if event.event in (EventType.PLAYER_ADDED, EventType.PLAYER_UPDATED):
+            # Player events always have an object id
+            assert event.object_id
             self._players[event.object_id] = Player.from_dict(event.data)
             return
         if event.event == EventType.PLAYER_REMOVED:
+            # Player events always have an object id
+            assert event.object_id
             self._players.pop(event.object_id, None)
index 94f6cdce017a261436dd76cc86427b8619589718..d2456aaff390518e50aa61c8ba6865b65f9ca579 100644 (file)
@@ -40,7 +40,7 @@ class SuccessResultMessage(ResultMessageBase):
 class ErrorResultMessage(ResultMessageBase):
     """Message sent when a command did not execute successfully."""
 
-    error_code: str
+    error_code: int
     details: str | None = None