feat(P1/P2): 完成TDD开发及P1/P2设计文档

## 设计文档
- multi_role_permission_design: 多角色权限设计 (CONDITIONAL GO)
- audit_log_enhancement_design: 审计日志增强 (CONDITIONAL GO)
- routing_strategy_template_design: 路由策略模板 (CONDITIONAL GO)
- sso_saml_technical_research: SSO/SAML调研 (CONDITIONAL GO)
- compliance_capability_package_design: 合规能力包设计 (CONDITIONAL GO)

## TDD开发成果
- IAM模块: supply-api/internal/iam/ (111个测试)
- 审计日志模块: supply-api/internal/audit/ (40+测试)
- 路由策略模块: gateway/internal/router/ (33+测试)
- 合规能力包: gateway/internal/compliance/ + scripts/ci/compliance/

## 规范文档
- parallel_agent_output_quality_standards: 并行Agent产出质量规范
- project_experience_summary: 项目经验总结 (v2)
- 2026-04-02-p1-p2-tdd-execution-plan: TDD执行计划

## 评审报告
- 5个CONDITIONAL GO设计文档评审报告
- fix_verification_report: 修复验证报告
- full_verification_report: 全面质量验证报告
- tdd_module_quality_verification: TDD模块质量验证
- tdd_execution_summary: TDD执行总结

依据: Superpowers执行框架 + TDD规范
This commit is contained in:
Your Name
2026-04-02 23:35:53 +08:00
parent ed0961d486
commit 89104bd0db
94 changed files with 24738 additions and 5 deletions

View File

@@ -0,0 +1,71 @@
package strategy
import (
"fmt"
"hash/fnv"
"time"
)
// ABStrategy A/B测试策略
type ABStrategy struct {
controlStrategy *RoutingStrategyTemplate
experimentStrategy *RoutingStrategyTemplate
trafficSplit int // 实验组流量百分比 (0-100)
bucketKey string // 分桶key
experimentID string
startTime *time.Time
endTime *time.Time
}
// NewABStrategy 创建A/B测试策略
func NewABStrategy(control, experiment *RoutingStrategyTemplate, split int, bucketKey string) *ABStrategy {
return &ABStrategy{
controlStrategy: control,
experimentStrategy: experiment,
trafficSplit: split,
bucketKey: bucketKey,
}
}
// ShouldApplyToRequest 判断请求是否应该使用实验组策略
func (a *ABStrategy) ShouldApplyToRequest(req *RoutingRequest) bool {
// 检查时间范围
now := time.Now()
if a.startTime != nil && now.Before(*a.startTime) {
return false
}
if a.endTime != nil && now.After(*a.endTime) {
return false
}
// 一致性哈希分桶
bucket := a.hashString(fmt.Sprintf("%s:%s", a.bucketKey, req.UserID)) % 100
return bucket < a.trafficSplit
}
// hashString 计算字符串哈希值 (用于一致性分桶)
func (a *ABStrategy) hashString(s string) int {
h := fnv.New32a()
h.Write([]byte(s))
return int(h.Sum32())
}
// GetControlStrategy 获取对照组策略
func (a *ABStrategy) GetControlStrategy() *RoutingStrategyTemplate {
return a.controlStrategy
}
// GetExperimentStrategy 获取实验组策略
func (a *ABStrategy) GetExperimentStrategy() *RoutingStrategyTemplate {
return a.experimentStrategy
}
// RoutingStrategyTemplate 路由策略模板
type RoutingStrategyTemplate struct {
ID string
Name string
Type string
Priority int
Enabled bool
Description string
}

View File

