Add update_provider_mapping function (#3037)
authorJozef Kruszynski <60214390+jozefKruszynski@users.noreply.github.com>
Sun, 1 Feb 2026 11:00:29 +0000 (12:00 +0100)
committerGitHub <noreply@github.com>
Sun, 1 Feb 2026 11:00:29 +0000 (12:00 +0100)
music_assistant/controllers/media/base.py
music_assistant/controllers/music.py
music_assistant/providers/tidal/streaming.py
tests/providers/tidal/test_streaming.py

index 2a895d85b99d74c088096b91126e8ca151abbd90..737deec95dfde15e77ba01620b6fca323624d381 100644 (file)
@@ -16,11 +16,18 @@ from music_assistant_models.errors import (
     MediaNotFoundError,
     ProviderUnavailableError,
 )
-from music_assistant_models.media_items import ItemMapping, MediaItemType, ProviderMapping, Track
+from music_assistant_models.media_items import (
+    AudioFormat,
+    ItemMapping,
+    MediaItemType,
+    ProviderMapping,
+    Track,
+)
 
 from music_assistant.constants import DB_TABLE_PLAYLOG, DB_TABLE_PROVIDER_MAPPINGS, MASS_LOGGER_NAME
 from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user
 from music_assistant.helpers.compare import compare_media_item, create_safe_string
+from music_assistant.helpers.database import UNSET
 from music_assistant.helpers.json import json_loads, serialize_to_json
 from music_assistant.helpers.util import guard_single_request
 
@@ -614,6 +621,75 @@ class MediaControllerBase[ItemCls: "MediaItemType"](metaclass=ABCMeta):
         await self.set_provider_mappings(db_id, library_item.provider_mappings)
         self.mass.signal_event(EventType.MEDIA_ITEM_UPDATED, library_item.uri, library_item)
 
+    @final
+    async def update_provider_mapping(
+        self,
+        item_id: str | int,
+        provider_instance_id: str,
+        provider_item_id: str,
+        *,
+        available: bool | Any = UNSET,
+        in_library: bool | Any = UNSET,
+        is_unique: bool | None | Any = UNSET,
+        url: str | None | Any = UNSET,
+        details: str | None | Any = UNSET,
+        audio_format: AudioFormat | Any = UNSET,
+    ) -> None:
+        """Update an existing provider mapping for a library item."""
+        db_id = int(item_id)  # ensure integer
+        library_item = await self.get_library_item(db_id)
+
+        # find the current mapping (strictly by provider instance + provider item id)
+        cur_mapping: ProviderMapping | None = None
+        for mapping in library_item.provider_mappings:
+            if (
+                mapping.provider_instance == provider_instance_id
+                and mapping.item_id == provider_item_id
+            ):
+                cur_mapping = mapping
+                break
+        if cur_mapping is None:
+            msg = (
+                f"Provider mapping {provider_instance_id}/{provider_item_id} "
+                f"not found for item {db_id}"
+            )
+            raise MediaNotFoundError(msg)
+
+        # guard against nulls for NOT NULL columns
+        if available is None:
+            available = UNSET
+        if in_library is None:
+            in_library = UNSET
+
+        updates: dict[str, Any] = {}
+        if available is not UNSET:
+            updates["available"] = bool(available)
+        if in_library is not UNSET:
+            updates["in_library"] = bool(in_library)
+        if is_unique is not UNSET:
+            updates["is_unique"] = is_unique
+        if url is not UNSET:
+            updates["url"] = url
+        if details is not UNSET:
+            updates["details"] = details
+        if audio_format is not UNSET:
+            updates["audio_format"] = serialize_to_json(audio_format)
+
+        if not updates:
+            return
+
+        match = {
+            "media_type": self.media_type.value,
+            "item_id": db_id,
+            "provider_instance": provider_instance_id,
+            "provider_item_id": provider_item_id,
+        }
+        await self.mass.music.database.update(DB_TABLE_PROVIDER_MAPPINGS, match, updates)
+
+        # Re-fetch the updated item so the event payload reflects persisted DB state.
+        updated_item = await self.get_library_item(db_id)
+        self.mass.signal_event(EventType.MEDIA_ITEM_UPDATED, updated_item.uri, updated_item)
+
     @final
     async def remove_provider_mapping(
         self, item_id: str | int, provider_instance_id: str, provider_item_id: str
index 1677f054a54e7aad20b0715532a7862711ba3be3..c8149b0a2fc9bfa56530968cd07708f5068d3511 100644 (file)
@@ -33,6 +33,7 @@ from music_assistant_models.errors import (
 from music_assistant_models.helpers import get_global_cache_value
 from music_assistant_models.media_items import (
     Artist,
+    AudioFormat,
     BrowseFolder,
     ItemMapping,
     MediaItemType,
@@ -66,7 +67,7 @@ from music_assistant.controllers.streams.smart_fades.fades import SMART_CROSSFAD
 from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user
 from music_assistant.helpers.api import api_command
 from music_assistant.helpers.compare import compare_strings, compare_version, create_safe_string
-from music_assistant.helpers.database import DatabaseConnection
+from music_assistant.helpers.database import UNSET, DatabaseConnection
 from music_assistant.helpers.datetime import utc_timestamp
 from music_assistant.helpers.json import json_dumps, json_loads, serialize_to_json
 from music_assistant.helpers.tags import split_artists
@@ -1673,6 +1674,34 @@ class MusicController(CoreController):
         db_item = await ctrl.get_library_item(db_id)
         await ctrl.match_providers(db_item)
 
+    async def update_provider_mapping(
+        self,
+        media_type: MediaType,
+        db_id: str | int,
+        provider_instance_id: str,
+        provider_item_id: str,
+        *,
+        available: bool | Any = UNSET,
+        in_library: bool | Any = UNSET,
+        is_unique: bool | None | Any = UNSET,
+        url: str | None | Any = UNSET,
+        details: str | None | Any = UNSET,
+        audio_format: AudioFormat | Any = UNSET,
+    ) -> None:
+        """Update an existing provider mapping for a library item."""
+        ctrl = self.get_controller(media_type)
+        await ctrl.update_provider_mapping(
+            item_id=db_id,
+            provider_instance_id=provider_instance_id,
+            provider_item_id=provider_item_id,
+            available=available,
+            in_library=in_library,
+            is_unique=is_unique,
+            url=url,
+            details=details,
+            audio_format=audio_format,
+        )
+
     async def _get_default_recommendations(self) -> list[RecommendationFolder]:
         """Return default recommendations."""
         return [
index 3830f64af0a957dda1cf370ffe533adf5b10fdd8..f9233442c4adb9f6144bc5ab7aee690a74274815 100644 (file)
@@ -2,6 +2,7 @@
 
 from __future__ import annotations
 
+from sqlite3 import OperationalError
 from typing import TYPE_CHECKING
 
 from music_assistant_models.enums import ContentType, ExternalID, StreamType
@@ -72,15 +73,25 @@ class TidalStreamingManager:
         else:
             content_type = ContentType.MP4
 
+        resolved_audio_format = AudioFormat(
+            content_type=content_type,
+            sample_rate=stream_data.get("sampleRate", 44100),
+            bit_depth=stream_data.get("bitDepth", 16),
+            channels=2,
+        )
+
+        # Never block or fail playback on DB issues.
+        self.mass.create_task(
+            self._async_update_provider_mapping_audio_format(
+                provider_track_id=track.item_id,
+                resolved_audio_format=resolved_audio_format,
+            )
+        )
+
         return StreamDetails(
             item_id=track.item_id,
             provider=self.provider.instance_id,
-            audio_format=AudioFormat(
-                content_type=content_type,
-                sample_rate=stream_data.get("sampleRate", 44100),
-                bit_depth=stream_data.get("bitDepth", 16),
-                channels=2,
-            ),
+            audio_format=resolved_audio_format,
             stream_type=StreamType.HTTP,
             duration=track.duration,
             path=url,
@@ -88,6 +99,53 @@ class TidalStreamingManager:
             allow_seek=True,
         )
 
+    async def _async_update_provider_mapping_audio_format(
+        self,
+        provider_track_id: str,
+        resolved_audio_format: AudioFormat,
+    ) -> None:
+        """Persist resolved audio format on the provider mapping (best-effort)."""
+        try:
+            lib_track = await self.mass.music.tracks.get_library_item_by_prov_id(
+                provider_track_id, self.provider.instance_id
+            )
+            if not lib_track:
+                return
+
+            cur_mapping = next(
+                (
+                    m
+                    for m in lib_track.provider_mappings
+                    if m.provider_instance == self.provider.instance_id
+                    and m.item_id == provider_track_id
+                ),
+                None,
+            )
+            if not cur_mapping or cur_mapping.audio_format == resolved_audio_format:
+                return
+
+            await self.mass.music.tracks.update_provider_mapping(
+                item_id=lib_track.item_id,
+                provider_instance_id=self.provider.instance_id,
+                provider_item_id=provider_track_id,
+                audio_format=resolved_audio_format,
+            )
+        except (MediaNotFoundError, OperationalError, AssertionError) as err:
+            self.provider.logger.debug(
+                "Failed to persist audio_format on provider mapping for Tidal track %s "
+                "(provider_instance=%s): %s",
+                provider_track_id,
+                self.provider.instance_id,
+                err,
+            )
+        except Exception:
+            self.provider.logger.exception(
+                "Unexpected error while persisting audio_format on provider mapping for "
+                "Tidal track %s (provider_instance=%s)",
+                provider_track_id,
+                self.provider.instance_id,
+            )
+
     async def _get_track_by_isrc(self, item_id: str) -> Track | None:
         """Lookup track by ISRC with caching."""
         # Check cache
index d10618bad39145fb98cc5c890edbd763f055bac1..9acecf2ae7337e26100c8ff4aba00b1b389db224 100644 (file)
@@ -1,11 +1,14 @@
 """Test Tidal Streaming Manager."""
 
+from collections.abc import Coroutine
+from sqlite3 import OperationalError
+from typing import Any
 from unittest.mock import AsyncMock, MagicMock, Mock
 
 import pytest
 from music_assistant_models.enums import ContentType, ExternalID, StreamType
 from music_assistant_models.errors import MediaNotFoundError
-from music_assistant_models.media_items import Track
+from music_assistant_models.media_items import AudioFormat, Track
 
 from music_assistant.providers.tidal.streaming import TidalStreamingManager
 
@@ -361,3 +364,193 @@ async def test_get_stream_details_with_isrc_fallback(
 
     assert stream_details.item_id == "123"
     assert stream_details.path == "https://example.com/stream.flac"
+
+
+async def test_get_stream_details_schedules_background_mapping_update(
+    streaming_manager: TidalStreamingManager,
+    provider_mock: Mock,
+    mock_track: Mock,
+    monkeypatch: pytest.MonkeyPatch,
+) -> None:
+    """Ensure get_stream_details schedules the background mapping update task."""
+    provider_mock.get_track.return_value = mock_track
+    provider_mock.api.get.return_value = {
+        "urls": ["https://example.com/stream.flac"],
+        "audioQuality": "LOSSLESS",
+        "sampleRate": 44100,
+        "bitDepth": 16,
+    }
+
+    created: list[tuple[str, AudioFormat]] = []
+
+    async def _fake_worker(provider_track_id: str, resolved_audio_format: AudioFormat) -> None:
+        created.append((provider_track_id, resolved_audio_format))
+
+    # Patch the worker method so we can validate the coroutine is created with expected args
+    monkeypatch.setattr(
+        streaming_manager, "_async_update_provider_mapping_audio_format", _fake_worker
+    )
+
+    captured_coros: list[Coroutine[Any, Any, None]] = []
+
+    def _fake_create_task(coro: Coroutine[Any, Any, None]) -> None:
+        # Don't schedule; just capture the coroutine so the test can await it.
+        captured_coros.append(coro)
+
+    provider_mock.mass.create_task = _fake_create_task
+
+    stream_details = await streaming_manager.get_stream_details("123")
+
+    assert len(captured_coros) == 1
+
+    # Execute the captured coroutine (safe because we patched the worker)
+    await captured_coros[0]
+
+    assert created == [("123", stream_details.audio_format)]
+
+
+async def test_async_update_provider_mapping_audio_format_no_library_item(
+    streaming_manager: TidalStreamingManager, provider_mock: Mock
+) -> None:
+    """Ensure no update occurs when no library item is found."""
+    provider_mock.mass.music.tracks.get_library_item_by_prov_id.return_value = None
+    provider_mock.mass.music.tracks.update_provider_mapping = AsyncMock()
+
+    await streaming_manager._async_update_provider_mapping_audio_format(
+        provider_track_id="123",
+        resolved_audio_format=AudioFormat(
+            content_type=ContentType.FLAC, sample_rate=44100, bit_depth=16
+        ),
+    )
+
+    provider_mock.mass.music.tracks.update_provider_mapping.assert_not_called()
+
+
+async def test_async_update_provider_mapping_audio_format_no_mapping(
+    streaming_manager: TidalStreamingManager, provider_mock: Mock
+) -> None:
+    """Ensure no update occurs when no provider mapping is found."""
+    lib_track = Mock()
+    lib_track.item_id = 1
+    lib_track.provider_mappings = set()
+    provider_mock.mass.music.tracks.get_library_item_by_prov_id.return_value = lib_track
+    provider_mock.mass.music.tracks.update_provider_mapping = AsyncMock()
+
+    await streaming_manager._async_update_provider_mapping_audio_format(
+        provider_track_id="123",
+        resolved_audio_format=AudioFormat(
+            content_type=ContentType.FLAC, sample_rate=44100, bit_depth=16
+        ),
+    )
+
+    provider_mock.mass.music.tracks.update_provider_mapping.assert_not_called()
+
+
+async def test_async_update_provider_mapping_audio_format_same_format_no_update(
+    streaming_manager: TidalStreamingManager, provider_mock: Mock
+) -> None:
+    """Ensure no update occurs when the audio format is unchanged."""
+    fmt = AudioFormat(content_type=ContentType.FLAC, sample_rate=44100, bit_depth=16)
+    mapping = Mock()
+    mapping.provider_instance = provider_mock.instance_id
+    mapping.item_id = "123"
+    mapping.audio_format = fmt
+
+    lib_track = Mock()
+    lib_track.item_id = 1
+    lib_track.provider_mappings = {mapping}
+    provider_mock.mass.music.tracks.get_library_item_by_prov_id.return_value = lib_track
+    provider_mock.mass.music.tracks.update_provider_mapping = AsyncMock()
+
+    await streaming_manager._async_update_provider_mapping_audio_format(
+        provider_track_id="123",
+        resolved_audio_format=fmt,
+    )
+
+    provider_mock.mass.music.tracks.update_provider_mapping.assert_not_called()
+
+
+async def test_async_update_provider_mapping_audio_format_different_format_updates(
+    streaming_manager: TidalStreamingManager, provider_mock: Mock
+) -> None:
+    """Ensure update occurs when the audio format is different."""
+    old_fmt = AudioFormat(content_type=ContentType.MP4, sample_rate=44100, bit_depth=16)
+    new_fmt = AudioFormat(content_type=ContentType.FLAC, sample_rate=44100, bit_depth=16)
+
+    mapping = Mock()
+    mapping.provider_instance = provider_mock.instance_id
+    mapping.item_id = "123"
+    mapping.audio_format = old_fmt
+
+    lib_track = Mock()
+    lib_track.item_id = 1
+    lib_track.provider_mappings = {mapping}
+    provider_mock.mass.music.tracks.get_library_item_by_prov_id.return_value = lib_track
+    provider_mock.mass.music.tracks.update_provider_mapping = AsyncMock()
+
+    await streaming_manager._async_update_provider_mapping_audio_format(
+        provider_track_id="123",
+        resolved_audio_format=new_fmt,
+    )
+
+    provider_mock.mass.music.tracks.update_provider_mapping.assert_awaited_once()
+    provider_mock.mass.music.tracks.update_provider_mapping.assert_awaited_with(
+        item_id=1,
+        provider_instance_id=provider_mock.instance_id,
+        provider_item_id="123",
+        audio_format=new_fmt,
+    )
+
+
+async def test_async_update_provider_mapping_audio_format_sqlite_operational_error_logs_debug(
+    streaming_manager: TidalStreamingManager, provider_mock: Mock
+) -> None:
+    """Ensure OperationalError is logged at debug level."""
+    provider_mock.logger = Mock()
+    provider_mock.mass.music.tracks.get_library_item_by_prov_id.side_effect = OperationalError(
+        "database is locked"
+    )
+
+    await streaming_manager._async_update_provider_mapping_audio_format(
+        provider_track_id="123",
+        resolved_audio_format=AudioFormat(
+            content_type=ContentType.FLAC, sample_rate=44100, bit_depth=16
+        ),
+    )
+
+    provider_mock.logger.debug.assert_called()
+
+
+async def test_async_update_provider_mapping_audio_format_unexpected_error_logs_exception(
+    streaming_manager: TidalStreamingManager, provider_mock: Mock
+) -> None:
+    """Ensure unexpected errors are logged at exception level."""
+    provider_mock.logger = Mock()
+
+    lib_track = Mock()
+    lib_track.item_id = 1
+    lib_track.provider_mappings = set()
+    provider_mock.mass.music.tracks.get_library_item_by_prov_id.return_value = lib_track
+
+    # Force an unexpected error after resolving lib_track
+    provider_mock.mass.music.tracks.update_provider_mapping = AsyncMock(
+        side_effect=RuntimeError("boom")
+    )
+
+    # Create a mapping that triggers the update path
+    mapping = Mock()
+    mapping.provider_instance = provider_mock.instance_id
+    mapping.item_id = "123"
+    mapping.audio_format = AudioFormat(
+        content_type=ContentType.MP4, sample_rate=44100, bit_depth=16
+    )
+    lib_track.provider_mappings = {mapping}
+
+    await streaming_manager._async_update_provider_mapping_audio_format(
+        provider_track_id="123",
+        resolved_audio_format=AudioFormat(
+            content_type=ContentType.FLAC, sample_rate=44100, bit_depth=16
+        ),
+    )
+
+    provider_mock.logger.exception.assert_called()