Files
llm-intelligence/scripts/fetch_openrouter.go

352 lines
9.9 KiB
Go
Raw Normal View History

// fetch_openrouter.go - OpenRouter 模型数据采集器
// Phase 1 单数据源采集器,抓取模型基础信息与价格信息
package main
import (
"database/sql"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"time"
_ "github.com/lib/pq"
)
// Config 采集配置
type Config struct {
APIKey string
APIURL string
OutPath string
MaxRetries int
TimeoutSec int
// PostgreSQL 连接参数(新增)
DBConn string // e.g. "host=/var/run/postgresql dbname=llm_intelligence sslmode=disable"
}
// OpenRouter API 响应结构(仅关键字段)
type APIResponse struct {
Data []ModelInfo `json:"data"`
}
type ModelInfo struct {
ID string `json:"id"`
Name string `json:"name,omitempty"`
Created int64 `json:"created,omitempty"`
Description string `json:"description,omitempty"`
ContextLength int `json:"context_length,omitempty"`
Capabilities []string `json:"capabilities,omitempty"`
Pricing ModelPricing `json:"pricing,omitempty"`
}
type ModelPricing struct {
Input float64 `json:"input,omitempty"`
Output float64 `json:"output,omitempty"`
}
func main() {
cfg := parseArgs()
if err := run(cfg); err != nil {
fmt.Fprintf(os.Stderr, "采集失败: %v\n", err)
os.Exit(1)
}
}
func parseArgs() Config {
apiKey := flag.String("api-key", "", "OpenRouter API Key建议通过环境变量注入")
apiURL := flag.String("api-url", "https://openrouter.ai/api/v1/models", "API 地址")
outPath := flag.String("out", "models.json", "输出文件路径")
maxRetries := flag.Int("retry", 3, "最大重试次数")
timeoutSec := flag.Int("timeout", 30, "请求超时(秒)")
dbConn := flag.String("db", os.Getenv("DATABASE_URL"), "PostgreSQL 连接字符串(默认从 DATABASE_URL 环境变量读取)")
flag.Parse()
return Config{
APIKey: *apiKey,
APIURL: *apiURL,
OutPath: *outPath,
MaxRetries: *maxRetries,
TimeoutSec: *timeoutSec,
DBConn: *dbConn,
}
}
func run(cfg Config) error {
models, err := fetchModels(cfg)
if err != nil {
return err
}
// 优先写入 PostgreSQL若配置了 DBConn 则入库
if cfg.DBConn != "" {
if err := summarizeDB(cfg.DBConn, models); err != nil {
fmt.Fprintf(os.Stderr, "警告: PostgreSQL 写入失败: %v\n", err)
fmt.Fprintln(os.Stderr, "降级为仅写入 JSON")
}
}
return summarize(cfg.OutPath, models)
}
// fetchModels 抓取 OpenRouter 模型列表
func fetchModels(cfg Config) ([]ModelInfo, error) {
// 无 API Key 时返回模拟数据(写入由后续 summarize 统一处理)
if cfg.APIKey == "" {
fmt.Println("警告: 未提供 API Key使用模拟数据")
return []ModelInfo{
{ID: "openai/gpt-4o", ContextLength: 128000,
Pricing: ModelPricing{Input: 2.5, Output: 10.0}},
{ID: "anthropic/claude-3.5-sonnet:free", ContextLength: 200000,
Pricing: ModelPricing{}},
}, nil
}
client := &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second}
req, err := http.NewRequest("GET", cfg.APIURL, nil)
if err != nil {
return nil, fmt.Errorf("构造请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
req.Header.Set("Content-Type", "application/json")
var resp *http.Response
for i := 0; i <= cfg.MaxRetries; i++ {
resp, err = client.Do(req)
if err == nil {
break
}
if i < cfg.MaxRetries {
time.Sleep(time.Duration(i+1) * time.Second)
}
}
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("非 200 响应: %d %s", resp.StatusCode, string(body))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
// 健壮解析,兼容字段缺失和结构差异
models, err := parseModels(body)
if err != nil {
return nil, fmt.Errorf("JSON 解析失败: %w", err)
}
// TODO: 字段标准化映射OpenRouter id → 标准厂商名、模型名)
return models, nil
}
// parseModels 健壮解析模型列表,兼容字段缺失/类型不一致/嵌套结构差异
func parseModels(raw []byte) ([]ModelInfo, error) {
var wrapper struct {
Data json.RawMessage `json:"data"`
}
if err := json.Unmarshal(raw, &wrapper); err != nil {
return nil, fmt.Errorf("解析 data 字段失败: %w", err)
}
// data 为数组,每元素字段可能不同,统一用 map[string]any 兼容
var rawItems []any
if err := json.Unmarshal(wrapper.Data, &rawItems); err != nil {
return nil, fmt.Errorf("解析模型数组失败: %w", err)
}
models := make([]ModelInfo, 0, len(rawItems))
for _, item := range rawItems {
m, ok := item.(map[string]any)
if !ok {
continue // 跳过非法条目
}
model := ModelInfo{
ID: getString(m, "id"),
Name: getString(m, "name"),
}
if model.ID == "" {
continue // id 为必填
}
// pricing 可能为嵌套对象(如 {openrouter: {input: 1}}),尝试多路径取值
if p, ok := m["pricing"].(map[string]any); ok {
model.Pricing.Input = getPrice(p, "input", "prompt")
model.Pricing.Output = getPrice(p, "output", "completion")
}
model.ContextLength = getInt(m, "context_length")
model.Description = getString(m, "description")
model.Created = getInt64(m, "created")
if caps, ok := m["capabilities"].([]any); ok {
for _, c := range caps {
if s, ok := c.(string); ok {
model.Capabilities = append(model.Capabilities, s)
}
}
}
models = append(models, model)
}
return models, nil
}
func getString(m map[string]any, key string) string {
if v, ok := m[key].(string); ok {
return v
}
return ""
}
func getInt(m map[string]any, key string) int {
if v, ok := m[key].(float64); ok {
return int(v)
}
return 0
}
func getInt64(m map[string]any, key string) int64 {
if v, ok := m[key].(float64); ok {
return int64(v)
}
return 0
}
// getPrice 多路径取值,兼容不同嵌套结构(如 {input:1} 或 {openrouter:{input:1}}
func getPrice(m map[string]any, keys ...string) float64 {
for _, k := range keys {
if v, ok := m[k].(float64); ok {
return v
}
}
return 0
}
// summarize 输出采集摘要到 JSON 文件(保持向后兼容)
func summarize(outPath string, models []ModelInfo) error {
return writeJSON(outPath, models)
}
// summarizeDB 将采集结果写入 PostgreSQLmodels + model_prices 表)
func summarizeDB(connStr string, models []ModelInfo) error {
db, err := sql.Open("postgres", connStr)
if err != nil {
return fmt.Errorf("连接数据库失败: %w", err)
}
defer db.Close()
if err := db.Ping(); err != nil {
return fmt.Errorf("ping 数据库失败: %w", err)
}
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开启事务失败: %w", err)
}
defer tx.Rollback()
now := time.Now()
insertedModels := 0
insertedPrices := 0
for _, m := range models {
isFree := len(m.ID) > 5 && m.ID[len(m.ID)-5:] == ":free"
// upsert models 表
var modelID int64
err := tx.QueryRow(`
INSERT INTO models (source, external_id, name, description, context_length, capabilities, created_at_source, is_free, status, raw_payload, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT (external_id) DO UPDATE SET
name = EXCLUDED.name,
description = EXCLUDED.description,
context_length = EXCLUDED.context_length,
capabilities = EXCLUDED.capabilities,
created_at_source = EXCLUDED.created_at_source,
is_free = EXCLUDED.is_free,
status = EXCLUDED.status,
raw_payload = EXCLUDED.raw_payload,
updated_at = $12
RETURNING id
`, "openrouter", m.ID, m.Name, m.Description, m.ContextLength,
jsonCapabilities(m.Capabilities), m.Created, isFree, "active",
rawPayload(m), now, now).Scan(&modelID)
if err != nil {
return fmt.Errorf("写入 models 失败 (%s): %w", m.ID, err)
}
insertedModels++
// upsert model_prices 表(当天有效日期)
effectiveDate := now.Format("2006-01-02")
_, err = tx.Exec(`
INSERT INTO model_prices (model_id, source, currency, input_price_per_mtok, output_price_per_mtok, effective_date, source_url, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (model_id, source, currency, effective_date) DO UPDATE SET
input_price_per_mtok = EXCLUDED.input_price_per_mtok,
output_price_per_mtok = EXCLUDED.output_price_per_mtok,
created_at = EXCLUDED.created_at
`, modelID, "openrouter", "USD", m.Pricing.Input, m.Pricing.Output, effectiveDate, "https://openrouter.ai/api/v1/models", now)
if err != nil {
return fmt.Errorf("写入 model_prices 失败 (%s): %w", m.ID, err)
}
insertedPrices++
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("提交事务失败: %w", err)
}
fmt.Printf("PostgreSQL 写入完成: %d models, %d prices\n", insertedModels, insertedPrices)
return nil
}
func jsonCapabilities(caps []string) []byte {
if len(caps) == 0 {
return []byte("[]")
}
b, _ := json.Marshal(caps)
return b
}
func rawPayload(m ModelInfo) []byte {
b, _ := json.Marshal(m)
return b
}
// writeJSON 统一写入 JSON 文件(含摘要信息)
func writeJSON(outPath string, models []ModelInfo) error {
total := len(models)
var freeCnt, paidCnt int
for _, m := range models {
if len(m.ID) > 5 && m.ID[len(m.ID)-5:] == ":free" {
freeCnt++
} else if m.Pricing.Input > 0 || m.Pricing.Output > 0 {
paidCnt++
}
}
summary := fmt.Sprintf("采集完成: 共 %d 模型(免费 %d / 付费 %d\n", total, freeCnt, paidCnt)
fmt.Print(summary)
out, err := os.Create(outPath)
if err != nil {
return fmt.Errorf("创建输出文件失败: %w", err)
}
defer out.Close()
enc := json.NewEncoder(out)
enc.SetIndent("", " ")
if err := enc.Encode(map[string]any{
"generated_at": time.Now().Format(time.RFC3339),
"total": total,
"free": freeCnt,
"paid": paidCnt,
"models": models,
}); err != nil {
return fmt.Errorf("写入 JSON 失败: %w", err)
}
fmt.Printf("结果已写入: %s\n", outPath)
return nil
}