@@ -0,0 +1,161 @@
package strategy
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestABStrategy_TrafficSplit 测试A/B测试流量分配
func TestABStrategy_TrafficSplit(t *testing.T) {
ab := &ABStrategy{
controlStrategy: &RoutingStrategyTemplate{ID: "control"},
experimentStrategy: &RoutingStrategyTemplate{ID: "experiment"},
trafficSplit: 20, // 20%实验组
bucketKey: "user_id",
}
// 验证流量分配
// 一致性哈希同一user_id始终分配到同一组
controlCount := 0
experimentCount := 0
for i := 0; i < 100; i++ {
userID := string(rune('0' + i))
isExperiment := ab.ShouldApplyToRequest(&RoutingRequest{UserID: userID})
if isExperiment {
experimentCount++
} else {
controlCount++
}
}
// 验证一致性同一user_id应该始终在同一组
for i := 0; i < 10; i++ {
userID := "test_user_123"
first := ab.ShouldApplyToRequest(&RoutingRequest{UserID: userID})
for j := 0; j < 10; j++ {
second := ab.ShouldApplyToRequest(&RoutingRequest{UserID: userID})
assert.Equal(t, first, second, "Same user_id should always be in same group")
}
}
// 验证分配比例大约是80:20
assert.InDelta(t, 80, controlCount, 15, "Control should be around 80%%")
assert.InDelta(t, 20, experimentCount, 15, "Experiment should be around 20%%")
}
// TestRollout_Percentage 测试灰度发布百分比递增
func TestRollout_Percentage(t *testing.T) {
rollout := &RolloutStrategy{
percentage: 10,
bucketKey: "user_id",
}
// 统计10%时的用户数
count10 := 0
for i := 0; i < 100; i++ {
userID := string(rune('0' + i))
if rollout.ShouldApply(&RoutingRequest{UserID: userID}) {
count10++
}
}
assert.InDelta(t, 10, count10, 5, "10%% rollout should have around 10 users")
// 增加百分比到20%
rollout.SetPercentage(20)
// 统计20%时的用户数
count20 := 0
for i := 0; i < 100; i++ {
userID := string(rune('0' + i))
if rollout.ShouldApply(&RoutingRequest{UserID: userID}) {
count20++
}
}
assert.InDelta(t, 20, count20, 5, "20%% rollout should have around 20 users")
// 增加百分比到50%
rollout.SetPercentage(50)
// 统计50%时的用户数
count50 := 0
for i := 0; i < 100; i++ {
userID := string(rune('0' + i))
if rollout.ShouldApply(&RoutingRequest{UserID: userID}) {
count50++
}
}
assert.InDelta(t, 50, count50, 10, "50%% rollout should have around 50 users")
// 增加百分比到100%
rollout.SetPercentage(100)
// 验证100%时所有用户都在
for i := 0; i < 100; i++ {
userID := string(rune('0' + i))
assert.True(t, rollout.ShouldApply(&RoutingRequest{UserID: userID}), "All users should be in 100% rollout")
}
}
// TestRollout_Consistency 测试灰度发布一致性
func TestRollout_Consistency(t *testing.T) {
rollout := &RolloutStrategy{
percentage: 30,
bucketKey: "user_id",
}
// 同一用户应该始终被同样对待
userID := "consistent_user"
firstResult := rollout.ShouldApply(&RoutingRequest{UserID: userID})
for i := 0; i < 100; i++ {
result := rollout.ShouldApply(&RoutingRequest{UserID: userID})
assert.Equal(t, firstResult, result, "Same user should always have same result")
}
}
// TestRollout_PercentageIncrease 测试百分比递增
func TestRollout_PercentageIncrease(t *testing.T) {
rollout := &RolloutStrategy{
percentage: 10,
bucketKey: "user_id",
}
// 收集10%时的用户
var in10Percent []string
for i := 0; i < 100; i++ {
userID := string(rune('a' + i))
if rollout.ShouldApply(&RoutingRequest{UserID: userID}) {
in10Percent = append(in10Percent, userID)
}
}
// 增加百分比到50%
rollout.SetPercentage(50)
// 收集50%时的用户
var in50Percent []string
for i := 0; i < 100; i++ {
userID := string(rune('a' + i))
if rollout.ShouldApply(&RoutingRequest{UserID: userID}) {
in50Percent = append(in50Percent, userID)
}
}
// 50%的用户应该包含10%的用户(一致性)
for _, userID := range in10Percent {
found := false
for _, id := range in50Percent {
if userID == id {
found = true
break
}
}
assert.True(t, found, "10%% users should be included in 50%% rollout")
}
// 50%应该包含更多用户
assert.Greater(t, len(in50Percent), len(in10Percent), "50%% should have more users than 10%%")
}

