from sqlite3 import OperationalError
from typing import TYPE_CHECKING, Any
+import jwt as pyjwt
from music_assistant_models.auth import (
AuthProviderType,
AuthToken,
from music_assistant.helpers.database import DatabaseConnection
from music_assistant.helpers.datetime import utc
from music_assistant.helpers.json import json_dumps, json_loads
+from music_assistant.helpers.jwt_auth import JWTHelper
if TYPE_CHECKING:
from music_assistant.controllers.webserver import WebserverController
self.login_providers: dict[str, LoginProvider] = {}
self.logger = LOGGER
self._has_users: bool = False
+ self.jwt_helper: JWTHelper = None # type: ignore[assignment]
async def setup(self) -> None:
"""Initialize the authentication manager."""
# Create database schema and handle migrations
await self._setup_database()
+ # Initialize JWT helper with secret key
+ jwt_secret = await self._get_or_create_jwt_secret()
+ self.jwt_helper = JWTHelper(jwt_secret)
+
# Setup login providers based on config
await self._setup_login_providers(allow_self_registration)
await self.database.execute("UPDATE users SET username = LOWER(username)")
await self.database.commit()
+ async def _get_or_create_jwt_secret(self) -> str:
+ """Get or create JWT secret key from database.
+
+ :return: JWT secret key for signing tokens.
+ """
+ # Try to get existing secret
+ if secret_row := await self.database.get_row("settings", {"key": "jwt_secret"}):
+ return str(secret_row["value"])
+
+ # Generate new secret
+ jwt_secret = JWTHelper.generate_secret_key()
+
+ # Store in database
+ await self.database.insert_or_replace(
+ "settings",
+ {"key": "jwt_secret", "value": jwt_secret, "type": "string"},
+ )
+ await self.database.commit()
+
+ self.logger.info("Generated new JWT secret key")
+ return jwt_secret
+
async def _setup_login_providers(self, allow_self_registration: bool) -> None:
"""
Set up available login providers based on configuration.
"""
Authenticate a user with an access token.
- :param token: The access token.
+ Supports both JWT tokens and legacy hash-based tokens for backward compatibility.
+
+ :param token: The access token (JWT or legacy hash token).
"""
- # Hash the token to look it up
- token_hash = hashlib.sha256(token.encode()).hexdigest()
+ # Try to decode as JWT first
+ try:
+ payload = self.jwt_helper.decode_token(token, verify_exp=True)
+ token_id = payload.get("jti")
+ user_id = payload.get("sub")
+ is_long_lived = payload.get("is_long_lived", False)
+
+ if not token_id or not user_id:
+ return None
+
+ token_row = await self.database.get_row("auth_tokens", {"token_id": token_id})
+ if not token_row:
+ return None
- # Find token in database
+ # Database expiration is source of truth
+ if token_row["expires_at"]:
+ db_expires_at = datetime.fromisoformat(token_row["expires_at"])
+ if utc() > db_expires_at:
+ await self.database.delete("auth_tokens", {"token_id": token_id})
+ return None
+
+ # Update last used timestamp
+ now = utc()
+ updates = {"last_used_at": now.isoformat()}
+
+ if not is_long_lived:
+ # Short-lived token: extend expiration on each use (sliding window)
+ new_expires_at = now + timedelta(days=TOKEN_SHORT_LIVED_EXPIRATION)
+ updates["expires_at"] = new_expires_at.isoformat()
+
+ # Update database
+ await self.database.update(
+ "auth_tokens",
+ {"token_id": token_id},
+ updates,
+ )
+
+ return await self.get_user(user_id)
+
+ except pyjwt.ExpiredSignatureError:
+ if token_id := self.jwt_helper.get_token_id(token):
+ await self.database.delete("auth_tokens", {"token_id": token_id})
+ return None
+ except pyjwt.InvalidTokenError:
+ self.logger.debug("Token is not a valid JWT, trying legacy hash lookup")
+ except Exception as err:
+ self.logger.debug("Error decoding JWT token: %s, trying legacy hash lookup", err)
+
+ # Fallback to legacy hash-based token lookup
+ token_hash = hashlib.sha256(token.encode()).hexdigest()
token_row = await self.database.get_row("auth_tokens", {"token_hash": token_hash})
if not token_row:
return None
"""
Get token_id from a token string (for tracking revocation).
- :param token: The access token.
+ :param token: The access token (JWT or legacy hash token).
:return: The token_id or None if token not found.
"""
- # Hash the token to look it up
- token_hash = hashlib.sha256(token.encode()).hexdigest()
+ # Try to extract from JWT first
+ if token_id := self.jwt_helper.get_token_id(token):
+ return token_id
- # Find token in database
+ # Fallback: Hash-based lookup for legacy tokens
+ token_hash = hashlib.sha256(token.encode()).hexdigest()
token_row = await self.database.get_row("auth_tokens", {"token_hash": token_hash})
if not token_row:
return None
Short-lived tokens (False): Auto-renewing on use, expire after 30 days of inactivity.
Long-lived tokens (True): No auto-renewal, expire after 10 years.
"""
- # Generate token
- token = secrets.token_urlsafe(48)
- token_hash = hashlib.sha256(token.encode()).hexdigest()
+ # Generate unique token ID
+ token_id = secrets.token_urlsafe(32)
# Calculate expiration based on token type
created_at = utc()
# Short-lived tokens expire after 30 days (with auto-renewal on use)
expires_at = created_at + timedelta(days=TOKEN_SHORT_LIVED_EXPIRATION)
- # Store token
+ # Generate JWT token
+ token = self.jwt_helper.encode_token(
+ user=user,
+ token_id=token_id,
+ token_name=name,
+ expires_at=expires_at,
+ is_long_lived=is_long_lived,
+ )
+
+ # Store token hash in database for revocation checking
+ token_hash = hashlib.sha256(token.encode()).hexdigest()
token_data = {
- "token_id": secrets.token_urlsafe(32),
+ "token_id": token_id,
"user_id": user.user_id,
"token_hash": token_hash,
"name": name,
--- /dev/null
+"""JWT token helper for Music Assistant authentication.
+
+Future OIDC Support:
+- Consuming external OIDC providers (Google, Keycloak, etc.): Can be added without
+ changes to token structure. MA would validate external OIDC tokens and issue its
+ own JWT tokens (similar to current Home Assistant OAuth flow).
+
+- Acting as OIDC provider for third parties: Would require implementing OAuth2
+ refresh token flow with a dedicated /auth/token endpoint for token refresh.
+ Short-lived access tokens (15 min) + long-lived refresh tokens would be needed
+ for proper OIDC compliance.
+"""
+
+from __future__ import annotations
+
+import secrets
+from datetime import datetime
+from typing import TYPE_CHECKING, Any
+
+import jwt
+
+from music_assistant.helpers.datetime import utc
+
+if TYPE_CHECKING:
+ from music_assistant_models.auth import User
+
+
+class JWTHelper:
+ """Helper class for JWT token operations."""
+
+ def __init__(self, secret_key: str) -> None:
+ """Initialize JWT helper.
+
+ :param secret_key: Secret key for signing JWTs.
+ """
+ self.secret_key = secret_key
+ self.algorithm = "HS256"
+
+ def encode_token(
+ self,
+ user: User,
+ token_id: str,
+ token_name: str,
+ expires_at: datetime,
+ is_long_lived: bool = False,
+ ) -> str:
+ """Encode a JWT token for a user.
+
+ :param user: User object to create token for.
+ :param token_id: Unique token identifier.
+ :param token_name: Human-readable token name.
+ :param expires_at: Token expiration datetime.
+ :param is_long_lived: Whether this is a long-lived token.
+ :return: Encoded JWT token string.
+ """
+ now = utc()
+ payload = {
+ "sub": user.user_id,
+ "jti": token_id,
+ "iat": int(now.timestamp()),
+ "exp": int(expires_at.timestamp()),
+ "username": user.username,
+ "role": user.role.value,
+ "token_name": token_name,
+ "is_long_lived": is_long_lived,
+ }
+
+ return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
+
+ def decode_token(self, token: str, verify_exp: bool = True) -> dict[str, Any]:
+ """Decode and verify a JWT token.
+
+ :param token: JWT token string to decode.
+ :param verify_exp: Whether to verify token expiration.
+ :return: Decoded token payload.
+ :raises jwt.InvalidTokenError: If token is invalid or expired.
+ """
+ options = {"verify_exp": verify_exp}
+ payload: dict[str, Any] = jwt.decode(
+ token,
+ self.secret_key,
+ algorithms=[self.algorithm],
+ options=options,
+ )
+ return payload
+
+ @staticmethod
+ def generate_secret_key() -> str:
+ """Generate a secure random secret key for JWT signing.
+
+ :return: Base64-encoded 256-bit random key.
+ """
+ return secrets.token_urlsafe(32) # 32 bytes = 256 bits
+
+ def get_token_id(self, token: str) -> str | None:
+ """Extract token ID (jti) from JWT without full validation.
+
+ :param token: JWT token string.
+ :return: Token ID or None if invalid.
+ """
+ try:
+ payload: dict[str, Any] = jwt.decode(
+ token,
+ options={"verify_signature": False, "verify_exp": False},
+ )
+ jti = payload.get("jti")
+ return str(jti) if jti else None
+ except Exception:
+ return None