import logging
import secrets
from datetime import datetime, timedelta
-from sqlite3 import OperationalError
+from sqlite3 import IntegrityError, OperationalError
from typing import TYPE_CHECKING, Any
import jwt as pyjwt
InvalidDataError,
)
-from music_assistant.constants import (
- DB_TABLE_PLAYLOG,
- HOMEASSISTANT_SYSTEM_USER,
- MASS_LOGGER_NAME,
-)
+from music_assistant.constants import DB_TABLE_PLAYLOG, HOMEASSISTANT_SYSTEM_USER, MASS_LOGGER_NAME
from music_assistant.controllers.webserver.helpers.auth_middleware import (
get_current_token,
get_current_user,
LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.auth")
# Database schema version
-DB_SCHEMA_VERSION = 4
+DB_SCHEMA_VERSION = 5
# Token expiration constants (in days)
TOKEN_SHORT_LIVED_EXPIRATION = 30 # Short-lived tokens (auto-renewing on use)
TOKEN_LONG_LIVED_EXPIRATION = 3650 # Long-lived tokens (10 years, no auto-renewal)
+# Join code constants (short codes for QR/link-based login)
+JOIN_CODE_LENGTH = 6
+JOIN_CODE_CHARSET = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" # No I/O/0/1 for readability
+JOIN_CODE_DEFAULT_EXPIRY_HOURS = 8
+
class AuthenticationManager:
"""Manager for authentication and user management (part of webserver controller)."""
self._has_users = await self._has_non_system_users()
+ self._schedule_join_code_cleanup()
+
self.logger.info(
"Authentication manager initialized (providers=%d)", len(self.login_providers)
)
)
"""
)
+ # Join codes table (for short code to JWT exchange, used by providers like party mode)
+ await self.database.execute(
+ """
+ CREATE TABLE IF NOT EXISTS join_codes (
+ code_id TEXT PRIMARY KEY,
+ code TEXT NOT NULL UNIQUE,
+ user_id TEXT NOT NULL,
+ created_at TEXT NOT NULL,
+ expires_at TEXT NOT NULL,
+ max_uses INTEGER DEFAULT 0,
+ use_count INTEGER DEFAULT 0,
+ last_used_at TEXT,
+ device_name TEXT,
+ provider_name TEXT,
+ FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
+ )
+ """
+ )
await self.database.commit()
async def _create_database_indexes(self) -> None:
await self.database.execute(
"CREATE INDEX IF NOT EXISTS idx_tokens_hash ON auth_tokens(token_hash)"
)
+ await self.database.execute(
+ "CREATE INDEX IF NOT EXISTS idx_join_codes_user ON join_codes(user_id)"
+ )
async def _migrate_database(self, from_version: int) -> None:
"""Perform database migration.
await self.database.execute("UPDATE users SET username = LOWER(username)")
await self.database.commit()
+ # Migration to version 5: Add join codes table
+ if from_version < 5:
+ await self.database.execute(
+ """
+ CREATE TABLE IF NOT EXISTS join_codes (
+ code_id TEXT PRIMARY KEY,
+ code TEXT NOT NULL UNIQUE,
+ user_id TEXT NOT NULL,
+ created_at TEXT NOT NULL,
+ expires_at TEXT NOT NULL,
+ max_uses INTEGER DEFAULT 0,
+ use_count INTEGER DEFAULT 0,
+ last_used_at TEXT,
+ device_name TEXT,
+ provider_name TEXT,
+ FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
+ )
+ """
+ )
+ await self.database.commit()
+
async def _get_or_create_jwt_secret(self) -> str:
"""Get or create JWT secret key from database.
async def authenticate_with_token(self, token: str) -> User | None:
"""
- Authenticate a user with an access token.
+ Authenticate a user with an access token (JWT or legacy).
Supports both JWT tokens and legacy hash-based tokens for backward compatibility.
# Implement sliding expiration for short-lived tokens
is_long_lived = bool(token_row["is_long_lived"])
now = utc()
- updates = {"last_used_at": now.isoformat()}
+ legacy_updates: dict[str, str] = {"last_used_at": now.isoformat()}
if not is_long_lived and token_row["expires_at"]:
# Short-lived token: extend expiration on each use (sliding window)
new_expires_at = now + timedelta(days=TOKEN_SHORT_LIVED_EXPIRATION)
- updates["expires_at"] = new_expires_at.isoformat()
+ legacy_updates["expires_at"] = new_expires_at.isoformat()
# Update last used timestamp and potentially expiration
await self.database.update(
"auth_tokens",
{"token_id": token_row["token_id"]},
- updates,
+ legacy_updates,
)
# Get user
token_row = await self.database.get_row("auth_tokens", {"token_hash": token_hash})
if not token_row:
return None
-
return str(token_row["token_id"])
@api_command("auth/user", required_role="admin")
# Create new link
await self.link_user_to_provider(user, provider_type, provider_user_id)
- async def create_token(self, user: User, name: str, is_long_lived: bool = False) -> str:
+ async def create_token(
+ self, user: User, name: str, is_long_lived: bool = False, provider_name: str | None = None
+ ) -> str:
"""
- Create a new access token for a user.
+ Create a new JWT access token for a user.
:param user: The user to create the token for.
:param name: A name/description for the token (e.g., device name).
:param is_long_lived: Whether this is a long-lived token (default: False).
Short-lived tokens (False): Auto-renewing on use, expire after 30 days of inactivity.
Long-lived tokens (True): No auto-renewal, expire after 10 years.
+ :param provider_name: Optional provider name that created this token (e.g., "party_mode").
+ :return: JWT token string.
"""
# Generate unique token ID
token_id = secrets.token_urlsafe(32)
token_name=name,
expires_at=expires_at,
is_long_lived=is_long_lived,
+ provider_name=provider_name,
)
# Store token hash in database for revocation checking
)
await self.database.commit()
return True
+
+ # ==================== Join Code Methods ====================
+
+ async def generate_join_code(
+ self,
+ user_id: str,
+ expires_in_hours: int = JOIN_CODE_DEFAULT_EXPIRY_HOURS,
+ max_uses: int = 1,
+ device_name: str = "Short Code Login",
+ provider_name: str | None = None,
+ ) -> tuple[str, datetime]:
+ """Generate a short join code for link/QR-based login.
+
+ This creates a short alphanumeric code that can be exchanged for a JWT token.
+ Providers can use this to implement features like party mode guest access,
+ device pairing, or other short-code authentication flows.
+
+ :param user_id: The user ID that tokens created from this code will belong to.
+ :param expires_in_hours: Hours until code expires (default: 8).
+ :param max_uses: Maximum number of uses (0 = unlimited).
+ :param device_name: Device name for tokens created with this code.
+ :param provider_name: Optional provider name identifier (e.g., "party_mode").
+ :return: Tuple of (code, expires_at datetime).
+ """
+ if expires_in_hours <= 0:
+ raise ValueError("expires_in_hours must be positive")
+ if max_uses < 0:
+ raise ValueError("max_uses must be non-negative (0 = unlimited)")
+ user = await self.get_user(user_id)
+ if not user:
+ raise ValueError(f"User not found: {user_id}")
+ if user.role != UserRole.GUEST:
+ raise ValueError("Join codes can only be generated for guest accounts")
+
+ now = utc()
+ expires_at = now + timedelta(hours=expires_in_hours)
+
+ for _ in range(3): # Try up to 3 times to avoid code collisions
+ code = "".join(secrets.choice(JOIN_CODE_CHARSET) for _ in range(JOIN_CODE_LENGTH))
+ code_data = {
+ "code_id": secrets.token_urlsafe(32),
+ "code": code,
+ "user_id": user_id,
+ "created_at": now.isoformat(),
+ "expires_at": expires_at.isoformat(),
+ "max_uses": max_uses,
+ "use_count": 0,
+ "device_name": device_name,
+ "provider_name": provider_name,
+ }
+ try:
+ await self.database.insert("join_codes", code_data)
+ await self.database.commit()
+ self.logger.info(
+ "Join code generated for user %s (expires: %s, max_uses: %s, provider: %s)",
+ user.username,
+ expires_at,
+ max_uses,
+ provider_name,
+ )
+ return code, expires_at
+ except IntegrityError:
+ self.logger.warning("Join code collision, retrying...")
+ continue
+
+ raise RuntimeError("Failed to generate a unique join code after 3 attempts")
+
+ async def _exchange_join_code(self, code: str) -> str | None:
+ """Exchange a join code for a JWT access token.
+
+ The token is created for the user associated with the join code,
+ using the provider_name that was specified when the code was generated.
+
+ :param code: The short join code.
+ :return: JWT token string if valid, None otherwise.
+ """
+ now = utc()
+
+ cursor = await self.database.execute(
+ """
+ UPDATE join_codes
+ SET use_count = use_count + 1,
+ last_used_at = :now
+ WHERE code = :code
+ AND expires_at > :now
+ AND (max_uses = 0 OR use_count < max_uses)
+ RETURNING user_id, provider_name, device_name
+ """,
+ {"now": now.isoformat(), "code": code.upper()},
+ )
+ row = await cursor.fetchone()
+ await self.database.commit()
+
+ if not row:
+ self.logger.warning("Join code exchange rejected (code=%s)", code.upper())
+ return None
+
+ user = await self.get_user(row["user_id"])
+ if not user:
+ self.logger.error(
+ "User not found for join code despite FK constraint (user_id=%s)", row["user_id"]
+ )
+ return None
+
+ device_name = row["device_name"] or "Short Code Login"
+ token = await self.create_token(
+ user,
+ device_name,
+ is_long_lived=False,
+ provider_name=row["provider_name"],
+ )
+
+ self.logger.info(
+ "Join code exchanged for token (user=%s, provider=%s)",
+ user.username,
+ row["provider_name"],
+ )
+ return token
+
+ async def revoke_join_codes(
+ self,
+ user_id: str | None = None,
+ provider_name: str | None = None,
+ ) -> int:
+ """Revoke join codes filtered by user and/or provider.
+
+ At least one filter parameter must be provided to prevent accidental deletion of all codes.
+
+ :param user_id: User ID to revoke codes for.
+ :param provider_name: Provider name to revoke codes for.
+ :return: Number of codes revoked.
+ """
+ if not user_id and not provider_name:
+ raise ValueError("At least one of user_id or provider_name must be provided")
+
+ conditions = []
+ params = {}
+
+ if user_id:
+ conditions.append("user_id = :user_id")
+ params["user_id"] = user_id
+ if provider_name:
+ conditions.append("provider_name = :provider_name")
+ params["provider_name"] = provider_name
+
+ cursor = await self.database.execute(
+ f"DELETE FROM join_codes WHERE {' AND '.join(conditions)}", params
+ )
+ await self.database.commit()
+
+ count = int(cursor.rowcount)
+ if count > 0:
+ self.logger.info("Revoked %d join code(s)", count)
+ return count
+
+ async def _cleanup_expired_join_codes(self) -> None:
+ """Delete expired and exhausted join codes from the database."""
+ now = utc()
+ cursor = await self.database.execute(
+ """
+ DELETE FROM join_codes
+ WHERE expires_at < :now
+ OR (max_uses > 0 AND use_count >= max_uses)
+ """,
+ {"now": now.isoformat()},
+ )
+ await self.database.commit()
+ count = int(cursor.rowcount)
+ if count > 0:
+ self.logger.debug("Cleaned up %d expired/exhausted join code(s)", count)
+
+ def _schedule_join_code_cleanup(self) -> None:
+ """Schedule periodic cleanup of expired join codes."""
+ self.mass.create_task(self._cleanup_expired_join_codes())
+ self.mass.call_later(86400, self._schedule_join_code_cleanup)
+
+ @api_command("auth/join_code/exchange", authenticated=False)
+ async def exchange_join_code(self, code: str) -> dict[str, Any]:
+ """Exchange a join code for an access token (public API).
+
+ This is the public API endpoint for short-code authentication.
+ Clients call this with a code (e.g., from QR scan or link) to receive a JWT token.
+
+ :param code: The short join code.
+ :return: Authentication result with access token if successful.
+ """
+ token = await self._exchange_join_code(code)
+
+ if not token:
+ return {
+ "success": False,
+ "error": "Invalid or expired join code",
+ }
+
+ # Decode token to get user info
+ try:
+ payload = self.jwt_helper.decode_token(token)
+ return {
+ "success": True,
+ "access_token": token,
+ "user": {
+ "user_id": payload.get("sub"),
+ "username": payload.get("username"),
+ "role": payload.get("role"),
+ },
+ }
+ except pyjwt.InvalidTokenError:
+ return {
+ "success": False,
+ "error": "Failed to create access token",
+ }
+
+ @api_command("auth/join_codes", required_role="admin")
+ async def list_join_codes(self, user_id: str | None = None) -> list[dict[str, Any]]:
+ """List join codes, optionally filtered by user (admin only).
+
+ :param user_id: Optional user ID to filter codes for.
+ :return: List of join code records.
+ """
+ filter_args = {"user_id": user_id} if user_id else None
+ rows = await self.database.get_rows("join_codes", filter_args, limit=100)
+ return [dict(row) for row in rows]
+
+ @api_command("auth/join_code/revoke", required_role="admin")
+ async def revoke_join_code(self, code_id: str) -> None:
+ """Revoke a specific join code (admin only).
+
+ :param code_id: The code ID to revoke.
+ """
+ code_row = await self.database.get_row("join_codes", {"code_id": code_id})
+ if not code_row:
+ raise InvalidDataError("Join code not found")
+
+ await self.database.delete("join_codes", {"code_id": code_id})
+ await self.database.commit()
+ self.logger.info("Join code revoked (code_id=%s)", code_id)
token_name: str,
expires_at: datetime,
is_long_lived: bool = False,
+ provider_name: str | None = None,
) -> str:
"""Encode a JWT token for a user.
:param token_name: Human-readable token name.
:param expires_at: Token expiration datetime.
:param is_long_lived: Whether this is a long-lived token.
+ :param provider_name: Optional provider name that created this token (e.g., "party_mode").
:return: Encoded JWT token string.
"""
now = utc()
"is_long_lived": is_long_lived,
}
+ if provider_name:
+ payload["provider_name"] = provider_name
+
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
def decode_token(self, token: str, verify_exp: bool = True) -> dict[str, Any]:
<script>
const API_BASE = window.location.origin;
- // Get return_url and device_name from query string
+ // Get parameters from query string
const urlParams = new URLSearchParams(window.location.search);
const returnUrl = urlParams.get('return_url');
const deviceName = urlParams.get('device_name');
+ const joinCode = urlParams.get('join');
// Show error message
function showError(message) {
input.addEventListener('input', hideError);
});
- // Load providers on page load
- loadProviders();
+ // Handle short code authentication (e.g., from QR code or link)
+ async function handleJoinCode() {
+ if (!joinCode) {
+ return false;
+ }
+
+ // Show loading state
+ const container = document.querySelector('.login-container');
+ container.innerHTML = `
+ <div class="logo">
+ <img src="logo.png" alt="Music Assistant">
+ </div>
+ <h1>Music Assistant</h1>
+ <p class="subtitle">Connecting...</p>
+ <div style="text-align: center; padding: 20px;">
+ <div class="loading" style="display: inline-block;"></div>
+ </div>
+ <div id="error" class="error"></div>
+ `;
+
+ try {
+ // Exchange the join code for a JWT token via JSON-RPC API
+ const response = await fetch(`${API_BASE}/api`, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json'
+ },
+ body: JSON.stringify({
+ message_id: 'join_code_auth',
+ command: 'auth/join_code/exchange',
+ args: { code: joinCode.toUpperCase() }
+ })
+ });
+
+ const response_data = await response.json();
+
+ // JSON-RPC wraps results in 'result' field
+ const data = response_data.result || response_data;
+
+ if (data.success && data.access_token) {
+ // Redirect with the token
+ let redirectUrl = '/';
+ // Only allow same-origin relative paths (e.g. "/path"), not external URLs
+ if (typeof returnUrl === 'string' && returnUrl.startsWith('/') && !returnUrl.startsWith('//')) {
+ redirectUrl = returnUrl;
+ }
+ const separator = redirectUrl.includes('?') ? '&' : '?';
+ window.location.href = `${redirectUrl}${separator}code=${encodeURIComponent(data.access_token)}`;
+ return true;
+ } else {
+ // Show error - check both data and response_data for error info
+ const errorMsg = data.error || response_data.error_message || 'Invalid or expired code';
+ const errorEl = document.getElementById('error');
+ if (errorEl) {
+ errorEl.textContent = errorMsg;
+ errorEl.classList.add('show');
+ }
+ // Reload page without join code to show login form
+ setTimeout(() => {
+ urlParams.delete('join');
+ const newUrl = window.location.pathname + (urlParams.toString() ? '?' + urlParams.toString() : '');
+ window.location.href = newUrl;
+ }, 2000);
+ return false;
+ }
+ } catch (error) {
+ console.error('Join code authentication failed:', error);
+ // Show error and reload without join code
+ const errorEl = document.getElementById('error');
+ if (errorEl) {
+ errorEl.textContent = 'Authentication failed. Please try again.';
+ errorEl.classList.add('show');
+ }
+ setTimeout(() => {
+ urlParams.delete('join');
+ const newUrl = window.location.pathname + (urlParams.toString() ? '?' + urlParams.toString() : '');
+ window.location.href = newUrl;
+ }, 2000);
+ return false;
+ }
+ }
+
+ // On page load: try join code first, then load providers
+ (async function() {
+ if (joinCode) {
+ const handled = await handleJoinCode();
+ if (handled) return; // Successfully joined, redirecting
+ }
+ // No join code or it failed, show normal login
+ loadProviders();
+ })();
</script>
</body>
</html>
assert retrieved_user is not None
assert retrieved_user.user_id == existing_user.user_id
assert retrieved_user.username == "admin"
+
+
+# ==================== Join Code Tests ====================
+
+
+async def test_generate_join_code(auth_manager: AuthenticationManager) -> None:
+ """Test generating a join code for a user.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ user = await auth_manager.create_user(username="joincodeuser", role=UserRole.GUEST)
+
+ code, expires_at = await auth_manager.generate_join_code(
+ user_id=user.user_id,
+ expires_in_hours=24,
+ max_uses=0,
+ device_name="Test Device",
+ provider_name="test_provider",
+ )
+
+ assert code is not None
+ assert len(code) == 6 # JOIN_CODE_LENGTH
+ assert code.isalnum()
+ assert expires_at is not None
+ assert expires_at > utc()
+
+
+async def test_generate_join_code_non_guest_rejected(auth_manager: AuthenticationManager) -> None:
+ """Test that generating a join code for non-guest users is rejected.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ admin = await auth_manager.create_user(username="joinadmin", role=UserRole.ADMIN)
+ user = await auth_manager.create_user(username="joinuser", role=UserRole.USER)
+
+ with pytest.raises(ValueError, match="guest accounts"):
+ await auth_manager.generate_join_code(user_id=admin.user_id)
+
+ with pytest.raises(ValueError, match="guest accounts"):
+ await auth_manager.generate_join_code(user_id=user.user_id)
+
+
+async def test_generate_join_code_invalid_user(auth_manager: AuthenticationManager) -> None:
+ """Test that generating a join code for non-existent user raises error.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ with pytest.raises(ValueError, match="User not found"):
+ await auth_manager.generate_join_code(
+ user_id="nonexistent-user-id",
+ expires_in_hours=24,
+ )
+
+
+async def test_exchange_join_code(auth_manager: AuthenticationManager) -> None:
+ """Test exchanging a valid join code for a JWT token.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ user = await auth_manager.create_user(username="exchangeuser", role=UserRole.GUEST)
+
+ code, _ = await auth_manager.generate_join_code(
+ user_id=user.user_id,
+ expires_in_hours=24,
+ device_name="Exchange Test",
+ )
+
+ # Exchange code for token
+ token = await auth_manager._exchange_join_code(code)
+
+ assert token is not None
+ assert len(token) > 0
+
+ # Verify token works for authentication
+ authenticated_user = await auth_manager.authenticate_with_token(token)
+ assert authenticated_user is not None
+ assert authenticated_user.user_id == user.user_id
+ assert authenticated_user.username == user.username
+
+
+async def test_exchange_join_code_case_insensitive(auth_manager: AuthenticationManager) -> None:
+ """Test that join codes are case-insensitive.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ user = await auth_manager.create_user(username="caseuser", role=UserRole.GUEST)
+
+ code, _ = await auth_manager.generate_join_code(
+ user_id=user.user_id,
+ expires_in_hours=24,
+ )
+
+ # Exchange with lowercase version
+ token = await auth_manager._exchange_join_code(code.lower())
+ assert token is not None
+
+ # Verify token works
+ authenticated_user = await auth_manager.authenticate_with_token(token)
+ assert authenticated_user is not None
+ assert authenticated_user.user_id == user.user_id
+
+
+async def test_exchange_join_code_invalid(auth_manager: AuthenticationManager) -> None:
+ """Test that invalid join codes are rejected.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ token = await auth_manager._exchange_join_code("INVALID")
+ assert token is None
+
+
+async def test_exchange_join_code_expired(auth_manager: AuthenticationManager) -> None:
+ """Test that expired join codes are rejected.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ user = await auth_manager.create_user(username="expiredcodeuser", role=UserRole.GUEST)
+
+ code, _ = await auth_manager.generate_join_code(
+ user_id=user.user_id,
+ expires_in_hours=24,
+ )
+
+ # Manually expire the code by updating expires_at in database
+ code_row = await auth_manager.database.get_row("join_codes", {"code": code})
+ assert code_row is not None
+
+ past_time = utc() - timedelta(hours=1)
+ await auth_manager.database.update(
+ "join_codes",
+ {"code_id": code_row["code_id"]},
+ {"expires_at": past_time.isoformat()},
+ )
+
+ # Try to exchange expired code
+ token = await auth_manager._exchange_join_code(code)
+ assert token is None
+
+
+async def test_exchange_join_code_max_uses(auth_manager: AuthenticationManager) -> None:
+ """Test that join codes respect max_uses limit.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ user = await auth_manager.create_user(username="maxusesuser", role=UserRole.GUEST)
+
+ code, _ = await auth_manager.generate_join_code(
+ user_id=user.user_id,
+ expires_in_hours=24,
+ max_uses=2, # Only allow 2 uses
+ )
+
+ # First use should succeed
+ token1 = await auth_manager._exchange_join_code(code)
+ assert token1 is not None
+
+ # Second use should succeed
+ token2 = await auth_manager._exchange_join_code(code)
+ assert token2 is not None
+
+ # Third use should fail (max_uses=2 exceeded)
+ token3 = await auth_manager._exchange_join_code(code)
+ assert token3 is None
+
+
+async def test_exchange_join_code_unlimited_uses(auth_manager: AuthenticationManager) -> None:
+ """Test that join codes with max_uses=0 have unlimited uses.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ user = await auth_manager.create_user(username="unlimiteduser", role=UserRole.GUEST)
+
+ code, _ = await auth_manager.generate_join_code(
+ user_id=user.user_id,
+ expires_in_hours=24,
+ max_uses=0, # Unlimited
+ )
+
+ # Should be able to use multiple times
+ for _ in range(5):
+ token = await auth_manager._exchange_join_code(code)
+ assert token is not None
+
+
+async def test_exchange_join_code_provider_name_in_token(
+ auth_manager: AuthenticationManager,
+) -> None:
+ """Test that provider_name is included in the JWT token claims.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ user = await auth_manager.create_user(username="provideruser", role=UserRole.GUEST)
+
+ code, _ = await auth_manager.generate_join_code(
+ user_id=user.user_id,
+ expires_in_hours=24,
+ provider_name="party_mode",
+ )
+
+ token = await auth_manager._exchange_join_code(code)
+ assert token is not None
+
+ # Decode token and verify provider_name claim
+ payload = auth_manager.jwt_helper.decode_token(token)
+ assert payload.get("provider_name") == "party_mode"
+
+
+async def test_revoke_join_codes_for_user(auth_manager: AuthenticationManager) -> None:
+ """Test revoking join codes for a specific user.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ user1 = await auth_manager.create_user(username="revokeuser1", role=UserRole.GUEST)
+ user2 = await auth_manager.create_user(username="revokeuser2", role=UserRole.GUEST)
+
+ # Create codes for both users
+ code1, _ = await auth_manager.generate_join_code(user_id=user1.user_id)
+ code2, _ = await auth_manager.generate_join_code(user_id=user2.user_id)
+
+ # Revoke codes for user1 only
+ revoked_count = await auth_manager.revoke_join_codes(user_id=user1.user_id)
+ assert revoked_count == 1
+
+ # User1's code should no longer work
+ token1 = await auth_manager._exchange_join_code(code1)
+ assert token1 is None
+
+ # User2's code should still work
+ token2 = await auth_manager._exchange_join_code(code2)
+ assert token2 is not None
+
+
+async def test_revoke_join_codes_requires_filter(auth_manager: AuthenticationManager) -> None:
+ """Test that revoking join codes requires at least one filter parameter.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ with pytest.raises(ValueError, match="At least one of"):
+ await auth_manager.revoke_join_codes()
+
+
+async def test_authenticate_with_join_code_api(auth_manager: AuthenticationManager) -> None:
+ """Test the public API endpoint for join code authentication.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ user = await auth_manager.create_user(
+ username="apijoincodeuser",
+ role=UserRole.GUEST,
+ display_name="API Guest",
+ )
+
+ code, _ = await auth_manager.generate_join_code(
+ user_id=user.user_id,
+ expires_in_hours=24,
+ provider_name="party_mode",
+ )
+
+ # Call the API endpoint
+ result = await auth_manager.exchange_join_code(code)
+
+ assert result["success"] is True
+ assert "access_token" in result
+ assert result["user"]["user_id"] == user.user_id
+ assert result["user"]["username"] == user.username
+ assert result["user"]["role"] == "guest"
+
+
+async def test_authenticate_with_join_code_api_invalid(
+ auth_manager: AuthenticationManager,
+) -> None:
+ """Test the API endpoint with invalid join code.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ result = await auth_manager.exchange_join_code("BADCODE")
+
+ assert result["success"] is False
+ assert "error" in result
+ assert "access_token" not in result
+
+
+async def test_list_join_codes(auth_manager: AuthenticationManager) -> None:
+ """Test listing active join codes (admin only).
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ admin = await auth_manager.create_user(username="listcodesadmin", role=UserRole.ADMIN)
+ guest1 = await auth_manager.create_user(username="listguest1", role=UserRole.GUEST)
+ guest2 = await auth_manager.create_user(username="listguest2", role=UserRole.GUEST)
+ set_current_user(admin)
+
+ # Create codes for both guests
+ await auth_manager.generate_join_code(user_id=guest1.user_id, provider_name="party_mode")
+ await auth_manager.generate_join_code(user_id=guest2.user_id, provider_name="party_mode")
+
+ # List all codes
+ codes = await auth_manager.list_join_codes()
+ assert len(codes) == 2
+
+ # List codes for specific user
+ codes = await auth_manager.list_join_codes(user_id=guest1.user_id)
+ assert len(codes) == 1
+ assert codes[0]["user_id"] == guest1.user_id
+
+
+async def test_revoke_join_code_api(auth_manager: AuthenticationManager) -> None:
+ """Test revoking a specific join code by code_id (admin only).
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ admin = await auth_manager.create_user(username="revokecodeadmin", role=UserRole.ADMIN)
+ guest = await auth_manager.create_user(username="revokeguest", role=UserRole.GUEST)
+ set_current_user(admin)
+
+ code, _ = await auth_manager.generate_join_code(user_id=guest.user_id)
+
+ # Get the code_id from the database
+ codes = await auth_manager.list_join_codes(user_id=guest.user_id)
+ assert len(codes) == 1
+ code_id = codes[0]["code_id"]
+
+ # Revoke the specific code
+ await auth_manager.revoke_join_code(code_id)
+
+ # Code should no longer work
+ token = await auth_manager._exchange_join_code(code)
+ assert token is None
+
+ # List should be empty
+ codes = await auth_manager.list_join_codes(user_id=guest.user_id)
+ assert len(codes) == 0
+
+
+async def test_revoke_join_code_api_not_found(auth_manager: AuthenticationManager) -> None:
+ """Test revoking a non-existent join code raises error.
+
+ :param auth_manager: AuthenticationManager instance.
+ """
+ admin = await auth_manager.create_user(username="revokenotfound", role=UserRole.ADMIN)
+ set_current_user(admin)
+
+ with pytest.raises(InvalidDataError, match="Join code not found"):
+ await auth_manager.revoke_join_code("nonexistent-code-id")