View File

@@ -0,0 +1,189 @@
package strategy
import (
"context"
"errors"
"lijiaoqiao/gateway/internal/adapter"
"lijiaoqiao/gateway/internal/router/scoring"
gwerror "lijiaoqiao/gateway/pkg/error"
)
// ErrNoQualifiedProvider 没有符合条件的Provider
var ErrNoQualifiedProvider = errors.New("no qualified provider available")
// CostAwareTemplate 成本感知策略模板
// 综合考虑成本、质量、延迟进行权衡
type CostAwareTemplate struct {
name string
maxCostPer1KTokens float64
maxLatencyMs int64
minQualityScore float64
providers map[string]adapter.ProviderAdapter
scoringModel *scoring.ScoringModel
}
// CostAwareParams 成本感知参数
type CostAwareParams struct {
MaxCostPer1KTokens float64
MaxLatencyMs int64
MinQualityScore float64
}
// NewCostAwareTemplate 创建成本感知策略模板
func NewCostAwareTemplate(name string, params CostAwareParams) *CostAwareTemplate {
return &CostAwareTemplate{
name: name,
maxCostPer1KTokens: params.MaxCostPer1KTokens,
maxLatencyMs: params.MaxLatencyMs,
minQualityScore: params.MinQualityScore,
providers: make(map[string]adapter.ProviderAdapter),
scoringModel: scoring.NewScoringModel(scoring.DefaultWeights),
}
}
// RegisterProvider 注册Provider
func (t *CostAwareTemplate) RegisterProvider(name string, provider adapter.ProviderAdapter) {
t.providers[name] = provider
}
// Name 获取策略名称
func (t *CostAwareTemplate) Name() string {
return t.name
}
// Type 获取策略类型
func (t *CostAwareTemplate) Type() string {
return "cost_aware"
}
// SelectProvider 选择最佳平衡的Provider
func (t *CostAwareTemplate) SelectProvider(ctx context.Context, req *RoutingRequest) (*RoutingDecision, error) {
if len(t.providers) == 0 {
return nil, gwerror.NewGatewayError(gwerror.ROUTER_NO_PROVIDER_AVAILABLE, "no provider registered")
}
type candidate struct {
name string
cost float64
quality float64
latency int64
score float64
}
var candidates []candidate
maxCost := t.maxCostPer1KTokens
if req.MaxCost > 0 && req.MaxCost < maxCost {
maxCost = req.MaxCost
}
maxLatency := t.maxLatencyMs
if req.MaxLatency > 0 && req.MaxLatency < maxLatency {
maxLatency = req.MaxLatency
}
minQuality := t.minQualityScore
if req.MinQuality > 0 && req.MinQuality > minQuality {
minQuality = req.MinQuality
}
for name, provider := range t.providers {
// 检查provider是否支持该模型
supported := false
for _, m := range provider.SupportedModels() {
if m == req.Model || m == "*" {
supported = true
break
}
}
if !supported {
continue
}
// 检查健康状态
if !provider.HealthCheck(ctx) {
continue
}
// 获取provider指标
cost := t.getProviderCost(provider)
quality := t.getProviderQuality(provider)
latency := t.getProviderLatency(provider)
// 过滤不满足基本条件的provider
if cost > maxCost || latency > maxLatency || quality < minQuality {
continue
}
// 计算综合评分
metrics := scoring.ProviderMetrics{
Name: name,
LatencyMs: latency,
Availability: 1.0, // 假设可用
CostPer1KTokens: cost,
QualityScore: quality,
}
score := t.scoringModel.CalculateScore(metrics)
candidates = append(candidates, candidate{
name: name,
cost: cost,
quality: quality,
latency: latency,
score: score,
})
}
if len(candidates) == 0 {
return nil, ErrNoQualifiedProvider
}
// 选择评分最高的provider
best := &candidates[0]
for i := 1; i < len(candidates); i++ {
if candidates[i].score > best.score {
best = &candidates[i]
}
}
return &RoutingDecision{
Provider: best.name,
Strategy: t.Type(),
CostPer1KTokens: best.cost,
EstimatedLatency: best.latency,
QualityScore: best.quality,
TakeoverMark: true, // M-008: 标记为接管
}, nil
}
// getProviderCost 获取Provider的成本
func (t *CostAwareTemplate) getProviderCost(provider adapter.ProviderAdapter) float64 {
if cp, ok := provider.(CostAwareProvider); ok {
return cp.GetCostPer1KTokens()
}
return 0.5
}
// getProviderQuality 获取Provider的质量分数
func (t *CostAwareTemplate) getProviderQuality(provider adapter.ProviderAdapter) float64 {
if qp, ok := provider.(QualityProvider); ok {
return qp.GetQualityScore()
}
return 0.8 // 默认质量分数
}
// getProviderLatency 获取Provider的延迟
func (t *CostAwareTemplate) getProviderLatency(provider adapter.ProviderAdapter) int64 {
if lp, ok := provider.(LatencyProvider); ok {
return lp.GetLatencyMs()
}
return 100 // 默认延迟100ms
}
// QualityProvider 质量感知Provider接口
type QualityProvider interface {
GetQualityScore() float64
}
// LatencyProvider 延迟感知Provider接口
type LatencyProvider interface {
GetLatencyMs() int64
}

