2026-03-31 11:39:18 +08:00
|
|
|
|
package models
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"fmt"
|
|
|
|
|
|
"io"
|
|
|
|
|
|
"time"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// =============================================================================
|
|
|
|
|
|
// Type Definitions
|
|
|
|
|
|
// =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
// Provider 模型提供商接口
|
|
|
|
|
|
// 所有模型提供商都必须实现此接口
|
|
|
|
|
|
type Provider interface {
|
|
|
|
|
|
// Name 返回提供商名称 (如 "deepseek", "qwen", "baidu")
|
|
|
|
|
|
Name() string
|
|
|
|
|
|
|
|
|
|
|
|
// BaseURL 返回 API 基础地址
|
|
|
|
|
|
BaseURL() string
|
|
|
|
|
|
|
|
|
|
|
|
// Models 返回支持的模型列表
|
|
|
|
|
|
Models() []Model
|
|
|
|
|
|
|
|
|
|
|
|
// Chat 发起聊天请求 (非流式)
|
|
|
|
|
|
Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error)
|
|
|
|
|
|
|
|
|
|
|
|
// ChatStream 发起流式聊天请求
|
|
|
|
|
|
ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error)
|
|
|
|
|
|
|
|
|
|
|
|
// Embeddings 获取嵌入向量
|
|
|
|
|
|
Embeddings(ctx context.Context, req *EmbeddingsRequest) (*EmbeddingsResponse, error)
|
|
|
|
|
|
|
|
|
|
|
|
// ValidateKey 验证 API 密钥有效性
|
|
|
|
|
|
ValidateKey(ctx context.Context, key string) error
|
|
|
|
|
|
|
|
|
|
|
|
// Close 关闭 provider,释放资源
|
|
|
|
|
|
Close() error
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Model 模型信息
|
|
|
|
|
|
type Model struct {
|
|
|
|
|
|
ID string `json:"id"` // 模型 ID (如 "deepseek-chat")
|
|
|
|
|
|
Name string `json:"name"` // 显示名称 (如 "DeepSeek Chat")
|
|
|
|
|
|
Provider string `json:"provider"` // 提供商名称
|
|
|
|
|
|
Type string `json:"type"` // 类型: "chat", "embedding", "image"
|
|
|
|
|
|
ContextSize int `json:"context_size"` // 上下文长度 (tokens)
|
|
|
|
|
|
MaxTokens int `json:"max_tokens"` // 最大输出 tokens
|
|
|
|
|
|
Capabilities []string `json:"capabilities"` // 能力列表: "streaming", "vision", "function_call"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ChatRequest 聊天请求
|
|
|
|
|
|
type ChatRequest struct {
|
|
|
|
|
|
Model string `json:"model"`
|
|
|
|
|
|
Messages []ChatMessage `json:"messages"`
|
|
|
|
|
|
Temperature float64 `json:"temperature,omitempty"`
|
|
|
|
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
|
|
|
|
Stream bool `json:"stream,omitempty"`
|
|
|
|
|
|
TopP float64 `json:"top_p,omitempty"`
|
|
|
|
|
|
Stop []string `json:"stop,omitempty"`
|
|
|
|
|
|
Tools []Tool `json:"tools,omitempty"`
|
|
|
|
|
|
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
|
|
|
|
|
ResponseFormat interface{} `json:"response_format,omitempty"`
|
|
|
|
|
|
// Provider 特定参数
|
|
|
|
|
|
Extra map[string]interface{} `json:"-"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ChatMessage 聊天消息
|
|
|
|
|
|
type ChatMessage struct {
|
|
|
|
|
|
Role string `json:"role"` // "system", "user", "assistant", "tool"
|
|
|
|
|
|
Content string `json:"content"`
|
|
|
|
|
|
Name string `json:"name,omitempty"`
|
|
|
|
|
|
// Tool call 相关
|
|
|
|
|
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
|
|
|
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ToolCall 工具调用
|
|
|
|
|
|
type ToolCall struct {
|
|
|
|
|
|
ID string `json:"id"`
|
|
|
|
|
|
Type string `json:"type"` // "function"
|
|
|
|
|
|
Function FunctionCall `json:"function"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// FunctionCall 函数调用
|
|
|
|
|
|
type FunctionCall struct {
|
|
|
|
|
|
Name string `json:"name"`
|
|
|
|
|
|
Arguments string `json:"arguments"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Tool 工具定义
|
|
|
|
|
|
type Tool struct {
|
|
|
|
|
|
Type string `json:"type"` // "function"
|
|
|
|
|
|
Function FunctionDefinition `json:"function"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// FunctionDefinition 函数定义
|
|
|
|
|
|
type FunctionDefinition struct {
|
|
|
|
|
|
Name string `json:"name"`
|
|
|
|
|
|
Description string `json:"description"`
|
|
|
|
|
|
Parameters map[string]interface{} `json:"parameters"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ChatResponse 聊天响应 (非流式)
|
|
|
|
|
|
type ChatResponse struct {
|
|
|
|
|
|
ID string `json:"id"`
|
|
|
|
|
|
Object string `json:"object"`
|
|
|
|
|
|
Created int64 `json:"created"`
|
|
|
|
|
|
Model string `json:"model"`
|
|
|
|
|
|
Choices []Choice `json:"choices"`
|
|
|
|
|
|
Usage Usage `json:"usage"`
|
|
|
|
|
|
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Choice 选项
|
|
|
|
|
|
type Choice struct {
|
|
|
|
|
|
Index int `json:"index"`
|
|
|
|
|
|
Message ChatMessage `json:"message"`
|
|
|
|
|
|
FinishReason string `json:"finish_reason"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Usage 用量统计
|
|
|
|
|
|
type Usage struct {
|
|
|
|
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
|
|
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
|
|
|
|
TotalTokens int `json:"total_tokens"`
|
|
|
|
|
|
// Provider 特定
|
|
|
|
|
|
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
|
|
|
|
|
|
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
type PromptTokensDetails struct {
|
|
|
|
|
|
CachedTokens int `json:"cached_tokens,omitempty"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
type CompletionTokensDetails struct {
|
|
|
|
|
|
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// EmbeddingsRequest 嵌入请求
|
|
|
|
|
|
type EmbeddingsRequest struct {
|
|
|
|
|
|
Model string `json:"model"`
|
|
|
|
|
|
Input []string `json:"input"`
|
|
|
|
|
|
InputType string `json:"input_type,omitempty"`
|
|
|
|
|
|
EncodingFormat string `json:"encoding_format,omitempty"`
|
|
|
|
|
|
Dimensions int `json:"dimensions,omitempty"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// EmbeddingsResponse 嵌入响应
|
|
|
|
|
|
type EmbeddingsResponse struct {
|
|
|
|
|
|
Object string `json:"object"`
|
|
|
|
|
|
Data []Embedding `json:"data"`
|
|
|
|
|
|
Model string `json:"model"`
|
|
|
|
|
|
Usage Usage `json:"usage"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Embedding 嵌入向量
|
|
|
|
|
|
type Embedding struct {
|
|
|
|
|
|
Object string `json:"object"`
|
|
|
|
|
|
Index int `json:"index"`
|
|
|
|
|
|
Embedding []float64 `json:"embedding"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ProviderConfig 提供商配置
|
|
|
|
|
|
type ProviderConfig struct {
|
|
|
|
|
|
APIKey string
|
|
|
|
|
|
BaseURL string
|
|
|
|
|
|
Organization string
|
|
|
|
|
|
Timeout time.Duration
|
|
|
|
|
|
// 认证相关
|
|
|
|
|
|
AuthType string // "bearer", "api_key", "oauth"
|
|
|
|
|
|
// Provider 特定
|
|
|
|
|
|
Extra map[string]interface{}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// HTTPClientConfig HTTP 客户端配置
|
|
|
|
|
|
type HTTPConfig struct {
|
|
|
|
|
|
Timeout time.Duration
|
|
|
|
|
|
MaxRetries int
|
|
|
|
|
|
RetryDelay time.Duration
|
|
|
|
|
|
MaxKeepAlive int
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// PoolConfig 连接池配置
|
|
|
|
|
|
type PoolConfig struct {
|
|
|
|
|
|
MaxIdle int
|
|
|
|
|
|
MaxActive int
|
|
|
|
|
|
IdleTimeout time.Duration
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Errors
|
|
|
|
|
|
var (
|
|
|
|
|
|
ErrUnknownProvider = NewError("unknown provider: %s")
|
|
|
|
|
|
ErrInvalidAPIKey = NewError("invalid API key")
|
|
|
|
|
|
ErrRateLimited = NewError("rate limited")
|
|
|
|
|
|
ErrInsufficientQuota = NewError("insufficient quota")
|
|
|
|
|
|
ErrModelNotFound = NewError("model not found: %s")
|
|
|
|
|
|
ErrInvalidRequest = NewError("invalid request: %s")
|
|
|
|
|
|
ErrContextCancelled = NewError("context cancelled")
|
|
|
|
|
|
ErrTimeout = NewError("request timeout")
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// NewError 创建错误
|
|
|
|
|
|
func NewError(format string, args ...interface{}) error {
|
|
|
|
|
|
return &ModelError{Message: format, Args: args}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Errorf 创建格式化错误
|
|
|
|
|
|
func Errorf(format string, args ...interface{}) error {
|
|
|
|
|
|
return &ModelError{Message: format, Args: args}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ModelError 模型错误
|
|
|
|
|
|
type ModelError struct {
|
|
|
|
|
|
Message string
|
|
|
|
|
|
Args []interface{}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (e *ModelError) Error() string {
|
|
|
|
|
|
if len(e.Args) > 0 {
|
2026-04-01 13:39:37 +08:00
|
|
|
|
return fmt.Sprintf(e.Message, e.Args...)
|
2026-03-31 11:39:18 +08:00
|
|
|
|
}
|
|
|
|
|
|
return e.Message
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (e *ModelError) Unwrap() error {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Is performs error identity comparison.
|
|
|
|
|
|
// Only returns true when target is the exact same ModelError instance
|
|
|
|
|
|
// (including format string and arguments), not for substring matches.
|
|
|
|
|
|
// This ensures errors.Is(ErrRateLimited, ErrInvalidRequest) returns false
|
|
|
|
|
|
// even when message strings happen to be substrings of each other.
|
|
|
|
|
|
func (e *ModelError) Is(target error) bool {
|
|
|
|
|
|
targetME, ok := target.(*ModelError)
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
if e == targetME {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
if e.Message != targetME.Message {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
if len(e.Args) != len(targetME.Args) {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
for i, a := range e.Args {
|
|
|
|
|
|
if a != targetME.Args[i] {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|