Mypy: Add music_assistant.common (#1428)
authorJc2k <john.carr@unrouted.co.uk>
Tue, 2 Jul 2024 20:50:19 +0000 (21:50 +0100)
committerGitHub <noreply@github.com>
Tue, 2 Jul 2024 20:50:19 +0000 (22:50 +0200)
17 files changed:
music_assistant/client/music.py
music_assistant/common/helpers/datetime.py
music_assistant/common/helpers/global_cache.py
music_assistant/common/helpers/json.py
music_assistant/common/helpers/uri.py
music_assistant/common/helpers/util.py
music_assistant/common/models/api.py
music_assistant/common/models/config_entries.py
music_assistant/common/models/enums.py
music_assistant/common/models/errors.py
music_assistant/common/models/media_items.py
music_assistant/common/models/player.py
music_assistant/common/models/player_queue.py
music_assistant/common/models/provider.py
music_assistant/common/models/queue_item.py
music_assistant/server/providers/plex/__init__.py
mypy.ini

index 628290bb264d734fc3c856a3a95d3c3f9e6e4e3c..858352ddd7893f40d19475341f3cba7d94ffc442 100644 (file)
@@ -440,7 +440,7 @@ class Music:
         path: str | None = None,
         limit: int | None = None,
         offset: int | None = None,
-    ) -> list[MediaItemType]:
+    ) -> list[MediaItemType | ItemMapping]:
         """Browse Music providers."""
         return [
             media_from_dict(obj)
@@ -454,7 +454,7 @@ class Music:
 
     async def recently_played(
         self, limit: int = 10, media_types: list[MediaType] | None = None
-    ) -> list[MediaItemType]:
+    ) -> list[MediaItemType | ItemMapping]:
         """Return a list of the last played items."""
         return [
             media_from_dict(item)
@@ -466,7 +466,7 @@ class Music:
     async def get_item_by_uri(
         self,
         uri: str,
-    ) -> MediaItemType:
+    ) -> MediaItemType | ItemMapping:
         """Get single music item providing a mediaitem uri."""
         return media_from_dict(await self.client.send_command("music/item_by_uri", uri=uri))
 
@@ -478,7 +478,7 @@ class Music:
         force_refresh: bool = False,
         lazy: bool = True,
         add_to_library: bool = False,
-    ) -> MediaItemType:
+    ) -> MediaItemType | ItemMapping:
         """Get single music item by id and media type."""
         return media_from_dict(
             await self.client.send_command(
@@ -534,7 +534,7 @@ class Music:
     async def refresh_item(
         self,
         media_item: MediaItemType,
-    ) -> MediaItemType | None:
+    ) -> MediaItemType | ItemMapping | None:
         """Try to refresh a mediaitem by requesting it's full object or search for substitutes."""
         if result := await self.client.send_command("music/refresh_item", media_item=media_item):
             return media_from_dict(result)
index 80fa279f7e30e4c57d5f0cad45a639c0cb9fd613..af8f4b24ce4bfb836f6fea545a2e6e2a135967cd 100644 (file)
@@ -27,7 +27,7 @@ def now_timestamp() -> float:
     return now().timestamp()
 
 
-def future_timestamp(**kwargs) -> float:
+def future_timestamp(**kwargs: float) -> float:
     """Return current timestamp + timedelta."""
     return (now() + datetime.timedelta(**kwargs)).timestamp()
 
index a33f50a9bce71d460dbd02a4ad6dc0100e474912..6cd741dd91ce217fe86f914122a83a24b1585729 100644 (file)
@@ -8,7 +8,7 @@ from typing import Any
 # global cache - we use this on a few places (as limited as possible)
 # where we have no other options
 _global_cache_lock = asyncio.Lock()
-_global_cache = {}
+_global_cache: dict[str, Any] = {}
 
 
 def get_global_cache_value(key: str, default: Any = None) -> Any:
index 845fbecedcef1e4e37794ad01638b96385c13f7e..8fb679e1191076291ca486818f088bff99f5554b 100644 (file)
@@ -4,10 +4,11 @@ import asyncio
 import base64
 from _collections_abc import dict_keys, dict_values
 from types import MethodType
-from typing import Any
+from typing import Any, TypeVar
 
 import aiofiles
 import orjson
+from mashumaro.mixins.orjson import DataClassORJSONMixin
 
 JSON_ENCODE_EXCEPTIONS = (TypeError, ValueError)
 JSON_DECODE_EXCEPTIONS = (orjson.JSONDecodeError,)
@@ -59,12 +60,11 @@ def json_dumps(data: Any, indent: bool = False) -> str:
 
 json_loads = orjson.loads
 
+TargetT = TypeVar("TargetT", bound=DataClassORJSONMixin)
 
-async def load_json_file(path: str, target_class: type | None = None) -> dict:
+
+async def load_json_file(path: str, target_class: type[TargetT]) -> TargetT:
     """Load JSON from file."""
     async with aiofiles.open(path, "r") as _file:
         content = await _file.read()
-        if target_class:
-            # support for a mashumaro model
-            return target_class.from_json(content)
-        return json_loads(content)
+        return target_class.from_json(content)
index 93ed85cd88d03ed1115c2d752f04384e743c14c5..a381861dcc3a8d4c94f956e46c4bd92b70129778 100644 (file)
@@ -10,7 +10,7 @@ from music_assistant.common.models.errors import InvalidProviderID, InvalidProvi
 base62_length22_id_pattern = re.compile(r"^[a-zA-Z0-9]{22}$")
 
 
-def valid_base62_length22(item_id) -> bool:
+def valid_base62_length22(item_id: str) -> bool:
     """Validate Spotify style ID."""
     return bool(base62_length22_id_pattern.match(item_id))
 
index bbf28c1845cdd3585064ad71d33b29f53f4444cb..62b5c433601f610849c581b8bbe942884d66e7b9 100644 (file)
@@ -7,14 +7,13 @@ import os
 import re
 import socket
 from collections.abc import Callable
+from collections.abc import Set as AbstractSet
 from typing import Any, TypeVar
 from urllib.parse import urlparse
 from uuid import UUID
 
 # pylint: disable=invalid-name
 T = TypeVar("T")
-_UNDEF: dict = {}
-CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable)
 CALLBACK_TYPE = Callable[[], None]
 # pylint: enable=invalid-name
 
