Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/proxy/client/cli/commands/keys.py
2026-03-26 20:06:14 +08:00

416 lines
14 KiB
Python

import json
from datetime import datetime
from typing import Literal, Optional, List, Dict, Any
import click
import rich
import requests
from rich.table import Table
from ...keys import KeysManagementClient
@click.group()
def keys():
"""Manage API keys for the LiteLLM proxy server"""
pass
@keys.command()
@click.option("--page", type=int, help="Page number for pagination")
@click.option("--size", type=int, help="Number of items per page")
@click.option("--user-id", type=str, help="Filter keys by user ID")
@click.option("--team-id", type=str, help="Filter keys by team ID")
@click.option("--organization-id", type=str, help="Filter keys by organization ID")
@click.option("--key-hash", type=str, help="Filter by specific key hash")
@click.option("--key-alias", type=str, help="Filter by key alias")
@click.option(
"--return-full-object",
is_flag=True,
default=True,
help="Return the full key object",
)
@click.option(
"--include-team-keys", is_flag=True, help="Include team keys in the response"
)
@click.option(
"--format",
"output_format",
type=click.Choice(["table", "json"]),
default="table",
help="Output format (table or json)",
)
@click.pass_context
def list(
ctx: click.Context,
page: Optional[int],
size: Optional[int],
user_id: Optional[str],
team_id: Optional[str],
organization_id: Optional[str],
key_hash: Optional[str],
key_alias: Optional[str],
include_team_keys: bool,
output_format: Literal["table", "json"],
return_full_object: bool,
):
"""List all API keys"""
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
response = client.list(
page=page,
size=size,
user_id=user_id,
team_id=team_id,
organization_id=organization_id,
key_hash=key_hash,
key_alias=key_alias,
return_full_object=return_full_object,
include_team_keys=include_team_keys,
)
assert isinstance(response, dict)
if output_format == "json":
rich.print_json(data=response)
else:
rich.print(
f"Showing {len(response.get('keys', []))} keys out of {response.get('total_count', 0)}"
)
table = Table(title="API Keys")
table.add_column("Key Hash", style="cyan")
table.add_column("Alias", style="green")
table.add_column("User ID", style="magenta")
table.add_column("Team ID", style="yellow")
table.add_column("Spend", style="red")
for key in response.get("keys", []):
table.add_row(
str(key.get("token", "")),
str(key.get("key_alias", "")),
str(key.get("user_id", "")),
str(key.get("team_id", "")),
str(key.get("spend", "")),
)
rich.print(table)
@keys.command()
@click.option("--models", type=str, help="Comma-separated list of allowed models")
@click.option("--aliases", type=str, help="JSON string of model alias mappings")
@click.option("--spend", type=float, help="Maximum spend limit for this key")
@click.option(
"--duration",
type=str,
help="Duration for which the key is valid (e.g. '24h', '7d')",
)
@click.option("--key-alias", type=str, help="Alias/name for the key")
@click.option("--team-id", type=str, help="Team ID to associate the key with")
@click.option("--user-id", type=str, help="User ID to associate the key with")
@click.option("--budget-id", type=str, help="Budget ID to associate the key with")
@click.option(
"--config", type=str, help="JSON string of additional configuration parameters"
)
@click.pass_context
def generate(
ctx: click.Context,
models: Optional[str],
aliases: Optional[str],
spend: Optional[float],
duration: Optional[str],
key_alias: Optional[str],
team_id: Optional[str],
user_id: Optional[str],
budget_id: Optional[str],
config: Optional[str],
):
"""Generate a new API key"""
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
try:
models_list = [m.strip() for m in models.split(",")] if models else None
aliases_dict = json.loads(aliases) if aliases else None
config_dict = json.loads(config) if config else None
except json.JSONDecodeError as e:
raise click.BadParameter(f"Invalid JSON: {str(e)}")
try:
response = client.generate(
models=models_list,
aliases=aliases_dict,
spend=spend,
duration=duration,
key_alias=key_alias,
team_id=team_id,
user_id=user_id,
budget_id=budget_id,
config=config_dict,
)
rich.print_json(data=response)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()
@keys.command()
@click.option("--keys", type=str, help="Comma-separated list of API keys to delete")
@click.option(
"--key-aliases", type=str, help="Comma-separated list of key aliases to delete"
)
@click.pass_context
def delete(ctx: click.Context, keys: Optional[str], key_aliases: Optional[str]):
"""Delete API keys by key or alias"""
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
keys_list = [k.strip() for k in keys.split(",")] if keys else None
aliases_list = [a.strip() for a in key_aliases.split(",")] if key_aliases else None
try:
response = client.delete(keys=keys_list, key_aliases=aliases_list)
rich.print_json(data=response)
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()
def _parse_created_since_filter(created_since: Optional[str]) -> Optional[datetime]:
"""Parse and validate the created_since date filter."""
if not created_since:
return None
try:
# Support formats: YYYY-MM-DD_HH:MM or YYYY-MM-DD
if "_" in created_since:
return datetime.strptime(created_since, "%Y-%m-%d_%H:%M")
else:
return datetime.strptime(created_since, "%Y-%m-%d")
except ValueError:
click.echo(
f"Error: Invalid date format '{created_since}'. Use YYYY-MM-DD_HH:MM or YYYY-MM-DD",
err=True,
)
raise click.Abort()
def _fetch_all_keys_with_pagination(
source_client: KeysManagementClient, source_base_url: str
) -> List[Dict[str, Any]]:
"""Fetch all keys from source instance using pagination."""
click.echo(f"Fetching keys from source server: {source_base_url}")
source_keys = []
page = 1
page_size = 100 # Use a larger page size to minimize API calls
while True:
source_response = source_client.list(
return_full_object=True, page=page, size=page_size
)
# source_client.list() returns Dict[str, Any] when return_request is False (default)
assert isinstance(source_response, dict), "Expected dict response from list API"
page_keys = source_response.get("keys", [])
if not page_keys:
break
source_keys.extend(page_keys)
click.echo(f"Fetched page {page}: {len(page_keys)} keys")
# Check if we got fewer keys than the page size, indicating last page
if len(page_keys) < page_size:
break
page += 1
return source_keys
def _filter_keys_by_created_since(
source_keys: List[Dict[str, Any]],
created_since_dt: Optional[datetime],
created_since: str,
) -> List[Dict[str, Any]]:
"""Filter keys by created_since date if specified."""
if not created_since_dt:
return source_keys
filtered_keys = []
for key in source_keys:
key_created_at = key.get("created_at")
if key_created_at:
# Parse the key's created_at timestamp
if isinstance(key_created_at, str):
if "T" in key_created_at:
key_dt = datetime.fromisoformat(
key_created_at.replace("Z", "+00:00")
)
else:
key_dt = datetime.fromisoformat(key_created_at)
# Convert to naive datetime for comparison (assuming UTC)
if key_dt.tzinfo:
key_dt = key_dt.replace(tzinfo=None)
if key_dt >= created_since_dt:
filtered_keys.append(key)
click.echo(
f"Filtered {len(source_keys)} keys to {len(filtered_keys)} keys created since {created_since}"
)
return filtered_keys
def _display_dry_run_table(source_keys: List[Dict[str, Any]]) -> None:
"""Display a table of keys that would be imported in dry-run mode."""
click.echo("\n--- DRY RUN MODE ---")
table = Table(title="Keys that would be imported")
table.add_column("Key Alias", style="green")
table.add_column("User ID", style="magenta")
table.add_column("Created", style="cyan")
for key in source_keys:
created_at = key.get("created_at", "")
# Format the timestamp if it exists
if created_at:
# Try to parse and format the timestamp for better readability
if isinstance(created_at, str):
# Handle common timestamp formats
if "T" in created_at:
dt = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
created_at = dt.strftime("%Y-%m-%d %H:%M")
table.add_row(
str(key.get("key_alias", "")), str(key.get("user_id", "")), str(created_at)
)
rich.print(table)
def _prepare_key_import_data(key: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare key data for import by extracting relevant fields."""
import_data = {}
# Copy relevant fields if they exist
for field in [
"models",
"aliases",
"spend",
"key_alias",
"team_id",
"user_id",
"budget_id",
"config",
]:
if key.get(field):
import_data[field] = key[field]
return import_data
def _import_keys_to_destination(
source_keys: List[Dict[str, Any]], dest_client: KeysManagementClient
) -> tuple[int, int]:
"""Import each key to the destination instance and return counts."""
imported_count = 0
failed_count = 0
for key in source_keys:
try:
# Prepare key data for import
import_data = _prepare_key_import_data(key)
# Generate the key in destination instance
response = dest_client.generate(**import_data)
click.echo(f"Generated key: {response}")
# The generate method returns JSON data directly, not a Response object
imported_count += 1
key_alias = key.get("key_alias", "N/A")
click.echo(f"✓ Imported key: {key_alias}")
except Exception as e:
failed_count += 1
key_alias = key.get("key_alias", "N/A")
click.echo(f"✗ Failed to import key {key_alias}: {str(e)}", err=True)
return imported_count, failed_count
@keys.command(name="import")
@click.option(
"--source-base-url",
required=True,
help="Base URL of the source LiteLLM proxy server to import keys from",
)
@click.option(
"--source-api-key", help="API key for authentication to the source server"
)
@click.option(
"--dry-run",
is_flag=True,
help="Show what would be imported without actually importing",
)
@click.option(
"--created-since",
help="Only import keys created after this date/time (format: YYYY-MM-DD_HH:MM or YYYY-MM-DD)",
)
@click.pass_context
def import_keys(
ctx: click.Context,
source_base_url: str,
source_api_key: Optional[str],
dry_run: bool,
created_since: Optional[str],
):
"""Import API keys from another LiteLLM instance"""
# Parse created_since filter if provided
created_since_dt = _parse_created_since_filter(created_since)
# Create clients for both source and destination
source_client = KeysManagementClient(source_base_url, source_api_key)
dest_client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
try:
# Get all keys from source instance with pagination
source_keys = _fetch_all_keys_with_pagination(source_client, source_base_url)
# Filter keys by created_since if specified
if created_since:
source_keys = _filter_keys_by_created_since(
source_keys, created_since_dt, created_since
)
if not source_keys:
click.echo("No keys found in source instance.")
return
click.echo(f"Found {len(source_keys)} keys in source instance.")
if dry_run:
_display_dry_run_table(source_keys)
return
# Import each key
imported_count, failed_count = _import_keys_to_destination(
source_keys, dest_client
)
# Summary
click.echo("\nImport completed:")
click.echo(f" Successfully imported: {imported_count}")
click.echo(f" Failed to import: {failed_count}")
click.echo(f" Total keys processed: {len(source_keys)}")
except requests.exceptions.HTTPError as e:
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
try:
error_body = e.response.json()
rich.print_json(data=error_body)
except json.JSONDecodeError:
click.echo(e.response.text, err=True)
raise click.Abort()
except Exception as e:
click.echo(f"Error: {str(e)}", err=True)
raise click.Abort()