From: Marcel van der Veldt Date: Fri, 22 Mar 2024 10:50:04 +0000 (+0100) Subject: lazy load provider requirements X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=ba8919c7bd770d2603283c7f99817dc9dc5fea79;p=music-assistant-server.git lazy load provider requirements --- diff --git a/music_assistant/server/controllers/config.py b/music_assistant/server/controllers/config.py index b06a9493..5f6aa913 100644 --- a/music_assistant/server/controllers/config.py +++ b/music_assistant/server/controllers/config.py @@ -37,7 +37,7 @@ from music_assistant.constants import ( ENCRYPT_SUFFIX, ) from music_assistant.server.helpers.api import api_command -from music_assistant.server.helpers.util import get_provider_module +from music_assistant.server.helpers.util import load_provider_module from music_assistant.server.models.player_provider import PlayerProvider if TYPE_CHECKING: @@ -231,7 +231,7 @@ class ConfigController: # lookup provider manifest and module for prov in self.mass.get_provider_manifests(): if prov.domain == provider_domain: - prov_mod = await get_provider_module(provider_domain) + prov_mod = await load_provider_module(provider_domain, prov.requirements) break else: msg = f"Unknown provider domain: {provider_domain}" diff --git a/music_assistant/server/helpers/util.py b/music_assistant/server/helpers/util.py index 2c8c6b9c..2236936c 100644 --- a/music_assistant/server/helpers/util.py +++ b/music_assistant/server/helpers/util.py @@ -30,6 +30,7 @@ HA_WHEELS = "https://wheels.home-assistant.io/musllinux/" async def install_package(package: str) -> None: """Install package with pip, raise when install failed.""" + LOGGER.debug("Installing python package %s", package) cmd = f"python3 -m pip install --find-links {HA_WHEELS} {package}" proc = await asyncio.create_subprocess_shell( cmd, stderr=asyncio.subprocess.STDOUT, stdout=asyncio.subprocess.PIPE @@ -91,13 +92,32 @@ async def is_hass_supervisor() -> bool: return await asyncio.to_thread(_check) -async def get_provider_module(domain: str) -> ProviderModuleType: - """Return module for given provider domain.""" +async def load_provider_module(domain: str, requirements: list[str]) -> ProviderModuleType: + """Return module for given provider domain and make sure the requirements are met.""" @lru_cache def _get_provider_module(domain: str) -> ProviderModuleType: return importlib.import_module(f".{domain}", "music_assistant.server.providers") + # ensure module requirements are met + for requirement in requirements: + if "==" not in requirement: + # we should really get rid of unpinned requirements + continue + package_name, version = requirement.split("==", 1) + installed_version = await get_package_version(package_name) + if installed_version != version: + await install_package(requirement) + + # try to load the module + try: + return await asyncio.to_thread(_get_provider_module, domain) + except ImportError: + # (re)install ALL requirements + for requirement in requirements: + await install_package(requirement) + # try loading the provider again to be safe + # this will fail if something else is wrong (as it should) return await asyncio.to_thread(_get_provider_module, domain) diff --git a/music_assistant/server/server.py b/music_assistant/server/server.py index 645e3abb..d087b22e 100644 --- a/music_assistant/server/server.py +++ b/music_assistant/server/server.py @@ -43,9 +43,8 @@ from music_assistant.server.helpers.api import APICommandHandler, api_command from music_assistant.server.helpers.images import get_icon_string from music_assistant.server.helpers.util import ( get_package_version, - get_provider_module, - install_package, is_hass_supervisor, + load_provider_module, ) from .models import ProviderInstanceType @@ -432,7 +431,7 @@ class MusicAssistant: raise SetupFailedError(msg) # try to setup the module - prov_mod = await get_provider_module(domain) + prov_mod = await load_provider_module(domain, prov_manifest.requirements) try: async with asyncio.timeout(30): provider = await prov_mod.setup(self, prov_manifest, conf) @@ -556,15 +555,6 @@ class MusicAssistant: icon_path = os.path.join(provider_path, "icon_dark.svg") if os.path.isfile(icon_path): provider_manifest.icon_svg_dark = await get_icon_string(icon_path) - # try to load the module - try: - await get_provider_module(provider_manifest.domain) - except ImportError: - # install requirements - for requirement in provider_manifest.requirements: - await install_package(requirement) - # try loading the provider again to be safe - await get_provider_module(provider_manifest.domain) self._provider_manifests[provider_manifest.domain] = provider_manifest LOGGER.debug("Loaded manifest for provider %s", provider_manifest.name) except Exception as exc: # pylint: disable=broad-except