From cd6c22e9195cf1c9f5252a4ee23863284dcd39b4 Mon Sep 17 00:00:00 2001 From: Marcel van der Veldt Date: Mon, 15 Dec 2025 00:43:40 +0100 Subject: [PATCH] Use preferred provider steering also for radio mode --- music_assistant/controllers/media/albums.py | 13 +- music_assistant/controllers/media/artists.py | 15 ++- .../controllers/media/audiobooks.py | 14 ++- music_assistant/controllers/media/base.py | 12 +- music_assistant/controllers/media/genres.py | 12 +- .../controllers/media/playlists.py | 13 +- music_assistant/controllers/media/podcasts.py | 12 +- music_assistant/controllers/media/radio.py | 12 +- music_assistant/controllers/media/tracks.py | 113 +++++++++++++----- music_assistant/controllers/player_queues.py | 14 ++- 10 files changed, 164 insertions(+), 66 deletions(-) diff --git a/music_assistant/controllers/media/albums.py b/music_assistant/controllers/media/albums.py index 1a764788..a7a4f785 100644 --- a/music_assistant/controllers/media/albums.py +++ b/music_assistant/controllers/media/albums.py @@ -454,11 +454,16 @@ class AlbumsController(MediaControllerBase[Album]): async def radio_mode_base_tracks( self, - item_id: str, - provider_instance_id_or_domain: str, + item: Album, + preferred_provider_instances: list[str] | None = None, ) -> list[Track]: - """Get the list of base tracks from the controller used to calculate the dynamic radio.""" - return await self.tracks(item_id, provider_instance_id_or_domain, in_library_only=False) + """ + Get the list of base tracks from the controller used to calculate the dynamic radio. + + :param item: The Album to get base tracks for. + :param preferred_provider_instances: List of preferred provider instance IDs to use. + """ + return await self.tracks(item.item_id, item.provider, in_library_only=False) async def _set_album_artists( self, diff --git a/music_assistant/controllers/media/artists.py b/music_assistant/controllers/media/artists.py index 4042d22b..d9acdd9e 100644 --- a/music_assistant/controllers/media/artists.py +++ b/music_assistant/controllers/media/artists.py @@ -368,13 +368,18 @@ class ArtistsController(MediaControllerBase[Artist]): async def radio_mode_base_tracks( self, - item_id: str, - provider_instance_id_or_domain: str, + item: Artist, + preferred_provider_instances: list[str] | None = None, ) -> list[Track]: - """Get the list of base tracks from the controller used to calculate the dynamic radio.""" + """ + Get the list of base tracks from the controller used to calculate the dynamic radio. + + :param item: The Artist to get base tracks for. + :param preferred_provider_instances: List of preferred provider instance IDs to use. + """ return await self.tracks( - item_id, - provider_instance_id_or_domain, + item.item_id, + item.provider, in_library_only=False, ) diff --git a/music_assistant/controllers/media/audiobooks.py b/music_assistant/controllers/media/audiobooks.py index 5403c033..54d37657 100644 --- a/music_assistant/controllers/media/audiobooks.py +++ b/music_assistant/controllers/media/audiobooks.py @@ -204,12 +204,16 @@ class AudiobooksController(MediaControllerBase[Audiobook]): async def radio_mode_base_tracks( self, - item_id: str, - provider_instance_id_or_domain: str, - limit: int = 25, + item: Audiobook, + preferred_provider_instances: list[str] | None = None, ) -> list[Track]: - """Get the list of base tracks from the controller used to calculate the dynamic radio.""" - msg = "Dynamic tracks not supported for Radio MediaItem" + """ + Get the list of base tracks from the controller used to calculate the dynamic radio. + + :param item: The Audiobook to get base tracks for. + :param preferred_provider_instances: List of preferred provider instance IDs to use. + """ + msg = "Dynamic tracks not supported for Audiobook MediaItem" raise NotImplementedError(msg) async def match_provider( diff --git a/music_assistant/controllers/media/base.py b/music_assistant/controllers/media/base.py index 2e44ecc9..a60d9f79 100644 --- a/music_assistant/controllers/media/base.py +++ b/music_assistant/controllers/media/base.py @@ -736,10 +736,16 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta): @abstractmethod async def radio_mode_base_tracks( self, - item_id: str, - provider_instance_id_or_domain: str, + item: ItemCls, + preferred_provider_instances: list[str] | None = None, ) -> list[Track]: - """Get the list of base tracks from the controller used to calculate the dynamic radio.""" + """ + Get the list of base tracks from the controller used to calculate the dynamic radio. + + :param item: The MediaItem to get base tracks for. + :param preferred_provider_instances: List of preferred provider instance IDs to use. + When provided, these providers will be tried first before falling back to others. + """ @final async def _get_library_items_by_query( diff --git a/music_assistant/controllers/media/genres.py b/music_assistant/controllers/media/genres.py index 0ea6f2f8..bb940dd7 100644 --- a/music_assistant/controllers/media/genres.py +++ b/music_assistant/controllers/media/genres.py @@ -48,11 +48,15 @@ class GenreController(MediaControllerBase[Genre]): async def radio_mode_base_tracks( self, - item_id: str, - provider_instance_id_or_domain: str, - limit: int = 25, + item: Genre, + preferred_provider_instances: list[str] | None = None, ) -> list[Track]: - """Get the list of base tracks from the controller - stub implementation.""" + """ + Get the list of base tracks from the controller - stub implementation. + + :param item: The Genre to get base tracks for. + :param preferred_provider_instances: List of preferred provider instance IDs to use. + """ raise NotImplementedError("Genre support is not yet implemented") async def match_providers(self, db_item: Genre) -> None: diff --git a/music_assistant/controllers/media/playlists.py b/music_assistant/controllers/media/playlists.py index 80b27192..b4ec18a7 100644 --- a/music_assistant/controllers/media/playlists.py +++ b/music_assistant/controllers/media/playlists.py @@ -440,13 +440,18 @@ class PlaylistController(MediaControllerBase[Playlist]): async def radio_mode_base_tracks( self, - item_id: str, - provider_instance_id_or_domain: str, + item: Playlist, + preferred_provider_instances: list[str] | None = None, ) -> list[Track]: - """Get the list of base tracks from the controller used to calculate the dynamic radio.""" + """ + Get the list of base tracks from the controller used to calculate the dynamic radio. + + :param item: The Playlist to get base tracks for. + :param preferred_provider_instances: List of preferred provider instance IDs to use. + """ return [ x - async for x in self.tracks(item_id, provider_instance_id_or_domain) + async for x in self.tracks(item.item_id, item.provider) # filter out unavailable tracks if x.available ] diff --git a/music_assistant/controllers/media/podcasts.py b/music_assistant/controllers/media/podcasts.py index d8138ee4..7530b958 100644 --- a/music_assistant/controllers/media/podcasts.py +++ b/music_assistant/controllers/media/podcasts.py @@ -243,11 +243,15 @@ class PodcastsController(MediaControllerBase[Podcast]): async def radio_mode_base_tracks( self, - item_id: str, - provider_instance_id_or_domain: str, - limit: int = 25, + item: Podcast, + preferred_provider_instances: list[str] | None = None, ) -> list[Track]: - """Get the list of base tracks from the controller used to calculate the dynamic radio.""" + """ + Get the list of base tracks from the controller used to calculate the dynamic radio. + + :param item: The Podcast to get base tracks for. + :param preferred_provider_instances: List of preferred provider instance IDs to use. + """ msg = "Dynamic tracks not supported for Podcast MediaItem" raise NotImplementedError(msg) diff --git a/music_assistant/controllers/media/radio.py b/music_assistant/controllers/media/radio.py index cad66626..d5e902a8 100644 --- a/music_assistant/controllers/media/radio.py +++ b/music_assistant/controllers/media/radio.py @@ -124,11 +124,15 @@ class RadioController(MediaControllerBase[Radio]): async def radio_mode_base_tracks( self, - item_id: str, - provider_instance_id_or_domain: str, - limit: int = 25, + item: Radio, + preferred_provider_instances: list[str] | None = None, ) -> list[Track]: - """Get the list of base tracks from the controller used to calculate the dynamic radio.""" + """ + Get the list of base tracks from the controller used to calculate the dynamic radio. + + :param item: The Radio to get base tracks for. + :param preferred_provider_instances: List of preferred provider instance IDs to use. + """ msg = "Dynamic tracks not supported for Radio MediaItem" raise NotImplementedError(msg) diff --git a/music_assistant/controllers/media/tracks.py b/music_assistant/controllers/media/tracks.py index 47bc97d6..7a9e555f 100644 --- a/music_assistant/controllers/media/tracks.py +++ b/music_assistant/controllers/media/tracks.py @@ -6,7 +6,7 @@ import urllib.parse from collections.abc import Iterable from typing import TYPE_CHECKING, Any -from music_assistant_models.enums import MediaType, ProviderFeature, ProviderType +from music_assistant_models.enums import MediaType, ProviderFeature from music_assistant_models.errors import ( InvalidDataError, MusicAssistantError, @@ -305,36 +305,79 @@ class TracksController(MediaControllerBase[Track]): provider_instance_id_or_domain: str, limit: int = 25, allow_lookup: bool = False, + preferred_provider_instances: list[str] | None = None, ) -> list[Track]: - """Get a list of similar tracks for the given track.""" + """ + Get a list of similar tracks for the given track. + + :param item_id: The item ID of the track. + :param provider_instance_id_or_domain: The provider instance ID or domain. + :param limit: Maximum number of similar tracks to return. + :param allow_lookup: Allow lookup on other providers if not found. + :param preferred_provider_instances: List of preferred provider instance IDs to use. + When provided, these providers will be tried first before falling back to others. + """ ref_item = await self.get(item_id, provider_instance_id_or_domain) - for prov_mapping in ref_item.provider_mappings: - prov = self.mass.get_provider(prov_mapping.provider_instance) - if prov is None: - continue - if not isinstance(prov, MusicProvider): - continue - if ProviderFeature.SIMILAR_TRACKS not in prov.supported_features: - continue - # Grab similar tracks from the music provider - return await prov.get_similar_tracks(prov_track_id=prov_mapping.item_id, limit=limit) + + # Sort provider mappings to prefer user's provider instances + def sort_key(mapping: ProviderMapping) -> tuple[int, int]: + # Primary sort: preferred providers first (0), then others (1) + preferred = ( + 0 + if preferred_provider_instances + and mapping.provider_instance in preferred_provider_instances + else 1 + ) + # Secondary sort: by quality (higher is better, so negate) + quality = -(mapping.quality or 0) + return (preferred, quality) + + sorted_mappings = sorted(ref_item.provider_mappings, key=sort_key) + + # Try preferred providers first, then fall back to others + for allow_other_provider in (False, True): + for prov_mapping in sorted_mappings: + if ( + not allow_other_provider + and preferred_provider_instances + and prov_mapping.provider_instance not in preferred_provider_instances + ): + continue + prov = self.mass.get_provider(prov_mapping.provider_instance) + if prov is None: + continue + if not isinstance(prov, MusicProvider): + continue + if ProviderFeature.SIMILAR_TRACKS not in prov.supported_features: + continue + # Grab similar tracks from the music provider + return await prov.get_similar_tracks( + prov_track_id=prov_mapping.item_id, limit=limit + ) + if not allow_lookup: return [] # check if we have any provider that supports dynamic tracks # TODO: query metadata provider(s) (such as lastfm?) # to get similar tracks (or tracks from similar artists) - for prov in self.mass.get_providers(ProviderType.MUSIC): + music_prov: MusicProvider | None = None + for prov in self.mass.music.providers: if ProviderFeature.SIMILAR_TRACKS in prov.supported_features: + music_prov = prov break - else: + if music_prov is None: msg = "No Music Provider found that supports requesting similar tracks." raise UnsupportedFeaturedException(msg) - if ref_item.provider == "library": - await self.mass.metadata.update_metadata(ref_item) - else: - await self.match_providers(ref_item) + if mappings := await self.match_provider(ref_item, music_prov): + if ref_item.provider == "library": + # update database with new provider mappings + await self.add_provider_mappings(ref_item.item_id, mappings) + ref_item.provider_mappings.update(mappings) + return await music_prov.get_similar_tracks( + prov_track_id=mappings[0].item_id, limit=limit + ) return [] @@ -375,30 +418,30 @@ class TracksController(MediaControllerBase[Track]): async def match_provider( self, - db_track: Track, + base_track: Track, provider: MusicProvider, strict: bool = True, ref_albums: list[Album] | None = None, ) -> list[ProviderMapping]: """ - Try to find match on (streaming) provider for the provided (database) track. + Try to find match on (streaming) provider for the provided track. This is used to link objects of different providers/qualities together. """ if ref_albums is None: - ref_albums = await self.albums(db_track.item_id, db_track.provider) - self.logger.debug("Trying to match track %s on provider %s", db_track.name, provider.name) + ref_albums = await self.albums(base_track.item_id, base_track.provider) + self.logger.debug("Trying to match track %s on provider %s", base_track.name, provider.name) matches: list[ProviderMapping] = [] - for artist in db_track.artists: + for artist in base_track.artists: if matches: break - search_str = f"{artist.name} - {db_track.name}" + search_str = f"{artist.name} - {base_track.name}" search_result = await self.search(search_str, provider.domain) for search_result_item in search_result: if not search_result_item.available: continue # do a basic compare first - if not compare_media_item(db_track, search_result_item, strict=False): + if not compare_media_item(base_track, search_result_item, strict=False): continue # we must fetch the full version, search results can be simplified objects prov_track = await self.get_provider_item( @@ -406,19 +449,20 @@ class TracksController(MediaControllerBase[Track]): search_result_item.provider, fallback=search_result_item, ) - if compare_track(db_track, prov_track, strict=strict, track_albums=ref_albums): + if compare_track(base_track, prov_track, strict=strict, track_albums=ref_albums): matches.extend(search_result_item.provider_mappings) if not matches: self.logger.debug( "Could not find match for Track %s on provider %s", - db_track.name, + base_track.name, provider.name, ) return matches async def match_providers(self, db_track: Track) -> None: - """Try to find matching track on all providers for the provided (database) track_id. + """ + Try to find matching track on all providers for the provided (database) track_id. This is used to link objects of different providers/qualities together. """ @@ -447,11 +491,16 @@ class TracksController(MediaControllerBase[Track]): async def radio_mode_base_tracks( self, - item_id: str, - provider_instance_id_or_domain: str, + item: Track, + preferred_provider_instances: list[str] | None = None, ) -> list[Track]: - """Get the list of base tracks from the controller used to calculate the dynamic radio.""" - return [await self.get(item_id, provider_instance_id_or_domain)] + """ + Get the list of base tracks from the controller used to calculate the dynamic radio. + + :param item: The Track to get base tracks for. + :param preferred_provider_instances: List of preferred provider instance IDs to use. + """ + return [item] async def _add_library_item(self, item: Track, overwrite_existing: bool = False) -> int: """Add a new item record to the database.""" diff --git a/music_assistant/controllers/player_queues.py b/music_assistant/controllers/player_queues.py index a843709b..9499243a 100644 --- a/music_assistant/controllers/player_queues.py +++ b/music_assistant/controllers/player_queues.py @@ -1832,6 +1832,16 @@ class PlayerQueuesController(CoreController): queue.display_name, ", ".join([x.name for x in queue.radio_source]), ) + + # Get user's preferred provider instances for steering provider selection + preferred_provider_instances: list[str] | None = None + if ( + queue.userid + and (playback_user := await self.mass.webserver.auth.get_user(queue.userid)) + and playback_user.provider_filter + ): + preferred_provider_instances = playback_user.provider_filter + available_base_tracks: list[Track] = [] base_track_sample_size = 5 # Some providers have very deterministic similar track algorithms when providing @@ -1852,7 +1862,8 @@ class PlayerQueuesController(CoreController): available_base_tracks += [ track for track in await ctrl.radio_mode_base_tracks( - radio_item.item_id, radio_item.provider + radio_item, # type: ignore[arg-type] + preferred_provider_instances, ) # Avoid duplicate base tracks if track not in available_base_tracks @@ -1883,6 +1894,7 @@ class PlayerQueuesController(CoreController): base_track.item_id, base_track.provider, allow_lookup=allow_lookup, + preferred_provider_instances=preferred_provider_instances, ) except MediaNotFoundError: # Some providers don't have similar tracks for all items. For example, -- 2.34.1