Include a Snapserver in the snapcast provider (#1150)
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Mon, 18 Mar 2024 10:34:20 +0000 (11:34 +0100)
committerGitHub <noreply@github.com>
Mon, 18 Mar 2024 10:34:20 +0000 (11:34 +0100)
Include a snapserver by default

music_assistant/common/helpers/util.py
music_assistant/server/helpers/process.py
music_assistant/server/providers/snapcast/__init__.py

index 2dd0657ab01ccf53b2b9e93a79191e323b6aba45..9b7bd05597e70542169819df24cba5fe1ad48a4d 100644 (file)
@@ -151,11 +151,10 @@ async def get_ip():
     return await asyncio.to_thread(_get_ip)
 
 
-async def select_free_port(range_start: int, range_end: int) -> int:
-    """Automatically find available port within range."""
+async def is_port_in_use(port: int) -> bool:
+    """Check if port is in use."""
 
-    def is_port_in_use(port: int) -> bool:
-        """Check if port is in use."""
+    def _is_port_in_use() -> bool:
         with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as _sock:
             try:
                 _sock.bind(("0.0.0.0", port))
@@ -163,14 +162,16 @@ async def select_free_port(range_start: int, range_end: int) -> int:
                 return True
         return False
 
-    def _select_free_port():
-        for port in range(range_start, range_end):
-            if not is_port_in_use(port):
-                return port
-        msg = "No free port available"
-        raise OSError(msg)
+    return await asyncio.to_thread(_is_port_in_use)
+
 
-    return await asyncio.to_thread(_select_free_port)
+async def select_free_port(range_start: int, range_end: int) -> int:
+    """Automatically find available port within range."""
+    for port in range(range_start, range_end):
+        if not await is_port_in_use(port):
+            return port
+    msg = "No free port available"
+    raise OSError(msg)
 
 
 async def get_ip_from_host(dns_name: str) -> str | None:
index 8668b4992da55b9c1e99ce6b2f4cbcf8673d5b05..1e9c5272b5e70713d5aa9936300e5ccee0ca0a66 100644 (file)
@@ -133,11 +133,12 @@ class AsyncProcess:
         if self.returncode is not None:
             return self.returncode
         # make sure the process is cleaned up
+        self._proc.terminate()
         try:
             async with asyncio.timeout(10):
                 await self.communicate()
         except (TimeoutError, asyncio.CancelledError):
-            self._proc.terminate()
+            self._proc.kill()
         return await self.wait()
 
     async def wait(self) -> int:
@@ -151,14 +152,10 @@ class AsyncProcess:
         stdout, stderr = await self._proc.communicate(input_data)
         return (stdout, stderr)
 
-    async def read_stderr(self, n: int = -1) -> bytes:
-        """Read up to n bytes from the stderr stream.
-
-        If n is positive, this function try to read n bytes,
-        and may return less or equal bytes than requested, but at least one byte.
-        If EOF was received before any byte is read, this function returns empty byte object.
-        """
-        return await self._proc.stderr.read(n)
+    async def read_stderr(self) -> AsyncGenerator[bytes, None]:
+        """Read lines from the stderr stream."""
+        async for line in self._proc.stderr:
+            yield line
 
 
 async def check_output(shell_cmd: str) -> tuple[int, bytes]:
index 11d352270d726dc9fc25da90cad51e0a1bfb167c..7a29bf7950b9c22264c351b8cb865c60f37e7ea5 100644 (file)
@@ -30,6 +30,7 @@ from music_assistant.common.models.errors import SetupFailedError
 from music_assistant.common.models.media_items import AudioFormat
 from music_assistant.common.models.player import DeviceInfo, Player
 from music_assistant.server.helpers.audio import get_media_stream
+from music_assistant.server.helpers.process import AsyncProcess, check_output
 from music_assistant.server.models.player_provider import PlayerProvider
 from music_assistant.server.providers.ugp import UGP_PREFIX
 
@@ -44,14 +45,16 @@ if TYPE_CHECKING:
     from music_assistant.server import MusicAssistant
     from music_assistant.server.models import ProviderInstanceType
 
-CONF_SNAPCAST_SERVER_HOST = "snapcast_server_host"
-CONF_SNAPCAST_SERVER_CONTROL_PORT = "snapcast_server_control_port"
+CONF_SERVER_HOST = "snapcast_server_host"
+CONF_SERVER_CONTROL_PORT = "snapcast_server_control_port"
+CONF_USE_EXTERNAL_SERVER = "snapcast_use_external_server"
 
 SNAP_STREAM_STATUS_MAP = {
     "idle": PlayerState.IDLE,
     "playing": PlayerState.PLAYING,
     "unknown": PlayerState.IDLE,
 }
+DEFAULT_SNAPSERVER_PORT = 1705
 
 
 async def setup(
@@ -64,10 +67,10 @@ async def setup(
 
 
 async def get_config_entries(
-    mass: MusicAssistant,
-    instance_id: str | None = None,
-    action: str | None = None,
-    values: dict[str, ConfigValueType] | None = None,
+    mass: MusicAssistant,  # noqa: ARG001
+    instance_id: str | None = None,  # noqa: ARG001
+    action: str | None = None,  # noqa: ARG001
+    values: dict[str, ConfigValueType] | None = None,  # noqa: ARG001
 ) -> tuple[ConfigEntry, ...]:
     """
     Return Config entries to setup this provider.
@@ -76,21 +79,37 @@ async def get_config_entries(
     action: [optional] action key called from config entries UI.
     values: the (intermediate) raw values for config entries sent with the action.
     """
-    # ruff: noqa: ARG001
+    returncode, output = await check_output("snapserver -v")
+    snapserver_present = returncode == 0 and "snapserver v0.27.0" in output.decode()
     return (
         ConfigEntry(
-            key=CONF_SNAPCAST_SERVER_HOST,
+            key=CONF_USE_EXTERNAL_SERVER,
+            type=ConfigEntryType.BOOLEAN,
+            default_value=not snapserver_present,
+            label="Use existing Snapserver",
+            required=False,
+            description="Music Assistant by default already includes a Snapserver. \n\n"
+            "Checking this option allows you to connect to your own/external existing Snapserver "
+            "and not use the builtin one provided by Music Assistant.",
+            advanced=snapserver_present,
+        ),
+        ConfigEntry(
+            key=CONF_SERVER_HOST,
             type=ConfigEntryType.STRING,
             default_value="127.0.0.1",
             label="Snapcast server ip",
-            required=True,
+            required=False,
+            depends_on=CONF_USE_EXTERNAL_SERVER,
+            advanced=snapserver_present,
         ),
         ConfigEntry(
-            key=CONF_SNAPCAST_SERVER_CONTROL_PORT,
+            key=CONF_SERVER_CONTROL_PORT,
             type=ConfigEntryType.INTEGER,
-            default_value="1705",
+            default_value=DEFAULT_SNAPSERVER_PORT,
             label="Snapcast control port",
-            required=True,
+            required=False,
+            depends_on=CONF_USE_EXTERNAL_SERVER,
+            advanced=snapserver_present,
         ),
     )
 
@@ -99,9 +118,12 @@ class SnapCastProvider(PlayerProvider):
     """Player provider for Snapcast based players."""
 
     _snapserver: Snapserver
-    snapcast_server_host: str
-    snapcast_server_control_port: int
+    _snapcast_server_host: str
+    _snapcast_server_control_port: int
     _stream_tasks: dict[str, asyncio.Task]
+    _use_builtin_server: bool
+    _snapserver_runner: asyncio.Task | None
+    _snapserver_started = asyncio.Event | None
 
     @property
     def supported_features(self) -> tuple[ProviderFeature, ...]:
@@ -110,20 +132,29 @@ class SnapCastProvider(PlayerProvider):
 
     async def handle_async_init(self) -> None:
         """Handle async initialization of the provider."""
-        self.snapcast_server_host = self.config.get_value(CONF_SNAPCAST_SERVER_HOST)
-        self.snapcast_server_control_port = self.config.get_value(CONF_SNAPCAST_SERVER_CONTROL_PORT)
+        self._snapcast_server_host = self.config.get_value(CONF_SERVER_HOST)
+        self._snapcast_server_control_port = self.config.get_value(CONF_SERVER_CONTROL_PORT)
+        self._use_builtin_server = not self.config.get_value(CONF_USE_EXTERNAL_SERVER)
         self._stream_tasks = {}
+        if self._use_builtin_server:
+            # start our own builtin snapserver
+            self._snapserver_started = asyncio.Event()
+            self._snapserver_runner = asyncio.create_task(self._builtin_server_runner())
+            await asyncio.wait_for(self._snapserver_started.wait(), 10)
+        else:
+            self._snapserver_runner = None
+            self._snapserver_started = None
         try:
             self._snapserver = await create_server(
                 self.mass.loop,
-                self.snapcast_server_host,
-                port=self.snapcast_server_control_port,
+                self._snapcast_server_host,
+                port=self._snapcast_server_control_port,
                 reconnect=True,
             )
             self._snapserver.set_on_update_callback(self._handle_update)
             self.logger.info(
-                f"Started Snapserver connection on:"
-                f"{self.snapcast_server_host}:{self.snapcast_server_control_port}"
+                "Started connection to Snapserver %s",
+                f"{self._snapcast_server_host}:{self._snapcast_server_control_port}",
             )
         except OSError as err:
             msg = "Unable to start the Snapserver connection ?"
@@ -139,6 +170,10 @@ class SnapCastProvider(PlayerProvider):
         for client in self._snapserver.clients:
             await self.cmd_stop(client.identifier)
         await self._snapserver.stop()
+        self._snapserver_started.clear()
+        if self._snapserver_runner and not self._snapserver_runner.done():
+            self._snapserver_runner.cancel()
+        await asyncio.sleep(2)  # prevent race conditions when reloading
 
     def _handle_update(self) -> None:
         """Process Snapcast init Player/Group and set callback ."""
@@ -287,7 +322,7 @@ class SnapCastProvider(PlayerProvider):
             )
 
         async def _streamer() -> None:
-            host = self.snapcast_server_host
+            host = self._snapcast_server_host
             _, writer = await asyncio.open_connection(host, port)
             self.logger.debug("Opened connection to %s:%s", host, port)
             player.current_item_id = f"{queue_item.queue_id}.{queue_item.queue_item_id}"
@@ -374,3 +409,20 @@ class SnapCastProvider(PlayerProvider):
             player = self.mass.players.get(child_player_id)
             player.state = state
             self.mass.players.update(child_player_id)
+
+    async def _builtin_server_runner(self) -> None:
+        """Start running the builtin snapserver."""
+        if self._snapserver_started.is_set():
+            raise RuntimeError("Snapserver is already started!")
+        logger = self.logger.getChild("snapserver")
+        logger.info("Starting builtin Snapserver...")
+        async with AsyncProcess(
+            ["snapserver"], enable_stdin=False, enable_stdout=True, enable_stderr=False
+        ) as snapserver_proc:
+            # keep reading from stderr until exit
+            async for data in snapserver_proc.iter_any():
+                data = data.decode().strip()  # noqa: PLW2901
+                for line in data.split("\n"):
+                    logger.debug(line)
+                    if "Name now registered and active" in line:
+                        self._snapserver_started.set()