fix some hidden blocking IO
authorMarcel van der Veldt <m.vanderveldt@outlook.com>
Thu, 11 Apr 2024 14:58:34 +0000 (16:58 +0200)
committerMarcel van der Veldt <m.vanderveldt@outlook.com>
Thu, 11 Apr 2024 14:58:34 +0000 (16:58 +0200)
music_assistant/common/helpers/uri.py
music_assistant/server/controllers/music.py
music_assistant/server/providers/spotify/__init__.py
music_assistant/server/server.py
tests/test_helpers.py

index 5ed68e9c1b8c11d424305d9bc445300f99d0a02f..8c6bc90e16c89c28a36fcf84c62595419729fc7d 100644 (file)
@@ -1,5 +1,6 @@
 """Helpers for creating/parsing URI's."""
 
+import asyncio
 import os
 import re
 
@@ -22,7 +23,7 @@ def valid_id(provider: str, item_id: str) -> bool:
         return True
 
 
-def parse_uri(uri: str, validate_id: bool = False) -> tuple[MediaType, str, str]:
+async def parse_uri(uri: str, validate_id: bool = False) -> tuple[MediaType, str, str]:
     """Try to parse URI to Mass identifiers.
 
     Returns Tuple: MediaType, provider_instance_id_or_domain, item_id
@@ -51,7 +52,7 @@ def parse_uri(uri: str, validate_id: bool = False) -> tuple[MediaType, str, str]
             # spotify new-style uri
             provider_instance_id_or_domain, media_type_str, item_id = uri.split(":")
             media_type = MediaType(media_type_str)
-        elif os.path.isfile(uri):
+        elif "/" in uri and await asyncio.to_thread(os.path.isfile, uri):
             # Translate a local file (which is not from file provider) to the URL provider
             provider_instance_id_or_domain = "url"
             media_type = MediaType.TRACK
index 49c8ff589271db7a69e058c5a0f6f7f765b7e075..ad95fa98a985cad71f4f685bd8531e0a36295c9b 100644 (file)
@@ -165,7 +165,7 @@ class MusicController(CoreController):
         """
         # Check if the search query is a streaming provider public shareable URL
         try:
