Add JWT-based authentication with backward compatibility (#2891)
authorZtripez <ztripez@vonmatern.org>
Wed, 28 Jan 2026 07:58:34 +0000 (08:58 +0100)
committerGitHub <noreply@github.com>
Wed, 28 Jan 2026 07:58:34 +0000 (08:58 +0100)
* Add JWT-based authentication with backward compatibility

Migrates from hash-based tokens to JWT (JSON Web Tokens) while maintaining
full backward compatibility with existing tokens. This enables stateless
authentication with embedded user claims for better integration with
external systems and OAuth2/OIDC compliance.

JWT Claims Structure:
- Standard claims: sub (user_id), jti (token_id), iat, exp
- Custom claims: username, role, player_filter, provider_filter,
  token_name, is_long_lived

Token Types:
- Short-lived (30 days, auto-renewing on use, sliding window)
- Long-lived (10 years, no auto-renewal for API integrations)

Implementation:
- JWTHelper class for encoding/decoding with HS256 algorithm
- JWT secret key generated and stored in auth.db settings table
- Token verification tries JWT first, falls back to legacy hash lookup
- Database still stores tokens for revocation checking

Benefits:
- Stateless authentication (user info embedded in token)
- Permission scopes available without database lookup
- OAuth2/OIDC compatibility for external integrations
- Standard JWT format for third-party verification

Migration Strategy:
- Automatic: Old tokens work until expiration
- New logins get JWT tokens automatically
- No breaking changes for existing clients

* Fix JWT token expiration check to honor database expiration

Database expiration is the source of truth for token validity, not just
the JWT expiration claim. This ensures manual token expiration (via
database update) works correctly even when JWT exp is still valid.

* Add documentation for future OIDC support

Notes on consuming external OIDC vs acting as OIDC provider, and
refresh token requirements for the latter.

* Clean up JWT implementation: remove dead code and verbose comments

- Remove unused methods: refresh_short_lived_token(), get_user_from_token()
- Remove unused imports: timedelta, UserRole
- Move User import to TYPE_CHECKING block
- Remove TODO comment about refresh tokens (not implementing)
- Simplify inline comments and reduce verbosity
- All tests still passing (36/36)

* Remove player_filter and provider_filter from JWT claims

These values can be dynamically updated, so storing them in the token
would result in stale data. The current values are available from the
database lookup during token validation.

---------

Co-authored-by: Ztripez von Matérn <ztripez@bobby.se>
music_assistant/controllers/webserver/auth.py
music_assistant/helpers/jwt_auth.py [new file with mode: 0644]
pyproject.toml
requirements_all.txt

index 754d37e32124f333aada04bb5063b9e032f99f5f..622663dd986cfb8193bf680417092a3b0694dbf5 100644 (file)
@@ -10,6 +10,7 @@ from datetime import datetime, timedelta
 from sqlite3 import OperationalError
 from typing import TYPE_CHECKING, Any
 
+import jwt as pyjwt
 from music_assistant_models.auth import (
     AuthProviderType,
     AuthToken,
@@ -46,6 +47,7 @@ from music_assistant.helpers.api import api_command
 from music_assistant.helpers.database import DatabaseConnection
 from music_assistant.helpers.datetime import utc
 from music_assistant.helpers.json import json_dumps, json_loads
+from music_assistant.helpers.jwt_auth import JWTHelper
 
 if TYPE_CHECKING:
     from music_assistant.controllers.webserver import WebserverController
@@ -75,6 +77,7 @@ class AuthenticationManager:
         self.login_providers: dict[str, LoginProvider] = {}
         self.logger = LOGGER
         self._has_users: bool = False
+        self.jwt_helper: JWTHelper = None  # type: ignore[assignment]
 
     async def setup(self) -> None:
         """Initialize the authentication manager."""
@@ -90,6 +93,10 @@ class AuthenticationManager:
         # Create database schema and handle migrations
         await self._setup_database()
 
+        # Initialize JWT helper with secret key
+        jwt_secret = await self._get_or_create_jwt_secret()
+        self.jwt_helper = JWTHelper(jwt_secret)
+
         # Setup login providers based on config
         await self._setup_login_providers(allow_self_registration)
 
@@ -257,6 +264,28 @@ class AuthenticationManager:
             await self.database.execute("UPDATE users SET username = LOWER(username)")
             await self.database.commit()
 
+    async def _get_or_create_jwt_secret(self) -> str:
+        """Get or create JWT secret key from database.
+
+        :return: JWT secret key for signing tokens.
+        """
+        # Try to get existing secret
+        if secret_row := await self.database.get_row("settings", {"key": "jwt_secret"}):
+            return str(secret_row["value"])
+
+        # Generate new secret
+        jwt_secret = JWTHelper.generate_secret_key()
+
+        # Store in database
+        await self.database.insert_or_replace(
+            "settings",
+            {"key": "jwt_secret", "value": jwt_secret, "type": "string"},
+        )
+        await self.database.commit()
+
+        self.logger.info("Generated new JWT secret key")
+        return jwt_secret
+
     async def _setup_login_providers(self, allow_self_registration: bool) -> None:
         """
         Set up available login providers based on configuration.
@@ -350,12 +379,60 @@ class AuthenticationManager:
         """
         Authenticate a user with an access token.
 
-        :param token: The access token.
+        Supports both JWT tokens and legacy hash-based tokens for backward compatibility.
+
+        :param token: The access token (JWT or legacy hash token).
         """
-        # Hash the token to look it up
-        token_hash = hashlib.sha256(token.encode()).hexdigest()
+        # Try to decode as JWT first
+        try:
+            payload = self.jwt_helper.decode_token(token, verify_exp=True)
+            token_id = payload.get("jti")
+            user_id = payload.get("sub")
+            is_long_lived = payload.get("is_long_lived", False)
+
+            if not token_id or not user_id:
+                return None
+
+            token_row = await self.database.get_row("auth_tokens", {"token_id": token_id})
+            if not token_row:
+                return None
 
-        # Find token in database
+            # Database expiration is source of truth
+            if token_row["expires_at"]:
+                db_expires_at = datetime.fromisoformat(token_row["expires_at"])
+                if utc() > db_expires_at:
+                    await self.database.delete("auth_tokens", {"token_id": token_id})
+                    return None
+
+            # Update last used timestamp
+            now = utc()
+            updates = {"last_used_at": now.isoformat()}
+
+            if not is_long_lived:
+                # Short-lived token: extend expiration on each use (sliding window)
+                new_expires_at = now + timedelta(days=TOKEN_SHORT_LIVED_EXPIRATION)
+                updates["expires_at"] = new_expires_at.isoformat()
+
+            # Update database
+            await self.database.update(
+                "auth_tokens",
+                {"token_id": token_id},
+                updates,
+            )
+
+            return await self.get_user(user_id)
+
+        except pyjwt.ExpiredSignatureError:
+            if token_id := self.jwt_helper.get_token_id(token):
+                await self.database.delete("auth_tokens", {"token_id": token_id})
+            return None
+        except pyjwt.InvalidTokenError:
+            self.logger.debug("Token is not a valid JWT, trying legacy hash lookup")
+        except Exception as err:
+            self.logger.debug("Error decoding JWT token: %s, trying legacy hash lookup", err)
+
+        # Fallback to legacy hash-based token lookup
+        token_hash = hashlib.sha256(token.encode()).hexdigest()
         token_row = await self.database.get_row("auth_tokens", {"token_hash": token_hash})
         if not token_row:
             return None
@@ -392,13 +469,15 @@ class AuthenticationManager:
         """
         Get token_id from a token string (for tracking revocation).
 
-        :param token: The access token.
+        :param token: The access token (JWT or legacy hash token).
         :return: The token_id or None if token not found.
         """
-        # Hash the token to look it up
-        token_hash = hashlib.sha256(token.encode()).hexdigest()
+        # Try to extract from JWT first
+        if token_id := self.jwt_helper.get_token_id(token):
+            return token_id
 
-        # Find token in database
+        # Fallback: Hash-based lookup for legacy tokens
+        token_hash = hashlib.sha256(token.encode()).hexdigest()
         token_row = await self.database.get_row("auth_tokens", {"token_hash": token_hash})
         if not token_row:
             return None
@@ -783,9 +862,8 @@ class AuthenticationManager:
             Short-lived tokens (False): Auto-renewing on use, expire after 30 days of inactivity.
             Long-lived tokens (True): No auto-renewal, expire after 10 years.
         """
-        # Generate token
-        token = secrets.token_urlsafe(48)
-        token_hash = hashlib.sha256(token.encode()).hexdigest()
+        # Generate unique token ID
+        token_id = secrets.token_urlsafe(32)
 
         # Calculate expiration based on token type
         created_at = utc()
@@ -796,9 +874,19 @@ class AuthenticationManager:
             # Short-lived tokens expire after 30 days (with auto-renewal on use)
             expires_at = created_at + timedelta(days=TOKEN_SHORT_LIVED_EXPIRATION)
 
-        # Store token
+        # Generate JWT token
+        token = self.jwt_helper.encode_token(
+            user=user,
+            token_id=token_id,
+            token_name=name,
+            expires_at=expires_at,
+            is_long_lived=is_long_lived,
+        )
+
+        # Store token hash in database for revocation checking
+        token_hash = hashlib.sha256(token.encode()).hexdigest()
         token_data = {
-            "token_id": secrets.token_urlsafe(32),
+            "token_id": token_id,
             "user_id": user.user_id,
             "token_hash": token_hash,
             "name": name,
diff --git a/music_assistant/helpers/jwt_auth.py b/music_assistant/helpers/jwt_auth.py
new file mode 100644 (file)
index 0000000..8f00db8
--- /dev/null
@@ -0,0 +1,109 @@
+"""JWT token helper for Music Assistant authentication.
+
+Future OIDC Support:
+- Consuming external OIDC providers (Google, Keycloak, etc.): Can be added without
+  changes to token structure. MA would validate external OIDC tokens and issue its
+  own JWT tokens (similar to current Home Assistant OAuth flow).
+
+- Acting as OIDC provider for third parties: Would require implementing OAuth2
+  refresh token flow with a dedicated /auth/token endpoint for token refresh.
+  Short-lived access tokens (15 min) + long-lived refresh tokens would be needed
+  for proper OIDC compliance.
+"""
+
+from __future__ import annotations
+
+import secrets
+from datetime import datetime
+from typing import TYPE_CHECKING, Any
+
+import jwt
+
+from music_assistant.helpers.datetime import utc
+
+if TYPE_CHECKING:
+    from music_assistant_models.auth import User
+
+
+class JWTHelper:
+    """Helper class for JWT token operations."""
+
+    def __init__(self, secret_key: str) -> None:
+        """Initialize JWT helper.
+
+        :param secret_key: Secret key for signing JWTs.
+        """
+        self.secret_key = secret_key
+        self.algorithm = "HS256"
+
+    def encode_token(
+        self,
+        user: User,
+        token_id: str,
+        token_name: str,
+        expires_at: datetime,
+        is_long_lived: bool = False,
+    ) -> str:
+        """Encode a JWT token for a user.
+
+        :param user: User object to create token for.
+        :param token_id: Unique token identifier.
+        :param token_name: Human-readable token name.
+        :param expires_at: Token expiration datetime.
+        :param is_long_lived: Whether this is a long-lived token.
+        :return: Encoded JWT token string.
+        """
+        now = utc()
+        payload = {
+            "sub": user.user_id,
+            "jti": token_id,
+            "iat": int(now.timestamp()),
+            "exp": int(expires_at.timestamp()),
+            "username": user.username,
+            "role": user.role.value,
+            "token_name": token_name,
+            "is_long_lived": is_long_lived,
+        }
+
+        return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
+
+    def decode_token(self, token: str, verify_exp: bool = True) -> dict[str, Any]:
+        """Decode and verify a JWT token.
+
+        :param token: JWT token string to decode.
+        :param verify_exp: Whether to verify token expiration.
+        :return: Decoded token payload.
+        :raises jwt.InvalidTokenError: If token is invalid or expired.
+        """
+        options = {"verify_exp": verify_exp}
+        payload: dict[str, Any] = jwt.decode(
+            token,
+            self.secret_key,
+            algorithms=[self.algorithm],
+            options=options,
+        )
+        return payload
+
+    @staticmethod
+    def generate_secret_key() -> str:
+        """Generate a secure random secret key for JWT signing.
+
+        :return: Base64-encoded 256-bit random key.
+        """
+        return secrets.token_urlsafe(32)  # 32 bytes = 256 bits
+
+    def get_token_id(self, token: str) -> str | None:
+        """Extract token ID (jti) from JWT without full validation.
+
+        :param token: JWT token string.
+        :return: Token ID or None if invalid.
+        """
+        try:
+            payload: dict[str, Any] = jwt.decode(
+                token,
+                options={"verify_signature": False, "verify_exp": False},
+            )
+            jti = payload.get("jti")
+            return str(jti) if jti else None
+        except Exception:
+            return None
index 1cb5f5f1484298bfe49117425e068ee74de79095..fbda22a133ef367ca8727a8996a805b9f8b96e31 100644 (file)
@@ -46,6 +46,7 @@ dependencies = [
   "gql[all]==4.0.0",
   "aiovban>=0.6.3",
   "aiortc>=1.6.0",
+  "pyjwt[crypto]>=2.10.1",
 ]
 description = "Music Assistant"
 license = {text = "Apache-2.0"}
index b7795c5d3c1fa279abd678f2a12d6bc3f5913094..4f14d08e9f8527618a638478879e12090e692c81 100644 (file)
@@ -57,6 +57,7 @@ pycares==4.11.0
 PyChromecast==14.0.9
 pycryptodome==3.23.0
 pyheos==1.0.6
+pyjwt[crypto]>=2.10.1
 pylast==6.0.0
 python-fullykiosk==0.0.14
 python-slugify==8.0.4