Chore: Add mypy for helpers (#2097)
authorFabian Munkes <105975993+fmunkes@users.noreply.github.com>
Thu, 15 May 2025 20:13:16 +0000 (22:13 +0200)
committerGitHub <noreply@github.com>
Thu, 15 May 2025 20:13:16 +0000 (22:13 +0200)
16 files changed:
music_assistant/helpers/api.py
music_assistant/helpers/audio.py
music_assistant/helpers/auth.py
music_assistant/helpers/compare.py
music_assistant/helpers/database.py
music_assistant/helpers/dsp.py
music_assistant/helpers/ffmpeg.py
music_assistant/helpers/images.py
music_assistant/helpers/playlists.py
music_assistant/helpers/process.py
music_assistant/helpers/tags.py
music_assistant/helpers/throttle_retry.py
music_assistant/helpers/upnp.py
music_assistant/helpers/util.py
music_assistant/helpers/webserver.py
pyproject.toml

index 6be15d3a6a880f59b661a5fd54305711b0931b3d..72b22e8f8a5dcd71047606c86de71b3a8c8cb111 100644 (file)
@@ -42,7 +42,7 @@ class APICommandHandler:
             if not hasattr(value, "__name__"):
                 continue
             if value.__name__ == "ItemCls":
-                type_hints[key] = func.__self__.item_cls
+                type_hints[key] = func.__self__.item_cls  # type: ignore[attr-defined]
         return APICommandHandler(
             command=command,
             signature=inspect.signature(func),
@@ -64,7 +64,7 @@ def api_command(command: str) -> Callable[[_F], _F]:
 def parse_arguments(
     func_sig: inspect.Signature,
     func_types: dict[str, Any],
-    args: dict | None,
+    args: dict[str, Any] | None,
     strict: bool = False,
 ) -> dict[str, Any]:
     """Parse (and convert) incoming arguments to correct types."""
@@ -161,6 +161,7 @@ def parse_value(  # noqa: PLR0911
         logging.getLogger(__name__).warning(err)
         return None
     if origin is type:
+        assert isinstance(value, str)  # for type checking
         return eval(value)
     if value_type is Any:
         return value
@@ -169,9 +170,10 @@ def parse_value(  # noqa: PLR0911
         raise KeyError(msg)
 
     try:
-        if issubclass(value_type, Enum):  # type: ignore[arg-type]
-            return value_type(value)  # type: ignore[operator]
-        if issubclass(value_type, datetime):  # type: ignore[arg-type]
+        if issubclass(value_type, Enum):
+            return value_type(value)
+        if issubclass(value_type, datetime):
+            assert isinstance(value, str)  # for type checking
             return parse_utc_timestamp(value)
     except TypeError:
         # happens if value_type is not a class
@@ -190,7 +192,7 @@ def parse_value(  # noqa: PLR0911
         if value_type is bool and isinstance(value, str | int):
             return try_parse_bool(value)
 
-    if not isinstance(value, value_type):  # type: ignore[arg-type]
+    if not isinstance(value, value_type):
         # all options failed, raise exception
         msg = (
             f"Value {value} of type {type(value)} is invalid for {name}, "
index aa231bf4e37e3b6a1030fc804839c8581a984c13..d9ccb257a2c5cd147962423c4be20231d159d4e0 100644 (file)
@@ -31,7 +31,7 @@ from music_assistant_models.errors import (
     MusicAssistantError,
     ProviderUnavailableError,
 )
-from music_assistant_models.streamdetails import AudioFormat
+from music_assistant_models.media_items import AudioFormat
 
 from music_assistant.constants import (
     CONF_ALLOW_AUDIO_CACHE,
@@ -58,10 +58,12 @@ from .util import detect_charset, has_enough_space
 if TYPE_CHECKING:
     from music_assistant_models.config_entries import CoreConfig, PlayerConfig
     from music_assistant_models.player import Player
-    from music_assistant_models.player_queue import QueueItem
+    from music_assistant_models.queue_item import QueueItem
     from music_assistant_models.streamdetails import StreamDetails
 
-    from music_assistant import MusicAssistant
+    from music_assistant.mass import MusicAssistant
+    from music_assistant.models.music_provider import MusicProvider
+    from music_assistant.providers.player_group import PlayerGroupProvider
 
 LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.audio")
 
@@ -85,6 +87,34 @@ class StreamCache:
     or when the audio stream is slow itself.
     """
 
+    def __init__(self, mass: MusicAssistant, streamdetails: StreamDetails) -> None:
+        """Initialize the StreamCache."""
+        self.mass = mass
+        self.streamdetails = streamdetails
+        self.logger = LOGGER.getChild("cache")
+        self._cache_file: str | None = None
+        self._fetch_task: asyncio.Task[None] | None = None
+        self._subscribers: int = 0
+        self._first_part_received = asyncio.Event()
+        self._all_data_written: bool = False
+        self._stream_error: str | None = None
+        self.org_path: str | None = streamdetails.path
+        self.org_stream_type: StreamType | None = streamdetails.stream_type
+        self.org_extra_input_args: list[str] | None = streamdetails.extra_input_args
+        self.org_audio_format = streamdetails.audio_format
+        streamdetails.audio_format = AudioFormat(
+            content_type=ContentType.NUT,
+            codec_type=streamdetails.audio_format.codec_type,
+            sample_rate=streamdetails.audio_format.sample_rate,
+            bit_depth=streamdetails.audio_format.bit_depth,
+            channels=streamdetails.audio_format.channels,
+        )
+        streamdetails.path = "-"
+        streamdetails.stream_type = StreamType.CACHE
+        streamdetails.can_seek = True
+        streamdetails.allow_seek = True
+        streamdetails.extra_input_args = []
+
     async def create(self) -> None:
         """Create the cache file (if needed)."""
         if self._cache_file is None:
@@ -93,6 +123,7 @@ class StreamCache:
             ):
                 # we have a mapping stored for this uri, prefer that
                 self._cache_file = cached_cache_path
+                assert self._cache_file is not None  # for type checking
                 if await asyncio.to_thread(os.path.exists, self._cache_file):
                     # cache file already exists from a previous session,
                     # we can simply use that, there is nothing to create
@@ -124,6 +155,7 @@ class StreamCache:
         """Release the cache file."""
         self._subscribers -= 1
         if self._subscribers <= 0:
+            assert self._cache_file is not None  # for type checking
             CACHE_FILES_IN_USE.discard(self._cache_file)
 
     async def get_audio_stream(self) -> str | AsyncGenerator[bytes, None]:
@@ -170,12 +202,17 @@ class StreamCache:
     async def _create_cache_file(self) -> None:
         time_start = time.time()
         self.logger.debug("Creating audio cache for %s", self.streamdetails.uri)
+        assert self._cache_file is not None  # for type checking
         CACHE_FILES_IN_USE.add(self._cache_file)
         self._first_part_received.clear()
         self._all_data_written = False
         extra_input_args = ["-y", *(self.org_extra_input_args or [])]
+        audio_source: AsyncGenerator[bytes, None] | str | int
         if self.org_stream_type == StreamType.CUSTOM:
-            audio_source = self.mass.get_provider(self.streamdetails.provider).get_audio_stream(
+            provider = self.mass.get_provider(self.streamdetails.provider)
+            if TYPE_CHECKING:  # avoid circular import
+                assert isinstance(provider, MusicProvider)
+            audio_source = provider.get_audio_stream(
                 self.streamdetails,
             )
         elif self.org_stream_type == StreamType.ICY:
@@ -183,14 +220,18 @@ class StreamCache:
         elif self.org_stream_type == StreamType.HLS:
             if self.streamdetails.media_type == MediaType.RADIO:
                 raise NotImplementedError("Caching of this streamtype is not supported!")
+            assert self.org_path is not None  # for type checking
             substream = await get_hls_substream(self.mass, self.org_path)
             audio_source = substream.path
         elif self.org_stream_type == StreamType.ENCRYPTED_HTTP:
+            assert self.org_path is not None  # for type checking
+            assert self.streamdetails.decryption_key is not None  # for type checking
             audio_source = self.org_path
             extra_input_args += ["-decryption_key", self.streamdetails.decryption_key]
         elif self.org_stream_type == StreamType.MULTI_FILE:
             audio_source = get_multi_file_stream(self.mass, self.streamdetails)
         else:
+            assert self.org_path is not None  # for type checking
             audio_source = self.org_path
 
         # we always use ffmpeg to fetch the original audio source
@@ -199,15 +240,15 @@ class StreamCache:
         # and it also accounts for complicated cases such as encrypted streams or
         # m4a/mp4 streams with the moov atom at the end of the file.
         # ffmpeg will produce a lossless copy of the original codec.
+        ffmpeg_proc = FFMpeg(
+            audio_input=audio_source,
+            input_format=self.org_audio_format,
+            output_format=self.streamdetails.audio_format,
+            extra_input_args=extra_input_args,
+            audio_output=self._cache_file,
+            collect_log_history=True,
+        )
         try:
-            ffmpeg_proc = FFMpeg(
-                audio_input=audio_source,
-                input_format=self.org_audio_format,
-                output_format=self.streamdetails.audio_format,
-                extra_input_args=extra_input_args,
-                audio_output=self._cache_file,
-                collect_log_history=True,
-            )
             await ffmpeg_proc.start()
             # wait until the first data is written to the cache file
             while ffmpeg_proc.returncode is None:
@@ -249,7 +290,7 @@ class StreamCache:
             await self._remove_cache_file()
             # unblock the waiting tasks by setting the event
             # this will allow the tasks to continue and handle the error
-            self._stream_error = str(err) or err.__qualname__
+            self._stream_error = str(err) or err.__qualname__  # type: ignore [attr-defined]
             self._first_part_received.set()
         finally:
             await ffmpeg_proc.close()
@@ -258,36 +299,9 @@ class StreamCache:
         self._first_part_received.clear()
         self._all_data_written = False
         self._fetch_task = None
+        assert self._cache_file is not None  # for type checking
         await remove_file(self._cache_file)
 
-    def __init__(self, mass: MusicAssistant, streamdetails: StreamDetails) -> None:
-        """Initialize the StreamCache."""
-        self.mass = mass
-        self.streamdetails = streamdetails
-        self.logger = LOGGER.getChild("cache")
-        self._cache_file: str | None = None
-        self._fetch_task: asyncio.Task | None = None
-        self._subscribers: int = 0
-        self._first_part_received = asyncio.Event()
-        self._all_data_written: bool = False
-        self._stream_error: str | None = None
-        self.org_path: str | None = streamdetails.path
-        self.org_stream_type: StreamType | None = streamdetails.stream_type
-        self.org_extra_input_args: list[str] | None = streamdetails.extra_input_args
-        self.org_audio_format = streamdetails.audio_format
-        streamdetails.audio_format = AudioFormat(
-            content_type=ContentType.NUT,
-            codec_type=streamdetails.audio_format.codec_type,
-            sample_rate=streamdetails.audio_format.sample_rate,
-            bit_depth=streamdetails.audio_format.bit_depth,
-            channels=streamdetails.audio_format.channels,
-        )
-        streamdetails.path = "-"
-        streamdetails.stream_type = StreamType.CACHE
-        streamdetails.can_seek = True
-        streamdetails.allow_seek = True
-        streamdetails.extra_input_args = []
-
 
 async def crossfade_pcm_parts(
     fade_in_part: bytes,
@@ -439,7 +453,7 @@ async def strip_silence(
 
 
 def get_player_dsp_details(
-    mass: MusicAssistant, player: Player, group_preventing_dsp=False
+    mass: MusicAssistant, player: Player, group_preventing_dsp: bool = False
 ) -> DSPDetails:
     """Return DSP details of single a player.
 
@@ -479,6 +493,7 @@ def get_stream_dsp_details(
     """Return DSP details of all players playing this queue, keyed by player_id."""
     player = mass.players.get(queue_id)
     dsp: dict[str, DSPDetails] = {}
+    assert player is not None  # for type checking
     group_preventing_dsp = is_grouping_preventing_dsp(player)
     output_format = None
     is_external_group = False
@@ -488,6 +503,8 @@ def get_stream_dsp_details(
             try:
                 # We need a bit of a hack here since only the leader knows the correct output format
                 provider = mass.get_provider(player.provider)
+                if TYPE_CHECKING:  # avoid circular import
+                    assert isinstance(provider, PlayerGroupProvider)
                 if provider:
                     output_format = provider._get_sync_leader(player).output_format
             except RuntimeError:
@@ -557,6 +574,7 @@ async def get_stream_details(
     else:
         # retrieve streamdetails from provider
         media_item = queue_item.media_item
+        assert media_item is not None  # for type checking
         # sort by quality and check item's availability
         for prov_media in sorted(
             media_item.provider_mappings, key=lambda x: x.quality or 0, reverse=True
@@ -566,12 +584,14 @@ async def get_stream_details(
                 continue
             # guard that provider is available
             music_prov = mass.get_provider(prov_media.provider_instance)
+            if TYPE_CHECKING:  # avoid circular import
+                assert isinstance(music_prov, MusicProvider)
             if not music_prov:
                 LOGGER.debug(f"Skipping {prov_media} - provider not available")
                 continue  # provider not available ?
             # get streamdetails from provider
             try:
-                streamdetails: StreamDetails = await music_prov.get_stream_details(
+                streamdetails = await music_prov.get_stream_details(
                     prov_media.item_id, media_item.media_type
                 )
             except MusicAssistantError as err:
@@ -587,6 +607,7 @@ async def get_stream_details(
             streamdetails.stream_type in (StreamType.ICY, StreamType.HLS, StreamType.HTTP)
             and streamdetails.media_type == MediaType.RADIO
         ):
+            assert streamdetails.path is not None  # for type checking
             resolved_url, stream_type = await resolve_radio_stream(mass, streamdetails.path)
             streamdetails.path = resolved_url
             streamdetails.stream_type = stream_type
@@ -611,7 +632,7 @@ async def get_stream_details(
     player_settings = await mass.config.get_player_config(streamdetails.queue_id)
     core_config = await mass.config.get_core_config("streams")
     streamdetails.target_loudness = float(
-        player_settings.get_value(CONF_VOLUME_NORMALIZATION_TARGET)
+        str(player_settings.get_value(CONF_VOLUME_NORMALIZATION_TARGET))
     )
     streamdetails.volume_normalization_mode = _get_normalization_mode(
         core_config, player_settings, streamdetails
@@ -714,6 +735,8 @@ async def get_media_stream(
     extra_input_args = streamdetails.extra_input_args or []
     strip_silence_begin = streamdetails.strip_silence_begin
     strip_silence_end = streamdetails.strip_silence_end
+    if filter_params is None:
+        filter_params = []
     if streamdetails.fade_in:
         filter_params.append("afade=type=in:start_time=0:duration=3")
         strip_silence_begin = False
@@ -726,13 +749,18 @@ async def get_media_stream(
     elif stream_type == StreamType.MULTI_FILE:
         audio_source = get_multi_file_stream(mass, streamdetails)
     elif stream_type == StreamType.CUSTOM:
-        audio_source = mass.get_provider(streamdetails.provider).get_audio_stream(
+        music_prov = mass.get_provider(streamdetails.provider)
+        if TYPE_CHECKING:  # avoid circular import
+            assert isinstance(music_prov, MusicProvider)
+        audio_source = music_prov.get_audio_stream(
             streamdetails,
             seek_position=streamdetails.seek_position if streamdetails.can_seek else 0,
         )
     elif stream_type == StreamType.ICY:
+        assert streamdetails.path is not None  # for type checking
         audio_source = get_icy_radio_stream(mass, streamdetails.path, streamdetails)
     elif stream_type == StreamType.HLS:
+        assert streamdetails.path is not None  # for type checking
         substream = await get_hls_substream(mass, streamdetails.path)
         audio_source = substream.path
         if streamdetails.media_type == MediaType.RADIO:
@@ -741,9 +769,12 @@ async def get_media_stream(
             # so we tell ffmpeg to loop around in this case.
             extra_input_args += ["-stream_loop", "-1", "-re"]
     elif stream_type == StreamType.ENCRYPTED_HTTP:
+        assert streamdetails.path is not None  # for type checking
+        assert streamdetails.decryption_key is not None  # for type checking
         audio_source = streamdetails.path
         extra_input_args += ["-decryption_key", streamdetails.decryption_key]
     else:
+        assert streamdetails.path is not None  # for type checking
         audio_source = streamdetails.path
 
     # handle seek support
@@ -774,6 +805,7 @@ async def get_media_stream(
 
     try:
         await ffmpeg_proc.start()
+        assert ffmpeg_proc.proc is not None  # for type checking
         logger.debug(
             "Started media stream for %s"
             " - using streamtype: %s"
@@ -886,7 +918,7 @@ async def get_media_stream(
         streamdetails.seconds_streamed = seconds_streamed
         # store accurate duration
         if finished and not streamdetails.seek_position and seconds_streamed:
-            streamdetails.duration = seconds_streamed
+            streamdetails.duration = int(seconds_streamed)
 
         # release cache if needed
         if cache := streamdetails.cache:
@@ -936,10 +968,14 @@ async def get_media_stream(
         if (finished or seconds_streamed >= 30) and (
             music_prov := mass.get_provider(streamdetails.provider)
         ):
+            if TYPE_CHECKING:  # avoid circular import
+                assert isinstance(music_prov, MusicProvider)
             mass.create_task(music_prov.on_streamed(streamdetails))
 
 
-def create_wave_header(samplerate=44100, channels=2, bitspersample=16, duration=None):
+def create_wave_header(
+    samplerate: int = 44100, channels: int = 2, bitspersample: int = 16, duration: int | None = None
+) -> bytes:
     """Generate a wave header from given params."""
     file = BytesIO()
 
@@ -1005,7 +1041,7 @@ async def resolve_radio_stream(mass: MusicAssistant, url: str) -> tuple[str, Str
     """
     cache_base_key = "resolved_radio_info"
     if cache := await mass.cache.get(url, base_key=cache_base_key):
-        return cache
+        return cast("tuple[str, StreamType]", cache)
     stream_type = StreamType.HTTP
     resolved_url = url
     timeout = ClientTimeout(total=0, connect=10, sock_read=5)
@@ -1024,7 +1060,7 @@ async def resolve_radio_stream(mass: MusicAssistant, url: str) -> tuple[str, Str
             or ".m3u?" in url
             or ".m3u8?" in url
             or ".pls?" in url
-            or "audio/x-mpegurl" in headers.get("content-type")
+            or "audio/x-mpegurl" in headers.get("content-type", "")
             or "audio/x-scpls" in headers.get("content-type", "")
         ):
             # url is playlist, we need to unfold it
@@ -1074,15 +1110,15 @@ async def get_icy_radio_stream(
             if not meta_data:
                 continue
             meta_data = meta_data.rstrip(b"\0")
-            stream_title = re.search(rb"StreamTitle='([^']*)';", meta_data)
-            if not stream_title:
+            stream_title_re = re.search(rb"StreamTitle='([^']*)';", meta_data)
+            if not stream_title_re:
                 continue
             try:
                 # in 99% of the cases the stream title is utf-8 encoded
-                stream_title = stream_title.group(1).decode("utf-8")
+                stream_title = stream_title_re.group(1).decode("utf-8")
             except UnicodeDecodeError:
                 # fallback to iso-8859-1
-                stream_title = stream_title.group(1).decode("iso-8859-1", errors="replace")
+                stream_title = stream_title_re.group(1).decode("iso-8859-1", errors="replace")
             cleaned_stream_title = clean_stream_title(stream_title)
             if cleaned_stream_title != streamdetails.stream_title:
                 LOGGER.log(
@@ -1127,7 +1163,12 @@ async def get_hls_substream(
         return PlaylistItem(path=url, key=substreams[0].key)
     # sort substreams on best quality (highest bandwidth) when available
     if any(x for x in substreams if x.stream_info):
-        substreams.sort(key=lambda x: int(x.stream_info.get("BANDWIDTH", "0")), reverse=True)
+        substreams.sort(
+            key=lambda x: int(
+                x.stream_info.get("BANDWIDTH", "0") if x.stream_info is not None else 0
+            ),
+            reverse=True,
+        )
     substream = substreams[0]
     if not substream.path.startswith("http"):
         # path is relative, stitch it together
@@ -1159,6 +1200,7 @@ async def get_http_stream(
     timeout = ClientTimeout(total=0, connect=30, sock_read=5 * 60)
     skip_bytes = 0
     if seek_position and streamdetails.size:
+        assert streamdetails.duration is not None  # for type checking
         skip_bytes = int(streamdetails.size / streamdetails.duration * seek_position)
         headers["Range"] = f"bytes={skip_bytes}-{streamdetails.size}"
 
@@ -1238,6 +1280,7 @@ async def get_file_stream(
     chunk_size = get_chunksize(streamdetails.audio_format)
     async with aiofiles.open(streamdetails.data, "rb") as _file:
         if seek_position:
+            assert streamdetails.duration is not None  # for type checking
             seek_pos = int((streamdetails.size / streamdetails.duration) * seek_position)
             await _file.seek(seek_pos)
         # yield chunks of data from file
@@ -1286,11 +1329,18 @@ async def get_preview_stream(
     """Create a 30 seconds preview audioclip for the given streamdetails."""
     if not (music_prov := mass.get_provider(provider_instance_id_or_domain)):
         raise ProviderUnavailableError
+    if TYPE_CHECKING:  # avoid circular import
+        assert isinstance(music_prov, MusicProvider)
     streamdetails = await music_prov.get_stream_details(item_id, media_type)
+
+    audio_input: AsyncGenerator[bytes, None] | str
+    if streamdetails.stream_type == StreamType.CUSTOM:
+        audio_input = music_prov.get_audio_stream(streamdetails, 30)
+    else:
+        assert streamdetails.path is not None  # for type checking
+        audio_input = streamdetails.path
     async for chunk in get_ffmpeg_stream(
-        audio_input=music_prov.get_audio_stream(streamdetails, 30)
-        if streamdetails.stream_type == StreamType.CUSTOM
-        else streamdetails.path,
+        audio_input=audio_input,
         input_format=streamdetails.audio_format,
         output_format=AudioFormat(content_type=ContentType.AAC),
         extra_input_args=["-to", "30"],
@@ -1401,11 +1451,12 @@ def is_output_limiter_enabled(mass: MusicAssistant, player: Player) -> bool:
     elif player.synced_to:
         # Not in sync group, but synced, get from the leader
         deciding_player_id = player.synced_to
-    return mass.config.get_raw_player_config_value(
+    output_limiter_enabled = mass.config.get_raw_player_config_value(
         deciding_player_id,
         CONF_ENTRY_OUTPUT_LIMITER.key,
         CONF_ENTRY_OUTPUT_LIMITER.default_value,
     )
+    return bool(output_limiter_enabled)
 
 
 def get_player_filter_params(
@@ -1434,7 +1485,8 @@ def get_player_filter_params(
             # We can still apply the DSP of that single player.
             if player.group_childs:
                 child_player = mass.players.get(player.group_childs[0])
-                dsp = mass.config.get_player_dsp_config(child_player)
+                assert child_player is not None  # for type checking
+                dsp = mass.config.get_player_dsp_config(child_player.player_id)
             else:
                 # This should normally never happen, but if it does, we disable DSP.
                 dsp.enabled = False
@@ -1525,16 +1577,23 @@ async def analyze_loudness(
     elif streamdetails.stream_type == StreamType.MULTI_FILE:
         audio_source = get_multi_file_stream(mass, streamdetails)
     elif streamdetails.stream_type == StreamType.CUSTOM:
-        audio_source = mass.get_provider(streamdetails.provider).get_audio_stream(
+        music_prov = mass.get_provider(streamdetails.provider)
+        if TYPE_CHECKING:  # avoid circular import
+            assert isinstance(music_prov, MusicProvider)
+        audio_source = music_prov.get_audio_stream(
             streamdetails,
         )
     elif streamdetails.stream_type == StreamType.HLS:
+        assert streamdetails.path is not None  # for type checking
         substream = await get_hls_substream(mass, streamdetails.path)
         audio_source = substream.path
     elif streamdetails.stream_type == StreamType.ENCRYPTED_HTTP:
+        assert streamdetails.path is not None  # for type checking
+        assert streamdetails.decryption_key is not None  # for type checking
         audio_source = streamdetails.path
         extra_input_args += ["-decryption_key", streamdetails.decryption_key]
     else:
+        assert streamdetails.path is not None  # for type checking
         audio_source = streamdetails.path
 
     # calculate BS.1770 R128 integrated loudness with ffmpeg
@@ -1593,10 +1652,12 @@ def _get_normalization_mode(
         return VolumeNormalizationMode.DISABLED
     # work out preference for track or radio
     preference = VolumeNormalizationMode(
-        core_config.get_value(
-            CONF_VOLUME_NORMALIZATION_RADIO
-            if streamdetails.media_type == MediaType.RADIO
-            else CONF_VOLUME_NORMALIZATION_TRACKS,
+        str(
+            core_config.get_value(
+                CONF_VOLUME_NORMALIZATION_RADIO
+                if streamdetails.media_type == MediaType.RADIO
+                else CONF_VOLUME_NORMALIZATION_TRACKS,
+            )
         )
     )
 
index 9f3465a5e1978f15cf5bd852d1c47573dee7f109..8a5ba50cea389e23f2d8740bf1ca9e593ac60d3b 100644 (file)
@@ -12,7 +12,7 @@ from music_assistant_models.enums import EventType
 from music_assistant_models.errors import LoginFailed
 
 if TYPE_CHECKING:
-    from music_assistant import MusicAssistant
+    from music_assistant.mass import MusicAssistant
 
 LOGGER = logging.getLogger(__name__)
 
@@ -51,6 +51,7 @@ class AuthenticationHelper:
     ) -> bool | None:
         """Exit context manager."""
         self.mass.webserver.unregister_dynamic_route(self._cb_path, "GET")
+        return None
 
     async def authenticate(self, auth_url: str, timeout: int = 60) -> dict[str, str]:
         """Start the auth process and return any query params if received on the callback."""
index 2162aaedc4c26ddaeaceac4558166fe830ad7f7f..4f2e91502eec586744ec3f6af0aa9b2e88e702d1 100644 (file)
@@ -36,22 +36,38 @@ def compare_media_item(
 ) -> bool | None:
     """Compare two media items and return True if they match."""
     if base_item.media_type == MediaType.ARTIST and compare_item.media_type == MediaType.ARTIST:
+        assert isinstance(base_item, Artist | ItemMapping)  # for type checking
+        assert isinstance(compare_item, Artist | ItemMapping)  # for type checking
         return compare_artist(base_item, compare_item, strict)
     if base_item.media_type == MediaType.ALBUM and compare_item.media_type == MediaType.ALBUM:
+        assert isinstance(base_item, Album | ItemMapping)  # for type checking
+        assert isinstance(compare_item, Album | ItemMapping)  # for type checking
         return compare_album(base_item, compare_item, strict)
     if base_item.media_type == MediaType.TRACK and compare_item.media_type == MediaType.TRACK:
+        assert isinstance(base_item, Track)  # for type checking
+        assert isinstance(compare_item, Track)  # for type checking
         return compare_track(base_item, compare_item, strict)
     if base_item.media_type == MediaType.PLAYLIST and compare_item.media_type == MediaType.PLAYLIST:
+        assert isinstance(base_item, Playlist | ItemMapping)  # for type checking
+        assert isinstance(compare_item, Playlist | ItemMapping)  # for type checking
         return compare_playlist(base_item, compare_item, strict)
     if base_item.media_type == MediaType.RADIO and compare_item.media_type == MediaType.RADIO:
+        assert isinstance(base_item, Radio | ItemMapping)  # for type checking
+        assert isinstance(compare_item, Radio | ItemMapping)  # for type checking
         return compare_radio(base_item, compare_item, strict)
     if (
         base_item.media_type == MediaType.AUDIOBOOK
         and compare_item.media_type == MediaType.AUDIOBOOK
     ):
+        assert isinstance(base_item, Audiobook | ItemMapping)  # for type checking
+        assert isinstance(compare_item, Audiobook | ItemMapping)  # for type checking
         return compare_audiobook(base_item, compare_item, strict)
     if base_item.media_type == MediaType.PODCAST and compare_item.media_type == MediaType.PODCAST:
+        assert isinstance(base_item, Podcast | ItemMapping)  # for type checking
+        assert isinstance(compare_item, Podcast | ItemMapping)  # for type checking
         return compare_podcast(base_item, compare_item, strict)
+    assert isinstance(base_item, ItemMapping)  # for type checking
+    assert isinstance(compare_item, ItemMapping)  # for type checking
     return compare_item_mapping(base_item, compare_item, strict)
 
 
@@ -61,8 +77,6 @@ def compare_artist(
     strict: bool = True,
 ) -> bool | None:
     """Compare two artist items and return True if they match."""
-    if base_item is None or compare_item is None:
-        return False
     # return early on exact item_id match
     if compare_item_ids(base_item, compare_item):
         return True
@@ -78,13 +92,11 @@ def compare_artist(
 
 
 def compare_album(
-    base_item: Album | ItemMapping | None,
-    compare_item: Album | ItemMapping | None,
+    base_item: Album | ItemMapping,
+    compare_item: Album | ItemMapping,
     strict: bool = True,
 ) -> bool | None:
     """Compare two album items and return True if they match."""
-    if base_item is None or compare_item is None:
-        return False
     # return early on exact item_id match
     if compare_item_ids(base_item, compare_item):
         return True
@@ -125,14 +137,12 @@ def compare_album(
 
 
 def compare_track(
-    base_item: Track | None,
-    compare_item: Track | None,
+    base_item: Track,
+    compare_item: Track,
     strict: bool = True,
     track_albums: list[Album] | None = None,
 ) -> bool:
     """Compare two track items and return True if they match."""
-    if base_item is None or compare_item is None:
-        return False
     # return early on exact item_id match
     if compare_item_ids(base_item, compare_item):
         return True
@@ -243,8 +253,6 @@ def compare_playlist(
     strict: bool = True,
 ) -> bool | None:
     """Compare two Playlist items and return True if they match."""
-    if base_item is None or compare_item is None:
-        return False
     # require (exact) name match
     if not compare_strings(base_item.name, compare_item.name, strict=strict):
         return False
@@ -262,8 +270,6 @@ def compare_radio(
     strict: bool = True,
 ) -> bool | None:
     """Compare two Radio items and return True if they match."""
-    if base_item is None or compare_item is None:
-        return False
     # return early on exact item_id match
     if compare_item_ids(base_item, compare_item):
         return True
@@ -275,13 +281,11 @@ def compare_radio(
 
 
 def compare_audiobook(
-    base_item: Audiobook | ItemMapping | None,
-    compare_item: Audiobook | ItemMapping | None,
+    base_item: Audiobook | ItemMapping,
+    compare_item: Audiobook | ItemMapping,
     strict: bool = True,
 ) -> bool | None:
     """Compare two Audiobook items and return True if they match."""
-    if base_item is None or compare_item is None:
-        return False
     # return early on exact item_id match
     if compare_item_ids(base_item, compare_item):
         return True
@@ -324,13 +328,11 @@ def compare_audiobook(
 
 
 def compare_podcast(
-    base_item: Podcast | ItemMapping | None,
-    compare_item: Podcast | ItemMapping | None,
+    base_item: Podcast | ItemMapping,
+    compare_item: Podcast | ItemMapping,
     strict: bool = True,
 ) -> bool | None:
     """Compare two Podcast items and return True if they match."""
-    if base_item is None or compare_item is None:
-        return False
     # return early on exact item_id match
     if compare_item_ids(base_item, compare_item):
         return True
@@ -371,15 +373,17 @@ def compare_item_mapping(
     strict: bool = True,
 ) -> bool | None:
     """Compare two ItemMapping items and return True if they match."""
-    if base_item is None or compare_item is None:
-        return False
     # return early on exact item_id match
     if compare_item_ids(base_item, compare_item):
         return True
     # return early on (un)matched external id
-    external_id_match = compare_external_ids(base_item.external_ids, compare_item.external_ids)
-    if external_id_match is not None:
-        return external_id_match
+    # check all ExternalID, as ItemMapping is a minimized obj for all MediaItems
+    for ext_id in ExternalID:
+        external_id_match = compare_external_ids(
+            base_item.external_ids, compare_item.external_ids, ext_id
+        )
+        if external_id_match is not None:
+            return external_id_match
     # compare version
     if not compare_version(base_item.version, compare_item.version):
         return False
@@ -440,6 +444,7 @@ def compare_item_ids(
     compare_prov_ids = getattr(compare_item, "provider_mappings", None)
 
     if base_prov_ids is not None:
+        assert isinstance(base_item, MediaItem)  # for type checking
         for prov_l in base_item.provider_mappings:
             if (
                 prov_l.provider_domain == compare_item.provider
@@ -448,11 +453,14 @@ def compare_item_ids(
                 return True
 
     if compare_prov_ids is not None:
+        assert isinstance(compare_item, MediaItem)  # for type checking
         for prov_r in compare_item.provider_mappings:
             if prov_r.provider_domain == base_item.provider and prov_r.item_id == base_item.item_id:
                 return True
 
     if base_prov_ids is not None and compare_prov_ids is not None:
+        assert isinstance(base_item, MediaItem)  # for type checking
+        assert isinstance(compare_item, MediaItem)  # for type checking
         for prov_l in base_item.provider_mappings:
             for prov_r in compare_item.provider_mappings:
                 if prov_l.provider_domain != prov_r.provider_domain:
index 578f666f4be4ae539c2fc343f9699ef36448aeaa..c393d386bfffc132757336fbeb20ecd5104bb37f 100644 (file)
@@ -6,16 +6,17 @@ import asyncio
 import logging
 import os
 import time
+from collections.abc import Mapping
 from contextlib import asynccontextmanager
 from sqlite3 import OperationalError
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, cast
 
 import aiosqlite
 
 from music_assistant.constants import MASS_LOGGER_NAME
 
 if TYPE_CHECKING:
-    from collections.abc import AsyncGenerator, Mapping
+    from collections.abc import AsyncGenerator
 
 LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.database")
 
@@ -23,7 +24,9 @@ ENABLE_DEBUG = os.environ.get("PYTHONDEVMODE") == "1"
 
 
 @asynccontextmanager
-async def debug_query(sql_query: str, query_params: dict | None = None):
+async def debug_query(
+    sql_query: str, query_params: dict[str, Any] | None = None
+) -> AsyncGenerator[None]:
     """Time the processing time of an sql query."""
     if not ENABLE_DEBUG:
         yield
@@ -46,7 +49,7 @@ async def debug_query(sql_query: str, query_params: dict | None = None):
 def query_params(query: str, params: dict[str, Any] | None) -> tuple[str, dict[str, Any]]:
     """Extend query parameters support."""
     if params is None:
-        return (query, params)
+        return (query, {})
     count = 0
     result_query = query
     result_params = {}
@@ -100,11 +103,11 @@ class DatabaseConnection:
     async def get_rows(
         self,
         table: str,
-        match: dict | None = None,
+        match: dict[str, Any] | None = None,
         order_by: str | None = None,
         limit: int = 500,
         offset: int = 0,
-    ) -> list[Mapping]:
+    ) -> list[Mapping[str, Any]]:
         """Get all rows for given table."""
         sql_query = f"SELECT * FROM {table}"
         if match is not None:
@@ -114,26 +117,28 @@ class DatabaseConnection:
         if limit:
             sql_query += f" LIMIT {limit} OFFSET {offset}"
         async with debug_query(sql_query):
-            return await self._db.execute_fetchall(sql_query, match)
+            return cast(
+                "list[Mapping[str, Any]]", await self._db.execute_fetchall(sql_query, match)
+            )
 
     async def get_rows_from_query(
         self,
         query: str,
-        params: dict | None = None,
+        params: dict[str, Any] | None = None,
         limit: int = 500,
         offset: int = 0,
-    ) -> list[Mapping]:
+    ) -> list[Mapping[str, Any]]:
         """Get all rows for given custom query."""
         if limit:
             query += f" LIMIT {limit} OFFSET {offset}"
         _query, _params = query_params(query, params)
         async with debug_query(_query, _params):
-            return await self._db.execute_fetchall(_query, _params)
+            return cast("list[Mapping[str, Any]]", await self._db.execute_fetchall(_query, _params))
 
     async def get_count_from_query(
         self,
         query: str,
-        params: dict | None = None,
+        params: dict[str, Any] | None = None,
     ) -> int:
         """Get row count for given custom query."""
         query = f"SELECT count() FROM ({query})"
@@ -141,6 +146,7 @@ class DatabaseConnection:
         async with debug_query(_query):
             async with self._db.execute(_query, _params) as cursor:
                 if result := await cursor.fetchone():
+                    assert isinstance(result[0], int)  # for type checking
                     return result[0]
             return 0
 
@@ -153,22 +159,27 @@ class DatabaseConnection:
         async with debug_query(query):
             async with self._db.execute(query) as cursor:
                 if result := await cursor.fetchone():
+                    assert isinstance(result[0], int)  # for type checking
                     return result[0]
             return 0
 
-    async def search(self, table: str, search: str, column: str = "name") -> list[Mapping]:
+    async def search(
+        self, table: str, search: str, column: str = "name"
+    ) -> list[Mapping[str, Any]]:
         """Search table by column."""
         sql_query = f"SELECT * FROM {table} WHERE {table}.{column} LIKE :search"
         params = {"search": f"%{search}%"}
         async with debug_query(sql_query, params):
-            return await self._db.execute_fetchall(sql_query, params)
+            return cast(
+                "list[Mapping[str, Any]]", await self._db.execute_fetchall(sql_query, params)
+            )
 
-    async def get_row(self, table: str, match: dict[str, Any]) -> Mapping | None:
+    async def get_row(self, table: str, match: dict[str, Any]) -> Mapping[str, Any] | None:
         """Get single row for given table where column matches keys/values."""
         sql_query = f"SELECT * FROM {table} WHERE "
         sql_query += " AND ".join(f"{table}.{x} = :{x}" for x in match)
         async with debug_query(sql_query, match), self._db.execute(sql_query, match) as cursor:
-            return await cursor.fetchone()
+            return cast("Mapping[str, Any] | None", await cursor.fetchone())
 
     async def insert(
         self,
@@ -185,9 +196,11 @@ class DatabaseConnection:
         sql_query += f" VALUES ({','.join(f':{x}' for x in keys)})"
         row_id = await self._db.execute_insert(sql_query, values)
         await self._db.commit()
+        assert row_id is not None  # for type checking
+        assert isinstance(row_id[0], int)  # for type checking
         return row_id[0]
 
-    async def insert_or_replace(self, table: str, values: dict[str, Any]) -> Mapping:
+    async def insert_or_replace(self, table: str, values: dict[str, Any]) -> int:
         """Insert or replace data in given table."""
         return await self.insert(table=table, values=values, allow_replace=True)
 
@@ -196,7 +209,7 @@ class DatabaseConnection:
         table: str,
         match: dict[str, Any],
         values: dict[str, Any],
-    ) -> Mapping:
+    ) -> Mapping[str, Any]:
         """Update record."""
         keys = tuple(values.keys())
         sql_query = f"UPDATE {table} SET {','.join(f'{x}=:{x}' for x in keys)} WHERE "
@@ -204,9 +217,13 @@ class DatabaseConnection:
         await self.execute(sql_query, {**match, **values})
         await self._db.commit()
         # return updated item
-        return await self.get_row(table, match)
+        updated_item = await self.get_row(table, match)
+        assert updated_item is not None  # for type checking
+        return updated_item
 
-    async def delete(self, table: str, match: dict | None = None, query: str | None = None) -> None:
+    async def delete(
+        self, table: str, match: dict[str, Any] | None = None, query: str | None = None
+    ) -> None:
         """Delete data in given table."""
         assert not (query and "where" in query.lower())
         sql_query = f"DELETE FROM {table} "
@@ -225,7 +242,7 @@ class DatabaseConnection:
         await self.execute(sql_query)
         await self._db.commit()
 
-    async def execute(self, query: str, values: dict | None = None) -> Any:
+    async def execute(self, query: str, values: dict[str, Any] | None = None) -> Any:
         """Execute command on the database."""
         return await self._db.execute(query, values)
 
@@ -236,8 +253,8 @@ class DatabaseConnection:
     async def iter_items(
         self,
         table: str,
-        match: dict | None = None,
-    ) -> AsyncGenerator[Mapping, None]:
+        match: dict[str, Any] | None = None,
+    ) -> AsyncGenerator[Mapping[str, Any], None]:
         """Iterate all items within a table."""
         limit: int = 500
         offset: int = 0
index 1d741fb1953846a1ca0ed4fef713959b99ba64ec..b22ce8307e0afba64988a465e48cb0d9577b9a66 100644 (file)
@@ -9,7 +9,7 @@ from music_assistant_models.dsp import (
     ParametricEQFilter,
     ToneControlFilter,
 )
-from music_assistant_models.streamdetails import AudioFormat
+from music_assistant_models.media_items.audio_format import AudioFormat
 
 # ruff: noqa: PLR0915
 
@@ -38,7 +38,9 @@ def filter_to_ffmpeg_params(dsp_filter: DSPFilter, input_format: AudioFormat) ->
                 # Get gain for this channel, default to 0 if not specified
                 gain_db = dsp_filter.per_channel_preamp.get(channel_id, 0)
                 # Apply both the overall preamp and the per-channel preamp
-                total_gain_db = dsp_filter.preamp + gain_db
+                total_gain_db = (
+                    dsp_filter.preamp + gain_db if dsp_filter.preamp is not None else gain_db
+                )
                 if total_gain_db != 0:
                     # Convert dB to linear gain
                     gain = 10 ** (total_gain_db / 20)
index fa9c2b8ae89ef5a91012ac432c19379a97870639..e936844668d1515664be725a8e0f49afd238a90a 100644 (file)
@@ -57,9 +57,10 @@ class FFMpeg(AsyncProcess):
         self.input_format = input_format
         self.collect_log_history = collect_log_history
         self.log_history: deque[str] = deque(maxlen=100)
-        self._stdin_task: asyncio.Task | None = None
-        self._logger_task: asyncio.Task | None = None
+        self._stdin_task: asyncio.Task[None] | None = None
+        self._logger_task: asyncio.Task[None] | None = None
         self._input_codec_parsed = False
+        stdin: bool | int
         if audio_input == "-" or isinstance(audio_input, AsyncGenerator):
             stdin = True
         else:
@@ -161,8 +162,8 @@ class FFMpeg(AsyncProcess):
 
     async def _feed_stdin(self) -> None:
         """Feed stdin with audio chunks from an AsyncGenerator."""
-        if TYPE_CHECKING:
-            self.audio_input: AsyncGenerator[bytes, None]
+        assert not isinstance(self.audio_input, str | int)
+
         generator_exhausted = False
         cancelled = False
         try:
index 99e34c406353085368b9db231d9d04ce17b9816f..2fd31b9e0d49a64659fce8070ef82418db663615 100644 (file)
@@ -9,7 +9,7 @@ import random
 from base64 import b64decode
 from collections.abc import Iterable
 from io import BytesIO
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, cast
 
 import aiofiles
 from aiohttp.client_exceptions import ClientError
@@ -17,19 +17,20 @@ from PIL import Image, UnidentifiedImageError
 
 from music_assistant.helpers.tags import get_embedded_image
 from music_assistant.models.metadata_provider import MetadataProvider
+from music_assistant.models.music_provider import MusicProvider
 
 if TYPE_CHECKING:
     from music_assistant_models.media_items import MediaItemImage
+    from PIL.Image import Image as ImageClass
 
-    from music_assistant import MusicAssistant
-    from music_assistant.models.music_provider import MusicProvider
+    from music_assistant.mass import MusicAssistant
 
 
 async def get_image_data(mass: MusicAssistant, path_or_url: str, provider: str) -> bytes:
     """Create thumbnail from image url."""
     # TODO: add local cache here !
     if prov := mass.get_provider(provider):
-        prov: MusicProvider | MetadataProvider
+        assert isinstance(prov, MusicProvider | MetadataProvider)
         if resolved_image := await prov.resolve_image(path_or_url):
             if isinstance(resolved_image, bytes):
                 return resolved_image
@@ -49,7 +50,7 @@ async def get_image_data(mass: MusicAssistant, path_or_url: str, provider: str)
     if path_or_url.endswith(("jpg", "JPG", "png", "PNG", "jpeg")):
         if await asyncio.to_thread(os.path.isfile, path_or_url):
             async with aiofiles.open(path_or_url, "rb") as _file:
-                return await _file.read()
+                return cast("bytes", await _file.read())
     # use ffmpeg for embedded images
     if img_data := await get_embedded_image(path_or_url):
         return img_data
@@ -72,7 +73,7 @@ async def get_image_thumb(
     if not size and image_format.encode() in img_data:
         return img_data
 
-    def _create_image():
+    def _create_image() -> bytes:
         data = BytesIO()
         try:
             img = Image.open(BytesIO(img_data))
@@ -90,12 +91,14 @@ async def get_image_thumb(
 
 
 async def create_collage(
-    mass: MusicAssistant, images: Iterable[MediaItemImage], dimensions: tuple[int] = (1500, 1500)
+    mass: MusicAssistant,
+    images: Iterable[MediaItemImage],
+    dimensions: tuple[int, int] = (1500, 1500),
 ) -> bytes:
     """Create a basic collage image from multiple image urls."""
     image_size = 250
 
-    def _new_collage():
+    def _new_collage() -> ImageClass:
         return Image.new("RGB", (dimensions[0], dimensions[1]), color=(255, 255, 255, 255))
 
     collage = await asyncio.to_thread(_new_collage)
@@ -122,7 +125,7 @@ async def create_collage(
                     del img_data
                     break
 
-    def _save_collage():
+    def _save_collage() -> bytes:
         final_data = BytesIO()
         collage.convert("RGB").save(final_data, "JPEG", optimize=True)
         return final_data.getvalue()
@@ -136,4 +139,5 @@ async def get_icon_string(icon_path: str) -> str:
     assert ext == "svg"
     async with aiofiles.open(icon_path) as _file:
         xml_data = await _file.read()
+        assert isinstance(xml_data, str)  # for type checking
         return xml_data.replace("\n", "").strip()
index f639d26f10d99420acba41755a280b0960a4a4a8..837d6f47af01b8457a923f856167f7d7b863c8e5 100644 (file)
@@ -8,13 +8,13 @@ from dataclasses import dataclass
 from typing import TYPE_CHECKING
 from urllib.parse import urlparse
 
-from aiohttp import client_exceptions
+from aiohttp import ClientTimeout, client_exceptions
 from music_assistant_models.errors import InvalidDataError
 
 from music_assistant.helpers.util import detect_charset
 
 if TYPE_CHECKING:
-    from music_assistant import MusicAssistant
+    from music_assistant.mass import MusicAssistant
 
 
 LOGGER = logging.getLogger(__name__)
@@ -60,7 +60,7 @@ def parse_m3u(m3u_data: str) -> list[PlaylistItem]:
 
     length = None
     title = None
-    stream_info = None
+    stream_info: dict[str, str] | None = None
     key = None
 
     for line in m3u_lines:
@@ -148,7 +148,9 @@ async def fetch_playlist(
 ) -> list[PlaylistItem]:
     """Parse an online m3u or pls playlist."""
     try:
-        async with mass.http_session.get(url, allow_redirects=True, timeout=5) as resp:
+        async with mass.http_session.get(
+            url, allow_redirects=True, timeout=ClientTimeout(total=5)
+        ) as resp:
             try:
                 raw_data = await resp.content.read(64 * 1024)
                 # NOTE: using resp.charset is not reliable, we need to detect it ourselves
index b18f0a6b2f17a879beebf7124347298355919ca5..ff65ca7323b96a43e121c332fd8e2ac0d391daf1 100644 (file)
@@ -54,7 +54,7 @@ class AsyncProcess:
         self._stdout = None if stdout is False else stdout
         self._stderr = asyncio.subprocess.DEVNULL if stderr is False else stderr
         self._close_called = False
-        self._returncode: bool | None = None
+        self._returncode: int | None = None
 
     @property
     def closed(self) -> bool:
@@ -87,6 +87,7 @@ class AsyncProcess:
         # send interrupt signal to process when we're cancelled
         await self.close(send_signal=exc_type in (GeneratorExit, asyncio.CancelledError))
         self._returncode = self.returncode
+        return None
 
     async def start(self) -> None:
         """Perform Async init of process."""
@@ -120,6 +121,8 @@ class AsyncProcess:
         """Read exactly n bytes from the process stdout (or less if eof)."""
         if self._close_called:
             return b""
+        assert self.proc is not None  # for type checking
+        assert self.proc.stdout is not None  # for type checking
         try:
             return await self.proc.stdout.readexactly(n)
         except asyncio.IncompleteReadError as err:
@@ -134,12 +137,16 @@ class AsyncProcess:
         """
         if self._close_called:
             return b""
+        assert self.proc is not None  # for type checking
+        assert self.proc.stdout is not None  # for type checking
         return await self.proc.stdout.read(n)
 
     async def write(self, data: bytes) -> None:
         """Write data to process stdin."""
         if self.closed:
             raise RuntimeError("write called while process already done")
+        assert self.proc is not None  # for type checking
+        assert self.proc.stdin is not None  # for type checking
         self.proc.stdin.write(data)
         with suppress(BrokenPipeError, ConnectionResetError):
             await self.proc.stdin.drain()
@@ -148,6 +155,8 @@ class AsyncProcess:
         """Write end of file to to process stdin."""
         if self.closed:
             return
+        assert self.proc is not None  # for type checking
+        assert self.proc.stdin is not None  # for type checking
         try:
             if self.proc.stdin.can_write_eof():
                 self.proc.stdin.write_eof()
@@ -165,6 +174,8 @@ class AsyncProcess:
         """Read line from stderr."""
         if self.returncode is not None:
             return b""
+        assert self.proc is not None  # for type checking
+        assert self.proc.stderr is not None  # for type checking
         try:
             return await self.proc.stderr.readline()
         except ValueError as err:
@@ -180,6 +191,7 @@ class AsyncProcess:
 
     async def iter_stderr(self) -> AsyncGenerator[str, None]:
         """Iterate lines from the stderr stream as string."""
+        line: str | bytes
         while True:
             line = await self.read_stderr()
             if line == b"":
@@ -198,13 +210,15 @@ class AsyncProcess:
         if self.closed:
             raise RuntimeError("communicate called while process already done")
         # abort existing readers on stderr/stdout first before we send communicate
-        waiter: asyncio.Future
-        if self.proc.stdout and (waiter := self.proc.stdout._waiter):
-            self.proc.stdout._waiter = None
+        waiter: asyncio.Future[None]
+        assert self.proc is not None  # for type checking
+        # _waiter is attribute of StreamReader
+        if self.proc.stdout and (waiter := self.proc.stdout._waiter):  # type: ignore[attr-defined]
+            self.proc.stdout._waiter = None  # type: ignore[attr-defined]
             if waiter and not waiter.done():
                 waiter.set_exception(asyncio.CancelledError())
-        if self.proc.stderr and (waiter := self.proc.stderr._waiter):
-            self.proc.stderr._waiter = None
+        if self.proc.stderr and (waiter := self.proc.stderr._waiter):  # type: ignore[attr-defined]
+            self.proc.stderr._waiter = None  # type: ignore[attr-defined]
             if waiter and not waiter.done():
                 waiter.set_exception(asyncio.CancelledError())
         stdout, stderr = await asyncio.wait_for(self.proc.communicate(input), timeout)
@@ -220,13 +234,13 @@ class AsyncProcess:
         if self.proc.stdin and not self.proc.stdin.is_closing():
             self.proc.stdin.close()
         # abort existing readers on stderr/stdout first before we send communicate
-        waiter: asyncio.Future
-        if self.proc.stdout and (waiter := self.proc.stdout._waiter):
-            self.proc.stdout._waiter = None
+        waiter: asyncio.Future[None]
+        if self.proc.stdout and (waiter := self.proc.stdout._waiter):  # type: ignore[attr-defined]
+            self.proc.stdout._waiter = None  # type: ignore[attr-defined]
             if waiter and not waiter.done():
                 waiter.set_exception(asyncio.CancelledError())
-        if self.proc.stderr and (waiter := self.proc.stderr._waiter):
-            self.proc.stderr._waiter = None
+        if self.proc.stderr and (waiter := self.proc.stderr._waiter):  # type: ignore[attr-defined]
+            self.proc.stderr._waiter = None  # type: ignore[attr-defined]
             if waiter and not waiter.done():
                 waiter.set_exception(asyncio.CancelledError())
         await asyncio.sleep(0)  # yield to loop
@@ -261,6 +275,7 @@ class AsyncProcess:
     async def wait(self) -> int:
         """Wait for the process and return the returncode."""
         if self._returncode is None:
+            assert self.proc is not None
             self._returncode = await self.proc.wait()
         return self._returncode
 
@@ -275,6 +290,7 @@ async def check_output(*args: str, env: dict[str, str] | None = None) -> tuple[i
         *args, stderr=asyncio.subprocess.STDOUT, stdout=asyncio.subprocess.PIPE, env=env
     )
     stdout, _ = await proc.communicate()
+    assert proc.returncode is not None  # for type checking
     return (proc.returncode, stdout)
 
 
@@ -290,4 +306,5 @@ async def communicate(
         stdin=asyncio.subprocess.PIPE if input is not None else None,
     )
     stdout, stderr = await proc.communicate(input)
+    assert proc.returncode is not None  # for type checking
     return (proc.returncode, stdout, stderr)
index 2094fa7d70b91643da3edba9eaa8178e918524f6..647d90f5e56ffd71751c68d03561db7ab7a8f85b 100644 (file)
@@ -32,16 +32,18 @@ LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.tags")
 TAG_SPLITTER = ";"
 
 
-def clean_tuple(values: Iterable[str]) -> tuple:
+def clean_tuple(values: Iterable[str]) -> tuple[str, ...]:
     """Return a tuple with all empty values removed."""
     return tuple(x.strip() for x in values if x not in (None, "", " "))
 
 
-def split_items(org_str: str, allow_unsafe_splitters: bool = False) -> tuple[str, ...]:
+def split_items(
+    org_str: str | list[str] | tuple[str, ...] | None, allow_unsafe_splitters: bool = False
+) -> tuple[str, ...]:
     """Split up a tags string by common splitter."""
     if org_str is None:
         return ()
-    if isinstance(org_str, list):
+    if isinstance(org_str, tuple | list):
         final_items: list[str] = []
         for item in org_str:
             final_items.extend(split_items(item, allow_unsafe_splitters))
@@ -333,7 +335,7 @@ class AudioTags:
         return AlbumType.UNKNOWN
 
     @property
-    def isrc(self) -> tuple[str]:
+    def isrc(self) -> tuple[str, ...]:
         """Return isrc tag(s)."""
         for tag_name in ("isrc", "tsrc"):
             if tag := self.tags.get(tag_name):
@@ -398,7 +400,7 @@ class AudioTags:
         return None
 
     @classmethod
-    def parse(cls, raw: dict) -> AudioTags:
+    def parse(cls, raw: dict[str, Any]) -> AudioTags:
         """Parse instance from raw ffmpeg info output."""
         audio_stream = next((x for x in raw["streams"] if x["codec_type"] == "audio"), None)
         if audio_stream is None:
@@ -435,7 +437,7 @@ class AudioTags:
             filename=raw["format"]["filename"],
         )
 
-    def get(self, key: str, default=None) -> Any:
+    def get(self, key: str, default: Any | None = None) -> Any:
         """Get tag by key."""
         return self.tags.get(key, default)
 
@@ -532,7 +534,7 @@ def get_file_duration(input_file: str) -> float:
         # extract duration from ffmpeg output
         duration_str = res.split("time=")[-1].split(" ")[0].strip()
         duration_parts = duration_str.split(":")
-        duration = 0
+        duration = 0.0
         for part in duration_parts:
             duration = duration * 60 + float(part)
         return duration
@@ -547,10 +549,11 @@ def parse_tags_mutagen(input_file: str) -> dict[str, Any]:
 
     NOT Async friendly.
     """
-    result = {}
+    result: dict[str, Any] = {}
     try:
         # TODO: extend with more tags and file types!
-        tags = mutagen.File(input_file)
+        # https://mutagen.readthedocs.io/en/latest/user/gettingstarted.html
+        tags = mutagen.File(input_file)  # type: ignore[attr-defined]
         if tags is None or not tags.tags:
             return result
         tags = dict(tags.tags)
@@ -604,7 +607,7 @@ async def get_embedded_image(input_file: str) -> bytes | None:
 
     Input_file may be a (local) filename or URL accessible by ffmpeg.
     """
-    args = (
+    args = [
         "ffmpeg",
         "-hide_banner",
         "-loglevel",
@@ -617,7 +620,7 @@ async def get_embedded_image(input_file: str) -> bytes | None:
         "-f",
         "mjpeg",
         "-",
-    )
+    ]
     async with AsyncProcess(
         args, stdin=False, stdout=True, stderr=None, name="ffmpeg_image"
     ) as ffmpeg:
index 74a957387b421112fe0e5b351e5eb1c866c6d572..fff16ca2c0988b1a5a8355774d40dfcc20a97c06 100644 (file)
@@ -8,6 +8,7 @@ from collections import deque
 from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
 from contextlib import asynccontextmanager
 from contextvars import ContextVar
+from types import TracebackType
 from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar
 
 from music_assistant_models.errors import ResourceTemporarilyUnavailable, RetriesExhausted
@@ -33,13 +34,13 @@ class Throttler:
     - Return the delay caused by acquire()
     """
 
-    def __init__(self, rate_limit: int, period=1.0):
+    def __init__(self, rate_limit: int, period: float = 1.0) -> None:
         """Initialize the Throttler."""
         self.rate_limit = rate_limit
         self.period = period
         self._task_logs: deque[float] = deque()
 
-    def _flush(self):
+    def _flush(self) -> None:
         now = time.monotonic()
         while self._task_logs:
             if now - self._task_logs[0] > self.period:
@@ -67,21 +68,28 @@ class Throttler:
         """Wait until the lock is acquired, return the time delay."""
         return await self.acquire()
 
-    async def __aexit__(self, exc_type, exc, tb):
+    async def __aexit__(
+        self,
+        exc_type: type[BaseException] | None,
+        exc_val: BaseException | None,
+        exc_tb: TracebackType | None,
+    ) -> bool | None:
         """Nothing to do on exit."""
 
 
 class ThrottlerManager:
     """Throttler manager that extends asyncio Throttle by retrying."""
 
-    def __init__(self, rate_limit: int, period: float = 1, retry_attempts=5, initial_backoff=5):
+    def __init__(
+        self, rate_limit: int, period: float = 1, retry_attempts: int = 5, initial_backoff: int = 5
+    ):
         """Initialize the AsyncThrottledContextManager."""
         self.retry_attempts = retry_attempts
         self.initial_backoff = initial_backoff
         self.throttler = Throttler(rate_limit, period)
 
     @asynccontextmanager
-    async def acquire(self) -> AsyncGenerator[None, float]:
+    async def acquire(self) -> AsyncGenerator[float, None]:
         """Acquire a free slot from the Throttler, returns the throttled time."""
         if BYPASS_THROTTLER.get():
             yield 0
@@ -92,10 +100,12 @@ class ThrottlerManager:
     async def bypass(self) -> AsyncGenerator[None, None]:
         """Bypass the throttler."""
         try:
-            token = BYPASS_THROTTLER.set(True)
+            BYPASS_THROTTLER.set(True)
             yield None
         finally:
-            BYPASS_THROTTLER.reset(token)
+            # TODO: token is unbound here
+            # BYPASS_THROTTLER.reset(token)
+            ...
 
 
 def throttle_with_retries(
@@ -107,7 +117,7 @@ def throttle_with_retries(
     async def wrapper(self: _ProviderT, *args: _P.args, **kwargs: _P.kwargs) -> _R:
         """Call async function using the throttler with retries."""
         # the trottler attribute must be present on the class
-        throttler: ThrottlerManager = self.throttler
+        throttler: ThrottlerManager = self.throttler  # type: ignore[attr-defined]
         backoff_time = throttler.initial_backoff
         async with throttler.acquire() as delay:
             if delay != 0:
index b560f3e556d29016301cbd465aea3c0842221d6d..47041a8cb0c783eb75b2f46c7d07c554518b9f43 100644 (file)
@@ -130,6 +130,8 @@ def create_didl_metadata(media: PlayerMedia) -> str:
         )
     duration_str = str(datetime.timedelta(seconds=media.duration or 0)) + ".000"
 
+    assert media.queue_item_id is not None  # for type checking
+
     return (
         '<DIDL-Lite xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:upnp="urn:schemas-upnp-org:metadata-1-0/upnp/" xmlns="urn:schemas-upnp-org:metadata-1-0/DIDL-Lite/" xmlns:r="urn:schemas-rinconnetworks-com:metadata-1-0/">'
         f'<item id="{media.queue_item_id or xmlescape(media.uri)}" restricted="true" parentID="{media.queue_id or ""}">'
index 7dca3724835703d08c2c860cb468f99005100a54..c38b3db6fd8e0d5ac39aa3b85275bbabe5d940e5 100644 (file)
@@ -13,13 +13,13 @@ import socket
 import urllib.error
 import urllib.parse
 import urllib.request
-from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
+from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Coroutine
 from contextlib import suppress
 from functools import lru_cache
 from importlib.metadata import PackageNotFoundError
 from importlib.metadata import version as pkg_version
 from types import TracebackType
-from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar
+from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar, cast
 from urllib.parse import urlparse
 
 import cchardet as chardet
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
 
     from zeroconf.asyncio import AsyncServiceInfo
 
-    from music_assistant import MusicAssistant
+    from music_assistant.mass import MusicAssistant
     from music_assistant.models import ProviderModuleType
 
 LOGGER = logging.getLogger(__name__)
@@ -221,33 +221,34 @@ def clean_stream_title(line: str) -> str:
     return line
 
 
-async def get_ip_addresses(include_ipv6: bool = False) -> tuple[str]:
+async def get_ip_addresses(include_ipv6: bool = False) -> tuple[str, ...]:
     """Return all IP-adresses of all network interfaces."""
 
-    def call() -> set[str]:
+    def call() -> tuple[str, ...]:
         result: list[tuple[int, str]] = []
         adapters = ifaddr.get_adapters()
         for adapter in adapters:
             for ip in adapter.ips:
                 if ip.is_IPv6 and not include_ipv6:
                     continue
-                if ip.ip.startswith(("127", "169.254")):
+                ip_str = str(ip.ip)
+                if ip_str.startswith(("127", "169.254")):
                     # filter out IPv4 loopback/APIPA address
                     continue
-                if ip.ip.startswith(("::1", "::ffff:", "fe80")):
+                if ip_str.startswith(("::1", "::ffff:", "fe80")):
                     # filter out IPv6 loopback/link-local address
                     continue
-                if ip.ip.startswith(("192.168.",)):
+                if ip_str.startswith(("192.168.",)):
                     # we rank the 192.168 range a bit higher as its most
                     # often used as the private network subnet
                     score = 2
-                elif ip.ip.startswith(("172.", "10.", "192.")):
+                elif ip_str.startswith(("172.", "10.", "192.")):
                     # we rank the 172 range a bit lower as its most
                     # often used as the private docker network
                     score = 1
                 else:
                     score = 0
-                result.append((score, ip.ip))
+                result.append((score, ip_str))
         result.sort(key=lambda x: x[0], reverse=True)
         return tuple(ip[1] for ip in result)
 
@@ -406,7 +407,7 @@ async def get_package_version(pkg_name: str) -> str | None:
 async def is_hass_supervisor() -> bool:
     """Return if we're running inside the HA Supervisor (e.g. HAOS)."""
 
-    def _check():
+    def _check() -> bool:
         try:
             urllib.request.urlopen("http://supervisor/core", timeout=1)
         except urllib.error.URLError as err:
@@ -424,7 +425,9 @@ async def load_provider_module(domain: str, requirements: list[str]) -> Provider
 
     @lru_cache
     def _get_provider_module(domain: str) -> ProviderModuleType:
-        return importlib.import_module(f".{domain}", "music_assistant.providers")
+        return cast(
+            "ProviderModuleType", importlib.import_module(f".{domain}", "music_assistant.providers")
+        )
 
     # ensure module requirements are met
     for requirement in requirements:
@@ -474,8 +477,8 @@ async def get_free_space(folder: str) -> float:
     def _get_free_space(folder: str) -> float:
         """Return free space on given folderpath in GB."""
         try:
-            if res := shutil.disk_usage(folder):
-                return res.free / float(1 << 30)
+            res = shutil.disk_usage(folder)
+            return res.free / float(1 << 30)
         except (FileNotFoundError, OSError, PermissionError):
             return 0.0
 
@@ -488,8 +491,8 @@ async def get_free_space_percentage(folder: str) -> float:
     def _get_free_space(folder: str) -> float:
         """Return free space on given folderpath in GB."""
         try:
-            if res := shutil.disk_usage(folder):
-                return res.free / res.total * 100
+            res = shutil.disk_usage(folder)
+            return res.free / res.total * 100
         except (FileNotFoundError, OSError, PermissionError):
             return 0.0
 
@@ -528,7 +531,7 @@ def get_primary_ip_address_from_zeroconf(discovery_info: AsyncServiceInfo) -> st
     return None
 
 
-def get_port_from_zeroconf(discovery_info: AsyncServiceInfo) -> str | None:
+def get_port_from_zeroconf(discovery_info: AsyncServiceInfo) -> int | None:
     """Get primary IP address from zeroconf discovery info."""
     return discovery_info.port
 
@@ -542,11 +545,12 @@ async def close_async_generator(agen: AsyncGenerator[Any, None]) -> None:
     await agen.aclose()
 
 
-async def detect_charset(data: bytes, fallback="utf-8") -> str:
+async def detect_charset(data: bytes, fallback: str = "utf-8") -> str:
     """Detect charset of raw data."""
     try:
-        detected = await asyncio.to_thread(chardet.detect, data)
+        detected: dict[str, Any] = await asyncio.to_thread(chardet.detect, data)
         if detected and detected["encoding"] and detected["confidence"] > 0.75:
+            assert isinstance(detected["encoding"], str)  # for type checking
             return detected["encoding"]
     except Exception as err:
         LOGGER.debug("Failed to detect charset: %s", err)
@@ -599,25 +603,26 @@ class TaskManager:
     def __init__(self, mass: MusicAssistant, limit: int = 0):
         """Initialize the TaskManager."""
         self.mass = mass
-        self._tasks: list[asyncio.Task] = []
+        self._tasks: list[asyncio.Task[None]] = []
         self._semaphore = asyncio.Semaphore(limit) if limit else None
 
-    def create_task(self, coro: Coroutine) -> asyncio.Task:
+    def create_task(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task[None]:
         """Create a new task and add it to the manager."""
         task = self.mass.create_task(coro)
         self._tasks.append(task)
         return task
 
-    async def create_task_with_limit(self, coro: Coroutine) -> None:
+    async def create_task_with_limit(self, coro: Coroutine[Any, Any, None]) -> None:
         """Create a new task with semaphore limit."""
         assert self._semaphore is not None
 
-        def task_done_callback(_task: asyncio.Task) -> None:
+        def task_done_callback(_task: asyncio.Task[None]) -> None:
+            assert self._semaphore is not None  # for type checking
             self._tasks.remove(task)
             self._semaphore.release()
 
         await self._semaphore.acquire()
-        task: asyncio.Task = self.create_task(coro)
+        task: asyncio.Task[None] = self.create_task(coro)
         task.add_done_callback(task_done_callback)
 
     async def __aenter__(self) -> Self:
@@ -634,6 +639,7 @@ class TaskManager:
         if len(self._tasks) > 0:
             await asyncio.wait(self._tasks)
             self._tasks.clear()
+        return None
 
 
 _R = TypeVar("_R")
@@ -650,7 +656,7 @@ def lock(
         """Call async function using the throttler with retries."""
         if not (func_lock := getattr(func, "lock", None)):
             func_lock = asyncio.Lock()
-            func.lock = func_lock
+            func.lock = func_lock  # type: ignore[attr-defined]
         async with func_lock:
             return await func(*args, **kwargs)
 
@@ -664,7 +670,7 @@ class TimedAsyncGenerator:
     Source: https://medium.com/@dmitry8912/implementing-timeouts-in-pythons-asynchronous-generators-f7cbaa6dc1e9
     """
 
-    def __init__(self, iterable, timeout=0):
+    def __init__(self, iterable: AsyncIterator[Any], timeout: int = 0):
         """
         Initialize the AsyncTimedIterable.
 
@@ -674,10 +680,10 @@ class TimedAsyncGenerator:
         """
 
         class AsyncTimedIterator:
-            def __init__(self):
+            def __init__(self) -> None:
                 self._iterator = iterable.__aiter__()
 
-            async def __anext__(self):
+            async def __anext__(self) -> Any:
                 result = await asyncio.wait_for(self._iterator.__anext__(), int(timeout))
                 if not result:
                     raise StopAsyncIteration
@@ -685,6 +691,6 @@ class TimedAsyncGenerator:
 
         self._factory = AsyncTimedIterator
 
-    def __aiter__(self):
+    def __aiter__(self):  # type: ignore[no-untyped-def]
         """Return the async iterator."""
         return self._factory()
index 9c28a3014da7eca0511e590a3555d28b09f8bf2d..137ad9255c4ed6a03f7f80c82230105d7332bda1 100644 (file)
@@ -9,7 +9,9 @@ from aiohttp import web
 
 if TYPE_CHECKING:
     import logging
-    from collections.abc import Awaitable, Callable
+    from collections.abc import Callable
+
+    from aiohttp.typedefs import Handler
 
 
 MAX_CLIENT_SIZE: Final = 1024**2 * 16
@@ -30,7 +32,7 @@ class Webserver:
         self._apprunner: web.AppRunner | None = None
         self._webapp: web.Application | None = None
         self._tcp_site: web.TCPSite | None = None
-        self._static_routes: list[tuple[str, str, Awaitable]] | None = None
+        self._static_routes: list[tuple[str, str, Handler]] | None = None
         self._dynamic_routes: dict[str, Callable] | None = {} if enable_dynamic_routes else None
         self._bind_port: int | None = None
 
@@ -39,7 +41,7 @@ class Webserver:
         bind_ip: str | None,
         bind_port: int,
         base_url: str,
-        static_routes: list[tuple[str, str, Awaitable]] | None = None,
+        static_routes: list[tuple[str, str, Handler]] | None = None,
         static_content: tuple[str, str, str] | None = None,
     ) -> None:
         """Async initialize of module."""
@@ -95,12 +97,12 @@ class Webserver:
             await self._webapp.cleanup()
 
     @property
-    def base_url(self):
+    def base_url(self) -> str:
         """Return the base URL of this webserver."""
         return self._base_url
 
     @property
-    def port(self):
+    def port(self) -> int | None:
         """Return the port of this webserver."""
         return self._bind_port
 
@@ -121,6 +123,7 @@ class Webserver:
         self._dynamic_routes[key] = handler
 
         def _remove():
+            assert self._dynamic_routes is not None  # for type checking
             return self._dynamic_routes.pop(key)
 
         return _remove
@@ -142,6 +145,7 @@ class Webserver:
         """Redirect request to correct destination."""
         # find handler for the request
         for key in (f"{request.method}.{request.path}", f"*.{request.path}"):
+            assert self._dynamic_routes is not None  # for type checking
             if handler := self._dynamic_routes.get(key):
                 return await handler(request)
         # deny all other requests
index e12cbe23be36ff4675e3b1b56e1927a24f46ec24..c9fb15ef405814a3b70d85b5ba06fb8966b709ea 100644 (file)
@@ -123,7 +123,8 @@ enable_error_code = [
 ]
 exclude = [
   '^music_assistant/controllers/.*$',
-  '^music_assistant/helpers/.*$',
+  '^music_assistant/helpers/app_vars.py',
+  '^music_assistant/helpers/webserver.py',
   '^music_assistant/models/.*$',
   '^music_assistant/providers/apple_music/.*$',
   '^music_assistant/providers/bluesound/.*$',