Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/proxy/management_endpoints/ui_sso.py
2026-03-26 20:06:14 +08:00

3646 lines
146 KiB
Python

"""
Has all /sso/* routes
/sso/key/generate - handles user signing in with SSO and redirects to /sso/callback
/sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI
/sso/debug/login - handles user signing in with SSO and redirects to /sso/debug/callback
/sso/debug/callback - returns the OpenID object returned by the SSO provider
"""
import asyncio
import base64
import hashlib
import inspect
import os
import secrets
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast
if TYPE_CHECKING:
import httpx
import jwt
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
import litellm
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.caching import DualCache
from litellm.constants import (
LITELLM_UI_SESSION_DURATION,
MAX_SPENDLOG_ROWS_TO_QUERY,
MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE,
MICROSOFT_USER_EMAIL_ATTRIBUTE,
MICROSOFT_USER_FIRST_NAME_ATTRIBUTE,
MICROSOFT_USER_ID_ATTRIBUTE,
MICROSOFT_USER_LAST_NAME_ATTRIBUTE,
)
from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import (
CommonProxyErrors,
LiteLLM_UserTable,
LitellmUserRoles,
Member,
NewTeamRequest,
NewUserRequest,
NewUserResponse,
ProxyErrorTypes,
ProxyException,
SSOUserDefinedValues,
TeamMemberAddRequest,
UserAPIKeyAuth,
)
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken, get_user_object
from litellm.proxy.auth.auth_utils import _has_user_setup_sso
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.admin_ui_utils import (
admin_ui_disabled,
show_missing_vars_in_env,
)
from litellm.proxy.common_utils.html_forms.jwt_display_template import (
jwt_display_template,
)
from litellm.proxy.common_utils.html_forms.ui_login import html_form
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
from litellm.proxy.management_endpoints.sso import CustomMicrosoftSSO
from litellm.proxy.management_endpoints.sso_helper_utils import (
check_is_admin_only_access,
has_admin_ui_access,
)
from litellm.proxy.management_endpoints.team_endpoints import new_team, team_member_add
from litellm.proxy.management_endpoints.types import (
CustomOpenID,
get_litellm_user_role,
is_valid_litellm_user_role,
)
from litellm.proxy.utils import (
PrismaClient,
ProxyLogging,
get_custom_url,
get_server_root_path,
)
from litellm.secret_managers.main import get_secret_bool, str_to_bool
from litellm.types.proxy.management_endpoints.ui_sso import * # noqa: F403, F401
from litellm.types.proxy.management_endpoints.ui_sso import (
DefaultTeamSSOParams,
MicrosoftGraphAPIUserGroupDirectoryObject,
MicrosoftGraphAPIUserGroupResponse,
MicrosoftServicePrincipalTeam,
RoleMappings,
TeamMappings,
)
from litellm.types.proxy.ui_sso import ParsedOpenIDResult
if TYPE_CHECKING:
from fastapi_sso.sso.base import OpenID
else:
from typing import Any as OpenID
router = APIRouter()
# OAuth bearer credential fields that must not appear in SSO debug responses
# (received_response is included in restricted-group error messages).
# Metadata fields (token_type, expires_in, scope) are intentionally kept so
# response convertors see the same fields in the PKCE path as in the non-PKCE path.
_OAUTH_TOKEN_FIELDS = frozenset({"access_token", "id_token", "refresh_token"})
def normalize_email(email: Optional[str]) -> Optional[str]:
"""
Normalize email address to lowercase for consistent storage and comparison.
Email addresses should be treated as case-insensitive for SSO purposes,
even though RFC 5321 technically allows case-sensitive local parts.
This prevents issues where SSO providers return emails with different casing
than what's stored in the database.
Args:
email: Email address to normalize, can be None
Returns:
Lowercased email address, or None if input is None
"""
if email is None:
return None
return email.lower() if isinstance(email, str) else email
def determine_role_from_groups(
user_groups: List[str],
role_mappings: "RoleMappings",
) -> Optional[LitellmUserRoles]:
"""
Determine the highest privilege role for a user based on their groups.
Role hierarchy (highest to lowest):
- proxy_admin
- proxy_admin_viewer
- internal_user
- internal_user_viewer
Args:
user_groups: List of group names from the SSO token
role_mappings: RoleMappings configuration object
Returns:
The highest privilege role found, or default_role if no matches, or None
"""
if not role_mappings.roles:
# No role mappings configured, return default_role
return role_mappings.default_role
# Role hierarchy (highest to lowest)
role_hierarchy = [
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
]
# Convert user_groups to a set for efficient lookup
user_groups_set = set(user_groups) if isinstance(user_groups, list) else set()
# Find the highest privilege role the user belongs to
for role in role_hierarchy:
if role in role_mappings.roles:
role_groups = role_mappings.roles[role]
if isinstance(role_groups, list) and user_groups_set.intersection(
set(role_groups)
):
verbose_proxy_logger.debug(
f"User groups {user_groups} matched role '{role.value}' via groups: {role_groups}"
)
return role
# No matching groups found, return default_role
verbose_proxy_logger.debug(
f"User groups {user_groups} did not match any role mappings, using default_role: {role_mappings.default_role}"
)
return role_mappings.default_role
def process_sso_jwt_access_token(
access_token_str: Optional[str],
sso_jwt_handler: Optional[JWTHandler],
result: Union[OpenID, dict, None],
role_mappings: Optional["RoleMappings"] = None,
) -> None:
"""
Process SSO JWT access token and extract team IDs and user role if available.
This function decodes the JWT access token and extracts team IDs and user
role, then sets them on the result object. Role extraction from the access
token is needed because some SSO providers (e.g., Keycloak) do not include
role claims in the UserInfo endpoint response.
Args:
access_token_str: The JWT access token string
sso_jwt_handler: SSO-specific JWT handler for team ID extraction
result: The SSO result object to update with team IDs and role
role_mappings: Optional role mappings configuration for group-based role determination
"""
if access_token_str and result:
import jwt
try:
access_token_payload = jwt.decode(
access_token_str, options={"verify_signature": False}
)
except jwt.exceptions.DecodeError:
verbose_proxy_logger.debug(
"Access token is not a valid JWT (possibly an opaque token), skipping JWT-based extraction"
)
return
# Extract team IDs from access token if sso_jwt_handler is available
if sso_jwt_handler:
if isinstance(result, dict):
result_team_ids: Optional[List[str]] = result.get("team_ids", [])
if not result_team_ids:
team_ids = sso_jwt_handler.get_team_ids_from_jwt(
access_token_payload
)
result["team_ids"] = team_ids
else:
result_team_ids = getattr(result, "team_ids", []) if result else []
if not result_team_ids:
team_ids = sso_jwt_handler.get_team_ids_from_jwt(
access_token_payload
)
setattr(result, "team_ids", team_ids)
# Extract user role from access token if not already set from UserInfo
existing_role = (
result.get("user_role")
if isinstance(result, dict)
else getattr(result, "user_role", None)
)
if existing_role is None:
user_role: Optional[LitellmUserRoles] = None
# Try role_mappings first (group-based role determination)
if role_mappings is not None and role_mappings.roles:
group_claim = role_mappings.group_claim
user_groups_raw: Any = get_nested_value(
access_token_payload, group_claim
)
user_groups: List[str] = []
if isinstance(user_groups_raw, list):
user_groups = [str(g) for g in user_groups_raw]
elif isinstance(user_groups_raw, str):
user_groups = [
g.strip() for g in user_groups_raw.split(",") if g.strip()
]
elif user_groups_raw is not None:
user_groups = [str(user_groups_raw)]
if user_groups:
user_role = determine_role_from_groups(user_groups, role_mappings)
verbose_proxy_logger.debug(
f"Determined role '{user_role}' from access token groups '{user_groups}' using role_mappings"
)
elif role_mappings.default_role:
user_role = role_mappings.default_role
# Fallback: try GENERIC_USER_ROLE_ATTRIBUTE on the access token payload
if user_role is None:
generic_user_role_attribute_name = os.getenv(
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
)
user_role_from_token = get_nested_value(
access_token_payload, generic_user_role_attribute_name
)
if user_role_from_token is not None:
user_role = get_litellm_user_role(user_role_from_token)
verbose_proxy_logger.debug(
f"Extracted role '{user_role}' from access token field '{generic_user_role_attribute_name}'"
)
if user_role is not None:
if isinstance(result, dict):
result["user_role"] = user_role
else:
setattr(result, "user_role", user_role)
verbose_proxy_logger.debug(
f"Set user_role='{user_role}' from JWT access token"
)
@router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False)
async def google_login(
request: Request,
source: Optional[str] = None,
key: Optional[str] = None,
existing_key: Optional[str] = None,
): # noqa: PLR0915
"""
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
Example:
"""
from litellm.proxy.proxy_server import (
premium_user,
prisma_client,
user_custom_ui_sso_sign_in_handler,
)
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
####### Check if UI is disabled #######
_disable_ui_flag = os.getenv("DISABLE_ADMIN_UI")
if _disable_ui_flag is not None:
is_disabled = str_to_bool(value=_disable_ui_flag)
if is_disabled:
return admin_ui_disabled()
####### Check if user is a Enterprise / Premium User #######
if (
microsoft_client_id is not None
or google_client_id is not None
or generic_client_id is not None
):
if premium_user is not True:
# Check if under 'free SSO user' limit
if prisma_client is not None:
total_users = await prisma_client.db.litellm_usertable.count()
if total_users and total_users > 5:
raise ProxyException(
message="You must be a LiteLLM Enterprise user to use SSO for more than 5 users. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/cx9p-5yf-2nm/litellm-introductions You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
type=ProxyErrorTypes.auth_error,
param="premium_user",
code=status.HTTP_403_FORBIDDEN,
)
else:
raise ProxyException(
message=CommonProxyErrors.db_not_connected_error.value,
type=ProxyErrorTypes.auth_error,
param="premium_user",
code=status.HTTP_403_FORBIDDEN,
)
####### Detect DB + MASTER KEY in .env #######
missing_env_vars = show_missing_vars_in_env()
if missing_env_vars is not None:
return missing_env_vars
ui_username = os.getenv("UI_USERNAME")
# get url from request - always use regular callback, but set state for CLI
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
request=request,
sso_callback_route="sso/callback",
existing_key=existing_key,
)
# Store CLI key in state for OAuth flow
cli_state: Optional[str] = SSOAuthenticationHandler._get_cli_state(
source=source,
key=key,
existing_key=existing_key,
)
# check if user defined a custom auth sso sign in handler, if yes, use it
if user_custom_ui_sso_sign_in_handler is not None:
try:
from litellm_enterprise.proxy.auth.custom_sso_handler import ( # type: ignore[import-untyped]
EnterpriseCustomSSOHandler,
)
return await EnterpriseCustomSSOHandler.handle_custom_ui_sso_sign_in(
request=request,
)
except ImportError:
raise ValueError(
"Enterprise features are not available. Custom UI SSO sign-in requires LiteLLM Enterprise."
)
# Check if we should use SSO handler
if (
SSOAuthenticationHandler.should_use_sso_handler(
microsoft_client_id=microsoft_client_id,
google_client_id=google_client_id,
generic_client_id=generic_client_id,
)
is True
):
verbose_proxy_logger.info(f"Redirecting to SSO login for {redirect_url}")
return await SSOAuthenticationHandler.get_sso_login_redirect(
redirect_url=redirect_url,
microsoft_client_id=microsoft_client_id,
google_client_id=google_client_id,
generic_client_id=generic_client_id,
state=cli_state,
)
elif ui_username is not None:
# No Google, Microsoft SSO
# Use UI Credentials set in .env
from fastapi.responses import HTMLResponse
return HTMLResponse(content=html_form, status_code=200)
else:
from fastapi.responses import HTMLResponse
return HTMLResponse(content=html_form, status_code=200)
def generic_response_convertor(
response,
jwt_handler: JWTHandler,
sso_jwt_handler: Optional[JWTHandler] = None,
role_mappings: Optional["RoleMappings"] = None,
team_mappings: Optional["TeamMappings"] = None,
) -> CustomOpenID:
generic_user_id_attribute_name = os.getenv(
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
)
generic_user_display_name_attribute_name = os.getenv(
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
)
generic_user_email_attribute_name = os.getenv(
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
)
generic_user_first_name_attribute_name = os.getenv(
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
)
generic_user_last_name_attribute_name = os.getenv(
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
)
generic_provider_attribute_name = os.getenv(
"GENERIC_USER_PROVIDER_ATTRIBUTE", "provider"
)
generic_user_role_attribute_name = os.getenv("GENERIC_USER_ROLE_ATTRIBUTE", "role")
generic_user_extra_attributes = os.getenv("GENERIC_USER_EXTRA_ATTRIBUTES", None)
verbose_proxy_logger.debug(
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}"
)
all_teams = []
if sso_jwt_handler is not None:
team_ids = sso_jwt_handler.get_team_ids_from_jwt(cast(dict, response))
all_teams.extend(team_ids)
if team_mappings is not None and team_mappings.team_ids_jwt_field is not None:
team_ids_from_db_mapping: Optional[List[str]] = get_nested_value(
data=cast(dict, response),
key_path=team_mappings.team_ids_jwt_field,
default=[],
)
if team_ids_from_db_mapping:
all_teams.extend(team_ids_from_db_mapping)
verbose_proxy_logger.debug(
f"Loaded team_ids from DB team_mappings.team_ids_jwt_field='{team_mappings.team_ids_jwt_field}': {team_ids_from_db_mapping}"
)
else:
team_ids = jwt_handler.get_team_ids_from_jwt(cast(dict, response))
all_teams.extend(team_ids)
# Determine user role based on role_mappings if available
# Only apply role_mappings for GENERIC SSO provider
user_role: Optional[LitellmUserRoles] = None
if role_mappings is not None and role_mappings.provider.lower() in [
"generic",
"okta",
]:
# Use role_mappings to determine role from groups
group_claim = role_mappings.group_claim
user_groups_raw: Any = get_nested_value(response, group_claim)
# Handle different formats: could be a list, string (comma-separated), or single value
user_groups: List[str] = []
if isinstance(user_groups_raw, list):
user_groups = [str(g) for g in user_groups_raw]
elif isinstance(user_groups_raw, str):
# Handle comma-separated string
user_groups = [g.strip() for g in user_groups_raw.split(",") if g.strip()]
elif user_groups_raw is not None:
# Single value
user_groups = [str(user_groups_raw)]
if user_groups:
user_role = determine_role_from_groups(user_groups, role_mappings)
verbose_proxy_logger.debug(
f"Determined role '{user_role.value if user_role else None}' from groups '{user_groups}' using role_mappings"
)
else:
# No groups found, use default_role
user_role = role_mappings.default_role
verbose_proxy_logger.debug(
f"No groups found in '{group_claim}', using default_role: {role_mappings.default_role}"
)
# Fallback to existing logic if role_mappings not used
if user_role is None:
user_role_from_sso = get_nested_value(
response, generic_user_role_attribute_name
)
if user_role_from_sso is not None:
role = get_litellm_user_role(user_role_from_sso)
if role is not None:
user_role = role
verbose_proxy_logger.debug(
f"Found valid LitellmUserRoles '{role.value}' from SSO attribute '{generic_user_role_attribute_name}'"
)
# Build extra_fields dict from GENERIC_USER_EXTRA_ATTRIBUTES if specified
extra_fields: Optional[Dict[str, Any]] = None
if generic_user_extra_attributes:
extra_fields = {}
for attr_name in generic_user_extra_attributes.split(","):
attr_name = attr_name.strip()
extra_fields[attr_name] = get_nested_value(response, attr_name)
return CustomOpenID(
id=get_nested_value(response, generic_user_id_attribute_name),
display_name=get_nested_value(
response, generic_user_display_name_attribute_name
),
email=normalize_email(
get_nested_value(response, generic_user_email_attribute_name)
),
first_name=get_nested_value(response, generic_user_first_name_attribute_name),
last_name=get_nested_value(response, generic_user_last_name_attribute_name),
provider=get_nested_value(response, generic_provider_attribute_name),
team_ids=all_teams,
user_role=user_role,
extra_fields=extra_fields,
)
def _setup_generic_sso_env_vars(
generic_client_id: str, redirect_url: str
) -> Tuple[str, List[str], str, str, str, bool]:
"""Setup and validate Generic SSO environment variables."""
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
generic_authorization_endpoint = os.getenv("GENERIC_AUTHORIZATION_ENDPOINT", None)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
generic_include_client_id = (
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
)
# Validate required environment variables
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
return (
generic_client_secret,
generic_scope,
generic_authorization_endpoint,
generic_token_endpoint,
generic_userinfo_endpoint,
generic_include_client_id,
)
async def _setup_team_mappings() -> Optional["TeamMappings"]:
"""Setup team mappings from SSO database settings."""
team_mappings: Optional["TeamMappings"] = None
try:
from litellm.proxy.utils import get_prisma_client_or_throw
prisma_client = get_prisma_client_or_throw(
"Prisma client is None, connect a database to your proxy"
)
sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique(
where={"id": "sso_config"}
)
if sso_db_record and sso_db_record.sso_settings:
sso_settings_dict = dict(sso_db_record.sso_settings)
team_mappings_data = sso_settings_dict.get("team_mappings")
if team_mappings_data:
from litellm.types.proxy.management_endpoints.ui_sso import TeamMappings
if isinstance(team_mappings_data, dict):
team_mappings = TeamMappings(**team_mappings_data)
elif isinstance(team_mappings_data, TeamMappings):
team_mappings = team_mappings_data
if team_mappings and team_mappings.team_ids_jwt_field:
verbose_proxy_logger.debug(
f"Loaded team_mappings with team_ids_jwt_field: '{team_mappings.team_ids_jwt_field}'"
)
except Exception as e:
verbose_proxy_logger.debug(
f"Could not load team_mappings from database: {e}. Continuing with config-based team mapping."
)
return team_mappings
async def _setup_role_mappings() -> Optional["RoleMappings"]:
"""Setup role mappings from SSO database settings."""
role_mappings: Optional["RoleMappings"] = None
try:
from litellm.proxy.utils import get_prisma_client_or_throw
prisma_client = get_prisma_client_or_throw(
"Prisma client is None, connect a database to your proxy"
)
sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique(
where={"id": "sso_config"}
)
if sso_db_record and sso_db_record.sso_settings:
sso_settings_dict = dict(sso_db_record.sso_settings)
role_mappings_data = sso_settings_dict.get("role_mappings")
if role_mappings_data:
from litellm.types.proxy.management_endpoints.ui_sso import RoleMappings
if isinstance(role_mappings_data, dict):
role_mappings = RoleMappings(**role_mappings_data)
elif isinstance(role_mappings_data, RoleMappings):
role_mappings = role_mappings_data
if role_mappings:
verbose_proxy_logger.debug(
f"Loaded role_mappings for provider '{role_mappings.provider}'"
)
except Exception as e:
verbose_proxy_logger.debug(
f"Could not load role_mappings from database: {e}. Continuing with existing role logic."
)
generic_role_mappings = os.getenv("GENERIC_ROLE_MAPPINGS_ROLES", None)
generic_role_mappings_group_claim = os.getenv(
"GENERIC_ROLE_MAPPINGS_GROUP_CLAIM", None
)
generic_role_mappoings_default_role = os.getenv(
"GENERIC_ROLE_MAPPINGS_DEFAULT_ROLE", None
)
if generic_role_mappings is not None:
verbose_proxy_logger.debug(
"Found role_mappings for generic provider in environment variables"
)
import ast
try:
generic_user_role_mappings_data: Dict[
LitellmUserRoles, List[str]
] = ast.literal_eval(generic_role_mappings)
if isinstance(generic_user_role_mappings_data, dict):
from litellm.types.proxy.management_endpoints.ui_sso import RoleMappings
role_mappings_data = {
"provider": "generic",
"group_claim": generic_role_mappings_group_claim,
"default_role": generic_role_mappoings_default_role,
"roles": generic_user_role_mappings_data,
}
role_mappings = RoleMappings(**role_mappings_data)
verbose_proxy_logger.debug(
f"Loaded role_mappings from environments for provider '{role_mappings.provider}'."
)
return role_mappings
except TypeError as e:
verbose_proxy_logger.warning(
f"Error decoding role mappings from environment variables: {e}. Continuing with existing role logic."
)
return role_mappings
def _parse_generic_sso_headers() -> dict:
"""Parse comma-separated GENERIC_SSO_HEADERS env var into a dict."""
raw = os.getenv("GENERIC_SSO_HEADERS", None)
if raw is None:
return {}
result: Dict[str, str] = {}
for header in raw.split(","):
header = header.strip()
if header:
key, value = header.split("=")
result[key] = value
return result
def _handle_generic_sso_error(
e: Exception,
generic_authorization_endpoint: Optional[str],
generic_token_endpoint: Optional[str],
additional_headers: dict,
) -> None:
"""Handle errors from generic SSO verify_and_process. Always re-raises."""
error_message = str(e)
# Surface a helpful PKCE misconfiguration hint only when:
# 1. The error mentions PKCE/code verifier, AND
# 2. PKCE is not currently configured (GENERIC_CLIENT_USE_PKCE != true)
pkce_configured = os.getenv("GENERIC_CLIENT_USE_PKCE", "false").lower() == "true"
if not pkce_configured and (
"PKCE" in error_message or "code verifier" in error_message.lower()
):
is_okta = (
generic_authorization_endpoint
and "okta" in generic_authorization_endpoint.lower()
) or (generic_token_endpoint and "okta" in generic_token_endpoint.lower())
provider_name = "Okta" if is_okta else "Your OAuth provider"
detailed_message = (
f"SSO authentication failed: {provider_name} requires PKCE (Proof Key for Code Exchange) "
f"but it's not enabled in your LiteLLM configuration.\n\n"
f"SOLUTION: Add this environment variable and restart your proxy:\n"
f" GENERIC_CLIENT_USE_PKCE=true\n\n"
)
if is_okta:
detailed_message += (
"For AWS ECS: Add the environment variable to your task definition.\n"
"For Docker: Add -e GENERIC_CLIENT_USE_PKCE=true to your docker run command.\n"
"For .env file: Add GENERIC_CLIENT_USE_PKCE=true to your .env file.\n\n"
)
detailed_message += f"Original error: {error_message}"
raise ProxyException(
message=detailed_message,
type=ProxyErrorTypes.auth_error,
param="GENERIC_CLIENT_USE_PKCE",
code=status.HTTP_401_UNAUTHORIZED,
)
if isinstance(e, ProxyException):
verbose_proxy_logger.error(
"SSO authentication failed: %s. Passed in headers: %s",
e,
additional_headers,
)
else:
verbose_proxy_logger.exception(
"Error verifying and processing generic SSO: %s. Passed in headers: %s",
e,
additional_headers,
)
raise e
async def get_generic_sso_response(
request: Request,
jwt_handler: JWTHandler,
sso_jwt_handler: Optional[
JWTHandler
], # sso specific jwt handler - used for restricted sso group access control
generic_client_id: str,
redirect_url: str,
) -> Tuple[Union[OpenID, dict], Optional[dict]]: # return received response
# make generic sso provider
from fastapi_sso.sso.base import DiscoveryDocument
from fastapi_sso.sso.generic import create_provider
received_response: Optional[dict] = None
# Setup environment variables
(
generic_client_secret,
generic_scope,
generic_authorization_endpoint,
generic_token_endpoint,
generic_userinfo_endpoint,
generic_include_client_id,
) = _setup_generic_sso_env_vars(generic_client_id, redirect_url)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
role_mappings = await _setup_role_mappings()
team_mappings = await _setup_team_mappings()
def response_convertor(response, client):
nonlocal received_response # return for user debugging
received_response = response
return generic_response_convertor(
response=response,
jwt_handler=jwt_handler,
sso_jwt_handler=sso_jwt_handler,
role_mappings=role_mappings,
team_mappings=team_mappings,
)
SSOProvider = create_provider(
name="oidc",
discovery_document=discovery,
response_convertor=response_convertor,
)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
scope=generic_scope,
)
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
additional_generic_sso_headers_dict = _parse_generic_sso_headers()
code_verifier: Optional[str] = None # assigned inside try; initialized for type tracking
try:
token_exchange_params = await SSOAuthenticationHandler.prepare_token_exchange_parameters(
request=request,
generic_include_client_id=generic_include_client_id,
)
# Extract code_verifier (and the cache key for deferred deletion) before calling fastapi-sso
code_verifier = token_exchange_params.pop("code_verifier", None)
pkce_cache_key = token_exchange_params.pop("_pkce_cache_key", None)
# Get authorization code from query params (only used in the PKCE path below;
# the non-PKCE path delegates to verify_and_process which handles OAuth error
# callbacks — user-denied, CSRF mismatch — internally).
authorization_code = request.query_params.get("code")
if code_verifier:
if not authorization_code:
raise ProxyException(
message="Missing authorization code in callback",
type=ProxyErrorTypes.auth_error,
param="code",
code=status.HTTP_400_BAD_REQUEST,
)
if not generic_client_id:
raise ProxyException(
message="GENERIC_CLIENT_ID must be set when PKCE is enabled",
type=ProxyErrorTypes.auth_error,
param="GENERIC_CLIENT_ID",
code=status.HTTP_401_UNAUTHORIZED,
)
if not generic_token_endpoint:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT must be set when PKCE is enabled",
type=ProxyErrorTypes.auth_error,
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_401_UNAUTHORIZED,
)
# All guards above raise, so authorization_code is a non-empty str here.
# Use an explicit type guard rather than assert (assert is a no-op with -O).
if not isinstance(authorization_code, str):
raise ProxyException(
message="Missing authorization code in callback",
type=ProxyErrorTypes.auth_error,
param="code",
code=status.HTTP_400_BAD_REQUEST,
)
combined_response = await SSOAuthenticationHandler._pkce_token_exchange(
authorization_code=authorization_code,
code_verifier=code_verifier,
client_id=generic_client_id,
client_secret=generic_client_secret,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
include_client_id=generic_include_client_id,
redirect_url=redirect_url,
additional_headers=additional_generic_sso_headers_dict,
)
# Pass the full response so custom response_convertor implementations
# can access all fields (including id_token for claim extraction).
result = response_convertor(combined_response, generic_sso)
# Strip bearer credentials from combined_response before storing in
# received_response. received_response may appear in restricted-group
# error messages — bearer tokens (access_token, id_token, refresh_token)
# must not be exposed to callers.
# Assign directly rather than relying on nonlocal mutation so that Pyright
# can track that received_response is non-None from this point on.
received_response = {
k: v for k, v in combined_response.items() if k not in _OAUTH_TOKEN_FIELDS
}
# In the PKCE path verify_and_process is skipped, so generic_sso.access_token
# is never set. Read the token directly from the exchange response instead so
# process_sso_jwt_access_token can extract JWT-embedded roles/teams.
access_token_str: Optional[str] = combined_response.get("access_token")
else:
result = await generic_sso.verify_and_process(
request,
params=token_exchange_params,
headers=additional_generic_sso_headers_dict,
)
access_token_str = generic_sso.access_token
process_sso_jwt_access_token(
access_token_str, sso_jwt_handler, result, role_mappings=role_mappings
)
# Delete the single-use PKCE verifier only after all downstream processing
# (response_convertor and process_sso_jwt_access_token) has completed
# successfully. Deleting earlier would consume the verifier on a transient
# failure, forcing the user to restart the entire OAuth flow from scratch.
if pkce_cache_key:
await SSOAuthenticationHandler._delete_pkce_verifier(pkce_cache_key)
except Exception as e:
_handle_generic_sso_error(
e,
generic_authorization_endpoint,
generic_token_endpoint,
additional_generic_sso_headers_dict,
)
verbose_proxy_logger.debug("generic result: %s", result)
return result or {}, received_response
async def create_team_member_add_task(team_id, user_info):
"""Create a task for adding a member to a team."""
try:
member = Member(user_id=user_info.user_id, role="user")
team_member_add_request = TeamMemberAddRequest(
member=member,
team_id=team_id,
)
return await team_member_add(
data=team_member_add_request,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
except Exception as e:
verbose_proxy_logger.debug(
f"[Non-Blocking] Error trying to add sso user to db: {e}"
)
async def add_missing_team_member(
user_info: Union[NewUserResponse, LiteLLM_UserTable], sso_teams: List[str]
):
"""
- Get missing teams (diff b/w user_info.team_ids and sso_teams)
- Add missing user to missing teams
"""
# Handle None as empty list for new users
user_teams = user_info.teams if user_info.teams is not None else []
missing_teams = set(sso_teams) - set(user_teams)
missing_teams_list = list(missing_teams)
tasks = []
tasks = [
create_team_member_add_task(team_id, user_info)
for team_id in missing_teams_list
]
try:
await asyncio.gather(*tasks)
except Exception as e:
verbose_proxy_logger.debug(
f"[Non-Blocking] Error trying to add sso user to db: {e}"
)
def get_disabled_non_admin_personal_key_creation():
key_generation_settings = litellm.key_generation_settings
if key_generation_settings is None:
return False
personal_key_generation = (
key_generation_settings.get("personal_key_generation") or {}
)
allowed_user_roles = personal_key_generation.get("allowed_user_roles") or []
return bool("proxy_admin" in allowed_user_roles)
async def get_existing_user_info_from_db(
user_id: Optional[str],
user_email: Optional[str],
prisma_client: PrismaClient,
user_api_key_cache: DualCache,
proxy_logging_obj: ProxyLogging,
) -> Optional[LiteLLM_UserTable]:
try:
user_info = await get_user_object(
user_id=user_id,
user_email=user_email,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_id_upsert=False,
parent_otel_span=None,
proxy_logging_obj=proxy_logging_obj,
sso_user_id=user_id,
)
except Exception as e:
verbose_proxy_logger.debug(f"Error getting user object: {e}")
user_info = None
return user_info
async def get_user_info_from_db(
result: Union[CustomOpenID, OpenID, dict],
prisma_client: PrismaClient,
user_api_key_cache: DualCache,
proxy_logging_obj: ProxyLogging,
user_email: Optional[str],
user_defined_values: Optional[SSOUserDefinedValues],
alternate_user_id: Optional[str] = None,
) -> Optional[Union[LiteLLM_UserTable, NewUserResponse]]:
try:
potential_user_ids = []
if alternate_user_id is not None:
potential_user_ids.append(alternate_user_id)
if not isinstance(result, dict):
_id = getattr(result, "id", None)
if _id is not None and isinstance(_id, str):
potential_user_ids.append(_id)
else:
_id = result.get("id", None)
if _id is not None and isinstance(_id, str):
potential_user_ids.append(_id)
user_email = normalize_email(
getattr(result, "email", None)
if not isinstance(result, dict)
else result.get("email", None)
)
user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]] = None
for user_id in potential_user_ids:
user_info = await get_existing_user_info_from_db(
user_id=user_id,
user_email=user_email,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
if user_info is not None:
break
verbose_proxy_logger.debug(
f"user_info: {user_info}; litellm.default_internal_user_params: {litellm.default_internal_user_params}"
)
# Upsert SSO User to LiteLLM DB
user_info = await SSOAuthenticationHandler.upsert_sso_user(
result=result,
user_info=user_info,
user_email=user_email,
user_defined_values=user_defined_values,
prisma_client=prisma_client,
)
await SSOAuthenticationHandler.add_user_to_teams_from_sso_response(
result=result,
user_info=user_info,
)
return user_info
except Exception as e:
verbose_proxy_logger.exception(
f"[Non-Blocking] Error trying to add sso user to db: {e}"
)
return None
def _should_use_role_from_sso_response(sso_role: Optional[str]) -> bool:
"""returns true if SSO upsert should use the 'role' defined on the SSO response"""
if sso_role is None:
return False
if not is_valid_litellm_user_role(sso_role):
verbose_proxy_logger.debug(
f"SSO role '{sso_role}' is not a valid LiteLLM user role. "
"Ignoring role from SSO response. See LitellmUserRoles enum for valid roles."
)
return False
return True
def _build_sso_user_update_data(
result: Optional[Union["CustomOpenID", OpenID, dict]],
user_email: Optional[str],
user_id: Optional[str],
) -> dict:
"""
Build the update data dictionary for SSO user upsert.
Args:
result: The SSO response containing user information
user_email: The user's email from SSO
user_id: The user's ID for logging purposes
Returns:
dict: Update data containing user_email and optionally user_role if valid
"""
update_data: dict = {"user_email": normalize_email(user_email)}
# Get SSO role from result and include if valid
sso_role = getattr(result, "user_role", None)
if sso_role is not None:
# Convert enum to string if needed
sso_role_str = (
sso_role.value if isinstance(sso_role, LitellmUserRoles) else sso_role
)
# Only include if it's a valid LiteLLM role
if _should_use_role_from_sso_response(sso_role_str):
update_data["user_role"] = sso_role_str
verbose_proxy_logger.info(
f"Updating user {user_id} role from SSO: {sso_role_str}"
)
return update_data
def apply_user_info_values_to_sso_user_defined_values(
user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]],
user_defined_values: Optional[SSOUserDefinedValues],
) -> Optional[SSOUserDefinedValues]:
if user_defined_values is None:
return None
if user_info is not None and user_info.user_id is not None:
user_defined_values["user_id"] = user_info.user_id
# SSO role takes precedence - only use DB role if SSO didn't provide one
# This ensures SSO is the authoritative source for user roles
sso_role = user_defined_values.get("user_role")
db_role = user_info.user_role if user_info else None
if _should_use_role_from_sso_response(sso_role):
# SSO provided a valid role, keep it and log that we're using it
verbose_proxy_logger.info(
f"Using SSO role: {sso_role} (DB role was: {db_role})"
)
else:
# SSO didn't provide a valid role, fall back to DB role or default
if user_info is None or user_info.user_role is None:
user_defined_values[
"user_role"
] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
verbose_proxy_logger.debug(
"No SSO or DB role found, using default: INTERNAL_USER_VIEW_ONLY"
)
else:
user_defined_values["user_role"] = user_info.user_role
verbose_proxy_logger.debug(f"Using DB role: {user_info.user_role}")
# Preserve the user's existing models from the database
if user_info is not None and hasattr(user_info, "models") and user_info.models:
user_defined_values["models"] = user_info.models
return user_defined_values
async def check_and_update_if_proxy_admin_id(
user_role: str, user_id: str, prisma_client: Optional[PrismaClient]
):
"""
- Check if user role in DB is admin
- If not, update user role in DB to admin role
"""
proxy_admin_id = os.getenv("PROXY_ADMIN_ID")
if proxy_admin_id is not None and proxy_admin_id == user_id:
if user_role and user_role == LitellmUserRoles.PROXY_ADMIN.value:
return user_role
if prisma_client:
await prisma_client.db.litellm_usertable.update(
where={"user_id": user_id},
data={"user_role": LitellmUserRoles.PROXY_ADMIN.value},
)
user_role = LitellmUserRoles.PROXY_ADMIN.value
return user_role
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
async def auth_callback(request: Request, state: Optional[str] = None): # noqa: PLR0915
"""Verify login"""
verbose_proxy_logger.info(f"Starting SSO callback with state: {state}")
# Check if this is a CLI login (state starts with our CLI prefix)
from litellm.constants import LITELLM_CLI_SESSION_TOKEN_PREFIX
from litellm.proxy._types import LiteLLM_JWTAuth
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.proxy_server import (
general_settings,
jwt_handler,
master_key,
prisma_client,
user_api_key_cache,
)
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
sso_jwt_handler: Optional[JWTHandler] = None
ui_access_mode = general_settings.get("ui_access_mode", None)
if ui_access_mode is not None and isinstance(ui_access_mode, dict):
sso_jwt_handler = JWTHandler()
sso_jwt_handler.update_environment(
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
litellm_jwtauth=LiteLLM_JWTAuth(
team_ids_jwt_field=general_settings.get("ui_access_mode", {}).get(
"sso_group_jwt_field", None
),
),
leeway=0,
)
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
received_response: Optional[dict] = None
# get url from request
if master_key is None:
raise ProxyException(
message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
type=ProxyErrorTypes.auth_error,
param="master_key",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
request=request, sso_callback_route="sso/callback"
)
verbose_proxy_logger.info(f"Redirecting to {redirect_url}")
result = None
if google_client_id is not None:
result = await GoogleSSOHandler.get_google_callback_response(
request=request,
google_client_id=google_client_id,
redirect_url=redirect_url,
)
elif microsoft_client_id is not None:
result = await MicrosoftSSOHandler.get_microsoft_callback_response(
request=request,
microsoft_client_id=microsoft_client_id,
redirect_url=redirect_url,
)
elif generic_client_id is not None:
result, received_response = await get_generic_sso_response(
request=request,
jwt_handler=jwt_handler,
generic_client_id=generic_client_id,
redirect_url=redirect_url,
sso_jwt_handler=sso_jwt_handler,
)
if result is None:
raise HTTPException(
status_code=401,
detail="Result not returned by SSO provider.",
)
if state and state.startswith(f"{LITELLM_CLI_SESSION_TOKEN_PREFIX}:"):
# Extract the key ID and existing_key from the state
# State format: {PREFIX}:{key}:{existing_key} or {PREFIX}:{key}
state_parts = state.split(":", 2) # Split into max 3 parts
key_id = state_parts[1] if len(state_parts) > 1 else None
existing_key = state_parts[2] if len(state_parts) > 2 else None
verbose_proxy_logger.info(
f"CLI SSO callback detected for key: {key_id}, existing_key: {existing_key}"
)
return await cli_sso_callback(
request=request, key=key_id, existing_key=existing_key, result=result
)
return await SSOAuthenticationHandler.get_redirect_response_from_openid(
result=result,
request=request,
received_response=received_response,
generic_client_id=generic_client_id,
ui_access_mode=ui_access_mode,
)
async def cli_sso_callback(
request: Request,
key: Optional[str] = None,
existing_key: Optional[str] = None,
result: Optional[Union[OpenID, dict]] = None,
):
"""CLI SSO callback - stores session info for JWT generation on polling"""
verbose_proxy_logger.info(
f"CLI SSO callback for key: {key}, existing_key: {existing_key}"
)
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
if not key or not key.startswith("sk-"):
raise HTTPException(
status_code=400,
detail="Invalid key parameter. Must be a valid key ID starting with 'sk-'",
)
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
if result is None:
raise HTTPException(
status_code=500,
detail="SSO authentication failed - no result returned from provider",
)
# After None check, cast to non-None type for type checker
result_non_none: Union[OpenID, dict] = cast(Union[OpenID, dict], result)
parsed_openid_result = SSOAuthenticationHandler._get_user_email_and_id_from_result(
result=result_non_none
)
verbose_proxy_logger.debug(f"parsed_openid_result: {parsed_openid_result}")
try:
# Get full user info from DB
user_info = await get_user_info_from_db(
result=result_non_none,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
user_email=parsed_openid_result.get("user_email"),
user_defined_values=None,
alternate_user_id=parsed_openid_result.get("user_id"),
)
if user_info is None:
raise HTTPException(
status_code=500, detail="Failed to retrieve user information from SSO"
)
# Store session info in cache (10 min TTL)
from litellm.constants import CLI_SSO_SESSION_CACHE_KEY_PREFIX
# Get all teams from user_info - CLI will let user select which one
teams: List[str] = []
if hasattr(user_info, "teams") and user_info.teams:
teams = user_info.teams if isinstance(user_info.teams, list) else []
# Also fetch team aliases for a better CLI UX. We keep the original
# "teams" list of IDs for backwards compatibility and add an
# optional "team_details" field containing objects with both
# team_id and team_alias.
team_details: List[Dict[str, Any]] = []
try:
if teams:
prisma_teams = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": teams}}
)
for team_row in prisma_teams:
team_dict = team_row.model_dump()
team_details.append(
{
"team_id": team_dict.get("team_id"),
"team_alias": team_dict.get("team_alias"),
}
)
except Exception as e:
# If anything goes wrong here, fall back gracefully without
# impacting the SSO flow.
verbose_proxy_logger.error(
f"Error fetching team details for CLI SSO session: {e}"
)
session_data = {
"user_id": user_info.user_id,
"user_role": user_info.user_role,
"models": user_info.models if hasattr(user_info, "models") else [],
"user_email": parsed_openid_result.get("user_email"),
"teams": teams,
# Optional rich metadata for clients that want nicer display
"team_details": team_details,
}
cache_key = f"{CLI_SSO_SESSION_CACHE_KEY_PREFIX}:{key}"
user_api_key_cache.set_cache(key=cache_key, value=session_data, ttl=600)
verbose_proxy_logger.info(
f"Stored CLI SSO session for user: {user_info.user_id}, teams: {teams}, num_teams: {len(teams)}"
)
# Return success page
from fastapi.responses import HTMLResponse
from litellm.proxy.common_utils.html_forms.cli_sso_success import (
render_cli_sso_success_page,
)
html_content = render_cli_sso_success_page()
return HTMLResponse(content=html_content, status_code=200)
except Exception as e:
verbose_proxy_logger.error(f"Error with CLI SSO callback: {e}")
raise HTTPException(
status_code=500, detail=f"Failed to process CLI SSO: {str(e)}"
)
@router.get("/sso/cli/poll/{key_id}", tags=["experimental"], include_in_schema=False)
async def cli_poll_key(key_id: str, team_id: Optional[str] = None):
"""
CLI polling endpoint - retrieves session from cache and generates JWT.
Flow:
1. First poll (no team_id): Returns teams list without generating JWT
2. Second poll (with team_id): Generates JWT with selected team and deletes session
Args:
key_id: The session key ID
team_id: Optional team ID to assign to the JWT. If provided, must be one of user's teams.
"""
from litellm.constants import CLI_SSO_SESSION_CACHE_KEY_PREFIX
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken
from litellm.proxy.proxy_server import user_api_key_cache
if not key_id.startswith("sk-"):
raise HTTPException(status_code=400, detail="Invalid key ID format")
try:
# Look up session in cache
cache_key = f"{CLI_SSO_SESSION_CACHE_KEY_PREFIX}:{key_id}"
session_data = user_api_key_cache.get_cache(key=cache_key)
if session_data:
user_teams = session_data.get("teams", [])
user_team_details = session_data.get("team_details")
user_id = session_data["user_id"]
verbose_proxy_logger.info(
f"CLI poll: user={user_id}, team_id={team_id}, user_teams={user_teams}, num_teams={len(user_teams)}"
)
# If no team_id provided and user has teams, return teams list for selection
# Don't generate JWT yet - let CLI select a team first. For newer
# clients we return rich team details (id + alias); older clients
# can continue to rely on the simple "teams" list.
if team_id is None and len(user_teams) > 1:
verbose_proxy_logger.info(
f"Returning teams list for user {user_id} to select from: {user_teams}"
)
# Best-effort construction of team_details if it wasn't
# already cached for some reason.
team_details_response: Optional[List[Dict[str, Any]]] = None
if isinstance(user_team_details, list) and user_team_details:
team_details_response = user_team_details
elif user_teams:
team_details_response = [
{"team_id": t, "team_alias": None} for t in user_teams
]
return {
"status": "ready",
"user_id": user_id,
"teams": user_teams,
"team_details": team_details_response,
"requires_team_selection": True,
}
# Validate team_id if provided
if team_id is not None:
if team_id not in user_teams:
raise HTTPException(
status_code=403,
detail=f"User does not belong to team: {team_id}. Available teams: {user_teams}",
)
else:
# If no team_id provided and user has 0 or 1 team, use first team (or None)
team_id = user_teams[0] if len(user_teams) > 0 else None
# Create user object for JWT generation
user_info = LiteLLM_UserTable(
user_id=user_id,
user_role=session_data["user_role"],
models=session_data.get("models", []),
max_budget=litellm.max_ui_session_budget,
)
# Generate CLI JWT on-demand (expiration configurable via LITELLM_CLI_JWT_EXPIRATION_HOURS)
# Pass selected team_id to ensure JWT has correct team
jwt_token = ExperimentalUIJWTToken.get_cli_jwt_auth_token(
user_info=user_info, team_id=team_id
)
# Delete cache entry (single-use)
user_api_key_cache.delete_cache(key=cache_key)
verbose_proxy_logger.info(
f"CLI JWT generated for user: {user_id}, team: {team_id}"
)
return {
"status": "ready",
"key": jwt_token,
"user_id": user_id,
"team_id": team_id,
"teams": user_teams,
# Echo back any team details we have so clients can
# present nicer information if needed.
"team_details": user_team_details,
}
else:
return {"status": "pending"}
except Exception as e:
verbose_proxy_logger.error(f"Error polling for CLI JWT: {e}")
raise HTTPException(
status_code=500, detail=f"Error checking session status: {str(e)}"
)
async def insert_sso_user(
result_openid: Optional[Union[OpenID, dict]],
user_defined_values: Optional[SSOUserDefinedValues] = None,
) -> NewUserResponse:
"""
Helper function to create a New User in LiteLLM DB after a successful SSO login
Args:
result_openid (OpenID): User information in OpenID format if the login was successful.
user_defined_values (Optional[SSOUserDefinedValues], optional): LiteLLM SSOValues / fields that were read
Returns:
Tuple[str, str]: User ID and User Role
"""
verbose_proxy_logger.debug(
f"Inserting SSO user into DB. User values: {user_defined_values}"
)
if result_openid is None:
raise ValueError("result_openid is None")
if isinstance(result_openid, dict):
result_openid = OpenID(**result_openid)
if user_defined_values is None:
raise ValueError("user_defined_values is None")
# Apply default_internal_user_params
if litellm.default_internal_user_params:
# Preserve the SSO-extracted role if it's a valid LiteLLM role,
# regardless of how it was determined (role_mappings, Microsoft app_roles,
# GENERIC_USER_ROLE_ATTRIBUTE, custom SSO handler, etc.)
sso_role = user_defined_values.get("user_role")
if _should_use_role_from_sso_response(sso_role):
# Preserve the SSO-extracted role, but apply other defaults
preserved_role = sso_role
user_defined_values.update(litellm.default_internal_user_params) # type: ignore
user_defined_values["user_role"] = preserved_role # Restore preserved role
verbose_proxy_logger.debug(
f"Preserved SSO-extracted role '{preserved_role}'"
)
else:
# SSO didn't provide a valid role, apply all defaults including role
user_defined_values.update(litellm.default_internal_user_params) # type: ignore
# Set budget for internal users
if user_defined_values.get("user_role") == LitellmUserRoles.INTERNAL_USER.value:
if user_defined_values.get("max_budget") is None:
user_defined_values["max_budget"] = litellm.max_internal_user_budget
if user_defined_values.get("budget_duration") is None:
user_defined_values[
"budget_duration"
] = litellm.internal_user_budget_duration
if user_defined_values["user_role"] is None:
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
new_user_request = NewUserRequest(
user_id=user_defined_values["user_id"],
user_email=normalize_email(user_defined_values["user_email"]),
user_role=user_defined_values["user_role"], # type: ignore
max_budget=user_defined_values["max_budget"],
budget_duration=user_defined_values["budget_duration"],
sso_user_id=user_defined_values["user_id"],
auto_create_key=False,
)
if result_openid and hasattr(result_openid, "provider"):
new_user_request.metadata = {
"auth_provider": getattr(result_openid, "provider")
}
response = await new_user(
data=new_user_request,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
return response
@router.get(
"/sso/get/ui_settings",
tags=["experimental"],
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)],
)
async def get_ui_settings(request: Request):
from litellm.proxy.proxy_server import general_settings, proxy_state
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
_logout_url = os.getenv("PROXY_LOGOUT_URL", None)
_api_doc_base_url = os.getenv("LITELLM_UI_API_DOC_BASE_URL", None)
_is_sso_enabled = _has_user_setup_sso()
disable_expensive_db_queries = (
proxy_state.get_proxy_state_variable("spend_logs_row_count")
> MAX_SPENDLOG_ROWS_TO_QUERY
)
default_team_disabled = general_settings.get("default_team_disabled", False)
if "PROXY_DEFAULT_TEAM_DISABLED" in os.environ:
if os.environ["PROXY_DEFAULT_TEAM_DISABLED"].lower() == "true":
default_team_disabled = True
return {
"PROXY_BASE_URL": _proxy_base_url,
"PROXY_LOGOUT_URL": _logout_url,
"LITELLM_UI_API_DOC_BASE_URL": _api_doc_base_url,
"DEFAULT_TEAM_DISABLED": default_team_disabled,
"SSO_ENABLED": _is_sso_enabled,
"NUM_SPEND_LOGS_ROWS": proxy_state.get_proxy_state_variable(
"spend_logs_row_count"
),
"DISABLE_EXPENSIVE_DB_QUERIES": disable_expensive_db_queries,
}
@router.get(
"/sso/readiness",
tags=["experimental"],
dependencies=[Depends(user_api_key_auth)],
)
async def sso_readiness():
"""
Health endpoint for checking SSO readiness.
Checks if the configured SSO provider has all required environment variables set in memory.
"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
# Determine which SSO provider is configured
configured_provider = None
if google_client_id is not None:
configured_provider = "google"
elif microsoft_client_id is not None:
configured_provider = "microsoft"
elif generic_client_id is not None:
configured_provider = "generic"
# If no SSO is configured, return healthy (SSO is optional)
if configured_provider is None:
return {
"status": "healthy",
"sso_configured": False,
"message": "No SSO provider configured",
}
# Check required environment variables for the configured provider
missing_vars = []
if configured_provider == "google":
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
missing_vars.append("GOOGLE_CLIENT_SECRET")
elif configured_provider == "microsoft":
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
missing_vars.append("MICROSOFT_CLIENT_SECRET")
if microsoft_tenant is None:
missing_vars.append("MICROSOFT_TENANT")
elif configured_provider == "generic":
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_authorization_endpoint = os.getenv(
"GENERIC_AUTHORIZATION_ENDPOINT", None
)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
if generic_client_secret is None:
missing_vars.append("GENERIC_CLIENT_SECRET")
if generic_authorization_endpoint is None:
missing_vars.append("GENERIC_AUTHORIZATION_ENDPOINT")
if generic_token_endpoint is None:
missing_vars.append("GENERIC_TOKEN_ENDPOINT")
if generic_userinfo_endpoint is None:
missing_vars.append("GENERIC_USERINFO_ENDPOINT")
# If all required variables are present, return healthy
if len(missing_vars) == 0:
return {
"status": "healthy",
"sso_configured": True,
"provider": configured_provider,
"message": f"{configured_provider.capitalize()} SSO is properly configured",
}
# If some variables are missing, return unhealthy
raise HTTPException(
status_code=503,
detail={
"status": "unhealthy",
"sso_configured": True,
"provider": configured_provider,
"missing_environment_variables": missing_vars,
"message": f"{configured_provider.capitalize()} SSO is configured but missing required environment variables: {', '.join(missing_vars)}",
},
)
class SSOAuthenticationHandler:
"""
Handler for SSO Authentication across all SSO providers
"""
@staticmethod
async def get_sso_login_redirect(
redirect_url: str,
google_client_id: Optional[str] = None,
microsoft_client_id: Optional[str] = None,
generic_client_id: Optional[str] = None,
state: Optional[str] = None,
) -> Optional[RedirectResponse]:
"""
Step 1. Call Get Login Redirect for the SSO provider. Send the redirect response to `redirect_url`
Args:
redirect_url (str): The URL to redirect the user to after login
google_client_id (Optional[str], optional): The Google Client ID. Defaults to None.
microsoft_client_id (Optional[str], optional): The Microsoft Client ID. Defaults to None.
generic_client_id (Optional[str], optional): The Generic Client ID. Defaults to None.
Returns:
RedirectResponse: The redirect response from the SSO provider.
"""
# Google SSO Auth
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
google_sso = GoogleSSO(
client_id=google_client_id,
client_secret=google_client_secret,
redirect_uri=redirect_url,
)
verbose_proxy_logger.info(
f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}"
)
with google_sso:
return await google_sso.get_login_redirect(state=state)
# Microsoft SSO Auth
elif microsoft_client_id is not None:
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = CustomMicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
with microsoft_sso:
return await microsoft_sso.get_login_redirect(state=state)
elif generic_client_id is not None:
from fastapi_sso.sso.base import DiscoveryDocument
from fastapi_sso.sso.generic import create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(
" "
)
generic_authorization_endpoint = os.getenv(
"GENERIC_AUTHORIZATION_ENDPOINT", None
)
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
if generic_client_secret is None:
raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_authorization_endpoint is None:
raise ProxyException(
message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_AUTHORIZATION_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_token_endpoint is None:
raise ProxyException(
message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_TOKEN_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if generic_userinfo_endpoint is None:
raise ProxyException(
message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GENERIC_USERINFO_ENDPOINT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
verbose_proxy_logger.debug(
f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}"
)
verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
)
discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint,
)
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
generic_sso = SSOProvider(
client_id=generic_client_id,
client_secret=generic_client_secret,
redirect_uri=redirect_url,
allow_insecure_http=True,
scope=generic_scope,
)
return await SSOAuthenticationHandler.get_generic_sso_redirect_response(
generic_sso=generic_sso,
state=state,
generic_authorization_endpoint=generic_authorization_endpoint,
)
raise ValueError(
"Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso"
)
@staticmethod
async def get_generic_sso_redirect_response(
generic_sso: Any,
state: Optional[str] = None,
generic_authorization_endpoint: Optional[str] = None,
) -> Optional[RedirectResponse]:
"""
Get the redirect response for Generic SSO
"""
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache
with generic_sso:
# TODO: state should be a random string and added to the user session with cookie
# or a cryptographicly signed state that we can verify stateless
# For simplification we are using a static state, this is not perfect but some
# SSO providers do not allow stateless verification
(
redirect_params,
code_verifier,
) = SSOAuthenticationHandler._get_generic_sso_redirect_params(
state=state,
generic_authorization_endpoint=generic_authorization_endpoint,
)
# Separate PKCE params from state params (fastapi-sso doesn't accept code_challenge)
pkce_params = {}
state_only_params = {}
for key, value in redirect_params.items():
if key in ("code_challenge", "code_challenge_method"):
pkce_params[key] = value
else:
state_only_params[key] = value
# Get the redirect response from fastapi-sso with only state param
redirect_response = await generic_sso.get_login_redirect(**state_only_params) # type: ignore
# If PKCE is enabled, add PKCE parameters to the redirect URL
if code_verifier and "state" in redirect_params:
# Store code_verifier in cache (10 min TTL). Wrap in dict for proper
# JSON serialization in Redis. Use Redis when available so callbacks
# landing on another pod can retrieve it (multi-pod SSO).
cache_key = f"pkce_verifier:{redirect_params['state']}"
if redis_usage_cache is not None:
await redis_usage_cache.async_set_cache(
key=cache_key,
value={"code_verifier": code_verifier},
ttl=600,
)
else:
await user_api_key_cache.async_set_cache(
key=cache_key,
value={"code_verifier": code_verifier},
ttl=600,
)
verbose_proxy_logger.debug(
"PKCE code_verifier stored in cache (TTL: 600s)"
)
# Add PKCE parameters to the authorization URL
if pkce_params:
parsed_url = urlparse(str(redirect_response.headers["location"]))
query_params = parse_qs(parsed_url.query)
# Add PKCE parameters
for key, value in pkce_params.items():
query_params[key] = [value]
# Reconstruct the URL with PKCE parameters
new_query = urlencode(query_params, doseq=True)
new_url = urlunparse(
(
parsed_url.scheme,
parsed_url.netloc,
parsed_url.path,
parsed_url.params,
new_query,
parsed_url.fragment,
)
)
# Update the redirect response
redirect_response.headers["location"] = new_url
return redirect_response
@staticmethod
def _get_generic_sso_redirect_params(
state: Optional[str] = None,
generic_authorization_endpoint: Optional[str] = None,
) -> Tuple[dict, Optional[str]]:
"""
Get redirect parameters for Generic SSO with proper state priority handling.
Optionally generates PKCE parameters if GENERIC_CLIENT_USE_PKCE is enabled.
Priority order:
1. CLI state (if provided)
2. GENERIC_CLIENT_STATE environment variable
3. Generated UUID (required by Okta and most OAuth providers)
Args:
state: Optional state parameter (e.g., CLI state)
generic_authorization_endpoint: Authorization endpoint URL
Returns:
Tuple[dict, Optional[str]]:
- Redirect parameters for SSO login (may include PKCE params)
- code_verifier (if PKCE is enabled, None otherwise)
"""
redirect_params = {}
code_verifier: Optional[str] = None
if state:
# CLI state takes priority
# the litellm proxy cli sends the "state" parameter to the proxy server for auth. We should maintain the state parameter for the cli if it is provided
redirect_params["state"] = state
else:
generic_client_state = os.getenv("GENERIC_CLIENT_STATE", None)
if generic_client_state:
redirect_params["state"] = generic_client_state
else:
redirect_params["state"] = uuid.uuid4().hex
# Handle PKCE (Proof Key for Code Exchange) if enabled
# Set GENERIC_CLIENT_USE_PKCE=true to enable PKCE for enhanced OAuth security
use_pkce = os.getenv("GENERIC_CLIENT_USE_PKCE", "false").lower() == "true"
if use_pkce:
(
code_verifier,
code_challenge,
) = SSOAuthenticationHandler.generate_pkce_params()
redirect_params["code_challenge"] = code_challenge
redirect_params["code_challenge_method"] = "S256"
verbose_proxy_logger.debug("PKCE enabled for authorization request")
return redirect_params, code_verifier
@staticmethod
def should_use_sso_handler(
google_client_id: Optional[str] = None,
microsoft_client_id: Optional[str] = None,
generic_client_id: Optional[str] = None,
) -> bool:
if (
google_client_id is not None
or microsoft_client_id is not None
or generic_client_id is not None
):
return True
return False
@staticmethod
def get_redirect_url_for_sso(
request: Request,
sso_callback_route: str,
existing_key: Optional[str] = None,
) -> str:
"""
Get the redirect URL for SSO
Note: existing_key is not added to the URL to avoid changing the callback URL.
It should be passed via the state parameter instead.
"""
from litellm.proxy.utils import get_custom_url
redirect_url = get_custom_url(request_base_url=str(request.base_url))
if redirect_url.endswith("/"):
redirect_url += sso_callback_route
else:
redirect_url += "/" + sso_callback_route
return redirect_url
@staticmethod
async def upsert_sso_user(
result: Optional[Union[CustomOpenID, OpenID, dict]],
user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]],
user_email: Optional[str],
user_defined_values: Optional[SSOUserDefinedValues],
prisma_client: PrismaClient,
):
"""
Connects the SSO Users to the User Table in LiteLLM DB
- If user on LiteLLM DB, update the user_email and user_role (if SSO provides valid role) with the SSO values
- If user not on LiteLLM DB, insert the user into LiteLLM DB
"""
try:
if user_info is not None:
user_id = user_info.user_id
update_data = _build_sso_user_update_data(
result=result,
user_email=user_email,
user_id=user_id,
)
await prisma_client.db.litellm_usertable.update_many(
where={"user_id": user_id}, data=update_data
)
else:
verbose_proxy_logger.info(
"user not in DB, inserting user into LiteLLM DB"
)
# user not in DB, insert User into LiteLLM DB
user_info = await insert_sso_user(
result_openid=result,
user_defined_values=user_defined_values,
)
return user_info
except Exception as e:
verbose_proxy_logger.exception(
f"Error upserting SSO user into LiteLLM DB: {e}"
)
return user_info
@staticmethod
async def add_user_to_teams_from_sso_response(
result: Optional[Union[CustomOpenID, OpenID, dict]],
user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]],
):
"""
Adds the user as a team member to the teams specified in the SSO responses `team_ids` field
The `team_ids` field is populated by litellm after processing the SSO response
"""
if user_info is None:
verbose_proxy_logger.debug(
"User not found in LiteLLM DB, skipping team member addition"
)
return
sso_teams = getattr(result, "team_ids", [])
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
@staticmethod
def verify_user_in_restricted_sso_group(
general_settings: Dict,
result: Optional[Union[CustomOpenID, OpenID, dict]],
received_response: Optional[dict],
) -> Literal[True]:
"""
when ui_access_mode.type == "restricted_sso_group":
- result.team_ids should contain the restricted_sso_group
- if not, raise a ProxyException
- if so, return True
- if result.team_ids is None, return False
- if result.team_ids is an empty list, return False
- if result.team_ids is a list, return True if the restricted_sso_group is in the list, otherwise return False
"""
ui_access_mode = cast(
Optional[Union[Dict, str]], general_settings.get("ui_access_mode")
)
if ui_access_mode is None:
return True
if isinstance(ui_access_mode, str):
return True
team_ids = getattr(result, "team_ids", [])
if ui_access_mode.get("type") == "restricted_sso_group":
restricted_sso_group = ui_access_mode.get("restricted_sso_group")
if restricted_sso_group not in team_ids:
raise ProxyException(
message=f"User is not in the restricted SSO group: {restricted_sso_group}. User groups: {team_ids}. Received SSO response: {received_response}",
type=ProxyErrorTypes.auth_error,
param="restricted_sso_group",
code=status.HTTP_403_FORBIDDEN,
)
return True
@staticmethod
async def create_litellm_team_from_sso_group(
litellm_team_id: str,
litellm_team_name: Optional[str] = None,
):
"""
Creates a Litellm Team from a SSO Group ID
Your SSO provider might have groups that should be created on LiteLLM
Use this helper to create a Litellm Team from a SSO Group ID
Args:
litellm_team_id (str): The ID of the Litellm Team
litellm_team_name (Optional[str]): The name of the Litellm Team
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise ProxyException(
message="Prisma client not found. Set it in the proxy_server.py file",
type=ProxyErrorTypes.auth_error,
param="prisma_client",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
try:
team_obj = await prisma_client.db.litellm_teamtable.find_first(
where={"team_id": litellm_team_id}
)
verbose_proxy_logger.debug(f"Team object: {team_obj}")
# only create a new team if it doesn't exist
if team_obj:
verbose_proxy_logger.debug(
f"Team already exists: {litellm_team_id} - {litellm_team_name}"
)
return
team_request: NewTeamRequest = NewTeamRequest(
team_id=litellm_team_id,
team_alias=litellm_team_name,
)
if litellm.default_team_params:
team_request = SSOAuthenticationHandler._cast_and_deepcopy_litellm_default_team_params(
default_team_params=litellm.default_team_params,
litellm_team_id=litellm_team_id,
litellm_team_name=litellm_team_name,
team_request=team_request,
)
await new_team(
data=team_request,
# params used for Audit Logging
http_request=Request(scope={"type": "http", "method": "POST"}),
user_api_key_dict=UserAPIKeyAuth(
token="",
key_alias=f"litellm.{MicrosoftSSOHandler.__name__}",
),
)
except Exception as e:
verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}")
@staticmethod
def _cast_and_deepcopy_litellm_default_team_params(
default_team_params: Union[DefaultTeamSSOParams, Dict],
team_request: NewTeamRequest,
litellm_team_id: str,
litellm_team_name: Optional[str] = None,
) -> NewTeamRequest:
"""
Casts and deepcopies the litellm.default_team_params to a NewTeamRequest object
- Ensures we create a new DefaultTeamSSOParams object
- Handle the case where litellm.default_team_params is a dict or a DefaultTeamSSOParams object
- Adds the litellm_team_id and litellm_team_name to the DefaultTeamSSOParams object
"""
if isinstance(default_team_params, dict):
_team_request = deepcopy(default_team_params)
_team_request["team_id"] = litellm_team_id
_team_request["team_alias"] = litellm_team_name
team_request = NewTeamRequest(**_team_request)
elif isinstance(litellm.default_team_params, DefaultTeamSSOParams):
_default_team_params = deepcopy(litellm.default_team_params)
_new_team_request = team_request.model_dump()
_new_team_request.update(_default_team_params)
team_request = NewTeamRequest(**_new_team_request)
return team_request
@staticmethod
def _get_cli_state(
source: Optional[str], key: Optional[str], existing_key: Optional[str] = None
) -> Optional[str]:
"""
Checks the request 'source' if a cli state token was passed in
This is used to authenticate through the CLI login flow.
The state parameter format is: {PREFIX}:{key}:{existing_key}
- If existing_key is provided, it's included in the state
- The state parameter is used to pass data through the OAuth flow without changing the callback URL
"""
from litellm.constants import (
LITELLM_CLI_SESSION_TOKEN_PREFIX,
LITELLM_CLI_SOURCE_IDENTIFIER,
)
if source == LITELLM_CLI_SOURCE_IDENTIFIER and key:
if existing_key:
return f"{LITELLM_CLI_SESSION_TOKEN_PREFIX}:{key}:{existing_key}"
else:
return f"{LITELLM_CLI_SESSION_TOKEN_PREFIX}:{key}"
else:
return None
@staticmethod
def _get_user_email_and_id_from_result(
result: Optional[Union[OpenID, dict]],
generic_client_id: Optional[str] = None,
) -> ParsedOpenIDResult:
"""
Gets the user email and id from the OpenID result after validating the email domain
"""
user_email: Optional[str] = normalize_email(getattr(result, "email", None))
user_id: Optional[str] = (
getattr(result, "id", None) if result is not None else None
)
user_role: Optional[str] = None
if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None:
email_domain = user_email.split("@")[1]
allowed_domains = os.getenv("ALLOWED_EMAIL_DOMAINS").split(",") # type: ignore
if email_domain not in allowed_domains:
raise HTTPException(
status_code=401,
detail={
"message": "The email domain={}, is not an allowed email domain={}. Contact your admin to change this.".format(
email_domain, allowed_domains
)
},
)
# Extract user_role from result (works for all SSO providers)
if result is not None:
_user_role = getattr(result, "user_role", None)
if _user_role is not None:
# Convert enum to string if needed
user_role = (
_user_role.value
if isinstance(_user_role, LitellmUserRoles)
else _user_role
)
verbose_proxy_logger.debug(
f"Extracted user_role from SSO result: {user_role}"
)
# generic client id - override with custom attribute name if specified
if generic_client_id is not None and result is not None:
generic_user_role_attribute_name = os.getenv(
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
)
user_id = getattr(result, "id", None)
user_email = normalize_email(getattr(result, "email", None))
if user_role is None:
_role_from_attr = getattr(result, generic_user_role_attribute_name, None) # type: ignore
if _role_from_attr is not None:
# Convert enum to string if needed
user_role = (
_role_from_attr.value
if isinstance(_role_from_attr, LitellmUserRoles)
else _role_from_attr
)
if user_id is None and result is not None:
_first_name = getattr(result, "first_name", "") or ""
_last_name = getattr(result, "last_name", "") or ""
user_id = _first_name + _last_name
if user_email is not None and (user_id is None or len(user_id) == 0):
user_id = user_email
return ParsedOpenIDResult(
user_email=user_email,
user_id=user_id,
user_role=user_role,
)
@staticmethod
async def get_redirect_response_from_openid( # noqa: PLR0915
result: Union[OpenID, dict, CustomOpenID],
request: Request,
received_response: Optional[dict] = None,
generic_client_id: Optional[str] = None,
ui_access_mode: Optional[Dict] = None,
) -> RedirectResponse:
import jwt
from litellm.proxy.proxy_server import (
general_settings,
generate_key_helper_fn,
master_key,
premium_user,
proxy_logging_obj,
user_api_key_cache,
user_custom_sso,
)
from litellm.proxy.utils import get_prisma_client_or_throw
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
prisma_client = get_prisma_client_or_throw(
"Prisma client is None, connect a database to your proxy"
)
# User is Authe'd in - generate key for the UI to access Proxy
parsed_openid_result = (
SSOAuthenticationHandler._get_user_email_and_id_from_result(
result=result, generic_client_id=generic_client_id
)
)
user_email = parsed_openid_result.get("user_email")
user_id = parsed_openid_result.get("user_id")
user_role = parsed_openid_result.get("user_role")
verbose_proxy_logger.info(f"SSO callback result: {result}")
user_info = None
user_id_models: List = []
max_internal_user_budget = litellm.max_internal_user_budget
internal_user_budget_duration = litellm.internal_user_budget_duration
# User might not be already created on first generation of key
# But if it is, we want their models preferences
default_ui_key_values: Dict[str, Any] = {
"duration": LITELLM_UI_SESSION_DURATION,
"key_max_budget": litellm.max_ui_session_budget,
"aliases": {},
"config": {},
"spend": 0,
"team_id": "litellm-dashboard",
}
user_defined_values: Optional[SSOUserDefinedValues] = None
if user_custom_sso is not None:
if inspect.iscoroutinefunction(user_custom_sso):
user_defined_values = await user_custom_sso(result) # type: ignore
else:
raise ValueError("user_custom_sso must be a coroutine function")
elif user_id is not None:
user_defined_values = SSOUserDefinedValues(
models=user_id_models,
user_id=user_id,
user_email=user_email,
max_budget=max_internal_user_budget,
user_role=user_role,
budget_duration=internal_user_budget_duration,
)
# (IF SET) Verify user is in restricted SSO group
SSOAuthenticationHandler.verify_user_in_restricted_sso_group(
general_settings=general_settings,
result=result,
received_response=received_response,
)
user_info = await get_user_info_from_db(
result=result,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
user_email=user_email,
user_defined_values=user_defined_values,
alternate_user_id=user_id,
)
user_defined_values = apply_user_info_values_to_sso_user_defined_values(
user_info=user_info, user_defined_values=user_defined_values
)
if user_defined_values is None:
raise Exception(
"Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues"
)
verbose_proxy_logger.info(
f"user_defined_values for creating ui key: {user_defined_values}"
)
default_ui_key_values.update(user_defined_values)
default_ui_key_values["request_type"] = "key"
response = await generate_key_helper_fn(
**default_ui_key_values, # type: ignore
table_name="key",
)
key = response["token"] # type: ignore
user_id = response["user_id"] # type: ignore
user_role = (
user_defined_values["user_role"]
or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
)
if user_id and isinstance(user_id, str):
user_role = await check_and_update_if_proxy_admin_id(
user_role=user_role, user_id=user_id, prisma_client=prisma_client
)
verbose_proxy_logger.debug(
f"user_role: {user_role}; ui_access_mode: {ui_access_mode}"
)
## CHECK IF ROLE ALLOWED TO USE PROXY ##
is_admin_only_access = check_is_admin_only_access(ui_access_mode or {})
if is_admin_only_access:
has_access = has_admin_ui_access(user_role or "")
if not has_access:
raise HTTPException(
status_code=401,
detail={
"error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}"
},
)
disabled_non_admin_personal_key_creation = (
get_disabled_non_admin_personal_key_creation()
)
litellm_dashboard_ui = get_custom_url(
request_base_url=str(request.base_url), route="ui/"
)
if get_secret_bool("EXPERIMENTAL_UI_LOGIN"):
_user_info: Optional[LiteLLM_UserTable] = None
if (
user_defined_values is not None
and user_defined_values["user_id"] is not None
):
_user_info = LiteLLM_UserTable(
user_id=user_defined_values["user_id"],
user_role=user_defined_values["user_role"] or user_role,
models=[],
max_budget=litellm.max_ui_session_budget,
)
if _user_info is None:
raise HTTPException(
status_code=401,
detail={
"error": "User Information is required for experimental UI login"
},
)
key = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
_user_info
)
returned_ui_token_object = ReturnedUITokenObject(
user_id=cast(str, user_id),
key=key,
user_email=user_email,
user_role=user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value,
login_method="sso",
premium_user=premium_user,
auth_header_name=general_settings.get(
"litellm_key_header_name", "Authorization"
),
disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation,
server_root_path=get_server_root_path(),
)
jwt_token = jwt.encode(
cast(dict, returned_ui_token_object),
master_key or "",
algorithm="HS256",
)
if user_id is not None and isinstance(user_id, str):
litellm_dashboard_ui += "?login=success"
verbose_proxy_logger.info(f"Redirecting to {litellm_dashboard_ui}")
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
redirect_response.set_cookie(key="token", value=jwt_token)
return redirect_response
@staticmethod
async def prepare_token_exchange_parameters(
request: Request,
generic_include_client_id: bool,
) -> dict:
"""
Prepare token exchange parameters for Generic SSO.
Args:
request: Request object
generic_include_client_id: Generic OAuth Client ID
Returns:
dict: Token exchange parameters
"""
# Prepare token exchange parameters (may add code_verifier: str later)
token_params: Dict[str, Any] = {"include_client_id": generic_include_client_id}
# Retrieve PKCE code_verifier if PKCE was used in authorization.
# Gate on GENERIC_CLIENT_USE_PKCE to avoid an unnecessary Redis round-trip
# on every non-PKCE SSO callback.
query_params = dict(request.query_params)
state = query_params.get("state")
use_pkce = os.getenv("GENERIC_CLIENT_USE_PKCE", "false").lower() == "true"
if use_pkce and not state:
verbose_proxy_logger.warning(
"PKCE is enabled (GENERIC_CLIENT_USE_PKCE=true) but no 'state' parameter "
"was found in the callback. The PKCE verifier cannot be retrieved without "
"a state value — the token exchange will proceed without code_verifier, "
"which the provider may reject. Ensure your OAuth provider returns 'state' "
"in the callback redirect."
)
if state and use_pkce:
from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache
cache_key = f"pkce_verifier:{state}"
if redis_usage_cache is not None:
cached_data = await redis_usage_cache.async_get_cache(key=cache_key)
else:
cached_data = await user_api_key_cache.async_get_cache(key=cache_key)
code_verifier = None
# Track why code_verifier is absent for accurate strict-mode diagnostics.
_empty_value_in_dict = False # dict format correct but value is empty/null
if cached_data:
# Extract code_verifier from dict (stored as dict for JSON serialization)
if isinstance(cached_data, dict) and "code_verifier" in cached_data:
code_verifier = cached_data["code_verifier"]
if not code_verifier:
# Dict format is correct but value is empty or null. This is
# a distinct case from an unrecognized format — the entry exists
# but was stored with an empty/null verifier (data integrity issue).
_empty_value_in_dict = True
verbose_proxy_logger.warning(
"PKCE verifier dict for state '%s' has an empty/null code_verifier "
"value — may indicate a storage bug. Treating as a cache miss.",
state,
)
else:
verbose_proxy_logger.debug("PKCE code_verifier retrieved from cache")
elif isinstance(cached_data, str):
# Handle legacy format (plain string) for backward compatibility
code_verifier = cached_data
verbose_proxy_logger.warning(
"Retrieved code_verifier in legacy plain-string format. "
"Future storage will use dict format."
)
else:
# Defer the detailed ERROR log to the strict-mode branch below
# (which includes state and a diagnostic message). Log at DEBUG
# here to avoid duplicate ERROR entries in the same request.
verbose_proxy_logger.debug(
"Unexpected PKCE verifier cache format (type=%s); skipping.",
type(cached_data).__name__,
)
if code_verifier:
# Add code_verifier to token exchange parameters.
token_params["code_verifier"] = code_verifier
# Return the cache key so the caller can delete it *after* a
# successful token exchange (avoids losing the verifier on retry
# if the exchange fails partway through).
token_params["_pkce_cache_key"] = cache_key
else:
await SSOAuthenticationHandler._handle_missing_pkce_verifier(
state=state,
cache_key=cache_key,
cached_data=cached_data,
empty_value_in_dict=_empty_value_in_dict,
redis_usage_cache=redis_usage_cache,
user_api_key_cache=user_api_key_cache,
)
return token_params
@staticmethod
async def _handle_missing_pkce_verifier(
state: Optional[str],
cache_key: str,
cached_data: object,
empty_value_in_dict: bool,
redis_usage_cache: object,
user_api_key_cache: object,
) -> None:
"""Handle the case where PKCE verifier could not be extracted from cache.
In strict mode (PKCE_STRICT_CACHE_MISS=true) raises ProxyException.
Otherwise logs a warning and returns (token exchange proceeds without verifier).
"""
active_cache = redis_usage_cache if redis_usage_cache is not None else user_api_key_cache
strict_cache_miss = (
os.getenv("PKCE_STRICT_CACHE_MISS", "false").lower() == "true"
)
if strict_cache_miss:
if empty_value_in_dict:
await SSOAuthenticationHandler._delete_pkce_verifier(cache_key)
raise ProxyException(
message=(
f"PKCE verifier for state '{state}' was found in cache but "
f"has an empty or null code_verifier value — possible storage bug."
),
type=ProxyErrorTypes.auth_error,
param="PKCE_CACHE_MISS",
code=status.HTTP_401_UNAUTHORIZED,
)
elif cached_data is not None:
await SSOAuthenticationHandler._delete_pkce_verifier(cache_key)
verbose_proxy_logger.error(
"PKCE verifier for state '%s' has an unrecognized format (type=%s); "
"treating as a cache miss. Investigate the cached value — it may be "
"a corrupt or stale entry.",
state,
type(cached_data).__name__,
)
raise ProxyException(
message=(
f"PKCE verifier for state '{state}' has an unrecognized format "
f"(type={type(cached_data).__name__}). The cached entry may be corrupt."
),
type=ProxyErrorTypes.auth_error,
param="PKCE_CACHE_MISS",
code=status.HTTP_401_UNAUTHORIZED,
)
else:
if redis_usage_cache is not None:
cause = (
"The authorization and callback were likely handled by different "
"instances — the verifier was stored on one pod but not found on another."
)
else:
cause = (
"The verifier may have expired (TTL), been lost on a pod restart, "
"or the PKCE authorization step was never completed. "
"Configure Redis so all proxy instances share the PKCE verifier."
)
verbose_proxy_logger.error(
"PKCE is enabled but no verifier found in cache for state '%s'. "
"%s Cache type: %s.",
state,
cause,
type(active_cache).__name__,
)
raise ProxyException(
message=f"PKCE verifier not found in cache for state '{state}'. {cause}",
type=ProxyErrorTypes.auth_error,
param="PKCE_CACHE_MISS",
code=status.HTTP_401_UNAUTHORIZED,
)
else:
if cached_data is not None:
await SSOAuthenticationHandler._delete_pkce_verifier(cache_key)
verbose_proxy_logger.warning(
"PKCE is enabled but verifier not found in cache for state '%s' "
"(cache type: %s, raw data present: %s). "
"Continuing without code_verifier — set PKCE_STRICT_CACHE_MISS=true to fail fast instead.",
state,
type(active_cache).__name__,
cached_data is not None,
)
@staticmethod
async def _delete_pkce_verifier(cache_key: str) -> None:
"""Delete a single-use PKCE verifier from cache after a successful exchange.
Failure is non-fatal: a leftover verifier is a minor security concern
(unused key in cache) but not worth aborting an otherwise-successful login.
"""
from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache
try:
if redis_usage_cache is not None:
await redis_usage_cache.async_delete_cache(key=cache_key)
else:
await user_api_key_cache.async_delete_cache(key=cache_key)
except Exception as exc:
verbose_proxy_logger.warning(
"PKCE: failed to delete verifier cache key '%s' (best-effort cleanup): %s",
cache_key,
exc,
)
@staticmethod
def generate_pkce_params() -> Tuple[str, str]:
"""
Generate PKCE (Proof Key for Code Exchange) parameters for OAuth 2.0.
Returns:
Tuple[str, str]: (code_verifier, code_challenge)
- code_verifier: Random 43-128 character string (we use 43 for efficiency)
- code_challenge: Base64-URL-encoded SHA256 hash of the code_verifier
Reference: https://datatracker.ietf.org/doc/html/rfc7636
"""
# Generate a cryptographically random code_verifier (43 characters)
# Using 32 random bytes which becomes 43 characters when base64-url-encoded
code_verifier = (
base64.urlsafe_b64encode(secrets.token_bytes(32))
.decode("utf-8")
.rstrip("=")
)
# Generate code_challenge using S256 method (SHA256)
code_challenge_bytes = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge = (
base64.urlsafe_b64encode(code_challenge_bytes).decode("utf-8").rstrip("=")
)
return code_verifier, code_challenge
@staticmethod
def _validate_token_response(response: "httpx.Response") -> dict:
"""
Parse and validate the token endpoint response.
Ensures the response is valid JSON, a dict, and contains a non-null
access_token string. Raises ProxyException on any validation failure.
"""
try:
token_response_raw = response.json()
except Exception as json_err:
verbose_proxy_logger.error(
"Failed to parse token response as JSON: %s. Body: %s",
json_err,
response.text[:500],
)
raise ProxyException(
message=f"Token endpoint returned invalid JSON: {json_err}",
type=ProxyErrorTypes.auth_error,
param="token_exchange",
code=status.HTTP_401_UNAUTHORIZED,
)
if not isinstance(token_response_raw, dict):
verbose_proxy_logger.error(
"Token endpoint returned non-dict JSON (type=%s). Body: %s",
type(token_response_raw).__name__,
response.text[:500],
)
raise ProxyException(
message=(
f"Token endpoint returned unexpected response format "
f"(expected JSON object, got {type(token_response_raw).__name__})"
),
type=ProxyErrorTypes.auth_error,
param="token_exchange",
code=status.HTTP_401_UNAUTHORIZED,
)
token_response: dict = token_response_raw
access_token_val = token_response.get("access_token")
if not isinstance(access_token_val, str) or not access_token_val:
error = token_response.get("error")
error_desc = token_response.get("error_description", "")
if error:
detail = f"{error} - {error_desc}" if error_desc else error
else:
detail = (
"token endpoint returned HTTP 200 but no access_token "
f"(response keys: {sorted(token_response.keys())})"
)
verbose_proxy_logger.error(
"Token response missing or null access_token. detail=%s", detail
)
raise ProxyException(
message=f"Token exchange failed: {detail}",
type=ProxyErrorTypes.auth_error,
param="token_exchange",
code=status.HTTP_401_UNAUTHORIZED,
)
return token_response
@staticmethod
async def _pkce_token_exchange(
authorization_code: str,
code_verifier: str,
client_id: str,
client_secret: Optional[str],
token_endpoint: str,
userinfo_endpoint: Optional[str],
include_client_id: bool,
redirect_url: Optional[str],
additional_headers: Dict[str, str],
) -> dict:
"""
Performs a direct OAuth token exchange including the PKCE code_verifier.
fastapi-sso does not forward code_verifier, so when PKCE is enabled we
bypass it and call the token endpoint ourselves, then fetch user info.
Returns a combined dict of the token response and user info, suitable
for passing to a response_convertor.
"""
verbose_proxy_logger.debug(
"PKCE: performing direct token exchange (code_verifier length=%d)",
len(code_verifier),
)
token_data: Dict[str, str] = {
"grant_type": "authorization_code",
"code": authorization_code,
"code_verifier": code_verifier,
}
# Only include redirect_uri when set — omitting it avoids sending the
# literal string "None" to the provider if the env var is missing.
if redirect_url:
token_data["redirect_uri"] = redirect_url
request_headers = {
**additional_headers,
"Content-Type": "application/x-www-form-urlencoded", # must not be overridden
"Accept": "application/json",
}
if not include_client_id:
# Use Basic Auth only when a secret is available; public PKCE clients omit it.
if client_secret:
credentials = base64.b64encode(
f"{client_id}:{client_secret}".encode()
).decode()
request_headers["Authorization"] = f"Basic {credentials}"
else:
token_data["client_id"] = client_id
else:
token_data["client_id"] = client_id
if client_secret:
token_data["client_secret"] = client_secret
http_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SSO_HANDLER
)
try:
response = await http_client.post(
url=token_endpoint,
data=token_data,
headers=request_headers,
timeout=30.0,
)
except Exception as exc:
# Catch network-level errors (SSL, DNS, TCP, timeout, etc.) and
# wrap them as a clean ProxyException rather than leaking raw
# httpx or OS exceptions to callers.
verbose_proxy_logger.error("PKCE token endpoint unreachable: %s", exc)
raise ProxyException(
message=f"Token endpoint request failed: {exc}",
type=ProxyErrorTypes.auth_error,
param="token_exchange",
code=status.HTTP_401_UNAUTHORIZED,
) from exc
if response.status_code != 200:
verbose_proxy_logger.error(
"PKCE token exchange failed. status=%s body=%s",
response.status_code,
response.text[:500],
)
raise ProxyException(
message=f"Token exchange failed: {response.status_code} - {response.text[:500]}",
type=ProxyErrorTypes.auth_error,
param="token_exchange",
code=status.HTTP_401_UNAUTHORIZED,
)
token_response = SSOAuthenticationHandler._validate_token_response(response)
verbose_proxy_logger.debug(
"PKCE token exchange successful. id_token_present=%s",
bool(token_response.get("id_token")),
)
# Bearer credentials (access_token, id_token, refresh_token) are always sourced
# from token_response — not from userinfo — in the merge step below.
userinfo = await SSOAuthenticationHandler._get_pkce_userinfo(
access_token=token_response["access_token"],
id_token=token_response.get("id_token"),
userinfo_endpoint=userinfo_endpoint,
additional_headers=additional_headers,
)
# Merge: userinfo takes precedence for identity claims (sub, email, name, …) per
# the OpenID Connect spec (userinfo is the authoritative source for identity).
# Bearer credentials (access_token, id_token, refresh_token) from the token endpoint
# take precedence over same-named fields in userinfo — non-standard providers sometimes
# include token fields in userinfo, which must not shadow the real bearer token.
# If a bearer field is absent from the token response, any userinfo-provided value
# is preserved as a fallback (useful for non-standard providers that omit id_token
# from the token response but include it in userinfo).
#
# Three-way merge semantics for each bearer-credential field:
# 1. token_response has a non-null value → use it (token endpoint is authoritative)
# 2. token_response explicitly sent null → remove the key so callers get a clean
# absence signal; the null from the token endpoint overrides userinfo too
# 3. field absent from token_response → leave whatever userinfo provided as-is
# (e.g. userinfo-provided id_token from a non-standard provider)
merged = {**token_response, **userinfo}
for field in _OAUTH_TOKEN_FIELDS:
if token_response.get(field) is not None:
# Case 1: non-null in token_response — restore authoritative value.
merged[field] = token_response[field]
elif field in token_response:
# Case 2: key exists but value is explicitly null — remove from merged.
merged.pop(field, None)
# Case 3: field absent from token_response — leave userinfo value as-is.
return merged
@staticmethod
async def _get_pkce_userinfo(
access_token: str,
id_token: Optional[str],
userinfo_endpoint: Optional[str],
additional_headers: Dict[str, str],
) -> dict:
"""
Fetches user info from the userinfo endpoint.
Falls back to decoding the id_token if the endpoint is unavailable.
"""
# None = request not yet attempted, failed, or returned empty/null (treated as failure
# so the id_token fallback can be attempted instead of returning a session with no claims).
userinfo: Optional[dict] = None
if userinfo_endpoint:
try:
client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SSO_HANDLER
)
resp = await client.get(
url=userinfo_endpoint,
headers={
**additional_headers,
"Authorization": f"Bearer {access_token}", # must not be overridden
},
)
if resp.status_code == 200:
try:
userinfo_raw = resp.json()
if not userinfo_raw:
# JSON null (None) or empty dict ({}) — no identity claims.
# Treat as failure so id_token fallback can be attempted.
verbose_proxy_logger.warning(
"Userinfo endpoint returned an empty or null response "
"(type=%s); treating as failure and attempting id_token fallback. "
"Check your provider's userinfo endpoint configuration.",
type(userinfo_raw).__name__,
)
userinfo = None
else:
userinfo = userinfo_raw
except Exception as json_err:
verbose_proxy_logger.warning(
"Userinfo endpoint returned non-JSON response (status 200): %s",
json_err,
)
else:
verbose_proxy_logger.warning(
"Userinfo endpoint returned %s (body: %s), falling back to id_token",
resp.status_code,
resp.text[:500],
)
except Exception as e:
verbose_proxy_logger.warning(
"Userinfo endpoint error: %s, falling back to id_token", e
)
# Only fall back to id_token when the userinfo request failed (None).
# Empty dict ({}) and JSON null are both treated as failure (set to None above) since
# they contain no identity claims — id_token fallback is attempted in that case too.
# Explicitly check for a non-empty string to avoid attempting JWT decode on
# a blank or non-string id_token field from a misbehaving provider.
if userinfo is None and isinstance(id_token, str) and id_token:
try:
userinfo = jwt.decode(id_token, options={"verify_signature": False})
if not userinfo:
# jwt.decode returned an empty dict (payload-free JWT or provider bug).
# Treat this the same as a missing userinfo — the session would have no
# identity claims, which is equivalent to a broken session.
verbose_proxy_logger.warning(
"id_token decoded to an empty payload — treating as failure."
)
userinfo = None
except Exception as decode_err:
verbose_proxy_logger.error("Failed to decode id_token: %s", decode_err)
raise ProxyException(
message=f"Failed to decode id_token JWT: {decode_err}",
type=ProxyErrorTypes.auth_error,
param="userinfo",
code=status.HTTP_401_UNAUTHORIZED,
)
if userinfo is None:
id_token_attempted = isinstance(id_token, str) and bool(id_token)
if userinfo_endpoint:
if id_token_attempted:
detail = (
"userinfo endpoint failed and id_token was present but "
"decoded to an empty payload — no identity claims available"
)
else:
detail = "userinfo endpoint failed and no id_token was present in the token response"
else:
if id_token_attempted:
detail = (
"no userinfo endpoint is configured (GENERIC_USERINFO_ENDPOINT) "
"and id_token decoded to an empty payload — no identity claims available"
)
else:
detail = "no userinfo endpoint is configured (GENERIC_USERINFO_ENDPOINT) and no id_token was present"
raise ProxyException(
message=f"SSO user info unavailable: {detail}.",
type=ProxyErrorTypes.auth_error,
param="userinfo",
code=status.HTTP_401_UNAUTHORIZED,
)
return userinfo
class MicrosoftSSOHandler:
"""
Handles Microsoft SSO callback response and returns a CustomOpenID object
"""
graph_api_base_url = "https://graph.microsoft.com/v1.0"
graph_api_user_groups_endpoint = f"{graph_api_base_url}/me/memberOf"
"""
Constants
"""
MAX_GRAPH_API_PAGES = 200
# used for debugging to show the user groups litellm found from Graph API
GRAPH_API_RESPONSE_KEY = "graph_api_user_groups"
@staticmethod
async def get_microsoft_callback_response(
request: Request,
microsoft_client_id: str,
redirect_url: str,
return_raw_sso_response: bool = False,
) -> Union[CustomOpenID, OpenID, dict]:
"""
Get the Microsoft SSO callback response
Args:
return_raw_sso_response: If True, return the raw SSO response
"""
microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None)
microsoft_tenant = os.getenv("MICROSOFT_TENANT", None)
if microsoft_client_secret is None:
raise ProxyException(
message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="MICROSOFT_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if microsoft_tenant is None:
raise ProxyException(
message="MICROSOFT_TENANT not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="MICROSOFT_TENANT",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
microsoft_sso = CustomMicrosoftSSO(
client_id=microsoft_client_id,
client_secret=microsoft_client_secret,
tenant=microsoft_tenant,
redirect_uri=redirect_url,
allow_insecure_http=True,
)
original_msft_result = (
await microsoft_sso.verify_and_process(
request=request,
convert_response=False, # type: ignore
)
or {}
)
user_team_ids = await MicrosoftSSOHandler.get_user_groups_from_graph_api(
access_token=microsoft_sso.access_token
)
# Extract app roles from the id_token JWT
app_roles = MicrosoftSSOHandler.get_app_roles_from_id_token(
id_token=microsoft_sso.id_token
)
verbose_proxy_logger.debug(f"Extracted app roles from id_token: {app_roles}")
# Combine groups and app roles
user_role: Optional[LitellmUserRoles] = None
if app_roles:
# Check if any app role is a valid LitellmUserRoles
for role_str in app_roles:
role = get_litellm_user_role(role_str)
if role is not None:
user_role = role
verbose_proxy_logger.debug(
f"Found valid LitellmUserRoles '{role.value}' in app_roles"
)
break
verbose_proxy_logger.debug(
f"Combined team_ids (groups + app roles): {user_team_ids}"
)
# if user is trying to get the raw sso response for debugging, return the raw sso response
if return_raw_sso_response:
original_msft_result[
MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY
] = user_team_ids
original_msft_result["app_roles"] = app_roles
return original_msft_result or {}
result = MicrosoftSSOHandler.openid_from_response(
response=original_msft_result,
team_ids=user_team_ids,
user_role=user_role,
)
return result
@staticmethod
def openid_from_response(
response: Optional[dict],
team_ids: List[str],
user_role: Optional[LitellmUserRoles],
) -> CustomOpenID:
response = response or {}
verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}")
openid_response = CustomOpenID(
email=normalize_email(
response.get(MICROSOFT_USER_EMAIL_ATTRIBUTE) or response.get("mail")
),
display_name=response.get(MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE),
provider="microsoft",
id=response.get(MICROSOFT_USER_ID_ATTRIBUTE),
first_name=response.get(MICROSOFT_USER_FIRST_NAME_ATTRIBUTE),
last_name=response.get(MICROSOFT_USER_LAST_NAME_ATTRIBUTE),
team_ids=team_ids,
user_role=user_role,
)
verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}")
return openid_response
@staticmethod
def get_app_roles_from_id_token(id_token: Optional[str]) -> List[str]:
"""
Extract app roles from the Microsoft Entra ID (Azure AD) id_token JWT.
App roles are assigned in the Azure AD Enterprise Application and appear
in the 'app_roles' claim of the id_token.
Args:
id_token (Optional[str]): The JWT id_token from Microsoft SSO
Returns:
List[str]: List of app role names assigned to the user
"""
if not id_token:
verbose_proxy_logger.debug("No id_token provided for app role extraction")
return []
try:
import jwt
# Decode the JWT without signature verification
# (signature is already verified by fastapi_sso)
decoded_token = jwt.decode(id_token, options={"verify_signature": False})
# Extract app_roles claim from the token
## check for both 'roles' and 'app_roles' claims
roles = decoded_token.get("app_roles", []) or decoded_token.get("roles", [])
if roles and isinstance(roles, list):
verbose_proxy_logger.debug(
f"Found {len(roles)} app role(s) in id_token: {roles}"
)
return roles
else:
verbose_proxy_logger.debug(
"No app roles found in id_token or roles claim is not a list"
)
return []
except Exception as e:
verbose_proxy_logger.error(f"Error extracting app roles from id_token: {e}")
return []
@staticmethod
async def get_user_groups_from_graph_api(
access_token: Optional[str] = None,
) -> List[str]:
"""
Returns a list of `team_ids` the user belongs to from the Microsoft Graph API
Args:
access_token (Optional[str]): Microsoft Graph API access token
Returns:
List[str]: List of group IDs the user belongs to
"""
try:
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SSO_HANDLER
)
# Handle MSFT Enterprise Application Groups
service_principal_id = os.getenv("MICROSOFT_SERVICE_PRINCIPAL_ID", None)
service_principal_group_ids: Optional[List[str]] = []
service_principal_teams: Optional[List[MicrosoftServicePrincipalTeam]] = []
if service_principal_id:
(
service_principal_group_ids,
service_principal_teams,
) = await MicrosoftSSOHandler.get_group_ids_from_service_principal(
service_principal_id=service_principal_id,
async_client=async_client,
access_token=access_token,
)
verbose_proxy_logger.debug(
f"Service principal group IDs: {service_principal_group_ids}"
)
if len(service_principal_group_ids) > 0:
await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids(
service_principal_teams=service_principal_teams,
)
# Fetch user membership from Microsoft Graph API
all_group_ids = []
next_link: Optional[
str
] = MicrosoftSSOHandler.graph_api_user_groups_endpoint
auth_headers = {"Authorization": f"Bearer {access_token}"}
page_count = 0
while (
next_link is not None
and page_count < MicrosoftSSOHandler.MAX_GRAPH_API_PAGES
):
group_ids, next_link = await MicrosoftSSOHandler.fetch_and_parse_groups(
url=next_link, headers=auth_headers, async_client=async_client
)
all_group_ids.extend(group_ids)
page_count += 1
if (
next_link is not None
and page_count >= MicrosoftSSOHandler.MAX_GRAPH_API_PAGES
):
verbose_proxy_logger.warning(
f"Reached maximum page limit of {MicrosoftSSOHandler.MAX_GRAPH_API_PAGES}. Some groups may not be included."
)
# If service_principal_group_ids is not empty, only return group_ids that are in both all_group_ids and service_principal_group_ids
if service_principal_group_ids and len(service_principal_group_ids) > 0:
all_group_ids = [
group_id
for group_id in all_group_ids
if group_id in service_principal_group_ids
]
return all_group_ids
except Exception as e:
verbose_proxy_logger.error(
f"Error getting user groups from Microsoft Graph API: {e}"
)
return []
@staticmethod
async def fetch_and_parse_groups(
url: str, headers: dict, async_client: AsyncHTTPHandler
) -> Tuple[List[str], Optional[str]]:
"""Helper function to fetch and parse group data from a URL"""
response = await async_client.get(url, headers=headers)
response_json = response.json()
response_typed = await MicrosoftSSOHandler._cast_graph_api_response_dict(
response=response_json
)
group_ids = MicrosoftSSOHandler._get_group_ids_from_graph_api_response(
response=response_typed
)
return group_ids, response_typed.get("odata_nextLink")
@staticmethod
def _get_group_ids_from_graph_api_response(
response: MicrosoftGraphAPIUserGroupResponse,
) -> List[str]:
group_ids = []
for _object in response.get("value", []) or []:
_group_id = _object.get("id")
if _group_id is not None:
group_ids.append(_group_id)
return group_ids
@staticmethod
async def _cast_graph_api_response_dict(
response: dict,
) -> MicrosoftGraphAPIUserGroupResponse:
directory_objects: List[MicrosoftGraphAPIUserGroupDirectoryObject] = []
for _object in response.get("value", []):
directory_objects.append(
MicrosoftGraphAPIUserGroupDirectoryObject(
odata_type=_object.get("@odata.type"),
id=_object.get("id"),
deletedDateTime=_object.get("deletedDateTime"),
description=_object.get("description"),
displayName=_object.get("displayName"),
roleTemplateId=_object.get("roleTemplateId"),
)
)
return MicrosoftGraphAPIUserGroupResponse(
odata_context=response.get("@odata.context"),
odata_nextLink=response.get("@odata.nextLink"),
value=directory_objects,
)
@staticmethod
async def get_group_ids_from_service_principal(
service_principal_id: str,
async_client: AsyncHTTPHandler,
access_token: Optional[str] = None,
) -> Tuple[List[str], List[MicrosoftServicePrincipalTeam]]:
"""
Gets the groups belonging to the Service Principal Application
Service Principal Id is an `Enterprise Application` in Azure AD
Users use Enterprise Applications to manage Groups and Users on Microsoft Entra ID
"""
base_url = "https://graph.microsoft.com/v1.0"
# Endpoint to get app role assignments for the given service principal
endpoint = f"/servicePrincipals/{service_principal_id}/appRoleAssignedTo"
url = base_url + endpoint
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
response = await async_client.get(url, headers=headers)
response_json = response.json()
verbose_proxy_logger.debug(
f"Response from service principal app role assigned to: {response_json}"
)
group_ids: List[str] = []
service_principal_teams: List[MicrosoftServicePrincipalTeam] = []
for _object in response_json.get("value", []):
if _object.get("principalType") == "Group":
# Append the group ID to the list
group_ids.append(_object.get("principalId"))
# Append the service principal team to the list
service_principal_teams.append(
MicrosoftServicePrincipalTeam(
principalDisplayName=_object.get("principalDisplayName"),
principalId=_object.get("principalId"),
)
)
return group_ids, service_principal_teams
@staticmethod
async def create_litellm_teams_from_service_principal_team_ids(
service_principal_teams: List[MicrosoftServicePrincipalTeam],
):
"""
Creates Litellm Teams from the Service Principal Group IDs
When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them
"""
verbose_proxy_logger.debug(
f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}"
)
for service_principal_team in service_principal_teams:
litellm_team_id: Optional[str] = service_principal_team.get("principalId")
litellm_team_name: Optional[str] = service_principal_team.get(
"principalDisplayName"
)
if not litellm_team_id:
verbose_proxy_logger.debug(
f"Skipping team creation for {litellm_team_name} because it has no principalId"
)
continue
await SSOAuthenticationHandler.create_litellm_team_from_sso_group(
litellm_team_id=litellm_team_id,
litellm_team_name=litellm_team_name,
)
class GoogleSSOHandler:
"""
Handles Google SSO callback response and returns a CustomOpenID object
"""
@staticmethod
async def get_google_callback_response(
request: Request,
google_client_id: str,
redirect_url: str,
return_raw_sso_response: bool = False,
) -> Union[OpenID, dict]:
"""
Get the Google SSO callback response
Args:
return_raw_sso_response: If True, return the raw SSO response
"""
from fastapi_sso.sso.google import GoogleSSO
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None)
if google_client_secret is None:
raise ProxyException(
message="GOOGLE_CLIENT_SECRET not set. Set it in .env file",
type=ProxyErrorTypes.auth_error,
param="GOOGLE_CLIENT_SECRET",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
google_sso = GoogleSSO(
client_id=google_client_id,
redirect_uri=redirect_url,
client_secret=google_client_secret,
)
# if user is trying to get the raw sso response for debugging, return the raw sso response
if return_raw_sso_response:
return (
await google_sso.verify_and_process(
request=request,
convert_response=False, # type: ignore
)
or {}
)
result = await google_sso.verify_and_process(request)
return result or {}
@router.get("/sso/debug/login", tags=["experimental"], include_in_schema=False)
async def debug_sso_login(request: Request):
"""
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
Example:
"""
from litellm.proxy.proxy_server import premium_user
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
####### Check if user is a Enterprise / Premium User #######
if (
microsoft_client_id is not None
or google_client_id is not None
or generic_client_id is not None
):
if premium_user is not True:
raise ProxyException(
message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/cx9p-5yf-2nm/litellm-introductions You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
type=ProxyErrorTypes.auth_error,
param="premium_user",
code=status.HTTP_403_FORBIDDEN,
)
# get url from request
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
request=request,
sso_callback_route="sso/debug/callback",
)
# Check if we should use SSO handler
if (
SSOAuthenticationHandler.should_use_sso_handler(
microsoft_client_id=microsoft_client_id,
google_client_id=google_client_id,
generic_client_id=generic_client_id,
)
is True
):
return await SSOAuthenticationHandler.get_sso_login_redirect(
redirect_url=redirect_url,
microsoft_client_id=microsoft_client_id,
google_client_id=google_client_id,
generic_client_id=generic_client_id,
)
@router.get("/sso/debug/callback", tags=["experimental"], include_in_schema=False)
async def debug_sso_callback(request: Request):
"""
Returns the OpenID object returned by the SSO provider
"""
import json
from fastapi.responses import HTMLResponse
from litellm.proxy._types import LiteLLM_JWTAuth
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.proxy_server import (
general_settings,
jwt_handler,
prisma_client,
user_api_key_cache,
)
sso_jwt_handler: Optional[JWTHandler] = None
ui_access_mode = general_settings.get("ui_access_mode", None)
if ui_access_mode is not None and isinstance(ui_access_mode, dict):
sso_jwt_handler = JWTHandler()
sso_jwt_handler.update_environment(
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
litellm_jwtauth=LiteLLM_JWTAuth(
team_ids_jwt_field=general_settings.get("ui_access_mode", {}).get(
"sso_group_jwt_field", None
),
),
leeway=0,
)
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
if redirect_url.endswith("/"):
redirect_url += "sso/debug/callback"
else:
redirect_url += "/sso/debug/callback"
result = None
if google_client_id is not None:
result = await GoogleSSOHandler.get_google_callback_response(
request=request,
google_client_id=google_client_id,
redirect_url=redirect_url,
return_raw_sso_response=True,
)
elif microsoft_client_id is not None:
result = await MicrosoftSSOHandler.get_microsoft_callback_response(
request=request,
microsoft_client_id=microsoft_client_id,
redirect_url=redirect_url,
return_raw_sso_response=True,
)
elif generic_client_id is not None:
result, _ = await get_generic_sso_response(
request=request,
jwt_handler=jwt_handler,
generic_client_id=generic_client_id,
redirect_url=redirect_url,
sso_jwt_handler=sso_jwt_handler,
)
# If result is None, return a basic error message
if result is None:
return HTMLResponse(
content="<h1>SSO Authentication Failed</h1><p>No data was returned from the SSO provider.</p>",
status_code=400,
)
# Convert the OpenID object to a dictionary
if hasattr(result, "__dict__"):
result_dict = result.__dict__
else:
result_dict = dict(result)
# Filter out any None values and convert to JSON serializable format
filtered_result = {}
for key, value in result_dict.items():
if value is not None and not key.startswith("_"):
if isinstance(value, (str, int, float, bool)) or value is None:
filtered_result[key] = value
else:
try:
# Try to convert to string or another JSON serializable format
filtered_result[key] = str(value)
except Exception as e:
filtered_result[key] = f"Complex value (not displayable): {str(e)}"
# Replace the placeholder in the template with the actual data
html_content = jwt_display_template.replace(
"const userData = SSO_DATA;",
f"const userData = {json.dumps(filtered_result, indent=2)};",
)
return HTMLResponse(content=html_content)