feat(P1/P2): 完成TDD开发及P1/P2设计文档
## 设计文档 - multi_role_permission_design: 多角色权限设计 (CONDITIONAL GO) - audit_log_enhancement_design: 审计日志增强 (CONDITIONAL GO) - routing_strategy_template_design: 路由策略模板 (CONDITIONAL GO) - sso_saml_technical_research: SSO/SAML调研 (CONDITIONAL GO) - compliance_capability_package_design: 合规能力包设计 (CONDITIONAL GO) ## TDD开发成果 - IAM模块: supply-api/internal/iam/ (111个测试) - 审计日志模块: supply-api/internal/audit/ (40+测试) - 路由策略模块: gateway/internal/router/ (33+测试) - 合规能力包: gateway/internal/compliance/ + scripts/ci/compliance/ ## 规范文档 - parallel_agent_output_quality_standards: 并行Agent产出质量规范 - project_experience_summary: 项目经验总结 (v2) - 2026-04-02-p1-p2-tdd-execution-plan: TDD执行计划 ## 评审报告 - 5个CONDITIONAL GO设计文档评审报告 - fix_verification_report: 修复验证报告 - full_verification_report: 全面质量验证报告 - tdd_module_quality_verification: TDD模块质量验证 - tdd_execution_summary: TDD执行总结 依据: Superpowers执行框架 + TDD规范
This commit is contained in:
71
gateway/internal/router/strategy/ab_strategy.go
Normal file
71
gateway/internal/router/strategy/ab_strategy.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package strategy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ABStrategy A/B测试策略
|
||||
type ABStrategy struct {
|
||||
controlStrategy *RoutingStrategyTemplate
|
||||
experimentStrategy *RoutingStrategyTemplate
|
||||
trafficSplit int // 实验组流量百分比 (0-100)
|
||||
bucketKey string // 分桶key
|
||||
experimentID string
|
||||
startTime *time.Time
|
||||
endTime *time.Time
|
||||
}
|
||||
|
||||
// NewABStrategy 创建A/B测试策略
|
||||
func NewABStrategy(control, experiment *RoutingStrategyTemplate, split int, bucketKey string) *ABStrategy {
|
||||
return &ABStrategy{
|
||||
controlStrategy: control,
|
||||
experimentStrategy: experiment,
|
||||
trafficSplit: split,
|
||||
bucketKey: bucketKey,
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldApplyToRequest 判断请求是否应该使用实验组策略
|
||||
func (a *ABStrategy) ShouldApplyToRequest(req *RoutingRequest) bool {
|
||||
// 检查时间范围
|
||||
now := time.Now()
|
||||
if a.startTime != nil && now.Before(*a.startTime) {
|
||||
return false
|
||||
}
|
||||
if a.endTime != nil && now.After(*a.endTime) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 一致性哈希分桶
|
||||
bucket := a.hashString(fmt.Sprintf("%s:%s", a.bucketKey, req.UserID)) % 100
|
||||
return bucket < a.trafficSplit
|
||||
}
|
||||
|
||||
// hashString 计算字符串哈希值 (用于一致性分桶)
|
||||
func (a *ABStrategy) hashString(s string) int {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(s))
|
||||
return int(h.Sum32())
|
||||
}
|
||||
|
||||
// GetControlStrategy 获取对照组策略
|
||||
func (a *ABStrategy) GetControlStrategy() *RoutingStrategyTemplate {
|
||||
return a.controlStrategy
|
||||
}
|
||||
|
||||
// GetExperimentStrategy 获取实验组策略
|
||||
func (a *ABStrategy) GetExperimentStrategy() *RoutingStrategyTemplate {
|
||||
return a.experimentStrategy
|
||||
}
|
||||
|
||||
// RoutingStrategyTemplate 路由策略模板
|
||||
type RoutingStrategyTemplate struct {
|
||||
ID string
|
||||
Name string
|
||||
Type string
|
||||
Priority int
|
||||
Enabled bool
|
||||
Description string
|
||||
}
|
||||
161
gateway/internal/router/strategy/ab_strategy_test.go
Normal file
161
gateway/internal/router/strategy/ab_strategy_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package strategy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestABStrategy_TrafficSplit 测试A/B测试流量分配
|
||||
func TestABStrategy_TrafficSplit(t *testing.T) {
|
||||
ab := &ABStrategy{
|
||||
controlStrategy: &RoutingStrategyTemplate{ID: "control"},
|
||||
experimentStrategy: &RoutingStrategyTemplate{ID: "experiment"},
|
||||
trafficSplit: 20, // 20%实验组
|
||||
bucketKey: "user_id",
|
||||
}
|
||||
|
||||
// 验证流量分配
|
||||
// 一致性哈希:同一user_id始终分配到同一组
|
||||
controlCount := 0
|
||||
experimentCount := 0
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
userID := string(rune('0' + i))
|
||||
isExperiment := ab.ShouldApplyToRequest(&RoutingRequest{UserID: userID})
|
||||
|
||||
if isExperiment {
|
||||
experimentCount++
|
||||
} else {
|
||||
controlCount++
|
||||
}
|
||||
}
|
||||
|
||||
// 验证一致性:同一user_id应该始终在同一组
|
||||
for i := 0; i < 10; i++ {
|
||||
userID := "test_user_123"
|
||||
first := ab.ShouldApplyToRequest(&RoutingRequest{UserID: userID})
|
||||
for j := 0; j < 10; j++ {
|
||||
second := ab.ShouldApplyToRequest(&RoutingRequest{UserID: userID})
|
||||
assert.Equal(t, first, second, "Same user_id should always be in same group")
|
||||
}
|
||||
}
|
||||
|
||||
// 验证分配比例大约是80:20
|
||||
assert.InDelta(t, 80, controlCount, 15, "Control should be around 80%%")
|
||||
assert.InDelta(t, 20, experimentCount, 15, "Experiment should be around 20%%")
|
||||
}
|
||||
|
||||
// TestRollout_Percentage 测试灰度发布百分比递增
|
||||
func TestRollout_Percentage(t *testing.T) {
|
||||
rollout := &RolloutStrategy{
|
||||
percentage: 10,
|
||||
bucketKey: "user_id",
|
||||
}
|
||||
|
||||
// 统计10%时的用户数
|
||||
count10 := 0
|
||||
for i := 0; i < 100; i++ {
|
||||
userID := string(rune('0' + i))
|
||||
if rollout.ShouldApply(&RoutingRequest{UserID: userID}) {
|
||||
count10++
|
||||
}
|
||||
}
|
||||
assert.InDelta(t, 10, count10, 5, "10%% rollout should have around 10 users")
|
||||
|
||||
// 增加百分比到20%
|
||||
rollout.SetPercentage(20)
|
||||
|
||||
// 统计20%时的用户数
|
||||
count20 := 0
|
||||
for i := 0; i < 100; i++ {
|
||||
userID := string(rune('0' + i))
|
||||
if rollout.ShouldApply(&RoutingRequest{UserID: userID}) {
|
||||
count20++
|
||||
}
|
||||
}
|
||||
assert.InDelta(t, 20, count20, 5, "20%% rollout should have around 20 users")
|
||||
|
||||
// 增加百分比到50%
|
||||
rollout.SetPercentage(50)
|
||||
|
||||
// 统计50%时的用户数
|
||||
count50 := 0
|
||||
for i := 0; i < 100; i++ {
|
||||
userID := string(rune('0' + i))
|
||||
if rollout.ShouldApply(&RoutingRequest{UserID: userID}) {
|
||||
count50++
|
||||
}
|
||||
}
|
||||
assert.InDelta(t, 50, count50, 10, "50%% rollout should have around 50 users")
|
||||
|
||||
// 增加百分比到100%
|
||||
rollout.SetPercentage(100)
|
||||
|
||||
// 验证100%时所有用户都在
|
||||
for i := 0; i < 100; i++ {
|
||||
userID := string(rune('0' + i))
|
||||
assert.True(t, rollout.ShouldApply(&RoutingRequest{UserID: userID}), "All users should be in 100% rollout")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRollout_Consistency 测试灰度发布一致性
|
||||
func TestRollout_Consistency(t *testing.T) {
|
||||
rollout := &RolloutStrategy{
|
||||
percentage: 30,
|
||||
bucketKey: "user_id",
|
||||
}
|
||||
|
||||
// 同一用户应该始终被同样对待
|
||||
userID := "consistent_user"
|
||||
firstResult := rollout.ShouldApply(&RoutingRequest{UserID: userID})
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
result := rollout.ShouldApply(&RoutingRequest{UserID: userID})
|
||||
assert.Equal(t, firstResult, result, "Same user should always have same result")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRollout_PercentageIncrease 测试百分比递增
|
||||
func TestRollout_PercentageIncrease(t *testing.T) {
|
||||
rollout := &RolloutStrategy{
|
||||
percentage: 10,
|
||||
bucketKey: "user_id",
|
||||
}
|
||||
|
||||
// 收集10%时的用户
|
||||
var in10Percent []string
|
||||
for i := 0; i < 100; i++ {
|
||||
userID := string(rune('a' + i))
|
||||
if rollout.ShouldApply(&RoutingRequest{UserID: userID}) {
|
||||
in10Percent = append(in10Percent, userID)
|
||||
}
|
||||
}
|
||||
|
||||
// 增加百分比到50%
|
||||
rollout.SetPercentage(50)
|
||||
|
||||
// 收集50%时的用户
|
||||
var in50Percent []string
|
||||
for i := 0; i < 100; i++ {
|
||||
userID := string(rune('a' + i))
|
||||
if rollout.ShouldApply(&RoutingRequest{UserID: userID}) {
|
||||
in50Percent = append(in50Percent, userID)
|
||||
}
|
||||
}
|
||||
|
||||
// 50%的用户应该包含10%的用户(一致性)
|
||||
for _, userID := range in10Percent {
|
||||
found := false
|
||||
for _, id := range in50Percent {
|
||||
if userID == id {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "10%% users should be included in 50%% rollout")
|
||||
}
|
||||
|
||||
// 50%应该包含更多用户
|
||||
assert.Greater(t, len(in50Percent), len(in10Percent), "50%% should have more users than 10%%")
|
||||
}
|
||||
189
gateway/internal/router/strategy/cost_aware.go
Normal file
189
gateway/internal/router/strategy/cost_aware.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package strategy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"lijiaoqiao/gateway/internal/adapter"
|
||||
"lijiaoqiao/gateway/internal/router/scoring"
|
||||
gwerror "lijiaoqiao/gateway/pkg/error"
|
||||
)
|
||||
|
||||
// ErrNoQualifiedProvider 没有符合条件的Provider
|
||||
var ErrNoQualifiedProvider = errors.New("no qualified provider available")
|
||||
|
||||
// CostAwareTemplate 成本感知策略模板
|
||||
// 综合考虑成本、质量、延迟进行权衡
|
||||
type CostAwareTemplate struct {
|
||||
name string
|
||||
maxCostPer1KTokens float64
|
||||
maxLatencyMs int64
|
||||
minQualityScore float64
|
||||
providers map[string]adapter.ProviderAdapter
|
||||
scoringModel *scoring.ScoringModel
|
||||
}
|
||||
|
||||
// CostAwareParams 成本感知参数
|
||||
type CostAwareParams struct {
|
||||
MaxCostPer1KTokens float64
|
||||
MaxLatencyMs int64
|
||||
MinQualityScore float64
|
||||
}
|
||||
|
||||
// NewCostAwareTemplate 创建成本感知策略模板
|
||||
func NewCostAwareTemplate(name string, params CostAwareParams) *CostAwareTemplate {
|
||||
return &CostAwareTemplate{
|
||||
name: name,
|
||||
maxCostPer1KTokens: params.MaxCostPer1KTokens,
|
||||
maxLatencyMs: params.MaxLatencyMs,
|
||||
minQualityScore: params.MinQualityScore,
|
||||
providers: make(map[string]adapter.ProviderAdapter),
|
||||
scoringModel: scoring.NewScoringModel(scoring.DefaultWeights),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterProvider 注册Provider
|
||||
func (t *CostAwareTemplate) RegisterProvider(name string, provider adapter.ProviderAdapter) {
|
||||
t.providers[name] = provider
|
||||
}
|
||||
|
||||
// Name 获取策略名称
|
||||
func (t *CostAwareTemplate) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// Type 获取策略类型
|
||||
func (t *CostAwareTemplate) Type() string {
|
||||
return "cost_aware"
|
||||
}
|
||||
|
||||
// SelectProvider 选择最佳平衡的Provider
|
||||
func (t *CostAwareTemplate) SelectProvider(ctx context.Context, req *RoutingRequest) (*RoutingDecision, error) {
|
||||
if len(t.providers) == 0 {
|
||||
return nil, gwerror.NewGatewayError(gwerror.ROUTER_NO_PROVIDER_AVAILABLE, "no provider registered")
|
||||
}
|
||||
|
||||
type candidate struct {
|
||||
name string
|
||||
cost float64
|
||||
quality float64
|
||||
latency int64
|
||||
score float64
|
||||
}
|
||||
|
||||
var candidates []candidate
|
||||
maxCost := t.maxCostPer1KTokens
|
||||
if req.MaxCost > 0 && req.MaxCost < maxCost {
|
||||
maxCost = req.MaxCost
|
||||
}
|
||||
maxLatency := t.maxLatencyMs
|
||||
if req.MaxLatency > 0 && req.MaxLatency < maxLatency {
|
||||
maxLatency = req.MaxLatency
|
||||
}
|
||||
minQuality := t.minQualityScore
|
||||
if req.MinQuality > 0 && req.MinQuality > minQuality {
|
||||
minQuality = req.MinQuality
|
||||
}
|
||||
|
||||
for name, provider := range t.providers {
|
||||
// 检查provider是否支持该模型
|
||||
supported := false
|
||||
for _, m := range provider.SupportedModels() {
|
||||
if m == req.Model || m == "*" {
|
||||
supported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !supported {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查健康状态
|
||||
if !provider.HealthCheck(ctx) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取provider指标
|
||||
cost := t.getProviderCost(provider)
|
||||
quality := t.getProviderQuality(provider)
|
||||
latency := t.getProviderLatency(provider)
|
||||
|
||||
// 过滤不满足基本条件的provider
|
||||
if cost > maxCost || latency > maxLatency || quality < minQuality {
|
||||
continue
|
||||
}
|
||||
|
||||
// 计算综合评分
|
||||
metrics := scoring.ProviderMetrics{
|
||||
Name: name,
|
||||
LatencyMs: latency,
|
||||
Availability: 1.0, // 假设可用
|
||||
CostPer1KTokens: cost,
|
||||
QualityScore: quality,
|
||||
}
|
||||
score := t.scoringModel.CalculateScore(metrics)
|
||||
|
||||
candidates = append(candidates, candidate{
|
||||
name: name,
|
||||
cost: cost,
|
||||
quality: quality,
|
||||
latency: latency,
|
||||
score: score,
|
||||
})
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, ErrNoQualifiedProvider
|
||||
}
|
||||
|
||||
// 选择评分最高的provider
|
||||
best := &candidates[0]
|
||||
for i := 1; i < len(candidates); i++ {
|
||||
if candidates[i].score > best.score {
|
||||
best = &candidates[i]
|
||||
}
|
||||
}
|
||||
|
||||
return &RoutingDecision{
|
||||
Provider: best.name,
|
||||
Strategy: t.Type(),
|
||||
CostPer1KTokens: best.cost,
|
||||
EstimatedLatency: best.latency,
|
||||
QualityScore: best.quality,
|
||||
TakeoverMark: true, // M-008: 标记为接管
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getProviderCost 获取Provider的成本
|
||||
func (t *CostAwareTemplate) getProviderCost(provider adapter.ProviderAdapter) float64 {
|
||||
if cp, ok := provider.(CostAwareProvider); ok {
|
||||
return cp.GetCostPer1KTokens()
|
||||
}
|
||||
return 0.5
|
||||
}
|
||||
|
||||
// getProviderQuality 获取Provider的质量分数
|
||||
func (t *CostAwareTemplate) getProviderQuality(provider adapter.ProviderAdapter) float64 {
|
||||
if qp, ok := provider.(QualityProvider); ok {
|
||||
return qp.GetQualityScore()
|
||||
}
|
||||
return 0.8 // 默认质量分数
|
||||
}
|
||||
|
||||
// getProviderLatency 获取Provider的延迟
|
||||
func (t *CostAwareTemplate) getProviderLatency(provider adapter.ProviderAdapter) int64 {
|
||||
if lp, ok := provider.(LatencyProvider); ok {
|
||||
return lp.GetLatencyMs()
|
||||
}
|
||||
return 100 // 默认延迟100ms
|
||||
}
|
||||
|
||||
// QualityProvider 质量感知Provider接口
|
||||
type QualityProvider interface {
|
||||
GetQualityScore() float64
|
||||
}
|
||||
|
||||
// LatencyProvider 延迟感知Provider接口
|
||||
type LatencyProvider interface {
|
||||
GetLatencyMs() int64
|
||||
}
|
||||
108
gateway/internal/router/strategy/cost_aware_test.go
Normal file
108
gateway/internal/router/strategy/cost_aware_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package strategy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestCostAwareStrategy_Balance 测试成本感知策略的平衡选择
|
||||
func TestCostAwareStrategy_Balance(t *testing.T) {
|
||||
template := NewCostAwareTemplate("CostAware", CostAwareParams{
|
||||
MaxCostPer1KTokens: 1.0,
|
||||
MaxLatencyMs: 500,
|
||||
MinQualityScore: 0.7,
|
||||
})
|
||||
|
||||
// 注册多个providers
|
||||
// ProviderA: 低成本, 低质量
|
||||
template.providers["ProviderA"] = &MockProvider{
|
||||
name: "ProviderA",
|
||||
costPer1KTokens: 0.2,
|
||||
available: true,
|
||||
models: []string{"gpt-4"},
|
||||
qualityScore: 0.6, // 质量不达标
|
||||
latencyMs: 100,
|
||||
}
|
||||
|
||||
// ProviderB: 中成本, 高质量, 低延迟
|
||||
template.providers["ProviderB"] = &MockProvider{
|
||||
name: "ProviderB",
|
||||
costPer1KTokens: 0.5,
|
||||
available: true,
|
||||
models: []string{"gpt-4"},
|
||||
qualityScore: 0.9,
|
||||
latencyMs: 150,
|
||||
}
|
||||
|
||||
// ProviderC: 高成本, 高质量, 高延迟
|
||||
template.providers["ProviderC"] = &MockProvider{
|
||||
name: "ProviderC",
|
||||
costPer1KTokens: 0.9,
|
||||
available: true,
|
||||
models: []string{"gpt-4"},
|
||||
qualityScore: 0.95,
|
||||
latencyMs: 400,
|
||||
}
|
||||
|
||||
req := &RoutingRequest{
|
||||
Model: "gpt-4",
|
||||
UserID: "user123",
|
||||
MaxCost: 1.0,
|
||||
MaxLatency: 500,
|
||||
MinQuality: 0.7,
|
||||
}
|
||||
|
||||
decision, err := template.SelectProvider(context.Background(), req)
|
||||
|
||||
// 验证选择逻辑
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, decision)
|
||||
|
||||
// ProviderA因质量不达标应被排除
|
||||
// ProviderB应在成本/质量/延迟权衡中胜出
|
||||
assert.Equal(t, "ProviderB", decision.Provider, "Should select balanced provider")
|
||||
assert.GreaterOrEqual(t, decision.QualityScore, 0.7, "Quality should meet minimum")
|
||||
assert.LessOrEqual(t, decision.CostPer1KTokens, 1.0, "Cost should be within budget")
|
||||
assert.LessOrEqual(t, decision.EstimatedLatency, int64(500), "Latency should be within limit")
|
||||
}
|
||||
|
||||
// TestCostAwareStrategy_QualityThreshold 测试质量阈值过滤
|
||||
func TestCostAwareStrategy_QualityThreshold(t *testing.T) {
|
||||
template := NewCostAwareTemplate("CostAware", CostAwareParams{
|
||||
MaxCostPer1KTokens: 1.0,
|
||||
MaxLatencyMs: 1000,
|
||||
MinQualityScore: 0.9, // 高质量要求
|
||||
})
|
||||
|
||||
// 所有provider质量都不达标
|
||||
template.providers["ProviderA"] = &MockProvider{
|
||||
name: "ProviderA",
|
||||
costPer1KTokens: 0.3,
|
||||
available: true,
|
||||
models: []string{"gpt-4"},
|
||||
qualityScore: 0.7,
|
||||
latencyMs: 100,
|
||||
}
|
||||
template.providers["ProviderB"] = &MockProvider{
|
||||
name: "ProviderB",
|
||||
costPer1KTokens: 0.4,
|
||||
available: true,
|
||||
models: []string{"gpt-4"},
|
||||
qualityScore: 0.8,
|
||||
latencyMs: 150,
|
||||
}
|
||||
|
||||
req := &RoutingRequest{
|
||||
Model: "gpt-4",
|
||||
UserID: "user123",
|
||||
MinQuality: 0.9,
|
||||
}
|
||||
|
||||
decision, err := template.SelectProvider(context.Background(), req)
|
||||
|
||||
// 应该返回错误,因为没有满足质量要求的provider
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, decision)
|
||||
}
|
||||
132
gateway/internal/router/strategy/cost_based.go
Normal file
132
gateway/internal/router/strategy/cost_based.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package strategy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
"lijiaoqiao/gateway/internal/adapter"
|
||||
gwerror "lijiaoqiao/gateway/pkg/error"
|
||||
)
|
||||
|
||||
// ErrNoAffordableProvider 没有可负担的Provider
|
||||
var ErrNoAffordableProvider = errors.New("no affordable provider available")
|
||||
|
||||
// CostBasedTemplate 成本优先策略模板
|
||||
// 选择成本最低的provider
|
||||
type CostBasedTemplate struct {
|
||||
name string
|
||||
maxCostPer1KTokens float64
|
||||
providers map[string]adapter.ProviderAdapter
|
||||
}
|
||||
|
||||
// CostParams 成本参数
|
||||
type CostParams struct {
|
||||
// 最大成本 ($/1K tokens)
|
||||
MaxCostPer1KTokens float64
|
||||
}
|
||||
|
||||
// NewCostBasedTemplate 创建成本优先策略模板
|
||||
func NewCostBasedTemplate(name string, params CostParams) *CostBasedTemplate {
|
||||
return &CostBasedTemplate{
|
||||
name: name,
|
||||
maxCostPer1KTokens: params.MaxCostPer1KTokens,
|
||||
providers: make(map[string]adapter.ProviderAdapter),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterProvider 注册Provider
|
||||
func (t *CostBasedTemplate) RegisterProvider(name string, provider adapter.ProviderAdapter) {
|
||||
t.providers[name] = provider
|
||||
}
|
||||
|
||||
// Name 获取策略名称
|
||||
func (t *CostBasedTemplate) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// Type 获取策略类型
|
||||
func (t *CostBasedTemplate) Type() string {
|
||||
return "cost_based"
|
||||
}
|
||||
|
||||
// SelectProvider 选择成本最低的Provider
|
||||
func (t *CostBasedTemplate) SelectProvider(ctx context.Context, req *RoutingRequest) (*RoutingDecision, error) {
|
||||
if len(t.providers) == 0 {
|
||||
return nil, gwerror.NewGatewayError(gwerror.ROUTER_NO_PROVIDER_AVAILABLE, "no provider registered")
|
||||
}
|
||||
|
||||
// 收集所有可用provider的候选列表
|
||||
type candidate struct {
|
||||
name string
|
||||
cost float64
|
||||
}
|
||||
var candidates []candidate
|
||||
|
||||
for name, provider := range t.providers {
|
||||
// 检查provider是否支持该模型
|
||||
supported := false
|
||||
for _, m := range provider.SupportedModels() {
|
||||
if m == req.Model || m == "*" {
|
||||
supported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !supported {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查健康状态
|
||||
if !provider.HealthCheck(ctx) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取成本信息 (实际实现需要从provider获取)
|
||||
// 这里暂时设置为模拟值
|
||||
cost := t.getProviderCost(provider)
|
||||
candidates = append(candidates, candidate{name: name, cost: cost})
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, gwerror.NewGatewayError(gwerror.ROUTER_NO_PROVIDER_AVAILABLE, "no available provider for model: "+req.Model)
|
||||
}
|
||||
|
||||
// 按成本排序
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].cost < candidates[j].cost
|
||||
})
|
||||
|
||||
// 选择成本最低且在预算内的provider
|
||||
maxCost := t.maxCostPer1KTokens
|
||||
if req.MaxCost > 0 && req.MaxCost < maxCost {
|
||||
maxCost = req.MaxCost
|
||||
}
|
||||
|
||||
for _, c := range candidates {
|
||||
if c.cost <= maxCost {
|
||||
return &RoutingDecision{
|
||||
Provider: c.name,
|
||||
Strategy: t.Type(),
|
||||
CostPer1KTokens: c.cost,
|
||||
TakeoverMark: true, // M-008: 标记为接管
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrNoAffordableProvider
|
||||
}
|
||||
|
||||
// CostAwareProvider 成本感知Provider接口
|
||||
type CostAwareProvider interface {
|
||||
GetCostPer1KTokens() float64
|
||||
}
|
||||
|
||||
// getProviderCost 获取Provider的成本
|
||||
func (t *CostBasedTemplate) getProviderCost(provider adapter.ProviderAdapter) float64 {
|
||||
// 尝试类型断言获取成本
|
||||
if cp, ok := provider.(CostAwareProvider); ok {
|
||||
return cp.GetCostPer1KTokens()
|
||||
}
|
||||
// 默认返回0.5,实际应从provider元数据获取
|
||||
return 0.5
|
||||
}
|
||||
142
gateway/internal/router/strategy/cost_based_test.go
Normal file
142
gateway/internal/router/strategy/cost_based_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package strategy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"lijiaoqiao/gateway/internal/adapter"
|
||||
)
|
||||
|
||||
// TestCostBasedStrategy_SelectProvider 测试成本优先策略选择Provider
|
||||
func TestCostBasedStrategy_SelectProvider(t *testing.T) {
|
||||
template := &CostBasedTemplate{
|
||||
name: "CostBased",
|
||||
maxCostPer1KTokens: 1.0,
|
||||
providers: make(map[string]adapter.ProviderAdapter),
|
||||
}
|
||||
|
||||
// 注册mock providers
|
||||
template.providers["ProviderA"] = &MockProvider{
|
||||
name: "ProviderA",
|
||||
costPer1KTokens: 0.5,
|
||||
available: true,
|
||||
models: []string{"gpt-4"},
|
||||
}
|
||||
template.providers["ProviderB"] = &MockProvider{
|
||||
name: "ProviderB",
|
||||
costPer1KTokens: 0.3, // 最低成本
|
||||
available: true,
|
||||
models: []string{"gpt-4"},
|
||||
}
|
||||
template.providers["ProviderC"] = &MockProvider{
|
||||
name: "ProviderC",
|
||||
costPer1KTokens: 0.8,
|
||||
available: true,
|
||||
models: []string{"gpt-4"},
|
||||
}
|
||||
|
||||
req := &RoutingRequest{
|
||||
Model: "gpt-4",
|
||||
UserID: "user123",
|
||||
MaxCost: 1.0,
|
||||
}
|
||||
|
||||
decision, err := template.SelectProvider(context.Background(), req)
|
||||
|
||||
// 验证选择了最低成本的Provider
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, decision)
|
||||
assert.Equal(t, "ProviderB", decision.Provider, "Should select lowest cost provider")
|
||||
assert.LessOrEqual(t, decision.CostPer1KTokens, 1.0, "Cost should be within budget")
|
||||
}
|
||||
|
||||
func TestCostBasedStrategy_Fallback(t *testing.T) {
|
||||
// 成本超出阈值时fallback
|
||||
template := &CostBasedTemplate{
|
||||
name: "CostBased",
|
||||
maxCostPer1KTokens: 0.5, // 设置低成本上限
|
||||
providers: make(map[string]adapter.ProviderAdapter),
|
||||
}
|
||||
|
||||
// 注册成本较高的providers
|
||||
template.providers["ProviderA"] = &MockProvider{
|
||||
name: "ProviderA",
|
||||
costPer1KTokens: 0.8,
|
||||
available: true,
|
||||
models: []string{"gpt-4"},
|
||||
}
|
||||
template.providers["ProviderB"] = &MockProvider{
|
||||
name: "ProviderB",
|
||||
costPer1KTokens: 1.0,
|
||||
available: true,
|
||||
models: []string{"gpt-4"},
|
||||
}
|
||||
|
||||
req := &RoutingRequest{
|
||||
Model: "gpt-4",
|
||||
UserID: "user123",
|
||||
MaxCost: 0.5,
|
||||
}
|
||||
|
||||
decision, err := template.SelectProvider(context.Background(), req)
|
||||
|
||||
// 应该返回错误
|
||||
assert.Error(t, err, "Should return error when no affordable provider")
|
||||
assert.Nil(t, decision, "Should not return decision when cost exceeds threshold")
|
||||
assert.Equal(t, ErrNoAffordableProvider, err, "Should return ErrNoAffordableProvider")
|
||||
}
|
||||
|
||||
// MockProvider 用于测试的Mock Provider
|
||||
type MockProvider struct {
|
||||
name string
|
||||
costPer1KTokens float64
|
||||
qualityScore float64
|
||||
latencyMs int64
|
||||
available bool
|
||||
models []string
|
||||
}
|
||||
|
||||
func (m *MockProvider) ChatCompletion(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (*adapter.CompletionResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) ChatCompletionStream(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (<-chan *adapter.StreamChunk, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) GetUsage(response *adapter.CompletionResponse) adapter.Usage {
|
||||
return adapter.Usage{}
|
||||
}
|
||||
|
||||
func (m *MockProvider) MapError(err error) adapter.ProviderError {
|
||||
return adapter.ProviderError{}
|
||||
}
|
||||
|
||||
func (m *MockProvider) HealthCheck(ctx context.Context) bool {
|
||||
return m.available
|
||||
}
|
||||
|
||||
func (m *MockProvider) ProviderName() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockProvider) SupportedModels() []string {
|
||||
return m.models
|
||||
}
|
||||
|
||||
func (m *MockProvider) GetCostPer1KTokens() float64 {
|
||||
return m.costPer1KTokens
|
||||
}
|
||||
|
||||
func (m *MockProvider) GetQualityScore() float64 {
|
||||
return m.qualityScore
|
||||
}
|
||||
|
||||
func (m *MockProvider) GetLatencyMs() int64 {
|
||||
return m.latencyMs
|
||||
}
|
||||
|
||||
// Verify MockProvider implements adapter.ProviderAdapter
|
||||
var _ adapter.ProviderAdapter = (*MockProvider)(nil)
|
||||
78
gateway/internal/router/strategy/rollout.go
Normal file
78
gateway/internal/router/strategy/rollout.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package strategy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// RolloutStrategy 灰度发布策略
|
||||
type RolloutStrategy struct {
|
||||
percentage int // 当前灰度百分比 (0-100)
|
||||
bucketKey string // 分桶key
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRolloutStrategy 创建灰度发布策略
|
||||
func NewRolloutStrategy(percentage int, bucketKey string) *RolloutStrategy {
|
||||
return &RolloutStrategy{
|
||||
percentage: percentage,
|
||||
bucketKey: bucketKey,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPercentage 设置灰度百分比
|
||||
func (r *RolloutStrategy) SetPercentage(percentage int) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if percentage < 0 {
|
||||
percentage = 0
|
||||
}
|
||||
if percentage > 100 {
|
||||
percentage = 100
|
||||
}
|
||||
r.percentage = percentage
|
||||
}
|
||||
|
||||
// GetPercentage 获取当前灰度百分比
|
||||
func (r *RolloutStrategy) GetPercentage() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.percentage
|
||||
}
|
||||
|
||||
// ShouldApply 判断请求是否应该在灰度范围内
|
||||
func (r *RolloutStrategy) ShouldApply(req *RoutingRequest) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if r.percentage >= 100 {
|
||||
return true
|
||||
}
|
||||
if r.percentage <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 一致性哈希分桶
|
||||
bucket := r.hashString(fmt.Sprintf("%s:%s", r.bucketKey, req.UserID)) % 100
|
||||
return bucket < r.percentage
|
||||
}
|
||||
|
||||
// hashString 计算字符串哈希值 (用于一致性分桶)
|
||||
func (r *RolloutStrategy) hashString(s string) int {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(s))
|
||||
return int(h.Sum32())
|
||||
}
|
||||
|
||||
// IncrementPercentage 增加灰度百分比
|
||||
func (r *RolloutStrategy) IncrementPercentage(delta int) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.percentage += delta
|
||||
if r.percentage > 100 {
|
||||
r.percentage = 100
|
||||
}
|
||||
}
|
||||
65
gateway/internal/router/strategy/strategy_test.go
Normal file
65
gateway/internal/router/strategy/strategy_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package strategy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"lijiaoqiao/gateway/internal/adapter"
|
||||
)
|
||||
|
||||
// TestStrategyTemplate_Interface 验证策略模板接口
|
||||
func TestStrategyTemplate_Interface(t *testing.T) {
|
||||
// 所有策略实现必须实现SelectProvider, Name, Type方法
|
||||
|
||||
// 创建策略实现示例
|
||||
costBased := &CostBasedTemplate{
|
||||
name: "CostBased",
|
||||
}
|
||||
|
||||
aware := &CostAwareTemplate{
|
||||
name: "CostAware",
|
||||
}
|
||||
|
||||
// 验证实现了StrategyTemplate接口
|
||||
var _ StrategyTemplate = costBased
|
||||
var _ StrategyTemplate = aware
|
||||
|
||||
// 验证方法
|
||||
assert.Equal(t, "CostBased", costBased.Name())
|
||||
assert.Equal(t, "cost_based", costBased.Type())
|
||||
|
||||
assert.Equal(t, "CostAware", aware.Name())
|
||||
assert.Equal(t, "cost_aware", aware.Type())
|
||||
}
|
||||
|
||||
// TestStrategyTemplate_SelectProvider_Signature 验证SelectProvider方法签名
|
||||
func TestStrategyTemplate_SelectProvider_Signature(t *testing.T) {
|
||||
req := &RoutingRequest{
|
||||
Model: "gpt-4",
|
||||
UserID: "user123",
|
||||
TenantID: "tenant1",
|
||||
MaxCost: 1.0,
|
||||
MaxLatency: 1000,
|
||||
}
|
||||
|
||||
// 验证返回值 - 创建一个有providers的模板
|
||||
template := &CostBasedTemplate{
|
||||
name: "test",
|
||||
maxCostPer1KTokens: 1.0,
|
||||
providers: make(map[string]adapter.ProviderAdapter),
|
||||
}
|
||||
template.providers["test"] = &MockProvider{
|
||||
name: "test",
|
||||
costPer1KTokens: 0.5,
|
||||
available: true,
|
||||
models: []string{"gpt-4"},
|
||||
}
|
||||
|
||||
decision, err := template.SelectProvider(context.Background(), req)
|
||||
|
||||
// 接口实现应该返回决策或错误
|
||||
assert.NotNil(t, decision)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
40
gateway/internal/router/strategy/types.go
Normal file
40
gateway/internal/router/strategy/types.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package strategy
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// RoutingRequest 路由请求
|
||||
type RoutingRequest struct {
|
||||
Model string
|
||||
UserID string
|
||||
TenantID string
|
||||
Region string
|
||||
Messages []string
|
||||
MaxCost float64
|
||||
MaxLatency int64
|
||||
MinQuality float64
|
||||
}
|
||||
|
||||
// RoutingDecision 路由决策
|
||||
type RoutingDecision struct {
|
||||
Provider string
|
||||
Strategy string
|
||||
CostPer1KTokens float64
|
||||
EstimatedLatency int64
|
||||
QualityScore float64
|
||||
TakeoverMark bool // M-008: 是否标记为接管
|
||||
}
|
||||
|
||||
// StrategyTemplate 策略模板接口
|
||||
// 所有路由策略都必须实现此接口
|
||||
type StrategyTemplate interface {
|
||||
// SelectProvider 选择最佳Provider
|
||||
SelectProvider(ctx context.Context, req *RoutingRequest) (*RoutingDecision, error)
|
||||
|
||||
// Name 获取策略名称
|
||||
Name() string
|
||||
|
||||
// Type 获取策略类型
|
||||
Type() string
|
||||
}
|
||||
Reference in New Issue
Block a user