from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Iterator, MutableMapping
from contextlib import asynccontextmanager
from contextvars import ContextVar
+from pathlib import Path
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, cast, get_type_hints
from music_assistant_models.config_entries import ConfigEntry, ConfigValueType
CONF_CLEAR_CACHE = "clear_cache"
DEFAULT_CACHE_EXPIRATION = 86400 * 30 # 30 days
DB_SCHEMA_VERSION = 7
+MAX_CACHE_DB_SIZE_MB = 2048
BYPASS_CACHE: ContextVar[bool] = ContextVar("BYPASS_CACHE", default=False)
finally:
BYPASS_CACHE.reset(token)
+ async def _check_and_reset_oversized_cache(self) -> bool:
+ """Check cache database size and remove it if it exceeds the max size.
+
+ Returns True if the cache database was removed.
+ """
+ db_path = os.path.join(self.mass.cache_path, "cache.db")
+ # also include the write ahead log and shared memory db files
+ db_files = [db_path + suffix for suffix in ("", "-wal", "-shm")]
+
+ def _get_db_size() -> float:
+ total = 0
+ for path in db_files:
+ if os.path.exists(path):
+ total += Path(path).stat().st_size
+ return total / (1024 * 1024)
+
+ db_size_mb = await asyncio.to_thread(_get_db_size)
+ if db_size_mb <= MAX_CACHE_DB_SIZE_MB:
+ return False
+ self.logger.warning(
+ "Cache database size %.2f MB exceeds maximum of %d MB, removing cache database",
+ db_size_mb,
+ MAX_CACHE_DB_SIZE_MB,
+ )
+ for path in db_files:
+ if await asyncio.to_thread(os.path.exists, path):
+ await asyncio.to_thread(os.remove, path)
+ return True
+
async def _setup_database(self) -> None:
"""Initialize database."""
+ cache_was_reset = await self._check_and_reset_oversized_cache()
db_path = os.path.join(self.mass.cache_path, "cache.db")
self.database = DatabaseConnection(db_path)
await self.database.setup()
# always create db tables if they don't exist to prevent errors trying to access them later
await self.__create_database_tables()
- try:
- if db_row := await self.database.get_row(DB_TABLE_SETTINGS, {"key": "version"}):
- prev_version = int(db_row["value"])
- else:
- prev_version = 0
- except (KeyError, ValueError):
- prev_version = 0
-
- if prev_version not in (0, DB_SCHEMA_VERSION):
- LOGGER.warning(
- "Performing database migration from %s to %s",
- prev_version,
- DB_SCHEMA_VERSION,
- )
+
+ if not cache_was_reset:
try:
- await self.__migrate_database(prev_version)
- except Exception as err:
- LOGGER.warning("Cache database migration failed: %s, resetting cache", err)
- await self.database.execute(f"DROP TABLE IF EXISTS {DB_TABLE_CACHE}")
- await self.__create_database_tables()
+ if db_row := await self.database.get_row(DB_TABLE_SETTINGS, {"key": "version"}):
+ prev_version = int(db_row["value"])
+ else:
+ prev_version = 0
+ except (KeyError, ValueError):
+ prev_version = 0
+
+ if prev_version not in (0, DB_SCHEMA_VERSION):
+ LOGGER.warning(
+ "Performing database migration from %s to %s",
+ prev_version,
+ DB_SCHEMA_VERSION,
+ )
+ try:
+ await self.__migrate_database(prev_version)
+ except Exception as err:
+ LOGGER.warning("Cache database migration failed: %s, resetting cache", err)
+ await self.database.execute(f"DROP TABLE IF EXISTS {DB_TABLE_CACHE}")
+ await self.__create_database_tables()
# store current schema version
await self.database.insert_or_replace(
{"key": "version", "value": str(DB_SCHEMA_VERSION), "type": "str"},
)
await self.__create_database_indexes()
- # compact db (vacuum) at startup
- self.logger.debug("Compacting database...")
- try:
- await self.database.vacuum()
- except Exception as err:
- self.logger.warning("Database vacuum failed: %s", str(err))
- else:
- self.logger.debug("Compacting database done")
+
+ if not cache_was_reset:
+ # compact db (vacuum) at startup
+ self.logger.debug("Compacting database...")
+ try:
+ await self.database.vacuum()
+ except Exception as err:
+ self.logger.warning("Database vacuum failed: %s", str(err))
+ else:
+ self.logger.debug("Compacting database done")
async def __create_database_tables(self) -> None:
"""Create database table(s)."""
"""Fixtures for testing Music Assistant."""
+import asyncio
import logging
import pathlib
from collections.abc import AsyncGenerator
import pytest
+from music_assistant.controllers.cache import CacheController
+from music_assistant.controllers.config import ConfigController
from music_assistant.mass import MusicAssistant
yield mass_instance
finally:
await mass_instance.stop()
+
+
+@pytest.fixture
+async def mass_minimal(tmp_path: pathlib.Path) -> AsyncGenerator[MusicAssistant, None]:
+ """Create a minimal Music Assistant instance without starting the full server.
+
+ Only initializes the event loop and config controller.
+ Useful for testing individual controllers without the overhead of the webserver.
+
+ :param tmp_path: Temporary directory for test data.
+ """
+ storage_path = tmp_path / "data"
+ cache_path = tmp_path / "cache"
+ storage_path.mkdir(parents=True)
+ cache_path.mkdir(parents=True)
+
+ logging.getLogger("aiosqlite").level = logging.INFO
+
+ mass_instance = MusicAssistant(str(storage_path), str(cache_path))
+
+ mass_instance.loop = asyncio.get_running_loop()
+ mass_instance.loop_thread_id = (
+ getattr(mass_instance.loop, "_thread_id", None)
+ if hasattr(mass_instance.loop, "_thread_id")
+ else id(mass_instance.loop)
+ )
+
+ mass_instance.config = ConfigController(mass_instance)
+ await mass_instance.config.setup()
+
+ mass_instance.cache = CacheController(mass_instance)
+
+ try:
+ yield mass_instance
+ finally:
+ if mass_instance.cache.database:
+ await mass_instance.cache.database.close()
+ await mass_instance.config.close()
--- /dev/null
+"""Tests for cache controller oversized cache detection and reset."""
+
+import os
+from collections.abc import Callable
+from typing import Any
+from unittest.mock import AsyncMock, patch
+
+import aiofiles
+import pytest
+
+from music_assistant.controllers.cache import MAX_CACHE_DB_SIZE_MB
+from music_assistant.mass import MusicAssistant
+
+
+async def _create_db_files(cache_path: str) -> list[str]:
+ """Create small cache.db, cache.db-wal, and cache.db-shm files.
+
+ :param cache_path: Path to the cache directory.
+ """
+ db_path = os.path.join(cache_path, "cache.db")
+ paths = [db_path + suffix for suffix in ("", "-wal", "-shm")]
+ for path in paths:
+ async with aiofiles.open(path, "wb") as f:
+ await f.write(b"\0")
+ return paths
+
+
+async def test_cache_reset_when_exceeding_limit(mass_minimal: MusicAssistant) -> None:
+ """Test that the cache database is removed when it exceeds MAX_CACHE_DB_SIZE_MB.
+
+ :param mass_minimal: Minimal MusicAssistant instance.
+ """
+ cache = mass_minimal.cache
+ db_files = await _create_db_files(mass_minimal.cache_path)
+
+ with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread:
+
+ async def _side_effect(func: Callable[..., Any], *args: Any) -> Any:
+ if getattr(func, "__name__", "") == "_get_db_size":
+ return float(MAX_CACHE_DB_SIZE_MB + 100)
+ return func(*args)
+
+ mock_to_thread.side_effect = _side_effect
+ result = await cache._check_and_reset_oversized_cache()
+
+ assert result is True
+ for path in db_files:
+ assert not os.path.exists(path)
+
+
+async def test_cache_not_reset_when_under_limit(mass_minimal: MusicAssistant) -> None:
+ """Test that the cache database is kept when it is under MAX_CACHE_DB_SIZE_MB.
+
+ :param mass_minimal: Minimal MusicAssistant instance.
+ """
+ cache = mass_minimal.cache
+ db_files = await _create_db_files(mass_minimal.cache_path)
+
+ with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread:
+
+ async def _side_effect(func: Callable[..., Any], *args: Any) -> Any:
+ if getattr(func, "__name__", "") == "_get_db_size":
+ return 1.0
+ return func(*args)
+
+ mock_to_thread.side_effect = _side_effect
+ result = await cache._check_and_reset_oversized_cache()
+
+ assert result is False
+ for path in db_files:
+ assert os.path.exists(path)
+
+
+async def test_all_three_db_files_included_in_size(mass_minimal: MusicAssistant) -> None:
+ """Test that cache.db, cache.db-wal, and cache.db-shm are all summed for size check.
+
+ :param mass_minimal: Minimal MusicAssistant instance.
+ """
+ cache = mass_minimal.cache
+ db_path = os.path.join(mass_minimal.cache_path, "cache.db")
+
+ # Create 3 files of 100 bytes each (300 bytes total)
+ for suffix in ("", "-wal", "-shm"):
+ async with aiofiles.open(db_path + suffix, "wb") as f:
+ await f.write(b"\0" * 100)
+
+ # Set threshold to ~200 bytes so 2 files pass but 3 files exceed it
+ size_threshold_mb = 0.0002
+ with patch("music_assistant.controllers.cache.MAX_CACHE_DB_SIZE_MB", size_threshold_mb):
+ result = await cache._check_and_reset_oversized_cache()
+
+ # 300 bytes exceeds the ~200 byte threshold, proving all 3 files are summed
+ assert result is True
+ assert not os.path.exists(db_path)
+ assert not os.path.exists(db_path + "-wal")
+ assert not os.path.exists(db_path + "-shm")
+
+
+async def test_skip_migration_when_cache_reset(
+ mass_minimal: MusicAssistant,
+ caplog: pytest.LogCaptureFixture,
+) -> None:
+ """Test that database migration is skipped when the cache was reset.
+
+ :param mass_minimal: Minimal MusicAssistant instance.
+ :param caplog: Log capture fixture.
+ """
+ cache = mass_minimal.cache
+
+ with patch.object(cache, "_check_and_reset_oversized_cache", return_value=True):
+ await cache._setup_database()
+
+ assert "Performing database migration" not in caplog.text
+
+
+async def test_skip_vacuum_when_cache_reset(
+ mass_minimal: MusicAssistant,
+ caplog: pytest.LogCaptureFixture,
+) -> None:
+ """Test that database vacuum is skipped when the cache was reset.
+
+ :param mass_minimal: Minimal MusicAssistant instance.
+ :param caplog: Log capture fixture.
+ """
+ cache = mass_minimal.cache
+
+ with patch.object(cache, "_check_and_reset_oversized_cache", return_value=True):
+ await cache._setup_database()
+
+ assert "Compacting database" not in caplog.text