import time
from collections.abc import AsyncGenerator
from io import BytesIO
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, cast
import aiofiles
+import shortuuid
from aiohttp import ClientTimeout
from music_assistant_models.dsp import DSPConfig, DSPDetails, DSPState
from music_assistant_models.enums import (
MusicAssistantError,
ProviderUnavailableError,
)
+from music_assistant_models.helpers import get_global_cache_value, set_global_cache_values
from music_assistant_models.streamdetails import AudioFormat
from music_assistant.constants import (
from .ffmpeg import FFMpeg, get_ffmpeg_stream
from .playlists import IsHLSPlaylist, PlaylistItem, fetch_playlist, parse_m3u
from .process import AsyncProcess, communicate
-from .util import create_tempfile, detect_charset
+from .util import detect_charset, has_tmpfs_mount
if TYPE_CHECKING:
from music_assistant_models.config_entries import CoreConfig, PlayerConfig
HTTP_HEADERS_ICY = {**HTTP_HEADERS, "Icy-MetaData": "1"}
+async def remove_file(file_path: str) -> None:
+ """Remove file path (if it exists)."""
+ if not await asyncio.to_thread(os.path.exists, file_path):
+ return
+ await asyncio.to_thread(os.remove, file_path)
+ LOGGER.log(VERBOSE_LOG_LEVEL, "Removed cache file: %s", file_path)
+
+
+class StreamCache:
+ """
+ StreamCache.
+
+ Basic class to handle temporary caching of audio streams.
+ For now, based on a (in-memory) tempfile and ffmpeg.
+ """
+
+ def acquire(self) -> str:
+ """Acquire the cache and return the cache file path."""
+ # for the edge case where the cache file is not released,
+ # set a fallback timer to remove the file after 20 minutes
+ self.mass.call_later(
+ 20 * 60, remove_file, self._temp_path, task_id=f"remove_file_{self._temp_path}"
+ )
+ return self._temp_path
+
+ def release(self) -> None:
+ """Release the cache file."""
+ # edge case: MA is closing, clean down the file immediately
+ if self.mass.closing:
+ os.remove(self._temp_path)
+ return
+ # set a timer to remove the file after 1 minute
+ # if the file is accessed again within this 1 minute, the timer will be cancelled
+ self.mass.call_later(
+ 60, remove_file, self._temp_path, task_id=f"remove_file_{self._temp_path}"
+ )
+
+ def __init__(self, mass: MusicAssistant, streamdetails: StreamDetails) -> None:
+ """Initialize the StreamCache."""
+ self.mass = mass
+ self.streamdetails = streamdetails
+ ext = streamdetails.audio_format.output_format_str
+ self._temp_path = f"/tmp/{shortuuid.random(20)}.{ext}" # noqa: S108
+ self._fetch_task: asyncio.Task | None = None
+ self.org_path: str | None = streamdetails.path
+ self.org_stream_type: StreamType | None = streamdetails.stream_type
+ self.org_extra_input_args: list[str] | None = streamdetails.extra_input_args
+ streamdetails.path = self._temp_path
+ streamdetails.stream_type = StreamType.CACHE_FILE
+ streamdetails.extra_input_args = []
+
+ async def create(self) -> None:
+ """Create the cache file (if needed)."""
+ if await asyncio.to_thread(os.path.exists, self._temp_path):
+ return
+ if self._fetch_task is not None and not self._fetch_task.done():
+ # fetch task is already busy
+ return
+ self._fetch_task = self.mass.create_task(self._create_cache_file())
+ # for the edge case where the cache file is not consumed at all,
+ # set a fallback timer to remove the file after 1 hour
+ self.mass.call_later(
+ 3600, remove_file, self._temp_path, task_id=f"remove_file_{self._temp_path}"
+ )
+
+ async def wait(self, require_complete_file: bool) -> None:
+ """
+ Wait until the cache is ready.
+
+ Optionally wait until the full file is available (e.g. when seeking).
+ """
+ # if 'require_complete_file' is specified, we wait until the fetch task is ready
+ if require_complete_file:
+ await self._fetch_task
+ return
+ # wait until the file is created
+ while not await asyncio.to_thread(os.path.exists, self._temp_path):
+ await asyncio.sleep(0.2)
+
+ async def _create_cache_file(self) -> None:
+ time_start = time.time()
+ LOGGER.log(VERBOSE_LOG_LEVEL, "Fetching audio stream to cache file %s", self._temp_path)
+
+ if self.org_stream_type == StreamType.CUSTOM:
+ audio_source = self.mass.get_provider(self.streamdetails.provider).get_audio_stream(
+ self.streamdetails,
+ )
+ elif self.org_stream_type in (StreamType.HTTP, StreamType.ENCRYPTED_HTTP):
+ audio_source = self.org_path
+ else:
+ raise NotImplementedError("Caching of this streamtype is not supported")
+
+ extra_input_args = self.org_extra_input_args or []
+ if self.streamdetails.decryption_key:
+ extra_input_args += ["-decryption_key", self.streamdetails.decryption_key]
+
+ ffmpeg = FFMpeg(
+ audio_input=audio_source,
+ input_format=self.streamdetails.audio_format,
+ output_format=self.streamdetails.audio_format,
+ extra_input_args=["-y", *extra_input_args],
+ audio_output=self._temp_path,
+ )
+ await ffmpeg.start()
+ await ffmpeg.wait()
+ process_time = int((time.time() - time_start) * 1000)
+ LOGGER.log(
+ VERBOSE_LOG_LEVEL,
+ "Writing cache file %s done in %s milliseconds",
+ self._temp_path,
+ process_time,
+ )
+
+ def __del__(self) -> None:
+ """Ensure the temp file gets cleaned up."""
+ self.mass.loop.call_soon_threadsafe(self.mass.create_task, remove_file(self._temp_path))
+
+
async def crossfade_pcm_parts(
fade_in_part: bytes,
fade_out_part: bytes,
sample_size = pcm_format.pcm_sample_size
# calculate the fade_length from the smallest chunk
fade_length = min(len(fade_in_part), len(fade_out_part)) / sample_size
- fadeoutfile = create_tempfile()
- async with aiofiles.open(fadeoutfile.name, "wb") as outfile:
+ fadeout_filename = f"/tmp/{shortuuid.random(20)}.pcm" # noqa: S108
+ async with aiofiles.open(fadeout_filename, "wb") as outfile:
await outfile.write(fade_out_part)
args = [
# generic args
"-ar",
str(pcm_format.sample_rate),
"-i",
- fadeoutfile.name,
+ fadeout_filename,
# fade_in part (stdin)
"-acodec",
pcm_format.content_type.name.lower(),
pcm_format.content_type.value,
"-",
]
- _returncode, crossfaded_audio, _stderr = await communicate(args, fade_in_part)
+ _, crossfaded_audio, _ = await communicate(args, fade_in_part)
+ await remove_file(fadeout_filename)
if crossfaded_audio:
LOGGER.log(
VERBOSE_LOG_LEVEL,
raise MediaNotFoundError(
f"Unable to retrieve streamdetails for {queue_item.name} ({queue_item.uri})"
)
- if queue_item.streamdetails and not queue_item.streamdetails.seconds_streamed:
- # already got a fresh/unused streamdetails
+ if queue_item.streamdetails and (
+ not queue_item.streamdetails.seconds_streamed
+ or queue_item.streamdetails.stream_type == StreamType.CACHE_FILE
+ ):
+ # already got a fresh/unused (or cached) streamdetails
streamdetails = queue_item.streamdetails
else:
media_item = queue_item.media_item
queue_item.uri,
process_time,
)
+
+ if streamdetails.decryption_key:
+ # using intermediate cache is mandatory for decryption
+ streamdetails.enable_cache = True
+
+ # determine if we may use a temporary cache for the audio stream
+ if streamdetails.enable_cache is None:
+ tmpfs_present = get_global_cache_value("tmpfs_present")
+ if tmpfs_present is None:
+ tmpfs_present = await has_tmpfs_mount()
+ await set_global_cache_values({"tmpfs_present": tmpfs_present})
+ streamdetails.enable_cache = (
+ tmpfs_present
+ and streamdetails.duration is not None
+ and streamdetails.duration < 1800
+ and streamdetails.stream_type
+ in (StreamType.HTTP, StreamType.ENCRYPTED_HTTP, StreamType.CUSTOM, StreamType.HLS)
+ )
+
+ # handle temporary cache support of audio stream
+ if streamdetails.enable_cache:
+ if streamdetails.cache is None:
+ streamdetails.cache = StreamCache(mass, streamdetails)
+ else:
+ streamdetails.cache = cast(StreamCache, streamdetails.cache)
+ await streamdetails.cache.create()
+ require_complete_file = (
+ # require complete file if we're seeking to prevent we're seeking beyond the cached data
+ streamdetails.seek_position > 0
+ or streamdetails.audio_format.content_type
+ # m4a/mp4 files often have their moov/atom at the end of the file
+ # so we need the whole file to be available
+ in (ContentType.M4A, ContentType.M4B, ContentType.MP4)
+ )
+ await streamdetails.cache.wait(require_complete_file=require_complete_file)
+
return streamdetails
if streamdetails.fade_in:
filter_params.append("afade=type=in:start_time=0:duration=3")
strip_silence_begin = False
+
+ if streamdetails.stream_type == StreamType.CACHE_FILE:
+ cache = cast(StreamCache, streamdetails.cache)
+ audio_source = cache.acquire()
+
bytes_sent = 0
chunk_number = 0
buffer: bytes = b""
pcm_format.content_type.value,
ffmpeg_proc.proc.pid,
)
- async for chunk in ffmpeg_proc.iter_chunked(pcm_format.pcm_sample_size):
- if chunk_number == 0:
+ # use 1 second chunks
+ chunk_size = pcm_format.pcm_sample_size
+ async for chunk in ffmpeg_proc.iter_chunked(chunk_size):
+ if chunk_number == 1:
# At this point ffmpeg has started and should now know the codec used
# for encoding the audio.
streamdetails.audio_format.codec_type = ffmpeg_proc.input_format.codec_type
chunk_number += 1
# determine buffer size dynamically
if chunk_number < 5 and strip_silence_begin:
- req_buffer_size = int(pcm_format.pcm_sample_size * 4)
- elif chunk_number > 30 and strip_silence_end:
+ req_buffer_size = int(pcm_format.pcm_sample_size * 5)
+ elif chunk_number > 240 and strip_silence_end:
+ req_buffer_size = int(pcm_format.pcm_sample_size * 10)
+ elif chunk_number > 60 and strip_silence_end:
req_buffer_size = int(pcm_format.pcm_sample_size * 8)
- else:
+ elif chunk_number > 30:
+ req_buffer_size = int(pcm_format.pcm_sample_size * 4)
+ elif chunk_number > 10 and strip_silence_end:
req_buffer_size = int(pcm_format.pcm_sample_size * 2)
+ else:
+ req_buffer_size = pcm_format.pcm_sample_size
# always append to buffer
buffer += chunk
# try to determine how many seconds we've streamed
seconds_streamed = bytes_sent / pcm_format.pcm_sample_size if bytes_sent else 0
- if not cancelled and ffmpeg_proc.returncode != 0:
+ if not cancelled and ffmpeg_proc.returncode not in (0, 255):
# dump the last 5 lines of the log in case of an unclean exit
log_tail = "\n" + "\n".join(list(ffmpeg_proc.log_history)[-5:])
else:
):
mass.create_task(music_prov.on_streamed(streamdetails))
+ # schedule removal of cache file
+ if streamdetails.stream_type == StreamType.CACHE_FILE:
+ cache = cast(StreamCache, streamdetails.cache)
+ cache.release()
+
def create_wave_header(samplerate=44100, channels=2, bitspersample=16, duration=None):
"""Generate a wave header from given params."""
from typing import TYPE_CHECKING, Any
import aiofiles
+from aiohttp.client_exceptions import ClientError
from music_assistant_models.config_entries import ConfigEntry, ConfigValueType
from music_assistant_models.enums import (
AlbumType,
return StreamDetails(
item_id=item_id,
provider=self.lookup_key,
- audio_format=AudioFormat(
- content_type=ContentType.UNKNOWN,
- ),
+ audio_format=AudioFormat(content_type=ContentType.M4A, codec_type=ContentType.AAC),
stream_type=StreamType.ENCRYPTED_HTTP,
- path=stream_url,
decryption_key=await self._get_decryption_key(license_url, key_id, uri, item_id),
+ path=stream_url,
can_seek=True,
allow_seek=True,
+ # enforce caching because the apple streams are m4a files with moov atom at the end
+ enable_cache=True,
)
def _parse_artist(self, artist_obj):
data = {
"salableAdamId": song_id,
}
- async with self.mass.http_session.post(
- playback_url, headers=self._get_decryption_headers(), json=data, ssl=True
- ) as response:
- response.raise_for_status()
- content = await response.json(loads=json_loads)
- return content["songList"][0]
+ for retry in (True, False):
+ try:
+ async with self.mass.http_session.post(
+ playback_url, headers=self._get_decryption_headers(), json=data, ssl=True
+ ) as response:
+ response.raise_for_status()
+ content = await response.json(loads=json_loads)
+ if content.get("failureType"):
+ message = content.get("failureMessage")
+ raise MediaNotFoundError(f"Failed to get song stream metadata: {message}")
+ return content["songList"][0]
+ except (MediaNotFoundError, ClientError) as exc:
+ if retry:
+ self.logger.warning("Failed to get song stream metadata: %s", exc)
+ continue
+ raise
+ raise MediaNotFoundError(f"Failed to get song stream metadata for {song_id}")
async def _parse_stream_url_and_uri(self, stream_assets: list[dict]) -> str:
"""Parse the Stream URL and Key URI from the song."""
}
async def _get_decryption_key(
- self, license_url: str, key_id: str, uri: str, item_id: str
+ self, license_url: str, key_id: bytes, uri: str, item_id: str
) -> str:
"""Get the decryption key for a song."""
cache_key = f"decryption_key.{item_id}"