Typing fixes for the stream controller (#2540)
authorOzGav <gavnosp@hotmail.com>
Tue, 18 Nov 2025 09:53:26 +0000 (19:53 +1000)
committerGitHub <noreply@github.com>
Tue, 18 Nov 2025 09:53:26 +0000 (10:53 +0100)
* work in progress

* mypy fixes for streams.py

* Fix typos

* Another typo

* PR review changes

* Fixes after merge conflicts

* Revert unnecessary change

* Remove unnecessary comments

* More changes post conflict resolve

* Remove exclude

* organise imports

* Updates post other commits

* Catch possible no stream details

* Updates for new commits

* Fix type error post conflict merge

* Simplify

* Simplify with type hints to config controller

* Typing adjustments

* Fix missing return type

* improve code clarity

* switch to assert for none check

* Fix return type mypy error

* Revert unnecessary bracket

* Use default parameter

music_assistant/controllers/streams.py
pyproject.toml

index 4a8de858789cb88bdeb6502a2cfaae2611a201b9..9b3f99d8545675cd02c11718f69d702acfc4afc9 100644 (file)
@@ -28,8 +28,13 @@ from music_assistant_models.enums import (
     StreamType,
     VolumeNormalizationMode,
 )
-from music_assistant_models.errors import AudioError, QueueEmpty
-from music_assistant_models.media_items import AudioFormat
+from music_assistant_models.errors import (
+    AudioError,
+    InvalidDataError,
+    ProviderUnavailableError,
+    QueueEmpty,
+)
+from music_assistant_models.media_items import AudioFormat, Track
 from music_assistant_models.player_queue import PlayLogEntry
 
 from music_assistant.constants import (
@@ -71,7 +76,6 @@ from music_assistant.helpers.ffmpeg import check_ffmpeg_version, get_ffmpeg_stre
 from music_assistant.helpers.smart_fades import (
     SMART_CROSSFADE_DURATION,
     SmartFadesMixer,
-    SmartFadesMode,
 )
 from music_assistant.helpers.util import (
     divide_chunks,
@@ -83,6 +87,7 @@ from music_assistant.helpers.webserver import Webserver
 from music_assistant.models.core_controller import CoreController
 from music_assistant.models.music_provider import MusicProvider
 from music_assistant.models.plugin import PluginProvider, PluginSource
+from music_assistant.models.smart_fades import SmartFadesMode
 from music_assistant.providers.universal_group.constants import UGP_PREFIX
 from music_assistant.providers.universal_group.player import UniversalGroupPlayer
 
@@ -295,7 +300,7 @@ class StreamsController(CoreController):
         )
         await self._server.setup(
             bind_ip=bind_ip,
-            bind_port=self.publish_port,
+            bind_port=cast("int", self.publish_port),
             base_url=f"http://{self.publish_ip}:{self.publish_port}",
             static_routes=[
                 (
@@ -343,13 +348,10 @@ class StreamsController(CoreController):
         """Resolve the stream URL for the given QueueItem."""
         if not player_id:
             player_id = queue_item.queue_id
-        try:
-            conf_output_codec = await self.mass.config.get_player_config_value(
-                player_id, CONF_OUTPUT_CODEC
-            )
-        except KeyError:
-            conf_output_codec = "flac"
-        output_codec = ContentType.try_parse(conf_output_codec)
+        conf_output_codec = await self.mass.config.get_player_config_value(
+            player_id, CONF_OUTPUT_CODEC, default="flac", return_type=str
+        )
+        output_codec = ContentType.try_parse(conf_output_codec or "flac")
         fmt = output_codec.value
         # handle raw pcm without exact format specifiers
         if output_codec.is_pcm() and ";" not in fmt:
@@ -369,7 +371,7 @@ class StreamsController(CoreController):
             fmt = plugin_source.audio_format.content_type.value
         return f"{self._server.base_url}/pluginsource/{plugin_source.id}/{player_id}.{fmt}"
 
-    async def serve_queue_item_stream(self, request: web.Request) -> web.Response:
+    async def serve_queue_item_stream(self, request: web.Request) -> web.StreamResponse:
         """Stream single queueitem audio to a player."""
         self._log_request(request)
         queue_id = request.match_info["queue_id"]
@@ -397,6 +399,9 @@ class StreamsController(CoreController):
                 raise web.HTTPNotFound(reason=f"No streamdetails for Queue item: {queue_item_id}")
 
         # pick output format based on the streamdetails and player capabilities
+        if not queue_player:
+            raise web.HTTPNotFound(reason=f"Unknown Player: {queue_id}")
+
         output_format = await self.get_output_format(
             output_format_str=request.match_info["fmt"],
             player=queue_player,
@@ -419,13 +424,13 @@ class StreamsController(CoreController):
             headers=headers,
         )
         resp.content_type = f"audio/{output_format.output_format_str}"
-        http_profile: str = await self.mass.config.get_player_config_value(
-            queue_id, CONF_HTTP_PROFILE
+        http_profile = await self.mass.config.get_player_config_value(
+            queue_id, CONF_HTTP_PROFILE, default="default", return_type=str
         )
         if http_profile == "forced_content_length" and not queue_item.duration:
             # just set an insane high content length to make sure the player keeps playing
             resp.content_length = get_chunksize(output_format, 12 * 3600)
-        elif http_profile == "forced_content_length":
+        elif http_profile == "forced_content_length" and queue_item.duration:
             # guess content length based on duration
             resp.content_length = get_chunksize(output_format, queue_item.duration)
         elif http_profile == "chunked":
@@ -442,7 +447,7 @@ class StreamsController(CoreController):
             smart_fades_mode = SmartFadesMode.DISABLED
         else:
             smart_fades_mode = await self.mass.config.get_player_config_value(
-                queue.queue_id, CONF_SMART_FADES_MODE
+                queue.queue_id, CONF_SMART_FADES_MODE, return_type=SmartFadesMode
             )
             standard_crossfade_duration = self.mass.config.get_raw_player_config_value(
                 queue.queue_id, CONF_CROSSFADE_DURATION, 10
@@ -560,7 +565,7 @@ class StreamsController(CoreController):
             self.mass.call_later(5, self.mass.player_queues.next(queue_id))
         return resp
 
-    async def serve_queue_flow_stream(self, request: web.Request) -> web.Response:
+    async def serve_queue_flow_stream(self, request: web.Request) -> web.StreamResponse:
         """Stream Queue Flow audio to player."""
         self._log_request(request)
         queue_id = request.match_info["queue_id"]
@@ -609,10 +614,9 @@ class StreamsController(CoreController):
             reason="OK",
             headers=headers,
         )
-        http_profile_value = await self.mass.config.get_player_config_value(
-            queue_id, CONF_HTTP_PROFILE
+        http_profile = await self.mass.config.get_player_config_value(
+            queue_id, CONF_HTTP_PROFILE, default="default", return_type=str
         )
-        http_profile = str(http_profile_value) if http_profile_value is not None else "default"
         if http_profile == "forced_content_length":
             # just set an insane high content length to make sure the player keeps playing
             resp.content_length = get_chunksize(output_format, 12 * 3600)
@@ -683,7 +687,7 @@ class StreamsController(CoreController):
 
         return resp
 
-    async def serve_command_request(self, request: web.Request) -> web.Response:
+    async def serve_command_request(self, request: web.Request) -> web.FileResponse:
         """Handle special 'command' request for a player."""
         self._log_request(request)
         queue_id = request.match_info["queue_id"]
@@ -692,7 +696,7 @@ class StreamsController(CoreController):
             self.mass.create_task(self.mass.player_queues.next(queue_id))
         return web.FileResponse(SILENCE_FILE, headers={"icy-name": "Music Assistant"})
 
-    async def serve_announcement_stream(self, request: web.Request) -> web.Response:
+    async def serve_announcement_stream(self, request: web.Request) -> web.StreamResponse:
         """Stream announcement audio to a player."""
         self._log_request(request)
         player_id = request.match_info["player_id"]
@@ -706,10 +710,9 @@ class StreamsController(CoreController):
         fmt = request.match_info["fmt"]
         audio_format = AudioFormat(content_type=ContentType.try_parse(fmt))
 
-        http_profile_value = await self.mass.config.get_player_config_value(
-            player_id, CONF_HTTP_PROFILE
+        http_profile = await self.mass.config.get_player_config_value(
+            player_id, CONF_HTTP_PROFILE, default="default", return_type=str
         )
-        http_profile = str(http_profile_value) if http_profile_value is not None else "default"
         if http_profile == "forced_content_length":
             # given the fact that an announcement is just a short audio clip,
             # just send it over completely at once so we have a fixed content length
@@ -767,13 +770,13 @@ class StreamsController(CoreController):
 
         return resp
 
-    async def serve_plugin_source_stream(self, request: web.Request) -> web.Response:
+    async def serve_plugin_source_stream(self, request: web.Request) -> web.StreamResponse:
         """Stream PluginSource audio to a player."""
         self._log_request(request)
         plugin_source_id = request.match_info["plugin_source"]
-        provider: PluginProvider | None
-        if not (provider := self.mass.get_provider(plugin_source_id)):
-            raise web.HTTPNotFound(reason=f"Unknown PluginSource: {plugin_source_id}")
+        provider = cast("PluginProvider", self.mass.get_provider(plugin_source_id))
+        if not provider:
+            raise ProviderUnavailableError(f"Unknown PluginSource: {plugin_source_id}")
         # work out output format/details
         player_id = request.match_info["player_id"]
         player = self.mass.players.get(player_id)
@@ -800,10 +803,9 @@ class StreamsController(CoreController):
             headers=headers,
         )
         resp.content_type = f"audio/{output_format.output_format_str}"
-        http_profile_value = await self.mass.config.get_player_config_value(
-            player_id, CONF_HTTP_PROFILE
+        http_profile = await self.mass.config.get_player_config_value(
+            player_id, CONF_HTTP_PROFILE, default="default", return_type=str
         )
-        http_profile = str(http_profile_value) if http_profile_value is not None else "default"
         if http_profile == "forced_content_length":
             # just set an insanely high content length to make sure the player keeps playing
             resp.content_length = get_chunksize(output_format, 12 * 3600)
@@ -817,6 +819,8 @@ class StreamsController(CoreController):
             return resp
 
         # all checks passed, start streaming!
+        if not plugin_source.audio_format:
+            raise InvalidDataError(f"No audio format for plugin source {plugin_source_id}")
         async for chunk in self.get_plugin_source_stream(
             plugin_source_id=plugin_source_id,
             output_format=output_format,
@@ -958,7 +962,7 @@ class StreamsController(CoreController):
             standard_crossfade_duration = 0
         else:
             smart_fades_mode = await self.mass.config.get_player_config_value(
-                queue.queue_id, CONF_SMART_FADES_MODE
+                queue.queue_id, CONF_SMART_FADES_MODE, return_type=SmartFadesMode
             )
             standard_crossfade_duration = self.mass.config.get_raw_player_config_value(
                 queue.queue_id, CONF_CROSSFADE_DURATION, 10
@@ -987,7 +991,7 @@ class StreamsController(CoreController):
                     break
 
             if queue_track.streamdetails is None:
-                raise RuntimeError(
+                raise InvalidDataError(
                     "No Streamdetails known for queue item %s",
                     queue_track.queue_item_id,
                 )
@@ -1001,7 +1005,6 @@ class StreamsController(CoreController):
             # append to play log so the queue controller can work out which track is playing
             play_log_entry = PlayLogEntry(queue_track.queue_item_id)
             queue.flow_mode_stream_log.append(play_log_entry)
-
             # calculate crossfade buffer size
             crossfade_buffer_duration = (
                 SMART_CROSSFADE_DURATION
@@ -1081,8 +1084,9 @@ class StreamsController(CoreController):
                     # we need to correct the bytes_written accordingly so the duration
                     # calculations at the end of the track are correct
                     crossfade_part_len = len(crossfade_part)
-                    bytes_written += crossfade_part_len / 2
+                    bytes_written += int(crossfade_part_len / 2)
                     if last_play_log_entry:
+                        assert last_play_log_entry.seconds_streamed is not None
                         last_play_log_entry.seconds_streamed += (
                             crossfade_part_len / 2 / pcm_sample_size
                         )
@@ -1140,7 +1144,7 @@ class StreamsController(CoreController):
             # this also accounts for crossfade and silence stripping
             seconds_streamed = bytes_written / pcm_sample_size
             queue_track.streamdetails.seconds_streamed = seconds_streamed
-            queue_track.streamdetails.duration = (
+            queue_track.streamdetails.duration = int(
                 queue_track.streamdetails.seek_position + seconds_streamed
             )
             play_log_entry.seconds_streamed = seconds_streamed
@@ -1160,8 +1164,12 @@ class StreamsController(CoreController):
                 del _chunk
             # correct seconds streamed/duration
             last_part_seconds = len(last_fadeout_part) / pcm_sample_size
-            queue_track.streamdetails.seconds_streamed += last_part_seconds
-            queue_track.streamdetails.duration += last_part_seconds
+            streamdetails = queue_track.streamdetails
+            assert streamdetails is not None
+            streamdetails.seconds_streamed = (
+                streamdetails.seconds_streamed or 0
+            ) + last_part_seconds
+            streamdetails.duration = int((streamdetails.duration or 0) + last_part_seconds)
             last_fadeout_part = b""
         total_bytes_sent += bytes_written
         self.logger.info("Finished Queue Flow stream for Queue %s", queue.display_name)
@@ -1216,7 +1224,10 @@ class StreamsController(CoreController):
         player_filter_params: list[str] | None = None,
     ) -> AsyncGenerator[bytes, None]:
         """Get the special plugin source stream."""
-        plugin_prov: PluginProvider = self.mass.get_provider(plugin_source_id)
+        plugin_prov = cast("PluginProvider", self.mass.get_provider(plugin_source_id))
+        if not plugin_prov:
+            raise ProviderUnavailableError(f"Unknown PluginSource: {plugin_source_id}")
+
         plugin_source = plugin_prov.get_source()
         self.logger.debug(
             "Start streaming PluginSource %s to %s using output format %s",
@@ -1229,10 +1240,11 @@ class StreamsController(CoreController):
 
         try:
             async for chunk in get_ffmpeg_stream(
-                audio_input=(
+                audio_input=cast(
+                    "str | AsyncGenerator[bytes, None]",
                     plugin_prov.get_audio_stream(player_id)
                     if plugin_source.stream_type == StreamType.CUSTOM
-                    else plugin_source.path
+                    else plugin_source.path,
                 ),
                 input_format=plugin_source.audio_format,
                 output_format=output_format,
@@ -1275,22 +1287,31 @@ class StreamsController(CoreController):
             filter_rule += ":print_format=json"
             filter_params.append(filter_rule)
         elif streamdetails.volume_normalization_mode == VolumeNormalizationMode.FIXED_GAIN:
-            # apply used defined fixed volume/gain correction
-            gain_correct: float = await self.mass.config.get_core_config_value(
-                self.domain,
+            # apply user defined fixed volume/gain correction
+            config_key = (
                 CONF_VOLUME_NORMALIZATION_FIXED_GAIN_TRACKS
                 if streamdetails.media_type == MediaType.TRACK
-                else CONF_VOLUME_NORMALIZATION_FIXED_GAIN_RADIO,
+                else CONF_VOLUME_NORMALIZATION_FIXED_GAIN_RADIO
             )
-            gain_correct = round(gain_correct, 2)
+            gain_value = await self.mass.config.get_core_config_value(
+                self.domain, config_key, default=0.0, return_type=float
+            )
+            gain_correct = round(gain_value, 2)
             filter_params.append(f"volume={gain_correct}dB")
         elif streamdetails.volume_normalization_mode == VolumeNormalizationMode.MEASUREMENT_ONLY:
             # volume normalization with known loudness measurement
             # apply volume/gain correction
+            target_loudness = (
+                float(streamdetails.target_loudness)
+                if streamdetails.target_loudness is not None
+                else 0.0
+            )
             if streamdetails.prefer_album_loudness and streamdetails.loudness_album is not None:
-                gain_correct = streamdetails.target_loudness - streamdetails.loudness_album
+                gain_correct = target_loudness - float(streamdetails.loudness_album)
+            elif streamdetails.loudness is not None:
+                gain_correct = target_loudness - float(streamdetails.loudness)
             else:
-                gain_correct = streamdetails.target_loudness - streamdetails.loudness
+                gain_correct = 0.0
             gain_correct = round(gain_correct, 2)
             filter_params.append(f"volume={gain_correct}dB")
         streamdetails.volume_normalization_gain_correct = gain_correct
@@ -1308,7 +1329,7 @@ class StreamsController(CoreController):
             " - using fade-in: %s"
             " - using volume normalization: %s",
             queue_item.name,
-            queue_item.streamdetails.uri,
+            streamdetails.uri,
             allow_buffer,
             streamdetails.fade_in,
             streamdetails.volume_normalization_mode,
@@ -1343,7 +1364,7 @@ class StreamsController(CoreController):
                     self.logger.debug(
                         "First audio chunk received for %s (%s) after %.2f seconds",
                         queue_item.name,
-                        queue_item.streamdetails.uri,
+                        streamdetails.uri,
                         asyncio.get_event_loop().time() - stream_started_at,
                     )
                 # handle optional fade-in
@@ -1352,7 +1373,10 @@ class StreamsController(CoreController):
                         fade_in_buffer += chunk
                     elif fade_in_buffer:
                         async for fade_chunk in get_ffmpeg_stream(
-                            audio_input=fade_in_buffer + chunk,
+                            # NOTE: get_ffmpeg_stream signature says str | AsyncGenerator
+                            # but FFMpeg class actually accepts bytes too. This works at
+                            # runtime but needs type: ignore for mypy.
+                            audio_input=fade_in_buffer + chunk,  # type: ignore[arg-type]
                             input_format=pcm_format,
                             output_format=pcm_format,
                             filter_params=["afade=type=in:start_time=0:duration=3"],
@@ -1366,7 +1390,7 @@ class StreamsController(CoreController):
                 del chunk
             finished = True
         except AudioError as err:
-            queue_item.streamdetails.stream_error = True
+            streamdetails.stream_error = True
             queue_item.available = False
             if raise_on_error:
                 raise
@@ -1375,7 +1399,7 @@ class StreamsController(CoreController):
             self.logger.error(
                 "AudioError while streaming queue item %s (%s): %s",
                 queue_item.name,
-                queue_item.streamdetails.uri,
+                streamdetails.uri,
                 err,
             )
         finally:
@@ -1596,6 +1620,10 @@ class StreamsController(CoreController):
                     queue.queue_id, queue_item.queue_item_id
                 )
                 # set index_in_buffer to prevent our next track is overwritten while preloading
+                if next_queue_item.streamdetails is None:
+                    raise InvalidDataError(
+                        f"No streamdetails for next queue item {next_queue_item.queue_item_id}"
+                    )
                 queue.index_in_buffer = self.mass.player_queues.index_by_id(
                     queue.queue_id, next_queue_item.queue_item_id
                 )
@@ -1657,7 +1685,7 @@ class StreamsController(CoreController):
                         fade_in_part=buffer,
                         fade_out_part=fade_out_data,
                         fade_in_streamdetails=next_queue_item.streamdetails,
-                        fade_out_streamdetails=queue_item.streamdetails,
+                        fade_out_streamdetails=streamdetails,
                         pcm_format=pcm_format,
                         standard_crossfade_duration=standard_crossfade_duration,
                         mode=smart_fades_mode,
@@ -1708,11 +1736,11 @@ class StreamsController(CoreController):
         # this also accounts for crossfade and silence stripping
         seconds_streamed = bytes_written / pcm_format.pcm_sample_size
         streamdetails.seconds_streamed = seconds_streamed
-        streamdetails.duration = streamdetails.seek_position + seconds_streamed
+        streamdetails.duration = int(streamdetails.seek_position + seconds_streamed)
         self.logger.debug(
             "Finished Streaming queue track: %s (%s) on queue %s "
             "- crossfade data prepared for next track: %s",
-            queue_item.streamdetails.uri,
+            streamdetails.uri,
             queue_item.name,
             queue.display_name,
             next_queue_item.name if next_queue_item else "N/A",
@@ -1746,16 +1774,17 @@ class StreamsController(CoreController):
     ) -> AudioFormat:
         """Parse (player specific) output format details for given format string."""
         content_type: ContentType = ContentType.try_parse(output_format_str)
-        supported_rates_conf: list[
-            tuple[str, str]
-        ] = await self.mass.config.get_player_config_value(
-            player.player_id, CONF_SAMPLE_RATES, unpack_splitted_values=True
+        supported_rates_conf = cast(
+            "list[tuple[str, str]]",
+            await self.mass.config.get_player_config_value(
+                player.player_id, CONF_SAMPLE_RATES, unpack_splitted_values=True
+            ),
         )
         output_channels_str = self.mass.config.get_raw_player_config_value(
             player.player_id, CONF_OUTPUT_CHANNELS, "stereo"
         )
-        supported_sample_rates: tuple[int] = tuple(int(x[0]) for x in supported_rates_conf)
-        supported_bit_depths: tuple[int] = tuple(int(x[1]) for x in supported_rates_conf)
+        supported_sample_rates = tuple(int(x[0]) for x in supported_rates_conf)
+        supported_bit_depths = tuple(int(x[1]) for x in supported_rates_conf)
 
         player_max_bit_depth = max(supported_bit_depths)
         output_bit_depth = min(content_bit_depth, player_max_bit_depth)
@@ -1785,12 +1814,13 @@ class StreamsController(CoreController):
         player: Player,
     ) -> AudioFormat:
         """Parse (player specific) flow stream PCM format."""
-        supported_rates_conf: list[
-            tuple[str, str]
-        ] = await self.mass.config.get_player_config_value(
-            player.player_id, CONF_SAMPLE_RATES, unpack_splitted_values=True
+        supported_rates_conf = cast(
+            "list[tuple[str, str]]",
+            await self.mass.config.get_player_config_value(
+                player.player_id, CONF_SAMPLE_RATES, unpack_splitted_values=True
+            ),
         )
-        supported_sample_rates: tuple[int] = tuple(int(x[0]) for x in supported_rates_conf)
+        supported_sample_rates = tuple(int(x[0]) for x in supported_rates_conf)
         output_sample_rate = INTERNAL_PCM_FORMAT.sample_rate
         for sample_rate in (192000, 96000, 48000, 44100):
             if sample_rate in supported_sample_rates:
@@ -1826,11 +1856,9 @@ class StreamsController(CoreController):
             self.logger.debug("Skipping crossfade: next item is not a track")
             return False
         if (
-            queue_item.media_type == MediaType.TRACK
-            and next_item.media_type == MediaType.TRACK
-            and queue_item.media_item
+            isinstance(queue_item.media_item, Track)
+            and isinstance(next_item.media_item, Track)
             and queue_item.media_item.album
-            and next_item.media_item
             and next_item.media_item.album
             and queue_item.media_item.album == next_item.media_item.album
             and not self.mass.config.get_raw_core_config_value(
@@ -1847,6 +1875,9 @@ class StreamsController(CoreController):
         if (
             not flow_mode
             and next_item.streamdetails
+            and queue_item.streamdetails
+            and next_item.streamdetails.audio_format
+            and queue_item.streamdetails.audio_format
             and (
                 queue_item.streamdetails.audio_format.sample_rate
                 != next_item.streamdetails.audio_format.sample_rate
@@ -1863,7 +1894,6 @@ class StreamsController(CoreController):
         ):
             self.logger.debug("Skipping crossfade: sample rate mismatch")
             return False
-
         return True
 
     async def _periodic_garbage_collection(self) -> None:
index 4792b3968a31817021f09896c3af8717746ae80e..a65ad10891af368ad634c7e25a4fbe88dd264c64 100644 (file)
@@ -138,7 +138,6 @@ exclude = [
   '^music_assistant/controllers/media/tracks.py*$',
   '^music_assistant/controllers/music.py$',
   '^music_assistant/controllers/player_queues.py$',
-  '^music_assistant/controllers/streams.py$',
   '^music_assistant/helpers/app_vars.py',
   '^music_assistant/providers/apple_music/.*$',
   '^music_assistant/providers/bluesound/.*$',