Reduce memory usage of cache (#314)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Sun, 15 May 2022 23:14:45 +0000 (01:14 +0200)
committerGitHub <noreply@github.com>
Sun, 15 May 2022 23:14:45 +0000 (01:14 +0200)
* Reduce memory usage of cache

* ensure string for checksum

music_assistant/helpers/cache.py

index 0d39ebf0a79c11ab3fe256ae6d9200cc41230107..815b72853097586bf1fc46c4d11a4a2c8382b21d 100644 (file)
@@ -5,7 +5,9 @@ import asyncio
 import functools
 import json
 import time
-from typing import TYPE_CHECKING
+from collections import OrderedDict
+from collections.abc import MutableMapping
+from typing import TYPE_CHECKING, Any, Iterator
 
 from music_assistant.helpers.database import TABLE_CACHE
 
@@ -20,7 +22,7 @@ class Cache:
         """Initialize our caching class."""
         self.mass = mass
         self.logger = mass.logger.getChild("cache")
-        self._mem_cache = {}
+        self._mem_cache = MemoryCache(500)
 
     async def setup(self) -> None:
         """Async initialize of cache module."""
@@ -35,6 +37,8 @@ class Cache:
                     cacheobject matches the checkum provided
         """
         cur_time = int(time.time())
+        if not isinstance(checksum, str):
+            checksum = str(checksum)
 
         # try memory cache first
         cache_data = self._mem_cache.get(cache_key)
@@ -61,17 +65,18 @@ class Cache:
                     )
                 else:
                     # also store in memory cache for faster access
-                    if cache_key not in self._mem_cache:
-                        self._mem_cache[cache_key] = (
-                            data,
-                            db_row["checksum"],
-                            db_row["expires"],
-                        )
+                    self._mem_cache[cache_key] = (
+                        data,
+                        db_row["checksum"],
+                        db_row["expires"],
+                    )
                     return data
         return default
 
     async def set(self, cache_key, data, checksum="", expiration=(86400 * 30)):
         """Set data in cache."""
+        if not isinstance(checksum, str):
+            checksum = str(checksum)
         expires = int(time.time() + expiration)
         self._mem_cache[cache_key] = (data, checksum, expires)
         if (expires - time.time()) < 3600 * 4:
@@ -118,7 +123,7 @@ def use_cache(expiration=86400 * 30):
             method_class_name = method_class.__class__.__name__
             cache_key_parts = [method_class_name, func.__name__]
             skip_cache = kwargs.pop("skip_cache", False)
-            cache_checksum = kwargs.pop("cache_checksum", None)
+            cache_checksum = kwargs.pop("cache_checksum", "")
             if len(args) > 1:
                 cache_key_parts += args[1:]
             for key in sorted(kwargs.keys()):
@@ -138,3 +143,50 @@ def use_cache(expiration=86400 * 30):
         return wrapped
 
     return wrapper
+
+
+class MemoryCache(MutableMapping):
+    """Simple limited in-memory cache implementation."""
+
+    def __init__(self, maxlen: int):
+        """Initialize."""
+        self._maxlen = maxlen
+        self.d = OrderedDict()
+
+    @property
+    def maxlen(self) -> int:
+        """Return max length."""
+        return self._maxlen
+
+    def get(self, key: str, default: Any = None) -> Any:
+        """Return item or default."""
+        return self.d.get(key, default)
+
+    def pop(self, key: str, default: Any = None) -> Any:
+        """Pop item from collection."""
+        return self.d.pop(key, default)
+
+    def __getitem__(self, key: str) -> Any:
+        """Get item."""
+        self.d.move_to_end(key)
+        return self.d[key]
+
+    def __setitem__(self, key: str, value: Any) -> None:
+        """Set item."""
+        if key in self.d:
+            self.d.move_to_end(key)
+        elif len(self.d) == self.maxlen:
+            self.d.popitem(last=False)
+        self.d[key] = value
+
+    def __delitem__(self, key) -> None:
+        """Delete item."""
+        del self.d[key]
+
+    def __iter__(self) -> Iterator:
+        """Iterate items."""
+        return self.d.__iter__()
+
+    def __len__(self) -> int:
+        """Return length."""
+        return len(self.d)