From: Fabian Munkes <105975993+fmunkes@users.noreply.github.com> Date: Thu, 15 May 2025 20:13:16 +0000 (+0200) Subject: Chore: Add mypy for helpers (#2097) X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=31eaf97408c738264bc9a16a1a625a9f452ffa35;p=music-assistant-server.git Chore: Add mypy for helpers (#2097) --- diff --git a/music_assistant/helpers/api.py b/music_assistant/helpers/api.py index 6be15d3a..72b22e8f 100644 --- a/music_assistant/helpers/api.py +++ b/music_assistant/helpers/api.py @@ -42,7 +42,7 @@ class APICommandHandler: if not hasattr(value, "__name__"): continue if value.__name__ == "ItemCls": - type_hints[key] = func.__self__.item_cls + type_hints[key] = func.__self__.item_cls # type: ignore[attr-defined] return APICommandHandler( command=command, signature=inspect.signature(func), @@ -64,7 +64,7 @@ def api_command(command: str) -> Callable[[_F], _F]: def parse_arguments( func_sig: inspect.Signature, func_types: dict[str, Any], - args: dict | None, + args: dict[str, Any] | None, strict: bool = False, ) -> dict[str, Any]: """Parse (and convert) incoming arguments to correct types.""" @@ -161,6 +161,7 @@ def parse_value( # noqa: PLR0911 logging.getLogger(__name__).warning(err) return None if origin is type: + assert isinstance(value, str) # for type checking return eval(value) if value_type is Any: return value @@ -169,9 +170,10 @@ def parse_value( # noqa: PLR0911 raise KeyError(msg) try: - if issubclass(value_type, Enum): # type: ignore[arg-type] - return value_type(value) # type: ignore[operator] - if issubclass(value_type, datetime): # type: ignore[arg-type] + if issubclass(value_type, Enum): + return value_type(value) + if issubclass(value_type, datetime): + assert isinstance(value, str) # for type checking return parse_utc_timestamp(value) except TypeError: # happens if value_type is not a class @@ -190,7 +192,7 @@ def parse_value( # noqa: PLR0911 if value_type is bool and isinstance(value, str | int): return try_parse_bool(value) - if not isinstance(value, value_type): # type: ignore[arg-type] + if not isinstance(value, value_type): # all options failed, raise exception msg = ( f"Value {value} of type {type(value)} is invalid for {name}, " diff --git a/music_assistant/helpers/audio.py b/music_assistant/helpers/audio.py index aa231bf4..d9ccb257 100644 --- a/music_assistant/helpers/audio.py +++ b/music_assistant/helpers/audio.py @@ -31,7 +31,7 @@ from music_assistant_models.errors import ( MusicAssistantError, ProviderUnavailableError, ) -from music_assistant_models.streamdetails import AudioFormat +from music_assistant_models.media_items import AudioFormat from music_assistant.constants import ( CONF_ALLOW_AUDIO_CACHE, @@ -58,10 +58,12 @@ from .util import detect_charset, has_enough_space if TYPE_CHECKING: from music_assistant_models.config_entries import CoreConfig, PlayerConfig from music_assistant_models.player import Player - from music_assistant_models.player_queue import QueueItem + from music_assistant_models.queue_item import QueueItem from music_assistant_models.streamdetails import StreamDetails - from music_assistant import MusicAssistant + from music_assistant.mass import MusicAssistant + from music_assistant.models.music_provider import MusicProvider + from music_assistant.providers.player_group import PlayerGroupProvider LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.audio") @@ -85,6 +87,34 @@ class StreamCache: or when the audio stream is slow itself. """ + def __init__(self, mass: MusicAssistant, streamdetails: StreamDetails) -> None: + """Initialize the StreamCache.""" + self.mass = mass + self.streamdetails = streamdetails + self.logger = LOGGER.getChild("cache") + self._cache_file: str | None = None + self._fetch_task: asyncio.Task[None] | None = None + self._subscribers: int = 0 + self._first_part_received = asyncio.Event() + self._all_data_written: bool = False + self._stream_error: str | None = None + self.org_path: str | None = streamdetails.path + self.org_stream_type: StreamType | None = streamdetails.stream_type + self.org_extra_input_args: list[str] | None = streamdetails.extra_input_args + self.org_audio_format = streamdetails.audio_format + streamdetails.audio_format = AudioFormat( + content_type=ContentType.NUT, + codec_type=streamdetails.audio_format.codec_type, + sample_rate=streamdetails.audio_format.sample_rate, + bit_depth=streamdetails.audio_format.bit_depth, + channels=streamdetails.audio_format.channels, + ) + streamdetails.path = "-" + streamdetails.stream_type = StreamType.CACHE + streamdetails.can_seek = True + streamdetails.allow_seek = True + streamdetails.extra_input_args = [] + async def create(self) -> None: """Create the cache file (if needed).""" if self._cache_file is None: @@ -93,6 +123,7 @@ class StreamCache: ): # we have a mapping stored for this uri, prefer that self._cache_file = cached_cache_path + assert self._cache_file is not None # for type checking if await asyncio.to_thread(os.path.exists, self._cache_file): # cache file already exists from a previous session, # we can simply use that, there is nothing to create @@ -124,6 +155,7 @@ class StreamCache: """Release the cache file.""" self._subscribers -= 1 if self._subscribers <= 0: + assert self._cache_file is not None # for type checking CACHE_FILES_IN_USE.discard(self._cache_file) async def get_audio_stream(self) -> str | AsyncGenerator[bytes, None]: @@ -170,12 +202,17 @@ class StreamCache: async def _create_cache_file(self) -> None: time_start = time.time() self.logger.debug("Creating audio cache for %s", self.streamdetails.uri) + assert self._cache_file is not None # for type checking CACHE_FILES_IN_USE.add(self._cache_file) self._first_part_received.clear() self._all_data_written = False extra_input_args = ["-y", *(self.org_extra_input_args or [])] + audio_source: AsyncGenerator[bytes, None] | str | int if self.org_stream_type == StreamType.CUSTOM: - audio_source = self.mass.get_provider(self.streamdetails.provider).get_audio_stream( + provider = self.mass.get_provider(self.streamdetails.provider) + if TYPE_CHECKING: # avoid circular import + assert isinstance(provider, MusicProvider) + audio_source = provider.get_audio_stream( self.streamdetails, ) elif self.org_stream_type == StreamType.ICY: @@ -183,14 +220,18 @@ class StreamCache: elif self.org_stream_type == StreamType.HLS: if self.streamdetails.media_type == MediaType.RADIO: raise NotImplementedError("Caching of this streamtype is not supported!") + assert self.org_path is not None # for type checking substream = await get_hls_substream(self.mass, self.org_path) audio_source = substream.path elif self.org_stream_type == StreamType.ENCRYPTED_HTTP: + assert self.org_path is not None # for type checking + assert self.streamdetails.decryption_key is not None # for type checking audio_source = self.org_path extra_input_args += ["-decryption_key", self.streamdetails.decryption_key] elif self.org_stream_type == StreamType.MULTI_FILE: audio_source = get_multi_file_stream(self.mass, self.streamdetails) else: + assert self.org_path is not None # for type checking audio_source = self.org_path # we always use ffmpeg to fetch the original audio source @@ -199,15 +240,15 @@ class StreamCache: # and it also accounts for complicated cases such as encrypted streams or # m4a/mp4 streams with the moov atom at the end of the file. # ffmpeg will produce a lossless copy of the original codec. + ffmpeg_proc = FFMpeg( + audio_input=audio_source, + input_format=self.org_audio_format, + output_format=self.streamdetails.audio_format, + extra_input_args=extra_input_args, + audio_output=self._cache_file, + collect_log_history=True, + ) try: - ffmpeg_proc = FFMpeg( - audio_input=audio_source, - input_format=self.org_audio_format, - output_format=self.streamdetails.audio_format, - extra_input_args=extra_input_args, - audio_output=self._cache_file, - collect_log_history=True, - ) await ffmpeg_proc.start() # wait until the first data is written to the cache file while ffmpeg_proc.returncode is None: @@ -249,7 +290,7 @@ class StreamCache: await self._remove_cache_file() # unblock the waiting tasks by setting the event # this will allow the tasks to continue and handle the error - self._stream_error = str(err) or err.__qualname__ + self._stream_error = str(err) or err.__qualname__ # type: ignore [attr-defined] self._first_part_received.set() finally: await ffmpeg_proc.close() @@ -258,36 +299,9 @@ class StreamCache: self._first_part_received.clear() self._all_data_written = False self._fetch_task = None + assert self._cache_file is not None # for type checking await remove_file(self._cache_file) - def __init__(self, mass: MusicAssistant, streamdetails: StreamDetails) -> None: - """Initialize the StreamCache.""" - self.mass = mass - self.streamdetails = streamdetails - self.logger = LOGGER.getChild("cache") - self._cache_file: str | None = None - self._fetch_task: asyncio.Task | None = None - self._subscribers: int = 0 - self._first_part_received = asyncio.Event() - self._all_data_written: bool = False - self._stream_error: str | None = None - self.org_path: str | None = streamdetails.path - self.org_stream_type: StreamType | None = streamdetails.stream_type - self.org_extra_input_args: list[str] | None = streamdetails.extra_input_args - self.org_audio_format = streamdetails.audio_format - streamdetails.audio_format = AudioFormat( - content_type=ContentType.NUT, - codec_type=streamdetails.audio_format.codec_type, - sample_rate=streamdetails.audio_format.sample_rate, - bit_depth=streamdetails.audio_format.bit_depth, - channels=streamdetails.audio_format.channels, - ) - streamdetails.path = "-" - streamdetails.stream_type = StreamType.CACHE - streamdetails.can_seek = True - streamdetails.allow_seek = True - streamdetails.extra_input_args = [] - async def crossfade_pcm_parts( fade_in_part: bytes, @@ -439,7 +453,7 @@ async def strip_silence( def get_player_dsp_details( - mass: MusicAssistant, player: Player, group_preventing_dsp=False + mass: MusicAssistant, player: Player, group_preventing_dsp: bool = False ) -> DSPDetails: """Return DSP details of single a player. @@ -479,6 +493,7 @@ def get_stream_dsp_details( """Return DSP details of all players playing this queue, keyed by player_id.""" player = mass.players.get(queue_id) dsp: dict[str, DSPDetails] = {} + assert player is not None # for type checking group_preventing_dsp = is_grouping_preventing_dsp(player) output_format = None is_external_group = False @@ -488,6 +503,8 @@ def get_stream_dsp_details( try: # We need a bit of a hack here since only the leader knows the correct output format provider = mass.get_provider(player.provider) + if TYPE_CHECKING: # avoid circular import + assert isinstance(provider, PlayerGroupProvider) if provider: output_format = provider._get_sync_leader(player).output_format except RuntimeError: @@ -557,6 +574,7 @@ async def get_stream_details( else: # retrieve streamdetails from provider media_item = queue_item.media_item + assert media_item is not None # for type checking # sort by quality and check item's availability for prov_media in sorted( media_item.provider_mappings, key=lambda x: x.quality or 0, reverse=True @@ -566,12 +584,14 @@ async def get_stream_details( continue # guard that provider is available music_prov = mass.get_provider(prov_media.provider_instance) + if TYPE_CHECKING: # avoid circular import + assert isinstance(music_prov, MusicProvider) if not music_prov: LOGGER.debug(f"Skipping {prov_media} - provider not available") continue # provider not available ? # get streamdetails from provider try: - streamdetails: StreamDetails = await music_prov.get_stream_details( + streamdetails = await music_prov.get_stream_details( prov_media.item_id, media_item.media_type ) except MusicAssistantError as err: @@ -587,6 +607,7 @@ async def get_stream_details( streamdetails.stream_type in (StreamType.ICY, StreamType.HLS, StreamType.HTTP) and streamdetails.media_type == MediaType.RADIO ): + assert streamdetails.path is not None # for type checking resolved_url, stream_type = await resolve_radio_stream(mass, streamdetails.path) streamdetails.path = resolved_url streamdetails.stream_type = stream_type @@ -611,7 +632,7 @@ async def get_stream_details( player_settings = await mass.config.get_player_config(streamdetails.queue_id) core_config = await mass.config.get_core_config("streams") streamdetails.target_loudness = float( - player_settings.get_value(CONF_VOLUME_NORMALIZATION_TARGET) + str(player_settings.get_value(CONF_VOLUME_NORMALIZATION_TARGET)) ) streamdetails.volume_normalization_mode = _get_normalization_mode( core_config, player_settings, streamdetails @@ -714,6 +735,8 @@ async def get_media_stream( extra_input_args = streamdetails.extra_input_args or [] strip_silence_begin = streamdetails.strip_silence_begin strip_silence_end = streamdetails.strip_silence_end + if filter_params is None: + filter_params = [] if streamdetails.fade_in: filter_params.append("afade=type=in:start_time=0:duration=3") strip_silence_begin = False @@ -726,13 +749,18 @@ async def get_media_stream( elif stream_type == StreamType.MULTI_FILE: audio_source = get_multi_file_stream(mass, streamdetails) elif stream_type == StreamType.CUSTOM: - audio_source = mass.get_provider(streamdetails.provider).get_audio_stream( + music_prov = mass.get_provider(streamdetails.provider) + if TYPE_CHECKING: # avoid circular import + assert isinstance(music_prov, MusicProvider) + audio_source = music_prov.get_audio_stream( streamdetails, seek_position=streamdetails.seek_position if streamdetails.can_seek else 0, ) elif stream_type == StreamType.ICY: + assert streamdetails.path is not None # for type checking audio_source = get_icy_radio_stream(mass, streamdetails.path, streamdetails) elif stream_type == StreamType.HLS: + assert streamdetails.path is not None # for type checking substream = await get_hls_substream(mass, streamdetails.path) audio_source = substream.path if streamdetails.media_type == MediaType.RADIO: @@ -741,9 +769,12 @@ async def get_media_stream( # so we tell ffmpeg to loop around in this case. extra_input_args += ["-stream_loop", "-1", "-re"] elif stream_type == StreamType.ENCRYPTED_HTTP: + assert streamdetails.path is not None # for type checking + assert streamdetails.decryption_key is not None # for type checking audio_source = streamdetails.path extra_input_args += ["-decryption_key", streamdetails.decryption_key] else: + assert streamdetails.path is not None # for type checking audio_source = streamdetails.path # handle seek support @@ -774,6 +805,7 @@ async def get_media_stream( try: await ffmpeg_proc.start() + assert ffmpeg_proc.proc is not None # for type checking logger.debug( "Started media stream for %s" " - using streamtype: %s" @@ -886,7 +918,7 @@ async def get_media_stream( streamdetails.seconds_streamed = seconds_streamed # store accurate duration if finished and not streamdetails.seek_position and seconds_streamed: - streamdetails.duration = seconds_streamed + streamdetails.duration = int(seconds_streamed) # release cache if needed if cache := streamdetails.cache: @@ -936,10 +968,14 @@ async def get_media_stream( if (finished or seconds_streamed >= 30) and ( music_prov := mass.get_provider(streamdetails.provider) ): + if TYPE_CHECKING: # avoid circular import + assert isinstance(music_prov, MusicProvider) mass.create_task(music_prov.on_streamed(streamdetails)) -def create_wave_header(samplerate=44100, channels=2, bitspersample=16, duration=None): +def create_wave_header( + samplerate: int = 44100, channels: int = 2, bitspersample: int = 16, duration: int | None = None +) -> bytes: """Generate a wave header from given params.""" file = BytesIO() @@ -1005,7 +1041,7 @@ async def resolve_radio_stream(mass: MusicAssistant, url: str) -> tuple[str, Str """ cache_base_key = "resolved_radio_info" if cache := await mass.cache.get(url, base_key=cache_base_key): - return cache + return cast("tuple[str, StreamType]", cache) stream_type = StreamType.HTTP resolved_url = url timeout = ClientTimeout(total=0, connect=10, sock_read=5) @@ -1024,7 +1060,7 @@ async def resolve_radio_stream(mass: MusicAssistant, url: str) -> tuple[str, Str or ".m3u?" in url or ".m3u8?" in url or ".pls?" in url - or "audio/x-mpegurl" in headers.get("content-type") + or "audio/x-mpegurl" in headers.get("content-type", "") or "audio/x-scpls" in headers.get("content-type", "") ): # url is playlist, we need to unfold it @@ -1074,15 +1110,15 @@ async def get_icy_radio_stream( if not meta_data: continue meta_data = meta_data.rstrip(b"\0") - stream_title = re.search(rb"StreamTitle='([^']*)';", meta_data) - if not stream_title: + stream_title_re = re.search(rb"StreamTitle='([^']*)';", meta_data) + if not stream_title_re: continue try: # in 99% of the cases the stream title is utf-8 encoded - stream_title = stream_title.group(1).decode("utf-8") + stream_title = stream_title_re.group(1).decode("utf-8") except UnicodeDecodeError: # fallback to iso-8859-1 - stream_title = stream_title.group(1).decode("iso-8859-1", errors="replace") + stream_title = stream_title_re.group(1).decode("iso-8859-1", errors="replace") cleaned_stream_title = clean_stream_title(stream_title) if cleaned_stream_title != streamdetails.stream_title: LOGGER.log( @@ -1127,7 +1163,12 @@ async def get_hls_substream( return PlaylistItem(path=url, key=substreams[0].key) # sort substreams on best quality (highest bandwidth) when available if any(x for x in substreams if x.stream_info): - substreams.sort(key=lambda x: int(x.stream_info.get("BANDWIDTH", "0")), reverse=True) + substreams.sort( + key=lambda x: int( + x.stream_info.get("BANDWIDTH", "0") if x.stream_info is not None else 0 + ), + reverse=True, + ) substream = substreams[0] if not substream.path.startswith("http"): # path is relative, stitch it together @@ -1159,6 +1200,7 @@ async def get_http_stream( timeout = ClientTimeout(total=0, connect=30, sock_read=5 * 60) skip_bytes = 0 if seek_position and streamdetails.size: + assert streamdetails.duration is not None # for type checking skip_bytes = int(streamdetails.size / streamdetails.duration * seek_position) headers["Range"] = f"bytes={skip_bytes}-{streamdetails.size}" @@ -1238,6 +1280,7 @@ async def get_file_stream( chunk_size = get_chunksize(streamdetails.audio_format) async with aiofiles.open(streamdetails.data, "rb") as _file: if seek_position: + assert streamdetails.duration is not None # for type checking seek_pos = int((streamdetails.size / streamdetails.duration) * seek_position) await _file.seek(seek_pos) # yield chunks of data from file @@ -1286,11 +1329,18 @@ async def get_preview_stream( """Create a 30 seconds preview audioclip for the given streamdetails.""" if not (music_prov := mass.get_provider(provider_instance_id_or_domain)): raise ProviderUnavailableError + if TYPE_CHECKING: # avoid circular import + assert isinstance(music_prov, MusicProvider) streamdetails = await music_prov.get_stream_details(item_id, media_type) + + audio_input: AsyncGenerator[bytes, None] | str + if streamdetails.stream_type == StreamType.CUSTOM: + audio_input = music_prov.get_audio_stream(streamdetails, 30) + else: + assert streamdetails.path is not None # for type checking + audio_input = streamdetails.path async for chunk in get_ffmpeg_stream( - audio_input=music_prov.get_audio_stream(streamdetails, 30) - if streamdetails.stream_type == StreamType.CUSTOM - else streamdetails.path, + audio_input=audio_input, input_format=streamdetails.audio_format, output_format=AudioFormat(content_type=ContentType.AAC), extra_input_args=["-to", "30"], @@ -1401,11 +1451,12 @@ def is_output_limiter_enabled(mass: MusicAssistant, player: Player) -> bool: elif player.synced_to: # Not in sync group, but synced, get from the leader deciding_player_id = player.synced_to - return mass.config.get_raw_player_config_value( + output_limiter_enabled = mass.config.get_raw_player_config_value( deciding_player_id, CONF_ENTRY_OUTPUT_LIMITER.key, CONF_ENTRY_OUTPUT_LIMITER.default_value, ) + return bool(output_limiter_enabled) def get_player_filter_params( @@ -1434,7 +1485,8 @@ def get_player_filter_params( # We can still apply the DSP of that single player. if player.group_childs: child_player = mass.players.get(player.group_childs[0]) - dsp = mass.config.get_player_dsp_config(child_player) + assert child_player is not None # for type checking + dsp = mass.config.get_player_dsp_config(child_player.player_id) else: # This should normally never happen, but if it does, we disable DSP. dsp.enabled = False @@ -1525,16 +1577,23 @@ async def analyze_loudness( elif streamdetails.stream_type == StreamType.MULTI_FILE: audio_source = get_multi_file_stream(mass, streamdetails) elif streamdetails.stream_type == StreamType.CUSTOM: - audio_source = mass.get_provider(streamdetails.provider).get_audio_stream( + music_prov = mass.get_provider(streamdetails.provider) + if TYPE_CHECKING: # avoid circular import + assert isinstance(music_prov, MusicProvider) + audio_source = music_prov.get_audio_stream( streamdetails, ) elif streamdetails.stream_type == StreamType.HLS: + assert streamdetails.path is not None # for type checking substream = await get_hls_substream(mass, streamdetails.path) audio_source = substream.path elif streamdetails.stream_type == StreamType.ENCRYPTED_HTTP: + assert streamdetails.path is not None # for type checking + assert streamdetails.decryption_key is not None # for type checking audio_source = streamdetails.path extra_input_args += ["-decryption_key", streamdetails.decryption_key] else: + assert streamdetails.path is not None # for type checking audio_source = streamdetails.path # calculate BS.1770 R128 integrated loudness with ffmpeg @@ -1593,10 +1652,12 @@ def _get_normalization_mode( return VolumeNormalizationMode.DISABLED # work out preference for track or radio preference = VolumeNormalizationMode( - core_config.get_value( - CONF_VOLUME_NORMALIZATION_RADIO - if streamdetails.media_type == MediaType.RADIO - else CONF_VOLUME_NORMALIZATION_TRACKS, + str( + core_config.get_value( + CONF_VOLUME_NORMALIZATION_RADIO + if streamdetails.media_type == MediaType.RADIO + else CONF_VOLUME_NORMALIZATION_TRACKS, + ) ) ) diff --git a/music_assistant/helpers/auth.py b/music_assistant/helpers/auth.py index 9f3465a5..8a5ba50c 100644 --- a/music_assistant/helpers/auth.py +++ b/music_assistant/helpers/auth.py @@ -12,7 +12,7 @@ from music_assistant_models.enums import EventType from music_assistant_models.errors import LoginFailed if TYPE_CHECKING: - from music_assistant import MusicAssistant + from music_assistant.mass import MusicAssistant LOGGER = logging.getLogger(__name__) @@ -51,6 +51,7 @@ class AuthenticationHelper: ) -> bool | None: """Exit context manager.""" self.mass.webserver.unregister_dynamic_route(self._cb_path, "GET") + return None async def authenticate(self, auth_url: str, timeout: int = 60) -> dict[str, str]: """Start the auth process and return any query params if received on the callback.""" diff --git a/music_assistant/helpers/compare.py b/music_assistant/helpers/compare.py index 2162aaed..4f2e9150 100644 --- a/music_assistant/helpers/compare.py +++ b/music_assistant/helpers/compare.py @@ -36,22 +36,38 @@ def compare_media_item( ) -> bool | None: """Compare two media items and return True if they match.""" if base_item.media_type == MediaType.ARTIST and compare_item.media_type == MediaType.ARTIST: + assert isinstance(base_item, Artist | ItemMapping) # for type checking + assert isinstance(compare_item, Artist | ItemMapping) # for type checking return compare_artist(base_item, compare_item, strict) if base_item.media_type == MediaType.ALBUM and compare_item.media_type == MediaType.ALBUM: + assert isinstance(base_item, Album | ItemMapping) # for type checking + assert isinstance(compare_item, Album | ItemMapping) # for type checking return compare_album(base_item, compare_item, strict) if base_item.media_type == MediaType.TRACK and compare_item.media_type == MediaType.TRACK: + assert isinstance(base_item, Track) # for type checking + assert isinstance(compare_item, Track) # for type checking return compare_track(base_item, compare_item, strict) if base_item.media_type == MediaType.PLAYLIST and compare_item.media_type == MediaType.PLAYLIST: + assert isinstance(base_item, Playlist | ItemMapping) # for type checking + assert isinstance(compare_item, Playlist | ItemMapping) # for type checking return compare_playlist(base_item, compare_item, strict) if base_item.media_type == MediaType.RADIO and compare_item.media_type == MediaType.RADIO: + assert isinstance(base_item, Radio | ItemMapping) # for type checking + assert isinstance(compare_item, Radio | ItemMapping) # for type checking return compare_radio(base_item, compare_item, strict) if ( base_item.media_type == MediaType.AUDIOBOOK and compare_item.media_type == MediaType.AUDIOBOOK ): + assert isinstance(base_item, Audiobook | ItemMapping) # for type checking + assert isinstance(compare_item, Audiobook | ItemMapping) # for type checking return compare_audiobook(base_item, compare_item, strict) if base_item.media_type == MediaType.PODCAST and compare_item.media_type == MediaType.PODCAST: + assert isinstance(base_item, Podcast | ItemMapping) # for type checking + assert isinstance(compare_item, Podcast | ItemMapping) # for type checking return compare_podcast(base_item, compare_item, strict) + assert isinstance(base_item, ItemMapping) # for type checking + assert isinstance(compare_item, ItemMapping) # for type checking return compare_item_mapping(base_item, compare_item, strict) @@ -61,8 +77,6 @@ def compare_artist( strict: bool = True, ) -> bool | None: """Compare two artist items and return True if they match.""" - if base_item is None or compare_item is None: - return False # return early on exact item_id match if compare_item_ids(base_item, compare_item): return True @@ -78,13 +92,11 @@ def compare_artist( def compare_album( - base_item: Album | ItemMapping | None, - compare_item: Album | ItemMapping | None, + base_item: Album | ItemMapping, + compare_item: Album | ItemMapping, strict: bool = True, ) -> bool | None: """Compare two album items and return True if they match.""" - if base_item is None or compare_item is None: - return False # return early on exact item_id match if compare_item_ids(base_item, compare_item): return True @@ -125,14 +137,12 @@ def compare_album( def compare_track( - base_item: Track | None, - compare_item: Track | None, + base_item: Track, + compare_item: Track, strict: bool = True, track_albums: list[Album] | None = None, ) -> bool: """Compare two track items and return True if they match.""" - if base_item is None or compare_item is None: - return False # return early on exact item_id match if compare_item_ids(base_item, compare_item): return True @@ -243,8 +253,6 @@ def compare_playlist( strict: bool = True, ) -> bool | None: """Compare two Playlist items and return True if they match.""" - if base_item is None or compare_item is None: - return False # require (exact) name match if not compare_strings(base_item.name, compare_item.name, strict=strict): return False @@ -262,8 +270,6 @@ def compare_radio( strict: bool = True, ) -> bool | None: """Compare two Radio items and return True if they match.""" - if base_item is None or compare_item is None: - return False # return early on exact item_id match if compare_item_ids(base_item, compare_item): return True @@ -275,13 +281,11 @@ def compare_radio( def compare_audiobook( - base_item: Audiobook | ItemMapping | None, - compare_item: Audiobook | ItemMapping | None, + base_item: Audiobook | ItemMapping, + compare_item: Audiobook | ItemMapping, strict: bool = True, ) -> bool | None: """Compare two Audiobook items and return True if they match.""" - if base_item is None or compare_item is None: - return False # return early on exact item_id match if compare_item_ids(base_item, compare_item): return True @@ -324,13 +328,11 @@ def compare_audiobook( def compare_podcast( - base_item: Podcast | ItemMapping | None, - compare_item: Podcast | ItemMapping | None, + base_item: Podcast | ItemMapping, + compare_item: Podcast | ItemMapping, strict: bool = True, ) -> bool | None: """Compare two Podcast items and return True if they match.""" - if base_item is None or compare_item is None: - return False # return early on exact item_id match if compare_item_ids(base_item, compare_item): return True @@ -371,15 +373,17 @@ def compare_item_mapping( strict: bool = True, ) -> bool | None: """Compare two ItemMapping items and return True if they match.""" - if base_item is None or compare_item is None: - return False # return early on exact item_id match if compare_item_ids(base_item, compare_item): return True # return early on (un)matched external id - external_id_match = compare_external_ids(base_item.external_ids, compare_item.external_ids) - if external_id_match is not None: - return external_id_match + # check all ExternalID, as ItemMapping is a minimized obj for all MediaItems + for ext_id in ExternalID: + external_id_match = compare_external_ids( + base_item.external_ids, compare_item.external_ids, ext_id + ) + if external_id_match is not None: + return external_id_match # compare version if not compare_version(base_item.version, compare_item.version): return False @@ -440,6 +444,7 @@ def compare_item_ids( compare_prov_ids = getattr(compare_item, "provider_mappings", None) if base_prov_ids is not None: + assert isinstance(base_item, MediaItem) # for type checking for prov_l in base_item.provider_mappings: if ( prov_l.provider_domain == compare_item.provider @@ -448,11 +453,14 @@ def compare_item_ids( return True if compare_prov_ids is not None: + assert isinstance(compare_item, MediaItem) # for type checking for prov_r in compare_item.provider_mappings: if prov_r.provider_domain == base_item.provider and prov_r.item_id == base_item.item_id: return True if base_prov_ids is not None and compare_prov_ids is not None: + assert isinstance(base_item, MediaItem) # for type checking + assert isinstance(compare_item, MediaItem) # for type checking for prov_l in base_item.provider_mappings: for prov_r in compare_item.provider_mappings: if prov_l.provider_domain != prov_r.provider_domain: diff --git a/music_assistant/helpers/database.py b/music_assistant/helpers/database.py index 578f666f..c393d386 100644 --- a/music_assistant/helpers/database.py +++ b/music_assistant/helpers/database.py @@ -6,16 +6,17 @@ import asyncio import logging import os import time +from collections.abc import Mapping from contextlib import asynccontextmanager from sqlite3 import OperationalError -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import aiosqlite from music_assistant.constants import MASS_LOGGER_NAME if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Mapping + from collections.abc import AsyncGenerator LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.database") @@ -23,7 +24,9 @@ ENABLE_DEBUG = os.environ.get("PYTHONDEVMODE") == "1" @asynccontextmanager -async def debug_query(sql_query: str, query_params: dict | None = None): +async def debug_query( + sql_query: str, query_params: dict[str, Any] | None = None +) -> AsyncGenerator[None]: """Time the processing time of an sql query.""" if not ENABLE_DEBUG: yield @@ -46,7 +49,7 @@ async def debug_query(sql_query: str, query_params: dict | None = None): def query_params(query: str, params: dict[str, Any] | None) -> tuple[str, dict[str, Any]]: """Extend query parameters support.""" if params is None: - return (query, params) + return (query, {}) count = 0 result_query = query result_params = {} @@ -100,11 +103,11 @@ class DatabaseConnection: async def get_rows( self, table: str, - match: dict | None = None, + match: dict[str, Any] | None = None, order_by: str | None = None, limit: int = 500, offset: int = 0, - ) -> list[Mapping]: + ) -> list[Mapping[str, Any]]: """Get all rows for given table.""" sql_query = f"SELECT * FROM {table}" if match is not None: @@ -114,26 +117,28 @@ class DatabaseConnection: if limit: sql_query += f" LIMIT {limit} OFFSET {offset}" async with debug_query(sql_query): - return await self._db.execute_fetchall(sql_query, match) + return cast( + "list[Mapping[str, Any]]", await self._db.execute_fetchall(sql_query, match) + ) async def get_rows_from_query( self, query: str, - params: dict | None = None, + params: dict[str, Any] | None = None, limit: int = 500, offset: int = 0, - ) -> list[Mapping]: + ) -> list[Mapping[str, Any]]: """Get all rows for given custom query.""" if limit: query += f" LIMIT {limit} OFFSET {offset}" _query, _params = query_params(query, params) async with debug_query(_query, _params): - return await self._db.execute_fetchall(_query, _params) + return cast("list[Mapping[str, Any]]", await self._db.execute_fetchall(_query, _params)) async def get_count_from_query( self, query: str, - params: dict | None = None, + params: dict[str, Any] | None = None, ) -> int: """Get row count for given custom query.""" query = f"SELECT count() FROM ({query})" @@ -141,6 +146,7 @@ class DatabaseConnection: async with debug_query(_query): async with self._db.execute(_query, _params) as cursor: if result := await cursor.fetchone(): + assert isinstance(result[0], int) # for type checking return result[0] return 0 @@ -153,22 +159,27 @@ class DatabaseConnection: async with debug_query(query): async with self._db.execute(query) as cursor: if result := await cursor.fetchone(): + assert isinstance(result[0], int) # for type checking return result[0] return 0 - async def search(self, table: str, search: str, column: str = "name") -> list[Mapping]: + async def search( + self, table: str, search: str, column: str = "name" + ) -> list[Mapping[str, Any]]: """Search table by column.""" sql_query = f"SELECT * FROM {table} WHERE {table}.{column} LIKE :search" params = {"search": f"%{search}%"} async with debug_query(sql_query, params): - return await self._db.execute_fetchall(sql_query, params) + return cast( + "list[Mapping[str, Any]]", await self._db.execute_fetchall(sql_query, params) + ) - async def get_row(self, table: str, match: dict[str, Any]) -> Mapping | None: + async def get_row(self, table: str, match: dict[str, Any]) -> Mapping[str, Any] | None: """Get single row for given table where column matches keys/values.""" sql_query = f"SELECT * FROM {table} WHERE " sql_query += " AND ".join(f"{table}.{x} = :{x}" for x in match) async with debug_query(sql_query, match), self._db.execute(sql_query, match) as cursor: - return await cursor.fetchone() + return cast("Mapping[str, Any] | None", await cursor.fetchone()) async def insert( self, @@ -185,9 +196,11 @@ class DatabaseConnection: sql_query += f" VALUES ({','.join(f':{x}' for x in keys)})" row_id = await self._db.execute_insert(sql_query, values) await self._db.commit() + assert row_id is not None # for type checking + assert isinstance(row_id[0], int) # for type checking return row_id[0] - async def insert_or_replace(self, table: str, values: dict[str, Any]) -> Mapping: + async def insert_or_replace(self, table: str, values: dict[str, Any]) -> int: """Insert or replace data in given table.""" return await self.insert(table=table, values=values, allow_replace=True) @@ -196,7 +209,7 @@ class DatabaseConnection: table: str, match: dict[str, Any], values: dict[str, Any], - ) -> Mapping: + ) -> Mapping[str, Any]: """Update record.""" keys = tuple(values.keys()) sql_query = f"UPDATE {table} SET {','.join(f'{x}=:{x}' for x in keys)} WHERE " @@ -204,9 +217,13 @@ class DatabaseConnection: await self.execute(sql_query, {**match, **values}) await self._db.commit() # return updated item - return await self.get_row(table, match) + updated_item = await self.get_row(table, match) + assert updated_item is not None # for type checking + return updated_item - async def delete(self, table: str, match: dict | None = None, query: str | None = None) -> None: + async def delete( + self, table: str, match: dict[str, Any] | None = None, query: str | None = None + ) -> None: """Delete data in given table.""" assert not (query and "where" in query.lower()) sql_query = f"DELETE FROM {table} " @@ -225,7 +242,7 @@ class DatabaseConnection: await self.execute(sql_query) await self._db.commit() - async def execute(self, query: str, values: dict | None = None) -> Any: + async def execute(self, query: str, values: dict[str, Any] | None = None) -> Any: """Execute command on the database.""" return await self._db.execute(query, values) @@ -236,8 +253,8 @@ class DatabaseConnection: async def iter_items( self, table: str, - match: dict | None = None, - ) -> AsyncGenerator[Mapping, None]: + match: dict[str, Any] | None = None, + ) -> AsyncGenerator[Mapping[str, Any], None]: """Iterate all items within a table.""" limit: int = 500 offset: int = 0 diff --git a/music_assistant/helpers/dsp.py b/music_assistant/helpers/dsp.py index 1d741fb1..b22ce830 100644 --- a/music_assistant/helpers/dsp.py +++ b/music_assistant/helpers/dsp.py @@ -9,7 +9,7 @@ from music_assistant_models.dsp import ( ParametricEQFilter, ToneControlFilter, ) -from music_assistant_models.streamdetails import AudioFormat +from music_assistant_models.media_items.audio_format import AudioFormat # ruff: noqa: PLR0915 @@ -38,7 +38,9 @@ def filter_to_ffmpeg_params(dsp_filter: DSPFilter, input_format: AudioFormat) -> # Get gain for this channel, default to 0 if not specified gain_db = dsp_filter.per_channel_preamp.get(channel_id, 0) # Apply both the overall preamp and the per-channel preamp - total_gain_db = dsp_filter.preamp + gain_db + total_gain_db = ( + dsp_filter.preamp + gain_db if dsp_filter.preamp is not None else gain_db + ) if total_gain_db != 0: # Convert dB to linear gain gain = 10 ** (total_gain_db / 20) diff --git a/music_assistant/helpers/ffmpeg.py b/music_assistant/helpers/ffmpeg.py index fa9c2b8a..e9368446 100644 --- a/music_assistant/helpers/ffmpeg.py +++ b/music_assistant/helpers/ffmpeg.py @@ -57,9 +57,10 @@ class FFMpeg(AsyncProcess): self.input_format = input_format self.collect_log_history = collect_log_history self.log_history: deque[str] = deque(maxlen=100) - self._stdin_task: asyncio.Task | None = None - self._logger_task: asyncio.Task | None = None + self._stdin_task: asyncio.Task[None] | None = None + self._logger_task: asyncio.Task[None] | None = None self._input_codec_parsed = False + stdin: bool | int if audio_input == "-" or isinstance(audio_input, AsyncGenerator): stdin = True else: @@ -161,8 +162,8 @@ class FFMpeg(AsyncProcess): async def _feed_stdin(self) -> None: """Feed stdin with audio chunks from an AsyncGenerator.""" - if TYPE_CHECKING: - self.audio_input: AsyncGenerator[bytes, None] + assert not isinstance(self.audio_input, str | int) + generator_exhausted = False cancelled = False try: diff --git a/music_assistant/helpers/images.py b/music_assistant/helpers/images.py index 99e34c40..2fd31b9e 100644 --- a/music_assistant/helpers/images.py +++ b/music_assistant/helpers/images.py @@ -9,7 +9,7 @@ import random from base64 import b64decode from collections.abc import Iterable from io import BytesIO -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import aiofiles from aiohttp.client_exceptions import ClientError @@ -17,19 +17,20 @@ from PIL import Image, UnidentifiedImageError from music_assistant.helpers.tags import get_embedded_image from music_assistant.models.metadata_provider import MetadataProvider +from music_assistant.models.music_provider import MusicProvider if TYPE_CHECKING: from music_assistant_models.media_items import MediaItemImage + from PIL.Image import Image as ImageClass - from music_assistant import MusicAssistant - from music_assistant.models.music_provider import MusicProvider + from music_assistant.mass import MusicAssistant async def get_image_data(mass: MusicAssistant, path_or_url: str, provider: str) -> bytes: """Create thumbnail from image url.""" # TODO: add local cache here ! if prov := mass.get_provider(provider): - prov: MusicProvider | MetadataProvider + assert isinstance(prov, MusicProvider | MetadataProvider) if resolved_image := await prov.resolve_image(path_or_url): if isinstance(resolved_image, bytes): return resolved_image @@ -49,7 +50,7 @@ async def get_image_data(mass: MusicAssistant, path_or_url: str, provider: str) if path_or_url.endswith(("jpg", "JPG", "png", "PNG", "jpeg")): if await asyncio.to_thread(os.path.isfile, path_or_url): async with aiofiles.open(path_or_url, "rb") as _file: - return await _file.read() + return cast("bytes", await _file.read()) # use ffmpeg for embedded images if img_data := await get_embedded_image(path_or_url): return img_data @@ -72,7 +73,7 @@ async def get_image_thumb( if not size and image_format.encode() in img_data: return img_data - def _create_image(): + def _create_image() -> bytes: data = BytesIO() try: img = Image.open(BytesIO(img_data)) @@ -90,12 +91,14 @@ async def get_image_thumb( async def create_collage( - mass: MusicAssistant, images: Iterable[MediaItemImage], dimensions: tuple[int] = (1500, 1500) + mass: MusicAssistant, + images: Iterable[MediaItemImage], + dimensions: tuple[int, int] = (1500, 1500), ) -> bytes: """Create a basic collage image from multiple image urls.""" image_size = 250 - def _new_collage(): + def _new_collage() -> ImageClass: return Image.new("RGB", (dimensions[0], dimensions[1]), color=(255, 255, 255, 255)) collage = await asyncio.to_thread(_new_collage) @@ -122,7 +125,7 @@ async def create_collage( del img_data break - def _save_collage(): + def _save_collage() -> bytes: final_data = BytesIO() collage.convert("RGB").save(final_data, "JPEG", optimize=True) return final_data.getvalue() @@ -136,4 +139,5 @@ async def get_icon_string(icon_path: str) -> str: assert ext == "svg" async with aiofiles.open(icon_path) as _file: xml_data = await _file.read() + assert isinstance(xml_data, str) # for type checking return xml_data.replace("\n", "").strip() diff --git a/music_assistant/helpers/playlists.py b/music_assistant/helpers/playlists.py index f639d26f..837d6f47 100644 --- a/music_assistant/helpers/playlists.py +++ b/music_assistant/helpers/playlists.py @@ -8,13 +8,13 @@ from dataclasses import dataclass from typing import TYPE_CHECKING from urllib.parse import urlparse -from aiohttp import client_exceptions +from aiohttp import ClientTimeout, client_exceptions from music_assistant_models.errors import InvalidDataError from music_assistant.helpers.util import detect_charset if TYPE_CHECKING: - from music_assistant import MusicAssistant + from music_assistant.mass import MusicAssistant LOGGER = logging.getLogger(__name__) @@ -60,7 +60,7 @@ def parse_m3u(m3u_data: str) -> list[PlaylistItem]: length = None title = None - stream_info = None + stream_info: dict[str, str] | None = None key = None for line in m3u_lines: @@ -148,7 +148,9 @@ async def fetch_playlist( ) -> list[PlaylistItem]: """Parse an online m3u or pls playlist.""" try: - async with mass.http_session.get(url, allow_redirects=True, timeout=5) as resp: + async with mass.http_session.get( + url, allow_redirects=True, timeout=ClientTimeout(total=5) + ) as resp: try: raw_data = await resp.content.read(64 * 1024) # NOTE: using resp.charset is not reliable, we need to detect it ourselves diff --git a/music_assistant/helpers/process.py b/music_assistant/helpers/process.py index b18f0a6b..ff65ca73 100644 --- a/music_assistant/helpers/process.py +++ b/music_assistant/helpers/process.py @@ -54,7 +54,7 @@ class AsyncProcess: self._stdout = None if stdout is False else stdout self._stderr = asyncio.subprocess.DEVNULL if stderr is False else stderr self._close_called = False - self._returncode: bool | None = None + self._returncode: int | None = None @property def closed(self) -> bool: @@ -87,6 +87,7 @@ class AsyncProcess: # send interrupt signal to process when we're cancelled await self.close(send_signal=exc_type in (GeneratorExit, asyncio.CancelledError)) self._returncode = self.returncode + return None async def start(self) -> None: """Perform Async init of process.""" @@ -120,6 +121,8 @@ class AsyncProcess: """Read exactly n bytes from the process stdout (or less if eof).""" if self._close_called: return b"" + assert self.proc is not None # for type checking + assert self.proc.stdout is not None # for type checking try: return await self.proc.stdout.readexactly(n) except asyncio.IncompleteReadError as err: @@ -134,12 +137,16 @@ class AsyncProcess: """ if self._close_called: return b"" + assert self.proc is not None # for type checking + assert self.proc.stdout is not None # for type checking return await self.proc.stdout.read(n) async def write(self, data: bytes) -> None: """Write data to process stdin.""" if self.closed: raise RuntimeError("write called while process already done") + assert self.proc is not None # for type checking + assert self.proc.stdin is not None # for type checking self.proc.stdin.write(data) with suppress(BrokenPipeError, ConnectionResetError): await self.proc.stdin.drain() @@ -148,6 +155,8 @@ class AsyncProcess: """Write end of file to to process stdin.""" if self.closed: return + assert self.proc is not None # for type checking + assert self.proc.stdin is not None # for type checking try: if self.proc.stdin.can_write_eof(): self.proc.stdin.write_eof() @@ -165,6 +174,8 @@ class AsyncProcess: """Read line from stderr.""" if self.returncode is not None: return b"" + assert self.proc is not None # for type checking + assert self.proc.stderr is not None # for type checking try: return await self.proc.stderr.readline() except ValueError as err: @@ -180,6 +191,7 @@ class AsyncProcess: async def iter_stderr(self) -> AsyncGenerator[str, None]: """Iterate lines from the stderr stream as string.""" + line: str | bytes while True: line = await self.read_stderr() if line == b"": @@ -198,13 +210,15 @@ class AsyncProcess: if self.closed: raise RuntimeError("communicate called while process already done") # abort existing readers on stderr/stdout first before we send communicate - waiter: asyncio.Future - if self.proc.stdout and (waiter := self.proc.stdout._waiter): - self.proc.stdout._waiter = None + waiter: asyncio.Future[None] + assert self.proc is not None # for type checking + # _waiter is attribute of StreamReader + if self.proc.stdout and (waiter := self.proc.stdout._waiter): # type: ignore[attr-defined] + self.proc.stdout._waiter = None # type: ignore[attr-defined] if waiter and not waiter.done(): waiter.set_exception(asyncio.CancelledError()) - if self.proc.stderr and (waiter := self.proc.stderr._waiter): - self.proc.stderr._waiter = None + if self.proc.stderr and (waiter := self.proc.stderr._waiter): # type: ignore[attr-defined] + self.proc.stderr._waiter = None # type: ignore[attr-defined] if waiter and not waiter.done(): waiter.set_exception(asyncio.CancelledError()) stdout, stderr = await asyncio.wait_for(self.proc.communicate(input), timeout) @@ -220,13 +234,13 @@ class AsyncProcess: if self.proc.stdin and not self.proc.stdin.is_closing(): self.proc.stdin.close() # abort existing readers on stderr/stdout first before we send communicate - waiter: asyncio.Future - if self.proc.stdout and (waiter := self.proc.stdout._waiter): - self.proc.stdout._waiter = None + waiter: asyncio.Future[None] + if self.proc.stdout and (waiter := self.proc.stdout._waiter): # type: ignore[attr-defined] + self.proc.stdout._waiter = None # type: ignore[attr-defined] if waiter and not waiter.done(): waiter.set_exception(asyncio.CancelledError()) - if self.proc.stderr and (waiter := self.proc.stderr._waiter): - self.proc.stderr._waiter = None + if self.proc.stderr and (waiter := self.proc.stderr._waiter): # type: ignore[attr-defined] + self.proc.stderr._waiter = None # type: ignore[attr-defined] if waiter and not waiter.done(): waiter.set_exception(asyncio.CancelledError()) await asyncio.sleep(0) # yield to loop @@ -261,6 +275,7 @@ class AsyncProcess: async def wait(self) -> int: """Wait for the process and return the returncode.""" if self._returncode is None: + assert self.proc is not None self._returncode = await self.proc.wait() return self._returncode @@ -275,6 +290,7 @@ async def check_output(*args: str, env: dict[str, str] | None = None) -> tuple[i *args, stderr=asyncio.subprocess.STDOUT, stdout=asyncio.subprocess.PIPE, env=env ) stdout, _ = await proc.communicate() + assert proc.returncode is not None # for type checking return (proc.returncode, stdout) @@ -290,4 +306,5 @@ async def communicate( stdin=asyncio.subprocess.PIPE if input is not None else None, ) stdout, stderr = await proc.communicate(input) + assert proc.returncode is not None # for type checking return (proc.returncode, stdout, stderr) diff --git a/music_assistant/helpers/tags.py b/music_assistant/helpers/tags.py index 2094fa7d..647d90f5 100644 --- a/music_assistant/helpers/tags.py +++ b/music_assistant/helpers/tags.py @@ -32,16 +32,18 @@ LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.tags") TAG_SPLITTER = ";" -def clean_tuple(values: Iterable[str]) -> tuple: +def clean_tuple(values: Iterable[str]) -> tuple[str, ...]: """Return a tuple with all empty values removed.""" return tuple(x.strip() for x in values if x not in (None, "", " ")) -def split_items(org_str: str, allow_unsafe_splitters: bool = False) -> tuple[str, ...]: +def split_items( + org_str: str | list[str] | tuple[str, ...] | None, allow_unsafe_splitters: bool = False +) -> tuple[str, ...]: """Split up a tags string by common splitter.""" if org_str is None: return () - if isinstance(org_str, list): + if isinstance(org_str, tuple | list): final_items: list[str] = [] for item in org_str: final_items.extend(split_items(item, allow_unsafe_splitters)) @@ -333,7 +335,7 @@ class AudioTags: return AlbumType.UNKNOWN @property - def isrc(self) -> tuple[str]: + def isrc(self) -> tuple[str, ...]: """Return isrc tag(s).""" for tag_name in ("isrc", "tsrc"): if tag := self.tags.get(tag_name): @@ -398,7 +400,7 @@ class AudioTags: return None @classmethod - def parse(cls, raw: dict) -> AudioTags: + def parse(cls, raw: dict[str, Any]) -> AudioTags: """Parse instance from raw ffmpeg info output.""" audio_stream = next((x for x in raw["streams"] if x["codec_type"] == "audio"), None) if audio_stream is None: @@ -435,7 +437,7 @@ class AudioTags: filename=raw["format"]["filename"], ) - def get(self, key: str, default=None) -> Any: + def get(self, key: str, default: Any | None = None) -> Any: """Get tag by key.""" return self.tags.get(key, default) @@ -532,7 +534,7 @@ def get_file_duration(input_file: str) -> float: # extract duration from ffmpeg output duration_str = res.split("time=")[-1].split(" ")[0].strip() duration_parts = duration_str.split(":") - duration = 0 + duration = 0.0 for part in duration_parts: duration = duration * 60 + float(part) return duration @@ -547,10 +549,11 @@ def parse_tags_mutagen(input_file: str) -> dict[str, Any]: NOT Async friendly. """ - result = {} + result: dict[str, Any] = {} try: # TODO: extend with more tags and file types! - tags = mutagen.File(input_file) + # https://mutagen.readthedocs.io/en/latest/user/gettingstarted.html + tags = mutagen.File(input_file) # type: ignore[attr-defined] if tags is None or not tags.tags: return result tags = dict(tags.tags) @@ -604,7 +607,7 @@ async def get_embedded_image(input_file: str) -> bytes | None: Input_file may be a (local) filename or URL accessible by ffmpeg. """ - args = ( + args = [ "ffmpeg", "-hide_banner", "-loglevel", @@ -617,7 +620,7 @@ async def get_embedded_image(input_file: str) -> bytes | None: "-f", "mjpeg", "-", - ) + ] async with AsyncProcess( args, stdin=False, stdout=True, stderr=None, name="ffmpeg_image" ) as ffmpeg: diff --git a/music_assistant/helpers/throttle_retry.py b/music_assistant/helpers/throttle_retry.py index 74a95738..fff16ca2 100644 --- a/music_assistant/helpers/throttle_retry.py +++ b/music_assistant/helpers/throttle_retry.py @@ -8,6 +8,7 @@ from collections import deque from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine from contextlib import asynccontextmanager from contextvars import ContextVar +from types import TracebackType from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar from music_assistant_models.errors import ResourceTemporarilyUnavailable, RetriesExhausted @@ -33,13 +34,13 @@ class Throttler: - Return the delay caused by acquire() """ - def __init__(self, rate_limit: int, period=1.0): + def __init__(self, rate_limit: int, period: float = 1.0) -> None: """Initialize the Throttler.""" self.rate_limit = rate_limit self.period = period self._task_logs: deque[float] = deque() - def _flush(self): + def _flush(self) -> None: now = time.monotonic() while self._task_logs: if now - self._task_logs[0] > self.period: @@ -67,21 +68,28 @@ class Throttler: """Wait until the lock is acquired, return the time delay.""" return await self.acquire() - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: """Nothing to do on exit.""" class ThrottlerManager: """Throttler manager that extends asyncio Throttle by retrying.""" - def __init__(self, rate_limit: int, period: float = 1, retry_attempts=5, initial_backoff=5): + def __init__( + self, rate_limit: int, period: float = 1, retry_attempts: int = 5, initial_backoff: int = 5 + ): """Initialize the AsyncThrottledContextManager.""" self.retry_attempts = retry_attempts self.initial_backoff = initial_backoff self.throttler = Throttler(rate_limit, period) @asynccontextmanager - async def acquire(self) -> AsyncGenerator[None, float]: + async def acquire(self) -> AsyncGenerator[float, None]: """Acquire a free slot from the Throttler, returns the throttled time.""" if BYPASS_THROTTLER.get(): yield 0 @@ -92,10 +100,12 @@ class ThrottlerManager: async def bypass(self) -> AsyncGenerator[None, None]: """Bypass the throttler.""" try: - token = BYPASS_THROTTLER.set(True) + BYPASS_THROTTLER.set(True) yield None finally: - BYPASS_THROTTLER.reset(token) + # TODO: token is unbound here + # BYPASS_THROTTLER.reset(token) + ... def throttle_with_retries( @@ -107,7 +117,7 @@ def throttle_with_retries( async def wrapper(self: _ProviderT, *args: _P.args, **kwargs: _P.kwargs) -> _R: """Call async function using the throttler with retries.""" # the trottler attribute must be present on the class - throttler: ThrottlerManager = self.throttler + throttler: ThrottlerManager = self.throttler # type: ignore[attr-defined] backoff_time = throttler.initial_backoff async with throttler.acquire() as delay: if delay != 0: diff --git a/music_assistant/helpers/upnp.py b/music_assistant/helpers/upnp.py index b560f3e5..47041a8c 100644 --- a/music_assistant/helpers/upnp.py +++ b/music_assistant/helpers/upnp.py @@ -130,6 +130,8 @@ def create_didl_metadata(media: PlayerMedia) -> str: ) duration_str = str(datetime.timedelta(seconds=media.duration or 0)) + ".000" + assert media.queue_item_id is not None # for type checking + return ( '' f'' diff --git a/music_assistant/helpers/util.py b/music_assistant/helpers/util.py index 7dca3724..c38b3db6 100644 --- a/music_assistant/helpers/util.py +++ b/music_assistant/helpers/util.py @@ -13,13 +13,13 @@ import socket import urllib.error import urllib.parse import urllib.request -from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Coroutine from contextlib import suppress from functools import lru_cache from importlib.metadata import PackageNotFoundError from importlib.metadata import version as pkg_version from types import TracebackType -from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar +from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar, cast from urllib.parse import urlparse import cchardet as chardet @@ -34,7 +34,7 @@ if TYPE_CHECKING: from zeroconf.asyncio import AsyncServiceInfo - from music_assistant import MusicAssistant + from music_assistant.mass import MusicAssistant from music_assistant.models import ProviderModuleType LOGGER = logging.getLogger(__name__) @@ -221,33 +221,34 @@ def clean_stream_title(line: str) -> str: return line -async def get_ip_addresses(include_ipv6: bool = False) -> tuple[str]: +async def get_ip_addresses(include_ipv6: bool = False) -> tuple[str, ...]: """Return all IP-adresses of all network interfaces.""" - def call() -> set[str]: + def call() -> tuple[str, ...]: result: list[tuple[int, str]] = [] adapters = ifaddr.get_adapters() for adapter in adapters: for ip in adapter.ips: if ip.is_IPv6 and not include_ipv6: continue - if ip.ip.startswith(("127", "169.254")): + ip_str = str(ip.ip) + if ip_str.startswith(("127", "169.254")): # filter out IPv4 loopback/APIPA address continue - if ip.ip.startswith(("::1", "::ffff:", "fe80")): + if ip_str.startswith(("::1", "::ffff:", "fe80")): # filter out IPv6 loopback/link-local address continue - if ip.ip.startswith(("192.168.",)): + if ip_str.startswith(("192.168.",)): # we rank the 192.168 range a bit higher as its most # often used as the private network subnet score = 2 - elif ip.ip.startswith(("172.", "10.", "192.")): + elif ip_str.startswith(("172.", "10.", "192.")): # we rank the 172 range a bit lower as its most # often used as the private docker network score = 1 else: score = 0 - result.append((score, ip.ip)) + result.append((score, ip_str)) result.sort(key=lambda x: x[0], reverse=True) return tuple(ip[1] for ip in result) @@ -406,7 +407,7 @@ async def get_package_version(pkg_name: str) -> str | None: async def is_hass_supervisor() -> bool: """Return if we're running inside the HA Supervisor (e.g. HAOS).""" - def _check(): + def _check() -> bool: try: urllib.request.urlopen("http://supervisor/core", timeout=1) except urllib.error.URLError as err: @@ -424,7 +425,9 @@ async def load_provider_module(domain: str, requirements: list[str]) -> Provider @lru_cache def _get_provider_module(domain: str) -> ProviderModuleType: - return importlib.import_module(f".{domain}", "music_assistant.providers") + return cast( + "ProviderModuleType", importlib.import_module(f".{domain}", "music_assistant.providers") + ) # ensure module requirements are met for requirement in requirements: @@ -474,8 +477,8 @@ async def get_free_space(folder: str) -> float: def _get_free_space(folder: str) -> float: """Return free space on given folderpath in GB.""" try: - if res := shutil.disk_usage(folder): - return res.free / float(1 << 30) + res = shutil.disk_usage(folder) + return res.free / float(1 << 30) except (FileNotFoundError, OSError, PermissionError): return 0.0 @@ -488,8 +491,8 @@ async def get_free_space_percentage(folder: str) -> float: def _get_free_space(folder: str) -> float: """Return free space on given folderpath in GB.""" try: - if res := shutil.disk_usage(folder): - return res.free / res.total * 100 + res = shutil.disk_usage(folder) + return res.free / res.total * 100 except (FileNotFoundError, OSError, PermissionError): return 0.0 @@ -528,7 +531,7 @@ def get_primary_ip_address_from_zeroconf(discovery_info: AsyncServiceInfo) -> st return None -def get_port_from_zeroconf(discovery_info: AsyncServiceInfo) -> str | None: +def get_port_from_zeroconf(discovery_info: AsyncServiceInfo) -> int | None: """Get primary IP address from zeroconf discovery info.""" return discovery_info.port @@ -542,11 +545,12 @@ async def close_async_generator(agen: AsyncGenerator[Any, None]) -> None: await agen.aclose() -async def detect_charset(data: bytes, fallback="utf-8") -> str: +async def detect_charset(data: bytes, fallback: str = "utf-8") -> str: """Detect charset of raw data.""" try: - detected = await asyncio.to_thread(chardet.detect, data) + detected: dict[str, Any] = await asyncio.to_thread(chardet.detect, data) if detected and detected["encoding"] and detected["confidence"] > 0.75: + assert isinstance(detected["encoding"], str) # for type checking return detected["encoding"] except Exception as err: LOGGER.debug("Failed to detect charset: %s", err) @@ -599,25 +603,26 @@ class TaskManager: def __init__(self, mass: MusicAssistant, limit: int = 0): """Initialize the TaskManager.""" self.mass = mass - self._tasks: list[asyncio.Task] = [] + self._tasks: list[asyncio.Task[None]] = [] self._semaphore = asyncio.Semaphore(limit) if limit else None - def create_task(self, coro: Coroutine) -> asyncio.Task: + def create_task(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task[None]: """Create a new task and add it to the manager.""" task = self.mass.create_task(coro) self._tasks.append(task) return task - async def create_task_with_limit(self, coro: Coroutine) -> None: + async def create_task_with_limit(self, coro: Coroutine[Any, Any, None]) -> None: """Create a new task with semaphore limit.""" assert self._semaphore is not None - def task_done_callback(_task: asyncio.Task) -> None: + def task_done_callback(_task: asyncio.Task[None]) -> None: + assert self._semaphore is not None # for type checking self._tasks.remove(task) self._semaphore.release() await self._semaphore.acquire() - task: asyncio.Task = self.create_task(coro) + task: asyncio.Task[None] = self.create_task(coro) task.add_done_callback(task_done_callback) async def __aenter__(self) -> Self: @@ -634,6 +639,7 @@ class TaskManager: if len(self._tasks) > 0: await asyncio.wait(self._tasks) self._tasks.clear() + return None _R = TypeVar("_R") @@ -650,7 +656,7 @@ def lock( """Call async function using the throttler with retries.""" if not (func_lock := getattr(func, "lock", None)): func_lock = asyncio.Lock() - func.lock = func_lock + func.lock = func_lock # type: ignore[attr-defined] async with func_lock: return await func(*args, **kwargs) @@ -664,7 +670,7 @@ class TimedAsyncGenerator: Source: https://medium.com/@dmitry8912/implementing-timeouts-in-pythons-asynchronous-generators-f7cbaa6dc1e9 """ - def __init__(self, iterable, timeout=0): + def __init__(self, iterable: AsyncIterator[Any], timeout: int = 0): """ Initialize the AsyncTimedIterable. @@ -674,10 +680,10 @@ class TimedAsyncGenerator: """ class AsyncTimedIterator: - def __init__(self): + def __init__(self) -> None: self._iterator = iterable.__aiter__() - async def __anext__(self): + async def __anext__(self) -> Any: result = await asyncio.wait_for(self._iterator.__anext__(), int(timeout)) if not result: raise StopAsyncIteration @@ -685,6 +691,6 @@ class TimedAsyncGenerator: self._factory = AsyncTimedIterator - def __aiter__(self): + def __aiter__(self): # type: ignore[no-untyped-def] """Return the async iterator.""" return self._factory() diff --git a/music_assistant/helpers/webserver.py b/music_assistant/helpers/webserver.py index 9c28a301..137ad925 100644 --- a/music_assistant/helpers/webserver.py +++ b/music_assistant/helpers/webserver.py @@ -9,7 +9,9 @@ from aiohttp import web if TYPE_CHECKING: import logging - from collections.abc import Awaitable, Callable + from collections.abc import Callable + + from aiohttp.typedefs import Handler MAX_CLIENT_SIZE: Final = 1024**2 * 16 @@ -30,7 +32,7 @@ class Webserver: self._apprunner: web.AppRunner | None = None self._webapp: web.Application | None = None self._tcp_site: web.TCPSite | None = None - self._static_routes: list[tuple[str, str, Awaitable]] | None = None + self._static_routes: list[tuple[str, str, Handler]] | None = None self._dynamic_routes: dict[str, Callable] | None = {} if enable_dynamic_routes else None self._bind_port: int | None = None @@ -39,7 +41,7 @@ class Webserver: bind_ip: str | None, bind_port: int, base_url: str, - static_routes: list[tuple[str, str, Awaitable]] | None = None, + static_routes: list[tuple[str, str, Handler]] | None = None, static_content: tuple[str, str, str] | None = None, ) -> None: """Async initialize of module.""" @@ -95,12 +97,12 @@ class Webserver: await self._webapp.cleanup() @property - def base_url(self): + def base_url(self) -> str: """Return the base URL of this webserver.""" return self._base_url @property - def port(self): + def port(self) -> int | None: """Return the port of this webserver.""" return self._bind_port @@ -121,6 +123,7 @@ class Webserver: self._dynamic_routes[key] = handler def _remove(): + assert self._dynamic_routes is not None # for type checking return self._dynamic_routes.pop(key) return _remove @@ -142,6 +145,7 @@ class Webserver: """Redirect request to correct destination.""" # find handler for the request for key in (f"{request.method}.{request.path}", f"*.{request.path}"): + assert self._dynamic_routes is not None # for type checking if handler := self._dynamic_routes.get(key): return await handler(request) # deny all other requests diff --git a/pyproject.toml b/pyproject.toml index e12cbe23..c9fb15ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,7 +123,8 @@ enable_error_code = [ ] exclude = [ '^music_assistant/controllers/.*$', - '^music_assistant/helpers/.*$', + '^music_assistant/helpers/app_vars.py', + '^music_assistant/helpers/webserver.py', '^music_assistant/models/.*$', '^music_assistant/providers/apple_music/.*$', '^music_assistant/providers/bluesound/.*$',