-            media_type, provider_instance_id_or_domain, item_id = parse_uri(
+            media_type, provider_instance_id_or_domain, item_id = await parse_uri(
                 search_query, validate_id=True
             )
         except InvalidProviderURI:
@@ -343,7 +343,7 @@ class MusicController(CoreController):
     @api_command("music/item_by_uri")
     async def get_item_by_uri(self, uri: str) -> MediaItemType:
         """Fetch MediaItem by uri."""
-        media_type, provider_instance_id_or_domain, item_id = parse_uri(uri)
+        media_type, provider_instance_id_or_domain, item_id = await parse_uri(uri)
         return await self.get_item(
             media_type=media_type,
             item_id=item_id,
index 1ba4173e22ac68709d4ac5e2c99646d6b3be67e1..ee25a68e7f0cba7e1dd5ce409e1ad9169eaaaf8c 100644 (file)
@@ -609,7 +609,7 @@ class SpotifyProvider(MusicProvider):
         # return existing token if we have one in memory
         if (
             self._auth_token
-            and os.path.isdir(self._cache_dir)
+            and asyncio.to_thread(os.path.isdir, self._cache_dir)
             and (self._auth_token["expiresAt"] > int(time.time()) + 600)
         ):
             return self._auth_token
index 25da5f74ab5db00cfda23702c2b533d06e78c459..2e233bdbf9866aba918d89b0638eb64110ce3f7a 100644 (file)
@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Self
 from uuid import uuid4
 
 import aiofiles
+from aiofiles.os import wrap
 from aiohttp import ClientSession, TCPConnector
 from zeroconf import IPVersion, NonUniqueNameException, ServiceStateChange, Zeroconf
 from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
@@ -55,6 +56,8 @@ if TYPE_CHECKING:
     from music_assistant.common.models.config_entries import ProviderConfig
     from music_assistant.server.models.core_controller import CoreController
 
+isdir = wrap(os.path.isdir)
+isfile = wrap(os.path.isfile)
 
 EventCallBackType = Callable[[MassEvent], None]
 EventSubscriptionType = tuple[
@@ -561,7 +564,7 @@ class MusicAssistant:
             # get files in subdirectory
             for file_str in os.listdir(provider_path):
                 file_path = os.path.join(provider_path, file_str)
-                if not os.path.isfile(file_path):
+                if not await isfile(file_path):
                     continue
                 if file_str != "manifest.json":
                     continue
@@ -570,12 +573,12 @@ class MusicAssistant:
                     # check for icon.svg file
                     if not provider_manifest.icon_svg:
                         icon_path = os.path.join(provider_path, "icon.svg")
-                        if os.path.isfile(icon_path):
+                        if await isfile(icon_path):
                             provider_manifest.icon_svg = await get_icon_string(icon_path)
                     # check for dark_icon file
                     if not provider_manifest.icon_svg_dark:
                         icon_path = os.path.join(provider_path, "icon_dark.svg")
-                        if os.path.isfile(icon_path):
+                        if await isfile(icon_path):
                             provider_manifest.icon_svg_dark = await get_icon_string(icon_path)
                     self._provider_manifests[provider_manifest.domain] = provider_manifest
                     LOGGER.debug("Loaded manifest for provider %s", provider_manifest.name)
@@ -589,7 +592,7 @@ class MusicAssistant:
         async with asyncio.TaskGroup() as tg:
             for dir_str in os.listdir(PROVIDERS_PATH):
                 dir_path = os.path.join(PROVIDERS_PATH, dir_str)
-                if not os.path.isdir(dir_path):
+                if not await isdir(dir_path):
                     continue
                 tg.create_task(load_provider_manifest(dir_str, dir_path))
 
index 25903aceb7be5966162345e3e524dbaee2d81616..24f0ab818c4b2300a8f47794e8de1d49f8ebf170 100644 (file)
@@ -15,32 +15,32 @@ def test_version_extract() -> None:
     assert version == "Karaoke Version"
 
 
-def test_uri_parsing() -> None:
+async def test_uri_parsing() -> None:
     """Test parsing of URI."""
     # test regular uri
     test_uri = "spotify://track/123456789"
-    media_type, provider, item_id = uri.parse_uri(test_uri)
+    media_type, provider, item_id = await uri.parse_uri(test_uri)
     assert media_type == media_items.MediaType.TRACK
     assert provider == "spotify"
     assert item_id == "123456789"
     # test spotify uri
     test_uri = "spotify:track:123456789"
-    media_type, provider, item_id = uri.parse_uri(test_uri)
+    media_type, provider, item_id = await uri.parse_uri(test_uri)
     assert media_type == media_items.MediaType.TRACK
     assert provider == "spotify"
     assert item_id == "123456789"
     # test public play/open url
     test_uri = "https://open.spotify.com/playlist/5lH9NjOeJvctAO92ZrKQNB?si=04a63c8234ac413e"
-    media_type, provider, item_id = uri.parse_uri(test_uri)
+    media_type, provider, item_id = await uri.parse_uri(test_uri)
     assert media_type == media_items.MediaType.PLAYLIST
     assert provider == "spotify"
     assert item_id == "5lH9NjOeJvctAO92ZrKQNB"
     # test filename with slashes as item_id
     test_uri = "filesystem://track/Artist/Album/Track.flac"
-    media_type, provider, item_id = uri.parse_uri(test_uri)
+    media_type, provider, item_id = await uri.parse_uri(test_uri)
     assert media_type == media_items.MediaType.TRACK
     assert provider == "filesystem"
     assert item_id == "Artist/Album/Track.flac"
     # test invalid uri
     with pytest.raises(MusicAssistantError):
-        uri.parse_uri("invalid://blah")
+        await uri.parse_uri("invalid://blah")