Change SMB Provider to use OS-level mounts (#603)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Sat, 1 Apr 2023 21:19:45 +0000 (23:19 +0200)
committerGitHub <noreply@github.com>
Sat, 1 Apr 2023 21:19:45 +0000 (23:19 +0200)
wrap os-level mount commands for the smb provider instead of native python

---------

Co-authored-by: Marvin Schenkel <marvinschenkel@gmail.com>
Dockerfile
docker-compose.example.yml
music_assistant/server/providers/filesystem_local/__init__.py
music_assistant/server/providers/filesystem_smb/__init__.py
music_assistant/server/providers/filesystem_smb/manifest.json
requirements_all.txt

index 085a59ec287fcd5c38585b3082c94df2cddd538b..96680d9ca508537603524ee076d2cc259fea582b 100644 (file)
@@ -56,6 +56,7 @@ RUN set -x \
         libsox-fmt-all \
         libsox3 \
         sox \
+        cifs-utils \
     # cleanup
     && rm -rf /tmp/* \
     && rm -rf /var/lib/apt/lists/*
index 1b30a659ae11592fb0da670d75191e8fb5a99fa6..cc1ee13634df9c7153b27f0fa27f48c77d406e0b 100644 (file)
@@ -11,3 +11,8 @@ services:
     network_mode: host
     volumes:
       - ${USERDIR:-$HOME}/docker/music-assistant-server/data:/data/
+    # privileged caps needed to mount smb folders within the container
+    cap_add:
+      - SYS_ADMIN
+      - DAC_READ_SEARCH
+    privileged: true
index 3cb695cac45c40d735008ed237562a8c83f1cefe..e65896e5d73f907d340df45cfa5cb5e331b1e789 100644 (file)
@@ -34,12 +34,16 @@ listdir = wrap(os.listdir)
 isdir = wrap(os.path.isdir)
 isfile = wrap(os.path.isfile)
 exists = wrap(os.path.exists)
+makedirs = wrap(os.makedirs)
 
 
 async def setup(
     mass: MusicAssistant, manifest: ProviderManifest, config: ProviderConfig
 ) -> ProviderInstanceType:
     """Initialize provider(instance) with given configuration."""
+    conf_path = config.get_value(CONF_PATH)
+    if not await isdir(conf_path):
+        raise SetupFailedError(f"Music Directory {conf_path} does not exist")
     prov = LocalFileSystemProvider(mass, manifest, config)
     await prov.handle_setup()
     return prov
@@ -80,11 +84,11 @@ async def create_item(base_path: str, entry: os.DirEntry) -> FileSystemItem:
 class LocalFileSystemProvider(FileSystemProviderBase):
     """Implementation of a musicprovider for local files."""
 
+    base_path: str
+
     async def handle_setup(self) -> None:
         """Handle async initialization of the provider."""
-        conf_path = self.config.get_value(CONF_PATH)
-        if not await isdir(conf_path):
-            raise SetupFailedError(f"Music Directory {conf_path} does not exist")
+        self.base_path = self.config.get_value(CONF_PATH)
 
     async def listdir(
         self, path: str, recursive: bool = False
@@ -102,14 +106,15 @@ class LocalFileSystemProvider(FileSystemProviderBase):
             AsyncGenerator yielding FileSystemItem objects.
 
         """
-        abs_path = get_absolute_path(self.config.get_value(CONF_PATH), path)
-        self.logger.debug("Processing: %s", abs_path)
+        abs_path = get_absolute_path(self.base_path, path)
+        rel_path = get_relative_path(self.base_path, path)
+        self.logger.debug("Processing: %s", rel_path)
         entries = await asyncio.to_thread(os.scandir, abs_path)
         for entry in entries:
             if entry.name.startswith(".") or any(x in entry.name for x in IGNORE_DIRS):
                 # skip invalid/system files and dirs
                 continue
-            item = await create_item(self.config.get_value(CONF_PATH), entry)
+            item = await create_item(self.base_path, entry)
             if recursive and item.is_dir:
                 try:
                     async for subitem in self.listdir(item.absolute_path, True):
@@ -127,13 +132,13 @@ class LocalFileSystemProvider(FileSystemProviderBase):
         If require_local is True, we prefer to have the `local_path` attribute filled
         (e.g. with a tempfile), if supported by the provider/item.
         """
-        absolute_path = get_absolute_path(self.config.get_value(CONF_PATH), file_path)
+        absolute_path = get_absolute_path(self.base_path, file_path)
 
         def _create_item():
             stat = os.stat(absolute_path, follow_symlinks=False)
             return FileSystemItem(
                 name=os.path.basename(file_path),
-                path=get_relative_path(self.config.get_value(CONF_PATH), file_path),
+                path=get_relative_path(self.base_path, file_path),
                 absolute_path=absolute_path,
                 is_dir=os.path.isdir(absolute_path),
                 is_file=os.path.isfile(absolute_path),
@@ -150,12 +155,12 @@ class LocalFileSystemProvider(FileSystemProviderBase):
         """Return bool is this FileSystem musicprovider has given file/dir."""
         if not file_path:
             return False  # guard
-        abs_path = get_absolute_path(self.config.get_value(CONF_PATH), file_path)
+        abs_path = get_absolute_path(self.base_path, file_path)
         return await exists(abs_path)
 
     async def read_file_content(self, file_path: str, seek: int = 0) -> AsyncGenerator[bytes, None]:
         """Yield (binary) contents of file in chunks of bytes."""
-        abs_path = get_absolute_path(self.config.get_value(CONF_PATH), file_path)
+        abs_path = get_absolute_path(self.base_path, file_path)
         chunk_size = 512000
         async with aiofiles.open(abs_path, "rb") as _file:
             if seek:
@@ -169,6 +174,6 @@ class LocalFileSystemProvider(FileSystemProviderBase):
 
     async def write_file_content(self, file_path: str, data: bytes) -> None:
         """Write entire file content as bytes (e.g. for playlists)."""
-        abs_path = get_absolute_path(self.config.get_value(CONF_PATH), file_path)
+        abs_path = get_absolute_path(self.base_path, file_path)
         async with aiofiles.open(abs_path, "wb") as _file:
             await _file.write(data)
index 49297b6a01e9f14e6dc90c2b994a57151c8f9a3a..2498fc4755595283fc343bec1e0864d5684e2156 100644 (file)
@@ -2,31 +2,19 @@
 from __future__ import annotations
 
 import asyncio
-import logging
-import os
-from collections.abc import AsyncGenerator
-from contextlib import suppress
-from os.path import basename
+import platform
 from typing import TYPE_CHECKING
 
-import smbclient
-from smbclient import path as smbpath
-
-from music_assistant.common.helpers.util import empty_queue, get_ip_from_host
+from music_assistant.common.helpers.util import get_ip_from_host
 from music_assistant.common.models.config_entries import ConfigEntry
 from music_assistant.common.models.enums import ConfigEntryType
 from music_assistant.common.models.errors import LoginFailed
 from music_assistant.constants import CONF_PASSWORD, CONF_USERNAME
-from music_assistant.server.controllers.cache import use_cache
-from music_assistant.server.providers.filesystem_local.base import (
+from music_assistant.server.providers.filesystem_local import (
     CONF_ENTRY_MISSING_ALBUM_ARTIST,
-    IGNORE_DIRS,
-    FileSystemItem,
-    FileSystemProviderBase,
-)
-from music_assistant.server.providers.filesystem_local.helpers import (
-    get_absolute_path,
-    get_relative_path,
+    LocalFileSystemProvider,
+    exists,
+    makedirs,
 )
 
 if TYPE_CHECKING:
@@ -38,20 +26,21 @@ if TYPE_CHECKING:
 CONF_HOST = "host"
 CONF_SHARE = "share"
 CONF_SUBFOLDER = "subfolder"
-CONF_CONN_LIMIT = "connection_limit"
+CONF_MOUNT_OPTIONS = "mount_options"
 
 
 async def setup(
     mass: MusicAssistant, manifest: ProviderManifest, config: ProviderConfig
 ) -> ProviderInstanceType:
     """Initialize provider(instance) with given configuration."""
-    # silence logging a bit on smbprotocol
-    logging.getLogger("smbprotocol").setLevel("WARNING")
-    logging.getLogger("smbclient").setLevel("INFO")
-    # check if valid dns name is given
+    # check if valid dns name is given for the host
     server: str = config.get_value(CONF_HOST)
     if not await get_ip_from_host(server):
         raise LoginFailed(f"Unable to resolve {server}, make sure the address is resolveable.")
+    # check if share is valid
+    share: str = config.get_value(CONF_SHARE)
+    if not share or "/" in share or "\\" in share:
+        raise LoginFailed("Invalid share name")
     prov = SMBFileSystemProvider(mass, manifest, config)
     await prov.handle_setup()
     return prov
@@ -105,195 +94,101 @@ async def get_config_entries(
             description="[optional] Use if your music is stored in a sublevel of the share. "
             "E.g. 'collections' or 'albums/A-K'.",
         ),
+        ConfigEntry(
+            key=CONF_MOUNT_OPTIONS,
+            type=ConfigEntryType.STRING,
+            label="Mount options",
+            required=False,
+            advanced=True,
+            default_value="file_mode=0775,dir_mode=0775,uid=0,gid=0",
+            description="[optional] Any additional mount options you "
+            "want to pass to the mount command if needed for your particular setup.",
+        ),
         CONF_ENTRY_MISSING_ALBUM_ARTIST,
     )
 
 
-async def create_item(base_path: str, entry: smbclient.SMBDirEntry) -> FileSystemItem:
-    """Create FileSystemItem from smbclient.SMBDirEntry."""
-
-    def _create_item():
-        entry_path = entry.path.replace("/\\", os.sep).replace("\\", os.sep)
-        absolute_path = get_absolute_path(base_path, entry_path)
-        stat = entry.stat(follow_symlinks=False)
-        return FileSystemItem(
-            name=entry.name,
-            path=get_relative_path(base_path, entry_path),
-            absolute_path=absolute_path,
-            is_file=entry.is_file(follow_symlinks=False),
-            is_dir=entry.is_dir(follow_symlinks=False),
-            checksum=str(int(stat.st_mtime)),
-            file_size=stat.st_size,
-        )
-
-    # run in thread because strictly taken this may be blocking IO
-    return await asyncio.to_thread(_create_item)
-
+class SMBFileSystemProvider(LocalFileSystemProvider):
+    """
+    Implementation of an SMB File System Provider.
 
-class SMBFileSystemProvider(FileSystemProviderBase):
-    """Implementation of an SMB File System Provider."""
+    Basically this is just a wrapper around the regular local files provider,
+    except for the fact that it will mount a remote folder to a temporary location.
+    We went for this OS-depdendent approach because there is no solid async-compatible
+    smb library for Python (and we tried both pysmb and smbprotocol).
+    """
 
     async def handle_setup(self) -> None:
         """Handle async initialization of the provider."""
-        server: str = self.config.get_value(CONF_HOST)
-        share: str = self.config.get_value(CONF_SHARE)
-        subfolder: str = self.config.get_value(CONF_SUBFOLDER)
-
-        # create windows like path (\\server\share\subfolder)
-        if subfolder.endswith(os.sep):
-            subfolder = subfolder[:-1]
-        subfolder = subfolder.replace("\\", os.sep).replace("/", os.sep)
-        self._root_path = f"{os.sep}{os.sep}{server}{os.sep}{share}{os.sep}{subfolder}"
-        self.logger.debug("Using root path: %s", self._root_path)
+        # base_path will be the path where we're going to mount the remote share
+        self.base_path = f"/tmp/{self.instance_id}"
+        if not await exists(self.base_path):
+            await makedirs(self.base_path)
 
-        # register smb session
-        self.logger.info("Connecting to server %s", server)
         try:
-            self._session = await asyncio.to_thread(
-                smbclient.register_session,
-                server,
-                username=self.config.get_value(CONF_USERNAME),
-                password=self.config.get_value(CONF_PASSWORD),
-            )
-            # validate provided path
-            if not await asyncio.to_thread(smbpath.isdir, self._root_path):
-                raise LoginFailed(f"Invalid subfolder given: {subfolder}")
+            await self.mount()
         except Exception as err:
-            if "Unable to negotiate " in str(err):
-                detail = "Invalid credentials"
-            elif "refused " in str(err):
-                detail = "Invalid hostname (or host not reachable)"
-            elif "STATUS_NOT_FOUND" in str(err):
-                detail = "Share does not exist"
-            elif "Invalid argument" in str(err) and "." not in server:
-                detail = "Make sure to enter a FQDN hostname or IP-address"
-            else:
-                detail = str(err)
-            raise LoginFailed(f"Connection failed for the given details: {detail}") from err
-
-    async def listdir(
-        self, path: str, recursive: bool = False
-    ) -> AsyncGenerator[FileSystemItem, None]:
-        """List contents of a given provider directory/path.
-
-        Parameters
-        ----------
-        - path: path of the directory (relative or absolute) to list contents of.
-            Empty string for provider's root.
-        - recursive: If True will recursively keep unwrapping subdirectories (scandir equivalent).
-
-        Returns:
-        -------
-            AsyncGenerator yielding FileSystemItem objects.
+            raise LoginFailed(f"Connection failed for the given details: {err}") from err
 
+    async def unload(self) -> None:
         """
-        abs_path = get_absolute_path(self._root_path, path)
-        self.logger.debug("Processing: %s", abs_path)
-        entries = await asyncio.to_thread(smbclient.scandir, abs_path)
-        for entry in entries:
-            if entry.name.startswith(".") or any(x in entry.name for x in IGNORE_DIRS):
-                # skip invalid/system files and dirs
-                continue
-            item = await create_item(self._root_path, entry)
-            if recursive and item.is_dir:
-                async for subitem in self.listdir(item.absolute_path, True):
-                    yield subitem
-            else:
-                yield item
-
-    async def resolve(
-        self, file_path: str, require_local: bool = False  # noqa: ARG002
-    ) -> FileSystemItem:
-        """Resolve (absolute or relative) path to FileSystemItem.
+        Handle unload/close of the provider.
 
-        If require_local is True, we prefer to have the `local_path` attribute filled
-        (e.g. with a tempfile), if supported by the provider/item.
+        Called when provider is deregistered (e.g. MA exiting or config reloading).
         """
-        file_path = file_path.replace("\\", os.sep)
-        absolute_path = get_absolute_path(self._root_path, file_path)
+        await self.unmount()
 
-        def _create_item():
-            stat = smbclient.stat(absolute_path, follow_symlinks=False)
-            return FileSystemItem(
-                name=basename(file_path),
-                path=get_relative_path(self._root_path, file_path),
-                absolute_path=absolute_path,
-                is_dir=smbpath.isdir(absolute_path),
-                is_file=smbpath.isfile(absolute_path),
-                checksum=str(int(stat.st_mtime)),
-                file_size=stat.st_size,
-            )
-
-        # run in thread because strictly taken this may be blocking IO
-        return await asyncio.to_thread(_create_item)
-
-    @use_cache(120)
-    async def exists(self, file_path: str) -> bool:
-        """Return bool is this FileSystem musicprovider has given file/dir."""
-        if not file_path:
-            return False  # guard
-        file_path = file_path.replace("\\", os.sep)
-        abs_path = get_absolute_path(self._root_path, file_path)
-        try:
-            return await asyncio.to_thread(smbpath.exists, abs_path)
-        except Exception as err:
-            if "STATUS_OBJECT_NAME_INVALID" in str(err):
-                return False
-            raise err
-
-    async def read_file_content(self, file_path: str, seek: int = 0) -> AsyncGenerator[bytes, None]:
-        """Yield (binary) contents of file in chunks of bytes."""
-        file_path = file_path.replace("\\", os.sep)
-        absolute_path = get_absolute_path(self._root_path, file_path)
-
-        queue = asyncio.Queue(1)
-
-        def _reader():
-            self.logger.debug("Reading file contents for %s", absolute_path)
-            try:
-                chunk_size = 64000
-                bytes_sent = 0
-                with smbclient.open_file(
-                    absolute_path, "rb", buffering=chunk_size, share_access="r"
-                ) as _file:
-                    if seek:
-                        _file.seek(seek)
-                    while True:
-                        chunk = _file.read(chunk_size)
-                        if not chunk:
-                            return
-                        asyncio.run_coroutine_threadsafe(queue.put(chunk), self.mass.loop).result()
-                        bytes_sent += len(chunk)
-            finally:
-                asyncio.run_coroutine_threadsafe(queue.put(b""), self.mass.loop).result()
-                self.logger.debug(
-                    "Finished Reading file contents for %s - bytes transferred: %s",
-                    absolute_path,
-                    bytes_sent,
-                )
-
-        try:
-            task = self.mass.create_task(_reader)
-
-            while True:
-                chunk = await queue.get()
-                if not chunk:
-                    break
-                yield chunk
-        finally:
-            empty_queue(queue)
-            if task and not task.done():
-                task.cancel()
-                with suppress(asyncio.CancelledError):
-                    await task
-
-    async def write_file_content(self, file_path: str, data: bytes) -> None:
-        """Write entire file content as bytes (e.g. for playlists)."""
-        file_path = file_path.replace("\\", os.sep)
-        abs_path = get_absolute_path(self._root_path, file_path)
-
-        def _writer():
-            with smbclient.open_file(abs_path, "wb") as _file:
-                _file.write(data)
+    async def mount(self) -> None:
+        """Mount the SMB location to a temporary folder."""
+        server: str = self.config.get_value(CONF_HOST)
+        username: str = self.config.get_value(CONF_USERNAME)
+        password: str = self.config.get_value(CONF_PASSWORD)
+        share: str = self.config.get_value(CONF_SHARE)
 
-        await asyncio.to_thread(_writer)
+        # handle optional subfolder
+        subfolder: str = self.config.get_value(CONF_SUBFOLDER)
+        if subfolder:
+            subfolder = subfolder.replace("\\", "/")
+            if not subfolder.startswith("/"):
+                subfolder = "/" + subfolder
+            if subfolder.endswith("/"):
+                subfolder = subfolder[:-1]
+
+        if platform.system() == "Darwin":
+            password_str = f":{password}" if password else ""
+            mount_cmd = f"mount -t smbfs //{username}{password_str}@{server}/{share}{subfolder} {self.base_path}"  # noqa: E501
+
+        elif platform.system() == "Linux":
+            options = [
+                "rw",
+                f'username="{username}"',
+            ]
+            if password:
+                options.append(f'password="{password}"')
+            if mount_options := self.config.get_value(CONF_MOUNT_OPTIONS):
+                options += mount_options.split(",")
+            mount_cmd = f"mount -t cifs -o {','.join(options)} //{server}/{share}{subfolder} {self.base_path}"  # noqa: E501
+
+        else:
+            raise LoginFailed(f"SMB provider is not supported on {platform.system()}")
+
+        self.logger.info("Mounting //%s/%s%s to %s", server, share, subfolder, self.base_path)
+        self.logger.debug("Using mount command: %s", mount_cmd.replace(password, "########"))
+
+        proc = await asyncio.create_subprocess_shell(
+            mount_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
+        )
+        _, stderr = await proc.communicate()
+        if proc.returncode != 0:
+            raise LoginFailed(f"SMB mount failed with error: {stderr.decode()}")
+
+    async def unmount(self) -> None:
+        """Unmount the remote share."""
+        proc = await asyncio.create_subprocess_shell(
+            f"umount {self.base_path}",
+            stdout=asyncio.subprocess.PIPE,
+            stderr=asyncio.subprocess.PIPE,
+        )
+        _, stderr = await proc.communicate()
+        if proc.returncode != 0:
+            self.logger.warning("SMB unmount failed with error: %s", stderr.decode())
index 29c068281085d372b3038d745c5abd510f73499c..4566279e4e611b3f05e55b56d0c34e92839d3b45 100644 (file)
@@ -1,11 +1,11 @@
 {
-  "type": "music",
-  "domain": "filesystem_smb",
-  "name": "Filesystem (remote share)",
-  "description": "Support for music files that are present on remote SMB/CIFS or DFS share.",
-  "codeowners": ["@music-assistant"],
-  "requirements": ["smbprotocol==1.10.1"],
-  "documentation": "https://github.com/music-assistant/hass-music-assistant/discussions/820",
-  "multi_instance": true,
-  "icon": "mdi:mdi-network"
-}
+    "type": "music",
+    "domain": "filesystem_smb",
+    "name": "Filesystem (remote share)",
+    "description": "Support for music files that are present on remote SMB/CIFS.",
+    "codeowners": ["@music-assistant"],
+    "requirements": [],
+    "documentation": "https://github.com/music-assistant/hass-music-assistant/discussions/820",
+    "multi_instance": true,
+    "icon": "mdi:mdi-network"
+  }
index 552a9c1e76b57d4ded32869b7c27adfabedebadb..e76cacbd615ebef3e09a42a949938af9702f7a0c 100644 (file)
@@ -20,7 +20,6 @@ plexapi==4.13.2
 PyChromecast==13.0.6
 python-slugify==8.0.1
 shortuuid==1.0.11
-smbprotocol==1.10.1
 soco==0.29.1
 unidecode==1.3.6
 xmltodict==0.13.0