self.database: DatabaseConnection = None # type: ignore[assignment]
self.login_providers: dict[str, LoginProvider] = {}
self.logger = LOGGER
- # Pending OAuth sessions for remote clients (session_id -> token)
- self._pending_oauth_sessions: dict[str, str | None] = {}
self._has_users: bool = False
async def setup(self) -> None:
@api_command("auth/authorization_url", authenticated=False)
async def get_auth_url(
- self, provider_id: str, for_remote_client: bool = False
+ self,
+ provider_id: str,
+ return_url: str | None = None,
) -> dict[str, str | None]:
- """Get OAuth authorization URL for remote authentication.
+ """Get OAuth authorization URL for authentication.
For OAuth providers (like Home Assistant), this returns the URL that
the user should visit in their browser to authorize the application.
:param provider_id: The provider ID (e.g., "hass").
- :param for_remote_client: If True, creates a pending session for remote OAuth flow.
- :return: Dictionary with authorization_url and session_id (if remote).
+ :param return_url: URL to redirect to after OAuth completes.
+ :return: Dictionary with authorization_url.
"""
- # Generate session ID for remote clients
- session_id = None
- return_url = None
-
- if for_remote_client:
- session_id = secrets.token_urlsafe(32)
- # Mark session as pending
- self._pending_oauth_sessions[session_id] = None
- # Use special return URL that will capture the token
- return_url = f"urn:ietf:wg:oauth:2.0:oob:auto:{session_id}"
-
auth_url = await self.get_authorization_url(provider_id, return_url)
if not auth_url:
return {
return {
"authorization_url": auth_url,
- "session_id": session_id, # Only set for remote clients
- }
-
- @api_command("auth/oauth_status", authenticated=False)
- async def check_oauth_status(self, session_id: str) -> dict[str, Any]:
- """Check status of pending OAuth authentication.
-
- Remote clients use this to poll for completion of the OAuth flow.
-
- :param session_id: The session ID from get_auth_url.
- :return: Status and token if authentication completed.
- """
- if session_id not in self._pending_oauth_sessions:
- return {
- "status": "invalid",
- "error": "Invalid or expired session ID",
- }
-
- token = self._pending_oauth_sessions.get(session_id)
- if token is None:
- return {
- "status": "pending",
- "message": "Waiting for user to complete authentication",
- }
-
- # Authentication completed, return token and clean up
- del self._pending_oauth_sessions[session_id]
- return {
- "status": "completed",
- "access_token": token,
}
async def get_authorization_url(
self.logger.exception("Error during OAuth authorization")
return web.json_response({"error": "Authorization failed"}, status=500)
- async def _handle_auth_callback(self, request: web.Request) -> web.Response: # noqa: PLR0915
+ async def _handle_auth_callback(self, request: web.Request) -> web.Response:
"""Handle OAuth callback."""
try:
code = request.query.get("code")
device_name = f"OAuth ({provider_id})"
token = await self.auth.create_token(auth_result.user, device_name)
- if auth_result.return_url and auth_result.return_url.startswith(
- "urn:ietf:wg:oauth:2.0:oob:auto:"
- ):
- session_id = auth_result.return_url.split(":")[-1]
- if session_id in self.auth._pending_oauth_sessions:
- self.auth._pending_oauth_sessions[session_id] = token
- oauth_callback_html_path = str(RESOURCES_DIR.joinpath("oauth_callback.html"))
- async with aiofiles.open(oauth_callback_html_path) as f:
- success_html = await f.read()
-
- success_html = success_html.replace("{TOKEN}", token)
- success_html = success_html.replace("{REDIRECT_URL}", "about:blank")
- success_html = success_html.replace("{REQUIRES_CONSENT}", "false")
-
- return web.Response(text=success_html, content_type="text/html")
-
# Determine redirect URL (use return_url from OAuth flow or default to root)
final_redirect_url = auth_result.return_url or "/"
requires_consent = False