A collection of small tweaks and bugfixes (#1603)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Sat, 24 Aug 2024 17:29:35 +0000 (19:29 +0200)
committerGitHub <noreply@github.com>
Sat, 24 Aug 2024 17:29:35 +0000 (19:29 +0200)
13 files changed:
music_assistant/common/models/queue_item.py
music_assistant/server/controllers/cache.py
music_assistant/server/controllers/media/albums.py
music_assistant/server/controllers/media/artists.py
music_assistant/server/controllers/media/playlists.py
music_assistant/server/controllers/media/radio.py
music_assistant/server/controllers/media/tracks.py
music_assistant/server/controllers/music.py
music_assistant/server/controllers/player_queues.py
music_assistant/server/helpers/audio.py
music_assistant/server/helpers/database.py
music_assistant/server/helpers/throttle_retry.py
music_assistant/server/providers/filesystem_local/helpers.py

index 3c416342f092956978daeae50a6758234b50b57c..149dfeefcbb2de0d67d769fd45e7a57512ef4519 100644 (file)
@@ -30,7 +30,7 @@ class QueueItem(DataClassDictMixin):
 
     def __post_init__(self) -> None:
         """Set default values."""
-        if self.streamdetails and self.streamdetails.stream_title:
+        if not self.name and self.streamdetails and self.streamdetails.stream_title:
             self.name = self.streamdetails.stream_title
         if not self.name:
             self.name = self.uri
@@ -42,6 +42,8 @@ class QueueItem(DataClassDictMixin):
             streamdetails.pop("data", None)
             streamdetails.pop("direct", None)
             streamdetails.pop("expires", None)
+            streamdetails.pop("path", None)
+            streamdetails.pop("decryption_key", None)
         return d
 
     @property
index 8c4e22bc5f3a7b72d83af4cf68c75cd2c5cabb8f..bf27f7f409e0fdfae9f7820ed84d89ec27da53bb 100644 (file)
@@ -143,7 +143,7 @@ class CacheController(CoreController):
             # do not cache items in db with short expiration
             return
         data = await asyncio.to_thread(json_dumps, data)
-        await self.database.insert(
+        await self.database.insert_or_replace(
             DB_TABLE_CACHE,
             {
                 "category": category,
@@ -153,7 +153,6 @@ class CacheController(CoreController):
                 "checksum": checksum,
                 "data": data,
             },
-            allow_replace=True,
         )
 
     async def delete(
index 6d3454eb97ee03831fbf4e10e7e36a2a3b5dbc32..6a41fcc9e18fedfba4c4d146fb29ab1f18511d4a 100644 (file)
@@ -298,7 +298,7 @@ class AlbumsController(MediaControllerBase[Album]):
         if not item.artists:
             msg = "Album is missing artist(s)"
             raise InvalidDataError(msg)
-        new_item = await self.mass.music.database.insert(
+        db_id = await self.mass.music.database.insert(
             self.db_table,
             {
                 "name": item.name,
@@ -311,7 +311,6 @@ class AlbumsController(MediaControllerBase[Album]):
                 "external_ids": serialize_to_json(item.external_ids),
             },
         )
-        db_id = new_item["item_id"]
         # update/set provider_mappings table
         await self._set_provider_mappings(db_id, item.provider_mappings)
         # set track artist(s)
index b6ba3fdc02b2d6574c497c1105f44c5a24c8c457..1e04fe2b01d6230679003d3e88118a4342d111fc 100644 (file)
@@ -347,7 +347,7 @@ class ArtistsController(MediaControllerBase[Artist]):
         if item.mbid == VARIOUS_ARTISTS_MBID:
             item.name = VARIOUS_ARTISTS_NAME
         # no existing item matched: insert item
-        new_item = await self.mass.music.database.insert(
+        db_id = await self.mass.music.database.insert(
             self.db_table,
             {
                 "name": item.name,
@@ -357,7 +357,6 @@ class ArtistsController(MediaControllerBase[Artist]):
                 "metadata": serialize_to_json(item.metadata),
             },
         )
-        db_id = new_item["item_id"]
         # update/set provider_mappings table
         await self._set_provider_mappings(db_id, item.provider_mappings)
         self.logger.debug("added %s to database (id: %s)", item.name, db_id)
index 5fbf96260ff10160a54679802c2c2472c0293a32..4d8779c991f9f3804917aa84cad066c2703c9aba 100644 (file)
@@ -281,7 +281,7 @@ class PlaylistController(MediaControllerBase[Playlist]):
 
     async def _add_library_item(self, item: Playlist) -> int:
         """Add a new record to the database."""
-        new_item = await self.mass.music.database.insert(
+        db_id = await self.mass.music.database.insert(
             self.db_table,
             {
                 "name": item.name,
@@ -294,7 +294,6 @@ class PlaylistController(MediaControllerBase[Playlist]):
                 "cache_checksum": item.cache_checksum,
             },
         )
-        db_id = new_item["item_id"]
         # update/set provider_mappings table
         await self._set_provider_mappings(db_id, item.provider_mappings)
         self.logger.debug("added %s to database (id: %s)", item.name, db_id)
index 6074ffaaa2f5427d32142d060c4e3515491b7d80..42df326ecb4bdc3a111e17be495aa876939b3e49 100644 (file)
@@ -55,7 +55,7 @@ class RadioController(MediaControllerBase[Radio]):
 
     async def _add_library_item(self, item: Radio) -> int:
         """Add a new item record to the database."""
-        new_item = await self.mass.music.database.insert(
+        db_id = await self.mass.music.database.insert(
             self.db_table,
             {
                 "name": item.name,
@@ -65,7 +65,6 @@ class RadioController(MediaControllerBase[Radio]):
                 "external_ids": serialize_to_json(item.external_ids),
             },
         )
-        db_id = new_item["item_id"]
         # update/set provider_mappings table
         await self._set_provider_mappings(db_id, item.provider_mappings)
         self.logger.debug("added %s to database (id: %s)", item.name, db_id)
index 377515f4ce4404214a3e3af45a944f7d343c2908..e4b7d853a6f46248826ecc9aa4fc422395b4f87a 100644 (file)
@@ -411,7 +411,7 @@ class TracksController(MediaControllerBase[Track]):
         if not item.artists:
             msg = "Track is missing artist(s)"
             raise InvalidDataError(msg)
-        new_item = await self.mass.music.database.insert(
+        db_id = await self.mass.music.database.insert(
             self.db_table,
             {
                 "name": item.name,
@@ -423,7 +423,6 @@ class TracksController(MediaControllerBase[Track]):
                 "metadata": serialize_to_json(item.metadata),
             },
         )
-        db_id = new_item["item_id"]
         # update/set provider_mappings table
         await self._set_provider_mappings(db_id, item.provider_mappings)
         # set track artist(s)
index 1a8717eebffa6cb10202d2984b5145ade9edf55a..58a13f3353ed1b5210f0439b1bbf5c31d68db66e 100644 (file)
@@ -640,12 +640,27 @@ class MusicController(CoreController):
         # fetch full (provider) item
         media_item = await ctrl.get_provider_item(item_id, provider, force_refresh=True)
         # update library item if needed (including refresh of the metadata etc.)
-        if library_id is not None:
-            library_item = await ctrl.update_item_in_library(library_id, media_item, overwrite=True)
-            await self.mass.metadata.update_metadata(library_item, force_refresh=True)
-            return library_item
-
-        return media_item
+        if library_id is None:
+            return media_item
+        library_item = await ctrl.update_item_in_library(library_id, media_item, overwrite=True)
+        if library_item.media_type == MediaType.ALBUM:
+            # update (local) album tracks
+            for album_track in await self.albums.tracks(
+                library_item.item_id, library_item.provider, True
+            ):
+                for prov_mapping in album_track.provider_mappings:
+                    if not (prov := self.mass.get_provider(prov_mapping.provider_instance)):
+                        continue
+                    if prov.is_streaming_provider:
+                        continue
+                    with suppress(MediaNotFoundError):
+                        prov_track = await prov.get_track(prov_mapping.item_id)
+                        await self.mass.music.tracks.update_item_in_library(
+                            album_track.item_id, prov_track
+                        )
+
+        await self.mass.metadata.update_metadata(library_item, force_refresh=True)
+        return library_item
 
     async def set_track_loudness(
         self, item_id: str, provider_instance_id_or_domain: str, loudness: LoudnessMeasurement
index 025202aecc89d17cff294d34bf47e440d7a5d581..ff66ac51a4cf260f1bcc8f0b50979afdd58751ed 100644 (file)
@@ -30,10 +30,11 @@ from music_assistant.common.models.errors import (
     PlayerUnavailableError,
     QueueEmpty,
 )
-from music_assistant.common.models.media_items import MediaItemType, media_from_dict
+from music_assistant.common.models.media_items import AudioFormat, MediaItemType, media_from_dict
 from music_assistant.common.models.player import PlayerMedia
 from music_assistant.common.models.player_queue import PlayerQueue
 from music_assistant.common.models.queue_item import QueueItem
+from music_assistant.common.models.streamdetails import StreamDetails
 from music_assistant.constants import CONF_FLOW_MODE, FALLBACK_DURATION, MASS_LOGO_ONLINE
 from music_assistant.server.helpers.api import api_command
 from music_assistant.server.helpers.audio import get_stream_details
@@ -966,23 +967,37 @@ class PlayerQueuesController(CoreController):
             cur_index = current_item_id_or_index
         idx = 0
         while True:
+            next_item: QueueItem | None = None
             next_index = self._get_next_index(queue_id, cur_index + idx, allow_repeat=allow_repeat)
             if next_index is None:
                 raise QueueEmpty("No more tracks left in the queue.")
-            next_item = self.get_item(queue_id, next_index)
+            queue_item = self.get_item(queue_id, next_index)
             try:
                 # Check if the QueueItem is playable. For example, YT Music returns Radio Items
                 # that are not playable which will stop playback.
-                next_item.streamdetails = await get_stream_details(
-                    mass=self.mass, queue_item=next_item
+                queue_item.streamdetails = await get_stream_details(
+                    mass=self.mass, queue_item=queue_item
                 )
-                # Lazy load the full MediaItem for the QueueItem, making sure to get the
+                # Preload the full MediaItem for the QueueItem, making sure to get the
                 # maximum quality of thumbs
-                next_item.media_item = await self.mass.music.get_item_by_uri(next_item.uri)
+                if queue_item.media_item:
+                    queue_item.media_item = await self.mass.music.get_item_by_uri(queue_item.uri)
+                # we're all set, this is our next item
+                next_item = queue_item
                 break
             except MediaNotFoundError:
                 # No stream details found, skip this QueueItem
-                next_item = None
+                self.logger.debug("Skipping unplayable item: %s", next_item)
+                # we need to set a fake streamdetails object on the item
+                # otherwise our flow mode logic will break that
+                # calculates where we are in the queue
+                queue_item.streamdetails = StreamDetails(
+                    provider=queue_item.media_item.provider if queue_item.media_item else "unknown",
+                    item_id=queue_item.media_item.item_id if queue_item.media_item else "unknown",
+                    audio_format=AudioFormat(),
+                    media_type=queue_item.media_type,
+                    seconds_streamed=0,
+                )
                 idx += 1
         if next_item is None:
             raise QueueEmpty("No more (playable) tracks left in the queue.")
index ec2d3f67afc1cf0881bff3e88b9b8d59c9a601b0..41943ca96f1acf92158d64becbed5e92e06dc652 100644 (file)
@@ -49,6 +49,7 @@ from music_assistant.server.helpers.playlists import (
     fetch_playlist,
     parse_m3u,
 )
+from music_assistant.server.helpers.throttle_retry import BYPASS_THROTTLER
 
 from .process import AsyncProcess, check_output, communicate
 from .util import create_tempfile
@@ -334,50 +335,47 @@ async def get_stream_details(
     if seek_position and (queue_item.media_type == MediaType.RADIO or not queue_item.duration):
         LOGGER.warning("seeking is not possible on duration-less streams!")
         seek_position = 0
-    if queue_item.streamdetails and seek_position:
-        LOGGER.debug(f"Using (pre)cached streamdetails from queue_item for {queue_item.uri}")
-        # we already have (fresh?) streamdetails stored on the queueitem, use these.
-        # only do this when we're seeking.
-        # we create a copy (using to/from dict) to ensure the one-time values are cleared
-        streamdetails = StreamDetails.from_dict(queue_item.streamdetails.to_dict())
-    else:
-        # always request the full item as there might be other qualities available
-        full_item = await mass.music.get_item_by_uri(queue_item.uri)
-        # sort by quality and check track availability
-        for prov_media in sorted(
-            full_item.provider_mappings, key=lambda x: x.quality or 0, reverse=True
-        ):
-            if not prov_media.available:
-                LOGGER.debug(f"Skipping unavailable {prov_media}")
-                continue
-            # guard that provider is available
-            music_prov = mass.get_provider(prov_media.provider_instance)
-            if not music_prov:
-                LOGGER.debug(f"Skipping {prov_media} - provider not available")
-                continue  # provider not available ?
-            # get streamdetails from provider
-            try:
-                streamdetails: StreamDetails = await music_prov.get_stream_details(
-                    prov_media.item_id
-                )
-            except MusicAssistantError as err:
-                LOGGER.warning(str(err))
-            else:
-                break
+    # we use a contextvar to bypass the throttler for this asyncio task/context
+    # this makes sure that playback has priority over other requests that may be
+    # happening in the background
+    BYPASS_THROTTLER.set(True)
+    # always request the full item as there might be other qualities available
+    full_item = await mass.music.get_item_by_uri(queue_item.uri)
+    # sort by quality and check track availability
+    for prov_media in sorted(
+        full_item.provider_mappings, key=lambda x: x.quality or 0, reverse=True
+    ):
+        if not prov_media.available:
+            LOGGER.debug(f"Skipping unavailable {prov_media}")
+            continue
+        # guard that provider is available
+        music_prov = mass.get_provider(prov_media.provider_instance)
+        if not music_prov:
+            LOGGER.debug(f"Skipping {prov_media} - provider not available")
+            continue  # provider not available ?
+        # get streamdetails from provider
+        try:
+            streamdetails: StreamDetails = await music_prov.get_stream_details(prov_media.item_id)
+        except MusicAssistantError as err:
+            LOGGER.warning(str(err))
         else:
-            raise MediaNotFoundError(f"Unable to retrieve streamdetails for {queue_item}")
+            break
+    else:
+        raise MediaNotFoundError(
+            f"Unable to retrieve streamdetails for {queue_item.name} ({queue_item.uri})"
+        )
 
-        # work out how to handle radio stream
-        if (
-            streamdetails.media_type in (MediaType.RADIO, StreamType.ICY, StreamType.HLS)
-            and streamdetails.stream_type == StreamType.HTTP
-        ):
-            resolved_url, is_icy, is_hls = await resolve_radio_stream(mass, streamdetails.path)
-            streamdetails.path = resolved_url
-            if is_hls:
-                streamdetails.stream_type = StreamType.HLS
-            elif is_icy:
-                streamdetails.stream_type = StreamType.ICY
+    # work out how to handle radio stream
+    if (
+        streamdetails.media_type in (MediaType.RADIO, StreamType.ICY, StreamType.HLS)
+        and streamdetails.stream_type == StreamType.HTTP
+    ):
+        resolved_url, is_icy, is_hls = await resolve_radio_stream(mass, streamdetails.path)
+        streamdetails.path = resolved_url
+        if is_hls:
+            streamdetails.stream_type = StreamType.HLS
+        elif is_icy:
+            streamdetails.stream_type = StreamType.ICY
     # set queue_id on the streamdetails so we know what is being streamed
     streamdetails.queue_id = queue_item.queue_id
     # handle skip/fade_in details
index 3c5985063bd69f61d48535bd51f73236d292ab16..7cada57137250678845660793f58a11d86fe8c81 100644 (file)
@@ -168,7 +168,7 @@ class DatabaseConnection:
         table: str,
         values: dict[str, Any],
         allow_replace: bool = False,
-    ) -> Mapping:
+    ) -> int:
         """Insert data in given table."""
         keys = tuple(values.keys())
         if allow_replace:
@@ -176,11 +176,9 @@ class DatabaseConnection:
         else:
             sql_query = f'INSERT INTO {table}({",".join(keys)})'
         sql_query += f' VALUES ({",".join(f":{x}" for x in keys)})'
-        await self.execute(sql_query, values)
+        row_id = await self._db.execute_insert(sql_query, values)
         await self._db.commit()
-        # return inserted/replaced item
-        lookup_vals = {key: value for key, value in values.items() if value not in (None, "")}
-        return await self.get_row(table, lookup_vals)
+        return row_id[0]
 
     async def insert_or_replace(self, table: str, values: dict[str, Any]) -> Mapping:
         """Insert or replace data in given table."""
index 149b41298116b6efed5207e47868625166e724a5..ce33b6f0e6eb4d7a9bcfd2956aa507f098bdb13a 100644 (file)
@@ -5,7 +5,9 @@ import functools
 import logging
 import time
 from collections import deque
-from collections.abc import Awaitable, Callable, Coroutine
+from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
+from contextlib import asynccontextmanager
+from contextvars import ContextVar
 from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar
 
 from music_assistant.common.models.errors import ResourceTemporarilyUnavailable, RetriesExhausted
@@ -19,6 +21,8 @@ _R = TypeVar("_R")
 _P = ParamSpec("_P")
 LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.throttle_retry")
 
+BYPASS_THROTTLER: ContextVar[bool] = ContextVar("BYPASS_THROTTLER", default=False)
+
 
 class Throttler:
     """asyncio_throttle (https://github.com/hallazzang/asyncio-throttle).
@@ -32,7 +36,6 @@ class Throttler:
         """Initialize the Throttler."""
         self.rate_limit = rate_limit
         self.period = period
-
         self._task_logs: deque[float] = deque()
 
     def _flush(self):
@@ -43,14 +46,14 @@ class Throttler:
             else:
                 break
 
-    async def _acquire(self):
+    async def acquire(self) -> float:
+        """Acquire a free slot from the Throttler, returns the throttled time."""
         cur_time = time.monotonic()
         start_time = cur_time
         while True:
             self._flush()
             if len(self._task_logs) < self.rate_limit:
                 break
-
             # sleep the exact amount of time until the oldest task can be flushed
             time_to_release = self._task_logs[0] + self.period - cur_time
             await asyncio.sleep(time_to_release)
@@ -59,47 +62,39 @@ class Throttler:
         self._task_logs.append(cur_time)
         return cur_time - start_time  # exactly 0 if not throttled
 
-    async def __aenter__(self):
+    async def __aenter__(self) -> float:
         """Wait until the lock is acquired, return the time delay."""
-        return await self._acquire()
+        return await self.acquire()
 
     async def __aexit__(self, exc_type, exc, tb):
         """Nothing to do on exit."""
 
 
-class ThrottlerManager(Throttler):
+class ThrottlerManager:
     """Throttler manager that extends asyncio Throttle by retrying."""
 
     def __init__(self, rate_limit: int, period: float = 1, retry_attempts=5, initial_backoff=5):
         """Initialize the AsyncThrottledContextManager."""
-        super().__init__(rate_limit=rate_limit, period=period)
         self.retry_attempts = retry_attempts
         self.initial_backoff = initial_backoff
-
-    async def wrap(
-        self,
-        func: Callable[_P, Awaitable[_R]],
-        *args: _P.args,
-        **kwargs: _P.kwargs,
-    ):
-        """Async function wrapper with retry logic."""
-        backoff_time = self.initial_backoff
-        for attempt in range(self.retry_attempts):
-            try:
-                async with self:
-                    return await func(self, *args, **kwargs)
-            except ResourceTemporarilyUnavailable as e:
-                if e.backoff_time:
-                    backoff_time = e.backoff_time
-                level = logging.DEBUG if attempt > 1 else logging.INFO
-                LOGGER.log(level, f"Attempt {attempt + 1}/{self.retry_attempts} failed: {e}")
-                if attempt < self.retry_attempts - 1:
-                    LOGGER.log(level, f"Retrying in {backoff_time} seconds...")
-                    await asyncio.sleep(backoff_time)
-                    backoff_time *= 2
-        else:  # noqa: PLW0120
-            msg = f"Retries exhausted, failed after {self.retry_attempts} attempts"
-            raise RetriesExhausted(msg)
+        self.throttler = Throttler(rate_limit, period)
+
+    @asynccontextmanager
+    async def acquire(self) -> AsyncGenerator[None, float]:
+        """Acquire a free slot from the Throttler, returns the throttled time."""
+        if BYPASS_THROTTLER.get():
+            yield 0
+        else:
+            yield await self.throttler.acquire()
+
+    @asynccontextmanager
+    async def bypass(self) -> AsyncGenerator[None, None]:
+        """Bypass the throttler."""
+        try:
+            token = BYPASS_THROTTLER.set(True)
+            yield None
+        finally:
+            BYPASS_THROTTLER.reset(token)
 
 
 def throttle_with_retries(
@@ -111,14 +106,13 @@ def throttle_with_retries(
     async def wrapper(self: _ProviderT, *args: _P.args, **kwargs: _P.kwargs) -> _R | None:
         """Call async function using the throttler with retries."""
         # the trottler attribute must be present on the class
-        throttler = self.throttler
+        throttler: ThrottlerManager = self.throttler
         backoff_time = throttler.initial_backoff
-        async with throttler as delay:
+        async with throttler.acquire() as delay:
             if delay != 0:
                 self.logger.debug(
                     "%s was delayed for %.3f secs due to throttling", func.__name__, delay
                 )
-
             for attempt in range(throttler.retry_attempts):
                 try:
                     return await func(self, *args, **kwargs)
index 4573de7c894ac8a28f27ce7d2730a32882dd0d54..bb575597c73ee4d1c60cd09fc70a8ce0e40499c9 100644 (file)
@@ -11,13 +11,16 @@ def get_artist_dir(album_or_track_dir: str, artist_name: str) -> str | None:
     """Look for (Album)Artist directory in path of a track (or album)."""
     parentdir = os.path.dirname(album_or_track_dir)
     # account for disc or album sublevel by ignoring (max) 2 levels if needed
+    matched_dir: str | None = None
     for _ in range(3):
         dirname = parentdir.rsplit(os.sep)[-1]
         if compare_strings(artist_name, dirname, False):
             # literal match
-            return parentdir
+            # we keep hunting further down to account for the
+            # edge case where the album name has the same name as the artist
+            matched_dir = parentdir
         parentdir = os.path.dirname(parentdir)
-    return None
+    return matched_dir
 
 
 def get_album_dir(track_dir: str, album_name: str) -> str | None: