Add caching of audio data to fix streams not starting fast enough (#1989)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Fri, 28 Feb 2025 14:21:12 +0000 (15:21 +0100)
committerGitHub <noreply@github.com>
Fri, 28 Feb 2025 14:21:12 +0000 (15:21 +0100)
Add caching of audio data which fixes issues with providers which are slow with sending audio data or have a file format which requires the data to be stored in a file first before further processing

music_assistant/controllers/player_queues.py
music_assistant/controllers/streams.py
music_assistant/helpers/audio.py
music_assistant/helpers/ffmpeg.py
music_assistant/helpers/util.py
music_assistant/providers/apple_music/__init__.py
pyproject.toml
requirements_all.txt

index e582b2cba94e597b9cb20cac09d4ff83495fa377..facac8604a576a789005e0493899bc35e37d367a 100644 (file)
@@ -788,21 +788,18 @@ class PlayerQueuesController(CoreController):
         ):
             seek_position = max(0, int((resume_position_ms - 500) / 1000))
 
-        # load item (which also fetches the streamdetails)
-        # do this here to catch unavailable items early
-        next_index = self._get_next_index(queue_id, index, allow_repeat=False)
-        await self._load_item(
-            queue_item,
-            next_index,
-            is_start=True,
-            seek_position=seek_position,
-            fade_in=fade_in,
-        )
-
         # send play_media request to player
         # NOTE that we debounce this a bit to account for someone hitting the next button
         # like a madman. This will prevent the player from being overloaded with requests.
