From 51be73d56da90e14baa2c66294ab52874f0627bc Mon Sep 17 00:00:00 2001 From: Jozef Kruszynski <60214390+jozefKruszynski@users.noreply.github.com> Date: Sun, 1 Feb 2026 12:00:29 +0100 Subject: [PATCH] Add update_provider_mapping function (#3037) --- music_assistant/controllers/media/base.py | 78 +++++++- music_assistant/controllers/music.py | 31 ++- music_assistant/providers/tidal/streaming.py | 70 ++++++- tests/providers/tidal/test_streaming.py | 195 ++++++++++++++++++- 4 files changed, 365 insertions(+), 9 deletions(-) diff --git a/music_assistant/controllers/media/base.py b/music_assistant/controllers/media/base.py index 2a895d85..737deec9 100644 --- a/music_assistant/controllers/media/base.py +++ b/music_assistant/controllers/media/base.py @@ -16,11 +16,18 @@ from music_assistant_models.errors import ( MediaNotFoundError, ProviderUnavailableError, ) -from music_assistant_models.media_items import ItemMapping, MediaItemType, ProviderMapping, Track +from music_assistant_models.media_items import ( + AudioFormat, + ItemMapping, + MediaItemType, + ProviderMapping, + Track, +) from music_assistant.constants import DB_TABLE_PLAYLOG, DB_TABLE_PROVIDER_MAPPINGS, MASS_LOGGER_NAME from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user from music_assistant.helpers.compare import compare_media_item, create_safe_string +from music_assistant.helpers.database import UNSET from music_assistant.helpers.json import json_loads, serialize_to_json from music_assistant.helpers.util import guard_single_request @@ -614,6 +621,75 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta): await self.set_provider_mappings(db_id, library_item.provider_mappings) self.mass.signal_event(EventType.MEDIA_ITEM_UPDATED, library_item.uri, library_item) + @final + async def update_provider_mapping( + self, + item_id: str | int, + provider_instance_id: str, + provider_item_id: str, + *, + available: bool | Any = UNSET, + in_library: bool | Any = UNSET, + is_unique: bool | None | Any = UNSET, + url: str | None | Any = UNSET, + details: str | None | Any = UNSET, + audio_format: AudioFormat | Any = UNSET, + ) -> None: + """Update an existing provider mapping for a library item.""" + db_id = int(item_id) # ensure integer + library_item = await self.get_library_item(db_id) + + # find the current mapping (strictly by provider instance + provider item id) + cur_mapping: ProviderMapping | None = None + for mapping in library_item.provider_mappings: + if ( + mapping.provider_instance == provider_instance_id + and mapping.item_id == provider_item_id + ): + cur_mapping = mapping + break + if cur_mapping is None: + msg = ( + f"Provider mapping {provider_instance_id}/{provider_item_id} " + f"not found for item {db_id}" + ) + raise MediaNotFoundError(msg) + + # guard against nulls for NOT NULL columns + if available is None: + available = UNSET + if in_library is None: + in_library = UNSET + + updates: dict[str, Any] = {} + if available is not UNSET: + updates["available"] = bool(available) + if in_library is not UNSET: + updates["in_library"] = bool(in_library) + if is_unique is not UNSET: + updates["is_unique"] = is_unique + if url is not UNSET: + updates["url"] = url + if details is not UNSET: + updates["details"] = details + if audio_format is not UNSET: + updates["audio_format"] = serialize_to_json(audio_format) + + if not updates: + return + + match = { + "media_type": self.media_type.value, + "item_id": db_id, + "provider_instance": provider_instance_id, + "provider_item_id": provider_item_id, + } + await self.mass.music.database.update(DB_TABLE_PROVIDER_MAPPINGS, match, updates) + + # Re-fetch the updated item so the event payload reflects persisted DB state. + updated_item = await self.get_library_item(db_id) + self.mass.signal_event(EventType.MEDIA_ITEM_UPDATED, updated_item.uri, updated_item) + @final async def remove_provider_mapping( self, item_id: str | int, provider_instance_id: str, provider_item_id: str diff --git a/music_assistant/controllers/music.py b/music_assistant/controllers/music.py index 1677f054..c8149b0a 100644 --- a/music_assistant/controllers/music.py +++ b/music_assistant/controllers/music.py @@ -33,6 +33,7 @@ from music_assistant_models.errors import ( from music_assistant_models.helpers import get_global_cache_value from music_assistant_models.media_items import ( Artist, + AudioFormat, BrowseFolder, ItemMapping, MediaItemType, @@ -66,7 +67,7 @@ from music_assistant.controllers.streams.smart_fades.fades import SMART_CROSSFAD from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user from music_assistant.helpers.api import api_command from music_assistant.helpers.compare import compare_strings, compare_version, create_safe_string -from music_assistant.helpers.database import DatabaseConnection +from music_assistant.helpers.database import UNSET, DatabaseConnection from music_assistant.helpers.datetime import utc_timestamp from music_assistant.helpers.json import json_dumps, json_loads, serialize_to_json from music_assistant.helpers.tags import split_artists @@ -1673,6 +1674,34 @@ class MusicController(CoreController): db_item = await ctrl.get_library_item(db_id) await ctrl.match_providers(db_item) + async def update_provider_mapping( + self, + media_type: MediaType, + db_id: str | int, + provider_instance_id: str, + provider_item_id: str, + *, + available: bool | Any = UNSET, + in_library: bool | Any = UNSET, + is_unique: bool | None | Any = UNSET, + url: str | None | Any = UNSET, + details: str | None | Any = UNSET, + audio_format: AudioFormat | Any = UNSET, + ) -> None: + """Update an existing provider mapping for a library item.""" + ctrl = self.get_controller(media_type) + await ctrl.update_provider_mapping( + item_id=db_id, + provider_instance_id=provider_instance_id, + provider_item_id=provider_item_id, + available=available, + in_library=in_library, + is_unique=is_unique, + url=url, + details=details, + audio_format=audio_format, + ) + async def _get_default_recommendations(self) -> list[RecommendationFolder]: """Return default recommendations.""" return [ diff --git a/music_assistant/providers/tidal/streaming.py b/music_assistant/providers/tidal/streaming.py index 3830f64a..f9233442 100644 --- a/music_assistant/providers/tidal/streaming.py +++ b/music_assistant/providers/tidal/streaming.py @@ -2,6 +2,7 @@ from __future__ import annotations +from sqlite3 import OperationalError from typing import TYPE_CHECKING from music_assistant_models.enums import ContentType, ExternalID, StreamType @@ -72,15 +73,25 @@ class TidalStreamingManager: else: content_type = ContentType.MP4 + resolved_audio_format = AudioFormat( + content_type=content_type, + sample_rate=stream_data.get("sampleRate", 44100), + bit_depth=stream_data.get("bitDepth", 16), + channels=2, + ) + + # Never block or fail playback on DB issues. + self.mass.create_task( + self._async_update_provider_mapping_audio_format( + provider_track_id=track.item_id, + resolved_audio_format=resolved_audio_format, + ) + ) + return StreamDetails( item_id=track.item_id, provider=self.provider.instance_id, - audio_format=AudioFormat( - content_type=content_type, - sample_rate=stream_data.get("sampleRate", 44100), - bit_depth=stream_data.get("bitDepth", 16), - channels=2, - ), + audio_format=resolved_audio_format, stream_type=StreamType.HTTP, duration=track.duration, path=url, @@ -88,6 +99,53 @@ class TidalStreamingManager: allow_seek=True, ) + async def _async_update_provider_mapping_audio_format( + self, + provider_track_id: str, + resolved_audio_format: AudioFormat, + ) -> None: + """Persist resolved audio format on the provider mapping (best-effort).""" + try: + lib_track = await self.mass.music.tracks.get_library_item_by_prov_id( + provider_track_id, self.provider.instance_id + ) + if not lib_track: + return + + cur_mapping = next( + ( + m + for m in lib_track.provider_mappings + if m.provider_instance == self.provider.instance_id + and m.item_id == provider_track_id + ), + None, + ) + if not cur_mapping or cur_mapping.audio_format == resolved_audio_format: + return + + await self.mass.music.tracks.update_provider_mapping( + item_id=lib_track.item_id, + provider_instance_id=self.provider.instance_id, + provider_item_id=provider_track_id, + audio_format=resolved_audio_format, + ) + except (MediaNotFoundError, OperationalError, AssertionError) as err: + self.provider.logger.debug( + "Failed to persist audio_format on provider mapping for Tidal track %s " + "(provider_instance=%s): %s", + provider_track_id, + self.provider.instance_id, + err, + ) + except Exception: + self.provider.logger.exception( + "Unexpected error while persisting audio_format on provider mapping for " + "Tidal track %s (provider_instance=%s)", + provider_track_id, + self.provider.instance_id, + ) + async def _get_track_by_isrc(self, item_id: str) -> Track | None: """Lookup track by ISRC with caching.""" # Check cache diff --git a/tests/providers/tidal/test_streaming.py b/tests/providers/tidal/test_streaming.py index d10618ba..9acecf2a 100644 --- a/tests/providers/tidal/test_streaming.py +++ b/tests/providers/tidal/test_streaming.py @@ -1,11 +1,14 @@ """Test Tidal Streaming Manager.""" +from collections.abc import Coroutine +from sqlite3 import OperationalError +from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock import pytest from music_assistant_models.enums import ContentType, ExternalID, StreamType from music_assistant_models.errors import MediaNotFoundError -from music_assistant_models.media_items import Track +from music_assistant_models.media_items import AudioFormat, Track from music_assistant.providers.tidal.streaming import TidalStreamingManager @@ -361,3 +364,193 @@ async def test_get_stream_details_with_isrc_fallback( assert stream_details.item_id == "123" assert stream_details.path == "https://example.com/stream.flac" + + +async def test_get_stream_details_schedules_background_mapping_update( + streaming_manager: TidalStreamingManager, + provider_mock: Mock, + mock_track: Mock, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Ensure get_stream_details schedules the background mapping update task.""" + provider_mock.get_track.return_value = mock_track + provider_mock.api.get.return_value = { + "urls": ["https://example.com/stream.flac"], + "audioQuality": "LOSSLESS", + "sampleRate": 44100, + "bitDepth": 16, + } + + created: list[tuple[str, AudioFormat]] = [] + + async def _fake_worker(provider_track_id: str, resolved_audio_format: AudioFormat) -> None: + created.append((provider_track_id, resolved_audio_format)) + + # Patch the worker method so we can validate the coroutine is created with expected args + monkeypatch.setattr( + streaming_manager, "_async_update_provider_mapping_audio_format", _fake_worker + ) + + captured_coros: list[Coroutine[Any, Any, None]] = [] + + def _fake_create_task(coro: Coroutine[Any, Any, None]) -> None: + # Don't schedule; just capture the coroutine so the test can await it. + captured_coros.append(coro) + + provider_mock.mass.create_task = _fake_create_task + + stream_details = await streaming_manager.get_stream_details("123") + + assert len(captured_coros) == 1 + + # Execute the captured coroutine (safe because we patched the worker) + await captured_coros[0] + + assert created == [("123", stream_details.audio_format)] + + +async def test_async_update_provider_mapping_audio_format_no_library_item( + streaming_manager: TidalStreamingManager, provider_mock: Mock +) -> None: + """Ensure no update occurs when no library item is found.""" + provider_mock.mass.music.tracks.get_library_item_by_prov_id.return_value = None + provider_mock.mass.music.tracks.update_provider_mapping = AsyncMock() + + await streaming_manager._async_update_provider_mapping_audio_format( + provider_track_id="123", + resolved_audio_format=AudioFormat( + content_type=ContentType.FLAC, sample_rate=44100, bit_depth=16 + ), + ) + + provider_mock.mass.music.tracks.update_provider_mapping.assert_not_called() + + +async def test_async_update_provider_mapping_audio_format_no_mapping( + streaming_manager: TidalStreamingManager, provider_mock: Mock +) -> None: + """Ensure no update occurs when no provider mapping is found.""" + lib_track = Mock() + lib_track.item_id = 1 + lib_track.provider_mappings = set() + provider_mock.mass.music.tracks.get_library_item_by_prov_id.return_value = lib_track + provider_mock.mass.music.tracks.update_provider_mapping = AsyncMock() + + await streaming_manager._async_update_provider_mapping_audio_format( + provider_track_id="123", + resolved_audio_format=AudioFormat( + content_type=ContentType.FLAC, sample_rate=44100, bit_depth=16 + ), + ) + + provider_mock.mass.music.tracks.update_provider_mapping.assert_not_called() + + +async def test_async_update_provider_mapping_audio_format_same_format_no_update( + streaming_manager: TidalStreamingManager, provider_mock: Mock +) -> None: + """Ensure no update occurs when the audio format is unchanged.""" + fmt = AudioFormat(content_type=ContentType.FLAC, sample_rate=44100, bit_depth=16) + mapping = Mock() + mapping.provider_instance = provider_mock.instance_id + mapping.item_id = "123" + mapping.audio_format = fmt + + lib_track = Mock() + lib_track.item_id = 1 + lib_track.provider_mappings = {mapping} + provider_mock.mass.music.tracks.get_library_item_by_prov_id.return_value = lib_track + provider_mock.mass.music.tracks.update_provider_mapping = AsyncMock() + + await streaming_manager._async_update_provider_mapping_audio_format( + provider_track_id="123", + resolved_audio_format=fmt, + ) + + provider_mock.mass.music.tracks.update_provider_mapping.assert_not_called() + + +async def test_async_update_provider_mapping_audio_format_different_format_updates( + streaming_manager: TidalStreamingManager, provider_mock: Mock +) -> None: + """Ensure update occurs when the audio format is different.""" + old_fmt = AudioFormat(content_type=ContentType.MP4, sample_rate=44100, bit_depth=16) + new_fmt = AudioFormat(content_type=ContentType.FLAC, sample_rate=44100, bit_depth=16) + + mapping = Mock() + mapping.provider_instance = provider_mock.instance_id + mapping.item_id = "123" + mapping.audio_format = old_fmt + + lib_track = Mock() + lib_track.item_id = 1 + lib_track.provider_mappings = {mapping} + provider_mock.mass.music.tracks.get_library_item_by_prov_id.return_value = lib_track + provider_mock.mass.music.tracks.update_provider_mapping = AsyncMock() + + await streaming_manager._async_update_provider_mapping_audio_format( + provider_track_id="123", + resolved_audio_format=new_fmt, + ) + + provider_mock.mass.music.tracks.update_provider_mapping.assert_awaited_once() + provider_mock.mass.music.tracks.update_provider_mapping.assert_awaited_with( + item_id=1, + provider_instance_id=provider_mock.instance_id, + provider_item_id="123", + audio_format=new_fmt, + ) + + +async def test_async_update_provider_mapping_audio_format_sqlite_operational_error_logs_debug( + streaming_manager: TidalStreamingManager, provider_mock: Mock +) -> None: + """Ensure OperationalError is logged at debug level.""" + provider_mock.logger = Mock() + provider_mock.mass.music.tracks.get_library_item_by_prov_id.side_effect = OperationalError( + "database is locked" + ) + + await streaming_manager._async_update_provider_mapping_audio_format( + provider_track_id="123", + resolved_audio_format=AudioFormat( + content_type=ContentType.FLAC, sample_rate=44100, bit_depth=16 + ), + ) + + provider_mock.logger.debug.assert_called() + + +async def test_async_update_provider_mapping_audio_format_unexpected_error_logs_exception( + streaming_manager: TidalStreamingManager, provider_mock: Mock +) -> None: + """Ensure unexpected errors are logged at exception level.""" + provider_mock.logger = Mock() + + lib_track = Mock() + lib_track.item_id = 1 + lib_track.provider_mappings = set() + provider_mock.mass.music.tracks.get_library_item_by_prov_id.return_value = lib_track + + # Force an unexpected error after resolving lib_track + provider_mock.mass.music.tracks.update_provider_mapping = AsyncMock( + side_effect=RuntimeError("boom") + ) + + # Create a mapping that triggers the update path + mapping = Mock() + mapping.provider_instance = provider_mock.instance_id + mapping.item_id = "123" + mapping.audio_format = AudioFormat( + content_type=ContentType.MP4, sample_rate=44100, bit_depth=16 + ) + lib_track.provider_mappings = {mapping} + + await streaming_manager._async_update_provider_mapping_audio_format( + provider_track_id="123", + resolved_audio_format=AudioFormat( + content_type=ContentType.FLAC, sample_rate=44100, bit_depth=16 + ), + ) + + provider_mock.logger.exception.assert_called() -- 2.34.1