View File

@@ -0,0 +1,108 @@
package strategy
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
// TestCostAwareStrategy_Balance 测试成本感知策略的平衡选择
func TestCostAwareStrategy_Balance(t *testing.T) {
template := NewCostAwareTemplate("CostAware", CostAwareParams{
MaxCostPer1KTokens: 1.0,
MaxLatencyMs: 500,
MinQualityScore: 0.7,
})
// 注册多个providers
// ProviderA: 低成本, 低质量
template.providers["ProviderA"] = &MockProvider{
name: "ProviderA",
costPer1KTokens: 0.2,
available: true,
models: []string{"gpt-4"},
qualityScore: 0.6, // 质量不达标
latencyMs: 100,
}
// ProviderB: 中成本, 高质量, 低延迟
template.providers["ProviderB"] = &MockProvider{
name: "ProviderB",
costPer1KTokens: 0.5,
available: true,
models: []string{"gpt-4"},
qualityScore: 0.9,
latencyMs: 150,
}
// ProviderC: 高成本, 高质量, 高延迟
template.providers["ProviderC"] = &MockProvider{
name: "ProviderC",
costPer1KTokens: 0.9,
available: true,
models: []string{"gpt-4"},
qualityScore: 0.95,
latencyMs: 400,
}
req := &RoutingRequest{
Model: "gpt-4",
UserID: "user123",
MaxCost: 1.0,
MaxLatency: 500,
MinQuality: 0.7,
}
decision, err := template.SelectProvider(context.Background(), req)
// 验证选择逻辑
assert.NoError(t, err)
assert.NotNil(t, decision)
// ProviderA因质量不达标应被排除
// ProviderB应在成本/质量/延迟权衡中胜出
assert.Equal(t, "ProviderB", decision.Provider, "Should select balanced provider")
assert.GreaterOrEqual(t, decision.QualityScore, 0.7, "Quality should meet minimum")
assert.LessOrEqual(t, decision.CostPer1KTokens, 1.0, "Cost should be within budget")
assert.LessOrEqual(t, decision.EstimatedLatency, int64(500), "Latency should be within limit")
}
// TestCostAwareStrategy_QualityThreshold 测试质量阈值过滤
func TestCostAwareStrategy_QualityThreshold(t *testing.T) {
template := NewCostAwareTemplate("CostAware", CostAwareParams{
MaxCostPer1KTokens: 1.0,
MaxLatencyMs: 1000,
MinQualityScore: 0.9, // 高质量要求
})
// 所有provider质量都不达标
template.providers["ProviderA"] = &MockProvider{
name: "ProviderA",
costPer1KTokens: 0.3,
available: true,
models: []string{"gpt-4"},
qualityScore: 0.7,
latencyMs: 100,
}
template.providers["ProviderB"] = &MockProvider{
name: "ProviderB",
costPer1KTokens: 0.4,
available: true,
models: []string{"gpt-4"},
qualityScore: 0.8,
latencyMs: 150,
}
req := &RoutingRequest{
Model: "gpt-4",
UserID: "user123",
MinQuality: 0.9,
}
decision, err := template.SelectProvider(context.Background(), req)
// 应该返回错误因为没有满足质量要求的provider
assert.Error(t, err)
assert.Nil(t, decision)
}

