Files
llm-intelligence/scripts/fetch_openrouter.go
phamnazage-jpg f5b373caf4
Some checks failed
CI / go-test (push) Has been cancelled
CI / scripts-regression (push) Has been cancelled
CI / frontend-build (push) Has been cancelled
CI / docker-build (push) Has been cancelled
feat(report): improve daily intelligence UX and price tracking
2026-05-27 17:23:08 +08:00

659 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//go:build llm_script
// fetch_openrouter.go - OpenRouter 模型数据采集器 v2.0
// Sprint 2 增强版:指数退避重试 + 批量插入 + ProviderMapper + audit_log + 价格变动检测 + slog
package main
import (
"bufio"
"context"
"database/sql"
"encoding/json"
"flag"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strings"
"time"
"llm-intelligence/internal/collectors"
"llm-intelligence/internal/retry"
_ "github.com/lib/pq"
)
// Config 采集配置
type Config struct {
APIKey string
APIURL string
OutPath string
MaxRetries int
TimeoutSec int
BatchSize int
DBConn string
StrictReal bool
}
// ModelInfo 模型信息(与 collectors 包兼容)
type ModelInfo struct {
ID string `json:"id"`
Name string `json:"name,omitempty"`
Created int64 `json:"created,omitempty"`
Description string `json:"description,omitempty"`
ContextLength int `json:"context_length,omitempty"`
Capabilities []string `json:"capabilities,omitempty"`
Pricing ModelPricing `json:"pricing,omitempty"`
}
type ModelPricing struct {
Input float64 `json:"input,omitempty"`
Output float64 `json:"output,omitempty"`
}
var (
collectorVersion = "v2.0"
logger *slog.Logger
)
func init() {
logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))
}
func main() {
cfg := parseArgs()
start := time.Now()
logger.Info("采集器启动", "collector", "openrouter", "version", collectorVersion, "batch_size", cfg.BatchSize)
var runErr error
if err := run(cfg); err != nil {
logger.Error("采集失败", "error", err, "duration", time.Since(start))
runErr = err
}
duration := time.Since(start)
// 写入采集统计
if cfg.DBConn != "" {
if err := recordCollectorStats(cfg.DBConn, runErr, duration); err != nil {
logger.Warn("采集统计写入失败", "error", err)
}
}
if runErr != nil {
os.Exit(1)
}
logger.Info("采集完成", "collector", "openrouter", "duration_ms", duration.Milliseconds())
}
func parseArgs() Config {
loadProjectEnv()
apiKey := flag.String("api-key", "", "OpenRouter API Key")
apiURL := flag.String("api-url", "https://openrouter.ai/api/v1/models", "API 地址")
outPath := flag.String("out", "models.json", "输出文件路径")
maxRetries := flag.Int("retry", 3, "最大重试次数")
timeoutSec := flag.Int("timeout", 30, "请求超时(秒)")
batchSize := flag.Int("batch", 100, "批量插入批次大小")
dbConn := flag.String("db", os.Getenv("DATABASE_URL"), "PostgreSQL 连接字符串")
strictReal := flag.Bool("strict-real", false, "严格真实模式:缺少 API Key 或数据库写入失败时返回错误")
flag.Parse()
return Config{
APIKey: *apiKey,
APIURL: *apiURL,
OutPath: *outPath,
MaxRetries: *maxRetries,
TimeoutSec: *timeoutSec,
BatchSize: *batchSize,
DBConn: *dbConn,
StrictReal: *strictReal,
}
}
func loadProjectEnv() {
for _, path := range []string{".env.local", ".env"} {
loadEnvFile(path)
}
}
func loadEnvFile(path string) {
f, err := os.Open(path)
if err != nil {
return
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
key, value, ok := strings.Cut(line, "=")
if !ok {
continue
}
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
value = strings.Trim(value, `"'`)
if key == "" {
continue
}
if _, exists := os.LookupEnv(key); exists {
continue
}
_ = os.Setenv(key, value)
}
}
func run(cfg Config) error {
models, err := fetchModels(cfg)
if err != nil {
return err
}
logger.Info("API 数据获取完成", "records", len(models))
if cfg.DBConn != "" {
if err := summarizeDB(cfg.DBConn, models, cfg.BatchSize); err != nil {
logger.Error("PostgreSQL 写入失败", "error", err)
if cfg.StrictReal {
return fmt.Errorf("PostgreSQL 写入失败: %w", err)
}
logger.Warn("降级为仅写入 JSON")
} else {
logger.Info("PostgreSQL 写入完成", "records", len(models))
}
}
return summarize(cfg.OutPath, models)
}
// fetchModels 抓取 OpenRouter 模型列表(集成指数退避重试)
func fetchModels(cfg Config) ([]ModelInfo, error) {
if cfg.APIKey == "" {
if cfg.StrictReal {
return nil, fmt.Errorf("严格真实模式下必须提供 API Key")
}
logger.Warn("未提供 API Key使用模拟数据")
return []ModelInfo{
{ID: "openai/gpt-4o", ContextLength: 128000, Pricing: ModelPricing{Input: 2.5, Output: 10.0}},
{ID: "anthropic/claude-3.5-sonnet:free", ContextLength: 200000, Pricing: ModelPricing{}},
}, nil
}
strategy := retry.Strategy{
MaxRetries: cfg.MaxRetries,
BaseDelay: 1 * time.Second,
MaxDelay: 30 * time.Second,
Multiplier: 2.0,
Jitter: true,
Retryable: retry.IsRetryable,
}
var models []ModelInfo
var lastErr error
err := retry.Do(context.Background(), strategy, func() error {
client := &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second}
req, err := http.NewRequest("GET", cfg.APIURL, nil)
if err != nil {
return fmt.Errorf("构造请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
lastErr = err
return fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
lastErr = retry.HTTPStatusError{StatusCode: resp.StatusCode, Body: string(body)}
return lastErr
}
body, err := io.ReadAll(resp.Body)
if err != nil {
lastErr = err
return fmt.Errorf("读取响应失败: %w", err)
}
models, err = parseModels(body)
if err != nil {
lastErr = err
return fmt.Errorf("JSON 解析失败: %w", err)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("采集失败(%d次尝试: %w", strategy.MaxRetries+1, lastErr)
}
return models, nil
}
func parseModels(raw []byte) ([]ModelInfo, error) {
var wrapper struct {
Data json.RawMessage `json:"data"`
}
if err := json.Unmarshal(raw, &wrapper); err != nil {
return nil, fmt.Errorf("解析 data 字段失败: %w", err)
}
var rawItems []any
if err := json.Unmarshal(wrapper.Data, &rawItems); err != nil {
return nil, fmt.Errorf("解析模型数组失败: %w", err)
}
models := make([]ModelInfo, 0, len(rawItems))
for _, item := range rawItems {
m, ok := item.(map[string]any)
if !ok {
continue
}
model := ModelInfo{
ID: getString(m, "id"),
Name: getString(m, "name"),
}
if model.ID == "" {
continue
}
if p, ok := m["pricing"].(map[string]any); ok {
model.Pricing.Input = getPrice(p, "input", "prompt")
model.Pricing.Output = getPrice(p, "output", "completion")
}
model.ContextLength = getInt(m, "context_length")
model.Description = getString(m, "description")
model.Created = getInt64(m, "created")
if caps, ok := m["capabilities"].([]any); ok {
for _, c := range caps {
if s, ok := c.(string); ok {
model.Capabilities = append(model.Capabilities, s)
}
}
}
models = append(models, model)
}
return models, nil
}
func deriveModality(model ModelInfo) string {
for _, capability := range model.Capabilities {
normalized := strings.ToLower(capability)
switch {
case strings.Contains(normalized, "vision"), strings.Contains(normalized, "image"):
return "multimodal"
case strings.Contains(normalized, "audio"):
return "audio"
case strings.Contains(normalized, "video"):
return "video"
case strings.Contains(normalized, "code"):
return "code"
}
}
hints := strings.ToLower(strings.Join([]string{model.ID, model.Name, model.Description}, " "))
switch {
case strings.Contains(hints, "video") && (strings.Contains(hints, "omni") || strings.Contains(hints, "vision") || strings.Contains(hints, "multimodal")):
return "multimodal"
case strings.Contains(hints, "vision") || strings.Contains(hints, "image") || strings.Contains(hints, "vl") || strings.Contains(hints, "omni") || strings.Contains(hints, "multimodal"):
return "multimodal"
case strings.Contains(hints, "audio") || strings.Contains(hints, "speech") || strings.Contains(hints, "voice"):
return "audio"
case strings.Contains(hints, "video"):
return "video"
case strings.Contains(hints, "code"):
return "code"
default:
return "text"
}
}
func getString(m map[string]any, key string) string {
if v, ok := m[key].(string); ok {
return v
}
return ""
}
func getInt(m map[string]any, key string) int {
if v, ok := m[key].(float64); ok {
return int(v)
}
return 0
}
func getInt64(m map[string]any, key string) int64 {
if v, ok := m[key].(float64); ok {
return int64(v)
}
return 0
}
func getPrice(m map[string]any, keys ...string) float64 {
for _, k := range keys {
if v, ok := m[k].(float64); ok {
return v
}
}
return 0
}
func summarize(outPath string, models []ModelInfo) error {
return writeJSON(outPath, models)
}
// summarizeDB 将采集结果写入 PostgreSQL批量插入 + ProviderMapper + 价格变动检测 + audit_log
func summarizeDB(connStr string, models []ModelInfo, batchSize int) error {
db, err := sql.Open("postgres", connStr)
if err != nil {
return fmt.Errorf("连接数据库失败: %w", err)
}
defer db.Close()
if err := db.Ping(); err != nil {
return fmt.Errorf("ping 数据库失败: %w", err)
}
batchID := fmt.Sprintf("batch-%d", time.Now().Unix())
now := time.Now()
effectiveDate := now.Format("2006-01-02")
// 获取默认 operatorOpenRouter
var operatorID int64
err = db.QueryRow("SELECT id FROM operator WHERE name = 'OpenRouter' LIMIT 1").Scan(&operatorID)
if err != nil {
logger.Warn("未找到 OpenRouter operator使用 NULL", "error", err)
operatorID = 0
}
// 获取上次价格数据(用于变动检测)
lastPrices := make(map[int64]ModelPricing)
rows, err := db.Query(`
SELECT model_id, input_price_per_mtok, output_price_per_mtok
FROM region_pricing
WHERE operator_id = $1 AND effective_date = (
SELECT MAX(effective_date) FROM region_pricing WHERE operator_id = $1
)
`, operatorID)
if err == nil {
for rows.Next() {
var mid int64
var p ModelPricing
if err := rows.Scan(&mid, &p.Input, &p.Output); err == nil {
lastPrices[mid] = p
}
}
rows.Close()
}
insertedModels := 0
insertedPrices := 0
priceChanges := 0
// 批量处理
for i := 0; i < len(models); i += batchSize {
end := i + batchSize
if end > len(models) {
end = len(models)
}
batch := models[i:end]
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开启事务失败: %w", err)
}
for _, m := range batch {
// 使用 ProviderMapper 映射厂商
mapping, err := collectors.MapOpenRouterID(m.ID)
if err != nil {
logger.Warn("Provider 映射失败", "id", m.ID, "error", err)
mapping = collectors.ModelMapping{
Provider: collectors.ProviderInfo{ID: "unknown", Name: "Unknown"},
ModelName: m.Name,
RawID: m.ID,
IsFree: false,
}
}
// 查找或创建 provider_id
var providerID int64
err = tx.QueryRow("SELECT id FROM model_provider WHERE name = $1 LIMIT 1", mapping.Provider.Name).Scan(&providerID)
if err != nil {
// 未知厂商,插入
err = tx.QueryRow(`
INSERT INTO model_provider (name, name_cn, country, status)
VALUES ($1, $2, $3, 'active')
ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name
RETURNING id
`, mapping.Provider.Name, mapping.Provider.NameCN, mapping.Provider.Country).Scan(&providerID)
if err != nil {
logger.Warn("创建 provider 失败", "name", mapping.Provider.Name, "error", err)
providerID = 0
}
}
isFree := mapping.IsFree || (m.Pricing.Input == 0 && m.Pricing.Output == 0)
// upsert models 表(带新字段)
var modelID int64
err = tx.QueryRow(`
INSERT INTO models (
source, external_id, name, description, context_length,
capabilities, created_at_source, is_free, status,
raw_payload, provider_id, version, modality,
data_confidence, retrieved_at, batch_id, collector_version,
source_url, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $19)
ON CONFLICT (external_id) DO UPDATE SET
name = EXCLUDED.name,
description = EXCLUDED.description,
context_length = EXCLUDED.context_length,
capabilities = EXCLUDED.capabilities,
created_at_source = EXCLUDED.created_at_source,
is_free = EXCLUDED.is_free,
status = EXCLUDED.status,
raw_payload = EXCLUDED.raw_payload,
provider_id = EXCLUDED.provider_id,
data_confidence = 'official',
retrieved_at = EXCLUDED.retrieved_at,
batch_id = EXCLUDED.batch_id,
collector_version = EXCLUDED.collector_version,
updated_at = EXCLUDED.updated_at
RETURNING id
`,
"openrouter", m.ID, m.Name, m.Description, m.ContextLength,
jsonCapabilities(m.Capabilities), m.Created, isFree, "active",
rawPayload(m), providerID, "", deriveModality(m),
"official", now, batchID, collectorVersion,
"https://openrouter.ai/api/v1/models", now).Scan(&modelID)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("写入 models 失败 (%s): %w", m.ID, err)
}
insertedModels++
// 写入 audit_log
_, _ = tx.Exec(`
INSERT INTO audit_log (table_name, record_id, field_name, old_value, new_value, operation, operator, batch_id, source_url)
VALUES ('models', $1, 'external_id', NULL, $2, 'INSERT', 'fetch_openrouter', $3, $4)
`, modelID, m.ID, batchID, "https://openrouter.ai/api/v1/models")
// upsert region_pricing 表(替代 model_prices
sourceType := "reseller"
freeQuota := ""
freeLimitations := "[]"
rateLimit := "{}"
if isFree {
sourceType = "free_tier"
freeQuota = "Imported free-tier pricing entry"
freeLimitations = `["See source_url for current quota and policy"]`
}
var pricingID int64
err = tx.QueryRow(`
INSERT INTO region_pricing (
model_id, operator_id, region, currency,
input_price_per_mtok, output_price_per_mtok,
is_free, effective_date, source_url, source_type,
free_quota, free_limitations, rate_limit,
created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $14)
ON CONFLICT (model_id, operator_id, region, currency, effective_date) DO UPDATE SET
input_price_per_mtok = EXCLUDED.input_price_per_mtok,
output_price_per_mtok = EXCLUDED.output_price_per_mtok,
is_free = EXCLUDED.is_free,
source_type = EXCLUDED.source_type,
free_quota = EXCLUDED.free_quota,
free_limitations = EXCLUDED.free_limitations,
rate_limit = EXCLUDED.rate_limit,
updated_at = EXCLUDED.updated_at
RETURNING id
`, modelID, operatorID, "global", "USD", m.Pricing.Input, m.Pricing.Output,
isFree, effectiveDate, "https://openrouter.ai/api/v1/models", sourceType,
freeQuota, freeLimitations, rateLimit, now).Scan(&pricingID)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("写入 region_pricing 失败 (%s): %w", m.ID, err)
}
insertedPrices++
// 价格变动检测(>5%
if lastPrice, ok := lastPrices[modelID]; ok {
inputChange := calcChangePercent(lastPrice.Input, m.Pricing.Input)
outputChange := calcChangePercent(lastPrice.Output, m.Pricing.Output)
if abs(inputChange) > 5 || abs(outputChange) > 5 {
_, _ = tx.Exec(`
INSERT INTO pricing_history (
model_id, region, currency,
old_input_price, new_input_price,
old_output_price, new_output_price,
change_percent, changed_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`, modelID, "global", "USD",
lastPrice.Input, m.Pricing.Input,
lastPrice.Output, m.Pricing.Output,
max(abs(inputChange), abs(outputChange)), now)
priceChanges++
}
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("提交事务失败: %w", err)
}
logger.Info("批次完成", "batch", i/batchSize+1, "records", len(batch))
}
logger.Info("PostgreSQL 写入完成",
"models", insertedModels,
"prices", insertedPrices,
"price_changes", priceChanges,
"batch_id", batchID)
return nil
}
func calcChangePercent(old, new float64) float64 {
if old == 0 {
if new == 0 {
return 0
}
return 100
}
return ((new - old) / old) * 100
}
func abs(v float64) float64 {
if v < 0 {
return -v
}
return v
}
func max(a, b float64) float64 {
if a > b {
return a
}
return b
}
func jsonCapabilities(caps []string) []byte {
if len(caps) == 0 {
return []byte("[]")
}
b, _ := json.Marshal(caps)
return b
}
func rawPayload(m ModelInfo) []byte {
b, _ := json.Marshal(m)
return b
}
func writeJSON(outPath string, models []ModelInfo) error {
total := len(models)
var freeCnt, paidCnt int
for _, m := range models {
if len(m.ID) > 5 && m.ID[len(m.ID)-5:] == ":free" {
freeCnt++
} else if m.Pricing.Input > 0 || m.Pricing.Output > 0 {
paidCnt++
}
}
summary := fmt.Sprintf("采集完成: 共 %d 模型(免费 %d / 付费 %d\n", total, freeCnt, paidCnt)
fmt.Print(summary)
out, err := os.Create(outPath)
if err != nil {
return fmt.Errorf("创建输出文件失败: %w", err)
}
defer out.Close()
enc := json.NewEncoder(out)
enc.SetIndent("", " ")
if err := enc.Encode(map[string]any{
"generated_at": time.Now().Format(time.RFC3339),
"total": total,
"free": freeCnt,
"paid": paidCnt,
"models": models,
}); err != nil {
return fmt.Errorf("写入 JSON 失败: %w", err)
}
fmt.Printf("结果已写入: %s\n", outPath)
return nil
}
// recordCollectorStats 记录采集统计到 collector_stats 表
func recordCollectorStats(connStr string, runErr error, duration time.Duration) error {
db, err := sql.Open("postgres", connStr)
if err != nil {
return err
}
defer db.Close()
success := runErr == nil
errMsg := ""
if runErr != nil {
errMsg = runErr.Error()
}
_, err = db.Exec(`
INSERT INTO collector_stats (source, batch_id, success, duration_ms, error_message, created_at)
VALUES ('openrouter', $1, $2, $3, $4, $5)
`, fmt.Sprintf("batch-%d", time.Now().Unix()), success, int(duration.Milliseconds()), errMsg, time.Now())
return err
}