Fix SMB Music provider (#540)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Fri, 17 Mar 2023 08:25:29 +0000 (09:25 +0100)
committerGitHub <noreply@github.com>
Fri, 17 Mar 2023 08:25:29 +0000 (09:25 +0100)
Create a single connection per action. This a bit slower but much more
reliable.
Now it seems to handle all test cases I throw at it just fine.

Also adjust the configuration a bit and split out the path into
server/host, share and subfolder

music_assistant/common/helpers/util.py
music_assistant/common/models/media_items.py
music_assistant/server/helpers/audio.py
music_assistant/server/providers/filesystem_local/base.py
music_assistant/server/providers/filesystem_smb/__init__.py
music_assistant/server/providers/filesystem_smb/helpers.py
music_assistant/server/providers/filesystem_smb/manifest.json

index 3d7d18666fc7ae5df1bb4d1d3dad3b9c54754606..82d674d5c4cc5f6e8e647d4be45b9b801fb0151e 100755 (executable)
@@ -178,7 +178,7 @@ async def select_free_port(range_start: int, range_end: int) -> int:
     return await asyncio.to_thread(_select_free_port)
 
 
-async def get_ip_from_host(dns_name: str) -> str:
+async def get_ip_from_host(dns_name: str) -> str | None:
     """Resolve (first) IP-address for given dns name."""
 
     def _resolve():
@@ -186,7 +186,7 @@ async def get_ip_from_host(dns_name: str) -> str:
             return socket.gethostbyname(dns_name)
         except Exception:  # pylint: disable=broad-except
             # fail gracefully!
-            return dns_name
+            return None
 
     return await asyncio.to_thread(_resolve)
 
index 88246f23be043ffb60479ce2361e39c901f7cc73..87fd46aef50f64d7a176ef9489262fb422f867d1 100755 (executable)
@@ -462,6 +462,8 @@ class StreamDetails(DataClassDictMixin):
     data: Any = None
     # if the url/file is supported by ffmpeg directly, use direct stream
     direct: str | None = None
+    # bool to indicate that the providers 'get_audio_stream' supports seeking of the item
+    can_seek: bool = True
     # callback: optional callback function (or coroutine) to call when the stream completes.
     # needed for streaming provivders to report what is playing
     # receives the streamdetails as only argument from which to grab
index 814ec598bef6e02184211fd9e0956419ce072173..4ee853e031e281b5f35cdb72f423d39ed3f57d37 100644 (file)
@@ -397,11 +397,13 @@ async def get_media_stream(
         strip_silence_end = False
 
     # collect all arguments for ffmpeg
+    seek_pos = seek_position if (streamdetails.direct or not streamdetails.can_seek) else 0
     args = await _get_ffmpeg_args(
         streamdetails=streamdetails,
         sample_rate=sample_rate,
         bit_depth=bit_depth,
-        seek_position=seek_position,
+        # only use ffmpeg seeking if the provider stream does not support seeking
+        seek_position=seek_pos,
         fade_in=fade_in,
     )
 
@@ -412,7 +414,8 @@ async def get_media_stream(
             """Task that grabs the source audio and feeds it to ffmpeg."""
             LOGGER.debug("writer started for %s", streamdetails.uri)
             music_prov = mass.get_provider(streamdetails.provider)
-            async for audio_chunk in music_prov.get_audio_stream(streamdetails, seek_position):
+            seek_pos = seek_position if streamdetails.can_seek else 0
+            async for audio_chunk in music_prov.get_audio_stream(streamdetails, seek_pos):
                 await ffmpeg_proc.write(audio_chunk)
             # write eof when last packet is received
             ffmpeg_proc.write_eof()
@@ -745,6 +748,8 @@ async def _get_ffmpeg_args(
     ]
     # collect input args
     input_args = []
+    if seek_position:
+        input_args += ["-ss", str(seek_position)]
     if streamdetails.direct:
         # ffmpeg can access the inputfile (or url) directly
         if streamdetails.direct.startswith("http"):
@@ -766,8 +771,6 @@ async def _get_ffmpeg_args(
                     "5xx",
                 ]
 
-        if seek_position:
-            input_args += ["-ss", str(seek_position)]
         input_args += ["-i", streamdetails.direct]
     else:
         # the input is received from pipe/stdin
index 9814cb19f3ce1538a53869b8345689a82ce68893..4b8e74a4718ddedc8a7b027dc7762c6c9121b7df 100644 (file)
@@ -42,6 +42,7 @@ TRACK_EXTENSIONS = ("mp3", "m4a", "mp4", "flac", "wav", "ogg", "aiff", "wma", "d
 PLAYLIST_EXTENSIONS = ("m3u", "pls")
 SUPPORTED_EXTENSIONS = TRACK_EXTENSIONS + PLAYLIST_EXTENSIONS
 IMAGE_EXTENSIONS = ("jpg", "jpeg", "JPG", "JPEG", "png", "PNG", "gif", "GIF")
+SEEKABLE_FILES = (ContentType.MP3, ContentType.WAV, ContentType.FLAC)
 
 SUPPORTED_FEATURES = (
     ProviderFeature.LIBRARY_ARTISTS,
@@ -253,8 +254,9 @@ class FileSystemProviderBase(MusicProvider):
                 continue
 
             try:
-                cur_checksums[item.path] = item.checksum
+                # continue if the item did not change (checksum still the same)
                 if item.checksum == prev_checksums.get(item.path):
+                    cur_checksums[item.path] = item.checksum
                     continue
 
                 if item.ext in TRACK_EXTENSIONS:
@@ -275,6 +277,9 @@ class FileSystemProviderBase(MusicProvider):
             except Exception as err:  # pylint: disable=broad-except
                 # we don't want the whole sync to crash on one file so we catch all exceptions here
                 self.logger.exception("Error processing %s - %s", item.path, str(err))
+            else:
+                # save item's checksum only if the parse succeeded
+                cur_checksums[item.path] = item.checksum
 
             # save checksums every 100 processed items
             # this allows us to pickup where we leftoff when initial scan gets interrupted
@@ -624,6 +629,7 @@ class FileSystemProviderBase(MusicProvider):
             sample_rate=prov_mapping.sample_rate,
             bit_depth=prov_mapping.bit_depth,
             direct=file_item.local_path,
+            can_seek=prov_mapping.content_type in SEEKABLE_FILES,
         )
 
     async def get_audio_stream(
index 581f3fb13280a17755d0056578ed34dd6048cd08..aa77683b3514ffb30aa7a57f52dabae7a99ea1e8 100644 (file)
@@ -1,6 +1,5 @@
 """SMB filesystem provider for Music Assistant."""
 
-import contextvars
 import logging
 import os
 from collections.abc import AsyncGenerator
@@ -9,7 +8,9 @@ from contextlib import asynccontextmanager
 from smb.base import SharedFile
 
 from music_assistant.common.helpers.util import get_ip_from_host
-from music_assistant.constants import CONF_PASSWORD, CONF_PATH, CONF_USERNAME
+from music_assistant.common.models.errors import LoginFailed
+from music_assistant.constants import CONF_PASSWORD, CONF_USERNAME
+from music_assistant.server.controllers.cache import use_cache
 from music_assistant.server.providers.filesystem_local.base import (
     FileSystemItem,
     FileSystemProviderBase,
@@ -21,6 +22,10 @@ from music_assistant.server.providers.filesystem_local.helpers import (
 
 from .helpers import AsyncSMB
 
+CONF_HOST = "host"
+CONF_SHARE = "share"
+CONF_SUBFOLDER = "subfolder"
+
 
 async def create_item(file_path: str, entry: SharedFile, root_path: str) -> FileSystemItem:
     """Create FileSystemItem from smb.SharedFile."""
@@ -37,9 +42,6 @@ async def create_item(file_path: str, entry: SharedFile, root_path: str) -> File
     )
 
 
-smb_conn_ctx = contextvars.ContextVar("smb_conn_ctx", default=None)
-
-
 class SMBFileSystemProvider(FileSystemProviderBase):
     """Implementation of an SMB File System Provider."""
 
@@ -51,25 +53,31 @@ class SMBFileSystemProvider(FileSystemProviderBase):
     async def setup(self) -> None:
         """Handle async initialization of the provider."""
         # silence SMB.SMBConnection logger a bit
-        logging.getLogger("SMB.SMBConnection").setLevel("INFO")
-        # extract params from path
-        if self.config.get_value(CONF_PATH).startswith("\\\\"):
-            path_parts = self.config.get_value(CONF_PATH)[2:].split("\\", 2)
-        elif self.config.get_value(CONF_PATH).startswith("//"):
-            path_parts = self.config.get_value(CONF_PATH)[2:].split("/", 2)
-        elif self.config.get_value(CONF_PATH).startswith("smb://"):
-            path_parts = self.config.get_value(CONF_PATH)[6:].split("/", 2)
-        else:
-            path_parts = self.config.get_value(CONF_PATH).split(os.sep)
-        self._remote_name = path_parts[0]
-        self._service_name = path_parts[1]
-        if len(path_parts) > 2:
-            self._root_path = os.sep + path_parts[2]
-
-        default_target_ip = await get_ip_from_host(self._remote_name)
-        self._target_ip = self.config.get_value("target_ip") or default_target_ip
+        logging.getLogger("SMB.SMBConnection").setLevel("WARNING")
+
+        self._remote_name = self.config.get_value(CONF_HOST)
+        self._service_name = self.config.get_value(CONF_SHARE)
+
+        # validate provided path
+        subfolder: str = self.config.get_value(CONF_SUBFOLDER)
+        subfolder.replace("\\", "/")
+        if not subfolder.startswith("/"):
+            subfolder = "/" + subfolder
+        if not subfolder.endswith("/"):
+            subfolder += "/"
+        self._root_path = subfolder
+
+        # resolve dns name to IP
+        target_ip = await get_ip_from_host(self._remote_name)
+        if target_ip is None:
+            raise LoginFailed(
+                f"Unable to resolve {self._remote_name}, maybe use an IP address as remote host ?"
+            )
+        self._target_ip = target_ip
+
+        # test connection and return
+        # this code will raise if the connection did not succeed
         async with self._get_smb_connection():
-            # test connection and return
             return
 
     async def listdir(
@@ -93,21 +101,23 @@ class SMBFileSystemProvider(FileSystemProviderBase):
         abs_path = get_absolute_path(self._root_path, path)
         async with self._get_smb_connection() as smb_conn:
             path_result: list[SharedFile] = await smb_conn.list_path(abs_path)
-            for entry in path_result:
-                if entry.filename.startswith("."):
-                    # skip invalid/system files and dirs
-                    continue
-                file_path = os.path.join(path, entry.filename)
-                item = await create_item(file_path, entry, self._root_path)
-                if recursive and item.is_dir:
-                    # yield sublevel recursively
-                    try:
-                        async for subitem in self.listdir(file_path, True):
-                            yield subitem
-                    except (OSError, PermissionError) as err:
-                        self.logger.warning("Skip folder %s: %s", item.path, str(err))
-                elif item.is_file or item.is_dir:
-                    yield item
+
+        for entry in path_result:
+            if entry.filename.startswith("."):
+                # skip invalid/system files and dirs
+                continue
+            file_path = os.path.join(path, entry.filename)
+            item = await create_item(file_path, entry, self._root_path)
+            if recursive and item.is_dir:
+                # yield sublevel recursively
+                try:
+                    async for subitem in self.listdir(file_path, True):
+                        yield subitem
+                except (OSError, PermissionError) as err:
+                    self.logger.warning("Skip folder %s: %s", item.path, str(err))
+            else:
+                # yield single item (file or directory)
+                yield item
 
     async def resolve(self, file_path: str) -> FileSystemItem:
         """Resolve (absolute or relative) path to FileSystemItem."""
@@ -124,6 +134,7 @@ class SMBFileSystemProvider(FileSystemProviderBase):
                 file_size=entry.file_size,
             )
 
+    @use_cache(15 * 60)
     async def exists(self, file_path: str) -> bool:
         """Return bool if this FileSystem musicprovider has given file/dir."""
         abs_path = get_absolute_path(self._root_path, file_path)
@@ -147,11 +158,10 @@ class SMBFileSystemProvider(FileSystemProviderBase):
     @asynccontextmanager
     async def _get_smb_connection(self) -> AsyncGenerator[AsyncSMB, None]:
         """Get instance of AsyncSMB."""
-        # for a task that consists of multiple steps,
-        # the smb connection may be reused (shared through a contextvar)
-        if existing := smb_conn_ctx.get():
-            yield existing
-            return
+        # For now we just create a connection per call
+        # as that is the most reliable (but a bit slower)
+        # this could be improved by creating a connection pool
+        # if really needed
 
         async with AsyncSMB(
             remote_name=self._remote_name,
@@ -159,8 +169,8 @@ class SMBFileSystemProvider(FileSystemProviderBase):
             username=self.config.get_value(CONF_USERNAME),
             password=self.config.get_value(CONF_PASSWORD),
             target_ip=self._target_ip,
-            options={key: value.value for key, value in self.config.values.items()},
+            use_ntlm_v2=self.config.get_value("use_ntlm_v2"),
+            sign_options=self.config.get_value("sign_options"),
+            is_direct_tcp=self.config.get_value("is_direct_tcp"),
         ) as smb_conn:
-            token = smb_conn_ctx.set(smb_conn)
             yield smb_conn
-        smb_conn_ctx.reset(token)
index 8898bb5e4de555bc1a744533225a11693613985b..8baed531b719f82d21551050996c3a8bf2a915e3 100644 (file)
@@ -4,10 +4,8 @@ from __future__ import annotations
 import asyncio
 from collections.abc import AsyncGenerator
 from io import BytesIO
-from typing import Any
 
-from smb.base import SharedFile, SMBTimeout
-from smb.smb_structs import OperationFailure
+from smb.base import OperationFailure, SharedFile
 from smb.SMBConnection import SMBConnection
 
 from music_assistant.common.models.errors import LoginFailed
@@ -25,7 +23,9 @@ class AsyncSMB:
         username: str,
         password: str,
         target_ip: str,
-        options: dict[str, Any],
+        use_ntlm_v2: bool = True,
+        sign_options: int = 2,
+        is_direct_tcp: bool = False,
     ) -> None:
         """Initialize instance."""
         self._service_name = service_name
@@ -38,67 +38,77 @@ class AsyncSMB:
             password=self._password,
             my_name=SERVICE_NAME,
             remote_name=self._remote_name,
-            # choose sane default options but allow user to override them via the options dict
-            domain=options.get("domain", ""),
-            use_ntlm_v2=options.get("use_ntlm_v2", False),
-            sign_options=options.get("sign_options", 2),
-            is_direct_tcp=options.get("is_direct_tcp", False),
+            use_ntlm_v2=use_ntlm_v2,
+            sign_options=sign_options,
+            is_direct_tcp=is_direct_tcp,
         )
+        # the smb connection may not be used asynchronously and
+        # each operation should take sequentially.
+        # to support this, we use a Lock and we create a new.
+        self._lock = asyncio.Lock()
 
     async def list_path(self, path: str) -> list[SharedFile]:
         """Retrieve a directory listing of files/folders at *path*."""
-        return await asyncio.to_thread(self._conn.listPath, self._service_name, path)
+        async with self._lock:
+            return await asyncio.to_thread(self._conn.listPath, self._service_name, path)
 
     async def get_attributes(self, path: str) -> SharedFile:
         """Retrieve information about the file at *path* on the *service_name*."""
-        return await asyncio.to_thread(self._conn.getAttributes, self._service_name, path)
+        async with self._lock:
+            return await asyncio.to_thread(self._conn.getAttributes, self._service_name, path)
 
     async def retrieve_file(self, path: str, offset: int = 0) -> AsyncGenerator[bytes, None]:
         """Retrieve file contents."""
         chunk_size = 256000
         while True:
-            with BytesIO() as file_obj:
-                await asyncio.to_thread(
-                    self._conn.retrieveFileFromOffset,
-                    self._service_name,
-                    path,
-                    file_obj,
-                    offset,
-                    chunk_size,
-                )
-                file_obj.seek(0)
-                chunk = file_obj.read()
-                yield chunk
-                offset += len(chunk)
-                if len(chunk) < chunk_size:
-                    break
+            async with self._lock:
+                with BytesIO() as file_obj:
+                    await asyncio.to_thread(
+                        self._conn.retrieveFileFromOffset,
+                        self._service_name,
+                        path,
+                        file_obj,
+                        offset,
+                        chunk_size,
+                    )
+                    file_obj.seek(0)
+                    chunk = file_obj.read()
+                    yield chunk
+                    offset += len(chunk)
+                    if len(chunk) < chunk_size:
+                        break
 
     async def write_file(self, path: str, data: bytes) -> SharedFile:
         """Store the contents to the file at *path*."""
         with BytesIO() as file_obj:
             file_obj.write(data)
             file_obj.seek(0)
-            await asyncio.to_thread(
-                self._conn.storeFile,
-                self._service_name,
-                path,
-                file_obj,
-            )
+            async with self._lock:
+                await asyncio.to_thread(
+                    self._conn.storeFile,
+                    self._service_name,
+                    path,
+                    file_obj,
+                )
 
     async def path_exists(self, path: str) -> bool:
         """Return bool is this FileSystem musicprovider has given file/dir."""
-        try:
-            await asyncio.to_thread(self._conn.getAttributes, self._service_name, path)
-        except (OperationFailure, SMBTimeout):
-            return False
-        return True
+        async with self._lock:
+            try:
+                await asyncio.to_thread(self._conn.getAttributes, self._service_name, path)
+            except (OperationFailure,):
+                return False
+            except IndexError:
+                return False
+            return True
 
     async def connect(self) -> None:
         """Connect to the SMB server."""
-        try:
-            assert await asyncio.to_thread(self._conn.connect, self._target_ip) is True
-        except Exception as exc:
-            raise LoginFailed(f"SMB Connect failed to {self._remote_name}") from exc
+        async with self._lock:
+            try:
+                assert await asyncio.to_thread(self._conn.connect, self._target_ip) is True
+            except Exception as exc:
+                raise LoginFailed(f"SMB Connect failed to {self._remote_name}") from exc
 
     async def __aenter__(self) -> AsyncSMB:
         """Enter context manager."""
index ae7844d89309d0590bffeae1018211ba515bd19e..2ef2a4224244db568af6e376d95b98272dcf8768 100644 (file)
@@ -6,29 +6,39 @@
   "codeowners": ["@MarvinSchenkel", "@marcelveldt"],
   "config_entries": [
     {
-      "key": "path",
+      "key": "host",
       "type": "string",
-      "label": "Path",
-      "description": "Full SMB path to the files, e.g. \\\\server\\share\\folder or smb://server/share"
+      "label": "Remote host",
+      "description": "The hostname of the SMB/CIFS server to connect to. For example mynas.local. You may need to use the IP address instead of DNS name.",
+      "required": true
+    },
+    {
+      "key": "share",
+      "type": "string",
+      "label": "Share",
+      "description": "The name of the share/service you'd like to connect to on the remote host, For example 'media'.",
+      "required": true
+    },
+    {
+      "key": "subfolder",
+      "type": "string",
+      "label": "Subfolder",
+      "description": "[optional] Use if your music is stored in a sublevel of the share. E.g. 'music' or 'music/collection'.",
+      "default_value": "",
+      "required": false
     },
     {
       "key": "username",
       "type": "string",
-      "label": "Username"
+      "label": "Username",
+      "default_value": "anonymous",
+      "required": true
     },
     {
       "key": "password",
       "type": "secure_string",
       "label": "Password"
     },
-    {
-      "key": "target_ip",
-      "type": "string",
-      "label": "Target IP",
-      "description": "Use in case of DNS resolve issues. Connect to this IP instead of the DNS name.",
-      "advanced": true,
-      "required": false
-    },
     {
       "key": "domain",
       "type": "string",
@@ -43,7 +53,7 @@
       "type": "boolean",
       "label": "Use NTLM v2",
       "default_value": true,
-      "description": "Indicates whether pysmb should be NTLMv1 or NTLMv2 authentication algorithm for authentication. The choice of NTLMv1 and NTLMv2 is configured on the remote server, and there is no mechanism to auto-detect which algorithm has been configured. Hence, we can only “guess” or try both algorithms. On Sambda, Windows Vista and Windows 7, NTLMv2 is enabled by default. On Windows XP, we can use NTLMv1 before NTLMv2.",
+      "description": "Indicates whether NTLMv1 or NTLMv2 authentication algorithm should be used for authentication. The choice of NTLMv1 and NTLMv2 is configured on the remote server, and there is no mechanism to auto-detect which algorithm has been configured. Hence, we can only “guess” or try both algorithms. On Sambda, Windows Vista and Windows 7, NTLMv2 is enabled by default. On Windows XP, we can use NTLMv1 before NTLMv2.",
       "advanced": true,
       "required": false
     },
@@ -65,7 +75,7 @@
       "key": "is_direct_tcp",
       "type": "boolean",
       "label": "Use Direct TCP",
-      "default_value": true,
+      "default_value": false,
       "description": "Controls whether the NetBIOS over TCP/IP (is_direct_tcp=False) or the newer Direct hosting of SMB over TCP/IP (is_direct_tcp=True) will be used for the communication. The default parameter is False which will use NetBIOS over TCP/IP for wider compatibility (TCP port: 139).",
       "advanced": true,
       "required": false