From: Marcel van der Veldt Date: Fri, 28 Feb 2025 14:21:12 +0000 (+0100) Subject: Add caching of audio data to fix streams not starting fast enough (#1989) X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=ecdab7a5559814c1ce296a15bb50cb8ac5d0e596;p=music-assistant-server.git Add caching of audio data to fix streams not starting fast enough (#1989) Add caching of audio data which fixes issues with providers which are slow with sending audio data or have a file format which requires the data to be stored in a file first before further processing --- diff --git a/music_assistant/controllers/player_queues.py b/music_assistant/controllers/player_queues.py index e582b2cb..facac860 100644 --- a/music_assistant/controllers/player_queues.py +++ b/music_assistant/controllers/player_queues.py @@ -788,21 +788,18 @@ class PlayerQueuesController(CoreController): ): seek_position = max(0, int((resume_position_ms - 500) / 1000)) - # load item (which also fetches the streamdetails) - # do this here to catch unavailable items early - next_index = self._get_next_index(queue_id, index, allow_repeat=False) - await self._load_item( - queue_item, - next_index, - is_start=True, - seek_position=seek_position, - fade_in=fade_in, - ) - # send play_media request to player # NOTE that we debounce this a bit to account for someone hitting the next button # like a madman. This will prevent the player from being overloaded with requests. - async def play_media(): + async def play_media() -> None: + next_index = self._get_next_index(queue_id, index, allow_repeat=False) + await self._load_item( + queue_item, + next_index, + is_start=True, + seek_position=seek_position, + fade_in=fade_in, + ) await self.mass.players.play_media( player_id=queue_id, media=await self.player_media_from_queue_item(queue_item, queue.flow_mode), @@ -813,7 +810,7 @@ class PlayerQueuesController(CoreController): # we set a flag to notify the update logic that we're transitioning to a new track self._transitioning_players.add(queue_id) self.mass.call_later( - 1.5 if debounce else 0.1, + 1 if debounce else 0, play_media, task_id=f"play_media_{queue_id}", ) diff --git a/music_assistant/controllers/streams.py b/music_assistant/controllers/streams.py index 383c3d09..d45dc904 100644 --- a/music_assistant/controllers/streams.py +++ b/music_assistant/controllers/streams.py @@ -343,7 +343,6 @@ class StreamsController(CoreController): # inform the queue that the track is now loaded in the buffer # so for example the next track can be enqueued self.mass.player_queues.track_loaded_in_buffer(queue_id, queue_item_id) - async for chunk in get_ffmpeg_stream( audio_input=self.get_queue_item_stream( queue_item=queue_item, @@ -903,8 +902,10 @@ class StreamsController(CoreController): # collect all arguments for ffmpeg streamdetails = queue_item.streamdetails assert streamdetails + stream_type = streamdetails.stream_type filter_params = [] extra_input_args = streamdetails.extra_input_args or [] + # handle volume normalization gain_correct: float | None = None if streamdetails.volume_normalization_mode == VolumeNormalizationMode.DYNAMIC: @@ -933,14 +934,14 @@ class StreamsController(CoreController): streamdetails.volume_normalization_gain_correct = gain_correct # work out audio source for these streamdetails - if streamdetails.stream_type == StreamType.CUSTOM: + if stream_type == StreamType.CUSTOM: audio_source = self.mass.get_provider(streamdetails.provider).get_audio_stream( streamdetails, seek_position=streamdetails.seek_position, ) - elif streamdetails.stream_type == StreamType.ICY: + elif stream_type == StreamType.ICY: audio_source = get_icy_radio_stream(self.mass, streamdetails.path, streamdetails) - elif streamdetails.stream_type == StreamType.HLS: + elif stream_type == StreamType.HLS: substream = await get_hls_substream(self.mass, streamdetails.path) audio_source = substream.path if streamdetails.media_type == MediaType.RADIO: @@ -951,10 +952,6 @@ class StreamsController(CoreController): else: audio_source = streamdetails.path - # add support for decryption key provided in streamdetails - if streamdetails.decryption_key: - extra_input_args += ["-decryption_key", streamdetails.decryption_key] - # handle seek support if ( streamdetails.seek_position @@ -962,7 +959,7 @@ class StreamsController(CoreController): and streamdetails.allow_seek # allow seeking for custom streams, # but only for custom streams that can't seek theirselves - and (streamdetails.stream_type != StreamType.CUSTOM or not streamdetails.can_seek) + and (stream_type != StreamType.CUSTOM or not streamdetails.can_seek) ): extra_input_args += ["-ss", str(int(streamdetails.seek_position))] diff --git a/music_assistant/helpers/audio.py b/music_assistant/helpers/audio.py index 482ab8e8..68d2c9d4 100644 --- a/music_assistant/helpers/audio.py +++ b/music_assistant/helpers/audio.py @@ -10,9 +10,10 @@ import struct import time from collections.abc import AsyncGenerator from io import BytesIO -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import aiofiles +import shortuuid from aiohttp import ClientTimeout from music_assistant_models.dsp import DSPConfig, DSPDetails, DSPState from music_assistant_models.enums import ( @@ -30,6 +31,7 @@ from music_assistant_models.errors import ( MusicAssistantError, ProviderUnavailableError, ) +from music_assistant_models.helpers import get_global_cache_value, set_global_cache_values from music_assistant_models.streamdetails import AudioFormat from music_assistant.constants import ( @@ -48,7 +50,7 @@ from .dsp import filter_to_ffmpeg_params from .ffmpeg import FFMpeg, get_ffmpeg_stream from .playlists import IsHLSPlaylist, PlaylistItem, fetch_playlist, parse_m3u from .process import AsyncProcess, communicate -from .util import create_tempfile, detect_charset +from .util import detect_charset, has_tmpfs_mount if TYPE_CHECKING: from music_assistant_models.config_entries import CoreConfig, PlayerConfig @@ -66,6 +68,124 @@ HTTP_HEADERS = {"User-Agent": "Lavf/60.16.100.MusicAssistant"} HTTP_HEADERS_ICY = {**HTTP_HEADERS, "Icy-MetaData": "1"} +async def remove_file(file_path: str) -> None: + """Remove file path (if it exists).""" + if not await asyncio.to_thread(os.path.exists, file_path): + return + await asyncio.to_thread(os.remove, file_path) + LOGGER.log(VERBOSE_LOG_LEVEL, "Removed cache file: %s", file_path) + + +class StreamCache: + """ + StreamCache. + + Basic class to handle temporary caching of audio streams. + For now, based on a (in-memory) tempfile and ffmpeg. + """ + + def acquire(self) -> str: + """Acquire the cache and return the cache file path.""" + # for the edge case where the cache file is not released, + # set a fallback timer to remove the file after 20 minutes + self.mass.call_later( + 20 * 60, remove_file, self._temp_path, task_id=f"remove_file_{self._temp_path}" + ) + return self._temp_path + + def release(self) -> None: + """Release the cache file.""" + # edge case: MA is closing, clean down the file immediately + if self.mass.closing: + os.remove(self._temp_path) + return + # set a timer to remove the file after 1 minute + # if the file is accessed again within this 1 minute, the timer will be cancelled + self.mass.call_later( + 60, remove_file, self._temp_path, task_id=f"remove_file_{self._temp_path}" + ) + + def __init__(self, mass: MusicAssistant, streamdetails: StreamDetails) -> None: + """Initialize the StreamCache.""" + self.mass = mass + self.streamdetails = streamdetails + ext = streamdetails.audio_format.output_format_str + self._temp_path = f"/tmp/{shortuuid.random(20)}.{ext}" # noqa: S108 + self._fetch_task: asyncio.Task | None = None + self.org_path: str | None = streamdetails.path + self.org_stream_type: StreamType | None = streamdetails.stream_type + self.org_extra_input_args: list[str] | None = streamdetails.extra_input_args + streamdetails.path = self._temp_path + streamdetails.stream_type = StreamType.CACHE_FILE + streamdetails.extra_input_args = [] + + async def create(self) -> None: + """Create the cache file (if needed).""" + if await asyncio.to_thread(os.path.exists, self._temp_path): + return + if self._fetch_task is not None and not self._fetch_task.done(): + # fetch task is already busy + return + self._fetch_task = self.mass.create_task(self._create_cache_file()) + # for the edge case where the cache file is not consumed at all, + # set a fallback timer to remove the file after 1 hour + self.mass.call_later( + 3600, remove_file, self._temp_path, task_id=f"remove_file_{self._temp_path}" + ) + + async def wait(self, require_complete_file: bool) -> None: + """ + Wait until the cache is ready. + + Optionally wait until the full file is available (e.g. when seeking). + """ + # if 'require_complete_file' is specified, we wait until the fetch task is ready + if require_complete_file: + await self._fetch_task + return + # wait until the file is created + while not await asyncio.to_thread(os.path.exists, self._temp_path): + await asyncio.sleep(0.2) + + async def _create_cache_file(self) -> None: + time_start = time.time() + LOGGER.log(VERBOSE_LOG_LEVEL, "Fetching audio stream to cache file %s", self._temp_path) + + if self.org_stream_type == StreamType.CUSTOM: + audio_source = self.mass.get_provider(self.streamdetails.provider).get_audio_stream( + self.streamdetails, + ) + elif self.org_stream_type in (StreamType.HTTP, StreamType.ENCRYPTED_HTTP): + audio_source = self.org_path + else: + raise NotImplementedError("Caching of this streamtype is not supported") + + extra_input_args = self.org_extra_input_args or [] + if self.streamdetails.decryption_key: + extra_input_args += ["-decryption_key", self.streamdetails.decryption_key] + + ffmpeg = FFMpeg( + audio_input=audio_source, + input_format=self.streamdetails.audio_format, + output_format=self.streamdetails.audio_format, + extra_input_args=["-y", *extra_input_args], + audio_output=self._temp_path, + ) + await ffmpeg.start() + await ffmpeg.wait() + process_time = int((time.time() - time_start) * 1000) + LOGGER.log( + VERBOSE_LOG_LEVEL, + "Writing cache file %s done in %s milliseconds", + self._temp_path, + process_time, + ) + + def __del__(self) -> None: + """Ensure the temp file gets cleaned up.""" + self.mass.loop.call_soon_threadsafe(self.mass.create_task, remove_file(self._temp_path)) + + async def crossfade_pcm_parts( fade_in_part: bytes, fade_out_part: bytes, @@ -75,8 +195,8 @@ async def crossfade_pcm_parts( sample_size = pcm_format.pcm_sample_size # calculate the fade_length from the smallest chunk fade_length = min(len(fade_in_part), len(fade_out_part)) / sample_size - fadeoutfile = create_tempfile() - async with aiofiles.open(fadeoutfile.name, "wb") as outfile: + fadeout_filename = f"/tmp/{shortuuid.random(20)}.pcm" # noqa: S108 + async with aiofiles.open(fadeout_filename, "wb") as outfile: await outfile.write(fade_out_part) args = [ # generic args @@ -94,7 +214,7 @@ async def crossfade_pcm_parts( "-ar", str(pcm_format.sample_rate), "-i", - fadeoutfile.name, + fadeout_filename, # fade_in part (stdin) "-acodec", pcm_format.content_type.name.lower(), @@ -114,7 +234,8 @@ async def crossfade_pcm_parts( pcm_format.content_type.value, "-", ] - _returncode, crossfaded_audio, _stderr = await communicate(args, fade_in_part) + _, crossfaded_audio, _ = await communicate(args, fade_in_part) + await remove_file(fadeout_filename) if crossfaded_audio: LOGGER.log( VERBOSE_LOG_LEVEL, @@ -294,8 +415,11 @@ async def get_stream_details( raise MediaNotFoundError( f"Unable to retrieve streamdetails for {queue_item.name} ({queue_item.uri})" ) - if queue_item.streamdetails and not queue_item.streamdetails.seconds_streamed: - # already got a fresh/unused streamdetails + if queue_item.streamdetails and ( + not queue_item.streamdetails.seconds_streamed + or queue_item.streamdetails.stream_type == StreamType.CACHE_FILE + ): + # already got a fresh/unused (or cached) streamdetails streamdetails = queue_item.streamdetails else: media_item = queue_item.media_item @@ -367,6 +491,42 @@ async def get_stream_details( queue_item.uri, process_time, ) + + if streamdetails.decryption_key: + # using intermediate cache is mandatory for decryption + streamdetails.enable_cache = True + + # determine if we may use a temporary cache for the audio stream + if streamdetails.enable_cache is None: + tmpfs_present = get_global_cache_value("tmpfs_present") + if tmpfs_present is None: + tmpfs_present = await has_tmpfs_mount() + await set_global_cache_values({"tmpfs_present": tmpfs_present}) + streamdetails.enable_cache = ( + tmpfs_present + and streamdetails.duration is not None + and streamdetails.duration < 1800 + and streamdetails.stream_type + in (StreamType.HTTP, StreamType.ENCRYPTED_HTTP, StreamType.CUSTOM, StreamType.HLS) + ) + + # handle temporary cache support of audio stream + if streamdetails.enable_cache: + if streamdetails.cache is None: + streamdetails.cache = StreamCache(mass, streamdetails) + else: + streamdetails.cache = cast(StreamCache, streamdetails.cache) + await streamdetails.cache.create() + require_complete_file = ( + # require complete file if we're seeking to prevent we're seeking beyond the cached data + streamdetails.seek_position > 0 + or streamdetails.audio_format.content_type + # m4a/mp4 files often have their moov/atom at the end of the file + # so we need the whole file to be available + in (ContentType.M4A, ContentType.M4B, ContentType.MP4) + ) + await streamdetails.cache.wait(require_complete_file=require_complete_file) + return streamdetails @@ -386,6 +546,11 @@ async def get_media_stream( if streamdetails.fade_in: filter_params.append("afade=type=in:start_time=0:duration=3") strip_silence_begin = False + + if streamdetails.stream_type == StreamType.CACHE_FILE: + cache = cast(StreamCache, streamdetails.cache) + audio_source = cache.acquire() + bytes_sent = 0 chunk_number = 0 buffer: bytes = b"" @@ -415,8 +580,10 @@ async def get_media_stream( pcm_format.content_type.value, ffmpeg_proc.proc.pid, ) - async for chunk in ffmpeg_proc.iter_chunked(pcm_format.pcm_sample_size): - if chunk_number == 0: + # use 1 second chunks + chunk_size = pcm_format.pcm_sample_size + async for chunk in ffmpeg_proc.iter_chunked(chunk_size): + if chunk_number == 1: # At this point ffmpeg has started and should now know the codec used # for encoding the audio. streamdetails.audio_format.codec_type = ffmpeg_proc.input_format.codec_type @@ -430,11 +597,17 @@ async def get_media_stream( chunk_number += 1 # determine buffer size dynamically if chunk_number < 5 and strip_silence_begin: - req_buffer_size = int(pcm_format.pcm_sample_size * 4) - elif chunk_number > 30 and strip_silence_end: + req_buffer_size = int(pcm_format.pcm_sample_size * 5) + elif chunk_number > 240 and strip_silence_end: + req_buffer_size = int(pcm_format.pcm_sample_size * 10) + elif chunk_number > 60 and strip_silence_end: req_buffer_size = int(pcm_format.pcm_sample_size * 8) - else: + elif chunk_number > 30: + req_buffer_size = int(pcm_format.pcm_sample_size * 4) + elif chunk_number > 10 and strip_silence_end: req_buffer_size = int(pcm_format.pcm_sample_size * 2) + else: + req_buffer_size = pcm_format.pcm_sample_size # always append to buffer buffer += chunk @@ -494,7 +667,7 @@ async def get_media_stream( # try to determine how many seconds we've streamed seconds_streamed = bytes_sent / pcm_format.pcm_sample_size if bytes_sent else 0 - if not cancelled and ffmpeg_proc.returncode != 0: + if not cancelled and ffmpeg_proc.returncode not in (0, 255): # dump the last 5 lines of the log in case of an unclean exit log_tail = "\n" + "\n".join(list(ffmpeg_proc.log_history)[-5:]) else: @@ -558,6 +731,11 @@ async def get_media_stream( ): mass.create_task(music_prov.on_streamed(streamdetails)) + # schedule removal of cache file + if streamdetails.stream_type == StreamType.CACHE_FILE: + cache = cast(StreamCache, streamdetails.cache) + cache.release() + def create_wave_header(samplerate=44100, channels=2, bitspersample=16, duration=None): """Generate a wave header from given params.""" diff --git a/music_assistant/helpers/ffmpeg.py b/music_assistant/helpers/ffmpeg.py index 7c1c8d83..3f457107 100644 --- a/music_assistant/helpers/ffmpeg.py +++ b/music_assistant/helpers/ffmpeg.py @@ -321,6 +321,9 @@ def get_ffmpeg_args( "-f", output_format.content_type.value, ] + elif input_format == output_format: + # passthrough + output_args = ["-c", "copy"] else: raise RuntimeError("Invalid/unsupported output format specified") diff --git a/music_assistant/helpers/util.py b/music_assistant/helpers/util.py index 8daf159d..a0fe11f5 100644 --- a/music_assistant/helpers/util.py +++ b/music_assistant/helpers/util.py @@ -7,10 +7,8 @@ import functools import importlib import logging import os -import platform import re import socket -import tempfile import urllib.error import urllib.parse import urllib.request @@ -23,9 +21,9 @@ from types import TracebackType from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar from urllib.parse import urlparse +import aiofiles import cchardet as chardet import ifaddr -import memory_tempfile from zeroconf import IPVersion from music_assistant.helpers.process import check_output @@ -454,12 +452,16 @@ async def load_provider_module(domain: str, requirements: list[str]) -> Provider return await asyncio.to_thread(_get_provider_module, domain) -def create_tempfile(): - """Return a (named) temporary file.""" - # ruff: noqa: SIM115 - if platform.system() == "Linux": - return memory_tempfile.MemoryTempfile(fallback=True).NamedTemporaryFile(buffering=0) - return tempfile.NamedTemporaryFile(buffering=0) +async def has_tmpfs_mount() -> bool: + """Check if we have a tmpfs mount.""" + try: + async with aiofiles.open("/proc/mounts") as file: + async for line in file: + if "tmpfs /tmp tmpfs rw" in line: + return True + except (FileNotFoundError, OSError, PermissionError): + pass + return False def divide_chunks(data: bytes, chunk_size: int) -> Iterator[bytes]: diff --git a/music_assistant/providers/apple_music/__init__.py b/music_assistant/providers/apple_music/__init__.py index bfee2038..df388b22 100644 --- a/music_assistant/providers/apple_music/__init__.py +++ b/music_assistant/providers/apple_music/__init__.py @@ -8,6 +8,7 @@ import os from typing import TYPE_CHECKING, Any import aiofiles +from aiohttp.client_exceptions import ClientError from music_assistant_models.config_entries import ConfigEntry, ConfigValueType from music_assistant_models.enums import ( AlbumType, @@ -363,14 +364,14 @@ class AppleMusicProvider(MusicProvider): return StreamDetails( item_id=item_id, provider=self.lookup_key, - audio_format=AudioFormat( - content_type=ContentType.UNKNOWN, - ), + audio_format=AudioFormat(content_type=ContentType.M4A, codec_type=ContentType.AAC), stream_type=StreamType.ENCRYPTED_HTTP, - path=stream_url, decryption_key=await self._get_decryption_key(license_url, key_id, uri, item_id), + path=stream_url, can_seek=True, allow_seek=True, + # enforce caching because the apple streams are m4a files with moov atom at the end + enable_cache=True, ) def _parse_artist(self, artist_obj): @@ -714,12 +715,23 @@ class AppleMusicProvider(MusicProvider): data = { "salableAdamId": song_id, } - async with self.mass.http_session.post( - playback_url, headers=self._get_decryption_headers(), json=data, ssl=True - ) as response: - response.raise_for_status() - content = await response.json(loads=json_loads) - return content["songList"][0] + for retry in (True, False): + try: + async with self.mass.http_session.post( + playback_url, headers=self._get_decryption_headers(), json=data, ssl=True + ) as response: + response.raise_for_status() + content = await response.json(loads=json_loads) + if content.get("failureType"): + message = content.get("failureMessage") + raise MediaNotFoundError(f"Failed to get song stream metadata: {message}") + return content["songList"][0] + except (MediaNotFoundError, ClientError) as exc: + if retry: + self.logger.warning("Failed to get song stream metadata: %s", exc) + continue + raise + raise MediaNotFoundError(f"Failed to get song stream metadata for {song_id}") async def _parse_stream_url_and_uri(self, stream_assets: list[dict]) -> str: """Parse the Stream URL and Key URI from the song.""" @@ -755,7 +767,7 @@ class AppleMusicProvider(MusicProvider): } async def _get_decryption_key( - self, license_url: str, key_id: str, uri: str, item_id: str + self, license_url: str, key_id: bytes, uri: str, item_id: str ) -> str: """Get the decryption key for a song.""" cache_key = f"decryption_key.{item_id}" diff --git a/pyproject.toml b/pyproject.toml index 6bbcd6b4..23fffbee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,9 +23,8 @@ dependencies = [ "faust-cchardet>=2.1.18", "ifaddr==0.2.0", "mashumaro==3.15", - "memory-tempfile==2.2.3", "music-assistant-frontend==2.11.11", - "music-assistant-models==1.1.30", + "music-assistant-models==1.1.31", "mutagen==1.47.0", "orjson==3.10.12", "pillow==11.1.0", diff --git a/requirements_all.txt b/requirements_all.txt index 371ee686..7eda67f8 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -25,9 +25,8 @@ hass-client==1.2.0 ibroadcastaio==0.4.0 ifaddr==0.2.0 mashumaro==3.15 -memory-tempfile==2.2.3 music-assistant-frontend==2.11.11 -music-assistant-models==1.1.30 +music-assistant-models==1.1.31 mutagen==1.47.0 orjson==3.10.12 pillow==11.1.0