2026-04-01 08:53:28 +08:00
|
|
|
|
package middleware
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"net/http"
|
|
|
|
|
|
"net/http/httptest"
|
|
|
|
|
|
"strings"
|
|
|
|
|
|
"testing"
|
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
func TestTokenVerify(t *testing.T) {
|
|
|
|
|
|
secretKey := "test-secret-key-12345678901234567890"
|
|
|
|
|
|
issuer := "test-issuer"
|
|
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name string
|
|
|
|
|
|
token string
|
|
|
|
|
|
expectError bool
|
2026-04-01 08:53:28 +08:00
|
|
|
|
errorContains string
|
|
|
|
|
|
}{
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "valid token",
|
|
|
|
|
|
token: createTestToken(secretKey, issuer, "subject:1", "owner", time.Now().Add(time.Hour)),
|
|
|
|
|
|
expectError: false,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name: "expired token",
|
|
|
|
|
|
token: createTestToken(secretKey, issuer, "subject:1", "owner", time.Now().Add(-time.Hour)),
|
|
|
|
|
|
expectError: true,
|
2026-04-01 08:53:28 +08:00
|
|
|
|
errorContains: "expired",
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name: "wrong issuer",
|
|
|
|
|
|
token: createTestToken(secretKey, "wrong-issuer", "subject:1", "owner", time.Now().Add(time.Hour)),
|
|
|
|
|
|
expectError: true,
|
2026-04-01 08:53:28 +08:00
|
|
|
|
errorContains: "issuer",
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name: "invalid token",
|
|
|
|
|
|
token: "invalid.token.string",
|
|
|
|
|
|
expectError: true,
|
2026-04-01 08:53:28 +08:00
|
|
|
|
errorContains: "",
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
|
middleware := &AuthMiddleware{
|
|
|
|
|
|
config: AuthConfig{
|
|
|
|
|
|
SecretKey: secretKey,
|
|
|
|
|
|
Issuer: issuer,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
_, err := middleware.verifyToken(tt.token)
|
|
|
|
|
|
|
|
|
|
|
|
if tt.expectError {
|
|
|
|
|
|
if err == nil {
|
|
|
|
|
|
t.Errorf("expected error but got nil")
|
|
|
|
|
|
} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
|
|
|
|
|
|
t.Errorf("error = %v, want contains %v", err, tt.errorContains)
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
t.Errorf("unexpected error: %v", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestQueryKeyRejectMiddleware(t *testing.T) {
|
|
|
|
|
|
tests := []struct {
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name string
|
|
|
|
|
|
query string
|
2026-04-01 08:53:28 +08:00
|
|
|
|
expectStatus int
|
|
|
|
|
|
}{
|
|
|
|
|
|
{
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name: "no query params",
|
|
|
|
|
|
query: "",
|
2026-04-01 08:53:28 +08:00
|
|
|
|
expectStatus: http.StatusOK,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name: "normal params",
|
|
|
|
|
|
query: "?page=1&size=10",
|
2026-04-01 08:53:28 +08:00
|
|
|
|
expectStatus: http.StatusOK,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name: "blocked key param",
|
|
|
|
|
|
query: "?key=abc123",
|
2026-04-01 08:53:28 +08:00
|
|
|
|
expectStatus: http.StatusUnauthorized,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name: "blocked api_key param",
|
|
|
|
|
|
query: "?api_key=secret123",
|
2026-04-01 08:53:28 +08:00
|
|
|
|
expectStatus: http.StatusUnauthorized,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name: "blocked token param",
|
|
|
|
|
|
query: "?token=bearer123",
|
2026-04-01 08:53:28 +08:00
|
|
|
|
expectStatus: http.StatusUnauthorized,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name: "suspicious long param",
|
|
|
|
|
|
query: "?apikey=verylongparamvalueexceeding20chars",
|
2026-04-01 08:53:28 +08:00
|
|
|
|
expectStatus: http.StatusUnauthorized,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
|
middleware := &AuthMiddleware{
|
|
|
|
|
|
auditEmitter: nil,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
nextCalled := false
|
|
|
|
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
|
nextCalled = true
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
handler := middleware.QueryKeyRejectMiddleware(nextHandler)
|
|
|
|
|
|
|
|
|
|
|
|
req := httptest.NewRequest("POST", "/api/v1/supply/accounts"+tt.query, nil)
|
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
|
|
|
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
|
|
|
|
|
|
|
|
if tt.expectStatus == http.StatusOK {
|
|
|
|
|
|
if !nextCalled {
|
|
|
|
|
|
t.Errorf("expected next handler to be called")
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
if w.Code != tt.expectStatus {
|
|
|
|
|
|
t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestBearerExtractMiddleware(t *testing.T) {
|
|
|
|
|
|
tests := []struct {
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name string
|
|
|
|
|
|
authHeader string
|
2026-04-01 08:53:28 +08:00
|
|
|
|
expectStatus int
|
|
|
|
|
|
}{
|
|
|
|
|
|
{
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name: "valid bearer",
|
|
|
|
|
|
authHeader: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
|
2026-04-01 08:53:28 +08:00
|
|
|
|
expectStatus: http.StatusOK,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name: "missing header",
|
|
|
|
|
|
authHeader: "",
|
2026-04-01 08:53:28 +08:00
|
|
|
|
expectStatus: http.StatusUnauthorized,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name: "wrong prefix",
|
|
|
|
|
|
authHeader: "Basic abc123",
|
2026-04-01 08:53:28 +08:00
|
|
|
|
expectStatus: http.StatusUnauthorized,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
2026-04-01 13:03:44 +08:00
|
|
|
|
name: "empty token",
|
|
|
|
|
|
authHeader: "Bearer ",
|
2026-04-01 08:53:28 +08:00
|
|
|
|
expectStatus: http.StatusUnauthorized,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
|
middleware := &AuthMiddleware{}
|
|
|
|
|
|
|
|
|
|
|
|
nextCalled := false
|
|
|
|
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
|
nextCalled = true
|
|
|
|
|
|
// 检查context中是否有bearer token
|
|
|
|
|
|
if r.Context().Value(bearerTokenKey) == nil && tt.authHeader != "" && strings.HasPrefix(tt.authHeader, "Bearer ") {
|
|
|
|
|
|
// 这是预期的,因为token可能无效
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
handler := middleware.BearerExtractMiddleware(nextHandler)
|
|
|
|
|
|
|
|
|
|
|
|
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", nil)
|
|
|
|
|
|
if tt.authHeader != "" {
|
|
|
|
|
|
req.Header.Set("Authorization", tt.authHeader)
|
|
|
|
|
|
}
|
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
|
|
|
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
|
|
|
|
|
|
|
|
if tt.expectStatus == http.StatusOK {
|
|
|
|
|
|
if !nextCalled {
|
|
|
|
|
|
t.Errorf("expected next handler to be called")
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
if w.Code != tt.expectStatus {
|
|
|
|
|
|
t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestContainsScope(t *testing.T) {
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
|
name string
|
|
|
|
|
|
scopes []string
|
|
|
|
|
|
target string
|
|
|
|
|
|
expected bool
|
|
|
|
|
|
}{
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "exact match",
|
|
|
|
|
|
scopes: []string{"read", "write", "delete"},
|
|
|
|
|
|
target: "write",
|
|
|
|
|
|
expected: true,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "wildcard",
|
|
|
|
|
|
scopes: []string{"*"},
|
|
|
|
|
|
target: "anything",
|
|
|
|
|
|
expected: true,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "no match",
|
|
|
|
|
|
scopes: []string{"read", "write"},
|
|
|
|
|
|
target: "admin",
|
|
|
|
|
|
expected: false,
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
name: "empty scopes",
|
|
|
|
|
|
scopes: []string{},
|
|
|
|
|
|
target: "read",
|
|
|
|
|
|
expected: false,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
|
result := containsScope(tt.scopes, tt.target)
|
|
|
|
|
|
if result != tt.expected {
|
|
|
|
|
|
t.Errorf("containsScope(%v, %s) = %v, want %v", tt.scopes, tt.target, result, tt.expected)
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestRoleLevel(t *testing.T) {
|
|
|
|
|
|
hierarchy := map[string]int{
|
|
|
|
|
|
"admin": 3,
|
|
|
|
|
|
"owner": 2,
|
|
|
|
|
|
"viewer": 1,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
|
role string
|
|
|
|
|
|
expected int
|
|
|
|
|
|
}{
|
|
|
|
|
|
{"admin", 3},
|
|
|
|
|
|
{"owner", 2},
|
|
|
|
|
|
{"viewer", 1},
|
|
|
|
|
|
{"unknown", 0},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
|
t.Run(tt.role, func(t *testing.T) {
|
|
|
|
|
|
result := roleLevel(tt.role, hierarchy)
|
|
|
|
|
|
if result != tt.expected {
|
|
|
|
|
|
t.Errorf("roleLevel(%s) = %d, want %d", tt.role, result, tt.expected)
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestTokenCache(t *testing.T) {
|
|
|
|
|
|
cache := NewTokenCache()
|
|
|
|
|
|
|
|
|
|
|
|
t.Run("get empty", func(t *testing.T) {
|
|
|
|
|
|
status, found := cache.Get("nonexistent")
|
|
|
|
|
|
if found {
|
|
|
|
|
|
t.Errorf("expected not found")
|
|
|
|
|
|
}
|
|
|
|
|
|
if status != "" {
|
|
|
|
|
|
t.Errorf("expected empty status")
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
t.Run("set and get", func(t *testing.T) {
|
|
|
|
|
|
cache.Set("token1", "active", time.Hour)
|
|
|
|
|
|
|
|
|
|
|
|
status, found := cache.Get("token1")
|
|
|
|
|
|
if !found {
|
|
|
|
|
|
t.Errorf("expected to find token1")
|
|
|
|
|
|
}
|
|
|
|
|
|
if status != "active" {
|
|
|
|
|
|
t.Errorf("expected status 'active', got '%s'", status)
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
t.Run("invalidate", func(t *testing.T) {
|
|
|
|
|
|
cache.Set("token2", "revoked", time.Hour)
|
|
|
|
|
|
cache.Invalidate("token2")
|
|
|
|
|
|
|
|
|
|
|
|
_, found := cache.Get("token2")
|
|
|
|
|
|
if found {
|
|
|
|
|
|
t.Errorf("expected token2 to be invalidated")
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
t.Run("expiration", func(t *testing.T) {
|
|
|
|
|
|
cache.Set("token3", "active", time.Nanosecond)
|
|
|
|
|
|
time.Sleep(time.Millisecond)
|
|
|
|
|
|
|
|
|
|
|
|
_, found := cache.Get("token3")
|
|
|
|
|
|
if found {
|
|
|
|
|
|
t.Errorf("expected token3 to be expired")
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Helper functions
|
|
|
|
|
|
|
|
|
|
|
|
func createTestToken(secretKey, issuer, subject, role string, expiresAt time.Time) string {
|
|
|
|
|
|
claims := TokenClaims{
|
|
|
|
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
|
|
|
|
Issuer: issuer,
|
|
|
|
|
|
Subject: subject,
|
|
|
|
|
|
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
|
|
|
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
|
|
|
|
},
|
|
|
|
|
|
SubjectID: subject,
|
2026-04-01 13:03:44 +08:00
|
|
|
|
Role: role,
|
|
|
|
|
|
Scope: []string{"read", "write"},
|
|
|
|
|
|
TenantID: 1,
|
2026-04-01 08:53:28 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
|
|
|
|
tokenString, _ := token.SignedString([]byte(secretKey))
|
|
|
|
|
|
return tokenString
|
|
|
|
|
|
}
|