From f0312afb3dc123ce06c9865fccd0d36c6a59726d Mon Sep 17 00:00:00 2001 From: Marcel van der Veldt Date: Thu, 11 Apr 2024 16:58:34 +0200 Subject: [PATCH] fix some hidden blocking IO --- music_assistant/common/helpers/uri.py | 5 +++-- music_assistant/server/controllers/music.py | 4 ++-- music_assistant/server/providers/spotify/__init__.py | 2 +- music_assistant/server/server.py | 11 +++++++---- tests/test_helpers.py | 12 ++++++------ 5 files changed, 19 insertions(+), 15 deletions(-) diff --git a/music_assistant/common/helpers/uri.py b/music_assistant/common/helpers/uri.py index 5ed68e9c..8c6bc90e 100644 --- a/music_assistant/common/helpers/uri.py +++ b/music_assistant/common/helpers/uri.py @@ -1,5 +1,6 @@ """Helpers for creating/parsing URI's.""" +import asyncio import os import re @@ -22,7 +23,7 @@ def valid_id(provider: str, item_id: str) -> bool: return True -def parse_uri(uri: str, validate_id: bool = False) -> tuple[MediaType, str, str]: +async def parse_uri(uri: str, validate_id: bool = False) -> tuple[MediaType, str, str]: """Try to parse URI to Mass identifiers. Returns Tuple: MediaType, provider_instance_id_or_domain, item_id @@ -51,7 +52,7 @@ def parse_uri(uri: str, validate_id: bool = False) -> tuple[MediaType, str, str] # spotify new-style uri provider_instance_id_or_domain, media_type_str, item_id = uri.split(":") media_type = MediaType(media_type_str) - elif os.path.isfile(uri): + elif "/" in uri and await asyncio.to_thread(os.path.isfile, uri): # Translate a local file (which is not from file provider) to the URL provider provider_instance_id_or_domain = "url" media_type = MediaType.TRACK diff --git a/music_assistant/server/controllers/music.py b/music_assistant/server/controllers/music.py index 49c8ff58..ad95fa98 100644 --- a/music_assistant/server/controllers/music.py +++ b/music_assistant/server/controllers/music.py @@ -165,7 +165,7 @@ class MusicController(CoreController): """ # Check if the search query is a streaming provider public shareable URL try: - media_type, provider_instance_id_or_domain, item_id = parse_uri( + media_type, provider_instance_id_or_domain, item_id = await parse_uri( search_query, validate_id=True ) except InvalidProviderURI: @@ -343,7 +343,7 @@ class MusicController(CoreController): @api_command("music/item_by_uri") async def get_item_by_uri(self, uri: str) -> MediaItemType: """Fetch MediaItem by uri.""" - media_type, provider_instance_id_or_domain, item_id = parse_uri(uri) + media_type, provider_instance_id_or_domain, item_id = await parse_uri(uri) return await self.get_item( media_type=media_type, item_id=item_id, diff --git a/music_assistant/server/providers/spotify/__init__.py b/music_assistant/server/providers/spotify/__init__.py index 1ba4173e..ee25a68e 100644 --- a/music_assistant/server/providers/spotify/__init__.py +++ b/music_assistant/server/providers/spotify/__init__.py @@ -609,7 +609,7 @@ class SpotifyProvider(MusicProvider): # return existing token if we have one in memory if ( self._auth_token - and os.path.isdir(self._cache_dir) + and asyncio.to_thread(os.path.isdir, self._cache_dir) and (self._auth_token["expiresAt"] > int(time.time()) + 600) ): return self._auth_token diff --git a/music_assistant/server/server.py b/music_assistant/server/server.py index 25da5f74..2e233bdb 100644 --- a/music_assistant/server/server.py +++ b/music_assistant/server/server.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Self from uuid import uuid4 import aiofiles +from aiofiles.os import wrap from aiohttp import ClientSession, TCPConnector from zeroconf import IPVersion, NonUniqueNameException, ServiceStateChange, Zeroconf from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf @@ -55,6 +56,8 @@ if TYPE_CHECKING: from music_assistant.common.models.config_entries import ProviderConfig from music_assistant.server.models.core_controller import CoreController +isdir = wrap(os.path.isdir) +isfile = wrap(os.path.isfile) EventCallBackType = Callable[[MassEvent], None] EventSubscriptionType = tuple[ @@ -561,7 +564,7 @@ class MusicAssistant: # get files in subdirectory for file_str in os.listdir(provider_path): file_path = os.path.join(provider_path, file_str) - if not os.path.isfile(file_path): + if not await isfile(file_path): continue if file_str != "manifest.json": continue @@ -570,12 +573,12 @@ class MusicAssistant: # check for icon.svg file if not provider_manifest.icon_svg: icon_path = os.path.join(provider_path, "icon.svg") - if os.path.isfile(icon_path): + if await isfile(icon_path): provider_manifest.icon_svg = await get_icon_string(icon_path) # check for dark_icon file if not provider_manifest.icon_svg_dark: icon_path = os.path.join(provider_path, "icon_dark.svg") - if os.path.isfile(icon_path): + if await isfile(icon_path): provider_manifest.icon_svg_dark = await get_icon_string(icon_path) self._provider_manifests[provider_manifest.domain] = provider_manifest LOGGER.debug("Loaded manifest for provider %s", provider_manifest.name) @@ -589,7 +592,7 @@ class MusicAssistant: async with asyncio.TaskGroup() as tg: for dir_str in os.listdir(PROVIDERS_PATH): dir_path = os.path.join(PROVIDERS_PATH, dir_str) - if not os.path.isdir(dir_path): + if not await isdir(dir_path): continue tg.create_task(load_provider_manifest(dir_str, dir_path)) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 25903ace..24f0ab81 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -15,32 +15,32 @@ def test_version_extract() -> None: assert version == "Karaoke Version" -def test_uri_parsing() -> None: +async def test_uri_parsing() -> None: """Test parsing of URI.""" # test regular uri test_uri = "spotify://track/123456789" - media_type, provider, item_id = uri.parse_uri(test_uri) + media_type, provider, item_id = await uri.parse_uri(test_uri) assert media_type == media_items.MediaType.TRACK assert provider == "spotify" assert item_id == "123456789" # test spotify uri test_uri = "spotify:track:123456789" - media_type, provider, item_id = uri.parse_uri(test_uri) + media_type, provider, item_id = await uri.parse_uri(test_uri) assert media_type == media_items.MediaType.TRACK assert provider == "spotify" assert item_id == "123456789" # test public play/open url test_uri = "https://open.spotify.com/playlist/5lH9NjOeJvctAO92ZrKQNB?si=04a63c8234ac413e" - media_type, provider, item_id = uri.parse_uri(test_uri) + media_type, provider, item_id = await uri.parse_uri(test_uri) assert media_type == media_items.MediaType.PLAYLIST assert provider == "spotify" assert item_id == "5lH9NjOeJvctAO92ZrKQNB" # test filename with slashes as item_id test_uri = "filesystem://track/Artist/Album/Track.flac" - media_type, provider, item_id = uri.parse_uri(test_uri) + media_type, provider, item_id = await uri.parse_uri(test_uri) assert media_type == media_items.MediaType.TRACK assert provider == "filesystem" assert item_id == "Artist/Album/Track.flac" # test invalid uri with pytest.raises(MusicAssistantError): - uri.parse_uri("invalid://blah") + await uri.parse_uri("invalid://blah") -- 2.34.1