Fix several issues with scrobble plugins
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Wed, 26 Mar 2025 23:32:14 +0000 (00:32 +0100)
committerMarcel van der Veldt <m.vanderveldt@outlook.com>
Wed, 26 Mar 2025 23:32:14 +0000 (00:32 +0100)
Fix thread safety, initialization and typing issues with the scrobble plugins

music_assistant/helpers/scrobbler.py
music_assistant/providers/lastfm_scrobble/__init__.py
music_assistant/providers/listenbrainz_scrobble/__init__.py
tests/core/test_scrobbler.py

index b2995c7fd5c1fc755aa08665849c6fe7da838262..3ae9d0a2e8ca6c5725f3261222b04f449c12f2cf 100644 (file)
@@ -2,7 +2,6 @@
 
 from __future__ import annotations
 
-import asyncio
 import logging
 from typing import TYPE_CHECKING
 
@@ -26,10 +25,10 @@ class ScrobblerHelper:
         """Override if subclass needs specific configuration."""
         return True
 
-    def _update_now_playing(self, report: MediaItemPlaybackProgressReport) -> None:
+    async def _update_now_playing(self, report: MediaItemPlaybackProgressReport) -> None:
         """Send a Now Playing update to the scrobbling service."""
 
-    def _scrobble(self, report: MediaItemPlaybackProgressReport) -> None:
+    async def _scrobble(self, report: MediaItemPlaybackProgressReport) -> None:
         """Scrobble."""
 
     async def _on_mass_media_item_played(self, event: MassEvent) -> None:
@@ -48,29 +47,31 @@ class ScrobblerHelper:
             # reset currently playing to avoid it expiring when looping single songs
             self.currently_playing = None
 
-        def update_now_playing() -> None:
+        async def update_now_playing() -> None:
             try:
-                self._update_now_playing(report)
+                await self._update_now_playing(report)
                 self.logger.debug(f"track {report.uri} marked as 'now playing'")
                 self.currently_playing = report.uri
             except Exception as err:
+                # TODO: try to make this a more specific exception instead of a generic one
                 self.logger.exception(err)
 
-        def scrobble() -> None:
+        async def scrobble() -> None:
             try:
-                self._scrobble(report)
+                await self._scrobble(report)
                 self.last_scrobbled = report.uri
             except Exception as err:
+                # TODO: try to make this a more specific exception instead of a generic one
                 self.logger.exception(err)
 
         # update now playing if needed
         if report.is_playing and (
             self.currently_playing is None or self.currently_playing != report.uri
         ):
-            await asyncio.to_thread(update_now_playing)
+            update_now_playing()
 
         if self.should_scrobble(report):
-            await asyncio.to_thread(scrobble)
+            await scrobble()
 
     def should_scrobble(self, report: MediaItemPlaybackProgressReport) -> bool:
         """Determine if a track should be scrobbled, to be extended later."""
index 3ff359312303884067b0be843778b06d5b00dcd5..a515136a307a2aaacda34014f949d79711eda96a 100644 (file)
@@ -1,8 +1,10 @@
 """Allows scrobbling of tracks with the help of PyLast."""
 
+import asyncio
 import logging
 import time
 from collections.abc import Callable
