Files
lijiaoqiao/supply-api/internal/middleware/auth_test.go
2026-04-11 09:25:31 +08:00

1032 lines
26 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"lijiaoqiao/supply-api/internal/iam/model"
)
type stubTokenStatusBackend struct {
status string
err error
}
func (b *stubTokenStatusBackend) CheckTokenStatus(ctx context.Context, tokenID string) (string, error) {
return b.status, b.err
}
func TestTokenVerify(t *testing.T) {
secretKey := "test-secret-key-12345678901234567890"
issuer := "test-issuer"
tests := []struct {
name string
token string
expectError bool
errorContains string
}{
{
name: "valid token",
token: createTestToken(secretKey, issuer, "subject:1", "owner", time.Now().Add(time.Hour)),
expectError: false,
},
{
name: "expired token",
token: createTestToken(secretKey, issuer, "subject:1", "owner", time.Now().Add(-time.Hour)),
expectError: true,
errorContains: "expired",
},
{
name: "wrong issuer",
token: createTestToken(secretKey, "wrong-issuer", "subject:1", "owner", time.Now().Add(time.Hour)),
expectError: true,
errorContains: "issuer",
},
{
name: "invalid token",
token: "invalid.token.string",
expectError: true,
errorContains: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := &AuthMiddleware{
config: AuthConfig{
SecretKey: secretKey,
Issuer: issuer,
},
}
_, err := middleware.verifyToken(tt.token)
if tt.expectError {
if err == nil {
t.Errorf("expected error but got nil")
} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("error = %v, want contains %v", err, tt.errorContains)
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
})
}
}
func TestQueryKeyRejectMiddleware(t *testing.T) {
tests := []struct {
name string
query string
expectStatus int
}{
{
name: "no query params",
query: "",
expectStatus: http.StatusOK,
},
{
name: "normal params",
query: "?page=1&size=10",
expectStatus: http.StatusOK,
},
{
name: "blocked key param",
query: "?key=abc123",
expectStatus: http.StatusUnauthorized,
},
{
name: "blocked api_key param",
query: "?api_key=secret123",
expectStatus: http.StatusUnauthorized,
},
{
name: "blocked token param",
query: "?token=bearer123",
expectStatus: http.StatusUnauthorized,
},
{
name: "suspicious long param",
query: "?apikey=verylongparamvalueexceeding20chars",
expectStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := &AuthMiddleware{
auditEmitter: nil,
}
nextCalled := false
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
})
handler := middleware.QueryKeyRejectMiddleware(nextHandler)
req := httptest.NewRequest("POST", "/api/v1/supply/accounts"+tt.query, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if tt.expectStatus == http.StatusOK {
if !nextCalled {
t.Errorf("expected next handler to be called")
}
} else {
if w.Code != tt.expectStatus {
t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code)
}
}
})
}
}
func TestBearerExtractMiddleware(t *testing.T) {
tests := []struct {
name string
authHeader string
expectStatus int
}{
{
name: "valid bearer",
authHeader: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
expectStatus: http.StatusOK,
},
{
name: "missing header",
authHeader: "",
expectStatus: http.StatusUnauthorized,
},
{
name: "wrong prefix",
authHeader: "Basic abc123",
expectStatus: http.StatusUnauthorized,
},
{
name: "empty token",
authHeader: "Bearer ",
expectStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := &AuthMiddleware{}
nextCalled := false
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
// 检查context中是否有bearer token
if r.Context().Value(bearerTokenKey) == nil && tt.authHeader != "" && strings.HasPrefix(tt.authHeader, "Bearer ") {
// 这是预期的因为token可能无效
}
})
handler := middleware.BearerExtractMiddleware(nextHandler)
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", nil)
if tt.authHeader != "" {
req.Header.Set("Authorization", tt.authHeader)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if tt.expectStatus == http.StatusOK {
if !nextCalled {
t.Errorf("expected next handler to be called")
}
} else {
if w.Code != tt.expectStatus {
t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code)
}
}
})
}
}
func TestContainsScope(t *testing.T) {
tests := []struct {
name string
scopes []string
target string
expected bool
}{
{
name: "exact match",
scopes: []string{"read", "write", "delete"},
target: "write",
expected: true,
},
{
name: "wildcard",
scopes: []string{"*"},
target: "anything",
expected: true,
},
{
name: "no match",
scopes: []string{"read", "write"},
target: "admin",
expected: false,
},
{
name: "empty scopes",
scopes: []string{},
target: "read",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := containsScope(tt.scopes, tt.target)
if result != tt.expected {
t.Errorf("containsScope(%v, %s) = %v, want %v", tt.scopes, tt.target, result, tt.expected)
}
})
}
}
func TestRoleLevel(t *testing.T) {
tests := []struct {
role string
expected int
}{
{"super_admin", 100},
{"org_admin", 50},
{"supply_admin", 40},
{"operator", 30},
{"developer", 20},
{"finops", 20},
{"viewer", 10},
{"unknown", 0},
}
for _, tt := range tests {
t.Run(tt.role, func(t *testing.T) {
result := model.GetRoleLevelByCode(tt.role)
if result != tt.expected {
t.Errorf("GetRoleLevelByCode(%s) = %d, want %d", tt.role, result, tt.expected)
}
})
}
}
func TestTokenCache(t *testing.T) {
cache := NewTokenCache()
t.Run("get empty", func(t *testing.T) {
status, found := cache.Get("nonexistent")
if found {
t.Errorf("expected not found")
}
if status != "" {
t.Errorf("expected empty status")
}
})
t.Run("set and get", func(t *testing.T) {
cache.Set("token1", "active", time.Hour)
status, found := cache.Get("token1")
if !found {
t.Errorf("expected to find token1")
}
if status != "active" {
t.Errorf("expected status 'active', got '%s'", status)
}
})
t.Run("invalidate", func(t *testing.T) {
cache.Set("token2", "revoked", time.Hour)
cache.Invalidate("token2")
_, found := cache.Get("token2")
if found {
t.Errorf("expected token2 to be invalidated")
}
})
t.Run("expiration", func(t *testing.T) {
cache.Set("token3", "active", time.Nanosecond)
time.Sleep(time.Millisecond)
_, found := cache.Get("token3")
if found {
t.Errorf("expected token3 to be expired")
}
})
}
// HIGH-02: JWT算法验证 - 当前只支持HS256
// 注意: HS384/HS512/RS256需要配置支持测试当前仅验证HS256
func TestHIGH02_JWT_AlgorithmValidation(t *testing.T) {
secretKey := "test-secret-key-12345678901234567890"
issuer := "test-issuer"
tests := []struct {
name string
signingMethod jwt.SigningMethod
expectError bool
errorContains string
}{
{
name: "HS256 should be accepted with secret key",
signingMethod: jwt.SigningMethodHS256,
expectError: false,
},
{
name: "HS384 requires different implementation",
signingMethod: jwt.SigningMethodHS384,
expectError: true,
errorContains: "unexpected signing method",
},
{
name: "HS512 requires different implementation",
signingMethod: jwt.SigningMethodHS512,
expectError: true,
errorContains: "unexpected signing method",
},
{
name: "none algorithm should be rejected",
signingMethod: jwt.SigningMethodNone,
expectError: true,
errorContains: "malformed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claims := TokenClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: issuer,
Subject: "subject:1",
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
SubjectID: "subject:1",
Role: "owner",
Scope: []string{"read", "write"},
TenantID: 1,
}
token := jwt.NewWithClaims(tt.signingMethod, claims)
tokenString, _ := token.SignedString([]byte(secretKey))
middleware := &AuthMiddleware{
config: AuthConfig{
SecretKey: secretKey,
Issuer: issuer,
},
}
_, err := middleware.verifyToken(tokenString)
if tt.expectError {
if err == nil {
t.Errorf("expected error but got nil")
} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("error = %v, want contains %v", err, tt.errorContains)
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
})
}
}
// TestP001_RS256WithPublicKey RS256算法需要配置公钥验证
func TestP001_RS256WithPublicKey(t *testing.T) {
// 这个测试验证RS256需要公钥配置
// 使用rsa.GeneratingKey方式创建测试密钥
// 注意这个测试只验证配置逻辑不实际验证RS256签名
t.Skip("RS256 verification requires RSA key pair setup - tested in token_format_test.go")
}
// MED-02: checkTokenStatus缓存未命中时应该查询后端而不是默认返回active
func TestMED02_TokenCacheMiss_ShouldNotAssumeActive(t *testing.T) {
// arrange
middleware := &AuthMiddleware{
config: AuthConfig{
SecretKey: "test-secret-key-12345678901234567890",
Issuer: "test-issuer",
},
tokenCache: NewTokenCache(), // 空的缓存
// 没有设置tokenBackend
}
// act - 查询一个不在缓存中的token
status, err := middleware.checkTokenStatus(context.Background(), "nonexistent-token-id")
// assert - 缓存未命中且没有后端时应该返回错误(安全修复)
// 修复前bug缓存未命中时默认返回"active"
// 修复后:缓存未命中且没有后端时返回错误
if err == nil {
t.Errorf("MED-02: cache miss without backend should return error, got status='%s'", status)
}
}
func TestTokenVerifyMiddleware_BackendErrorShouldReject(t *testing.T) {
secretKey := "test-secret-key-12345678901234567890"
issuer := "test-issuer"
authMiddleware := NewAuthMiddleware(AuthConfig{
SecretKey: secretKey,
Issuer: issuer,
Enabled: true,
}, NewTokenCache(), &stubTokenStatusBackend{err: errors.New("database unavailable")}, nil)
nextCalled := false
handler := authMiddleware.TokenVerifyMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
}))
req := httptest.NewRequest("GET", "/api/v1/supply/accounts", nil)
req = req.WithContext(context.WithValue(req.Context(), bearerTokenKey, createTestToken(secretKey, issuer, "subject:1", "org_admin", time.Now().Add(time.Hour))))
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if nextCalled {
t.Fatal("expected request to be rejected when token backend is unavailable")
}
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected status 401, got %d", w.Code)
}
if !strings.Contains(w.Body.String(), "AUTH_TOKEN_STATUS_UNAVAILABLE") {
t.Fatalf("expected response to contain AUTH_TOKEN_STATUS_UNAVAILABLE, got %s", w.Body.String())
}
}
// Helper functions
func createTestToken(secretKey, issuer, subject, role string, expiresAt time.Time) string {
claims := TokenClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: issuer,
Subject: subject,
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
SubjectID: subject,
Role: role,
Scope: []string{"read", "write"},
TenantID: 1,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, _ := token.SignedString([]byte(secretKey))
return tokenString
}
// ==================== BruteForceProtection Tests ====================
func TestNewBruteForceProtection(t *testing.T) {
bp := NewBruteForceProtection(5, time.Minute)
if bp.maxAttempts != 5 {
t.Errorf("expected maxAttempts 5, got %d", bp.maxAttempts)
}
if bp.lockoutDuration != time.Minute {
t.Errorf("expected lockoutDuration 1m, got %v", bp.lockoutDuration)
}
if bp.attempts == nil {
t.Error("expected attempts map to be initialized")
}
}
func TestBruteForceProtection_RecordFailedAttempt(t *testing.T) {
bp := NewBruteForceProtection(3, time.Minute)
// 连续调用3次后应该锁定
bp.RecordFailedAttempt("192.168.1.1")
bp.RecordFailedAttempt("192.168.1.1")
bp.RecordFailedAttempt("192.168.1.1")
locked, remaining := bp.IsLocked("192.168.1.1")
if !locked {
t.Error("should be locked after 3 attempts")
}
if remaining <= 0 {
t.Error("remaining time should be positive when locked")
}
}
func TestBruteForceProtection_IsLocked(t *testing.T) {
bp := NewBruteForceProtection(2, time.Hour)
// 未记录的IP应该不锁定
locked, _ := bp.IsLocked("192.168.1.100")
if locked {
t.Error("unrecorded IP should not be locked")
}
// 达到最大尝试次数应该锁定
bp.RecordFailedAttempt("192.168.1.2")
bp.RecordFailedAttempt("192.168.1.2")
locked, remaining := bp.IsLocked("192.168.1.2")
if !locked {
t.Error("should be locked after 2 attempts")
}
if remaining <= 0 || remaining > time.Hour {
t.Errorf("remaining time should be within lockout duration, got %v", remaining)
}
}
func TestBruteForceProtection_Reset(t *testing.T) {
bp := NewBruteForceProtection(2, time.Hour)
// 锁定IP
bp.RecordFailedAttempt("192.168.1.1")
bp.RecordFailedAttempt("192.168.1.1")
locked, _ := bp.IsLocked("192.168.1.1")
if !locked {
t.Error("should be locked before reset")
}
// 重置
bp.Reset("192.168.1.1")
locked, _ = bp.IsLocked("192.168.1.1")
if locked {
t.Error("should not be locked after reset")
}
}
func TestBruteForceProtection_CleanExpired(t *testing.T) {
bp := NewBruteForceProtection(1, time.Millisecond)
// 锁定IP
bp.RecordFailedAttempt("192.168.1.1")
bp.RecordFailedAttempt("192.168.1.1")
// 等待锁定过期
time.Sleep(5 * time.Millisecond)
// 清理
bp.CleanExpired()
// IP应该不再被锁定记录应该被清理
locked, _ := bp.IsLocked("192.168.1.1")
if locked {
t.Error("expired lock should be cleaned")
}
}
func TestBruteForceProtection_Len(t *testing.T) {
bp := NewBruteForceProtection(3, time.Hour)
if bp.Len() != 0 {
t.Errorf("expected 0, got %d", bp.Len())
}
bp.RecordFailedAttempt("192.168.1.1")
bp.RecordFailedAttempt("192.168.1.2")
if bp.Len() != 2 {
t.Errorf("expected 2, got %d", bp.Len())
}
bp.Reset("192.168.1.1")
if bp.Len() != 1 {
t.Errorf("expected 1 after reset, got %d", bp.Len())
}
}
func TestBruteForceProtection_MultipleIPs(t *testing.T) {
bp := NewBruteForceProtection(2, time.Hour)
// 不同IP独立计数
bp.RecordFailedAttempt("192.168.1.1")
bp.RecordFailedAttempt("192.168.1.2")
// 第一个IP再失败一次应该锁定
bp.RecordFailedAttempt("192.168.1.1")
locked1, _ := bp.IsLocked("192.168.1.1")
locked2, _ := bp.IsLocked("192.168.1.2")
if !locked1 {
t.Error("192.168.1.1 should be locked")
}
if locked2 {
t.Error("192.168.1.2 should still not be locked")
}
}
// ==================== Helper Function Tests ====================
func TestGetRequestID(t *testing.T) {
tests := []struct {
name string
headers map[string]string
expectedID string
}{
{
name: "X-Request-Id header",
headers: map[string]string{"X-Request-Id": "req-123"},
expectedID: "req-123",
},
{
name: "X-Request-ID header (uppercase)",
headers: map[string]string{"X-Request-ID": "req-456"},
expectedID: "req-456",
},
{
name: "X-Request-Id only",
headers: map[string]string{"X-Request-Id": "req-123"},
expectedID: "req-123",
},
{
name: "both empty",
headers: map[string]string{},
expectedID: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
for k, v := range tt.headers {
req.Header.Set(k, v)
}
id := getRequestID(req)
if id != tt.expectedID {
t.Errorf("expected '%s', got '%s'", tt.expectedID, id)
}
})
}
}
func TestGetClientIP(t *testing.T) {
// 可信代理配置 - 包含测试中使用的RemoteAddr范围
trustedProxies := []string{"192.168.0.0/16", "10.0.0.0/8"}
tests := []struct {
name string
headers map[string]string
remoteAddr string
trusted []string // 可信代理配置nil表示使用默认不信任
expectedIP string
}{
{
name: "X-Forwarded-For single (trusted proxy)",
headers: map[string]string{"X-Forwarded-For": "203.0.113.1"},
remoteAddr: "192.168.1.1:1234",
trusted: trustedProxies,
expectedIP: "203.0.113.1",
},
{
name: "X-Forwarded-For multiple (trusted proxy)",
headers: map[string]string{"X-Forwarded-For": "203.0.113.1, 198.51.100.1, 10.0.0.1"},
remoteAddr: "192.168.1.1:1234",
trusted: trustedProxies,
expectedIP: "203.0.113.1",
},
{
name: "X-Real-IP (trusted proxy)",
headers: map[string]string{"X-Real-IP": "203.0.113.5"},
remoteAddr: "192.168.1.1:1234",
trusted: trustedProxies,
expectedIP: "203.0.113.5",
},
{
name: "X-Forwarded-For takes precedence (trusted proxy)",
headers: map[string]string{"X-Forwarded-For": "203.0.113.1", "X-Real-IP": "203.0.113.5"},
remoteAddr: "192.168.1.1:1234",
trusted: trustedProxies,
expectedIP: "203.0.113.1",
},
{
name: "fallback to RemoteAddr (no trusted proxy)",
headers: map[string]string{},
remoteAddr: "192.168.1.1:1234",
trusted: nil, // 不配置可信代理
expectedIP: "192.168.1.1",
},
{
name: "SEC-003: Untrusted source ignores X-Forwarded-For",
headers: map[string]string{"X-Forwarded-For": "203.0.113.1"},
remoteAddr: "203.0.113.1:1234", // 公网IP作为RemoteAddr
trusted: trustedProxies, // 但不在可信代理范围内
expectedIP: "203.0.113.1", // 应该使用RemoteAddr
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
for k, v := range tt.headers {
req.Header.Set(k, v)
}
req.RemoteAddr = tt.remoteAddr
ip := getClientIP(req, tt.trusted...)
if ip != tt.expectedIP {
t.Errorf("expected '%s', got '%s'", tt.expectedIP, ip)
}
})
}
}
func TestParseSubjectID(t *testing.T) {
tests := []struct {
name string
subject string
expected int64
}{
{
name: "valid subject with prefix",
subject: "user:12345",
expected: 12345,
},
{
name: "subject without prefix",
subject: "12345",
expected: 0,
},
{
name: "empty subject",
subject: "",
expected: 0,
},
{
name: "invalid number",
subject: "user:abc",
expected: 0,
},
{
name: "multiple colons",
subject: "user:12345:extra",
expected: 12345,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
id := parseSubjectID(tt.subject)
if id != tt.expected {
t.Errorf("expected %d, got %d", tt.expected, id)
}
})
}
}
func TestComputeFingerprint(t *testing.T) {
fp1 := ComputeFingerprint("test-credential-123")
fp2 := ComputeFingerprint("test-credential-123")
fp3 := ComputeFingerprint("different-credential")
if fp1 != fp2 {
t.Error("same input should produce same fingerprint")
}
if fp1 == fp3 {
t.Error("different inputs should produce different fingerprints")
}
if len(fp1) != 64 { // SHA256 produces 64 hex characters
t.Errorf("expected 64 hex chars, got %d", len(fp1))
}
}
// ==================== GetTokenClaims Tests ====================
func TestGetTokenClaims(t *testing.T) {
t.Run("with valid claims", func(t *testing.T) {
claims := &TokenClaims{
SubjectID: "user:123",
Role: "admin",
TenantID: 1,
}
ctx := context.WithValue(context.Background(), tokenClaimsKey, claims)
result := GetTokenClaims(ctx)
if result == nil {
t.Fatal("expected claims, got nil")
}
if result.SubjectID != "user:123" {
t.Errorf("expected SubjectID 'user:123', got '%s'", result.SubjectID)
}
})
t.Run("without claims", func(t *testing.T) {
ctx := context.Background()
result := GetTokenClaims(ctx)
if result != nil {
t.Error("expected nil when no claims in context")
}
})
t.Run("with wrong type", func(t *testing.T) {
ctx := context.WithValue(context.Background(), tokenClaimsKey, "not a token claims")
result := GetTokenClaims(ctx)
if result != nil {
t.Error("expected nil when wrong type in context")
}
})
}
// ==================== NewAuthMiddleware Tests ====================
func TestNewAuthMiddleware_DefaultCacheTTL(t *testing.T) {
config := AuthConfig{
SecretKey: "test-secret",
Issuer: "test-issuer",
CacheTTL: 0, // 应该使用默认值
}
mw := NewAuthMiddleware(config, nil, nil, nil)
if mw.config.CacheTTL != 30*time.Second {
t.Errorf("expected default CacheTTL 30s, got %v", mw.config.CacheTTL)
}
}
func TestNewAuthMiddleware_ExplicitCacheTTL(t *testing.T) {
config := AuthConfig{
SecretKey: "test-secret",
Issuer: "test-issuer",
CacheTTL: 30 * time.Second, // 显式设置
}
mw := NewAuthMiddleware(config, nil, nil, nil)
if mw.config.CacheTTL != 30*time.Second {
t.Errorf("expected explicit CacheTTL 30s, got %v", mw.config.CacheTTL)
}
}
// ==================== ScopeRoleAuthzMiddleware Tests ====================
func TestScopeRoleAuthzMiddleware(t *testing.T) {
secretKey := "test-secret-key-12345678901234567890"
issuer := "test-issuer"
// 创建一个有效的token
claims := &TokenClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: issuer,
Subject: "user:1",
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
SubjectID: "user:1",
Role: "viewer",
Scope: []string{"read"},
TenantID: 1,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
_, _ = token.SignedString([]byte(secretKey)) // tokenString not used in these tests
middleware := &AuthMiddleware{
config: AuthConfig{
SecretKey: secretKey,
Issuer: issuer,
},
}
tests := []struct {
name string
path string
setupContext func(r *http.Request)
requiredScope string
expectStatus int
}{
{
name: "missing claims in context",
path: "/api/v1/supply/accounts",
setupContext: func(r *http.Request) { /* 不设置claims */ },
requiredScope: "",
expectStatus: http.StatusUnauthorized,
},
{
name: "insufficient role for accounts",
path: "/api/v1/supply/accounts",
setupContext: func(r *http.Request) {
ctx := context.WithValue(r.Context(), tokenClaimsKey, claims)
*r = *r.WithContext(ctx)
},
requiredScope: "",
expectStatus: http.StatusForbidden,
},
{
name: "sufficient role for accounts",
path: "/api/v1/supply/accounts",
setupContext: func(r *http.Request) {
adminClaims := &TokenClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: issuer,
Subject: "user:1",
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
SubjectID: "user:1",
Role: "org_admin",
Scope: []string{"read", "write"},
TenantID: 1,
}
ctx := context.WithValue(r.Context(), tokenClaimsKey, adminClaims)
*r = *r.WithContext(ctx)
},
requiredScope: "",
expectStatus: http.StatusOK,
},
{
name: "viewer can access billing",
path: "/api/v1/supply/billing",
setupContext: func(r *http.Request) {
ctx := context.WithValue(r.Context(), tokenClaimsKey, claims)
*r = *r.WithContext(ctx)
},
requiredScope: "",
expectStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nextCalled := false
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
})
handler := middleware.ScopeRoleAuthzMiddleware(tt.requiredScope)(nextHandler)
req := httptest.NewRequest("GET", tt.path, nil)
tt.setupContext(req)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if tt.expectStatus == http.StatusOK {
if !nextCalled {
t.Error("expected next handler to be called")
}
} else {
if w.Code != tt.expectStatus {
t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code)
}
}
})
}
}
// ==================== TokenCache Extended Tests ====================
func TestTokenCache_Len(t *testing.T) {
cache := NewTokenCache()
if cache.Len() != 0 {
t.Errorf("expected 0, got %d", cache.Len())
}
cache.Set("token1", "active", time.Hour)
if cache.Len() != 1 {
t.Errorf("expected 1, got %d", cache.Len())
}
cache.Set("token2", "active", time.Hour)
if cache.Len() != 2 {
t.Errorf("expected 2, got %d", cache.Len())
}
cache.Invalidate("token1")
if cache.Len() != 1 {
t.Errorf("expected 1 after invalidate, got %d", cache.Len())
}
}
func TestTokenCache_CleanExpired(t *testing.T) {
cache := NewTokenCache()
// 设置一个立即过期的token
cache.Set("expired-token", "active", time.Nanosecond)
time.Sleep(time.Millisecond)
if cache.Len() != 1 {
t.Errorf("expected 1 before clean, got %d", cache.Len())
}
cache.CleanExpired()
if cache.Len() != 0 {
t.Errorf("expected 0 after clean, got %d", cache.Len())
}
}