from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
-from music_assistant_models.enums import MediaType, ProviderFeature, ProviderType
+from music_assistant_models.enums import MediaType, ProviderFeature
from music_assistant_models.errors import (
InvalidDataError,
MusicAssistantError,
provider_instance_id_or_domain: str,
limit: int = 25,
allow_lookup: bool = False,
+ preferred_provider_instances: list[str] | None = None,
) -> list[Track]:
- """Get a list of similar tracks for the given track."""
+ """
+ Get a list of similar tracks for the given track.
+
+ :param item_id: The item ID of the track.
+ :param provider_instance_id_or_domain: The provider instance ID or domain.
+ :param limit: Maximum number of similar tracks to return.
+ :param allow_lookup: Allow lookup on other providers if not found.
+ :param preferred_provider_instances: List of preferred provider instance IDs to use.
+ When provided, these providers will be tried first before falling back to others.
+ """
ref_item = await self.get(item_id, provider_instance_id_or_domain)
- for prov_mapping in ref_item.provider_mappings:
- prov = self.mass.get_provider(prov_mapping.provider_instance)
- if prov is None:
- continue
- if not isinstance(prov, MusicProvider):
- continue
- if ProviderFeature.SIMILAR_TRACKS not in prov.supported_features:
- continue
- # Grab similar tracks from the music provider
- return await prov.get_similar_tracks(prov_track_id=prov_mapping.item_id, limit=limit)
+
+ # Sort provider mappings to prefer user's provider instances
+ def sort_key(mapping: ProviderMapping) -> tuple[int, int]:
+ # Primary sort: preferred providers first (0), then others (1)
+ preferred = (
+ 0
+ if preferred_provider_instances
+ and mapping.provider_instance in preferred_provider_instances
+ else 1
+ )
+ # Secondary sort: by quality (higher is better, so negate)
+ quality = -(mapping.quality or 0)
+ return (preferred, quality)
+
+ sorted_mappings = sorted(ref_item.provider_mappings, key=sort_key)
+
+ # Try preferred providers first, then fall back to others
+ for allow_other_provider in (False, True):
+ for prov_mapping in sorted_mappings:
+ if (
+ not allow_other_provider
+ and preferred_provider_instances
+ and prov_mapping.provider_instance not in preferred_provider_instances
+ ):
+ continue
+ prov = self.mass.get_provider(prov_mapping.provider_instance)
+ if prov is None:
+ continue
+ if not isinstance(prov, MusicProvider):
+ continue
+ if ProviderFeature.SIMILAR_TRACKS not in prov.supported_features:
+ continue
+ # Grab similar tracks from the music provider
+ return await prov.get_similar_tracks(
+ prov_track_id=prov_mapping.item_id, limit=limit
+ )
+
if not allow_lookup:
return []
# check if we have any provider that supports dynamic tracks
# TODO: query metadata provider(s) (such as lastfm?)
# to get similar tracks (or tracks from similar artists)
- for prov in self.mass.get_providers(ProviderType.MUSIC):
+ music_prov: MusicProvider | None = None
+ for prov in self.mass.music.providers:
if ProviderFeature.SIMILAR_TRACKS in prov.supported_features:
+ music_prov = prov
break
- else:
+ if music_prov is None:
msg = "No Music Provider found that supports requesting similar tracks."
raise UnsupportedFeaturedException(msg)
- if ref_item.provider == "library":
- await self.mass.metadata.update_metadata(ref_item)
- else:
- await self.match_providers(ref_item)
+ if mappings := await self.match_provider(ref_item, music_prov):
+ if ref_item.provider == "library":
+ # update database with new provider mappings
+ await self.add_provider_mappings(ref_item.item_id, mappings)
+ ref_item.provider_mappings.update(mappings)
+ return await music_prov.get_similar_tracks(
+ prov_track_id=mappings[0].item_id, limit=limit
+ )
return []
async def match_provider(
self,
- db_track: Track,
+ base_track: Track,
provider: MusicProvider,
strict: bool = True,
ref_albums: list[Album] | None = None,
) -> list[ProviderMapping]:
"""
- Try to find match on (streaming) provider for the provided (database) track.
+ Try to find match on (streaming) provider for the provided track.
This is used to link objects of different providers/qualities together.
"""
if ref_albums is None:
- ref_albums = await self.albums(db_track.item_id, db_track.provider)
- self.logger.debug("Trying to match track %s on provider %s", db_track.name, provider.name)
+ ref_albums = await self.albums(base_track.item_id, base_track.provider)
+ self.logger.debug("Trying to match track %s on provider %s", base_track.name, provider.name)
matches: list[ProviderMapping] = []
- for artist in db_track.artists:
+ for artist in base_track.artists:
if matches:
break
- search_str = f"{artist.name} - {db_track.name}"
+ search_str = f"{artist.name} - {base_track.name}"
search_result = await self.search(search_str, provider.domain)
for search_result_item in search_result:
if not search_result_item.available:
continue
# do a basic compare first
- if not compare_media_item(db_track, search_result_item, strict=False):
+ if not compare_media_item(base_track, search_result_item, strict=False):
continue
# we must fetch the full version, search results can be simplified objects
prov_track = await self.get_provider_item(
search_result_item.provider,
fallback=search_result_item,
)
- if compare_track(db_track, prov_track, strict=strict, track_albums=ref_albums):
+ if compare_track(base_track, prov_track, strict=strict, track_albums=ref_albums):
matches.extend(search_result_item.provider_mappings)
if not matches:
self.logger.debug(
"Could not find match for Track %s on provider %s",
- db_track.name,
+ base_track.name,
provider.name,
)
return matches
async def match_providers(self, db_track: Track) -> None:
- """Try to find matching track on all providers for the provided (database) track_id.
+ """
+ Try to find matching track on all providers for the provided (database) track_id.
This is used to link objects of different providers/qualities together.
"""
async def radio_mode_base_tracks(
self,
- item_id: str,
- provider_instance_id_or_domain: str,
+ item: Track,
+ preferred_provider_instances: list[str] | None = None,
) -> list[Track]:
- """Get the list of base tracks from the controller used to calculate the dynamic radio."""
- return [await self.get(item_id, provider_instance_id_or_domain)]
+ """
+ Get the list of base tracks from the controller used to calculate the dynamic radio.
+
+ :param item: The Track to get base tracks for.
+ :param preferred_provider_instances: List of preferred provider instance IDs to use.
+ """
+ return [item]
async def _add_library_item(self, item: Track, overwrite_existing: bool = False) -> int:
"""Add a new item record to the database."""