From: Jc2k Date: Tue, 2 Jul 2024 20:50:19 +0000 (+0100) Subject: Mypy: Add music_assistant.common (#1428) X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=c2884f9fa00d0046ea584d0a53347df08f5315a7;p=music-assistant-server.git Mypy: Add music_assistant.common (#1428) --- diff --git a/music_assistant/client/music.py b/music_assistant/client/music.py index 628290bb..858352dd 100644 --- a/music_assistant/client/music.py +++ b/music_assistant/client/music.py @@ -440,7 +440,7 @@ class Music: path: str | None = None, limit: int | None = None, offset: int | None = None, - ) -> list[MediaItemType]: + ) -> list[MediaItemType | ItemMapping]: """Browse Music providers.""" return [ media_from_dict(obj) @@ -454,7 +454,7 @@ class Music: async def recently_played( self, limit: int = 10, media_types: list[MediaType] | None = None - ) -> list[MediaItemType]: + ) -> list[MediaItemType | ItemMapping]: """Return a list of the last played items.""" return [ media_from_dict(item) @@ -466,7 +466,7 @@ class Music: async def get_item_by_uri( self, uri: str, - ) -> MediaItemType: + ) -> MediaItemType | ItemMapping: """Get single music item providing a mediaitem uri.""" return media_from_dict(await self.client.send_command("music/item_by_uri", uri=uri)) @@ -478,7 +478,7 @@ class Music: force_refresh: bool = False, lazy: bool = True, add_to_library: bool = False, - ) -> MediaItemType: + ) -> MediaItemType | ItemMapping: """Get single music item by id and media type.""" return media_from_dict( await self.client.send_command( @@ -534,7 +534,7 @@ class Music: async def refresh_item( self, media_item: MediaItemType, - ) -> MediaItemType | None: + ) -> MediaItemType | ItemMapping | None: """Try to refresh a mediaitem by requesting it's full object or search for substitutes.""" if result := await self.client.send_command("music/refresh_item", media_item=media_item): return media_from_dict(result) diff --git a/music_assistant/common/helpers/datetime.py b/music_assistant/common/helpers/datetime.py index 80fa279f..af8f4b24 100644 --- a/music_assistant/common/helpers/datetime.py +++ b/music_assistant/common/helpers/datetime.py @@ -27,7 +27,7 @@ def now_timestamp() -> float: return now().timestamp() -def future_timestamp(**kwargs) -> float: +def future_timestamp(**kwargs: float) -> float: """Return current timestamp + timedelta.""" return (now() + datetime.timedelta(**kwargs)).timestamp() diff --git a/music_assistant/common/helpers/global_cache.py b/music_assistant/common/helpers/global_cache.py index a33f50a9..6cd741dd 100644 --- a/music_assistant/common/helpers/global_cache.py +++ b/music_assistant/common/helpers/global_cache.py @@ -8,7 +8,7 @@ from typing import Any # global cache - we use this on a few places (as limited as possible) # where we have no other options _global_cache_lock = asyncio.Lock() -_global_cache = {} +_global_cache: dict[str, Any] = {} def get_global_cache_value(key: str, default: Any = None) -> Any: diff --git a/music_assistant/common/helpers/json.py b/music_assistant/common/helpers/json.py index 845fbece..8fb679e1 100644 --- a/music_assistant/common/helpers/json.py +++ b/music_assistant/common/helpers/json.py @@ -4,10 +4,11 @@ import asyncio import base64 from _collections_abc import dict_keys, dict_values from types import MethodType -from typing import Any +from typing import Any, TypeVar import aiofiles import orjson +from mashumaro.mixins.orjson import DataClassORJSONMixin JSON_ENCODE_EXCEPTIONS = (TypeError, ValueError) JSON_DECODE_EXCEPTIONS = (orjson.JSONDecodeError,) @@ -59,12 +60,11 @@ def json_dumps(data: Any, indent: bool = False) -> str: json_loads = orjson.loads +TargetT = TypeVar("TargetT", bound=DataClassORJSONMixin) -async def load_json_file(path: str, target_class: type | None = None) -> dict: + +async def load_json_file(path: str, target_class: type[TargetT]) -> TargetT: """Load JSON from file.""" async with aiofiles.open(path, "r") as _file: content = await _file.read() - if target_class: - # support for a mashumaro model - return target_class.from_json(content) - return json_loads(content) + return target_class.from_json(content) diff --git a/music_assistant/common/helpers/uri.py b/music_assistant/common/helpers/uri.py index 93ed85cd..a381861d 100644 --- a/music_assistant/common/helpers/uri.py +++ b/music_assistant/common/helpers/uri.py @@ -10,7 +10,7 @@ from music_assistant.common.models.errors import InvalidProviderID, InvalidProvi base62_length22_id_pattern = re.compile(r"^[a-zA-Z0-9]{22}$") -def valid_base62_length22(item_id) -> bool: +def valid_base62_length22(item_id: str) -> bool: """Validate Spotify style ID.""" return bool(base62_length22_id_pattern.match(item_id)) diff --git a/music_assistant/common/helpers/util.py b/music_assistant/common/helpers/util.py index bbf28c18..62b5c433 100644 --- a/music_assistant/common/helpers/util.py +++ b/music_assistant/common/helpers/util.py @@ -7,14 +7,13 @@ import os import re import socket from collections.abc import Callable +from collections.abc import Set as AbstractSet from typing import Any, TypeVar from urllib.parse import urlparse from uuid import UUID # pylint: disable=invalid-name T = TypeVar("T") -_UNDEF: dict = {} -CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) CALLBACK_TYPE = Callable[[], None] # pylint: enable=invalid-name @@ -50,7 +49,7 @@ def try_parse_float(possible_float: Any, default: float | None = 0.0) -> float | return default -def try_parse_bool(possible_bool: Any) -> str: +def try_parse_bool(possible_bool: Any) -> bool: """Try to parse a bool.""" if isinstance(possible_bool, bool): return possible_bool @@ -79,7 +78,7 @@ def create_sort_name(input_str: str) -> str: return input_str.strip() -def parse_title_and_version(title: str, track_version: str | None = None): +def parse_title_and_version(title: str, track_version: str | None = None) -> tuple[str, str]: """Try to parse clean track title and version from the title.""" version = "" for splitter in [" (", " [", " - ", " (", " [", "-"]: @@ -135,7 +134,7 @@ def clean_title(title: str) -> str: return title.strip() -def get_version_substitute(version_str: str): +def get_version_substitute(version_str: str) -> str: """Transform provider version str to universal version type.""" version_str = version_str.lower() # substitute edit and edition with version @@ -169,7 +168,7 @@ def strip_url(line: str) -> str: ).rstrip() -def strip_dotcom(line: str): +def strip_dotcom(line: str) -> str: """Strip scheme-less netloc from line.""" return dot_com_pattern.sub("", line) @@ -227,17 +226,17 @@ def clean_stream_title(line: str) -> str: return line -async def get_ip(): +async def get_ip() -> str: """Get primary IP-address for this host.""" - def _get_ip(): + def _get_ip() -> str: """Get primary IP-address for this host.""" # pylint: disable=broad-except,no-member sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: # doesn't even have to be reachable sock.connect(("10.255.255.255", 1)) - _ip = sock.getsockname()[0] + _ip = str(sock.getsockname()[0]) except Exception: _ip = "127.0.0.1" finally: @@ -273,7 +272,7 @@ async def select_free_port(range_start: int, range_end: int) -> int: async def get_ip_from_host(dns_name: str) -> str | None: """Resolve (first) IP-address for given dns name.""" - def _resolve(): + def _resolve() -> str | None: try: return socket.gethostbyname(dns_name) except Exception: # pylint: disable=broad-except @@ -283,7 +282,7 @@ async def get_ip_from_host(dns_name: str) -> str | None: return await asyncio.to_thread(_resolve) -async def get_ip_pton(ip_string: str | None = None): +async def get_ip_pton(ip_string: str | None = None) -> bytes: """Return socket pton for local ip.""" if ip_string is None: ip_string = await get_ip() @@ -294,7 +293,7 @@ async def get_ip_pton(ip_string: str | None = None): return await asyncio.to_thread(socket.inet_pton, socket.AF_INET6, ip_string) -def get_folder_size(folderpath): +def get_folder_size(folderpath: str) -> float: """Return folder size in gb.""" total_size = 0 # pylint: disable=unused-variable @@ -306,7 +305,9 @@ def get_folder_size(folderpath): return total_size / float(1 << 30) -def merge_dict(base_dict: dict, new_dict: dict, allow_overwite=False): +def merge_dict( + base_dict: dict[Any, Any], new_dict: dict[Any, Any], allow_overwite: bool = False +) -> dict[Any, Any]: """Merge dict without overwriting existing values.""" final_dict = base_dict.copy() for key, value in new_dict.items(): @@ -321,12 +322,12 @@ def merge_dict(base_dict: dict, new_dict: dict, allow_overwite=False): return final_dict -def merge_tuples(base: tuple, new: tuple) -> tuple: +def merge_tuples(base: tuple[Any, ...], new: tuple[Any, ...]) -> tuple[Any, ...]: """Merge 2 tuples.""" return tuple(x for x in base if x not in new) + tuple(new) -def merge_lists(base: list, new: list) -> list: +def merge_lists(base: list[Any], new: list[Any]) -> list[Any]: """Merge 2 lists.""" return [x for x in base if x not in new] + list(new) @@ -335,7 +336,7 @@ def get_changed_keys( dict1: dict[str, Any], dict2: dict[str, Any], ignore_keys: list[str] | None = None, -) -> set[str]: +) -> AbstractSet[str]: """Compare 2 dicts and return set of changed keys.""" return get_changed_values(dict1, dict2, ignore_keys).keys() @@ -369,7 +370,7 @@ def get_changed_values( return changed_values -def empty_queue(q: asyncio.Queue) -> None: +def empty_queue(q: asyncio.Queue[T]) -> None: """Empty an asyncio Queue.""" for _ in range(q.qsize()): try: @@ -386,10 +387,3 @@ def is_valid_uuid(uuid_to_test: str) -> bool: except ValueError: return False return str(uuid_obj) == uuid_to_test - - -class classproperty(property): # noqa: N801 - """Implement class property for python3.11+.""" - - def __get__(self, cls, owner): # noqa: D105 - return classmethod(self.fget).__get__(None, owner)() diff --git a/music_assistant/common/models/api.py b/music_assistant/common/models/api.py index d2456aaf..dfbfd663 100644 --- a/music_assistant/common/models/api.py +++ b/music_assistant/common/models/api.py @@ -65,7 +65,7 @@ MessageType = ( ) -def parse_message(raw: dict) -> MessageType: +def parse_message(raw: dict[Any, Any]) -> MessageType: """Parse Message from raw dict object.""" if "event" in raw: return EventMessage.from_dict(raw) diff --git a/music_assistant/common/models/config_entries.py b/music_assistant/common/models/config_entries.py index b2f00a4c..27ccf59d 100644 --- a/music_assistant/common/models/config_entries.py +++ b/music_assistant/common/models/config_entries.py @@ -4,7 +4,7 @@ from __future__ import annotations import logging import warnings -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass from enum import Enum from types import NoneType @@ -46,8 +46,8 @@ warnings.filterwarnings("ignore", category=UserWarning, module="mashumaro") LOGGER = logging.getLogger(__name__) -ENCRYPT_CALLBACK: callable[[str], str] | None = None -DECRYPT_CALLBACK: callable[[str], str] | None = None +ENCRYPT_CALLBACK: Callable[[str], str] | None = None +DECRYPT_CALLBACK: Callable[[str], str] | None = None ConfigValueType = ( str @@ -58,15 +58,16 @@ ConfigValueType = ( | list[str] | list[int] | list[tuple[int, int]] + | Enum | None ) -ConfigEntryTypeMap = { +ConfigEntryTypeMap: dict[ConfigEntryType, type[ConfigValueType]] = { ConfigEntryType.BOOLEAN: bool, ConfigEntryType.STRING: str, ConfigEntryType.SECURE_STRING: str, ConfigEntryType.INTEGER: int, - ConfigEntryType.INTEGER_TUPLE: tuple[int, int], + ConfigEntryType.INTEGER_TUPLE: tuple[int, int], # type: ignore[dict-item] ConfigEntryType.FLOAT: float, ConfigEntryType.LABEL: str, ConfigEntryType.DIVIDER: str, @@ -187,6 +188,7 @@ class Config(DataClassDictMixin): """Return config value for given key.""" config_value = self.values[key] if config_value.type == ConfigEntryType.SECURE_STRING: + assert isinstance(config_value.value, str) assert DECRYPT_CALLBACK is not None return DECRYPT_CALLBACK(config_value.value) return config_value.value @@ -213,8 +215,9 @@ class Config(DataClassDictMixin): def to_raw(self) -> dict[str, Any]: """Return minimized/raw dict to store in persistent storage.""" - def _handle_value(value: ConfigEntry): + def _handle_value(value: ConfigEntry) -> ConfigValueType: if value.type == ConfigEntryType.SECURE_STRING: + assert isinstance(value.value, str) assert ENCRYPT_CALLBACK is not None return ENCRYPT_CALLBACK(value.value) return value.value @@ -253,9 +256,7 @@ class Config(DataClassDictMixin): setattr(self, key, new_val) changed_keys.add(key) - # config entry values - values = update.get("values", update) - for key, new_val in values.items(): + for key, new_val in update.items(): if key in root_values: continue cur_val = self.values[key].value if key in self.values else None @@ -366,12 +367,12 @@ CONF_ENTRY_AUTO_PLAY = ConfigEntry( CONF_ENTRY_OUTPUT_CHANNELS = ConfigEntry( key=CONF_OUTPUT_CHANNELS, type=ConfigEntryType.STRING, - options=[ + options=( ConfigValueOption("Stereo (both channels)", "stereo"), ConfigValueOption("Left channel", "left"), ConfigValueOption("Right channel", "right"), ConfigValueOption("Mono (both channels)", "mono"), - ], + ), default_value="stereo", label="Output Channel Mode", category="audio", @@ -503,12 +504,12 @@ CONF_ENTRY_TTS_PRE_ANNOUNCE = ConfigEntry( CONF_ENTRY_ANNOUNCE_VOLUME_STRATEGY = ConfigEntry( key=CONF_ANNOUNCE_VOLUME_STRATEGY, type=ConfigEntryType.STRING, - options=[ + options=( ConfigValueOption("Absolute volume", "absolute"), ConfigValueOption("Relative volume increase", "relative"), ConfigValueOption("Volume increase by fixed percentage", "percentual"), ConfigValueOption("Do not adjust volume", "none"), - ], + ), default_value="percentual", label="Volume strategy for Announcements", category="announcements", @@ -557,7 +558,7 @@ CONF_ENTRY_PLAYER_ICON_GROUP = ConfigEntry.from_dict( CONF_ENTRY_SAMPLE_RATES = ConfigEntry( key=CONF_SAMPLE_RATES, type=ConfigEntryType.INTEGER_TUPLE, - options=[ + options=( ConfigValueOption("44.1kHz / 16 bits", (44100, 16)), ConfigValueOption("44.1kHz / 24 bits", (44100, 24)), ConfigValueOption("48kHz / 16 bits", (48000, 16)), @@ -574,7 +575,7 @@ CONF_ENTRY_SAMPLE_RATES = ConfigEntry( ConfigValueOption("352.8kHz / 24 bits", (352800, 24)), ConfigValueOption("384kHz / 16 bits", (384000, 16)), ConfigValueOption("384kHz / 24 bits", (384000, 24)), - ], + ), default_value=[(44100, 16), (48000, 16)], required=True, multi_value=True, @@ -593,14 +594,19 @@ def create_sample_rates_config_entry( hidden: bool = False, ) -> ConfigEntry: """Create sample rates config entry based on player specific helpers.""" + assert CONF_ENTRY_SAMPLE_RATES.options conf_entry = ConfigEntry.from_dict(CONF_ENTRY_SAMPLE_RATES.to_dict()) - conf_entry.options = [] - conf_entry.default_value = [] conf_entry.hidden = hidden + options: list[ConfigValueOption] = [] + default_value: list[tuple[int, int]] = [] for option in CONF_ENTRY_SAMPLE_RATES.options: + if not isinstance(option.value, tuple): + continue sample_rate, bit_depth = option.value if sample_rate <= max_sample_rate and bit_depth <= max_bit_depth: - conf_entry.options.append(option) + options.append(option) if sample_rate <= safe_max_sample_rate and bit_depth <= safe_max_bit_depth: - conf_entry.default_value.append(option.value) + default_value.append(option.value) + conf_entry.options = tuple(options) + conf_entry.default_value = default_value return conf_entry diff --git a/music_assistant/common/models/enums.py b/music_assistant/common/models/enums.py index df4c5707..bd018efa 100644 --- a/music_assistant/common/models/enums.py +++ b/music_assistant/common/models/enums.py @@ -3,13 +3,25 @@ from __future__ import annotations import contextlib -from enum import StrEnum -from typing import Self +from enum import EnumType, StrEnum -from music_assistant.common.helpers.util import classproperty +class MediaTypeMeta(EnumType): + """Class properties for MediaType.""" -class MediaType(StrEnum): + @property + def ALL(self) -> list[MediaType]: # noqa: N802 + """All MediaTypes.""" + return [ + MediaType.ARTIST, + MediaType.ALBUM, + MediaType.TRACK, + MediaType.PLAYLIST, + MediaType.RADIO, + ] + + +class MediaType(StrEnum, metaclass=MediaTypeMeta): """Enum for MediaType.""" ARTIST = "artist" @@ -23,21 +35,10 @@ class MediaType(StrEnum): UNKNOWN = "unknown" @classmethod - def _missing_(cls: Self, value: object) -> Self: # noqa: ARG003 + def _missing_(cls, value: object) -> MediaType: # noqa: ARG003 """Set default enum member if an unknown value is provided.""" return cls.UNKNOWN - @classproperty - def ALL(self) -> tuple[MediaType, ...]: # noqa: N802 - """Return all (default) MediaTypes as tuple.""" - return ( - MediaType.ARTIST, - MediaType.ALBUM, - MediaType.TRACK, - MediaType.PLAYLIST, - MediaType.RADIO, - ) - class ExternalID(StrEnum): """Enum with External ID types.""" @@ -56,7 +57,7 @@ class ExternalID(StrEnum): UNKNOWN = "unknown" @classmethod - def _missing_(cls: Self, value: object) -> Self: # noqa: ARG003 + def _missing_(cls, value: object) -> ExternalID: # noqa: ARG003 """Set default enum member if an unknown value is provided.""" return cls.UNKNOWN @@ -78,7 +79,7 @@ class LinkType(StrEnum): UNKNOWN = "unknown" @classmethod - def _missing_(cls: Self, value: object) -> Self: # noqa: ARG003 + def _missing_(cls, value: object) -> LinkType: # noqa: ARG003 """Set default enum member if an unknown value is provided.""" return cls.UNKNOWN @@ -98,7 +99,7 @@ class ImageType(StrEnum): OTHER = "other" @classmethod - def _missing_(cls: Self, value: object) -> Self: # noqa: ARG003 + def _missing_(cls, value: object) -> ImageType: # noqa: ARG003 """Set default enum member if an unknown value is provided.""" return cls.OTHER @@ -142,12 +143,12 @@ class ContentType(StrEnum): UNKNOWN = "?" @classmethod - def _missing_(cls, value: object) -> Self: # noqa: ARG003 + def _missing_(cls, value: object) -> ContentType: # noqa: ARG003 """Set default enum member if an unknown value is provided.""" return cls.UNKNOWN @classmethod - def try_parse(cls, string: str) -> Self: + def try_parse(cls, string: str) -> ContentType: """Try to parse ContentType from (url)string/extension.""" tempstr = string.lower() if "audio/" in tempstr: @@ -247,7 +248,7 @@ class PlayerType(StrEnum): UNKNOWN = "unknown" @classmethod - def _missing_(cls: Self, value: object) -> Self: # noqa: ARG003 + def _missing_(cls, value: object) -> PlayerType: # noqa: ARG003 """Set default enum member if an unknown value is provided.""" return cls.UNKNOWN @@ -275,7 +276,7 @@ class PlayerFeature(StrEnum): UNKNOWN = "unknown" @classmethod - def _missing_(cls: Self, value: object) -> Self: # noqa: ARG003 + def _missing_(cls, value: object) -> PlayerFeature: # noqa: ARG003 """Set default enum member if an unknown value is provided.""" return cls.UNKNOWN @@ -302,7 +303,7 @@ class EventType(StrEnum): UNKNOWN = "unknown" @classmethod - def _missing_(cls: Self, value: object) -> Self: # noqa: ARG003 + def _missing_(cls, value: object) -> EventType: # noqa: ARG003 """Set default enum member if an unknown value is provided.""" return cls.UNKNOWN @@ -364,7 +365,7 @@ class ProviderFeature(StrEnum): UNKNOWN = "unknown" @classmethod - def _missing_(cls: Self, value: object) -> Self: # noqa: ARG003 + def _missing_(cls, value: object) -> ProviderFeature: # noqa: ARG003 """Set default enum member if an unknown value is provided.""" return cls.UNKNOWN @@ -396,7 +397,7 @@ class ConfigEntryType(StrEnum): UNKNOWN = "unknown" @classmethod - def _missing_(cls: Self, value: object) -> Self: # noqa: ARG003 + def _missing_(cls, value: object) -> ConfigEntryType: # noqa: ARG003 """Set default enum member if an unknown value is provided.""" return cls.UNKNOWN diff --git a/music_assistant/common/models/errors.py b/music_assistant/common/models/errors.py index fb0fc9ae..d1add798 100644 --- a/music_assistant/common/models/errors.py +++ b/music_assistant/common/models/errors.py @@ -115,7 +115,7 @@ class RetriesExhausted(MusicAssistantError): class ResourceTemporarilyUnavailable(MusicAssistantError): """Error thrown when a resource is temporarily unavailable.""" - def __init__(self, *args, backoff_time: int = 0) -> None: + def __init__(self, *args: object, backoff_time: int = 0) -> None: """Initialize.""" super().__init__(*args) self.backoff_time = backoff_time diff --git a/music_assistant/common/models/media_items.py b/music_assistant/common/models/media_items.py index 0549844d..d271b111 100644 --- a/music_assistant/common/models/media_items.py +++ b/music_assistant/common/models/media_items.py @@ -4,7 +4,7 @@ from __future__ import annotations from collections.abc import Iterable from dataclasses import dataclass, field, fields -from typing import TYPE_CHECKING, Any, Self, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeGuard, TypeVar, cast from mashumaro import DataClassDictMixin @@ -34,7 +34,7 @@ class UniqueList(list[_T]): if not iterable: super().__init__() return - seen = set() + seen: set[_T] = set() seen_add = seen.add super().__init__(x for x in iterable if not (x in seen or seen_add(x))) @@ -61,7 +61,7 @@ class AudioFormat(DataClassDictMixin): output_format_str: str = "" bit_rate: int = 320 # optional - def __post_init__(self): + def __post_init__(self) -> None: """Execute actions after init.""" if not self.output_format_str and self.content_type.is_pcm(): self.output_format_str = ( @@ -90,9 +90,9 @@ class AudioFormat(DataClassDictMixin): """Return the PCM sample size.""" return int(self.sample_rate * (self.bit_depth / 8) * self.channels) - def __eq__(self, other: AudioFormat) -> bool: + def __eq__(self, other: object) -> bool: """Check equality of two items.""" - if not other: + if not isinstance(other, AudioFormat): return False return self.output_format_str == other.output_format_str @@ -121,7 +121,7 @@ class ProviderMapping(DataClassDictMixin): quality += 1 return quality - def __post_init__(self): + def __post_init__(self) -> None: """Call after init.""" # having items for unavailable providers can have all sorts # of unpredictable results so ensure we have accurate availability status @@ -138,9 +138,9 @@ class ProviderMapping(DataClassDictMixin): """Return custom hash.""" return hash((self.provider_instance, self.item_id)) - def __eq__(self, other: ProviderMapping) -> bool: + def __eq__(self, other: object) -> bool: """Check equality of two items.""" - if not other: + if not isinstance(other, ProviderMapping): return False return self.provider_instance == other.provider_instance and self.item_id == other.item_id @@ -156,8 +156,10 @@ class MediaItemLink(DataClassDictMixin): """Return custom hash.""" return hash(self.type) - def __eq__(self, other: MediaItemLink) -> bool: + def __eq__(self, other: object) -> bool: """Check equality of two items.""" + if not isinstance(other, MediaItemLink): + return False return self.url == other.url @@ -174,8 +176,10 @@ class MediaItemImage(DataClassDictMixin): """Return custom hash.""" return hash((self.type.value, self.path)) - def __eq__(self, other: MediaItemImage) -> bool: + def __eq__(self, other: object) -> bool: """Check equality of two items.""" + if not isinstance(other, MediaItemImage): + return False return self.__hash__() == other.__hash__() @classmethod @@ -202,8 +206,10 @@ class MediaItemChapter(DataClassDictMixin): """Return custom hash.""" return hash(self.chapter_id) - def __eq__(self, other: MediaItemChapter) -> bool: + def __eq__(self, other: object) -> bool: """Check equality of two items.""" + if not isinstance(other, MediaItemChapter): + return False return self.chapter_id == other.chapter_id @@ -248,8 +254,7 @@ class MediaItemMetadata(DataClassDictMixin): new_val = merge_lists(cur_val, new_val) setattr(self, fld.name, new_val) elif isinstance(cur_val, set) and isinstance(new_val, set | list | tuple): - new_val = cur_val.update(new_val) - setattr(self, fld.name, new_val) + cur_val.update(new_val) elif new_val and fld.name in ( "popularity", "last_refresh", @@ -277,12 +282,8 @@ class _MediaItemBase(DataClassDictMixin): external_ids: set[tuple[ExternalID, str]] = field(default_factory=set) media_type: MediaType = MediaType.UNKNOWN - def __post_init__(self): + def __post_init__(self) -> None: """Call after init.""" - if self.name is None: - # we've got some reports where the name was empty, causing weird issues. - # e.g. here: https://github.com/music-assistant/hass-music-assistant/issues/1515 - self.name = "[Unknown]" if self.uri is None: self.uri = create_uri(self.media_type, self.provider, self.item_id) if self.sort_name is None: @@ -318,8 +319,10 @@ class _MediaItemBase(DataClassDictMixin): """Return custom hash.""" return hash(self.uri) - def __eq__(self, other: MediaItem | ItemMapping) -> bool: + def __eq__(self, other: object) -> bool: """Check equality of two items.""" + if not isinstance(other, MediaItem | ItemMapping): + return False return self.uri == other.uri @@ -327,7 +330,6 @@ class _MediaItemBase(DataClassDictMixin): class MediaItem(_MediaItemBase): """Base representation of a media item.""" - __hash__ = _MediaItemBase.__hash__ __eq__ = _MediaItemBase.__eq__ provider_mappings: set[ProviderMapping] @@ -336,8 +338,12 @@ class MediaItem(_MediaItemBase): favorite: bool = False position: int | None = None # required for playlist tracks, optional for all other + def __hash__(self) -> int: + """Return hash of MediaItem.""" + return super().__hash__() + @property - def available(self): + def available(self) -> bool: """Return (calculated) availability.""" return any(x.available for x in self.provider_mappings) @@ -360,7 +366,7 @@ class ItemMapping(_MediaItemBase): image: MediaItemImage | None = None @classmethod - def from_item(cls, item: MediaItem) -> ItemMapping: + def from_item(cls, item: MediaItem | ItemMapping) -> ItemMapping: """Create ItemMapping object from regular item.""" if isinstance(item, ItemMapping): return item @@ -409,7 +415,6 @@ class Album(MediaItem): class Track(MediaItem): """Model for a track.""" - __hash__ = _MediaItemBase.__hash__ __eq__ = _MediaItemBase.__eq__ media_type: MediaType = MediaType.TRACK @@ -420,7 +425,7 @@ class Track(MediaItem): disc_number: int | None = None # required for album tracks track_number: int | None = None # required for album tracks - def __hash__(self): + def __hash__(self) -> int: """Return custom hash.""" return hash((self.provider, self.item_id)) @@ -432,7 +437,11 @@ class Track(MediaItem): This is often an indicator that this track is an episode from a Podcast or AudioBook. """ - return self.metadata and self.metadata.chapters and len(self.metadata.chapters) > 1 + if not self.metadata: + return False + if not self.metadata.chapters: + return False + return len(self.metadata.chapters) > 1 @property def image(self) -> MediaItemImage | None: @@ -465,28 +474,22 @@ class AlbumTrack(Track): @classmethod def from_track( - cls: type, + cls, track: Track, album: Album | None = None, disc_number: int | None = None, track_number: int | None = None, - ) -> Self: + ) -> AlbumTrack: """Cast Track to AlbumTrack.""" - if album is None: - album = track.album + album_track = track.to_dict() + if album is None and track.album: + album_track["album"] = track.album if disc_number is None: - disc_number = track.disc_number + album_track["disc_number"] = track.disc_number if track_number is None: - track_number = track.track_number + album_track["track_number"] = track.track_number # let mushmumaro instantiate a new object - this will ensure that valididation takes place - return AlbumTrack.from_dict( - { - **track.to_dict(), - "album": album.to_dict(), - "disc_number": disc_number, - "track_number": track_number, - } - ) + return AlbumTrack.from_dict(album_track) @dataclass(kw_only=True) @@ -540,7 +543,7 @@ class BrowseFolder(MediaItem): label: str = "" provider_mappings: set[ProviderMapping] = field(default_factory=set) - def __post_init__(self): + def __post_init__(self) -> None: """Call after init.""" super().__post_init__() if not self.path: @@ -571,7 +574,7 @@ class SearchResults(DataClassDictMixin): radio: list[Radio | ItemMapping] = field(default_factory=list) -def media_from_dict(media_item: dict) -> MediaItemType: +def media_from_dict(media_item: dict[str, Any]) -> MediaItemType | ItemMapping: """Return MediaItem from dict.""" if "provider_mappings" not in media_item: return ItemMapping.from_dict(media_item) @@ -585,4 +588,9 @@ def media_from_dict(media_item: dict) -> MediaItemType: return Playlist.from_dict(media_item) if media_item["media_type"] == "radio": return Radio.from_dict(media_item) - return MediaItem.from_dict(media_item) + raise InvalidDataError("Unknown media type") + + +def is_track(val: MediaItem) -> TypeGuard[Track]: + """Return true if this MediaItem is a track.""" + return val.media_type == MediaType.TRACK diff --git a/music_assistant/common/models/player.py b/music_assistant/common/models/player.py index ac90f491..71610683 100644 --- a/music_assistant/common/models/player.py +++ b/music_assistant/common/models/player.py @@ -33,7 +33,7 @@ class PlayerMedia(DataClassDictMixin): duration: int | None = None # optional queue_id: str | None = None # only present for requests from queue controller queue_item_id: str | None = None # only present for requests from queue controller - custom_data: dict | None = None # optional + custom_data: dict[str, Any] | None = None # optional @dataclass @@ -146,6 +146,6 @@ class Player(DataClassDictMixin): return None @current_item_id.setter - def current_item_id(self, uri: str) -> str | None: + def current_item_id(self, uri: str) -> None: """Set current_item_id (for backwards compatibility).""" self.current_media = PlayerMedia(uri) diff --git a/music_assistant/common/models/player_queue.py b/music_assistant/common/models/player_queue.py index d205ce43..87e37e8e 100644 --- a/music_assistant/common/models/player_queue.py +++ b/music_assistant/common/models/player_queue.py @@ -59,7 +59,7 @@ class PlayerQueue(DataClassDictMixin): return d @classmethod - def from_cache(cls: Self, d: dict[Any, Any]) -> Self: + def from_cache(cls, d: dict[Any, Any]) -> Self: """Restore a PlayerQueue from a cache dict.""" d.pop("current_item", None) d.pop("next_item", None) diff --git a/music_assistant/common/models/provider.py b/music_assistant/common/models/provider.py index 947e6077..7b706b6b 100644 --- a/music_assistant/common/models/provider.py +++ b/music_assistant/common/models/provider.py @@ -53,7 +53,7 @@ class ProviderManifest(DataClassORJSONMixin): mdns_discovery: list[str] | None = None @classmethod - async def parse(cls: ProviderManifest, manifest_file: str) -> ProviderManifest: + async def parse(cls, manifest_file: str) -> ProviderManifest: """Parse ProviderManifest from file.""" return await load_json_file(manifest_file, ProviderManifest) @@ -79,9 +79,9 @@ class SyncTask: provider_domain: str provider_instance: str media_types: tuple[MediaType, ...] - task: asyncio.Task + task: asyncio.Task[None] | None - def to_dict(self, *args, **kwargs) -> dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Return SyncTask as (serializable) dict.""" # ruff: noqa:ARG002 return { diff --git a/music_assistant/common/models/queue_item.py b/music_assistant/common/models/queue_item.py index 7cff789d..3c416342 100644 --- a/music_assistant/common/models/queue_item.py +++ b/music_assistant/common/models/queue_item.py @@ -9,7 +9,7 @@ from uuid import uuid4 from mashumaro import DataClassDictMixin from .enums import MediaType -from .media_items import ItemMapping, MediaItemImage, Radio, Track +from .media_items import ItemMapping, MediaItemImage, Radio, Track, UniqueList, is_track from .streamdetails import StreamDetails @@ -28,7 +28,7 @@ class QueueItem(DataClassDictMixin): image: MediaItemImage | None = None index: int = 0 - def __post_init__(self): + def __post_init__(self) -> None: """Set default values.""" if self.streamdetails and self.streamdetails.stream_title: self.name = self.streamdetails.stream_title @@ -47,7 +47,7 @@ class QueueItem(DataClassDictMixin): @property def uri(self) -> str: """Return uri for this QueueItem (for logging purposes).""" - if self.media_item: + if self.media_item and self.media_item.uri: return self.media_item.uri return self.queue_item_id @@ -63,14 +63,13 @@ class QueueItem(DataClassDictMixin): @classmethod def from_media_item(cls, queue_id: str, media_item: Track | Radio) -> QueueItem: """Construct QueueItem from track/radio item.""" - if media_item.media_type == MediaType.TRACK: + if is_track(media_item): artists = "/".join(x.name for x in media_item.artists) name = f"{artists} - {media_item.name}" # save a lot of data/bandwidth by simplifying nested objects - media_item.artists = [ItemMapping.from_item(x) for x in media_item.artists] + media_item.artists = UniqueList([ItemMapping.from_item(x) for x in media_item.artists]) if media_item.album: media_item.album = ItemMapping.from_item(media_item.album) - media_item.albums = [] else: name = media_item.name return cls( @@ -89,7 +88,7 @@ class QueueItem(DataClassDictMixin): return base @classmethod - def from_cache(cls: Self, d: dict[Any, Any]) -> Self: + def from_cache(cls, d: dict[Any, Any]) -> Self: """Restore a QueueItem from a cache dict.""" d.pop("streamdetails", None) return cls.from_dict(d) diff --git a/music_assistant/server/providers/plex/__init__.py b/music_assistant/server/providers/plex/__init__.py index bbebd7cd..813a2b5d 100644 --- a/music_assistant/server/providers/plex/__init__.py +++ b/music_assistant/server/providers/plex/__init__.py @@ -51,6 +51,7 @@ from music_assistant.common.models.media_items import ( Track, ) from music_assistant.common.models.streamdetails import StreamDetails +from music_assistant.constants import UNKNOWN_ARTIST from music_assistant.server.helpers.auth import AuthenticationHelper from music_assistant.server.helpers.tags import parse_tags from music_assistant.server.models.music_provider import MusicProvider @@ -426,7 +427,7 @@ class PlexProvider(MusicProvider): artist_id = FAKE_ARTIST_PREFIX + artist_name return Artist( item_id=artist_id, - name=artist_name, + name=artist_name or UNKNOWN_ARTIST, provider=self.domain, provider_mappings={ ProviderMapping( @@ -498,7 +499,7 @@ class PlexProvider(MusicProvider): album = Album( item_id=album_id, provider=self.domain, - name=plex_album.title, + name=plex_album.title or "[Unknown]", provider_mappings={ ProviderMapping( item_id=str(album_id), @@ -533,7 +534,7 @@ class PlexProvider(MusicProvider): self._get_item_mapping( MediaType.ARTIST, plex_album.parentKey, - plex_album.parentTitle, + plex_album.parentTitle or UNKNOWN_ARTIST, ) ) return album @@ -546,7 +547,7 @@ class PlexProvider(MusicProvider): raise InvalidDataError(msg) artist = Artist( item_id=artist_id, - name=plex_artist.title, + name=plex_artist.title or UNKNOWN_ARTIST, provider=self.domain, provider_mappings={ ProviderMapping( @@ -575,7 +576,7 @@ class PlexProvider(MusicProvider): playlist = Playlist( item_id=plex_playlist.key, provider=self.domain, - name=plex_playlist.title, + name=plex_playlist.title or "[Unknown]", provider_mappings={ ProviderMapping( item_id=plex_playlist.key, @@ -612,7 +613,7 @@ class PlexProvider(MusicProvider): track = Track( item_id=plex_track.key, provider=self.instance_id, - name=plex_track.title, + name=plex_track.title or "[Unknown]", provider_mappings={ ProviderMapping( item_id=plex_track.key, @@ -639,13 +640,15 @@ class PlexProvider(MusicProvider): # The artist of the track if different from the album's artist. # For this kind of artist, we just know the name, so we create a fake artist, # if it does not already exist. - track.artists.append(await self._get_or_create_artist_by_name(plex_track.originalTitle)) + track.artists.append( + await self._get_or_create_artist_by_name(plex_track.originalTitle or UNKNOWN_ARTIST) + ) elif plex_track.grandparentKey: track.artists.append( self._get_item_mapping( MediaType.ARTIST, plex_track.grandparentKey, - plex_track.grandparentTitle, + plex_track.grandparentTitle or UNKNOWN_ARTIST, ) ) else: diff --git a/mypy.ini b/mypy.ini index cf337fb1..2bca30da 100644 --- a/mypy.ini +++ b/mypy.ini @@ -21,4 +21,4 @@ disallow_untyped_decorators = true disallow_untyped_defs = true warn_return_any = true warn_unreachable = true -packages=tests,music_assistant.client,music_assistant.server.providers.jellyfin +packages=tests,music_assistant.client,music_assistant.common,music_assistant.server.providers.jellyfin