-        async def play_media():
+        async def play_media() -> None:
+            next_index = self._get_next_index(queue_id, index, allow_repeat=False)
+            await self._load_item(
+                queue_item,
+                next_index,
+                is_start=True,
+                seek_position=seek_position,
+                fade_in=fade_in,
+            )
             await self.mass.players.play_media(
                 player_id=queue_id,
                 media=await self.player_media_from_queue_item(queue_item, queue.flow_mode),
@@ -813,7 +810,7 @@ class PlayerQueuesController(CoreController):
         # we set a flag to notify the update logic that we're transitioning to a new track
         self._transitioning_players.add(queue_id)
         self.mass.call_later(
-            1.5 if debounce else 0.1,
+            1 if debounce else 0,
             play_media,
             task_id=f"play_media_{queue_id}",
         )
index 383c3d09e2b928ed4bc41d1b2f865e31a806921b..d45dc904cafbcaadbaf98fd6d6319488259a6131 100644 (file)
@@ -343,7 +343,6 @@ class StreamsController(CoreController):
         # inform the queue that the track is now loaded in the buffer
         # so for example the next track can be enqueued
         self.mass.player_queues.track_loaded_in_buffer(queue_id, queue_item_id)
-
         async for chunk in get_ffmpeg_stream(
             audio_input=self.get_queue_item_stream(
                 queue_item=queue_item,
@@ -903,8 +902,10 @@ class StreamsController(CoreController):
         # collect all arguments for ffmpeg
         streamdetails = queue_item.streamdetails
         assert streamdetails
+        stream_type = streamdetails.stream_type
         filter_params = []
         extra_input_args = streamdetails.extra_input_args or []
+
         # handle volume normalization
         gain_correct: float | None = None
         if streamdetails.volume_normalization_mode == VolumeNormalizationMode.DYNAMIC:
@@ -933,14 +934,14 @@ class StreamsController(CoreController):
         streamdetails.volume_normalization_gain_correct = gain_correct
 
         # work out audio source for these streamdetails
-        if streamdetails.stream_type == StreamType.CUSTOM:
+        if stream_type == StreamType.CUSTOM:
             audio_source = self.mass.get_provider(streamdetails.provider).get_audio_stream(
                 streamdetails,
                 seek_position=streamdetails.seek_position,
             )
-        elif streamdetails.stream_type == StreamType.ICY:
+        elif stream_type == StreamType.ICY:
             audio_source = get_icy_radio_stream(self.mass, streamdetails.path, streamdetails)
-        elif streamdetails.stream_type == StreamType.HLS:
+        elif stream_type == StreamType.HLS:
             substream = await get_hls_substream(self.mass, streamdetails.path)
             audio_source = substream.path
             if streamdetails.media_type == MediaType.RADIO:
@@ -951,10 +952,6 @@ class StreamsController(CoreController):
         else:
             audio_source = streamdetails.path
 
-        # add support for decryption key provided in streamdetails
-        if streamdetails.decryption_key:
-            extra_input_args += ["-decryption_key", streamdetails.decryption_key]
-
         # handle seek support
         if (
             streamdetails.seek_position
@@ -962,7 +959,7 @@ class StreamsController(CoreController):
             and streamdetails.allow_seek
             # allow seeking for custom streams,
             # but only for custom streams that can't seek theirselves
-            and (streamdetails.stream_type != StreamType.CUSTOM or not streamdetails.can_seek)
+            and (stream_type != StreamType.CUSTOM or not streamdetails.can_seek)
         ):
             extra_input_args += ["-ss", str(int(streamdetails.seek_position))]
 
index 482ab8e8cf3749a1c305b1d163206999d9891492..68d2c9d475ada8fadc4bc6299526d1ad82ddf51c 100644 (file)
@@ -10,9 +10,10 @@ import struct
 import time
 from collections.abc import AsyncGenerator
 from io import BytesIO
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, cast
 
 import aiofiles
+import shortuuid
 from aiohttp import ClientTimeout
 from music_assistant_models.dsp import DSPConfig, DSPDetails, DSPState
 from music_assistant_models.enums import (
@@ -30,6 +31,7 @@ from music_assistant_models.errors import (
     MusicAssistantError,
     ProviderUnavailableError,
 )
+from music_assistant_models.helpers import get_global_cache_value, set_global_cache_values
 from music_assistant_models.streamdetails import AudioFormat
 
 from music_assistant.constants import (
@@ -48,7 +50,7 @@ from .dsp import filter_to_ffmpeg_params
 from .ffmpeg import FFMpeg, get_ffmpeg_stream
 from .playlists import IsHLSPlaylist, PlaylistItem, fetch_playlist, parse_m3u
 from .process import AsyncProcess, communicate
-from .util import create_tempfile, detect_charset
+from .util import detect_charset, has_tmpfs_mount
 
 if TYPE_CHECKING:
     from music_assistant_models.config_entries import CoreConfig, PlayerConfig
@@ -66,6 +68,124 @@ HTTP_HEADERS = {"User-Agent": "Lavf/60.16.100.MusicAssistant"}
 HTTP_HEADERS_ICY = {**HTTP_HEADERS, "Icy-MetaData": "1"}
 
 
+async def remove_file(file_path: str) -> None:
+    """Remove file path (if it exists)."""
+    if not await asyncio.to_thread(os.path.exists, file_path):
+        return
+    await asyncio.to_thread(os.remove, file_path)
+    LOGGER.log(VERBOSE_LOG_LEVEL, "Removed cache file: %s", file_path)
+
+
+class StreamCache:
+    """
+    StreamCache.
+
+    Basic class to handle temporary caching of audio streams.
+    For now, based on a (in-memory) tempfile and ffmpeg.
+    """
+
+    def acquire(self) -> str:
+        """Acquire the cache and return the cache file path."""
+        # for the edge case where the cache file is not released,
+        # set a fallback timer to remove the file after 20 minutes
+        self.mass.call_later(
+            20 * 60, remove_file, self._temp_path, task_id=f"remove_file_{self._temp_path}"
+        )
+        return self._temp_path
+
+    def release(self) -> None:
+        """Release the cache file."""
+        # edge case: MA is closing, clean down the file immediately
+        if self.mass.closing:
+            os.remove(self._temp_path)
+            return
+        # set a timer to remove the file after 1 minute
+        # if the file is accessed again within this 1 minute, the timer will be cancelled
+        self.mass.call_later(
+            60, remove_file, self._temp_path, task_id=f"remove_file_{self._temp_path}"
+        )
+
+    def __init__(self, mass: MusicAssistant, streamdetails: StreamDetails) -> None:
+        """Initialize the StreamCache."""
+        self.mass = mass
+        self.streamdetails = streamdetails
+        ext = streamdetails.audio_format.output_format_str
+        self._temp_path = f"/tmp/{shortuuid.random(20)}.{ext}"  # noqa: S108
+        self._fetch_task: asyncio.Task | None = None
+        self.org_path: str | None = streamdetails.path
+        self.org_stream_type: StreamType | None = streamdetails.stream_type
+        self.org_extra_input_args: list[str] | None = streamdetails.extra_input_args
+        streamdetails.path = self._temp_path
+        streamdetails.stream_type = StreamType.CACHE_FILE
+        streamdetails.extra_input_args = []
+
+    async def create(self) -> None:
+        """Create the cache file (if needed)."""
+        if await asyncio.to_thread(os.path.exists, self._temp_path):
+            return
+        if self._fetch_task is not None and not self._fetch_task.done():
+            # fetch task is already busy
+            return
+        self._fetch_task = self.mass.create_task(self._create_cache_file())
+        # for the edge case where the cache file is not consumed at all,
+        # set a fallback timer to remove the file after 1 hour
+        self.mass.call_later(
+            3600, remove_file, self._temp_path, task_id=f"remove_file_{self._temp_path}"
+        )
+
+    async def wait(self, require_complete_file: bool) -> None:
+        """
+        Wait until the cache is ready.
+
+        Optionally wait until the full file is available (e.g. when seeking).
+        """
+        # if 'require_complete_file' is specified, we wait until the fetch task is ready
+        if require_complete_file:
+            await self._fetch_task
+            return
+        # wait until the file is created
+        while not await asyncio.to_thread(os.path.exists, self._temp_path):
+            await asyncio.sleep(0.2)
+
+    async def _create_cache_file(self) -> None:
+        time_start = time.time()
+        LOGGER.log(VERBOSE_LOG_LEVEL, "Fetching audio stream to cache file %s", self._temp_path)
+
+        if self.org_stream_type == StreamType.CUSTOM:
+            audio_source = self.mass.get_provider(self.streamdetails.provider).get_audio_stream(
+                self.streamdetails,
+            )
+        elif self.org_stream_type in (StreamType.HTTP, StreamType.ENCRYPTED_HTTP):
+            audio_source = self.org_path
+        else:
+            raise NotImplementedError("Caching of this streamtype is not supported")
+
+        extra_input_args = self.org_extra_input_args or []
+        if self.streamdetails.decryption_key:
+            extra_input_args += ["-decryption_key", self.streamdetails.decryption_key]
+
+        ffmpeg = FFMpeg(
+            audio_input=audio_source,
+            input_format=self.streamdetails.audio_format,
+            output_format=self.streamdetails.audio_format,
+            extra_input_args=["-y", *extra_input_args],
+            audio_output=self._temp_path,
+        )
+        await ffmpeg.start()
+        await ffmpeg.wait()
+        process_time = int((time.time() - time_start) * 1000)
+        LOGGER.log(
+            VERBOSE_LOG_LEVEL,
+            "Writing cache file %s done in %s milliseconds",
+            self._temp_path,
+            process_time,
+        )
+
+    def __del__(self) -> None:
+        """Ensure the temp file gets cleaned up."""
+        self.mass.loop.call_soon_threadsafe(self.mass.create_task, remove_file(self._temp_path))
+
+
 async def crossfade_pcm_parts(
     fade_in_part: bytes,
     fade_out_part: bytes,
@@ -75,8 +195,8 @@ async def crossfade_pcm_parts(
     sample_size = pcm_format.pcm_sample_size
     # calculate the fade_length from the smallest chunk
     fade_length = min(len(fade_in_part), len(fade_out_part)) / sample_size
-    fadeoutfile = create_tempfile()
-    async with aiofiles.open(fadeoutfile.name, "wb") as outfile:
+    fadeout_filename = f"/tmp/{shortuuid.random(20)}.pcm"  # noqa: S108
+    async with aiofiles.open(fadeout_filename, "wb") as outfile:
         await outfile.write(fade_out_part)
     args = [
         # generic args
@@ -94,7 +214,7 @@ async def crossfade_pcm_parts(
         "-ar",
         str(pcm_format.sample_rate),
         "-i",
-        fadeoutfile.name,
+        fadeout_filename,
         # fade_in part (stdin)
         "-acodec",
         pcm_format.content_type.name.lower(),
@@ -114,7 +234,8 @@ async def crossfade_pcm_parts(
         pcm_format.content_type.value,
         "-",
     ]
-    _returncode, crossfaded_audio, _stderr = await communicate(args, fade_in_part)
+    _, crossfaded_audio, _ = await communicate(args, fade_in_part)
+    await remove_file(fadeout_filename)
     if crossfaded_audio:
         LOGGER.log(
             VERBOSE_LOG_LEVEL,
@@ -294,8 +415,11 @@ async def get_stream_details(
         raise MediaNotFoundError(
             f"Unable to retrieve streamdetails for {queue_item.name} ({queue_item.uri})"
         )
-    if queue_item.streamdetails and not queue_item.streamdetails.seconds_streamed:
-        # already got a fresh/unused streamdetails
+    if queue_item.streamdetails and (
+        not queue_item.streamdetails.seconds_streamed
+        or queue_item.streamdetails.stream_type == StreamType.CACHE_FILE
+    ):
+        # already got a fresh/unused (or cached) streamdetails
         streamdetails = queue_item.streamdetails
     else:
         media_item = queue_item.media_item
@@ -367,6 +491,42 @@ async def get_stream_details(
         queue_item.uri,
         process_time,
     )
+
+    if streamdetails.decryption_key:
+        # using intermediate cache is mandatory for decryption
+        streamdetails.enable_cache = True
+
+    # determine if we may use a temporary cache for the audio stream
+    if streamdetails.enable_cache is None:
+        tmpfs_present = get_global_cache_value("tmpfs_present")
+        if tmpfs_present is None:
+            tmpfs_present = await has_tmpfs_mount()
+            await set_global_cache_values({"tmpfs_present": tmpfs_present})
+        streamdetails.enable_cache = (
+            tmpfs_present
+            and streamdetails.duration is not None
+            and streamdetails.duration < 1800
+            and streamdetails.stream_type
+            in (StreamType.HTTP, StreamType.ENCRYPTED_HTTP, StreamType.CUSTOM, StreamType.HLS)
+        )
+
+    # handle temporary cache support of audio stream
+    if streamdetails.enable_cache:
+        if streamdetails.cache is None:
+            streamdetails.cache = StreamCache(mass, streamdetails)
+        else:
+            streamdetails.cache = cast(StreamCache, streamdetails.cache)
+        await streamdetails.cache.create()
+        require_complete_file = (
+            # require complete file if we're seeking to prevent we're seeking beyond the cached data
+            streamdetails.seek_position > 0
+            or streamdetails.audio_format.content_type
+            # m4a/mp4 files often have their moov/atom at the end of the file
+            # so we need the whole file to be available
+            in (ContentType.M4A, ContentType.M4B, ContentType.MP4)
+        )
+        await streamdetails.cache.wait(require_complete_file=require_complete_file)
+
     return streamdetails
 
 
@@ -386,6 +546,11 @@ async def get_media_stream(
     if streamdetails.fade_in:
         filter_params.append("afade=type=in:start_time=0:duration=3")
         strip_silence_begin = False
+
+    if streamdetails.stream_type == StreamType.CACHE_FILE:
+        cache = cast(StreamCache, streamdetails.cache)
+        audio_source = cache.acquire()
+
     bytes_sent = 0
     chunk_number = 0
     buffer: bytes = b""
@@ -415,8 +580,10 @@ async def get_media_stream(
             pcm_format.content_type.value,
             ffmpeg_proc.proc.pid,
         )
-        async for chunk in ffmpeg_proc.iter_chunked(pcm_format.pcm_sample_size):
-            if chunk_number == 0:
+        # use 1 second chunks
+        chunk_size = pcm_format.pcm_sample_size
+        async for chunk in ffmpeg_proc.iter_chunked(chunk_size):
+            if chunk_number == 1:
                 # At this point ffmpeg has started and should now know the codec used
                 # for encoding the audio.
                 streamdetails.audio_format.codec_type = ffmpeg_proc.input_format.codec_type
@@ -430,11 +597,17 @@ async def get_media_stream(
             chunk_number += 1
             # determine buffer size dynamically
             if chunk_number < 5 and strip_silence_begin:
-                req_buffer_size = int(pcm_format.pcm_sample_size * 4)
-            elif chunk_number > 30 and strip_silence_end:
+                req_buffer_size = int(pcm_format.pcm_sample_size * 5)
+            elif chunk_number > 240 and strip_silence_end:
+                req_buffer_size = int(pcm_format.pcm_sample_size * 10)
+            elif chunk_number > 60 and strip_silence_end:
                 req_buffer_size = int(pcm_format.pcm_sample_size * 8)
-            else:
+            elif chunk_number > 30:
+                req_buffer_size = int(pcm_format.pcm_sample_size * 4)
+            elif chunk_number > 10 and strip_silence_end:
                 req_buffer_size = int(pcm_format.pcm_sample_size * 2)
+            else:
+                req_buffer_size = pcm_format.pcm_sample_size
 
             # always append to buffer
             buffer += chunk
@@ -494,7 +667,7 @@ async def get_media_stream(
 
         # try to determine how many seconds we've streamed
         seconds_streamed = bytes_sent / pcm_format.pcm_sample_size if bytes_sent else 0
-        if not cancelled and ffmpeg_proc.returncode != 0:
+        if not cancelled and ffmpeg_proc.returncode not in (0, 255):
             # dump the last 5 lines of the log in case of an unclean exit
             log_tail = "\n" + "\n".join(list(ffmpeg_proc.log_history)[-5:])
         else:
@@ -558,6 +731,11 @@ async def get_media_stream(
         ):
             mass.create_task(music_prov.on_streamed(streamdetails))
 
+        # schedule removal of cache file
+        if streamdetails.stream_type == StreamType.CACHE_FILE:
+            cache = cast(StreamCache, streamdetails.cache)
+            cache.release()
+
 
 def create_wave_header(samplerate=44100, channels=2, bitspersample=16, duration=None):
     """Generate a wave header from given params."""
index 7c1c8d83cd37f274161d30da66cc028edfd79ca9..3f4571077b408a19c4b27b11287dd0af509e388a 100644 (file)
@@ -321,6 +321,9 @@ def get_ffmpeg_args(
             "-f",
             output_format.content_type.value,
         ]
+    elif input_format == output_format:
+        # passthrough
+        output_args = ["-c", "copy"]
     else:
         raise RuntimeError("Invalid/unsupported output format specified")
 
index 8daf159da1909ab658b6acb0d9e8efa10147b49d..a0fe11f5bb4d5b9fd4b05c5dc6f6c9a0cda24efc 100644 (file)
@@ -7,10 +7,8 @@ import functools
 import importlib
 import logging
 import os
-import platform
 import re
 import socket
-import tempfile
 import urllib.error
 import urllib.parse
 import urllib.request
@@ -23,9 +21,9 @@ from types import TracebackType
 from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar
 from urllib.parse import urlparse
 
+import aiofiles
 import cchardet as chardet
 import ifaddr
-import memory_tempfile
 from zeroconf import IPVersion
 
 from music_assistant.helpers.process import check_output
@@ -454,12 +452,16 @@ async def load_provider_module(domain: str, requirements: list[str]) -> Provider
     return await asyncio.to_thread(_get_provider_module, domain)
 
 
-def create_tempfile():
-    """Return a (named) temporary file."""
-    # ruff: noqa: SIM115
-    if platform.system() == "Linux":
-        return memory_tempfile.MemoryTempfile(fallback=True).NamedTemporaryFile(buffering=0)
-    return tempfile.NamedTemporaryFile(buffering=0)
+async def has_tmpfs_mount() -> bool:
+    """Check if we have a tmpfs mount."""
+    try:
+        async with aiofiles.open("/proc/mounts") as file:
+            async for line in file:
+                if "tmpfs /tmp tmpfs rw" in line:
+                    return True
+    except (FileNotFoundError, OSError, PermissionError):
+        pass
+    return False
 
 
 def divide_chunks(data: bytes, chunk_size: int) -> Iterator[bytes]:
index bfee2038c8615614142d59827f34d16ffb745a22..df388b22470005ae9435490604926515560a2b28 100644 (file)
@@ -8,6 +8,7 @@ import os
 from typing import TYPE_CHECKING, Any
 
 import aiofiles
+from aiohttp.client_exceptions import ClientError
 from music_assistant_models.config_entries import ConfigEntry, ConfigValueType
 from music_assistant_models.enums import (
     AlbumType,
@@ -363,14 +364,14 @@ class AppleMusicProvider(MusicProvider):
         return StreamDetails(
             item_id=item_id,
             provider=self.lookup_key,
-            audio_format=AudioFormat(
-                content_type=ContentType.UNKNOWN,
-            ),
+            audio_format=AudioFormat(content_type=ContentType.M4A, codec_type=ContentType.AAC),
             stream_type=StreamType.ENCRYPTED_HTTP,
-            path=stream_url,
             decryption_key=await self._get_decryption_key(license_url, key_id, uri, item_id),
+            path=stream_url,
             can_seek=True,
             allow_seek=True,
+            # enforce caching because the apple streams are m4a files with moov atom at the end
+            enable_cache=True,
         )
 
     def _parse_artist(self, artist_obj):
@@ -714,12 +715,23 @@ class AppleMusicProvider(MusicProvider):
         data = {
             "salableAdamId": song_id,
         }
-        async with self.mass.http_session.post(
-            playback_url, headers=self._get_decryption_headers(), json=data, ssl=True
-        ) as response:
-            response.raise_for_status()
-            content = await response.json(loads=json_loads)
-            return content["songList"][0]
+        for retry in (True, False):
+            try:
+                async with self.mass.http_session.post(
+                    playback_url, headers=self._get_decryption_headers(), json=data, ssl=True
+                ) as response:
+                    response.raise_for_status()
+                    content = await response.json(loads=json_loads)
+                    if content.get("failureType"):
+                        message = content.get("failureMessage")
+                        raise MediaNotFoundError(f"Failed to get song stream metadata: {message}")
+                    return content["songList"][0]
+            except (MediaNotFoundError, ClientError) as exc:
+                if retry:
+                    self.logger.warning("Failed to get song stream metadata: %s", exc)
+                    continue
+                raise
+        raise MediaNotFoundError(f"Failed to get song stream metadata for {song_id}")
 
     async def _parse_stream_url_and_uri(self, stream_assets: list[dict]) -> str:
         """Parse the Stream URL and Key URI from the song."""
@@ -755,7 +767,7 @@ class AppleMusicProvider(MusicProvider):
         }
 
     async def _get_decryption_key(
-        self, license_url: str, key_id: str, uri: str, item_id: str
+        self, license_url: str, key_id: bytes, uri: str, item_id: str
     ) -> str:
         """Get the decryption key for a song."""
         cache_key = f"decryption_key.{item_id}"
index 6bbcd6b47011558676bebb215e614d2645c89d68..23fffbee591268f7b1ba5db0da3efd7998e56119 100644 (file)
@@ -23,9 +23,8 @@ dependencies = [
   "faust-cchardet>=2.1.18",
   "ifaddr==0.2.0",
   "mashumaro==3.15",
-  "memory-tempfile==2.2.3",
   "music-assistant-frontend==2.11.11",
-  "music-assistant-models==1.1.30",
+  "music-assistant-models==1.1.31",
   "mutagen==1.47.0",
   "orjson==3.10.12",
   "pillow==11.1.0",
index 371ee686f0aec2916e310c113cfef9103f5b64ab..7eda67f8b34efd75af6d0514f2f00ee2714c2e05 100644 (file)
@@ -25,9 +25,8 @@ hass-client==1.2.0
 ibroadcastaio==0.4.0
 ifaddr==0.2.0
 mashumaro==3.15
-memory-tempfile==2.2.3
 music-assistant-frontend==2.11.11
-music-assistant-models==1.1.30
+music-assistant-models==1.1.31
 mutagen==1.47.0
 orjson==3.10.12
 pillow==11.1.0