@@ -50,7 +49,7 @@ def try_parse_float(possible_float: Any, default: float | None = 0.0) -> float |
         return default
 
 
-def try_parse_bool(possible_bool: Any) -> str:
+def try_parse_bool(possible_bool: Any) -> bool:
     """Try to parse a bool."""
     if isinstance(possible_bool, bool):
         return possible_bool
@@ -79,7 +78,7 @@ def create_sort_name(input_str: str) -> str:
     return input_str.strip()
 
 
-def parse_title_and_version(title: str, track_version: str | None = None):
+def parse_title_and_version(title: str, track_version: str | None = None) -> tuple[str, str]:
     """Try to parse clean track title and version from the title."""
     version = ""
     for splitter in [" (", " [", " - ", " (", " [", "-"]:
@@ -135,7 +134,7 @@ def clean_title(title: str) -> str:
     return title.strip()
 
 
-def get_version_substitute(version_str: str):
+def get_version_substitute(version_str: str) -> str:
     """Transform provider version str to universal version type."""
     version_str = version_str.lower()
     # substitute edit and edition with version
@@ -169,7 +168,7 @@ def strip_url(line: str) -> str:
     ).rstrip()
 
 
-def strip_dotcom(line: str):
+def strip_dotcom(line: str) -> str:
     """Strip scheme-less netloc from line."""
     return dot_com_pattern.sub("", line)
 
@@ -227,17 +226,17 @@ def clean_stream_title(line: str) -> str:
     return line
 
 
-async def get_ip():
+async def get_ip() -> str:
     """Get primary IP-address for this host."""
 
-    def _get_ip():
+    def _get_ip() -> str:
         """Get primary IP-address for this host."""
         # pylint: disable=broad-except,no-member
         sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         try:
             # doesn't even have to be reachable
             sock.connect(("10.255.255.255", 1))
-            _ip = sock.getsockname()[0]
+            _ip = str(sock.getsockname()[0])
         except Exception:
             _ip = "127.0.0.1"
         finally:
@@ -273,7 +272,7 @@ async def select_free_port(range_start: int, range_end: int) -> int:
 async def get_ip_from_host(dns_name: str) -> str | None:
     """Resolve (first) IP-address for given dns name."""
 
-    def _resolve():
+    def _resolve() -> str | None:
         try:
             return socket.gethostbyname(dns_name)
         except Exception:  # pylint: disable=broad-except
@@ -283,7 +282,7 @@ async def get_ip_from_host(dns_name: str) -> str | None:
     return await asyncio.to_thread(_resolve)
 
 
-async def get_ip_pton(ip_string: str | None = None):
+async def get_ip_pton(ip_string: str | None = None) -> bytes:
     """Return socket pton for local ip."""
     if ip_string is None:
         ip_string = await get_ip()
@@ -294,7 +293,7 @@ async def get_ip_pton(ip_string: str | None = None):
         return await asyncio.to_thread(socket.inet_pton, socket.AF_INET6, ip_string)
 
 
-def get_folder_size(folderpath):
+def get_folder_size(folderpath: str) -> float:
     """Return folder size in gb."""
     total_size = 0
     # pylint: disable=unused-variable
@@ -306,7 +305,9 @@ def get_folder_size(folderpath):
     return total_size / float(1 << 30)
 
 
-def merge_dict(base_dict: dict, new_dict: dict, allow_overwite=False):
+def merge_dict(
+    base_dict: dict[Any, Any], new_dict: dict[Any, Any], allow_overwite: bool = False
+) -> dict[Any, Any]:
     """Merge dict without overwriting existing values."""
     final_dict = base_dict.copy()
     for key, value in new_dict.items():
@@ -321,12 +322,12 @@ def merge_dict(base_dict: dict, new_dict: dict, allow_overwite=False):
     return final_dict
 
 
-def merge_tuples(base: tuple, new: tuple) -> tuple:
+def merge_tuples(base: tuple[Any, ...], new: tuple[Any, ...]) -> tuple[Any, ...]:
     """Merge 2 tuples."""
     return tuple(x for x in base if x not in new) + tuple(new)
 
 
-def merge_lists(base: list, new: list) -> list:
+def merge_lists(base: list[Any], new: list[Any]) -> list[Any]:
     """Merge 2 lists."""
     return [x for x in base if x not in new] + list(new)
 
@@ -335,7 +336,7 @@ def get_changed_keys(
     dict1: dict[str, Any],
     dict2: dict[str, Any],
     ignore_keys: list[str] | None = None,
-) -> set[str]:
+) -> AbstractSet[str]:
     """Compare 2 dicts and return set of changed keys."""
     return get_changed_values(dict1, dict2, ignore_keys).keys()
 
