From: Marcel van der Veldt Date: Sat, 13 Dec 2025 00:35:35 +0000 (+0100) Subject: Some final tweaks for user filtering X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=442a913acd9ca45e85da11e3503b978457400242;p=music-assistant-server.git Some final tweaks for user filtering --- diff --git a/music_assistant/controllers/music.py b/music_assistant/controllers/music.py index 485e2afa..348a8f77 100644 --- a/music_assistant/controllers/music.py +++ b/music_assistant/controllers/music.py @@ -203,8 +203,19 @@ class MusicController(CoreController): @property def providers(self) -> list[MusicProvider]: - """Return all loaded/running MusicProviders (instances).""" - return self.mass.get_providers(ProviderType.MUSIC) + """ + Return all loaded/running MusicProviders (instances). + + Note that this applies user provider filters (for all user types). + """ + user = get_current_user() + user_provider_filter = user.provider_filter if user else None + return [ + x + for x in self.mass.providers + if x.type == ProviderType.MUSIC + and (not user_provider_filter or x.instance_id in user_provider_filter) + ] @api_command("music/sync") async def start_sync( @@ -256,10 +267,10 @@ class MusicController(CoreController): :param media_types: A list of media_types to include. :param limit: number of items to return in the search (per type). """ - # use a (short-lived) cache to avoid repeated searches - cache_key = f"{search_query}{'-'.join(sorted([mt.value for mt in media_types]))}-{limit}-{library_only}" # noqa: E501 - if user := get_current_user(): - cache_key += user.user_id + # use cache to avoid repeated searches + search_providers = sorted(self.get_unique_providers()) + cache_provider_key = "library" if library_only else ",".join(search_providers) + cache_key = f"{search_query}{'-'.join(sorted([mt.value for mt in media_types]))}-{limit}-{library_only}-{cache_provider_key}" # noqa: E501 if cache := await self.mass.cache.get( key=cache_key, provider=self.domain, category=CACHE_CATEGORY_SEARCH_RESULTS ): @@ -325,7 +336,6 @@ class MusicController(CoreController): for prov_mapping in item.provider_mappings } # include results from library + all (unique) music providers - search_providers = self.get_unique_providers() results_per_provider += await asyncio.gather( *[ self._search_provider( @@ -398,7 +408,7 @@ class MusicController(CoreController): await self.mass.cache.set( key=cache_key, data=result, - expiration=300, + expiration=600, provider=self.domain, category=CACHE_CATEGORY_SEARCH_RESULTS, ) @@ -1363,34 +1373,36 @@ class MusicController(CoreController): def get_provider_instances( self, domain: str, return_unavailable: bool = False ) -> list[MusicProvider]: - """Return all provider instances for a given domain.""" - return [ - prov - # don't use self.providers here as that applies user filters - for prov in self.mass.providers - if isinstance(prov, MusicProvider) - and prov.domain == domain - and (return_unavailable or prov.available) - ] + """ + Return all provider instances for a given domain. + + Note that this skips user filters so may only be called from internal code. + """ + return cast( + "list[MusicProvider]", + self.mass.get_provider_instances(domain, return_unavailable, ProviderType.MUSIC), + ) - def get_unique_providers(self) -> set[str]: + def get_unique_providers(self) -> list[str]: """ Return all unique MusicProvider (instance or domain) ids. This will return a set of provider instance ids but will only return a single instance_id per streaming provider domain. + + Applies user provider filters (for non-admin users). """ processed_domains: set[str] = set() # Get user provider filter if set user = get_current_user() user_provider_filter = user.provider_filter if user and user.provider_filter else None - result: set[str] = set() + result: list[str] = [] for provider in self.providers: if provider.is_streaming_provider and provider.domain in processed_domains: continue if user_provider_filter and provider.instance_id not in user_provider_filter: continue - result.add(provider.instance_id) + result.append(provider.instance_id) processed_domains.add(provider.domain) return result diff --git a/music_assistant/controllers/players/player_controller.py b/music_assistant/controllers/players/player_controller.py index 85d36c96..3a3704b0 100644 --- a/music_assistant/controllers/players/player_controller.py +++ b/music_assistant/controllers/players/player_controller.py @@ -23,6 +23,7 @@ from collections.abc import Awaitable, Callable, Coroutine from contextlib import suppress from typing import TYPE_CHECKING, Any, Concatenate, TypedDict, cast, overload +from music_assistant_models.auth import UserRole from music_assistant_models.constants import ( PLAYER_CONTROL_FAKE, PLAYER_CONTROL_NATIVE, @@ -249,6 +250,8 @@ class PlayerController(CoreController): """ Return all registered players. + Note that this applies user filters for players (for non admin users). + :param return_unavailable [bool]: Include unavailable players. :param return_disabled [bool]: Include disabled players. :param provider_filter [str]: Optional filter by provider lookup key. @@ -256,7 +259,11 @@ class PlayerController(CoreController): :return: List of Player objects. """ current_user = get_current_user() - user_filter = current_user.player_filter if current_user else [] + user_filter = ( + current_user.player_filter + if current_user and current_user.role != UserRole.ADMIN + else None + ) return [ player for player in self._players.values() @@ -331,6 +338,15 @@ class PlayerController(CoreController): :raises PlayerUnavailableError: If player is unavailable and raise_unavailable is True. :return: Player object or None. """ + current_user = get_current_user() + user_filter = ( + current_user.player_filter + if current_user and current_user.role != UserRole.ADMIN + else None + ) + if current_user and user_filter and player_id not in user_filter: + msg = f"{current_user.username} does not have access to player {player_id}" + raise InsufficientPermissions(msg) if player := self.get(player_id, raise_unavailable): return player.state return None @@ -352,7 +368,16 @@ class PlayerController(CoreController): :param name: Name of the player. :return: PlayerState object or None. """ + current_user = get_current_user() + user_filter = ( + current_user.player_filter + if current_user and current_user.role != UserRole.ADMIN + else None + ) if player := self.get_player_by_name(name): + if current_user and user_filter and player.player_id not in user_filter: + msg = f"{current_user.username} does not have access to player {player.player_id}" + raise InsufficientPermissions(msg) return player.state return None diff --git a/music_assistant/mass.py b/music_assistant/mass.py index 584ab43b..74e57c50 100644 --- a/music_assistant/mass.py +++ b/music_assistant/mass.py @@ -14,6 +14,7 @@ from uuid import uuid4 import aiofiles from aiofiles.os import wrap from music_assistant_models.api import ServerInfoMessage +from music_assistant_models.auth import UserRole from music_assistant_models.enums import EventType, ProviderType from music_assistant_models.errors import MusicAssistantError, SetupFailedError from music_assistant_models.event import MassEvent @@ -277,15 +278,21 @@ class MusicAssistant: def get_providers( self, provider_type: ProviderType | None = None ) -> list[ProviderInstanceType]: - """Return all loaded/running Providers (instances), optionally filtered by ProviderType.""" - user = get_current_user() - user_provider_filter = user.provider_filter if user else None + """ + Return all loaded/running Providers (instances). + Optionally filtered by ProviderType. + Note that this applies user filters for music providers (for non admin users). + """ + user = get_current_user() + user_provider_filter = ( + user.provider_filter if user and user.role != UserRole.ADMIN else None + ) return [ x for x in self._providers.values() if (provider_type is None or provider_type == x.type) - # handle optional user (music) provider filter + # apply user provider filter and ( not user_provider_filter or x.instance_id in user_provider_filter @@ -293,7 +300,7 @@ class MusicAssistant: ) ] - @api_command("logging/get") + @api_command("logging/get", required_role=UserRole.ADMIN) async def get_application_log(self) -> str: """Return the application log from file.""" logfile = os.path.join(self.storage_path, "musicassistant.log") @@ -302,7 +309,11 @@ class MusicAssistant: @property def providers(self) -> list[ProviderInstanceType]: - """Return all loaded/running Providers (instances).""" + """ + Return all loaded/running Providers (instances). + + Note that this skips user filters so may only be called from internal code. + """ return list(self._providers.values()) @overload @@ -350,6 +361,25 @@ class MusicAssistant: return prov return None + def get_provider_instances( + self, + domain: str, + return_unavailable: bool = False, + provider_type: ProviderType | None = None, + ) -> list[ProviderInstanceType]: + """ + Return all provider instances for a given domain. + + Note that this skips user filters so may only be called from internal code. + """ + return [ + prov + for prov in self._providers.values() + if (provider_type is None or provider_type == prov.type) + and prov.domain == domain + and (return_unavailable or prov.available) + ] + def signal_event( self, event: EventType,