Typing fixes for the Cache controller (#2569)
authorOzGav <gavnosp@hotmail.com>
Tue, 18 Nov 2025 07:37:55 +0000 (17:37 +1000)
committerGitHub <noreply@github.com>
Tue, 18 Nov 2025 07:37:55 +0000 (08:37 +0100)
* mypy fixes for cache.py

* Fix  __init__ signature to match CoreController

* Adjust checksum typing

* Guarantee checksum is a str

* Apply suggestion from @marcelveldt

---------

Co-authored-by: Marcel van der Veldt <m.vanderveldt@outlook.com>
music_assistant/controllers/cache.py
pyproject.toml

index f7b4d02dc3c4637487ce6a6567c305262df555d0..cc81b60598f894bf9616794724b86ee12301ba9b 100644 (file)
@@ -11,7 +11,7 @@ from collections import OrderedDict
 from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Iterator, MutableMapping
 from contextlib import asynccontextmanager
 from contextvars import ContextVar
-from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, get_type_hints
+from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, cast, get_type_hints
 
 from music_assistant_models.config_entries import ConfigEntry, ConfigValueType
 from music_assistant_models.enums import ConfigEntryType
@@ -25,9 +25,9 @@ from music_assistant.models.core_controller import CoreController
 if TYPE_CHECKING:
     from music_assistant_models.config_entries import CoreConfig
 
+    from music_assistant import MusicAssistant
     from music_assistant.models.provider import Provider
 
-
 LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.cache")
 CONF_CLEAR_CACHE = "clear_cache"
 DEFAULT_CACHE_EXPIRATION = 86400 * 30  # 30 days
@@ -41,9 +41,9 @@ class CacheController(CoreController):
 
     domain: str = "cache"
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(self, mass: MusicAssistant) -> None:
         """Initialize core controller."""
-        super().__init__(*args, **kwargs)
+        super().__init__(mass)
         self.database: DatabaseConnection | None = None
         self._mem_cache = MemoryCache(500)
         self.manifest.name = "Cache controller"
@@ -92,7 +92,7 @@ class CacheController(CoreController):
         key: str,
         provider: str = "default",
         category: int = 0,
-        checksum: str | None = None,
+        checksum: str | int | None = None,
         default: Any = None,
         allow_bypass: bool = True,
     ) -> Any:
@@ -105,6 +105,7 @@ class CacheController(CoreController):
                     cache object matches the checksum provided
         - default: value to return if no cache object is found
         """
+        assert self.database is not None
         assert key, "No key provided"
         if allow_bypass and BYPASS_CACHE.get():
             return default
@@ -162,9 +163,10 @@ class CacheController(CoreController):
         - checksum: optional argument to store with the cache object
         - persistent: if True the cache object will not be deleted when clearing the cache
         """
+        assert self.database is not None
         if not key:
             return
-        if checksum is not None and not isinstance(checksum, str):
+        if checksum is not None:
             checksum = str(checksum)
         expires = int(time.time() + expiration)
         memory_key = f"{provider}/{category}/{key}"
@@ -190,6 +192,7 @@ class CacheController(CoreController):
         self, key: str | None, category: int | None = None, provider: str | None = None
     ) -> None:
         """Delete data from cache."""
+        assert self.database is not None
         match: dict[str, str | int] = {}
         if key is not None:
             match["key"] = key
@@ -211,6 +214,7 @@ class CacheController(CoreController):
         include_persistent: bool = False,
     ) -> None:
         """Clear all/partial items from cache."""
+        assert self.database is not None
         self._mem_cache.clear()
         self.logger.info("Clearing database...")
         query_parts: list[str] = []
@@ -228,6 +232,7 @@ class CacheController(CoreController):
 
     async def auto_cleanup(self) -> None:
         """Run scheduled auto cleanup task."""
+        assert self.database is not None
         self.logger.debug("Running automatic cleanup...")
         # simply reset the memory cache
         self._mem_cache.clear()
@@ -297,6 +302,7 @@ class CacheController(CoreController):
 
     async def __create_database_tables(self) -> None:
         """Create database table(s)."""
+        assert self.database is not None
         await self.database.execute(
             f"""CREATE TABLE IF NOT EXISTS {DB_TABLE_SETTINGS}(
                     key TEXT PRIMARY KEY,
@@ -322,6 +328,7 @@ class CacheController(CoreController):
 
     async def __create_database_indexes(self) -> None:
         """Create database indexes."""
+        assert self.database is not None
         await self.database.execute(
             f"CREATE INDEX IF NOT EXISTS {DB_TABLE_CACHE}_category_idx "
             f"ON {DB_TABLE_CACHE}(category);"
@@ -402,7 +409,7 @@ def use_cache(
             )
             if cachedata is not None:
                 type_hints = get_type_hints(func)
-                return parse_value(func.__name__, cachedata, type_hints["return"])
+                return cast("R", parse_value(func.__name__, cachedata, type_hints["return"]))
             # get data from method/provider
             result = await func(self, *args, **kwargs)
             # store result in cache (but don't await)
@@ -424,13 +431,13 @@ def use_cache(
     return _decorator
 
 
-class MemoryCache(MutableMapping):
+class MemoryCache(MutableMapping[str, Any]):
     """Simple limited in-memory cache implementation."""
 
     def __init__(self, maxlen: int) -> None:
         """Initialize."""
         self._maxlen = maxlen
-        self.d = OrderedDict()
+        self.d: OrderedDict[str, Any] = OrderedDict()
 
     @property
     def maxlen(self) -> int:
@@ -458,11 +465,11 @@ class MemoryCache(MutableMapping):
             self.d.popitem(last=False)
         self.d[key] = value
 
-    def __delitem__(self, key) -> None:
+    def __delitem__(self, key: str) -> None:
         """Delete item."""
         del self.d[key]
 
-    def __iter__(self) -> Iterator:
+    def __iter__(self) -> Iterator[str]:
         """Iterate items."""
         return self.d.__iter__()
 
index 1e6ac17208aaef63395a39cb4bdbd0e05b8ef264..c9861c2aa42d268fa4d5deeb9a5393523fc81705 100644 (file)
@@ -129,7 +129,6 @@ enable_error_code = [
   "truthy-iterable",
 ]
 exclude = [
-  '^music_assistant/controllers/cache.py$',
   '^music_assistant/controllers/media/albums.py*$',
   '^music_assistant/controllers/media/artists.py*$',
   '^music_assistant/controllers/media/audiobooks.py*$',