@@ -369,7 +370,7 @@ def get_changed_values(
     return changed_values
 
 
-def empty_queue(q: asyncio.Queue) -> None:
+def empty_queue(q: asyncio.Queue[T]) -> None:
     """Empty an asyncio Queue."""
     for _ in range(q.qsize()):
         try:
@@ -386,10 +387,3 @@ def is_valid_uuid(uuid_to_test: str) -> bool:
     except ValueError:
         return False
     return str(uuid_obj) == uuid_to_test
-
-
-class classproperty(property):  # noqa: N801
-    """Implement class property for python3.11+."""
-
-    def __get__(self, cls, owner):  # noqa: D105
-        return classmethod(self.fget).__get__(None, owner)()
index d2456aaff390518e50aa61c8ba6865b65f9ca579..dfbfd66315c7c3377bd54b8a7c1cbc2a88bcf012 100644 (file)
@@ -65,7 +65,7 @@ MessageType = (
 )
 
 
-def parse_message(raw: dict) -> MessageType:
+def parse_message(raw: dict[Any, Any]) -> MessageType:
     """Parse Message from raw dict object."""
     if "event" in raw:
         return EventMessage.from_dict(raw)
index b2f00a4c8c326f98986db4ddb24931d172f0b762..27ccf59d16b67e260f859766747409a4c5831688 100644 (file)
@@ -4,7 +4,7 @@ from __future__ import annotations
 
 import logging
 import warnings
-from collections.abc import Iterable
+from collections.abc import Callable, Iterable
 from dataclasses import dataclass
 from enum import Enum
 from types import NoneType
@@ -46,8 +46,8 @@ warnings.filterwarnings("ignore", category=UserWarning, module="mashumaro")
 
 LOGGER = logging.getLogger(__name__)
 
-ENCRYPT_CALLBACK: callable[[str], str] | None = None
-DECRYPT_CALLBACK: callable[[str], str] | None = None
+ENCRYPT_CALLBACK: Callable[[str], str] | None = None
+DECRYPT_CALLBACK: Callable[[str], str] | None = None
 
 ConfigValueType = (
     str
@@ -58,15 +58,16 @@ ConfigValueType = (
     | list[str]
     | list[int]
     | list[tuple[int, int]]
+    | Enum
     | None
 )
 
-ConfigEntryTypeMap = {
+ConfigEntryTypeMap: dict[ConfigEntryType, type[ConfigValueType]] = {
     ConfigEntryType.BOOLEAN: bool,
     ConfigEntryType.STRING: str,
     ConfigEntryType.SECURE_STRING: str,
     ConfigEntryType.INTEGER: int,
-    ConfigEntryType.INTEGER_TUPLE: tuple[int, int],
+    ConfigEntryType.INTEGER_TUPLE: tuple[int, int],  # type: ignore[dict-item]
     ConfigEntryType.FLOAT: float,
     ConfigEntryType.LABEL: str,
     ConfigEntryType.DIVIDER: str,
@@ -187,6 +188,7 @@ class Config(DataClassDictMixin):
         """Return config value for given key."""
         config_value = self.values[key]
         if config_value.type == ConfigEntryType.SECURE_STRING:
+            assert isinstance(config_value.value, str)
             assert DECRYPT_CALLBACK is not None
             return DECRYPT_CALLBACK(config_value.value)
         return config_value.value
@@ -213,8 +215,9 @@ class Config(DataClassDictMixin):
     def to_raw(self) -> dict[str, Any]:
         """Return minimized/raw dict to store in persistent storage."""
 
-        def _handle_value(value: ConfigEntry):
+        def _handle_value(value: ConfigEntry) -> ConfigValueType:
             if value.type == ConfigEntryType.SECURE_STRING:
+                assert isinstance(value.value, str)
                 assert ENCRYPT_CALLBACK is not None
                 return ENCRYPT_CALLBACK(value.value)
             return value.value
@@ -253,9 +256,7 @@ class Config(DataClassDictMixin):
             setattr(self, key, new_val)
             changed_keys.add(key)
 
-        # config entry values
-        values = update.get("values", update)
-        for key, new_val in values.items():
+        for key, new_val in update.items():
             if key in root_values:
                 continue
             cur_val = self.values[key].value if key in self.values else None
@@ -366,12 +367,12 @@ CONF_ENTRY_AUTO_PLAY = ConfigEntry(
 CONF_ENTRY_OUTPUT_CHANNELS = ConfigEntry(
     key=CONF_OUTPUT_CHANNELS,
     type=ConfigEntryType.STRING,
-    options=[
+    options=(
         ConfigValueOption("Stereo (both channels)", "stereo"),
         ConfigValueOption("Left channel", "left"),
         ConfigValueOption("Right channel", "right"),
         ConfigValueOption("Mono (both channels)", "mono"),
-    ],
+    ),
     default_value="stereo",
     label="Output Channel Mode",
     category="audio",
@@ -503,12 +504,12 @@ CONF_ENTRY_TTS_PRE_ANNOUNCE = ConfigEntry(
 CONF_ENTRY_ANNOUNCE_VOLUME_STRATEGY = ConfigEntry(
     key=CONF_ANNOUNCE_VOLUME_STRATEGY,
     type=ConfigEntryType.STRING,
-    options=[
+    options=(
         ConfigValueOption("Absolute volume", "absolute"),
         ConfigValueOption("Relative volume increase", "relative"),
         ConfigValueOption("Volume increase by fixed percentage", "percentual"),
         ConfigValueOption("Do not adjust volume", "none"),
-    ],
+    ),
     default_value="percentual",
     label="Volume strategy for Announcements",
     category="announcements",
@@ -557,7 +558,7 @@ CONF_ENTRY_PLAYER_ICON_GROUP = ConfigEntry.from_dict(
 CONF_ENTRY_SAMPLE_RATES = ConfigEntry(
     key=CONF_SAMPLE_RATES,
     type=ConfigEntryType.INTEGER_TUPLE,
-    options=[
+    options=(
         ConfigValueOption("44.1kHz / 16 bits", (44100, 16)),
         ConfigValueOption("44.1kHz / 24 bits", (44100, 24)),
         ConfigValueOption("48kHz / 16 bits", (48000, 16)),
@@ -574,7 +575,7 @@ CONF_ENTRY_SAMPLE_RATES = ConfigEntry(
         ConfigValueOption("352.8kHz / 24 bits", (352800, 24)),
         ConfigValueOption("384kHz / 16 bits", (384000, 16)),
         ConfigValueOption("384kHz / 24 bits", (384000, 24)),
-    ],
+    ),
     default_value=[(44100, 16), (48000, 16)],
     required=True,
     multi_value=True,
@@ -593,14 +594,19 @@ def create_sample_rates_config_entry(
     hidden: bool = False,
 ) -> ConfigEntry:
     """Create sample rates config entry based on player specific helpers."""
+    assert CONF_ENTRY_SAMPLE_RATES.options
     conf_entry = ConfigEntry.from_dict(CONF_ENTRY_SAMPLE_RATES.to_dict())
-    conf_entry.options = []
-    conf_entry.default_value = []
     conf_entry.hidden = hidden
+    options: list[ConfigValueOption] = []
+    default_value: list[tuple[int, int]] = []
     for option in CONF_ENTRY_SAMPLE_RATES.options:
+        if not isinstance(option.value, tuple):
+            continue
         sample_rate, bit_depth = option.value
         if sample_rate <= max_sample_rate and bit_depth <= max_bit_depth:
-            conf_entry.options.append(option)
+            options.append(option)
         if sample_rate <= safe_max_sample_rate and bit_depth <= safe_max_bit_depth:
-            conf_entry.default_value.append(option.value)
+            default_value.append(option.value)
+    conf_entry.options = tuple(options)
+    conf_entry.default_value = default_value
     return conf_entry
index df4c5707b55275338cf60f871723c78cf76b8bc3..bd018efa748c57bb3618338143203542020a44fc 100644 (file)
@@ -3,13 +3,25 @@
 from __future__ import annotations
 
 import contextlib
-from enum import StrEnum
-from typing import Self
+from enum import EnumType, StrEnum
 
-from music_assistant.common.helpers.util import classproperty
 
+class MediaTypeMeta(EnumType):
+    """Class properties for MediaType."""
 
-class MediaType(StrEnum):
+    @property
+    def ALL(self) -> list[MediaType]:  # noqa: N802
+        """All MediaTypes."""
+        return [
+            MediaType.ARTIST,
+            MediaType.ALBUM,
+            MediaType.TRACK,
+            MediaType.PLAYLIST,
+            MediaType.RADIO,
+        ]
+
+
+class MediaType(StrEnum, metaclass=MediaTypeMeta):
     """Enum for MediaType."""
 
     ARTIST = "artist"
@@ -23,21 +35,10 @@ class MediaType(StrEnum):
     UNKNOWN = "unknown"
 
     @classmethod
-    def _missing_(cls: Self, value: object) -> Self:  # noqa: ARG003
+    def _missing_(cls, value: object) -> MediaType:  # noqa: ARG003
         """Set default enum member if an unknown value is provided."""
         return cls.UNKNOWN
 
-    @classproperty
-    def ALL(self) -> tuple[MediaType, ...]:  # noqa: N802
-        """Return all (default) MediaTypes as tuple."""
-        return (
-            MediaType.ARTIST,
-            MediaType.ALBUM,
-            MediaType.TRACK,
-            MediaType.PLAYLIST,
-            MediaType.RADIO,
-        )
-
 
 class ExternalID(StrEnum):
     """Enum with External ID types."""
@@ -56,7 +57,7 @@ class ExternalID(StrEnum):
     UNKNOWN = "unknown"
 
     @classmethod
-    def _missing_(cls: Self, value: object) -> Self:  # noqa: ARG003
+    def _missing_(cls, value: object) -> ExternalID:  # noqa: ARG003
         """Set default enum member if an unknown value is provided."""
         return cls.UNKNOWN
 
@@ -78,7 +79,7 @@ class LinkType(StrEnum):
     UNKNOWN = "unknown"
 
     @classmethod
-    def _missing_(cls: Self, value: object) -> Self:  # noqa: ARG003
+    def _missing_(cls, value: object) -> LinkType:  # noqa: ARG003
         """Set default enum member if an unknown value is provided."""
         return cls.UNKNOWN
 
@@ -98,7 +99,7 @@ class ImageType(StrEnum):
     OTHER = "other"
 
     @classmethod
-    def _missing_(cls: Self, value: object) -> Self:  # noqa: ARG003
+    def _missing_(cls, value: object) -> ImageType:  # noqa: ARG003
         """Set default enum member if an unknown value is provided."""
         return cls.OTHER
 
@@ -142,12 +143,12 @@ class ContentType(StrEnum):
     UNKNOWN = "?"
 
     @classmethod
-    def _missing_(cls, value: object) -> Self:  # noqa: ARG003
+    def _missing_(cls, value: object) -> ContentType:  # noqa: ARG003
         """Set default enum member if an unknown value is provided."""
         return cls.UNKNOWN
 
     @classmethod
-    def try_parse(cls, string: str) -> Self:
+    def try_parse(cls, string: str) -> ContentType:
         """Try to parse ContentType from (url)string/extension."""
         tempstr = string.lower()
         if "audio/" in tempstr:
@@ -247,7 +248,7 @@ class PlayerType(StrEnum):
     UNKNOWN = "unknown"
 
     @classmethod
-    def _missing_(cls: Self, value: object) -> Self:  # noqa: ARG003
+    def _missing_(cls, value: object) -> PlayerType:  # noqa: ARG003
         """Set default enum member if an unknown value is provided."""
         return cls.UNKNOWN
 
@@ -275,7 +276,7 @@ class PlayerFeature(StrEnum):
     UNKNOWN = "unknown"
 
     @classmethod
-    def _missing_(cls: Self, value: object) -> Self:  # noqa: ARG003
+    def _missing_(cls, value: object) -> PlayerFeature:  # noqa: ARG003
         """Set default enum member if an unknown value is provided."""
         return cls.UNKNOWN
 
@@ -302,7 +303,7 @@ class EventType(StrEnum):
     UNKNOWN = "unknown"
 
     @classmethod
-    def _missing_(cls: Self, value: object) -> Self:  # noqa: ARG003
+    def _missing_(cls, value: object) -> EventType:  # noqa: ARG003
         """Set default enum member if an unknown value is provided."""
         return cls.UNKNOWN
 
@@ -364,7 +365,7 @@ class ProviderFeature(StrEnum):
     UNKNOWN = "unknown"
 
     @classmethod
-    def _missing_(cls: Self, value: object) -> Self:  # noqa: ARG003
+    def _missing_(cls, value: object) -> ProviderFeature:  # noqa: ARG003
         """Set default enum member if an unknown value is provided."""
         return cls.UNKNOWN
 
@@ -396,7 +397,7 @@ class ConfigEntryType(StrEnum):
     UNKNOWN = "unknown"
 
     @classmethod
-    def _missing_(cls: Self, value: object) -> Self:  # noqa: ARG003
+    def _missing_(cls, value: object) -> ConfigEntryType:  # noqa: ARG003
         """Set default enum member if an unknown value is provided."""
         return cls.UNKNOWN
 
index fb0fc9ae24dec50698264d15c96bebbe1f8efc21..d1add798323350154f2bf86e799e5d54a396e380 100644 (file)
@@ -115,7 +115,7 @@ class RetriesExhausted(MusicAssistantError):
 class ResourceTemporarilyUnavailable(MusicAssistantError):
     """Error thrown when a resource is temporarily unavailable."""
 
-    def __init__(self, *args, backoff_time: int = 0) -> None:
+    def __init__(self, *args: object, backoff_time: int = 0) -> None:
         """Initialize."""
         super().__init__(*args)
         self.backoff_time = backoff_time
index 0549844d9320b13d03431808470fcbef8252e7c6..d271b111956c026d9433c68eda6ab63acd277c47 100644 (file)
@@ -4,7 +4,7 @@ from __future__ import annotations
 
 from collections.abc import Iterable
 from dataclasses import dataclass, field, fields
-from typing import TYPE_CHECKING, Any, Self, TypeVar, cast
+from typing import TYPE_CHECKING, Any, TypeGuard, TypeVar, cast
 
 from mashumaro import DataClassDictMixin
 
@@ -34,7 +34,7 @@ class UniqueList(list[_T]):
         if not iterable:
             super().__init__()
             return
-        seen = set()
+        seen: set[_T] = set()
         seen_add = seen.add
         super().__init__(x for x in iterable if not (x in seen or seen_add(x)))
 
@@ -61,7 +61,7 @@ class AudioFormat(DataClassDictMixin):
     output_format_str: str = ""
     bit_rate: int = 320  # optional
 
-    def __post_init__(self):
+    def __post_init__(self) -> None:
         """Execute actions after init."""
         if not self.output_format_str and self.content_type.is_pcm():
             self.output_format_str = (
@@ -90,9 +90,9 @@ class AudioFormat(DataClassDictMixin):
         """Return the PCM sample size."""
         return int(self.sample_rate * (self.bit_depth / 8) * self.channels)
 
-    def __eq__(self, other: AudioFormat) -> bool:
+    def __eq__(self, other: object) -> bool:
         """Check equality of two items."""
-        if not other:
+        if not isinstance(other, AudioFormat):
             return False
         return self.output_format_str == other.output_format_str
 
@@ -121,7 +121,7 @@ class ProviderMapping(DataClassDictMixin):
             quality += 1
         return quality
 
-    def __post_init__(self):
+    def __post_init__(self) -> None:
         """Call after init."""
         # having items for unavailable providers can have all sorts
         # of unpredictable results so ensure we have accurate availability status
@@ -138,9 +138,9 @@ class ProviderMapping(DataClassDictMixin):
         """Return custom hash."""
         return hash((self.provider_instance, self.item_id))
 
-    def __eq__(self, other: ProviderMapping) -> bool:
+    def __eq__(self, other: object) -> bool:
         """Check equality of two items."""
-        if not other:
+        if not isinstance(other, ProviderMapping):
             return False
         return self.provider_instance == other.provider_instance and self.item_id == other.item_id
 
@@ -156,8 +156,10 @@ class MediaItemLink(DataClassDictMixin):
         """Return custom hash."""
         return hash(self.type)
 
-    def __eq__(self, other: MediaItemLink) -> bool:
+    def __eq__(self, other: object) -> bool:
         """Check equality of two items."""
+        if not isinstance(other, MediaItemLink):
+            return False
         return self.url == other.url
 
 
@@ -174,8 +176,10 @@ class MediaItemImage(DataClassDictMixin):
         """Return custom hash."""
         return hash((self.type.value, self.path))
 
-    def __eq__(self, other: MediaItemImage) -> bool:
+    def __eq__(self, other: object) -> bool:
         """Check equality of two items."""
+        if not isinstance(other, MediaItemImage):
+            return False
         return self.__hash__() == other.__hash__()
 
     @classmethod
@@ -202,8 +206,10 @@ class MediaItemChapter(DataClassDictMixin):
         """Return custom hash."""
         return hash(self.chapter_id)
 
-    def __eq__(self, other: MediaItemChapter) -> bool:
+    def __eq__(self, other: object) -> bool:
         """Check equality of two items."""
+        if not isinstance(other, MediaItemChapter):
+            return False
         return self.chapter_id == other.chapter_id
 
 
@@ -248,8 +254,7 @@ class MediaItemMetadata(DataClassDictMixin):
                 new_val = merge_lists(cur_val, new_val)
                 setattr(self, fld.name, new_val)
             elif isinstance(cur_val, set) and isinstance(new_val, set | list | tuple):
-                new_val = cur_val.update(new_val)
-                setattr(self, fld.name, new_val)
+                cur_val.update(new_val)
             elif new_val and fld.name in (
                 "popularity",
                 "last_refresh",
@@ -277,12 +282,8 @@ class _MediaItemBase(DataClassDictMixin):
     external_ids: set[tuple[ExternalID, str]] = field(default_factory=set)
     media_type: MediaType = MediaType.UNKNOWN
 
-    def __post_init__(self):
+    def __post_init__(self) -> None:
         """Call after init."""
-        if self.name is None:
-            # we've got some reports where the name was empty, causing weird issues.
-            # e.g. here: https://github.com/music-assistant/hass-music-assistant/issues/1515
-            self.name = "[Unknown]"
         if self.uri is None:
             self.uri = create_uri(self.media_type, self.provider, self.item_id)
         if self.sort_name is None:
@@ -318,8 +319,10 @@ class _MediaItemBase(DataClassDictMixin):
         """Return custom hash."""
         return hash(self.uri)
 
-    def __eq__(self, other: MediaItem | ItemMapping) -> bool:
+    def __eq__(self, other: object) -> bool:
         """Check equality of two items."""
+        if not isinstance(other, MediaItem | ItemMapping):
+            return False
         return self.uri == other.uri
 
 
@@ -327,7 +330,6 @@ class _MediaItemBase(DataClassDictMixin):
 class MediaItem(_MediaItemBase):
     """Base representation of a media item."""
 
-    __hash__ = _MediaItemBase.__hash__
     __eq__ = _MediaItemBase.__eq__
 
     provider_mappings: set[ProviderMapping]
@@ -336,8 +338,12 @@ class MediaItem(_MediaItemBase):
     favorite: bool = False
     position: int | None = None  # required for playlist tracks, optional for all other
 
+    def __hash__(self) -> int:
+        """Return hash of MediaItem."""
+        return super().__hash__()
+
     @property
-    def available(self):
+    def available(self) -> bool:
         """Return (calculated) availability."""
         return any(x.available for x in self.provider_mappings)
 
@@ -360,7 +366,7 @@ class ItemMapping(_MediaItemBase):
     image: MediaItemImage | None = None
 
     @classmethod
-    def from_item(cls, item: MediaItem) -> ItemMapping:
+    def from_item(cls, item: MediaItem | ItemMapping) -> ItemMapping:
         """Create ItemMapping object from regular item."""
         if isinstance(item, ItemMapping):
             return item
@@ -409,7 +415,6 @@ class Album(MediaItem):
 class Track(MediaItem):
     """Model for a track."""
 
-    __hash__ = _MediaItemBase.__hash__
     __eq__ = _MediaItemBase.__eq__
 
     media_type: MediaType = MediaType.TRACK
@@ -420,7 +425,7 @@ class Track(MediaItem):
     disc_number: int | None = None  # required for album tracks
     track_number: int | None = None  # required for album tracks
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         """Return custom hash."""
         return hash((self.provider, self.item_id))
 
@@ -432,7 +437,11 @@ class Track(MediaItem):
         This is often an indicator that this track is an episode from a
         Podcast or AudioBook.
         """
-        return self.metadata and self.metadata.chapters and len(self.metadata.chapters) > 1
+        if not self.metadata:
+            return False
+        if not self.metadata.chapters:
+            return False
+        return len(self.metadata.chapters) > 1
 
     @property
     def image(self) -> MediaItemImage | None:
@@ -465,28 +474,22 @@ class AlbumTrack(Track):
 
     @classmethod
     def from_track(
-        cls: type,
+        cls,
         track: Track,
         album: Album | None = None,
         disc_number: int | None = None,
         track_number: int | None = None,
-    ) -> Self:
+    ) -> AlbumTrack:
         """Cast Track to AlbumTrack."""
-        if album is None:
-            album = track.album
+        album_track = track.to_dict()
+        if album is None and track.album:
+            album_track["album"] = track.album
         if disc_number is None:
-            disc_number = track.disc_number
+            album_track["disc_number"] = track.disc_number
         if track_number is None:
-            track_number = track.track_number
+            album_track["track_number"] = track.track_number
         # let mushmumaro instantiate a new object - this will ensure that valididation takes place
-        return AlbumTrack.from_dict(
-            {
-                **track.to_dict(),
-                "album": album.to_dict(),
-                "disc_number": disc_number,
-                "track_number": track_number,
-            }
-        )
+        return AlbumTrack.from_dict(album_track)
 
 
 @dataclass(kw_only=True)
@@ -540,7 +543,7 @@ class BrowseFolder(MediaItem):
     label: str = ""
     provider_mappings: set[ProviderMapping] = field(default_factory=set)
 
-    def __post_init__(self):
+    def __post_init__(self) -> None:
         """Call after init."""
         super().__post_init__()
         if not self.path:
@@ -571,7 +574,7 @@ class SearchResults(DataClassDictMixin):
     radio: list[Radio | ItemMapping] = field(default_factory=list)
 
 
-def media_from_dict(media_item: dict) -> MediaItemType:
+def media_from_dict(media_item: dict[str, Any]) -> MediaItemType | ItemMapping:
     """Return MediaItem from dict."""
     if "provider_mappings" not in media_item:
         return ItemMapping.from_dict(media_item)
@@ -585,4 +588,9 @@ def media_from_dict(media_item: dict) -> MediaItemType:
         return Playlist.from_dict(media_item)
     if media_item["media_type"] == "radio":
         return Radio.from_dict(media_item)
-    return MediaItem.from_dict(media_item)
+    raise InvalidDataError("Unknown media type")
+
+
+def is_track(val: MediaItem) -> TypeGuard[Track]:
+    """Return true if this MediaItem is a track."""
+    return val.media_type == MediaType.TRACK
index ac90f491c61264f6de29577d45935b673be828fc..71610683d683d0b50ec0cbf38b47abf636692bbf 100644 (file)
@@ -33,7 +33,7 @@ class PlayerMedia(DataClassDictMixin):
     duration: int | None = None  # optional
     queue_id: str | None = None  # only present for requests from queue controller
     queue_item_id: str | None = None  # only present for requests from queue controller
-    custom_data: dict | None = None  # optional
+    custom_data: dict[str, Any] | None = None  # optional
 
 
 @dataclass
@@ -146,6 +146,6 @@ class Player(DataClassDictMixin):
         return None
 
     @current_item_id.setter
-    def current_item_id(self, uri: str) -> str | None:
+    def current_item_id(self, uri: str) -> None:
         """Set current_item_id (for backwards compatibility)."""
         self.current_media = PlayerMedia(uri)
index d205ce431a60f00fc5c88239f0d30c782bc603bd..87e37e8e01adcccc14cb5238de036ee5a4015e2d 100644 (file)
@@ -59,7 +59,7 @@ class PlayerQueue(DataClassDictMixin):
         return d
 
     @classmethod
-    def from_cache(cls: Self, d: dict[Any, Any]) -> Self:
+    def from_cache(cls, d: dict[Any, Any]) -> Self:
         """Restore a PlayerQueue from a cache dict."""
         d.pop("current_item", None)
         d.pop("next_item", None)
index 947e607700043919c7d68b94c88689347a4dc038..7b706b6b0d8b58c8edf9f42c5a65736c1dd7a192 100644 (file)
@@ -53,7 +53,7 @@ class ProviderManifest(DataClassORJSONMixin):
     mdns_discovery: list[str] | None = None
 
     @classmethod
-    async def parse(cls: ProviderManifest, manifest_file: str) -> ProviderManifest:
+    async def parse(cls, manifest_file: str) -> ProviderManifest:
         """Parse ProviderManifest from file."""
         return await load_json_file(manifest_file, ProviderManifest)
 
@@ -79,9 +79,9 @@ class SyncTask:
     provider_domain: str
     provider_instance: str
     media_types: tuple[MediaType, ...]
-    task: asyncio.Task
+    task: asyncio.Task[None] | None
 
-    def to_dict(self, *args, **kwargs) -> dict[str, Any]:
+    def to_dict(self) -> dict[str, Any]:
         """Return SyncTask as (serializable) dict."""
         # ruff: noqa:ARG002
         return {
index 7cff789da0051183be3b13992c17d0aa5e660371..3c416342f092956978daeae50a6758234b50b57c 100644 (file)
@@ -9,7 +9,7 @@ from uuid import uuid4
 from mashumaro import DataClassDictMixin
 
 from .enums import MediaType
-from .media_items import ItemMapping, MediaItemImage, Radio, Track
+from .media_items import ItemMapping, MediaItemImage, Radio, Track, UniqueList, is_track
 from .streamdetails import StreamDetails
 
 
@@ -28,7 +28,7 @@ class QueueItem(DataClassDictMixin):
     image: MediaItemImage | None = None
     index: int = 0
 
-    def __post_init__(self):
+    def __post_init__(self) -> None:
         """Set default values."""
         if self.streamdetails and self.streamdetails.stream_title:
             self.name = self.streamdetails.stream_title
@@ -47,7 +47,7 @@ class QueueItem(DataClassDictMixin):
     @property
     def uri(self) -> str:
         """Return uri for this QueueItem (for logging purposes)."""
-        if self.media_item:
+        if self.media_item and self.media_item.uri:
             return self.media_item.uri
         return self.queue_item_id
 
@@ -63,14 +63,13 @@ class QueueItem(DataClassDictMixin):
     @classmethod
     def from_media_item(cls, queue_id: str, media_item: Track | Radio) -> QueueItem:
         """Construct QueueItem from track/radio item."""
-        if media_item.media_type == MediaType.TRACK:
+        if is_track(media_item):
             artists = "/".join(x.name for x in media_item.artists)
             name = f"{artists} - {media_item.name}"
             # save a lot of data/bandwidth by simplifying nested objects
-            media_item.artists = [ItemMapping.from_item(x) for x in media_item.artists]
+            media_item.artists = UniqueList([ItemMapping.from_item(x) for x in media_item.artists])
             if media_item.album:
                 media_item.album = ItemMapping.from_item(media_item.album)
-            media_item.albums = []
         else:
             name = media_item.name
         return cls(
@@ -89,7 +88,7 @@ class QueueItem(DataClassDictMixin):
         return base
 
     @classmethod
-    def from_cache(cls: Self, d: dict[Any, Any]) -> Self:
+    def from_cache(cls, d: dict[Any, Any]) -> Self:
         """Restore a QueueItem from a cache dict."""
         d.pop("streamdetails", None)
         return cls.from_dict(d)
index bbebd7cd5944f7bb5b8119441f1f7500ad555a49..813a2b5de05fdc89a7e43d5d669c7eba74fd4b6d 100644 (file)
@@ -51,6 +51,7 @@ from music_assistant.common.models.media_items import (
     Track,
 )
 from music_assistant.common.models.streamdetails import StreamDetails
+from music_assistant.constants import UNKNOWN_ARTIST
 from music_assistant.server.helpers.auth import AuthenticationHelper
 from music_assistant.server.helpers.tags import parse_tags
 from music_assistant.server.models.music_provider import MusicProvider
@@ -426,7 +427,7 @@ class PlexProvider(MusicProvider):
         artist_id = FAKE_ARTIST_PREFIX + artist_name
         return Artist(
             item_id=artist_id,
-            name=artist_name,
+            name=artist_name or UNKNOWN_ARTIST,
             provider=self.domain,
             provider_mappings={
                 ProviderMapping(
@@ -498,7 +499,7 @@ class PlexProvider(MusicProvider):
         album = Album(
             item_id=album_id,
             provider=self.domain,
-            name=plex_album.title,
+            name=plex_album.title or "[Unknown]",
             provider_mappings={
                 ProviderMapping(
                     item_id=str(album_id),
@@ -533,7 +534,7 @@ class PlexProvider(MusicProvider):
             self._get_item_mapping(
                 MediaType.ARTIST,
                 plex_album.parentKey,
-                plex_album.parentTitle,
+                plex_album.parentTitle or UNKNOWN_ARTIST,
             )
         )
         return album
@@ -546,7 +547,7 @@ class PlexProvider(MusicProvider):
             raise InvalidDataError(msg)
         artist = Artist(
             item_id=artist_id,
-            name=plex_artist.title,
+            name=plex_artist.title or UNKNOWN_ARTIST,
             provider=self.domain,
             provider_mappings={
                 ProviderMapping(
@@ -575,7 +576,7 @@ class PlexProvider(MusicProvider):
         playlist = Playlist(
             item_id=plex_playlist.key,
             provider=self.domain,
-            name=plex_playlist.title,
+            name=plex_playlist.title or "[Unknown]",
             provider_mappings={
                 ProviderMapping(
                     item_id=plex_playlist.key,
@@ -612,7 +613,7 @@ class PlexProvider(MusicProvider):
         track = Track(
             item_id=plex_track.key,
             provider=self.instance_id,
-            name=plex_track.title,
+            name=plex_track.title or "[Unknown]",
             provider_mappings={
                 ProviderMapping(
                     item_id=plex_track.key,
@@ -639,13 +640,15 @@ class PlexProvider(MusicProvider):
             # The artist of the track if different from the album's artist.
             # For this kind of artist, we just know the name, so we create a fake artist,
             # if it does not already exist.
-            track.artists.append(await self._get_or_create_artist_by_name(plex_track.originalTitle))
+            track.artists.append(
+                await self._get_or_create_artist_by_name(plex_track.originalTitle or UNKNOWN_ARTIST)
+            )
         elif plex_track.grandparentKey:
             track.artists.append(
                 self._get_item_mapping(
                     MediaType.ARTIST,
                     plex_track.grandparentKey,
-                    plex_track.grandparentTitle,
+                    plex_track.grandparentTitle or UNKNOWN_ARTIST,
                 )
             )
         else:
index cf337fb172b3de0c1d0f234e39688ef453d1be4c..2bca30dae39474e2633fc9890d7c713aa8cbebb5 100644 (file)
--- a/mypy.ini
+++ b/mypy.ini
@@ -21,4 +21,4 @@ disallow_untyped_decorators = true
 disallow_untyped_defs = true
 warn_return_any = true
 warn_unreachable = true
-packages=tests,music_assistant.client,music_assistant.server.providers.jellyfin
+packages=tests,music_assistant.client,music_assistant.common,music_assistant.server.providers.jellyfin