Enhancement: Use pure memory cache for audio caching
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Thu, 6 Mar 2025 13:59:13 +0000 (14:59 +0100)
committerMarcel van der Veldt <m.vanderveldt@outlook.com>
Thu, 6 Mar 2025 13:59:13 +0000 (14:59 +0100)
- Only cache in memory and not with intermediate file to prevent weird issues
- Add config toggle to enable/disable audio caching
- Use intermediate NUT container when needed to enable seeking in e.g. ogg streams

music_assistant/constants.py
music_assistant/controllers/streams.py
music_assistant/helpers/audio.py
music_assistant/helpers/ffmpeg.py

index e05ffea13613cb016b89b54664d2a2f0634d9d21..a553e103987db4fa9830347953de50d5acdb0182 100644 (file)
@@ -84,10 +84,12 @@ CONF_POWER_CONTROL: Final[str] = "power_control"
 CONF_VOLUME_CONTROL: Final[str] = "volume_control"
 CONF_MUTE_CONTROL: Final[str] = "mute_control"
 CONF_OUTPUT_CODEC: Final[str] = "output_codec"
+CONF_ALLOW_MEMORY_CACHE: Final[str] = "allow_memory_cache"
 
 # config default values
 DEFAULT_HOST: Final[str] = "0.0.0.0"
 DEFAULT_PORT: Final[int] = 8095
+DEFAULT_ALLOW_MEMORY_CACHE: Final[bool] = True
 
 # common db tables
 DB_TABLE_PLAYLOG: Final[str] = "playlog"
index 1c9775876d42914538ea303218e25eeb5462ea24..edf77ab86e5fca240d022d3c049584e9de76ec0f 100644 (file)
@@ -30,6 +30,7 @@ from music_assistant_models.player_queue import PlayLogEntry
 
 from music_assistant.constants import (
     ANNOUNCE_ALERT_FILE,
+    CONF_ALLOW_MEMORY_CACHE,
     CONF_BIND_IP,
     CONF_BIND_PORT,
     CONF_CROSSFADE,
@@ -44,6 +45,7 @@ from music_assistant.constants import (
     CONF_VOLUME_NORMALIZATION_FIXED_GAIN_TRACKS,
     CONF_VOLUME_NORMALIZATION_RADIO,
     CONF_VOLUME_NORMALIZATION_TRACKS,
+    DEFAULT_ALLOW_MEMORY_CACHE,
     DEFAULT_PCM_FORMAT,
     DEFAULT_STREAM_HEADERS,
     ICY_HEADERS,
@@ -195,6 +197,18 @@ class StreamsController(CoreController):
                 category="advanced",
                 required=False,
             ),
+            ConfigEntry(
+                key=CONF_ALLOW_MEMORY_CACHE,
+                type=ConfigEntryType.BOOLEAN,
+                default_value=DEFAULT_ALLOW_MEMORY_CACHE,
+                label="Allow (in-memory) caching of audio streams",
+                description="To ensure smooth playback as well as fast seeking, "
+                "Music Assistant by default caches audio streams (in memory). "
+                "On systems with limited memory, this can be disabled, "
+                "but may result in less smooth playback.",
+                category="advanced",
+                required=False,
+            ),
         )
 
     async def setup(self, config: CoreConfig) -> None:
