import logging
import os
from collections.abc import Awaitable, Callable, Coroutine
-from typing import TYPE_CHECKING, Any, Self, TypeVar
+from typing import TYPE_CHECKING, Any, Self, TypeGuard, TypeVar
from uuid import uuid4
import aiofiles
load_provider_module,
)
from music_assistant.models import ProviderInstanceType
+from music_assistant.models.music_provider import MusicProvider
+from music_assistant.models.player_provider import PlayerProvider
if TYPE_CHECKING:
from types import TracebackType
_R = TypeVar("_R")
+def is_music_provider(provider: ProviderInstanceType) -> TypeGuard[MusicProvider]:
+ """Type guard that returns true if a provider is a music provider."""
+ return provider.type == ProviderType.MUSIC
+
+
+def is_player_provider(provider: ProviderInstanceType) -> TypeGuard[PlayerProvider]:
+ """Type guard that returns true if a provider is a player provider."""
+ return provider.type == ProviderType.PLAYER
+
+
class MusicAssistant:
"""Main MusicAssistant (Server) object."""
self._subscribers: set[EventSubscriptionType] = set()
self._provider_manifests: dict[str, ProviderManifest] = {}
self._providers: dict[str, ProviderInstanceType] = {}
- self._tracked_tasks: dict[str, asyncio.Task] = {}
+ self._tracked_tasks: dict[str, asyncio.Task[Any]] = {}
self._tracked_timers: dict[str, asyncio.TimerHandle] = {}
self.closing = False
self.running_as_hass_addon: bool = False
"""Return the application log from file."""
logfile = os.path.join(self.storage_path, "musicassistant.log")
async with aiofiles.open(logfile) as _file:
- return await _file.read()
+ return str(await _file.read())
@property
def providers(self) -> list[ProviderInstanceType]:
cb_func: EventCallBackType,
event_filter: EventType | tuple[EventType, ...] | None = None,
id_filter: str | tuple[str, ...] | None = None,
- ) -> Callable:
+ ) -> Callable[[], None]:
"""Add callback to event listeners.
Returns function to remove the listener.
Tasks created by this helper will be properly cancelled on stop.
"""
- if target is None:
- msg = "Target is missing"
- raise RuntimeError(msg)
if task_id and (existing := self._tracked_tasks.get(task_id)) and not existing.done():
# prevent duplicate tasks if task_id is given and already present
if abort_existing:
elif asyncio.iscoroutine(target):
# coroutine
task = self.loop.create_task(target)
- else:
+ elif callable(target):
task = self.loop.create_task(asyncio.to_thread(target, *args, **kwargs))
+ else:
+ raise RuntimeError("Target is missing")
- def task_done_callback(_task: asyncio.Task) -> None:
- _task_id = task.task_id
- self._tracked_tasks.pop(_task_id, None)
+ if task_id is None:
+ task_id = uuid4().hex
+
+ def task_done_callback(_task: asyncio.Task[Any]) -> None:
+ self._tracked_tasks.pop(task_id, None)
# log unhandled exceptions
if (
LOGGER.isEnabledFor(logging.DEBUG)
exc_info=err if LOGGER.isEnabledFor(logging.DEBUG) else None,
)
- if task_id is None:
- task_id = uuid4().hex
- task.task_id = task_id
self._tracked_tasks[task_id] = task
task.add_done_callback(task_done_callback)
return task
def call_later(
self,
delay: float,
- target: Coroutine | Awaitable | Callable,
+ target: Coroutine[Any, Any, _R] | Awaitable[_R] | Callable[..., _R],
*args: Any,
task_id: str | None = None,
**kwargs: Any,
self._tracked_timers[task_id] = handle
return handle
- def get_task(self, task_id: str) -> asyncio.Task:
+ def get_task(self, task_id: str) -> asyncio.Task[Any]:
"""Get existing scheduled task."""
if existing := self._tracked_tasks.get(task_id):
# prevent duplicate tasks if task_id is given and already present
def register_api_command(
self,
command: str,
- handler: Callable,
- ) -> None:
+ handler: Callable[..., Coroutine[Any, Any, Any]],
+ ) -> Callable[[], None]:
"""
Dynamically register a command on the API.
# make sure to stop any running sync tasks first
for sync_task in self.music.in_progress_syncs:
if sync_task.provider_instance == instance_id:
- sync_task.task.cancel()
+ if sync_task.task:
+ sync_task.task.cancel()
# check if there are no other providers dependent of this provider
for dep_prov in self.providers:
if dep_prov.manifest.depends_on == provider.domain:
await self.unload_provider(dep_prov.instance_id)
- if provider.type == ProviderType.PLAYER:
+ if is_player_provider(provider):
# mark all players of this provider as unavailable
for player in provider.players:
player.available = False
prov_manifest = self._provider_manifests.get(domain)
# check for other instances of this provider
existing = next((x for x in self.providers if x.domain == domain), None)
- if existing and not prov_manifest.multi_instance:
+ if existing and prov_manifest and not prov_manifest.multi_instance:
msg = f"Provider {domain} already loaded and only one instance allowed."
raise SetupFailedError(msg)
# check valid manifest (just in case)
) -> None:
"""Handle MDNS service state callback."""
- async def process_mdns_state_change(prov: ProviderInstanceType):
+ async def process_mdns_state_change(prov: ProviderInstanceType) -> None:
if state_change == ServiceStateChange.Removed:
info = None
else:
) -> bool | None:
"""Exit context manager."""
await self.stop()
+ return None
async def _update_available_providers_cache(self) -> None:
"""Update the global cache variable of loaded/available providers."""
"streaming_providers": {
x.lookup_key
for x in self.providers
- if x.type == ProviderType.MUSIC and x.is_streaming_provider
+ if is_music_provider(x) and x.is_streaming_provider
},
"non_streaming_providers": {
x.lookup_key
for x in self.providers
- if not (x.type == ProviderType.MUSIC and x.is_streaming_provider)
+ if not (is_music_provider(x) and x.is_streaming_provider)
},
}
)