View File

@@ -0,0 +1,132 @@
package strategy
import (
"context"
"errors"
"sort"
"lijiaoqiao/gateway/internal/adapter"
gwerror "lijiaoqiao/gateway/pkg/error"
)
// ErrNoAffordableProvider 没有可负担的Provider
var ErrNoAffordableProvider = errors.New("no affordable provider available")
// CostBasedTemplate 成本优先策略模板
// 选择成本最低的provider
type CostBasedTemplate struct {
name string
maxCostPer1KTokens float64
providers map[string]adapter.ProviderAdapter
}
// CostParams 成本参数
type CostParams struct {
// 最大成本 ($/1K tokens)
MaxCostPer1KTokens float64
}
// NewCostBasedTemplate 创建成本优先策略模板
func NewCostBasedTemplate(name string, params CostParams) *CostBasedTemplate {
return &CostBasedTemplate{
name: name,
maxCostPer1KTokens: params.MaxCostPer1KTokens,
providers: make(map[string]adapter.ProviderAdapter),
}
}
// RegisterProvider 注册Provider
func (t *CostBasedTemplate) RegisterProvider(name string, provider adapter.ProviderAdapter) {
t.providers[name] = provider
}
// Name 获取策略名称
func (t *CostBasedTemplate) Name() string {
return t.name
}
// Type 获取策略类型
func (t *CostBasedTemplate) Type() string {
return "cost_based"
}
// SelectProvider 选择成本最低的Provider
func (t *CostBasedTemplate) SelectProvider(ctx context.Context, req *RoutingRequest) (*RoutingDecision, error) {
if len(t.providers) == 0 {
return nil, gwerror.NewGatewayError(gwerror.ROUTER_NO_PROVIDER_AVAILABLE, "no provider registered")
}
// 收集所有可用provider的候选列表
type candidate struct {
name string
cost float64
}
var candidates []candidate
for name, provider := range t.providers {
// 检查provider是否支持该模型
supported := false
for _, m := range provider.SupportedModels() {
if m == req.Model || m == "*" {
supported = true
break
}
}
if !supported {
continue
}
// 检查健康状态
if !provider.HealthCheck(ctx) {
continue
}
// 获取成本信息 (实际实现需要从provider获取)
// 这里暂时设置为模拟值
cost := t.getProviderCost(provider)
candidates = append(candidates, candidate{name: name, cost: cost})
}
if len(candidates) == 0 {
return nil, gwerror.NewGatewayError(gwerror.ROUTER_NO_PROVIDER_AVAILABLE, "no available provider for model: "+req.Model)
}
// 按成本排序
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].cost < candidates[j].cost
})
// 选择成本最低且在预算内的provider
maxCost := t.maxCostPer1KTokens
if req.MaxCost > 0 && req.MaxCost < maxCost {
maxCost = req.MaxCost
}
for _, c := range candidates {
if c.cost <= maxCost {
return &RoutingDecision{
Provider: c.name,
Strategy: t.Type(),
CostPer1KTokens: c.cost,
TakeoverMark: true, // M-008: 标记为接管
}, nil
}
}
return nil, ErrNoAffordableProvider
}
// CostAwareProvider 成本感知Provider接口
type CostAwareProvider interface {
GetCostPer1KTokens() float64
}
// getProviderCost 获取Provider的成本
func (t *CostBasedTemplate) getProviderCost(provider adapter.ProviderAdapter) float64 {
// 尝试类型断言获取成本
if cp, ok := provider.(CostAwareProvider); ok {
return cp.GetCostPer1KTokens()
}
// 默认返回0.5实际应从provider元数据获取
return 0.5
}

View File