index ce06e561b7d9ae16ecd1b72c2ec65e80aa146cfd..b2b5c42a923413cfe3142aaed85e95ebafeb2082 100644 (file)
@@ -34,12 +34,14 @@ from music_assistant_models.errors import (
 from music_assistant_models.streamdetails import AudioFormat
 
 from music_assistant.constants import (
+    CONF_ALLOW_MEMORY_CACHE,
     CONF_ENTRY_OUTPUT_LIMITER,
     CONF_OUTPUT_CHANNELS,
     CONF_VOLUME_NORMALIZATION,
     CONF_VOLUME_NORMALIZATION_RADIO,
     CONF_VOLUME_NORMALIZATION_TARGET,
     CONF_VOLUME_NORMALIZATION_TRACKS,
+    DEFAULT_ALLOW_MEMORY_CACHE,
     MASS_LOGGER_NAME,
     VERBOSE_LOG_LEVEL,
 )
@@ -51,7 +53,7 @@ from .dsp import filter_to_ffmpeg_params
 from .ffmpeg import FFMpeg, get_ffmpeg_stream
 from .playlists import IsHLSPlaylist, PlaylistItem, fetch_playlist, parse_m3u
 from .process import AsyncProcess, communicate
-from .util import detect_charset, get_tmp_free_space
+from .util import detect_charset
 
 if TYPE_CHECKING:
     from music_assistant_models.config_entries import CoreConfig, PlayerConfig
@@ -81,81 +83,50 @@ class StreamCache:
     """
     StreamCache.
 
-    Basic class to handle (temporary) caching of audio streams.
+    Basic class to handle (temporary) in-memory caching of audio streams.
     Useful in case of slow or unreliable network connections, faster seeking,
     or when the audio stream is slow itself.
-
-    The cache is stored in a file on disk so ffmpeg can access it directly.
-    After 1 minute of inactivity, the cache file will be removed.
-
-    Because we use /tmp as the cache location, and on our systems /tmp is mounted as tmpfs,
-    the cache will be stored in memory and not on the disk.
     """
 
-    @property
-    def data_complete(self) -> bool:
-        """Return if the cache is complete."""
-        return self._fetch_task is not None and self._fetch_task.done()
-
-    async def acquire(self) -> str | AsyncGenerator[bytes, None]:
-        """Acquire the cache and return the cache file path."""
-        self.mass.cancel_timer(f"clear_cache_{self._temp_path}")
-        if not self.data_complete and not self._first_part_received.is_set():
-            # handle the situation where the cache
-            # file is not created yet or already removed
-            await self.create()
-        self._subscribers += 1
-        if self._all_data_written.is_set():
-            # cache is completely written, return the path
-            return self._temp_path
-        return self._stream_from_cache()
-
-    def release(self) -> None:
-        """Release the cache file."""
-        self._subscribers -= 1
-        if self._subscribers == 0:
-            # set a timer to remove the tempfile after 1 minute
-            # if the file is accessed again within this period,
-            # the timer will be cancelled
-            self.mass.call_later(60, self._clear, task_id=f"clear_cache_{self._temp_path}")
-
-    def __init__(self, mass: MusicAssistant, streamdetails: StreamDetails) -> None:
-        """Initialize the StreamCache."""
-        self.mass = mass
-        self.streamdetails = streamdetails
-        self.logger = LOGGER.getChild("cache")
-        self._temp_path = f"/tmp/{shortuuid.random(20)}"  # noqa: S108
-        self._fetch_task: asyncio.Task | None = None
-        self._subscribers: int = 0
-        self._first_part_received = asyncio.Event()
-        self._all_data_written = asyncio.Event()
-        self.org_path: str | None = streamdetails.path
-        self.org_stream_type: StreamType | None = streamdetails.stream_type
-        self.org_extra_input_args: list[str] | None = streamdetails.extra_input_args
-        streamdetails.path = self._temp_path
-        streamdetails.stream_type = StreamType.CACHE
-        streamdetails.can_seek = True
-        streamdetails.allow_seek = True
-        streamdetails.extra_input_args = []
-
     async def create(self) -> None:
         """Create the cache file (if needed)."""
-        self.mass.cancel_timer(f"clear_cache_{self._temp_path}")
-        if await asyncio.to_thread(os.path.exists, self._temp_path):
-            return
-        if self._fetch_task is not None and not self._fetch_task.done():
-            # fetch task is already busy
-            return
-        self._fetch_task = self.mass.create_task(self._create_cache_file())
+        self.mass.cancel_timer(f"clear_cache_{self.cache_id}")
+        if self._fetch_task is None:
+            self._fetch_task = self.mass.create_task(self._fill_cache())
         # wait until the first part of the file is received
         await self._first_part_received.wait()
 
-    async def _create_cache_file(self) -> None:
+    async def get_audio_stream(self) -> AsyncGenerator[bytes, None]:
+        """Stream audio from cachedata (while it might even still being written)."""
+        try:
+            self._subscribers += 1
+            bytes_read = 0
+            chunksize = 64000
+            await self.create()
+            while True:
+                async with self._lock:
+                    chunk = self._data[bytes_read : bytes_read + chunksize]
+                    bytes_read += len(chunk)
+                if len(chunk) < chunksize and self._all_data_written.is_set():
+                    # reached EOF
+                    break
+                elif not chunk:
+                    # data is not yet available, wait a bit
+                    await asyncio.sleep(0.05)
+                else:
+                    yield chunk
+                del chunk
+        finally:
+            self._subscribers -= 1
+            if self._subscribers == 0:
+                # set a timer to remove the tempfile after 1 minute
+                # if the file is accessed again within this period,
+                # the timer will be cancelled
+                self.mass.call_later(60, self._clear, task_id=f"clear_cache_{self.cache_id}")
+
+    async def _fill_cache(self) -> None:
         time_start = time.time()
-        self.logger.debug(
-            "Fetching audio stream for %s",
-            self.streamdetails.uri,
-        )
+        self.logger.debug("Fetching audio stream for %s", self.streamdetails.uri)
         if self.org_stream_type == StreamType.CUSTOM:
             audio_source = self.mass.get_provider(self.streamdetails.provider).get_audio_stream(
                 self.streamdetails,
@@ -180,65 +151,69 @@ class StreamCache:
         # ffmpeg will produce a lossless copy of the original codec to stdout.
         self._first_part_received.clear()
         self._all_data_written.clear()
-        required_bytes = get_chunksize(self.streamdetails.audio_format, 2)
-        async with FFMpeg(
+        self._data = b""
+        async for chunk in get_ffmpeg_stream(
             audio_input=audio_source,
             input_format=self.streamdetails.audio_format,
             output_format=self.streamdetails.audio_format,
             extra_input_args=extra_input_args,
-            audio_output=self._temp_path,
-        ) as ffmpeg_proc:
-            # wait until the first part of the file is received
-            while ffmpeg_proc.returncode is None:
-                await asyncio.sleep(0.05)
-                if not await asyncio.to_thread(os.path.exists, self._temp_path):
-                    continue
-                if await asyncio.to_thread(os.path.getsize, self._temp_path) >= required_bytes:
-                    break
-            self._first_part_received.set()
-            self.logger.debug(
-                "First part received for %s after %.2fs",
-                self.streamdetails.uri,
-                time.time() - time_start,
-            )
-            # wait until ffmpeg is done
-            await ffmpeg_proc.wait()
-            self._all_data_written.set()
-
-        LOGGER.debug(
+        ):
+            async with self._lock:
+                self._data += chunk
+                del chunk
+            if not self._first_part_received.is_set():
+                self._first_part_received.set()
+                self.logger.debug(
+                    "First part received for %s after %.2fs",
+                    self.streamdetails.uri,
+                    time.time() - time_start,
+                )
+        self._all_data_written.set()
+        self.logger.debug(
             "Writing all data for %s done in %.2fs",
             self.streamdetails.uri,
             time.time() - time_start,
         )
 
-    async def _stream_from_cache(self) -> AsyncGenerator[bytes, None]:
-        """Stream audio from cachefile (while its still being written)."""
-        async with aiofiles.open(self._temp_path, "rb", buffering=0) as _file:
-            while True:
-                chunk = await _file.read(64000)
-                if not chunk and self._all_data_written.is_set():
-                    break
-                elif not chunk:
-                    await asyncio.sleep(0.05)
-                else:
-                    yield chunk
+    def __init__(self, mass: MusicAssistant, streamdetails: StreamDetails) -> None:
+        """Initialize the StreamCache."""
+        self.mass = mass
+        self.streamdetails = streamdetails
+        self.cache_id = shortuuid.random(20)
+        self.logger = LOGGER.getChild("cache")
+        self._fetch_task: asyncio.Task | None = None
+        self._subscribers: int = 0
+        self._first_part_received = asyncio.Event()
+        self._all_data_written = asyncio.Event()
+        self._data: bytes = b""
+        self._lock: asyncio.Lock = asyncio.Lock()
+        self.org_path: str | None = streamdetails.path
+        self.org_stream_type: StreamType | None = streamdetails.stream_type
+        self.org_extra_input_args: list[str] | None = streamdetails.extra_input_args
+        streamdetails.path = "-"
+        streamdetails.stream_type = StreamType.CACHE
+        streamdetails.can_seek = True
+        streamdetails.allow_seek = True
+        streamdetails.extra_input_args = []
 
     async def _clear(self) -> None:
         """Clear the cache."""
+        self.logger.debug("Cleaning up cache %s", self.streamdetails.uri)
+        if self._fetch_task and not self._fetch_task.done():
+            self._fetch_task.cancel()
+        self._fetch_task = None
         self._first_part_received.clear()
         self._all_data_written.clear()
-        self._fetch_task = None
-        await remove_file(self._temp_path)
+        del self._data
+        self._data = b""
 
     def __del__(self) -> None:
-        """Ensure the temp file gets cleaned up."""
+        """Ensure the cache data gets cleaned up."""
         if self.mass.closing:
-            # edge case: MA is closing, clean down the file immediately
-            if os.path.isfile(self._temp_path):
-                os.remove(self._temp_path)
+            # edge case: MA is closing
             return
-        self.mass.loop.call_soon_threadsafe(self.mass.create_task, remove_file(self._temp_path))
-        self.mass.cancel_timer(f"remove_file_{self._temp_path}")
+        self.mass.cancel_timer(f"remove_file_{self.cache_id}")
+        del self._data
 
 
 async def crossfade_pcm_parts(
@@ -500,9 +475,8 @@ async def get_stream_details(
             else:
                 break
         else:
-            raise MediaNotFoundError(
-                f"Unable to retrieve streamdetails for {queue_item.name} ({queue_item.uri})"
-            )
+            msg = f"Unable to retrieve streamdetails for {queue_item.name} ({queue_item.uri})"
+            raise MediaNotFoundError(msg)
 
         # work out how to handle radio stream
         if (
@@ -548,19 +522,22 @@ async def get_stream_details(
     )
 
     if streamdetails.decryption_key:
-        # using intermediate cache is mandatory for decryption
+        # using intermediate cache is mandatory for encrypted streams
         streamdetails.enable_cache = True
 
-    # determine if we may use a temporary cache for the audio stream
+    # determine if we may use caching for the audio stream
     if streamdetails.enable_cache is None:
+        allow_cache = mass.config.get_raw_core_config_value(
+            "streams", CONF_ALLOW_MEMORY_CACHE, DEFAULT_ALLOW_MEMORY_CACHE
+        )
         streamdetails.enable_cache = (
-            streamdetails.duration is not None
+            allow_cache
+            and streamdetails.duration is not None
             and streamdetails.media_type
             in (MediaType.TRACK, MediaType.AUDIOBOOK, MediaType.PODCAST_EPISODE)
             and streamdetails.stream_type
             in (StreamType.HTTP, StreamType.ENCRYPTED_HTTP, StreamType.CUSTOM, StreamType.HLS)
             and streamdetails.audio_format.content_type != ContentType.UNKNOWN
-            and await get_tmp_free_space() > 512 * 1024 * 1024
             and get_chunksize(streamdetails.audio_format, streamdetails.duration) < 100000000
         )
 
@@ -601,11 +578,11 @@ async def get_media_stream(
     stream_type = streamdetails.stream_type
     if stream_type == StreamType.CACHE:
         cache = cast(StreamCache, streamdetails.cache)
-        audio_source = await cache.acquire()
+        audio_source = cache.get_audio_stream()
     elif stream_type == StreamType.CUSTOM:
         audio_source = mass.get_provider(streamdetails.provider).get_audio_stream(
             streamdetails,
-            seek_position=streamdetails.seek_position,
+            seek_position=streamdetails.seek_position if streamdetails.can_seek else 0,
         )
     elif stream_type == StreamType.ICY:
         audio_source = get_icy_radio_stream(mass, streamdetails.path, streamdetails)
@@ -627,7 +604,7 @@ async def get_media_stream(
         and streamdetails.allow_seek
         # allow seeking for custom streams,
         # but only for custom streams that can't seek theirselves
-        and (stream_type != StreamType.CUSTOM or not streamdetails.can_seek)
+        and not (stream_type == StreamType.CUSTOM and streamdetails.can_seek)
     ):
         extra_input_args += ["-ss", str(int(streamdetails.seek_position))]
 
@@ -812,11 +789,6 @@ async def get_media_stream(
         ):
             mass.create_task(music_prov.on_streamed(streamdetails))
 
-        # release cache file
-        if streamdetails.stream_type == StreamType.CACHE:
-            cache = cast(StreamCache, streamdetails.cache)
-            cache.release()
-
 
 def create_wave_header(samplerate=44100, channels=2, bitspersample=16, duration=None):
     """Generate a wave header from given params."""
@@ -1370,7 +1342,7 @@ async def analyze_loudness(
     ]
     if streamdetails.stream_type == StreamType.CACHE:
         cache = cast(StreamCache, streamdetails.cache)
-        audio_source = await cache.acquire()
+        audio_source = cache.get_audio_stream()
     elif streamdetails.stream_type == StreamType.CUSTOM:
         audio_source = mass.get_provider(streamdetails.provider).get_audio_stream(
             streamdetails,
@@ -1422,10 +1394,6 @@ async def analyze_loudness(
                 streamdetails.uri,
                 loudness,
             )
-    # release cache file
-    if streamdetails.stream_type == StreamType.CACHE:
-        cache = cast(StreamCache, streamdetails.cache)
-        cache.release()
 
 
 def _get_normalization_mode(
index 0118b1399c2784856195f1b1adf84574440adb49..151fcd16e6af208810564d2fbf7b9f404b9ca817 100644 (file)
@@ -290,7 +290,7 @@ def get_ffmpeg_args(  # noqa: PLR0915
             output_format.content_type.value,
         ]
     elif input_format == output_format and not extra_args:
-        # passthrough
+        # passthrough-mode (e.g. for creating the cache)
         if output_format.content_type in (
             ContentType.MP4,
             ContentType.MP4A,
@@ -298,8 +298,8 @@ def get_ffmpeg_args(  # noqa: PLR0915
             ContentType.M4B,
         ):
             fmt = "adts"
-        elif output_format.codec_type != ContentType.UNKNOWN:
-            fmt = output_format.codec_type.name.lower()
+        elif output_format.codec_type in (ContentType.UNKNOWN, ContentType.OGG):
+            fmt = "nut"  # use special nut container
         else:
             fmt = output_format.content_type.name.lower()
         output_args = [