Fix race condition when adding items to the library (#354)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Wed, 1 Jun 2022 11:17:03 +0000 (13:17 +0200)
committerGitHub <noreply@github.com>
Wed, 1 Jun 2022 11:17:03 +0000 (13:17 +0200)
* allow backgroundjobs to be awaited

* adjust controller to await the job if not lazy

music_assistant/mass.py
music_assistant/models/background_job.py
music_assistant/models/media_controller.py

index 2a7b0ce23e2449312d04dae7abafbde33a48571a..4cd451a5d34469f72bbce3d5b864ec8c329cb27b 100644 (file)
@@ -8,6 +8,7 @@ import threading
 from collections import deque
 from functools import partial
 from time import time
+from tkinter import NONE
 from types import TracebackType
 from typing import Any, Callable, Coroutine, Deque, List, Optional, Tuple, Type, Union
 from uuid import uuid4
@@ -137,20 +138,20 @@ class MusicAssistant:
 
     def add_job(
         self, coro: Coroutine, name: Optional[str] = None, allow_duplicate=False
-    ) -> None:
+    ) -> BackgroundJob:
         """Add job to be (slowly) processed in the background."""
         if not allow_duplicate:
-            # pylint: disable=protected-access
-            if any(x for x in self._jobs if x.name == name):
+            if existing := next((x for x in self._jobs if x.name == name), NONE):
                 self.logger.debug("Ignored duplicate job: %s", name)
                 coro.close()
-                return
+                return existing
         if not name:
             name = coro.__qualname__ or coro.__name__
         job = BackgroundJob(str(uuid4()), name=name, coro=coro)
         self._jobs.append(job)
         self._jobs_event.set()
         self.signal_event(MassEvent(EventType.BACKGROUND_JOB_UPDATED, data=job))
+        return job
 
     def create_task(
         self,
@@ -247,12 +248,15 @@ class MusicAssistant:
                 exc_info=err,
             )
         else:
+            job.result = task.result()
             job.status = JobStatus.FINISHED
             self.logger.info(
                 "Finished job [%s] in %s seconds.", job.name, execution_time
             )
         self._jobs.remove(job)
         self._jobs_event.set()
+        # mark job as done
+        job.done()
         self.signal_event(MassEvent(EventType.BACKGROUND_JOB_UPDATED, data=job))
 
     async def __aenter__(self) -> "MusicAssistant":
index bc61c3387f2b62a2ec0e2f067aa16922507d67a3..9d83a5f1eaab7389324b950dc80655c15224ba86 100644 (file)
@@ -1,7 +1,8 @@
 """Model for a Background Job."""
-from dataclasses import dataclass
+import asyncio
+from dataclasses import dataclass, field
 from time import time
-from typing import Coroutine
+from typing import Any, Coroutine
 
 from music_assistant.models.enums import JobStatus
 
@@ -15,6 +16,8 @@ class BackgroundJob:
     name: str
     timestamp: float = time()
     status: JobStatus = JobStatus.PENDING
+    result: Any = None
+    _evt: asyncio.Event = field(init=False, default_factory=asyncio.Event)
 
     def to_dict(self):
         """Return serializable dict from object."""
@@ -24,3 +27,11 @@ class BackgroundJob:
             "timestamp": self.status.value,
             "status": self.status.value,
         }
+
+    async def wait(self) -> None:
+        """Wait for the job to complete."""
+        await self._evt.wait()
+
+    def done(self) -> None:
+        """Mark job as done."""
+        self._evt.set()
index 001a009ffb2a496c2918fb101ee4bc8b24137899..babf4cf6b913f5fa2f5ac3d9cdb39def8473baf3 100644 (file)
@@ -91,12 +91,16 @@ class MediaControllerBase(Generic[ItemCls], metaclass=ABCMeta):
             provider_id=provider_id,
         )
         if db_item and (time() - db_item.last_refresh) > REFRESH_INTERVAL:
+            # it's been too long since the full metadata was last retrieved (or never at all)
             force_refresh = True
         if db_item and force_refresh:
+            # get (first) provider item id belonging to this db item
             provider_id, provider_item_id = await self.get_provider_id(db_item)
         elif db_item:
+            # we have a db item and no refreshing is needed, return the results!
             return db_item
         if not details and provider_id:
+            # no details provider nor in db, fetch them from the provider
             details = await self.get_provider_item(provider_item_id, provider_id)
         if not details and provider:
             # check providers for given provider type one by one
@@ -113,12 +117,19 @@ class MediaControllerBase(Generic[ItemCls], metaclass=ABCMeta):
                     else:
                         break
         if not details:
+            # we couldn't get a match from any of the providers, raise error
             raise MediaNotFoundError(
                 f"Item not found: {provider.value or provider_id}/{provider_item_id}"
             )
+        # create job to add the item to the db, including matching metadata etc. takes some time
+        # in 99% of the cases we just return lazy because we want the details as fast as possible
+        # only if we really need to wait for the result (e.g. to prevent race conditions), we
+        # can set lazy to false and we await to job to complete.
+        add_job = self.mass.add_job(self.add(details), f"Add {details.uri} to database")
         if not lazy:
-            return await self.add(details)
-        self.mass.add_job(self.add(details), f"Add {details.uri} to database")
+            await add_job.wait()
+            return add_job.result
+
         return db_item if db_item else details
 
     async def search(
@@ -155,6 +166,7 @@ class MediaControllerBase(Generic[ItemCls], metaclass=ABCMeta):
     ) -> None:
         """Add an item to the library."""
         # make sure we have a valid full item
+        # note that we set 'lazy' to False because we need a full db item
         db_item = await self.get(
             provider_item_id, provider=provider, provider_id=provider_id, lazy=False
         )
@@ -180,10 +192,11 @@ class MediaControllerBase(Generic[ItemCls], metaclass=ABCMeta):
     ) -> None:
         """Remove item from the library."""
         # make sure we have a valid full item
+        # note that we set 'lazy' to False because we need a full db item
         db_item = await self.get(
             provider_item_id, provider=provider, provider_id=provider_id, lazy=False
         )
-        # add to provider's libraries
+        # remove from provider's libraries
         for prov_id in db_item.provider_ids:
             if prov := self.mass.music.get_provider(prov_id.prov_id):
                 await prov.library_remove(prov_id.item_id, self.media_type)