+from typing import TYPE_CHECKING, cast
 
 import pylast
 from music_assistant_models.config_entries import (
@@ -38,34 +40,37 @@ async def setup(
     else:
         logging.getLogger("httpcore").setLevel(logging.WARNING)
 
+    # run async setup of provider to catch any login issues early
+    await provider.async_setup()
     return provider
 
 
 class LastFMScrobbleProvider(PluginProvider):
     """Plugin provider to support scrobbling of tracks."""
 
-    _on_unload: list[Callable[[], None]] = []
+    network: pylast._Network
+    _on_unload: list[Callable[[], None]]
 
-    def _get_network_config(self) -> dict[str, ConfigValueType]:
-        return {
-            CONF_API_KEY: self.config.get_value(CONF_API_KEY),
-            CONF_API_SECRET: self.config.get_value(CONF_API_SECRET),
-            CONF_PROVIDER: self.config.get_value(CONF_PROVIDER),
-            CONF_USERNAME: self.config.get_value(CONF_USERNAME),
-            CONF_SESSION_KEY: self.config.get_value(CONF_SESSION_KEY),
-        }
+    async def async_setup(self) -> None:
+        """Handle async setup."""
+        self._on_unload: list[Callable[[], None]] = []
 
-    async def loaded_in_mass(self) -> None:
-        """Call after the provider has been loaded."""
-        await super().loaded_in_mass()
+        if not self.config.get_value(CONF_API_KEY) or not self.config.get_value(CONF_API_SECRET):
+            raise SetupFailedError("API Key and Secret need to be set")
 
         if not self.config.get_value(CONF_SESSION_KEY):
             self.logger.info("No session key available, don't forget to authenticate!")
             return
+        # creating the network instance is (potentially) blocking IO
+        # so run it in an executor thread to be safe
+        self.network = await asyncio.to_thread(get_network, self._get_network_config())
 
-        handler = LastFMEventHandler(_get_network(self._get_network_config()), self.logger)
+    async def loaded_in_mass(self) -> None:
+        """Call after the provider has been loaded."""
+        await super().loaded_in_mass()
 
-        # subscribe to internal event
+        # subscribe to media_item_played event
+        handler = LastFMEventHandler(self.network, self.logger)
         self._on_unload.append(
             self.mass.subscribe(handler._on_mass_media_item_played, EventType.MEDIA_ITEM_PLAYED)
         )
@@ -79,6 +84,15 @@ class LastFMScrobbleProvider(PluginProvider):
         for unload_cb in self._on_unload:
             unload_cb()
 
+    def _get_network_config(self) -> dict[str, ConfigValueType]:
+        return {
+            CONF_API_KEY: self.config.get_value(CONF_API_KEY),
+            CONF_API_SECRET: self.config.get_value(CONF_API_SECRET),
+            CONF_PROVIDER: self.config.get_value(CONF_PROVIDER),
+            CONF_USERNAME: self.config.get_value(CONF_USERNAME),
+            CONF_SESSION_KEY: self.config.get_value(CONF_SESSION_KEY),
+        }
+
 
 class LastFMEventHandler(ScrobblerHelper):
     """Handles the event handling."""
@@ -90,15 +104,11 @@ class LastFMEventHandler(ScrobblerHelper):
         super().__init__(logger)
         self.network = network
 
-    def _is_configured(self) -> bool:
-        if self.network is None:
-            self.logger.error("no network available during _on_mass_media_item_played")
-            return False
-
-        return True
-
-    def _update_now_playing(self, report: MediaItemPlaybackProgressReport) -> None:
-        self.network.update_now_playing(
+    async def _update_now_playing(self, report: MediaItemPlaybackProgressReport) -> None:
+        # the lastfm client is not async friendly,
+        # so we need to run it in a executor thread
+        await asyncio.to_thread(
+            self.network.update_now_playing,
             report.artist,
             report.name,
             report.album,
@@ -106,13 +116,16 @@ class LastFMEventHandler(ScrobblerHelper):
             mbid=report.mbid,
         )
 
-    def _scrobble(self, report: MediaItemPlaybackProgressReport) -> None:
-        # album artist and track number are not available without an extra API call
+    async def _scrobble(self, report: MediaItemPlaybackProgressReport) -> None:
+        # the listenbrainz client is not async friendly,
+        # so we need to run it in a executor thread
+        # NOTE: album artist and track number are not available without an extra API call
         # so they won't be scrobbled
-        self.network.scrobble(
-            report.artist,
+        await asyncio.to_thread(
+            self.network.scrobble,
+            report.artist or "unknown artist",
             report.name,
-            time.time(),
+            int(time.time()),
             report.album,
             duration=report.duration,
             mbid=report.mbid,
@@ -192,7 +205,7 @@ async def get_config_entries(
         session_id = str(values.get("session_id"))
 
         async with AuthenticationHelper(mass, session_id) as auth_helper:
-            network = _get_network(values)
+            network = get_network(values)
             skg = pylast.SessionKeyGenerator(network)
 
             # pylast says it does web auth, but actually does desktop auth
@@ -261,10 +274,12 @@ async def get_config_entries(
     return tuple(entries)
 
 
-def _get_network(config: dict[str, ConfigValueType]) -> pylast._Network:
+def get_network(config: dict[str, ConfigValueType]) -> pylast._Network:
+    """Create a network instance."""
     key = config.get(CONF_API_KEY)
     secret = config.get(CONF_API_SECRET)
     session_key = config.get(CONF_SESSION_KEY)
+    username = config.get(CONF_USERNAME)
 
     assert key
     assert key != SECURE_STRING_SUBSTITUTE
@@ -276,14 +291,16 @@ def _get_network(config: dict[str, ConfigValueType]) -> pylast._Network:
 
     provider: str = str(config.get(CONF_PROVIDER))
 
+    if TYPE_CHECKING:
+        key = cast(str, key)
+        secret = cast(str, secret)
+        session_key = cast(str, session_key)
+        username = cast(str, username)
+
     match provider.lower():
         case "lastfm":
-            return pylast.LastFMNetwork(
-                key, secret, username=config.get(CONF_USERNAME), session_key=session_key
-            )
+            return pylast.LastFMNetwork(key, secret, username=username, session_key=session_key)
         case "librefm":
-            return pylast.LibreFMNetwork(
-                key, secret, username=config.get(CONF_USERNAME), session_key=session_key
-            )
+            return pylast.LibreFMNetwork(key, secret, username=username, session_key=session_key)
         case _:
             raise SetupFailedError(f"unknown provider {provider} configured")
index f560ed6f257174a44538fe353bf16a1622c49743..467afcb07b1283e8527b4234496421e1a07833aa 100644 (file)
@@ -4,17 +4,13 @@
 # released under the Creative Commons Attribution-ShareAlike(BY-SA) 4.0 license.
 # https://creativecommons.org/licenses/by-sa/4.0/
 
+import asyncio
 import logging
 import time
 from collections.abc import Callable
-from typing import Any
 
 from liblistenbrainz import Listen, ListenBrainz
-from music_assistant_models.config_entries import (
-    ConfigEntry,
-    ConfigValueType,
-    ProviderConfig,
-)
+from music_assistant_models.config_entries import ConfigEntry, ConfigValueType, ProviderConfig
 from music_assistant_models.constants import SECURE_STRING_SUBSTITUTE
 from music_assistant_models.enums import ConfigEntryType, EventType
 from music_assistant_models.errors import SetupFailedError
@@ -48,9 +44,6 @@ async def setup(
 class ListenBrainzScrobbleProvider(PluginProvider):
     """Plugin provider to support scrobbling of tracks."""
 
-    _client: ListenBrainz = None
-    _on_unload: list[Callable[[], None]] = []
-
     def __init__(
         self,
         mass: MusicAssistant,
@@ -61,6 +54,7 @@ class ListenBrainzScrobbleProvider(PluginProvider):
         """Initialize MusicProvider."""
         super().__init__(mass, manifest, config)
         self._client = client
+        self._on_unload: list[Callable[[], None]] = []
 
     async def loaded_in_mass(self) -> None:
         """Call after the provider has been loaded."""
@@ -68,7 +62,7 @@ class ListenBrainzScrobbleProvider(PluginProvider):
 
         handler = ListenBrainzEventHandler(self._client, self.logger)
 
-        # subscribe to internal event
+        # subscribe to media_item_played event
         self._on_unload.append(
             self.mass.subscribe(handler._on_mass_media_item_played, EventType.MEDIA_ITEM_PLAYED)
         )
@@ -86,21 +80,12 @@ class ListenBrainzScrobbleProvider(PluginProvider):
 class ListenBrainzEventHandler(ScrobblerHelper):
     """Handles the event handling."""
 
-    _client: ListenBrainz = None
-
     def __init__(self, client: ListenBrainz, logger: logging.Logger) -> None:
         """Initialize."""
         super().__init__(logger)
         self._client = client
 
-    def _is_configured(self) -> bool:
-        """Check that we are configured."""
-        if self._client is None:
-            self.logger.error("no client available during _on_mass_media_item_played")
-            return False
-        return True
-
-    def _make_listen(self, report: Any) -> Listen:
+    def _make_listen(self, report: MediaItemPlaybackProgressReport) -> Listen:
         # album artist and track number are not available without an extra API call
         # so they won't be scrobbled
 
@@ -115,23 +100,33 @@ class ListenBrainzEventHandler(ScrobblerHelper):
             listening_from="music-assistant",
         )
 
-    def _update_now_playing(self, report: MediaItemPlaybackProgressReport) -> None:
-        try:
-            listen = self._make_listen(report)
-            self._client.submit_playing_now(listen)
-            self.logger.debug(f"track {report.uri} marked as 'now playing'")
-            self._currently_playing = report.uri
-        except Exception as err:
-            self.logger.exception(err)
-
-    def _scrobble(self, report: MediaItemPlaybackProgressReport) -> None:
-        try:
-            listen = self._make_listen(report)
-            listen.listened_at = int(time.time())
-            self._client.submit_single_listen(listen)
-            self._last_scrobbled = report.uri
-        except Exception as err:
-            self.logger.exception(err)
+    async def _update_now_playing(self, report: MediaItemPlaybackProgressReport) -> None:
+        def handler() -> None:
+            try:
+                listen = self._make_listen(report)
+                self._client.submit_playing_now(listen)
+                self.logger.debug(f"track {report.uri} marked as 'now playing'")
+                self._currently_playing = report.uri
+            except Exception as err:
+                self.logger.exception(err)
+
+        # the listenbrainz client is not async friendly,
+        # so we need to run it in a executor thread
+        await asyncio.to_thread(handler)
+
+    async def _scrobble(self, report: MediaItemPlaybackProgressReport) -> None:
+        def handler() -> None:
+            try:
+                listen = self._make_listen(report)
+                listen.listened_at = int(time.time())
+                self._client.submit_single_listen(listen)
+                self._last_scrobbled = report.uri
+            except Exception as err:
+                self.logger.exception(err)
+
+        # the listenbrainz client is not async friendly,
+        # so we need to run it in a executor thread
+        await asyncio.to_thread(handler)
 
 
 async def get_config_entries(
index 2f73c48e62e512035205b1fdee89cd240179a396..4ffd9cc238aff18161d72656ba903665d303920b 100644 (file)
@@ -22,10 +22,10 @@ class DummyHandler(ScrobblerHelper):
     def _is_configured(self) -> bool:
         return True
 
-    def _update_now_playing(self, report: MediaItemPlaybackProgressReport) -> None:
+    async def _update_now_playing(self, report: MediaItemPlaybackProgressReport) -> None:
         self._now_playing += 1
 
-    def _scrobble(self, report: MediaItemPlaybackProgressReport) -> None:
+    async def _scrobble(self, report: MediaItemPlaybackProgressReport) -> None:
         self._tracked += 1