352 lines
9.9 KiB
Go
352 lines
9.9 KiB
Go
|
|
// fetch_openrouter.go - OpenRouter 模型数据采集器
|
|||
|
|
// Phase 1 单数据源采集器,抓取模型基础信息与价格信息
|
|||
|
|
package main
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"database/sql"
|
|||
|
|
"encoding/json"
|
|||
|
|
"flag"
|
|||
|
|
"fmt"
|
|||
|
|
"io"
|
|||
|
|
"net/http"
|
|||
|
|
"os"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
_ "github.com/lib/pq"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// Config 采集配置
|
|||
|
|
type Config struct {
|
|||
|
|
APIKey string
|
|||
|
|
APIURL string
|
|||
|
|
OutPath string
|
|||
|
|
MaxRetries int
|
|||
|
|
TimeoutSec int
|
|||
|
|
// PostgreSQL 连接参数(新增)
|
|||
|
|
DBConn string // e.g. "host=/var/run/postgresql dbname=llm_intelligence sslmode=disable"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// OpenRouter API 响应结构(仅关键字段)
|
|||
|
|
type APIResponse struct {
|
|||
|
|
Data []ModelInfo `json:"data"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type ModelInfo struct {
|
|||
|
|
ID string `json:"id"`
|
|||
|
|
Name string `json:"name,omitempty"`
|
|||
|
|
Created int64 `json:"created,omitempty"`
|
|||
|
|
Description string `json:"description,omitempty"`
|
|||
|
|
ContextLength int `json:"context_length,omitempty"`
|
|||
|
|
Capabilities []string `json:"capabilities,omitempty"`
|
|||
|
|
Pricing ModelPricing `json:"pricing,omitempty"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type ModelPricing struct {
|
|||
|
|
Input float64 `json:"input,omitempty"`
|
|||
|
|
Output float64 `json:"output,omitempty"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func main() {
|
|||
|
|
cfg := parseArgs()
|
|||
|
|
if err := run(cfg); err != nil {
|
|||
|
|
fmt.Fprintf(os.Stderr, "采集失败: %v\n", err)
|
|||
|
|
os.Exit(1)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func parseArgs() Config {
|
|||
|
|
apiKey := flag.String("api-key", "", "OpenRouter API Key(建议通过环境变量注入)")
|
|||
|
|
apiURL := flag.String("api-url", "https://openrouter.ai/api/v1/models", "API 地址")
|
|||
|
|
outPath := flag.String("out", "models.json", "输出文件路径")
|
|||
|
|
maxRetries := flag.Int("retry", 3, "最大重试次数")
|
|||
|
|
timeoutSec := flag.Int("timeout", 30, "请求超时(秒)")
|
|||
|
|
dbConn := flag.String("db", os.Getenv("DATABASE_URL"), "PostgreSQL 连接字符串(默认从 DATABASE_URL 环境变量读取)")
|
|||
|
|
flag.Parse()
|
|||
|
|
return Config{
|
|||
|
|
APIKey: *apiKey,
|
|||
|
|
APIURL: *apiURL,
|
|||
|
|
OutPath: *outPath,
|
|||
|
|
MaxRetries: *maxRetries,
|
|||
|
|
TimeoutSec: *timeoutSec,
|
|||
|
|
DBConn: *dbConn,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func run(cfg Config) error {
|
|||
|
|
models, err := fetchModels(cfg)
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
// 优先写入 PostgreSQL;若配置了 DBConn 则入库
|
|||
|
|
if cfg.DBConn != "" {
|
|||
|
|
if err := summarizeDB(cfg.DBConn, models); err != nil {
|
|||
|
|
fmt.Fprintf(os.Stderr, "警告: PostgreSQL 写入失败: %v\n", err)
|
|||
|
|
fmt.Fprintln(os.Stderr, "降级为仅写入 JSON")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return summarize(cfg.OutPath, models)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// fetchModels 抓取 OpenRouter 模型列表
|
|||
|
|
func fetchModels(cfg Config) ([]ModelInfo, error) {
|
|||
|
|
// 无 API Key 时返回模拟数据(写入由后续 summarize 统一处理)
|
|||
|
|
if cfg.APIKey == "" {
|
|||
|
|
fmt.Println("警告: 未提供 API Key,使用模拟数据")
|
|||
|
|
return []ModelInfo{
|
|||
|
|
{ID: "openai/gpt-4o", ContextLength: 128000,
|
|||
|
|
Pricing: ModelPricing{Input: 2.5, Output: 10.0}},
|
|||
|
|
{ID: "anthropic/claude-3.5-sonnet:free", ContextLength: 200000,
|
|||
|
|
Pricing: ModelPricing{}},
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
client := &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second}
|
|||
|
|
req, err := http.NewRequest("GET", cfg.APIURL, nil)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("构造请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
|
|||
|
|
req.Header.Set("Content-Type", "application/json")
|
|||
|
|
|
|||
|
|
var resp *http.Response
|
|||
|
|
for i := 0; i <= cfg.MaxRetries; i++ {
|
|||
|
|
resp, err = client.Do(req)
|
|||
|
|
if err == nil {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
if i < cfg.MaxRetries {
|
|||
|
|
time.Sleep(time.Duration(i+1) * time.Second)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
|
|||
|
|
if resp.StatusCode != http.StatusOK {
|
|||
|
|
body, _ := io.ReadAll(resp.Body)
|
|||
|
|
return nil, fmt.Errorf("非 200 响应: %d %s", resp.StatusCode, string(body))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
body, err := io.ReadAll(resp.Body)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("读取响应失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 健壮解析,兼容字段缺失和结构差异
|
|||
|
|
models, err := parseModels(body)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("JSON 解析失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TODO: 字段标准化映射(OpenRouter id → 标准厂商名、模型名)
|
|||
|
|
return models, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// parseModels 健壮解析模型列表,兼容字段缺失/类型不一致/嵌套结构差异
|
|||
|
|
func parseModels(raw []byte) ([]ModelInfo, error) {
|
|||
|
|
var wrapper struct {
|
|||
|
|
Data json.RawMessage `json:"data"`
|
|||
|
|
}
|
|||
|
|
if err := json.Unmarshal(raw, &wrapper); err != nil {
|
|||
|
|
return nil, fmt.Errorf("解析 data 字段失败: %w", err)
|
|||
|
|
}
|
|||
|
|
// data 为数组,每元素字段可能不同,统一用 map[string]any 兼容
|
|||
|
|
var rawItems []any
|
|||
|
|
if err := json.Unmarshal(wrapper.Data, &rawItems); err != nil {
|
|||
|
|
return nil, fmt.Errorf("解析模型数组失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
models := make([]ModelInfo, 0, len(rawItems))
|
|||
|
|
for _, item := range rawItems {
|
|||
|
|
m, ok := item.(map[string]any)
|
|||
|
|
if !ok {
|
|||
|
|
continue // 跳过非法条目
|
|||
|
|
}
|
|||
|
|
model := ModelInfo{
|
|||
|
|
ID: getString(m, "id"),
|
|||
|
|
Name: getString(m, "name"),
|
|||
|
|
}
|
|||
|
|
if model.ID == "" {
|
|||
|
|
continue // id 为必填
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// pricing 可能为嵌套对象(如 {openrouter: {input: 1}}),尝试多路径取值
|
|||
|
|
if p, ok := m["pricing"].(map[string]any); ok {
|
|||
|
|
model.Pricing.Input = getPrice(p, "input", "prompt")
|
|||
|
|
model.Pricing.Output = getPrice(p, "output", "completion")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
model.ContextLength = getInt(m, "context_length")
|
|||
|
|
model.Description = getString(m, "description")
|
|||
|
|
model.Created = getInt64(m, "created")
|
|||
|
|
|
|||
|
|
if caps, ok := m["capabilities"].([]any); ok {
|
|||
|
|
for _, c := range caps {
|
|||
|
|
if s, ok := c.(string); ok {
|
|||
|
|
model.Capabilities = append(model.Capabilities, s)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
models = append(models, model)
|
|||
|
|
}
|
|||
|
|
return models, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func getString(m map[string]any, key string) string {
|
|||
|
|
if v, ok := m[key].(string); ok {
|
|||
|
|
return v
|
|||
|
|
}
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func getInt(m map[string]any, key string) int {
|
|||
|
|
if v, ok := m[key].(float64); ok {
|
|||
|
|
return int(v)
|
|||
|
|
}
|
|||
|
|
return 0
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func getInt64(m map[string]any, key string) int64 {
|
|||
|
|
if v, ok := m[key].(float64); ok {
|
|||
|
|
return int64(v)
|
|||
|
|
}
|
|||
|
|
return 0
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getPrice 多路径取值,兼容不同嵌套结构(如 {input:1} 或 {openrouter:{input:1}})
|
|||
|
|
func getPrice(m map[string]any, keys ...string) float64 {
|
|||
|
|
for _, k := range keys {
|
|||
|
|
if v, ok := m[k].(float64); ok {
|
|||
|
|
return v
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return 0
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// summarize 输出采集摘要到 JSON 文件(保持向后兼容)
|
|||
|
|
func summarize(outPath string, models []ModelInfo) error {
|
|||
|
|
return writeJSON(outPath, models)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// summarizeDB 将采集结果写入 PostgreSQL(models + model_prices 表)
|
|||
|
|
func summarizeDB(connStr string, models []ModelInfo) error {
|
|||
|
|
db, err := sql.Open("postgres", connStr)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("连接数据库失败: %w", err)
|
|||
|
|
}
|
|||
|
|
defer db.Close()
|
|||
|
|
|
|||
|
|
if err := db.Ping(); err != nil {
|
|||
|
|
return fmt.Errorf("ping 数据库失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
tx, err := db.Begin()
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("开启事务失败: %w", err)
|
|||
|
|
}
|
|||
|
|
defer tx.Rollback()
|
|||
|
|
|
|||
|
|
now := time.Now()
|
|||
|
|
insertedModels := 0
|
|||
|
|
insertedPrices := 0
|
|||
|
|
|
|||
|
|
for _, m := range models {
|
|||
|
|
isFree := len(m.ID) > 5 && m.ID[len(m.ID)-5:] == ":free"
|
|||
|
|
// upsert models 表
|
|||
|
|
var modelID int64
|
|||
|
|
err := tx.QueryRow(`
|
|||
|
|
INSERT INTO models (source, external_id, name, description, context_length, capabilities, created_at_source, is_free, status, raw_payload, created_at, updated_at)
|
|||
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
|||
|
|
ON CONFLICT (external_id) DO UPDATE SET
|
|||
|
|
name = EXCLUDED.name,
|
|||
|
|
description = EXCLUDED.description,
|
|||
|
|
context_length = EXCLUDED.context_length,
|
|||
|
|
capabilities = EXCLUDED.capabilities,
|
|||
|
|
created_at_source = EXCLUDED.created_at_source,
|
|||
|
|
is_free = EXCLUDED.is_free,
|
|||
|
|
status = EXCLUDED.status,
|
|||
|
|
raw_payload = EXCLUDED.raw_payload,
|
|||
|
|
updated_at = $12
|
|||
|
|
RETURNING id
|
|||
|
|
`, "openrouter", m.ID, m.Name, m.Description, m.ContextLength,
|
|||
|
|
jsonCapabilities(m.Capabilities), m.Created, isFree, "active",
|
|||
|
|
rawPayload(m), now, now).Scan(&modelID)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("写入 models 失败 (%s): %w", m.ID, err)
|
|||
|
|
}
|
|||
|
|
insertedModels++
|
|||
|
|
|
|||
|
|
// upsert model_prices 表(当天有效日期)
|
|||
|
|
effectiveDate := now.Format("2006-01-02")
|
|||
|
|
_, err = tx.Exec(`
|
|||
|
|
INSERT INTO model_prices (model_id, source, currency, input_price_per_mtok, output_price_per_mtok, effective_date, source_url, created_at)
|
|||
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|||
|
|
ON CONFLICT (model_id, source, currency, effective_date) DO UPDATE SET
|
|||
|
|
input_price_per_mtok = EXCLUDED.input_price_per_mtok,
|
|||
|
|
output_price_per_mtok = EXCLUDED.output_price_per_mtok,
|
|||
|
|
created_at = EXCLUDED.created_at
|
|||
|
|
`, modelID, "openrouter", "USD", m.Pricing.Input, m.Pricing.Output, effectiveDate, "https://openrouter.ai/api/v1/models", now)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("写入 model_prices 失败 (%s): %w", m.ID, err)
|
|||
|
|
}
|
|||
|
|
insertedPrices++
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := tx.Commit(); err != nil {
|
|||
|
|
return fmt.Errorf("提交事务失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fmt.Printf("PostgreSQL 写入完成: %d models, %d prices\n", insertedModels, insertedPrices)
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func jsonCapabilities(caps []string) []byte {
|
|||
|
|
if len(caps) == 0 {
|
|||
|
|
return []byte("[]")
|
|||
|
|
}
|
|||
|
|
b, _ := json.Marshal(caps)
|
|||
|
|
return b
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func rawPayload(m ModelInfo) []byte {
|
|||
|
|
b, _ := json.Marshal(m)
|
|||
|
|
return b
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// writeJSON 统一写入 JSON 文件(含摘要信息)
|
|||
|
|
func writeJSON(outPath string, models []ModelInfo) error {
|
|||
|
|
total := len(models)
|
|||
|
|
var freeCnt, paidCnt int
|
|||
|
|
for _, m := range models {
|
|||
|
|
if len(m.ID) > 5 && m.ID[len(m.ID)-5:] == ":free" {
|
|||
|
|
freeCnt++
|
|||
|
|
} else if m.Pricing.Input > 0 || m.Pricing.Output > 0 {
|
|||
|
|
paidCnt++
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
summary := fmt.Sprintf("采集完成: 共 %d 模型(免费 %d / 付费 %d)\n", total, freeCnt, paidCnt)
|
|||
|
|
fmt.Print(summary)
|
|||
|
|
|
|||
|
|
out, err := os.Create(outPath)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("创建输出文件失败: %w", err)
|
|||
|
|
}
|
|||
|
|
defer out.Close()
|
|||
|
|
|
|||
|
|
enc := json.NewEncoder(out)
|
|||
|
|
enc.SetIndent("", " ")
|
|||
|
|
if err := enc.Encode(map[string]any{
|
|||
|
|
"generated_at": time.Now().Format(time.RFC3339),
|
|||
|
|
"total": total,
|
|||
|
|
"free": freeCnt,
|
|||
|
|
"paid": paidCnt,
|
|||
|
|
"models": models,
|
|||
|
|
}); err != nil {
|
|||
|
|
return fmt.Errorf("写入 JSON 失败: %w", err)
|
|||
|
|
}
|
|||
|
|
fmt.Printf("结果已写入: %s\n", outPath)
|
|||
|
|
return nil
|
|||
|
|
}
|