lazy load provider requirements
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Fri, 22 Mar 2024 10:50:04 +0000 (11:50 +0100)
committerMarcel van der Veldt <m.vanderveldt@outlook.com>
Fri, 22 Mar 2024 10:50:04 +0000 (11:50 +0100)
music_assistant/server/controllers/config.py
music_assistant/server/helpers/util.py
music_assistant/server/server.py

index b06a949339da0b88bd8a2f43485b425e3af98cfb..5f6aa9137eeea98c35db54b06464068018266366 100644 (file)
@@ -37,7 +37,7 @@ from music_assistant.constants import (
     ENCRYPT_SUFFIX,
 )
 from music_assistant.server.helpers.api import api_command
-from music_assistant.server.helpers.util import get_provider_module
+from music_assistant.server.helpers.util import load_provider_module
 from music_assistant.server.models.player_provider import PlayerProvider
 
 if TYPE_CHECKING:
@@ -231,7 +231,7 @@ class ConfigController:
         # lookup provider manifest and module
         for prov in self.mass.get_provider_manifests():
             if prov.domain == provider_domain:
-                prov_mod = await get_provider_module(provider_domain)
+                prov_mod = await load_provider_module(provider_domain, prov.requirements)
                 break
         else:
             msg = f"Unknown provider domain: {provider_domain}"
index 2c8c6b9c9453c787a1ba17112640fcb8abc076aa..2236936ce532ff32d83b797e1ddee08a395841b8 100644 (file)
@@ -30,6 +30,7 @@ HA_WHEELS = "https://wheels.home-assistant.io/musllinux/"
 
 async def install_package(package: str) -> None:
     """Install package with pip, raise when install failed."""
+    LOGGER.debug("Installing python package %s", package)
     cmd = f"python3 -m pip install --find-links {HA_WHEELS} {package}"
     proc = await asyncio.create_subprocess_shell(
         cmd, stderr=asyncio.subprocess.STDOUT, stdout=asyncio.subprocess.PIPE
@@ -91,13 +92,32 @@ async def is_hass_supervisor() -> bool:
     return await asyncio.to_thread(_check)
 
 
-async def get_provider_module(domain: str) -> ProviderModuleType:
-    """Return module for given provider domain."""
+async def load_provider_module(domain: str, requirements: list[str]) -> ProviderModuleType:
+    """Return module for given provider domain and make sure the requirements are met."""
 
     @lru_cache
     def _get_provider_module(domain: str) -> ProviderModuleType:
         return importlib.import_module(f".{domain}", "music_assistant.server.providers")
 
+    # ensure module requirements are met
+    for requirement in requirements:
+        if "==" not in requirement:
+            # we should really get rid of unpinned requirements
+            continue
+        package_name, version = requirement.split("==", 1)
+        installed_version = await get_package_version(package_name)
+        if installed_version != version:
+            await install_package(requirement)
+
+    # try to load the module
+    try:
+        return await asyncio.to_thread(_get_provider_module, domain)
+    except ImportError:
+        # (re)install ALL requirements
+        for requirement in requirements:
+            await install_package(requirement)
+    # try loading the provider again to be safe
+    # this will fail if something else is wrong (as it should)
     return await asyncio.to_thread(_get_provider_module, domain)
 
 
index 645e3abbd37cd4b4fea5733e7c37c06cca8c6f2f..d087b22e2124c21fae804335ae0af39c2b7b98d7 100644 (file)
@@ -43,9 +43,8 @@ from music_assistant.server.helpers.api import APICommandHandler, api_command
 from music_assistant.server.helpers.images import get_icon_string
 from music_assistant.server.helpers.util import (
     get_package_version,
-    get_provider_module,
-    install_package,
     is_hass_supervisor,
+    load_provider_module,
 )
 
 from .models import ProviderInstanceType
@@ -432,7 +431,7 @@ class MusicAssistant:
                 raise SetupFailedError(msg)
 
         # try to setup the module
-        prov_mod = await get_provider_module(domain)
+        prov_mod = await load_provider_module(domain, prov_manifest.requirements)
         try:
             async with asyncio.timeout(30):
                 provider = await prov_mod.setup(self, prov_manifest, conf)
@@ -556,15 +555,6 @@ class MusicAssistant:
                         icon_path = os.path.join(provider_path, "icon_dark.svg")
                         if os.path.isfile(icon_path):
                             provider_manifest.icon_svg_dark = await get_icon_string(icon_path)
-                    # try to load the module
-                    try:
-                        await get_provider_module(provider_manifest.domain)
-                    except ImportError:
-                        # install requirements
-                        for requirement in provider_manifest.requirements:
-                            await install_package(requirement)
-                        # try loading the provider again to be safe
-                        await get_provider_module(provider_manifest.domain)
                     self._provider_manifests[provider_manifest.domain] = provider_manifest
                     LOGGER.debug("Loaded manifest for provider %s", provider_manifest.name)
                 except Exception as exc:  # pylint: disable=broad-except