From: apophisnow Date: Wed, 25 Feb 2026 17:57:51 +0000 (-0800) Subject: Add generic short code authentication system (#3078) X-Git-Url: https://git.kitaultman.com/?a=commitdiff_plain;h=d0e4114ea82350834e0c4e76c80002d6938b2b82;p=music-assistant-server.git Add generic short code authentication system (#3078) * feat: add generic short code authentication system Add a reusable short code authentication system that any provider can use for QR code login, device pairing, or similar flows. Changes: - Add join_codes database table (schema v6) - Add generate_join_code(user_id, provider_name, ...) method - Add exchange_join_code() to convert codes to JWT tokens - Add auth/code public API endpoint - Add revoke_join_codes() for cleanup - Update login.html to handle ?join= parameter - Add provider_name parameter to JWT token encoding Providers can implement short code auth flows like: code, expires = await auth.generate_join_code( user_id=my_user.user_id, provider_name="my_provider", expires_in_hours=24, ) The provider_name is stored in the join code and passed to the JWT token, allowing providers to identify their authenticated sessions. Co-Authored-By: Claude Opus 4.5 * fix: Revise db migration * test: add comprehensive tests for join code authentication Add tests covering the short code authentication system: - generate_join_code: basic functionality, invalid user handling - exchange_join_code: success, case-insensitivity, expired codes, max_uses limits, unlimited uses, provider_name in JWT claims - revoke_join_codes: per-user revocation, revoke all codes - authenticate_with_join_code API: success and error cases Co-Authored-By: Claude Opus 4.5 * Update music_assistant/controllers/webserver/auth.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update music_assistant/helpers/resources/login.html Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update music_assistant/controllers/webserver/auth.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update music_assistant/controllers/webserver/auth.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: Simplify join code schema * chore: Update join code test * fix: Fix revoke_join_codes() to use correct db count function * fix: Simplify revoke_join_codes and add provider handling * fix: Avoid useless rename * fix: Handle rare collision edge case * fix: Fix return type * fix: Use named params * fix: Rename prevents already defined error * Add input validation to generate_join_code. * Fix unhandled RunTimeError. * Limit short codes to users with guest role only. * Log security events. * Rename endpoint and add endpoints to list and revoke codes. * Make at least one parameter of revoke_join_codes required to prevent accidental deletion of all codes. * Schedule cleanup of expired join codes once a day. * Minor cleanup. * Fixes for active source and current media with linked protocols * Fix DSP not applying for AirPlay and Sendspin players (#3191) * OpenSubsonic: Use server provided version tag if present (#3200) * Fix use playerid for http profile * ⬆️ Update music-assistant-frontend to 2.17.92 (#3203) Co-authored-by: marcelveldt <6389780+marcelveldt@users.noreply.github.com> * Expand PIN based auth in airplay 2 (#3165) * add LG details * Make pin based auth work in other devices * remove reference to apple tv and macos in check * remove unused constant and adjust airplay2 filter * also apply pairing check to raop * add unit test * Revert MIN_SCHEMA_VERSION to maintain HA compatability. * Add comments to schema version constants * Fix some more issues with syncgroups * Fix HEOS source switching back to Local Music after starting stream (#3206) * Fix group mute for protocol-synced players (#3205) * Handle HEAD requests on root route (#3204) * Fix announcements typo * Some small code quality changes to DLNA Provider * Small simplification for GroupPlayer * Fix Sonos S2 announcement 404 error on cloud queue context endpoint (#3208) * Snapcast: Fixes for hard switching of group leaders (#3209) * Gracefully skip files/folders with emoji names on SMB mounts (#3183) * Add API to handle playback speed (#3198) * Simplify can_group_with logic * Airplay2-configurable-latency (#3210) * Validate queue item ID in Sonos pause path (#3194) * Add some additional guards to asyncprocess * Add a bunch of extra error handling and logging for flow streams * Properly cleanup stream buffers * ⬆️ Update music-assistant-frontend to 2.17.93 (#3214) Co-authored-by: marcelveldt <6389780+marcelveldt@users.noreply.github.com> * Fix bluesound volume jumping back after volume_set. * Speed-up core startup a bit * More gracefully handle DLNA errors * Lock set_members to avoid concurrent actions * Fix issue with subprocess pips closing * Fix ungroup command * Add note in docstring * Auto ungroup when trying to form syncgroup with already synced player * Fix accessing player.state.synced_to * Fix playback speed handling on queue item and not on queue * Fix for _cleanup_player_memberships * Fix race condition with enqueue_next_media on SyncGroup * Fix some edge cases with AirPlay DACP commands * Fix set_members with lock * Fix player not available in HA at startup * Fix fully_played should return boolean * Auto translate commands directed at protocol player id to visible parent * Some minor tweaks to handling prevent-playback of airplay * Speedup core controller startup * Pre-compile Python bytecode in Dockerimage for faster startup * Speedup is_hass_supervisor check * Fix _cleanup_player_memberships and _handle_set_members * Fix player config not fully persisting * ⬆️ Update music-assistant-frontend to 2.17.94 (#3218) Co-authored-by: marcelveldt <6389780+marcelveldt@users.noreply.github.com> * Bandcamp: validate login on init when credentials are configured (#3215) Co-authored-by: David Bishop Co-authored-by: Claude Opus 4.6 * Fix bluesound volume jumping back after volume_set. * Use ImageType.THUMB for Bandcamp artwork images (#3212) Bandcamp artwork is square, not landscape. All other music providers in the codebase use THUMB for standard album and artist art. Co-authored-by: David Bishop Co-authored-by: Claude Opus 4.6 * Fix inverted track_number condition in Bandcamp converter (#3211) The condition checked output.track_number instead of track.track_number, meaning track numbers from the API were only applied when the output already had a non-None default. Co-authored-by: David Bishop Co-authored-by: Claude Opus 4.6 * Clear internal HEOS queue before playing (#3219) * Update Alexa player provider (#3167) * Update Alexa player provider * Remove redundant try catch Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Address PR comments * Remove ActionUnavailble catches * Remove extra catch alls and add _on_player_media_updated * Remove catch all * Bump AlexaPy * Fix _upload_metadata when media is not available --------- Co-authored-by: Sameer Alam Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fix race condition in player register flow wrt config * Fix select output protocol already in play_index to avoid race on flow mode * Fail job on test failures * Fix Radioparadise image URL (#3220) The change to the documentation repo moved the images * Fix flow mode determination * Fix player tests * Add genre icons and SVG handling to imageproxy (#3223) * Add genre icons and SVG handling to imageproxy * Cleanup * ⬆️ Update music-assistant-frontend to 2.17.95 (#3222) Co-authored-by: stvncode <25082266+stvncode@users.noreply.github.com> Co-authored-by: Marvin Schenkel * Fix static members of sync group * Fix last small issues with syncgroup * Fix issue with clearing output protocol during track changes * Use mass.call_later. --------- Co-authored-by: Claude Opus 4.5 Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Marvin Schenkel Co-authored-by: Marcel van der Veldt Co-authored-by: Maxim Raznatovski Co-authored-by: Eric Munson Co-authored-by: music-assistant-machine <141749843+music-assistant-machine@users.noreply.github.com> Co-authored-by: marcelveldt <6389780+marcelveldt@users.noreply.github.com> Co-authored-by: hmonteiro <1819451+hmonteiro@users.noreply.github.com> Co-authored-by: Tom Matheussen <13683094+Tommatheussen@users.noreply.github.com> Co-authored-by: scyto Co-authored-by: David Bishop Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: Mischa Siekmann <45062894+gnumpi@users.noreply.github.com> Co-authored-by: OzGav Co-authored-by: Andy Kelk Co-authored-by: Brad Keifer <15224368+bradkeifer@users.noreply.github.com> Co-authored-by: Bob Butler Co-authored-by: David Bishop Co-authored-by: Sameer Alam <31905246+alams154@users.noreply.github.com> Co-authored-by: Sameer Alam Co-authored-by: stvncode <25082266+stvncode@users.noreply.github.com> --- diff --git a/music_assistant/controllers/webserver/auth.py b/music_assistant/controllers/webserver/auth.py index 6bb721d2..c644511e 100644 --- a/music_assistant/controllers/webserver/auth.py +++ b/music_assistant/controllers/webserver/auth.py @@ -7,7 +7,7 @@ import hashlib import logging import secrets from datetime import datetime, timedelta -from sqlite3 import OperationalError +from sqlite3 import IntegrityError, OperationalError from typing import TYPE_CHECKING, Any import jwt as pyjwt @@ -24,11 +24,7 @@ from music_assistant_models.errors import ( InvalidDataError, ) -from music_assistant.constants import ( - DB_TABLE_PLAYLOG, - HOMEASSISTANT_SYSTEM_USER, - MASS_LOGGER_NAME, -) +from music_assistant.constants import DB_TABLE_PLAYLOG, HOMEASSISTANT_SYSTEM_USER, MASS_LOGGER_NAME from music_assistant.controllers.webserver.helpers.auth_middleware import ( get_current_token, get_current_user, @@ -53,12 +49,17 @@ if TYPE_CHECKING: LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.auth") # Database schema version -DB_SCHEMA_VERSION = 4 +DB_SCHEMA_VERSION = 5 # Token expiration constants (in days) TOKEN_SHORT_LIVED_EXPIRATION = 30 # Short-lived tokens (auto-renewing on use) TOKEN_LONG_LIVED_EXPIRATION = 3650 # Long-lived tokens (10 years, no auto-renewal) +# Join code constants (short codes for QR/link-based login) +JOIN_CODE_LENGTH = 6 +JOIN_CODE_CHARSET = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" # No I/O/0/1 for readability +JOIN_CODE_DEFAULT_EXPIRY_HOURS = 8 + class AuthenticationManager: """Manager for authentication and user management (part of webserver controller).""" @@ -96,6 +97,8 @@ class AuthenticationManager: self._has_users = await self._has_non_system_users() + self._schedule_join_code_cleanup() + self.logger.info( "Authentication manager initialized (providers=%d)", len(self.login_providers) ) @@ -203,6 +206,24 @@ class AuthenticationManager: ) """ ) + # Join codes table (for short code to JWT exchange, used by providers like party mode) + await self.database.execute( + """ + CREATE TABLE IF NOT EXISTS join_codes ( + code_id TEXT PRIMARY KEY, + code TEXT NOT NULL UNIQUE, + user_id TEXT NOT NULL, + created_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + max_uses INTEGER DEFAULT 0, + use_count INTEGER DEFAULT 0, + last_used_at TEXT, + device_name TEXT, + provider_name TEXT, + FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE + ) + """ + ) await self.database.commit() async def _create_database_indexes(self) -> None: @@ -221,6 +242,9 @@ class AuthenticationManager: await self.database.execute( "CREATE INDEX IF NOT EXISTS idx_tokens_hash ON auth_tokens(token_hash)" ) + await self.database.execute( + "CREATE INDEX IF NOT EXISTS idx_join_codes_user ON join_codes(user_id)" + ) async def _migrate_database(self, from_version: int) -> None: """Perform database migration. @@ -258,6 +282,27 @@ class AuthenticationManager: await self.database.execute("UPDATE users SET username = LOWER(username)") await self.database.commit() + # Migration to version 5: Add join codes table + if from_version < 5: + await self.database.execute( + """ + CREATE TABLE IF NOT EXISTS join_codes ( + code_id TEXT PRIMARY KEY, + code TEXT NOT NULL UNIQUE, + user_id TEXT NOT NULL, + created_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + max_uses INTEGER DEFAULT 0, + use_count INTEGER DEFAULT 0, + last_used_at TEXT, + device_name TEXT, + provider_name TEXT, + FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE + ) + """ + ) + await self.database.commit() + async def _get_or_create_jwt_secret(self) -> str: """Get or create JWT secret key from database. @@ -355,7 +400,7 @@ class AuthenticationManager: async def authenticate_with_token(self, token: str) -> User | None: """ - Authenticate a user with an access token. + Authenticate a user with an access token (JWT or legacy). Supports both JWT tokens and legacy hash-based tokens for backward compatibility. @@ -426,18 +471,18 @@ class AuthenticationManager: # Implement sliding expiration for short-lived tokens is_long_lived = bool(token_row["is_long_lived"]) now = utc() - updates = {"last_used_at": now.isoformat()} + legacy_updates: dict[str, str] = {"last_used_at": now.isoformat()} if not is_long_lived and token_row["expires_at"]: # Short-lived token: extend expiration on each use (sliding window) new_expires_at = now + timedelta(days=TOKEN_SHORT_LIVED_EXPIRATION) - updates["expires_at"] = new_expires_at.isoformat() + legacy_updates["expires_at"] = new_expires_at.isoformat() # Update last used timestamp and potentially expiration await self.database.update( "auth_tokens", {"token_id": token_row["token_id"]}, - updates, + legacy_updates, ) # Get user @@ -459,7 +504,6 @@ class AuthenticationManager: token_row = await self.database.get_row("auth_tokens", {"token_hash": token_hash}) if not token_row: return None - return str(token_row["token_id"]) @api_command("auth/user", required_role="admin") @@ -830,15 +874,19 @@ class AuthenticationManager: # Create new link await self.link_user_to_provider(user, provider_type, provider_user_id) - async def create_token(self, user: User, name: str, is_long_lived: bool = False) -> str: + async def create_token( + self, user: User, name: str, is_long_lived: bool = False, provider_name: str | None = None + ) -> str: """ - Create a new access token for a user. + Create a new JWT access token for a user. :param user: The user to create the token for. :param name: A name/description for the token (e.g., device name). :param is_long_lived: Whether this is a long-lived token (default: False). Short-lived tokens (False): Auto-renewing on use, expire after 30 days of inactivity. Long-lived tokens (True): No auto-renewal, expire after 10 years. + :param provider_name: Optional provider name that created this token (e.g., "party_mode"). + :return: JWT token string. """ # Generate unique token ID token_id = secrets.token_urlsafe(32) @@ -859,6 +907,7 @@ class AuthenticationManager: token_name=name, expires_at=expires_at, is_long_lived=is_long_lived, + provider_name=provider_name, ) # Store token hash in database for revocation checking @@ -1487,3 +1536,239 @@ class AuthenticationManager: ) await self.database.commit() return True + + # ==================== Join Code Methods ==================== + + async def generate_join_code( + self, + user_id: str, + expires_in_hours: int = JOIN_CODE_DEFAULT_EXPIRY_HOURS, + max_uses: int = 1, + device_name: str = "Short Code Login", + provider_name: str | None = None, + ) -> tuple[str, datetime]: + """Generate a short join code for link/QR-based login. + + This creates a short alphanumeric code that can be exchanged for a JWT token. + Providers can use this to implement features like party mode guest access, + device pairing, or other short-code authentication flows. + + :param user_id: The user ID that tokens created from this code will belong to. + :param expires_in_hours: Hours until code expires (default: 8). + :param max_uses: Maximum number of uses (0 = unlimited). + :param device_name: Device name for tokens created with this code. + :param provider_name: Optional provider name identifier (e.g., "party_mode"). + :return: Tuple of (code, expires_at datetime). + """ + if expires_in_hours <= 0: + raise ValueError("expires_in_hours must be positive") + if max_uses < 0: + raise ValueError("max_uses must be non-negative (0 = unlimited)") + user = await self.get_user(user_id) + if not user: + raise ValueError(f"User not found: {user_id}") + if user.role != UserRole.GUEST: + raise ValueError("Join codes can only be generated for guest accounts") + + now = utc() + expires_at = now + timedelta(hours=expires_in_hours) + + for _ in range(3): # Try up to 3 times to avoid code collisions + code = "".join(secrets.choice(JOIN_CODE_CHARSET) for _ in range(JOIN_CODE_LENGTH)) + code_data = { + "code_id": secrets.token_urlsafe(32), + "code": code, + "user_id": user_id, + "created_at": now.isoformat(), + "expires_at": expires_at.isoformat(), + "max_uses": max_uses, + "use_count": 0, + "device_name": device_name, + "provider_name": provider_name, + } + try: + await self.database.insert("join_codes", code_data) + await self.database.commit() + self.logger.info( + "Join code generated for user %s (expires: %s, max_uses: %s, provider: %s)", + user.username, + expires_at, + max_uses, + provider_name, + ) + return code, expires_at + except IntegrityError: + self.logger.warning("Join code collision, retrying...") + continue + + raise RuntimeError("Failed to generate a unique join code after 3 attempts") + + async def _exchange_join_code(self, code: str) -> str | None: + """Exchange a join code for a JWT access token. + + The token is created for the user associated with the join code, + using the provider_name that was specified when the code was generated. + + :param code: The short join code. + :return: JWT token string if valid, None otherwise. + """ + now = utc() + + cursor = await self.database.execute( + """ + UPDATE join_codes + SET use_count = use_count + 1, + last_used_at = :now + WHERE code = :code + AND expires_at > :now + AND (max_uses = 0 OR use_count < max_uses) + RETURNING user_id, provider_name, device_name + """, + {"now": now.isoformat(), "code": code.upper()}, + ) + row = await cursor.fetchone() + await self.database.commit() + + if not row: + self.logger.warning("Join code exchange rejected (code=%s)", code.upper()) + return None + + user = await self.get_user(row["user_id"]) + if not user: + self.logger.error( + "User not found for join code despite FK constraint (user_id=%s)", row["user_id"] + ) + return None + + device_name = row["device_name"] or "Short Code Login" + token = await self.create_token( + user, + device_name, + is_long_lived=False, + provider_name=row["provider_name"], + ) + + self.logger.info( + "Join code exchanged for token (user=%s, provider=%s)", + user.username, + row["provider_name"], + ) + return token + + async def revoke_join_codes( + self, + user_id: str | None = None, + provider_name: str | None = None, + ) -> int: + """Revoke join codes filtered by user and/or provider. + + At least one filter parameter must be provided to prevent accidental deletion of all codes. + + :param user_id: User ID to revoke codes for. + :param provider_name: Provider name to revoke codes for. + :return: Number of codes revoked. + """ + if not user_id and not provider_name: + raise ValueError("At least one of user_id or provider_name must be provided") + + conditions = [] + params = {} + + if user_id: + conditions.append("user_id = :user_id") + params["user_id"] = user_id + if provider_name: + conditions.append("provider_name = :provider_name") + params["provider_name"] = provider_name + + cursor = await self.database.execute( + f"DELETE FROM join_codes WHERE {' AND '.join(conditions)}", params + ) + await self.database.commit() + + count = int(cursor.rowcount) + if count > 0: + self.logger.info("Revoked %d join code(s)", count) + return count + + async def _cleanup_expired_join_codes(self) -> None: + """Delete expired and exhausted join codes from the database.""" + now = utc() + cursor = await self.database.execute( + """ + DELETE FROM join_codes + WHERE expires_at < :now + OR (max_uses > 0 AND use_count >= max_uses) + """, + {"now": now.isoformat()}, + ) + await self.database.commit() + count = int(cursor.rowcount) + if count > 0: + self.logger.debug("Cleaned up %d expired/exhausted join code(s)", count) + + def _schedule_join_code_cleanup(self) -> None: + """Schedule periodic cleanup of expired join codes.""" + self.mass.create_task(self._cleanup_expired_join_codes()) + self.mass.call_later(86400, self._schedule_join_code_cleanup) + + @api_command("auth/join_code/exchange", authenticated=False) + async def exchange_join_code(self, code: str) -> dict[str, Any]: + """Exchange a join code for an access token (public API). + + This is the public API endpoint for short-code authentication. + Clients call this with a code (e.g., from QR scan or link) to receive a JWT token. + + :param code: The short join code. + :return: Authentication result with access token if successful. + """ + token = await self._exchange_join_code(code) + + if not token: + return { + "success": False, + "error": "Invalid or expired join code", + } + + # Decode token to get user info + try: + payload = self.jwt_helper.decode_token(token) + return { + "success": True, + "access_token": token, + "user": { + "user_id": payload.get("sub"), + "username": payload.get("username"), + "role": payload.get("role"), + }, + } + except pyjwt.InvalidTokenError: + return { + "success": False, + "error": "Failed to create access token", + } + + @api_command("auth/join_codes", required_role="admin") + async def list_join_codes(self, user_id: str | None = None) -> list[dict[str, Any]]: + """List join codes, optionally filtered by user (admin only). + + :param user_id: Optional user ID to filter codes for. + :return: List of join code records. + """ + filter_args = {"user_id": user_id} if user_id else None + rows = await self.database.get_rows("join_codes", filter_args, limit=100) + return [dict(row) for row in rows] + + @api_command("auth/join_code/revoke", required_role="admin") + async def revoke_join_code(self, code_id: str) -> None: + """Revoke a specific join code (admin only). + + :param code_id: The code ID to revoke. + """ + code_row = await self.database.get_row("join_codes", {"code_id": code_id}) + if not code_row: + raise InvalidDataError("Join code not found") + + await self.database.delete("join_codes", {"code_id": code_id}) + await self.database.commit() + self.logger.info("Join code revoked (code_id=%s)", code_id) diff --git a/music_assistant/helpers/jwt_auth.py b/music_assistant/helpers/jwt_auth.py index c7d6bb03..f7054f5a 100644 --- a/music_assistant/helpers/jwt_auth.py +++ b/music_assistant/helpers/jwt_auth.py @@ -43,6 +43,7 @@ class JWTHelper: token_name: str, expires_at: datetime, is_long_lived: bool = False, + provider_name: str | None = None, ) -> str: """Encode a JWT token for a user. @@ -51,6 +52,7 @@ class JWTHelper: :param token_name: Human-readable token name. :param expires_at: Token expiration datetime. :param is_long_lived: Whether this is a long-lived token. + :param provider_name: Optional provider name that created this token (e.g., "party_mode"). :return: Encoded JWT token string. """ now = utc() @@ -65,6 +67,9 @@ class JWTHelper: "is_long_lived": is_long_lived, } + if provider_name: + payload["provider_name"] = provider_name + return jwt.encode(payload, self.secret_key, algorithm=self.algorithm) def decode_token(self, token: str, verify_exp: bool = True) -> dict[str, Any]: diff --git a/music_assistant/helpers/resources/login.html b/music_assistant/helpers/resources/login.html index bfc92d3e..8ba17088 100644 --- a/music_assistant/helpers/resources/login.html +++ b/music_assistant/helpers/resources/login.html @@ -81,10 +81,11 @@ diff --git a/tests/test_webserver_auth.py b/tests/test_webserver_auth.py index c2275e5c..081a12c1 100644 --- a/tests/test_webserver_auth.py +++ b/tests/test_webserver_auth.py @@ -907,3 +907,347 @@ async def test_ingress_auth_existing_username(auth_manager: AuthenticationManage assert retrieved_user is not None assert retrieved_user.user_id == existing_user.user_id assert retrieved_user.username == "admin" + + +# ==================== Join Code Tests ==================== + + +async def test_generate_join_code(auth_manager: AuthenticationManager) -> None: + """Test generating a join code for a user. + + :param auth_manager: AuthenticationManager instance. + """ + user = await auth_manager.create_user(username="joincodeuser", role=UserRole.GUEST) + + code, expires_at = await auth_manager.generate_join_code( + user_id=user.user_id, + expires_in_hours=24, + max_uses=0, + device_name="Test Device", + provider_name="test_provider", + ) + + assert code is not None + assert len(code) == 6 # JOIN_CODE_LENGTH + assert code.isalnum() + assert expires_at is not None + assert expires_at > utc() + + +async def test_generate_join_code_non_guest_rejected(auth_manager: AuthenticationManager) -> None: + """Test that generating a join code for non-guest users is rejected. + + :param auth_manager: AuthenticationManager instance. + """ + admin = await auth_manager.create_user(username="joinadmin", role=UserRole.ADMIN) + user = await auth_manager.create_user(username="joinuser", role=UserRole.USER) + + with pytest.raises(ValueError, match="guest accounts"): + await auth_manager.generate_join_code(user_id=admin.user_id) + + with pytest.raises(ValueError, match="guest accounts"): + await auth_manager.generate_join_code(user_id=user.user_id) + + +async def test_generate_join_code_invalid_user(auth_manager: AuthenticationManager) -> None: + """Test that generating a join code for non-existent user raises error. + + :param auth_manager: AuthenticationManager instance. + """ + with pytest.raises(ValueError, match="User not found"): + await auth_manager.generate_join_code( + user_id="nonexistent-user-id", + expires_in_hours=24, + ) + + +async def test_exchange_join_code(auth_manager: AuthenticationManager) -> None: + """Test exchanging a valid join code for a JWT token. + + :param auth_manager: AuthenticationManager instance. + """ + user = await auth_manager.create_user(username="exchangeuser", role=UserRole.GUEST) + + code, _ = await auth_manager.generate_join_code( + user_id=user.user_id, + expires_in_hours=24, + device_name="Exchange Test", + ) + + # Exchange code for token + token = await auth_manager._exchange_join_code(code) + + assert token is not None + assert len(token) > 0 + + # Verify token works for authentication + authenticated_user = await auth_manager.authenticate_with_token(token) + assert authenticated_user is not None + assert authenticated_user.user_id == user.user_id + assert authenticated_user.username == user.username + + +async def test_exchange_join_code_case_insensitive(auth_manager: AuthenticationManager) -> None: + """Test that join codes are case-insensitive. + + :param auth_manager: AuthenticationManager instance. + """ + user = await auth_manager.create_user(username="caseuser", role=UserRole.GUEST) + + code, _ = await auth_manager.generate_join_code( + user_id=user.user_id, + expires_in_hours=24, + ) + + # Exchange with lowercase version + token = await auth_manager._exchange_join_code(code.lower()) + assert token is not None + + # Verify token works + authenticated_user = await auth_manager.authenticate_with_token(token) + assert authenticated_user is not None + assert authenticated_user.user_id == user.user_id + + +async def test_exchange_join_code_invalid(auth_manager: AuthenticationManager) -> None: + """Test that invalid join codes are rejected. + + :param auth_manager: AuthenticationManager instance. + """ + token = await auth_manager._exchange_join_code("INVALID") + assert token is None + + +async def test_exchange_join_code_expired(auth_manager: AuthenticationManager) -> None: + """Test that expired join codes are rejected. + + :param auth_manager: AuthenticationManager instance. + """ + user = await auth_manager.create_user(username="expiredcodeuser", role=UserRole.GUEST) + + code, _ = await auth_manager.generate_join_code( + user_id=user.user_id, + expires_in_hours=24, + ) + + # Manually expire the code by updating expires_at in database + code_row = await auth_manager.database.get_row("join_codes", {"code": code}) + assert code_row is not None + + past_time = utc() - timedelta(hours=1) + await auth_manager.database.update( + "join_codes", + {"code_id": code_row["code_id"]}, + {"expires_at": past_time.isoformat()}, + ) + + # Try to exchange expired code + token = await auth_manager._exchange_join_code(code) + assert token is None + + +async def test_exchange_join_code_max_uses(auth_manager: AuthenticationManager) -> None: + """Test that join codes respect max_uses limit. + + :param auth_manager: AuthenticationManager instance. + """ + user = await auth_manager.create_user(username="maxusesuser", role=UserRole.GUEST) + + code, _ = await auth_manager.generate_join_code( + user_id=user.user_id, + expires_in_hours=24, + max_uses=2, # Only allow 2 uses + ) + + # First use should succeed + token1 = await auth_manager._exchange_join_code(code) + assert token1 is not None + + # Second use should succeed + token2 = await auth_manager._exchange_join_code(code) + assert token2 is not None + + # Third use should fail (max_uses=2 exceeded) + token3 = await auth_manager._exchange_join_code(code) + assert token3 is None + + +async def test_exchange_join_code_unlimited_uses(auth_manager: AuthenticationManager) -> None: + """Test that join codes with max_uses=0 have unlimited uses. + + :param auth_manager: AuthenticationManager instance. + """ + user = await auth_manager.create_user(username="unlimiteduser", role=UserRole.GUEST) + + code, _ = await auth_manager.generate_join_code( + user_id=user.user_id, + expires_in_hours=24, + max_uses=0, # Unlimited + ) + + # Should be able to use multiple times + for _ in range(5): + token = await auth_manager._exchange_join_code(code) + assert token is not None + + +async def test_exchange_join_code_provider_name_in_token( + auth_manager: AuthenticationManager, +) -> None: + """Test that provider_name is included in the JWT token claims. + + :param auth_manager: AuthenticationManager instance. + """ + user = await auth_manager.create_user(username="provideruser", role=UserRole.GUEST) + + code, _ = await auth_manager.generate_join_code( + user_id=user.user_id, + expires_in_hours=24, + provider_name="party_mode", + ) + + token = await auth_manager._exchange_join_code(code) + assert token is not None + + # Decode token and verify provider_name claim + payload = auth_manager.jwt_helper.decode_token(token) + assert payload.get("provider_name") == "party_mode" + + +async def test_revoke_join_codes_for_user(auth_manager: AuthenticationManager) -> None: + """Test revoking join codes for a specific user. + + :param auth_manager: AuthenticationManager instance. + """ + user1 = await auth_manager.create_user(username="revokeuser1", role=UserRole.GUEST) + user2 = await auth_manager.create_user(username="revokeuser2", role=UserRole.GUEST) + + # Create codes for both users + code1, _ = await auth_manager.generate_join_code(user_id=user1.user_id) + code2, _ = await auth_manager.generate_join_code(user_id=user2.user_id) + + # Revoke codes for user1 only + revoked_count = await auth_manager.revoke_join_codes(user_id=user1.user_id) + assert revoked_count == 1 + + # User1's code should no longer work + token1 = await auth_manager._exchange_join_code(code1) + assert token1 is None + + # User2's code should still work + token2 = await auth_manager._exchange_join_code(code2) + assert token2 is not None + + +async def test_revoke_join_codes_requires_filter(auth_manager: AuthenticationManager) -> None: + """Test that revoking join codes requires at least one filter parameter. + + :param auth_manager: AuthenticationManager instance. + """ + with pytest.raises(ValueError, match="At least one of"): + await auth_manager.revoke_join_codes() + + +async def test_authenticate_with_join_code_api(auth_manager: AuthenticationManager) -> None: + """Test the public API endpoint for join code authentication. + + :param auth_manager: AuthenticationManager instance. + """ + user = await auth_manager.create_user( + username="apijoincodeuser", + role=UserRole.GUEST, + display_name="API Guest", + ) + + code, _ = await auth_manager.generate_join_code( + user_id=user.user_id, + expires_in_hours=24, + provider_name="party_mode", + ) + + # Call the API endpoint + result = await auth_manager.exchange_join_code(code) + + assert result["success"] is True + assert "access_token" in result + assert result["user"]["user_id"] == user.user_id + assert result["user"]["username"] == user.username + assert result["user"]["role"] == "guest" + + +async def test_authenticate_with_join_code_api_invalid( + auth_manager: AuthenticationManager, +) -> None: + """Test the API endpoint with invalid join code. + + :param auth_manager: AuthenticationManager instance. + """ + result = await auth_manager.exchange_join_code("BADCODE") + + assert result["success"] is False + assert "error" in result + assert "access_token" not in result + + +async def test_list_join_codes(auth_manager: AuthenticationManager) -> None: + """Test listing active join codes (admin only). + + :param auth_manager: AuthenticationManager instance. + """ + admin = await auth_manager.create_user(username="listcodesadmin", role=UserRole.ADMIN) + guest1 = await auth_manager.create_user(username="listguest1", role=UserRole.GUEST) + guest2 = await auth_manager.create_user(username="listguest2", role=UserRole.GUEST) + set_current_user(admin) + + # Create codes for both guests + await auth_manager.generate_join_code(user_id=guest1.user_id, provider_name="party_mode") + await auth_manager.generate_join_code(user_id=guest2.user_id, provider_name="party_mode") + + # List all codes + codes = await auth_manager.list_join_codes() + assert len(codes) == 2 + + # List codes for specific user + codes = await auth_manager.list_join_codes(user_id=guest1.user_id) + assert len(codes) == 1 + assert codes[0]["user_id"] == guest1.user_id + + +async def test_revoke_join_code_api(auth_manager: AuthenticationManager) -> None: + """Test revoking a specific join code by code_id (admin only). + + :param auth_manager: AuthenticationManager instance. + """ + admin = await auth_manager.create_user(username="revokecodeadmin", role=UserRole.ADMIN) + guest = await auth_manager.create_user(username="revokeguest", role=UserRole.GUEST) + set_current_user(admin) + + code, _ = await auth_manager.generate_join_code(user_id=guest.user_id) + + # Get the code_id from the database + codes = await auth_manager.list_join_codes(user_id=guest.user_id) + assert len(codes) == 1 + code_id = codes[0]["code_id"] + + # Revoke the specific code + await auth_manager.revoke_join_code(code_id) + + # Code should no longer work + token = await auth_manager._exchange_join_code(code) + assert token is None + + # List should be empty + codes = await auth_manager.list_join_codes(user_id=guest.user_id) + assert len(codes) == 0 + + +async def test_revoke_join_code_api_not_found(auth_manager: AuthenticationManager) -> None: + """Test revoking a non-existent join code raises error. + + :param auth_manager: AuthenticationManager instance. + """ + admin = await auth_manager.create_user(username="revokenotfound", role=UserRole.ADMIN) + set_current_user(admin) + + with pytest.raises(InvalidDataError, match="Join code not found"): + await auth_manager.revoke_join_code("nonexistent-code-id")