chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""Command groups for the LiteLLM proxy CLI."""

View File

@@ -0,0 +1,623 @@
import json
import os
import sys
import time
import webbrowser
from pathlib import Path
from typing import Any, Dict, List, Optional
import click
import requests
from rich.console import Console
from rich.table import Table
from litellm.constants import CLI_JWT_EXPIRATION_HOURS
# Token storage utilities
def get_token_file_path() -> str:
"""Get the path to store the authentication token"""
home_dir = Path.home()
config_dir = home_dir / ".litellm"
config_dir.mkdir(exist_ok=True)
return str(config_dir / "token.json")
def save_token(token_data: Dict[str, Any]) -> None:
"""Save token data to file"""
token_file = get_token_file_path()
with open(token_file, "w") as f:
json.dump(token_data, f, indent=2)
# Set file permissions to be readable only by owner
os.chmod(token_file, 0o600)
def load_token() -> Optional[Dict[str, Any]]:
"""Load token data from file"""
token_file = get_token_file_path()
if not os.path.exists(token_file):
return None
try:
with open(token_file, "r") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return None
def clear_token() -> None:
"""Clear stored token"""
token_file = get_token_file_path()
if os.path.exists(token_file):
os.remove(token_file)
def get_stored_api_key() -> Optional[str]:
"""Get the stored API key from token file"""
# Use the SDK-level utility
from litellm.litellm_core_utils.cli_token_utils import get_litellm_gateway_api_key
return get_litellm_gateway_api_key()
# Team selection utilities
def display_teams_table(teams: List[Dict[str, Any]]) -> None:
"""Display teams in a formatted table"""
console = Console()
if not teams:
console.print("❌ No teams found for your user.")
return
table = Table(title="Available Teams")
table.add_column("Index", style="cyan", no_wrap=True)
table.add_column("Team Alias", style="magenta")
table.add_column("Team ID", style="green")
table.add_column("Models", style="yellow")
table.add_column("Max Budget", style="blue")
for i, team in enumerate(teams):
team_alias = team.get("team_alias") or "N/A"
team_id = team.get("team_id", "N/A")
models = team.get("models", [])
max_budget = team.get("max_budget")
# Format models list
if models:
if len(models) > 3:
models_str = ", ".join(models[:3]) + f" (+{len(models) - 3} more)"
else:
models_str = ", ".join(models)
else:
models_str = "All models"
# Format budget
budget_str = f"${max_budget}" if max_budget else "Unlimited"
table.add_row(str(i + 1), team_alias, team_id, models_str, budget_str)
console.print(table)
def get_key_input():
"""Get a single key input from the user (cross-platform)"""
try:
if sys.platform == "win32":
import msvcrt
key = msvcrt.getch()
if key == b"\xe0": # Arrow keys on Windows
key = msvcrt.getch()
if key == b"H": # Up arrow
return "up"
elif key == b"P": # Down arrow
return "down"
elif key == b"\r": # Enter key
return "enter"
elif key == b"\x1b": # Escape key
return "escape"
elif key == b"q":
return "quit"
return None
else:
import termios
import tty
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(sys.stdin.fileno())
key = sys.stdin.read(1)
if key == "\x1b": # Escape sequence
key += sys.stdin.read(2)
if key == "\x1b[A": # Up arrow
return "up"
elif key == "\x1b[B": # Down arrow
return "down"
elif key == "\x1b": # Just escape
return "escape"
elif key == "\r" or key == "\n": # Enter key
return "enter"
elif key == "q":
return "quit"
return None
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
except ImportError:
# Fallback to simple input if termios/msvcrt not available
return None
def display_interactive_team_selection(
teams: List[Dict[str, Any]], selected_index: int = 0
) -> None:
"""Display teams with one highlighted for selection"""
console = Console()
# Clear the screen using Rich's method
console.clear()
console.print("🎯 Select a Team (Use ↑↓ arrows, Enter to select, 'q' to skip):\n")
for i, team in enumerate(teams):
team_alias = team.get("team_alias") or "N/A"
team_id = team.get("team_id", "N/A")
models = team.get("models", [])
max_budget = team.get("max_budget")
# Format models list
if models:
if len(models) > 3:
models_str = ", ".join(models[:3]) + f" (+{len(models) - 3} more)"
else:
models_str = ", ".join(models)
else:
models_str = "All models"
# Format budget
budget_str = f"${max_budget}" if max_budget else "Unlimited"
# Highlight the selected item
if i == selected_index:
console.print(f"➤ [bold cyan]{team_alias}[/bold cyan] ({team_id})")
console.print(f" Models: [yellow]{models_str}[/yellow]")
console.print(f" Budget: [blue]{budget_str}[/blue]\n")
else:
console.print(f" [dim]{team_alias}[/dim] ({team_id})")
console.print(f" Models: [dim]{models_str}[/dim]")
console.print(f" Budget: [dim]{budget_str}[/dim]\n")
def prompt_team_selection(teams: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Interactive team selection with arrow keys"""
if not teams:
return None
selected_index = 0
try:
# Check if we can use interactive mode
if not sys.stdin.isatty():
# Fallback to simple selection for non-interactive environments
return prompt_team_selection_fallback(teams)
while True:
display_interactive_team_selection(teams, selected_index)
key = get_key_input()
if key == "up":
selected_index = (selected_index - 1) % len(teams)
elif key == "down":
selected_index = (selected_index + 1) % len(teams)
elif key == "enter":
selected_team = teams[selected_index]
# Clear screen and show selection
console = Console()
console.clear()
click.echo(
f"✅ Selected team: {selected_team.get('team_alias', 'N/A')} ({selected_team.get('team_id')})"
)
return selected_team
elif key == "quit" or key == "escape":
# Clear screen
console = Console()
console.clear()
click.echo(" Team selection skipped.")
return None
elif key is None:
# If we can't get key input, fall back to simple selection
return prompt_team_selection_fallback(teams)
except KeyboardInterrupt:
console = Console()
console.clear()
click.echo("\n❌ Team selection cancelled.")
return None
except Exception:
# If interactive mode fails, fall back to simple selection
return prompt_team_selection_fallback(teams)
def prompt_team_selection_fallback(
teams: List[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""Fallback team selection for non-interactive environments"""
if not teams:
return None
while True:
try:
choice = click.prompt(
"\nSelect a team by entering the index number (or 'skip' to continue without a team)",
type=str,
).strip()
if choice.lower() == "skip":
return None
index = int(choice) - 1
if 0 <= index < len(teams):
selected_team = teams[index]
click.echo(
f"\n✅ Selected team: {selected_team.get('team_alias', 'N/A')} ({selected_team.get('team_id')})"
)
return selected_team
else:
click.echo(
f"❌ Invalid selection. Please enter a number between 1 and {len(teams)}"
)
except ValueError:
click.echo("❌ Invalid input. Please enter a number or 'skip'")
except KeyboardInterrupt:
click.echo("\n❌ Team selection cancelled.")
return None
# Polling-based authentication - no local server needed
def _poll_for_ready_data(
url: str,
*,
total_timeout: int = 300,
poll_interval: int = 2,
request_timeout: int = 10,
pending_message: Optional[str] = None,
pending_log_every: int = 10,
other_status_message: Optional[str] = None,
other_status_log_every: int = 10,
http_error_log_every: int = 10,
connection_error_log_every: int = 10,
) -> Optional[Dict[str, Any]]:
for attempt in range(total_timeout // poll_interval):
try:
response = requests.get(url, timeout=request_timeout)
if response.status_code == 200:
data = response.json()
status = data.get("status")
if status == "ready":
return data
if status == "pending":
if (
pending_message
and pending_log_every > 0
and attempt % pending_log_every == 0
):
click.echo(pending_message)
elif (
other_status_message
and other_status_log_every > 0
and attempt % other_status_log_every == 0
):
click.echo(other_status_message)
elif http_error_log_every > 0 and attempt % http_error_log_every == 0:
click.echo(f"Polling error: HTTP {response.status_code}")
except requests.RequestException as e:
if (
connection_error_log_every > 0
and attempt % connection_error_log_every == 0
):
click.echo(f"Connection error (will retry): {e}")
time.sleep(poll_interval)
return None
def _normalize_teams(teams, team_details):
"""If team_details are a
Args:
teams (_type_): _description_
team_details (_type_): _description_
Returns:
_type_: _description_
"""
if isinstance(team_details, list) and team_details:
return [
{
"team_id": i.get("team_id") or i.get("id"),
"team_alias": i.get("team_alias"),
}
for i in team_details
if isinstance(i, dict) and (i.get("team_id") or i.get("id"))
]
if isinstance(teams, list):
return [{"team_id": str(t), "team_alias": None} for t in teams]
return []
def _poll_for_authentication(base_url: str, key_id: str) -> Optional[dict]:
"""
Poll the server for authentication completion and handle team selection.
Returns:
Dictionary with authentication data if successful, None otherwise
"""
poll_url = f"{base_url}/sso/cli/poll/{key_id}"
data = _poll_for_ready_data(
poll_url,
pending_message="Still waiting for authentication...",
)
if not data:
return None
if data.get("requires_team_selection"):
teams = data.get("teams", [])
team_details = data.get("team_details")
user_id = data.get("user_id")
normalized_teams: List[Dict[str, Any]] = _normalize_teams(teams, team_details)
if not normalized_teams:
click.echo("⚠️ No teams available for selection.")
return None
# User has multiple teams - let them select
jwt_with_team = _handle_team_selection_during_polling(
base_url=base_url,
key_id=key_id,
teams=normalized_teams,
)
# Use the team-specific JWT if selection succeeded
if jwt_with_team:
return {
"api_key": jwt_with_team,
"user_id": user_id,
"teams": teams,
"team_id": None, # Set by server in JWT
}
click.echo("❌ Team selection cancelled or JWT generation failed.")
return None
# JWT is ready (single team or team already selected)
api_key = data.get("key")
user_id = data.get("user_id")
teams = data.get("teams", [])
team_id = data.get("team_id")
# Show which team was assigned
if team_id and len(teams) == 1:
click.echo(f"\n✅ Automatically assigned to team: {team_id}")
if api_key:
return {
"api_key": api_key,
"user_id": user_id,
"teams": teams,
"team_id": team_id,
}
return None
def _handle_team_selection_during_polling(
base_url: str, key_id: str, teams: List[Dict[str, Any]]
) -> Optional[str]:
"""
Handle team selection and re-poll with selected team_id.
Args:
teams: List of team IDs (strings)
Returns:
The JWT token with the selected team, or None if selection was skipped
"""
if not teams:
click.echo(
" No teams found. You can create or join teams using the web interface."
)
return None
click.echo("\n" + "=" * 60)
click.echo("📋 Select a team for your CLI session...")
team_id = _render_and_prompt_for_team_selection(teams)
if not team_id:
click.echo(" No team selected.")
return None
click.echo(f"\n🔄 Generating JWT for team: {team_id}")
poll_url = f"{base_url}/sso/cli/poll/{key_id}?team_id={team_id}"
data = _poll_for_ready_data(
poll_url,
pending_message="Still waiting for team authentication...",
other_status_message="Waiting for team authentication to complete...",
http_error_log_every=10,
)
if not data:
return None
jwt_token = data.get("key")
if jwt_token:
click.echo(f"✅ Successfully generated JWT for team: {team_id}")
return jwt_token
return None
def _render_and_prompt_for_team_selection(teams: List[Dict[str, Any]]) -> Optional[str]:
"""Render teams table and prompt user for a team selection.
Returns the selected team_id as a string, or None if selection was
cancelled or skipped without any teams available.
"""
# Display teams as a simple list, but prefer showing aliases where
# available while still keeping the underlying IDs intact.
console = Console()
table = Table(title="Available Teams")
table.add_column("Index", style="cyan", no_wrap=True)
table.add_column("Team Name", style="magenta")
table.add_column("Team ID", style="green")
for i, team in enumerate(teams):
team_id = str(team.get("team_id"))
team_alias = team.get("team_alias") or team_id
table.add_row(str(i + 1), team_alias, team_id)
console.print(table)
# Simple selection
while True:
try:
choice = click.prompt(
"\nSelect a team by entering the index number (or 'skip' to use first team)",
type=str,
).strip()
if choice.lower() == "skip":
# Default to the first team's ID if the user skips an
# explicit selection.
if teams:
first_team = teams[0]
return str(first_team.get("team_id"))
return None
index = int(choice) - 1
if 0 <= index < len(teams):
selected_team = teams[index]
team_id = str(selected_team.get("team_id"))
team_alias = selected_team.get("team_alias") or team_id
click.echo(f"\n✅ Selected team: {team_alias} ({team_id})")
return team_id
click.echo(
f"❌ Invalid selection. Please enter a number between 1 and {len(teams)}"
)
except ValueError:
click.echo("❌ Invalid input. Please enter a number or 'skip'")
except KeyboardInterrupt:
click.echo("\n❌ Team selection cancelled.")
return None
@click.command(name="login")
@click.pass_context
def login(ctx: click.Context):
"""Login to LiteLLM proxy using SSO authentication"""
from litellm._uuid import uuid
from litellm.constants import LITELLM_CLI_SOURCE_IDENTIFIER
from litellm.proxy.client.cli.interface import show_commands
base_url = ctx.obj["base_url"]
# Check if we have an existing key to regenerate
existing_key = get_stored_api_key()
# Generate unique key ID for this login session
key_id = f"sk-{str(uuid.uuid4())}"
try:
# Construct SSO login URL with CLI source and pre-generated key
sso_url = f"{base_url}/sso/key/generate?source={LITELLM_CLI_SOURCE_IDENTIFIER}&key={key_id}"
# If we have an existing key, include it as a parameter to the login endpoint
# The server will encode it in the OAuth state parameter for the SSO flow
if existing_key:
sso_url += f"&existing_key={existing_key}"
click.echo(f"Opening browser to: {sso_url}")
click.echo("Please complete the SSO authentication in your browser...")
click.echo(f"Session ID: {key_id}")
# Open browser
webbrowser.open(sso_url)
# Poll for authentication completion
click.echo("Waiting for authentication...")
auth_result = _poll_for_authentication(base_url=base_url, key_id=key_id)
if auth_result:
api_key = auth_result["api_key"]
user_id = auth_result["user_id"]
# Save token data (simplified for CLI - we just need the key)
save_token(
{
"key": api_key,
"user_id": user_id or "cli-user",
"user_email": "unknown",
"user_role": "cli",
"auth_header_name": "Authorization",
"jwt_token": "",
"timestamp": time.time(),
}
)
click.echo("\n✅ Login successful!")
click.echo(f"JWT Token: {api_key[:20]}...")
click.echo("You can now use the CLI without specifying --api-key")
# Show available commands after successful login
click.echo("\n" + "=" * 60)
show_commands()
return
else:
click.echo("❌ Authentication timed out. Please try again.")
return
except KeyboardInterrupt:
click.echo("\n❌ Authentication cancelled by user.")
return
except Exception as e:
click.echo(f"❌ Authentication failed: {e}")
return
@click.command(name="logout")
def logout():
"""Logout and clear stored authentication"""
clear_token()
click.echo("✅ Logged out successfully. Authentication token cleared.")
@click.command(name="whoami")
def whoami():
"""Show current authentication status"""
token_data = load_token()
if not token_data:
click.echo("❌ Not authenticated. Run 'litellm-proxy login' to authenticate.")
return
click.echo("✅ Authenticated")
click.echo(f"User Email: {token_data.get('user_email', 'Unknown')}")
click.echo(f"User ID: {token_data.get('user_id', 'Unknown')}")
click.echo(f"User Role: {token_data.get('user_role', 'Unknown')}")
# Check if token is still valid (basic timestamp check)
timestamp = token_data.get("timestamp", 0)
age_hours = (time.time() - timestamp) / 3600
click.echo(f"Token age: {age_hours:.1f} hours")
if age_hours > CLI_JWT_EXPIRATION_HOURS:
click.echo(
f"⚠️ Warning: Token is more than {CLI_JWT_EXPIRATION_HOURS} hours old and may have expired."
)
# Export functions for use by other CLI commands
__all__ = ["login", "logout", "whoami", "prompt_team_selection"]
# Export individual commands instead of grouping them
# login, logout, and whoami will be added as top-level commands

View File

@@ -0,0 +1,406 @@
import json
import sys
from typing import Any, Dict, List, Optional
import click
import requests
from rich.console import Console
from rich.panel import Panel
from rich.prompt import Prompt
from rich.table import Table
from ... import Client
from ...chat import ChatClient
def _get_available_models(ctx: click.Context) -> List[Dict[str, Any]]:
"""Get list of available models from the proxy server"""
try:
client = Client(base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"])
models_list = client.models.list()
# Ensure we return a list of dictionaries
if isinstance(models_list, list):
# Filter to ensure all items are dictionaries
return [model for model in models_list if isinstance(model, dict)]
return []
except Exception as e:
click.echo(f"Warning: Could not fetch models list: {e}", err=True)
return []
def _select_model(
console: Console, available_models: List[Dict[str, Any]]
) -> Optional[str]:
"""Interactive model selection"""
if not available_models:
console.print(
"[yellow]No models available or could not fetch models list.[/yellow]"
)
model_name = Prompt.ask("Please enter a model name")
return model_name if model_name.strip() else None
# Display available models in a table
table = Table(title="Available Models")
table.add_column("Index", style="cyan", no_wrap=True)
table.add_column("Model ID", style="green")
table.add_column("Owned By", style="yellow")
MAX_MODELS_TO_DISPLAY = 200
models_to_display: List[Dict[str, Any]] = available_models[:MAX_MODELS_TO_DISPLAY]
for i, model in enumerate(models_to_display): # Limit to first 200 models
table.add_row(
str(i + 1), str(model.get("id", "")), str(model.get("owned_by", ""))
)
if len(available_models) > MAX_MODELS_TO_DISPLAY:
console.print(
f"\n[dim]... and {len(available_models) - MAX_MODELS_TO_DISPLAY} more models[/dim]"
)
console.print(table)
while True:
try:
choice = Prompt.ask(
"\nSelect a model by entering the index number (or type a model name directly)",
default="1",
).strip()
# Try to parse as index
try:
index = int(choice) - 1
if 0 <= index < len(available_models):
return available_models[index]["id"]
else:
console.print(
f"[red]Invalid index. Please enter a number between 1 and {len(available_models)}[/red]"
)
continue
except ValueError:
# Not a number, treat as model name
if choice:
return choice
else:
console.print("[red]Please enter a valid model name or index[/red]")
continue
except KeyboardInterrupt:
console.print("\n[yellow]Model selection cancelled.[/yellow]")
return None
@click.command()
@click.argument("model", required=False)
@click.option(
"--temperature",
"-t",
type=float,
default=0.7,
help="Sampling temperature between 0 and 2 (default: 0.7)",
)
@click.option(
"--max-tokens",
type=int,
help="Maximum number of tokens to generate",
)
@click.option(
"--system",
"-s",
type=str,
help="System message to set the behavior of the assistant",
)
@click.pass_context
def chat(
ctx: click.Context,
model: Optional[str],
temperature: float,
max_tokens: Optional[int] = None,
system: Optional[str] = None,
):
"""Interactive chat with streaming responses
Examples:
# Chat with a specific model
litellm-proxy chat gpt-4
# Chat without specifying model (will show model selection)
litellm-proxy chat
# Chat with custom settings
litellm-proxy chat gpt-4 --temperature 0.9 --system "You are a helpful coding assistant"
"""
console = Console()
# If no model specified, show model selection
if not model:
available_models = _get_available_models(ctx)
model = _select_model(console, available_models)
if not model:
console.print("[red]No model selected. Exiting.[/red]")
return
client = ChatClient(ctx.obj["base_url"], ctx.obj["api_key"])
# Initialize conversation history
messages: List[Dict[str, Any]] = []
# Add system message if provided
if system:
messages.append({"role": "system", "content": system})
# Display welcome message
console.print(
Panel.fit(
f"[bold blue]LiteLLM Interactive Chat[/bold blue]\n"
f"Model: [green]{model}[/green]\n"
f"Temperature: [yellow]{temperature}[/yellow]\n"
f"Max Tokens: [yellow]{max_tokens or 'unlimited'}[/yellow]\n\n"
f"Type your messages and press Enter. Type '/quit' or '/exit' to end the session.\n"
f"Type '/help' for more commands.",
title="🤖 Chat Session",
)
)
try:
while True:
# Get user input
try:
user_input = console.input("\n[bold cyan]You:[/bold cyan] ").strip()
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Chat session ended.[/yellow]")
break
# Handle special commands
should_exit, messages, new_model = _handle_special_commands(
console, user_input, messages, system, ctx
)
if should_exit:
break
if new_model:
model = new_model
# Check if this was a special command that was handled (not a normal message)
if (
user_input.lower().startswith(
(
"/quit",
"/exit",
"/q",
"/help",
"/clear",
"/history",
"/save",
"/load",
"/model",
)
)
or not user_input
):
continue
# Add user message to conversation
messages.append({"role": "user", "content": user_input})
# Display assistant label
console.print("\n[bold green]Assistant:[/bold green]")
# Stream the response
assistant_content = _stream_response(
console=console,
client=client,
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
# Add assistant message to conversation history
if assistant_content:
messages.append({"role": "assistant", "content": assistant_content})
else:
console.print("[red]Error: No content received from the model[/red]")
except KeyboardInterrupt:
console.print("\n[yellow]Chat session interrupted.[/yellow]")
def _show_help(console: Console):
"""Show help for interactive chat commands"""
help_text = """
[bold]Interactive Chat Commands:[/bold]
[cyan]/help[/cyan] - Show this help message
[cyan]/quit[/cyan] - Exit the chat session (also /exit, /q)
[cyan]/clear[/cyan] - Clear conversation history
[cyan]/history[/cyan] - Show conversation history
[cyan]/model[/cyan] - Switch to a different model
[cyan]/save <name>[/cyan] - Save conversation to file
[cyan]/load <name>[/cyan] - Load conversation from file
[bold]Tips:[/bold]
- Your conversation history is maintained during the session
- Use Ctrl+C to interrupt at any time
- Responses are streamed in real-time
- You can switch models mid-conversation with /model
"""
console.print(Panel(help_text, title="Help"))
def _show_history(console: Console, messages: List[Dict[str, Any]]):
"""Show conversation history"""
if not messages:
console.print("[yellow]No conversation history.[/yellow]")
return
console.print(Panel.fit("[bold]Conversation History[/bold]", title="History"))
for i, message in enumerate(messages, 1):
role = message["role"]
content = message["content"]
if role == "system":
console.print(
f"[dim]{i}. [bold magenta]System:[/bold magenta] {content}[/dim]"
)
elif role == "user":
console.print(f"{i}. [bold cyan]You:[/bold cyan] {content}")
elif role == "assistant":
console.print(
f"{i}. [bold green]Assistant:[/bold green] {content[:100]}{'...' if len(content) > 100 else ''}"
)
def _save_conversation(console: Console, messages: List[Dict[str, Any]], command: str):
"""Save conversation to a file"""
parts = command.split()
if len(parts) < 2:
console.print("[red]Usage: /save <filename>[/red]")
return
filename = parts[1]
if not filename.endswith(".json"):
filename += ".json"
try:
with open(filename, "w") as f:
json.dump(messages, f, indent=2)
console.print(f"[green]Conversation saved to {filename}[/green]")
except Exception as e:
console.print(f"[red]Error saving conversation: {e}[/red]")
def _load_conversation(
console: Console, command: str, system: Optional[str]
) -> List[Dict[str, Any]]:
"""Load conversation from a file"""
parts = command.split()
if len(parts) < 2:
console.print("[red]Usage: /load <filename>[/red]")
return []
filename = parts[1]
if not filename.endswith(".json"):
filename += ".json"
try:
with open(filename, "r") as f:
messages = json.load(f)
console.print(f"[green]Conversation loaded from {filename}[/green]")
return messages
except FileNotFoundError:
console.print(f"[red]File not found: {filename}[/red]")
except Exception as e:
console.print(f"[red]Error loading conversation: {e}[/red]")
# Return empty list or just system message if load failed
if system:
return [{"role": "system", "content": system}]
return []
def _handle_special_commands(
console: Console,
user_input: str,
messages: List[Dict[str, Any]],
system: Optional[str],
ctx: click.Context,
) -> tuple[bool, List[Dict[str, Any]], Optional[str]]:
"""Handle special chat commands. Returns (should_exit, updated_messages, updated_model)"""
if user_input.lower() in ["/quit", "/exit", "/q"]:
console.print("[yellow]Chat session ended.[/yellow]")
return True, messages, None
elif user_input.lower() == "/help":
_show_help(console)
return False, messages, None
elif user_input.lower() == "/clear":
new_messages = []
if system:
new_messages.append({"role": "system", "content": system})
console.print("[green]Conversation history cleared.[/green]")
return False, new_messages, None
elif user_input.lower() == "/history":
_show_history(console, messages)
return False, messages, None
elif user_input.lower().startswith("/save"):
_save_conversation(console, messages, user_input)
return False, messages, None
elif user_input.lower().startswith("/load"):
new_messages = _load_conversation(console, user_input, system)
return False, new_messages, None
elif user_input.lower() == "/model":
available_models = _get_available_models(ctx)
new_model = _select_model(console, available_models)
if new_model:
console.print(f"[green]Switched to model: {new_model}[/green]")
return False, messages, new_model
return False, messages, None
elif not user_input:
return False, messages, None
# Not a special command
return False, messages, None
def _stream_response(
console: Console,
client: ChatClient,
model: str,
messages: List[Dict[str, Any]],
temperature: float,
max_tokens: Optional[int],
) -> Optional[str]:
"""Stream the model response and return the complete content"""
try:
assistant_content = ""
for chunk in client.completions_stream(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
):
if "choices" in chunk and len(chunk["choices"]) > 0:
delta = chunk["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
assistant_content += content
console.print(content, end="")
sys.stdout.flush()
console.print() # Add newline after streaming
return assistant_content if assistant_content else None
except requests.exceptions.HTTPError as e:
console.print(f"\n[red]Error: HTTP {e.response.status_code}[/red]")
try:
error_body = e.response.json()
console.print(
f"[red]{error_body.get('error', {}).get('message', 'Unknown error')}[/red]"
)
except json.JSONDecodeError:
console.print(f"[red]{e.response.text}[/red]")
return None
except Exception as e:
console.print(f"\n[red]Error: {str(e)}[/red]")
return None

View File

@@ -0,0 +1,116 @@
import json
from typing import Literal
import click
import rich
import requests
from rich.table import Table
from ...credentials import CredentialsManagementClient
@click.group()
def credentials():
"""Manage credentials for the LiteLLM proxy server"""
pass
@credentials.command()
@click.option(
"--format",
"output_format",
type=click.Choice(["table", "json"]),
default="table",
help="Output format (table or json)",
)
@click.pass_context
def list(ctx: click.Context, output_format: Literal["table", "json"]):
"""List all credentials"""
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
response = client.list()
assert isinstance(response, dict)
if output_format == "json":
rich.print_json(data=response)
else: # table format
table = Table(title="Credentials")
# Add columns
table.add_column("Credential Name", style="cyan")
table.add_column("Custom LLM Provider", style="green")
# Add rows
for cred in response.get("credentials", []):
info = cred.get("credential_info", {})
table.add_row(
str(cred.get("credential_name", "")),
str(info.get("custom_llm_provider", "")),
)
rich.print(table)
@credentials.command()
@click.argument("credential_name")
@click.option(
"--info",
type=str,
help="JSON string containing credential info",
required=True,
)
@click.option(
"--values",
type=str,
help="JSON string containing credential values",
required=True,
)
@click.pass_context
def create(ctx: click.Context, credential_name: str, info: str, values: str):
"""Create a new credential"""
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
try:
credential_info = json.loads(info)
credential_values = json.loads(values)
except json.JSONDecodeError as e:
raise click.BadParameter(f"Invalid JSON: {str(e)}")
try:
response = client.create(credential_name, credential_info, credential_values)
rich.print_json(data=response)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()
@credentials.command()
@click.argument("credential_name")
@click.pass_context
def delete(ctx: click.Context, credential_name: str):
"""Delete a credential by name"""
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
try:
response = client.delete(credential_name)
rich.print_json(data=response)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()
@credentials.command()
@click.argument("credential_name")
@click.pass_context
def get(ctx: click.Context, credential_name: str):
"""Get a credential by name"""
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
response = client.get(credential_name)
rich.print_json(data=response)

View File

@@ -0,0 +1,102 @@
import json as json_lib
from typing import Optional
import click
import rich
import requests
from ...http_client import HTTPClient
@click.group()
def http():
"""Make HTTP requests to the LiteLLM proxy server"""
pass
@http.command()
@click.argument("method")
@click.argument("uri")
@click.option(
"--data",
"-d",
type=str,
help="Data to send in the request body (as JSON string)",
)
@click.option(
"--json",
"-j",
type=str,
help="JSON data to send in the request body (as JSON string)",
)
@click.option(
"--header",
"-H",
multiple=True,
help="HTTP headers in 'key:value' format. Can be specified multiple times.",
)
@click.pass_context
def request(
ctx: click.Context,
method: str,
uri: str,
data: Optional[str] = None,
json: Optional[str] = None,
header: tuple[str, ...] = (),
):
"""Make an HTTP request to the LiteLLM proxy server
METHOD: HTTP method (GET, POST, PUT, DELETE, etc.)
URI: URI path (will be appended to base_url)
Examples:
litellm http request GET /models
litellm http request POST /chat/completions -j '{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}'
litellm http request GET /health/test_connection -H "X-Custom-Header:value"
"""
# Parse headers from key:value format
headers = {}
for h in header:
try:
key, value = h.split(":", 1)
headers[key.strip()] = value.strip()
except ValueError:
raise click.BadParameter(
f"Invalid header format: {h}. Expected format: 'key:value'"
)
# Parse JSON data if provided
json_data = None
if json:
try:
json_data = json_lib.loads(json)
except ValueError as e:
raise click.BadParameter(f"Invalid JSON format: {e}")
# Parse data if provided
request_data = None
if data:
try:
request_data = json_lib.loads(data)
except ValueError:
# If not JSON, use as raw data
request_data = data
client = HTTPClient(ctx.obj["base_url"], ctx.obj["api_key"])
try:
response = client.request(
method=method,
uri=uri,
data=request_data,
json=json_data,
headers=headers,
)
rich.print_json(data=response)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json_lib.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()

View File

@@ -0,0 +1,415 @@
import json
from datetime import datetime
from typing import Literal, Optional, List, Dict, Any
import click
import rich
import requests
from rich.table import Table
from ...keys import KeysManagementClient
@click.group()
def keys():
"""Manage API keys for the LiteLLM proxy server"""
pass
@keys.command()
@click.option("--page", type=int, help="Page number for pagination")
@click.option("--size", type=int, help="Number of items per page")
@click.option("--user-id", type=str, help="Filter keys by user ID")
@click.option("--team-id", type=str, help="Filter keys by team ID")
@click.option("--organization-id", type=str, help="Filter keys by organization ID")
@click.option("--key-hash", type=str, help="Filter by specific key hash")
@click.option("--key-alias", type=str, help="Filter by key alias")
@click.option(
"--return-full-object",
is_flag=True,
default=True,
help="Return the full key object",
)
@click.option(
"--include-team-keys", is_flag=True, help="Include team keys in the response"
)
@click.option(
"--format",
"output_format",
type=click.Choice(["table", "json"]),
default="table",
help="Output format (table or json)",
)
@click.pass_context
def list(
ctx: click.Context,
page: Optional[int],
size: Optional[int],
user_id: Optional[str],
team_id: Optional[str],
organization_id: Optional[str],
key_hash: Optional[str],
key_alias: Optional[str],
include_team_keys: bool,
output_format: Literal["table", "json"],
return_full_object: bool,
):
"""List all API keys"""
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
response = client.list(
page=page,
size=size,
user_id=user_id,
team_id=team_id,
organization_id=organization_id,
key_hash=key_hash,
key_alias=key_alias,
return_full_object=return_full_object,
include_team_keys=include_team_keys,
)
assert isinstance(response, dict)
if output_format == "json":
rich.print_json(data=response)
else:
rich.print(
f"Showing {len(response.get('keys', []))} keys out of {response.get('total_count', 0)}"
)
table = Table(title="API Keys")
table.add_column("Key Hash", style="cyan")
table.add_column("Alias", style="green")
table.add_column("User ID", style="magenta")
table.add_column("Team ID", style="yellow")
table.add_column("Spend", style="red")
for key in response.get("keys", []):
table.add_row(
str(key.get("token", "")),
str(key.get("key_alias", "")),
str(key.get("user_id", "")),
str(key.get("team_id", "")),
str(key.get("spend", "")),
)
rich.print(table)
@keys.command()
@click.option("--models", type=str, help="Comma-separated list of allowed models")
@click.option("--aliases", type=str, help="JSON string of model alias mappings")
@click.option("--spend", type=float, help="Maximum spend limit for this key")
@click.option(
"--duration",
type=str,
help="Duration for which the key is valid (e.g. '24h', '7d')",
)
@click.option("--key-alias", type=str, help="Alias/name for the key")
@click.option("--team-id", type=str, help="Team ID to associate the key with")
@click.option("--user-id", type=str, help="User ID to associate the key with")
@click.option("--budget-id", type=str, help="Budget ID to associate the key with")
@click.option(
"--config", type=str, help="JSON string of additional configuration parameters"
)
@click.pass_context
def generate(
ctx: click.Context,
models: Optional[str],
aliases: Optional[str],
spend: Optional[float],
duration: Optional[str],
key_alias: Optional[str],
team_id: Optional[str],
user_id: Optional[str],
budget_id: Optional[str],
config: Optional[str],
):
"""Generate a new API key"""
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
try:
models_list = [m.strip() for m in models.split(",")] if models else None
aliases_dict = json.loads(aliases) if aliases else None
config_dict = json.loads(config) if config else None
except json.JSONDecodeError as e:
raise click.BadParameter(f"Invalid JSON: {str(e)}")
try:
response = client.generate(
models=models_list,
aliases=aliases_dict,
spend=spend,
duration=duration,
key_alias=key_alias,
team_id=team_id,
user_id=user_id,
budget_id=budget_id,
config=config_dict,
)
rich.print_json(data=response)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()
@keys.command()
@click.option("--keys", type=str, help="Comma-separated list of API keys to delete")
@click.option(
"--key-aliases", type=str, help="Comma-separated list of key aliases to delete"
)
@click.pass_context
def delete(ctx: click.Context, keys: Optional[str], key_aliases: Optional[str]):
"""Delete API keys by key or alias"""
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
keys_list = [k.strip() for k in keys.split(",")] if keys else None
aliases_list = [a.strip() for a in key_aliases.split(",")] if key_aliases else None
try:
response = client.delete(keys=keys_list, key_aliases=aliases_list)
rich.print_json(data=response)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()
def _parse_created_since_filter(created_since: Optional[str]) -> Optional[datetime]:
"""Parse and validate the created_since date filter."""
if not created_since:
return None
try:
# Support formats: YYYY-MM-DD_HH:MM or YYYY-MM-DD
if "_" in created_since:
return datetime.strptime(created_since, "%Y-%m-%d_%H:%M")
else:
return datetime.strptime(created_since, "%Y-%m-%d")
except ValueError:
click.echo(
f"Error: Invalid date format '{created_since}'. Use YYYY-MM-DD_HH:MM or YYYY-MM-DD",
err=True,
)
raise click.Abort()
def _fetch_all_keys_with_pagination(
source_client: KeysManagementClient, source_base_url: str
) -> List[Dict[str, Any]]:
"""Fetch all keys from source instance using pagination."""
click.echo(f"Fetching keys from source server: {source_base_url}")
source_keys = []
page = 1
page_size = 100 # Use a larger page size to minimize API calls
while True:
source_response = source_client.list(
return_full_object=True, page=page, size=page_size
)
# source_client.list() returns Dict[str, Any] when return_request is False (default)
assert isinstance(source_response, dict), "Expected dict response from list API"
page_keys = source_response.get("keys", [])
if not page_keys:
break
source_keys.extend(page_keys)
click.echo(f"Fetched page {page}: {len(page_keys)} keys")
# Check if we got fewer keys than the page size, indicating last page
if len(page_keys) < page_size:
break
page += 1
return source_keys
def _filter_keys_by_created_since(
source_keys: List[Dict[str, Any]],
created_since_dt: Optional[datetime],
created_since: str,
) -> List[Dict[str, Any]]:
"""Filter keys by created_since date if specified."""
if not created_since_dt:
return source_keys
filtered_keys = []
for key in source_keys:
key_created_at = key.get("created_at")
if key_created_at:
# Parse the key's created_at timestamp
if isinstance(key_created_at, str):
if "T" in key_created_at:
key_dt = datetime.fromisoformat(
key_created_at.replace("Z", "+00:00")
)
else:
key_dt = datetime.fromisoformat(key_created_at)
# Convert to naive datetime for comparison (assuming UTC)
if key_dt.tzinfo:
key_dt = key_dt.replace(tzinfo=None)
if key_dt >= created_since_dt:
filtered_keys.append(key)
click.echo(
f"Filtered {len(source_keys)} keys to {len(filtered_keys)} keys created since {created_since}"
)
return filtered_keys
def _display_dry_run_table(source_keys: List[Dict[str, Any]]) -> None:
"""Display a table of keys that would be imported in dry-run mode."""
click.echo("\n--- DRY RUN MODE ---")
table = Table(title="Keys that would be imported")
table.add_column("Key Alias", style="green")
table.add_column("User ID", style="magenta")
table.add_column("Created", style="cyan")
for key in source_keys:
created_at = key.get("created_at", "")
# Format the timestamp if it exists
if created_at:
# Try to parse and format the timestamp for better readability
if isinstance(created_at, str):
# Handle common timestamp formats
if "T" in created_at:
dt = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
created_at = dt.strftime("%Y-%m-%d %H:%M")
table.add_row(
str(key.get("key_alias", "")), str(key.get("user_id", "")), str(created_at)
)
rich.print(table)
def _prepare_key_import_data(key: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare key data for import by extracting relevant fields."""
import_data = {}
# Copy relevant fields if they exist
for field in [
"models",
"aliases",
"spend",
"key_alias",
"team_id",
"user_id",
"budget_id",
"config",
]:
if key.get(field):
import_data[field] = key[field]
return import_data
def _import_keys_to_destination(
source_keys: List[Dict[str, Any]], dest_client: KeysManagementClient
) -> tuple[int, int]:
"""Import each key to the destination instance and return counts."""
imported_count = 0
failed_count = 0
for key in source_keys:
try:
# Prepare key data for import
import_data = _prepare_key_import_data(key)
# Generate the key in destination instance
response = dest_client.generate(**import_data)
click.echo(f"Generated key: {response}")
# The generate method returns JSON data directly, not a Response object
imported_count += 1
key_alias = key.get("key_alias", "N/A")
click.echo(f"✓ Imported key: {key_alias}")
except Exception as e:
failed_count += 1
key_alias = key.get("key_alias", "N/A")
click.echo(f"✗ Failed to import key {key_alias}: {str(e)}", err=True)
return imported_count, failed_count
@keys.command(name="import")
@click.option(
"--source-base-url",
required=True,
help="Base URL of the source LiteLLM proxy server to import keys from",
)
@click.option(
"--source-api-key", help="API key for authentication to the source server"
)
@click.option(
"--dry-run",
is_flag=True,
help="Show what would be imported without actually importing",
)
@click.option(
"--created-since",
help="Only import keys created after this date/time (format: YYYY-MM-DD_HH:MM or YYYY-MM-DD)",
)
@click.pass_context
def import_keys(
ctx: click.Context,
source_base_url: str,
source_api_key: Optional[str],
dry_run: bool,
created_since: Optional[str],
):
"""Import API keys from another LiteLLM instance"""
# Parse created_since filter if provided
created_since_dt = _parse_created_since_filter(created_since)
# Create clients for both source and destination
source_client = KeysManagementClient(source_base_url, source_api_key)
dest_client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
try:
# Get all keys from source instance with pagination
source_keys = _fetch_all_keys_with_pagination(source_client, source_base_url)
# Filter keys by created_since if specified
if created_since:
source_keys = _filter_keys_by_created_since(
source_keys, created_since_dt, created_since
)
if not source_keys:
click.echo("No keys found in source instance.")
return
click.echo(f"Found {len(source_keys)} keys in source instance.")
if dry_run:
_display_dry_run_table(source_keys)
return
# Import each key
imported_count, failed_count = _import_keys_to_destination(
source_keys, dest_client
)
# Summary
click.echo("\nImport completed:")
click.echo(f" Successfully imported: {imported_count}")
click.echo(f" Failed to import: {failed_count}")
click.echo(f" Total keys processed: {len(source_keys)}")
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()
except Exception as e:
click.echo(f"Error: {str(e)}", err=True)
raise click.Abort()

View File

@@ -0,0 +1,485 @@
# stdlib imports
from datetime import datetime
import re
from typing import Optional, Literal, Any
import yaml
from dataclasses import dataclass
from collections import defaultdict
# third party imports
import click
import rich
# local imports
from ... import Client
@dataclass
class ModelYamlInfo:
model_name: str
model_params: dict[str, Any]
model_info: dict[str, Any]
model_id: str
access_groups: list[str]
provider: str
@property
def access_groups_str(self) -> str:
return ", ".join(self.access_groups) if self.access_groups else ""
def _get_model_info_obj_from_yaml(model: dict[str, Any]) -> ModelYamlInfo:
"""Extract model info from a model dict and return as ModelYamlInfo dataclass."""
model_name: str = model["model_name"]
model_params: dict[str, Any] = model["litellm_params"]
model_info: dict[str, Any] = model.get("model_info", {})
model_id: str = model_params["model"]
access_groups = model_info.get("access_groups", [])
provider = model_id.split("/", 1)[0] if "/" in model_id else model_id
return ModelYamlInfo(
model_name=model_name,
model_params=model_params,
model_info=model_info,
model_id=model_id,
access_groups=access_groups,
provider=provider,
)
def format_iso_datetime_str(iso_datetime_str: Optional[str]) -> str:
"""Format an ISO format datetime string to human-readable date with minute resolution."""
if not iso_datetime_str:
return ""
try:
# Parse ISO format datetime string
dt = datetime.fromisoformat(iso_datetime_str.replace("Z", "+00:00"))
return dt.strftime("%Y-%m-%d %H:%M")
except (TypeError, ValueError):
return str(iso_datetime_str)
def format_timestamp(timestamp: Optional[int]) -> str:
"""Format a Unix timestamp (integer) to human-readable date with minute resolution."""
if timestamp is None:
return ""
try:
dt = datetime.fromtimestamp(timestamp)
return dt.strftime("%Y-%m-%d %H:%M")
except (TypeError, ValueError):
return str(timestamp)
def format_cost_per_1k_tokens(cost: Optional[float]) -> str:
"""Format a per-token cost to cost per 1000 tokens."""
if cost is None:
return ""
try:
# Convert string to float if needed
cost_float = float(cost)
# Multiply by 1000 and format to 4 decimal places
return f"${cost_float * 1000:.4f}"
except (TypeError, ValueError):
return str(cost)
def create_client(ctx: click.Context) -> Client:
"""Helper function to create a client from context."""
return Client(base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"])
@click.group()
def models() -> None:
"""Manage models on your LiteLLM proxy server"""
pass
@models.command("list")
@click.option(
"--format",
"output_format",
type=click.Choice(["table", "json"]),
default="table",
help="Output format (table or json)",
)
@click.pass_context
def list_models(ctx: click.Context, output_format: Literal["table", "json"]) -> None:
"""List all available models"""
client = create_client(ctx)
models_list = client.models.list()
assert isinstance(models_list, list)
if output_format == "json":
rich.print_json(data=models_list)
else: # table format
table = rich.table.Table(title="Available Models")
# Add columns based on the data structure
table.add_column("ID", style="cyan")
table.add_column("Object", style="green")
table.add_column("Created", style="magenta")
table.add_column("Owned By", style="yellow")
# Add rows
for model in models_list:
created = model.get("created")
# Convert string timestamp to integer if needed
if isinstance(created, str) and created.isdigit():
created = int(created)
table.add_row(
str(model.get("id", "")),
str(model.get("object", "model")),
format_timestamp(created)
if isinstance(created, int)
else format_iso_datetime_str(created),
str(model.get("owned_by", "")),
)
rich.print(table)
@models.command("add")
@click.argument("model-name")
@click.option(
"--param",
"-p",
multiple=True,
help="Model parameters in key=value format (can be specified multiple times)",
)
@click.option(
"--info",
"-i",
multiple=True,
help="Model info in key=value format (can be specified multiple times)",
)
@click.pass_context
def add_model(
ctx: click.Context, model_name: str, param: tuple[str, ...], info: tuple[str, ...]
) -> None:
"""Add a new model to the proxy"""
# Convert parameters from key=value format to dict
model_params = dict(p.split("=", 1) for p in param)
model_info = dict(i.split("=", 1) for i in info) if info else None
client = create_client(ctx)
result = client.models.new(
model_name=model_name,
model_params=model_params,
model_info=model_info,
)
rich.print_json(data=result)
@models.command("delete")
@click.argument("model-id")
@click.pass_context
def delete_model(ctx: click.Context, model_id: str) -> None:
"""Delete a model from the proxy"""
client = create_client(ctx)
result = client.models.delete(model_id=model_id)
rich.print_json(data=result)
@models.command("get")
@click.option("--id", "model_id", help="ID of the model to retrieve")
@click.option("--name", "model_name", help="Name of the model to retrieve")
@click.pass_context
def get_model(
ctx: click.Context, model_id: Optional[str], model_name: Optional[str]
) -> None:
"""Get information about a specific model"""
if not model_id and not model_name:
raise click.UsageError("Either --id or --name must be provided")
client = create_client(ctx)
result = client.models.get(model_id=model_id, model_name=model_name)
rich.print_json(data=result)
@models.command("info")
@click.option(
"--format",
"output_format",
type=click.Choice(["table", "json"]),
default="table",
help="Output format (table or json)",
)
@click.option(
"--columns",
"columns",
default="public_model,upstream_model,updated_at",
help="Comma-separated list of columns to display. Valid columns: public_model, upstream_model, credential_name, created_at, updated_at, id, input_cost, output_cost. Default: public_model,upstream_model,updated_at",
)
@click.pass_context
def get_models_info(
ctx: click.Context, output_format: Literal["table", "json"], columns: str
) -> None:
"""Get detailed information about all models"""
client = create_client(ctx)
models_info = client.models.info()
assert isinstance(models_info, list)
if output_format == "json":
rich.print_json(data=models_info)
else: # table format
table = rich.table.Table(title="Models Information")
# Define all possible columns with their configurations
column_configs: dict[str, dict[str, Any]] = {
"public_model": {
"header": "Public Model",
"style": "cyan",
"get_value": lambda m: str(m.get("model_name", "")),
},
"upstream_model": {
"header": "Upstream Model",
"style": "green",
"get_value": lambda m: str(
m.get("litellm_params", {}).get("model", "")
),
},
"credential_name": {
"header": "Credential Name",
"style": "yellow",
"get_value": lambda m: str(
m.get("litellm_params", {}).get("litellm_credential_name", "")
),
},
"created_at": {
"header": "Created At",
"style": "magenta",
"get_value": lambda m: format_iso_datetime_str(
m.get("model_info", {}).get("created_at")
),
},
"updated_at": {
"header": "Updated At",
"style": "magenta",
"get_value": lambda m: format_iso_datetime_str(
m.get("model_info", {}).get("updated_at")
),
},
"id": {
"header": "ID",
"style": "blue",
"get_value": lambda m: str(m.get("model_info", {}).get("id", "")),
},
"input_cost": {
"header": "Input Cost",
"style": "green",
"justify": "right",
"get_value": lambda m: format_cost_per_1k_tokens(
m.get("model_info", {}).get("input_cost_per_token")
),
},
"output_cost": {
"header": "Output Cost",
"style": "green",
"justify": "right",
"get_value": lambda m: format_cost_per_1k_tokens(
m.get("model_info", {}).get("output_cost_per_token")
),
},
}
# Add requested columns
requested_columns = [col.strip() for col in columns.split(",")]
for col_name in requested_columns:
if col_name in column_configs:
config = column_configs[col_name]
table.add_column(
config["header"],
style=config["style"],
justify=config.get("justify", "left"),
)
else:
click.echo(f"Warning: Unknown column '{col_name}'", err=True)
# Add rows with only the requested columns
for model in models_info:
row_values = []
for col_name in requested_columns:
if col_name in column_configs:
row_values.append(column_configs[col_name]["get_value"](model))
if row_values:
table.add_row(*row_values)
rich.print(table)
@models.command("update")
@click.argument("model-id")
@click.option(
"--param",
"-p",
multiple=True,
help="Model parameters in key=value format (can be specified multiple times)",
)
@click.option(
"--info",
"-i",
multiple=True,
help="Model info in key=value format (can be specified multiple times)",
)
@click.pass_context
def update_model(
ctx: click.Context, model_id: str, param: tuple[str, ...], info: tuple[str, ...]
) -> None:
"""Update an existing model's configuration"""
# Convert parameters from key=value format to dict
model_params = dict(p.split("=", 1) for p in param)
model_info = dict(i.split("=", 1) for i in info) if info else None
client = create_client(ctx)
result = client.models.update(
model_id=model_id,
model_params=model_params,
model_info=model_info,
)
rich.print_json(data=result)
def _filter_model(model, model_regex, access_group_regex):
model_name = model.get("model_name")
model_params = model.get("litellm_params")
model_info = model.get("model_info", {})
if not model_name or not model_params:
return False
model_id = model_params.get("model")
if not model_id or not isinstance(model_id, str):
return False
if model_regex and not model_regex.search(model_id):
return False
access_groups = model_info.get("access_groups", [])
if access_group_regex:
if not isinstance(access_groups, list):
return False
if not any(
isinstance(group, str) and access_group_regex.search(group)
for group in access_groups
):
return False
return True
def _print_models_table(added_models: list[ModelYamlInfo], table_title: str):
if not added_models:
return
table = rich.table.Table(title=table_title)
table.add_column("Model Name", style="cyan")
table.add_column("Upstream Model", style="green")
table.add_column("Access Groups", style="magenta")
for m in added_models:
table.add_row(m.model_name, m.model_id, m.access_groups_str)
rich.print(table)
def _print_summary_table(provider_counts):
summary_table = rich.table.Table(title="Model Import Summary")
summary_table.add_column("Provider", style="cyan")
summary_table.add_column("Count", style="green")
for provider, count in provider_counts.items():
summary_table.add_row(str(provider), str(count))
total = sum(provider_counts.values())
summary_table.add_row("[bold]Total[/bold]", f"[bold]{total}[/bold]")
rich.print(summary_table)
def get_model_list_from_yaml_file(yaml_file: str) -> list[dict[str, Any]]:
"""Load and validate the model list from a YAML file."""
with open(yaml_file, "r") as f:
data = yaml.safe_load(f)
if not data or "model_list" not in data:
raise click.ClickException(
"YAML file must contain a 'model_list' key with a list of models."
)
model_list = data["model_list"]
if not isinstance(model_list, list):
raise click.ClickException("'model_list' must be a list of model definitions.")
return model_list
def _get_filtered_model_list(
model_list, only_models_matching_regex, only_access_groups_matching_regex
):
"""Return a list of models that pass the filter criteria."""
model_regex = (
re.compile(only_models_matching_regex) if only_models_matching_regex else None
)
access_group_regex = (
re.compile(only_access_groups_matching_regex)
if only_access_groups_matching_regex
else None
)
return [
model
for model in model_list
if _filter_model(model, model_regex, access_group_regex)
]
def _import_models_get_table_title(dry_run: bool) -> str:
if dry_run:
return "Models that would be imported if [yellow]--dry-run[/yellow] was not provided"
else:
return "Models Imported"
@models.command("import")
@click.argument(
"yaml_file", type=click.Path(exists=True, dir_okay=False, readable=True)
)
@click.option(
"--dry-run",
is_flag=True,
help="Show what would be imported without making any changes.",
)
@click.option(
"--only-models-matching-regex",
default=None,
help="Only import models where litellm_params.model matches the given regex.",
)
@click.option(
"--only-access-groups-matching-regex",
default=None,
help="Only import models where at least one item in model_info.access_groups matches the given regex.",
)
@click.pass_context
def import_models(
ctx: click.Context,
yaml_file: str,
dry_run: bool,
only_models_matching_regex: Optional[str],
only_access_groups_matching_regex: Optional[str],
) -> None:
"""Import models from a YAML file and add them to the proxy."""
provider_counts: dict[str, int] = defaultdict(int)
added_models: list[ModelYamlInfo] = []
model_list = get_model_list_from_yaml_file(yaml_file)
filtered_model_list = _get_filtered_model_list(
model_list, only_models_matching_regex, only_access_groups_matching_regex
)
if not dry_run:
client = create_client(ctx)
for model in filtered_model_list:
model_info_obj = _get_model_info_obj_from_yaml(model)
if not dry_run:
try:
client.models.new(
model_name=model_info_obj.model_name,
model_params=model_info_obj.model_params,
model_info=model_info_obj.model_info,
)
except Exception:
pass # For summary, ignore errors
added_models.append(model_info_obj)
provider_counts[model_info_obj.provider] += 1
table_title = _import_models_get_table_title(dry_run)
_print_models_table(added_models, table_title)
_print_summary_table(provider_counts)

View File

@@ -0,0 +1,167 @@
"""Team management commands for LiteLLM CLI."""
from typing import Any, Dict, List, Optional
import click
import requests
from rich.console import Console
from rich.table import Table
from litellm.proxy.client import Client
@click.group()
def teams():
"""Manage teams and team assignments"""
pass
def display_teams_table(teams: List[Dict[str, Any]]) -> None:
"""Display teams in a formatted table"""
console = Console()
if not teams:
console.print("❌ No teams found for your user.")
return
table = Table(title="Available Teams")
table.add_column("Index", style="cyan", no_wrap=True)
table.add_column("Team Alias", style="magenta")
table.add_column("Team ID", style="green")
table.add_column("Models", style="yellow")
table.add_column("Max Budget", style="blue")
table.add_column("Role", style="red")
for i, team in enumerate(teams):
team_alias = team.get("team_alias") or "N/A"
team_id = team.get("team_id", "N/A")
models = team.get("models", [])
max_budget = team.get("max_budget")
# Format models list
if models:
if len(models) > 3:
models_str = ", ".join(models[:3]) + f" (+{len(models) - 3} more)"
else:
models_str = ", ".join(models)
else:
models_str = "All models"
# Format budget
budget_str = f"${max_budget}" if max_budget else "Unlimited"
# Try to determine role (this might vary based on API response structure)
role = "Member" # Default role
if (
isinstance(team, dict)
and "members_with_roles" in team
and team["members_with_roles"]
):
# This would need to be implemented based on actual API response structure
pass
table.add_row(str(i + 1), team_alias, team_id, models_str, budget_str, role)
console.print(table)
@teams.command()
@click.pass_context
def list(ctx: click.Context):
"""List teams that you belong to"""
client = Client(ctx.obj["base_url"], ctx.obj["api_key"])
try:
# Use list() for simpler response structure (returns array directly)
teams = client.teams.list()
display_teams_table(teams)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
error_body = e.response.json()
click.echo(f"Details: {error_body.get('detail', 'Unknown error')}", err=True)
raise click.Abort()
except Exception as e:
click.echo(f"Error: {str(e)}", err=True)
raise click.Abort()
@teams.command()
@click.pass_context
def available(ctx: click.Context):
"""List teams that are available to join"""
client = Client(ctx.obj["base_url"], ctx.obj["api_key"])
try:
teams = client.teams.get_available()
if teams:
console = Console()
console.print("\n🎯 Available Teams to Join:")
display_teams_table(teams)
else:
click.echo(" No available teams to join.")
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
error_body = e.response.json()
click.echo(f"Details: {error_body.get('detail', 'Unknown error')}", err=True)
except Exception as e:
click.echo(f"Error: {str(e)}", err=True)
raise click.Abort()
@teams.command()
@click.option("--team-id", type=str, help="Team ID to assign the key to")
@click.pass_context
def assign_key(ctx: click.Context, team_id: Optional[str]):
"""Assign your current CLI key to a team"""
client = Client(ctx.obj["base_url"], ctx.obj["api_key"])
api_key = ctx.obj["api_key"]
if not api_key:
click.echo("❌ No API key found. Please login first using 'litellm login'")
raise click.Abort()
try:
# If no team_id provided, show teams and let user select
if not team_id:
teams = client.teams.list()
if not teams:
click.echo("❌ No teams found for your user.")
return
# Use interactive selection from auth module
from .auth import prompt_team_selection
selected_team = prompt_team_selection(teams)
if selected_team:
team_id = selected_team.get("team_id")
else:
click.echo("❌ Operation cancelled.")
return
# Update the key with the selected team
if team_id:
click.echo(f"\n🔄 Assigning your key to team: {team_id}")
client.keys.update(key=api_key, team_id=team_id)
click.echo(f"✅ Successfully assigned key to team: {team_id}")
# Show team details if available
teams = client.teams.list()
for team in teams:
if team.get("team_id") == team_id:
models = team.get("models", [])
if models:
click.echo(f"🎯 You can now access models: {', '.join(models)}")
else:
click.echo("🎯 You can now access all available models")
break
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
error_body = e.response.json()
click.echo(f"Details: {error_body.get('detail', 'Unknown error')}", err=True)
raise click.Abort()
except Exception as e:
click.echo(f"Error: {str(e)}", err=True)
raise click.Abort()

View File

@@ -0,0 +1,91 @@
import click
import rich
from ... import UsersManagementClient
@click.group()
def users():
"""Manage users on your LiteLLM proxy server"""
pass
@users.command("list")
@click.pass_context
def list_users(ctx: click.Context):
"""List all users"""
client = UsersManagementClient(
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
)
users = client.list_users()
if isinstance(users, dict) and "users" in users:
users = users["users"]
if not users:
click.echo("No users found.")
return
from rich.table import Table
from rich.console import Console
table = Table(title="Users")
table.add_column("User ID", style="cyan")
table.add_column("Email", style="green")
table.add_column("Role", style="magenta")
table.add_column("Teams", style="yellow")
for user in users:
table.add_row(
str(user.get("user_id", "")),
str(user.get("user_email", "")),
str(user.get("user_role", "")),
", ".join(user.get("teams", []) or []),
)
console = Console()
console.print(table)
@users.command("get")
@click.option("--id", "user_id", help="ID of the user to retrieve")
@click.pass_context
def get_user(ctx: click.Context, user_id: str):
"""Get information about a specific user"""
client = UsersManagementClient(
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
)
result = client.get_user(user_id=user_id)
rich.print_json(data=result)
@users.command("create")
@click.option("--email", required=True, help="User email")
@click.option("--role", default="internal_user", help="User role")
@click.option("--alias", default=None, help="User alias")
@click.option("--team", multiple=True, help="Team IDs (can specify multiple)")
@click.option("--max-budget", type=float, default=None, help="Max budget for user")
@click.pass_context
def create_user(ctx: click.Context, email, role, alias, team, max_budget):
"""Create a new user"""
client = UsersManagementClient(
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
)
user_data = {
"user_email": email,
"user_role": role,
}
if alias:
user_data["user_alias"] = alias
if team:
user_data["teams"] = list(team)
if max_budget is not None:
user_data["max_budget"] = max_budget
result = client.create_user(user_data)
rich.print_json(data=result)
@users.command("delete")
@click.argument("user_ids", nargs=-1)
@click.pass_context
def delete_user(ctx: click.Context, user_ids):
"""Delete one or more users by user_id"""
client = UsersManagementClient(
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
)
result = client.delete_user(list(user_ids))
rich.print_json(data=result)