One last attempt to get a stable smb provider
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Wed, 29 Mar 2023 00:07:08 +0000 (02:07 +0200)
committerMarcel van der Veldt <m.vanderveldt@outlook.com>
Wed, 29 Mar 2023 00:07:08 +0000 (02:07 +0200)
music_assistant/server/helpers/process.py
music_assistant/server/helpers/util.py
music_assistant/server/providers/filesystem_local/__init__.py
music_assistant/server/providers/filesystem_smb/__init__.py

index 164cc44e2f44a0873fa9ace06bf133e7a253b9e2..fafa623950efa5014c115cf376979c1ce8228325 100644 (file)
@@ -8,6 +8,7 @@ from __future__ import annotations
 import asyncio
 import logging
 from collections.abc import AsyncGenerator, Coroutine
+from contextlib import suppress
 
 LOGGER = logging.getLogger(__name__)
 
@@ -68,6 +69,9 @@ class AsyncProcess:
                 await self._proc.communicate()
         if self._proc.returncode is None:
             self._proc.kill()
+        if self._attached_task and not self._attached_task.done():
+            with suppress(asyncio.CancelledError):
+                self._attached_task.cancel()
 
     async def iter_chunked(self, n: int = DEFAULT_CHUNKSIZE) -> AsyncGenerator[bytes, None]:
         """Yield chunks of n size from the process stdout."""
index 0ed59a549ef72cecf5656cfb7deca27f989383c2..ac265e347905155bce2d4ca0ad99d6525959636f 100644 (file)
@@ -9,6 +9,8 @@ from contextlib import suppress
 from functools import lru_cache
 from typing import TYPE_CHECKING, Any
 
+from music_assistant.common.helpers.util import empty_queue
+
 if TYPE_CHECKING:
     from music_assistant.server.models import ProviderModuleType
 
@@ -39,35 +41,3 @@ async def get_provider_module(domain: str) -> ProviderModuleType:
         return importlib.import_module(f".{domain}", "music_assistant.server.providers")
 
     return await asyncio.to_thread(_get_provider_module, domain)
-
-
-async def async_iter(sync_iterator: Iterator, *args, **kwargs) -> AsyncGenerator[Any, None]:
-    """Wrap blocking iterator into an asynchronous one."""
-    # inspired by: https://stackoverflow.com/questions/62294385/synchronous-generator-in-asyncio
-    loop = asyncio.get_running_loop()
-    queue = asyncio.Queue(1)
-    _exit = asyncio.Event()
-    _end_ = object()
-
-    def iter_to_queue():
-        for item in sync_iterator(*args, **kwargs):
-            if _exit.is_set():
-                return
-            asyncio.run_coroutine_threadsafe(queue.put(item), loop).result()
-        asyncio.run_coroutine_threadsafe(queue.put(_end_), loop).result()
-
-    iter_fut = loop.run_in_executor(None, iter_to_queue)
-    try:
-        while True:
-            next_item = await queue.get()
-            if next_item is _end_:
-                break
-            yield next_item
-    finally:
-        # cleanup
-        _exit.set()
-        if not iter_fut.done():
-            iter_fut.cancel()
-            await iter_fut
-        with suppress(asyncio.QueueEmpty):
-            queue.get_nowait()
index 8469ce16817233fa01e9a95681e1650c5198db7f..3cb695cac45c40d735008ed237562a8c83f1cefe 100644 (file)
@@ -14,7 +14,6 @@ from music_assistant.common.models.config_entries import ConfigEntry
 from music_assistant.common.models.enums import ConfigEntryType
 from music_assistant.common.models.errors import SetupFailedError
 from music_assistant.constants import CONF_PATH
-from music_assistant.server.helpers.util import async_iter
 
 from .base import (
     CONF_ENTRY_MISSING_ALBUM_ARTIST,
@@ -104,11 +103,12 @@ class LocalFileSystemProvider(FileSystemProviderBase):
 
         """
         abs_path = get_absolute_path(self.config.get_value(CONF_PATH), path)
-        async for entry in async_iter(os.scandir, abs_path):
+        self.logger.debug("Processing: %s", abs_path)
+        entries = await asyncio.to_thread(os.scandir, abs_path)
+        for entry in entries:
             if entry.name.startswith(".") or any(x in entry.name for x in IGNORE_DIRS):
                 # skip invalid/system files and dirs
                 continue
-
             item = await create_item(self.config.get_value(CONF_PATH), entry)
             if recursive and item.is_dir:
                 try:
index fa29cc6296e494bc97294c141dcdd3b9deafee52..46970060b92be4c4101055bf34afec60fa9296ad 100644 (file)
@@ -5,19 +5,19 @@ import asyncio
 import logging
 import os
 from collections.abc import AsyncGenerator
+from contextlib import suppress
 from os.path import basename
 from typing import TYPE_CHECKING
 
 import smbclient
 from smbclient import path as smbpath
 
-from music_assistant.common.helpers.util import get_ip_from_host
+from music_assistant.common.helpers.util import empty_queue, get_ip_from_host
 from music_assistant.common.models.config_entries import ConfigEntry
 from music_assistant.common.models.enums import ConfigEntryType
 from music_assistant.common.models.errors import LoginFailed
 from music_assistant.constants import CONF_PASSWORD, CONF_USERNAME
 from music_assistant.server.controllers.cache import use_cache
-from music_assistant.server.helpers.util import async_iter
 from music_assistant.server.providers.filesystem_local.base import (
     CONF_ENTRY_MISSING_ALBUM_ARTIST,
     IGNORE_DIRS,
@@ -188,7 +188,9 @@ class SMBFileSystemProvider(FileSystemProviderBase):
 
         """
         abs_path = get_absolute_path(self._root_path, path)
-        async for entry in async_iter(smbclient.scandir, abs_path):
+        self.logger.debug("Processing: %s", abs_path)
+        entries = await asyncio.to_thread(smbclient.scandir, abs_path)
+        for entry in entries:
             if entry.name.startswith(".") or any(x in entry.name for x in IGNORE_DIRS):
                 # skip invalid/system files and dirs
                 continue
@@ -244,6 +246,8 @@ class SMBFileSystemProvider(FileSystemProviderBase):
         file_path = file_path.replace("\\", os.sep)
         absolute_path = get_absolute_path(self._root_path, file_path)
 
+        queue = asyncio.Queue(1)
+
         def _reader():
             self.logger.debug("Reading file contents for %s", absolute_path)
             try:
@@ -257,18 +261,32 @@ class SMBFileSystemProvider(FileSystemProviderBase):
                     while True:
                         chunk = _file.read(chunk_size)
                         if not chunk:
-                            break
-                        yield chunk
+                            return
+                        asyncio.run_coroutine_threadsafe(queue.put(chunk), self.mass.loop).result()
                         bytes_sent += len(chunk)
             finally:
+                asyncio.run_coroutine_threadsafe(queue.put(b""), self.mass.loop).result()
                 self.logger.debug(
                     "Finished Reading file contents for %s - bytes transferred: %s",
                     absolute_path,
                     bytes_sent,
                 )
 
-        async for chunk in async_iter(_reader):
-            yield chunk
+        try:
+            task = self.mass.create_task(_reader)
+
+            while True:
+                chunk = await queue.get()
+                if not chunk:
+                    break
+                yield chunk
+        finally:
+            empty_queue(queue)
+            if task and not task.done():
+                task.cancel()
+                with suppress(asyncio.CancelledError):
+                    await task
+            del queue
 
     async def write_file_content(self, file_path: str, data: bytes) -> None:
         """Write entire file content as bytes (e.g. for playlists)."""