@@ -0,0 +1,142 @@
package strategy
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"lijiaoqiao/gateway/internal/adapter"
)
// TestCostBasedStrategy_SelectProvider 测试成本优先策略选择Provider
func TestCostBasedStrategy_SelectProvider(t *testing.T) {
template := &CostBasedTemplate{
name: "CostBased",
maxCostPer1KTokens: 1.0,
providers: make(map[string]adapter.ProviderAdapter),
}
// 注册mock providers
template.providers["ProviderA"] = &MockProvider{
name: "ProviderA",
costPer1KTokens: 0.5,
available: true,
models: []string{"gpt-4"},
}
template.providers["ProviderB"] = &MockProvider{
name: "ProviderB",
costPer1KTokens: 0.3, // 最低成本
available: true,
models: []string{"gpt-4"},
}
template.providers["ProviderC"] = &MockProvider{
name: "ProviderC",
costPer1KTokens: 0.8,
available: true,
models: []string{"gpt-4"},
}
req := &RoutingRequest{
Model: "gpt-4",
UserID: "user123",
MaxCost: 1.0,
}
decision, err := template.SelectProvider(context.Background(), req)
// 验证选择了最低成本的Provider
assert.NoError(t, err)
assert.NotNil(t, decision)
assert.Equal(t, "ProviderB", decision.Provider, "Should select lowest cost provider")
assert.LessOrEqual(t, decision.CostPer1KTokens, 1.0, "Cost should be within budget")
}
func TestCostBasedStrategy_Fallback(t *testing.T) {
// 成本超出阈值时fallback
template := &CostBasedTemplate{
name: "CostBased",
maxCostPer1KTokens: 0.5, // 设置低成本上限
providers: make(map[string]adapter.ProviderAdapter),
}
// 注册成本较高的providers
template.providers["ProviderA"] = &MockProvider{
name: "ProviderA",
costPer1KTokens: 0.8,
available: true,
models: []string{"gpt-4"},
}
template.providers["ProviderB"] = &MockProvider{
name: "ProviderB",
costPer1KTokens: 1.0,
available: true,
models: []string{"gpt-4"},
}
req := &RoutingRequest{
Model: "gpt-4",
UserID: "user123",
MaxCost: 0.5,
}
decision, err := template.SelectProvider(context.Background(), req)
// 应该返回错误
assert.Error(t, err, "Should return error when no affordable provider")
assert.Nil(t, decision, "Should not return decision when cost exceeds threshold")
assert.Equal(t, ErrNoAffordableProvider, err, "Should return ErrNoAffordableProvider")
}
// MockProvider 用于测试的Mock Provider
type MockProvider struct {
name string
costPer1KTokens float64
qualityScore float64
latencyMs int64
available bool
models []string
}
func (m *MockProvider) ChatCompletion(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (*adapter.CompletionResponse, error) {
return nil, nil
}
func (m *MockProvider) ChatCompletionStream(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (<-chan *adapter.StreamChunk, error) {
return nil, nil
}
func (m *MockProvider) GetUsage(response *adapter.CompletionResponse) adapter.Usage {
return adapter.Usage{}
}
func (m *MockProvider) MapError(err error) adapter.ProviderError {
return adapter.ProviderError{}
}
func (m *MockProvider) HealthCheck(ctx context.Context) bool {
return m.available
}
func (m *MockProvider) ProviderName() string {
return m.name
}
func (m *MockProvider) SupportedModels() []string {
return m.models
}
func (m *MockProvider) GetCostPer1KTokens() float64 {
return m.costPer1KTokens
}
func (m *MockProvider) GetQualityScore() float64 {
return m.qualityScore
}
func (m *MockProvider) GetLatencyMs() int64 {
return m.latencyMs
}
// Verify MockProvider implements adapter.ProviderAdapter
var _ adapter.ProviderAdapter = (*MockProvider)(nil)

View File

@@ -0,0 +1,78 @@
package strategy
import (
"fmt"
"hash/fnv"
"sync"
)
// RolloutStrategy 灰度发布策略
type RolloutStrategy struct {
percentage int // 当前灰度百分比 (0-100)
bucketKey string // 分桶key
mu sync.RWMutex
}
// NewRolloutStrategy 创建灰度发布策略
func NewRolloutStrategy(percentage int, bucketKey string) *RolloutStrategy {
return &RolloutStrategy{
percentage: percentage,
bucketKey: bucketKey,
}
}
// SetPercentage 设置灰度百分比
func (r *RolloutStrategy) SetPercentage(percentage int) {
r.mu.Lock()
defer r.mu.Unlock()
if percentage < 0 {
percentage = 0
}
if percentage > 100 {
percentage = 100
}
r.percentage = percentage
}
// GetPercentage 获取当前灰度百分比
func (r *RolloutStrategy) GetPercentage() int {
r.mu.RLock()
defer r.mu.RUnlock()
return r.percentage
}
// ShouldApply 判断请求是否应该在灰度范围内
func (r *RolloutStrategy) ShouldApply(req *RoutingRequest) bool {
r.mu.RLock()
defer r.mu.RUnlock()
if r.percentage >= 100 {
return true
}
if r.percentage <= 0 {
return false
}
// 一致性哈希分桶
bucket := r.hashString(fmt.Sprintf("%s:%s", r.bucketKey, req.UserID)) % 100
return bucket < r.percentage
}
// hashString 计算字符串哈希值 (用于一致性分桶)
func (r *RolloutStrategy) hashString(s string) int {
h := fnv.New32a()
h.Write([]byte(s))
return int(h.Sum32())
}
// IncrementPercentage 增加灰度百分比
func (r *RolloutStrategy) IncrementPercentage(delta int) {
r.mu.Lock()
defer r.mu.Unlock()
r.percentage += delta
if r.percentage > 100 {
r.percentage = 100
}
}

View File

@@ -0,0 +1,65 @@
package strategy
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"lijiaoqiao/gateway/internal/adapter"
)
// TestStrategyTemplate_Interface 验证策略模板接口
func TestStrategyTemplate_Interface(t *testing.T) {
// 所有策略实现必须实现SelectProvider, Name, Type方法
// 创建策略实现示例
costBased := &CostBasedTemplate{
name: "CostBased",
}
aware := &CostAwareTemplate{
name: "CostAware",
}
// 验证实现了StrategyTemplate接口
var _ StrategyTemplate = costBased
var _ StrategyTemplate = aware
// 验证方法
assert.Equal(t, "CostBased", costBased.Name())
assert.Equal(t, "cost_based", costBased.Type())
assert.Equal(t, "CostAware", aware.Name())
assert.Equal(t, "cost_aware", aware.Type())
}
// TestStrategyTemplate_SelectProvider_Signature 验证SelectProvider方法签名
func TestStrategyTemplate_SelectProvider_Signature(t *testing.T) {
req := &RoutingRequest{
Model: "gpt-4",
UserID: "user123",
TenantID: "tenant1",
MaxCost: 1.0,
MaxLatency: 1000,
}
// 验证返回值 - 创建一个有providers的模板
template := &CostBasedTemplate{
name: "test",
maxCostPer1KTokens: 1.0,
providers: make(map[string]adapter.ProviderAdapter),
}
template.providers["test"] = &MockProvider{
name: "test",
costPer1KTokens: 0.5,
available: true,
models: []string{"gpt-4"},
}
decision, err := template.SelectProvider(context.Background(), req)
// 接口实现应该返回决策或错误
assert.NotNil(t, decision)
assert.Nil(t, err)
}

View File

@@ -0,0 +1,40 @@
package strategy
import (
"context"
)
// RoutingRequest 路由请求
type RoutingRequest struct {
Model string
UserID string
TenantID string
Region string
Messages []string
MaxCost float64
MaxLatency int64
MinQuality float64
}
// RoutingDecision 路由决策
type RoutingDecision struct {
Provider string
Strategy string
CostPer1KTokens float64
EstimatedLatency int64
QualityScore float64
TakeoverMark bool // M-008: 是否标记为接管
}
// StrategyTemplate 策略模板接口
// 所有路由策略都必须实现此接口
type StrategyTemplate interface {
// SelectProvider 选择最佳Provider
SelectProvider(ctx context.Context, req *RoutingRequest) (*RoutingDecision, error)
// Name 获取策略名称
Name() string
// Type 获取策略类型
Type() string
}