package middleware import ( "fmt" "net/http" "net/http/httptest" "strings" "testing" "github.com/gin-gonic/gin" ) // ============================================================================= // Test: admin_auth.go — Pure Function Unit Tests // 覆盖: isWebSocketUpgradeRequest, extractJWTFromWebSocketSubprotocol, // Authorization header parsing pattern, API key detection // ============================================================================= func TestIsWebSocketUpgradeRequest(t *testing.T) { tests := []struct { name string upgradeHeader string connectionHeader string expected bool }{ {"valid websocket upgrade", "websocket", "upgrade", true}, {"valid websocket with extra connection values", "websocket", "Upgrade, keep-alive", true}, {"case insensitive upgrade", "WebSocket", "Upgrade", true}, {"case insensitive connection", "websocket", "Upgrade", true}, {"wrong upgrade value", "http/2", "upgrade", false}, {"missing upgrade header", "", "upgrade", false}, {"missing connection header", "websocket", "", false}, {"both empty", "", "", false}, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodGet, "/", nil) if tc.upgradeHeader != "" { c.Request.Header.Set("Upgrade", tc.upgradeHeader) } if tc.connectionHeader != "" { c.Request.Header.Set("Connection", tc.connectionHeader) } got := isWebSocketUpgradeRequest(c) if got != tc.expected { t.Errorf("isWebSocketUpgradeRequest() = %v, want %v (upgrade=%q, connection=%q)", got, tc.expected, tc.upgradeHeader, tc.connectionHeader) } }) } } func TestIsWebSocketUpgradeRequest_NilContext(t *testing.T) { assertNoPanic(t, func() { isWebSocketUpgradeRequest(nil) }) if isWebSocketUpgradeRequest(nil) != false { t.Error("nil context should return false") } } func TestExtractJWTFromWebSocketSubprotocol(t *testing.T) { tests := []struct { name string protocolHeader string expectedToken string description string }{ { name: "valid jwt.token format", protocolHeader: "sub2api-admin, jwt.eyJhbGciOiJIUzI1NiJ9.test", expectedToken: "eyJhbGciOiJIUzI1NiJ9.test", description: "Should extract token after jwt. prefix", }, { name: "jwt.token at start", protocolHeader: "jwt.my-secret-token-here", expectedToken: "my-secret-token-here", description: "First protocol item can be jwt. prefixed", }, { name: "multiple protocols, jwt in middle", protocolHeader: "v1, jwt.token-123, v2", expectedToken: "token-123", description: "Finds jwt. prefix among comma-separated items", }, { name: "whitespace around token", protocolHeader: " jwt.trimmed-token ", expectedToken: "trimmed-token", description: "Trims whitespace from extracted token", }, { name: "empty after prefix returns empty", protocolHeader: "jwt.", expectedToken: "", description: "Empty after prefix → no match returned", }, { name: "no jwt prefix", protocolHeader: "sub2api-admin, v1, chat", expectedToken: "", description: "Returns empty when no jwt. prefix found", }, { name: "empty header", protocolHeader: "", expectedToken: "", description: "Empty header returns empty", }, { name: "similar but wrong prefix", protocolHeader: "jwttoken, bearer-token", expectedToken: "", description: "Must be exactly 'jwt.' prefix, not 'jwttoken'", }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodGet, "/", nil) if tc.protocolHeader != "" { c.Request.Header.Set("Sec-WebSocket-Protocol", tc.protocolHeader) } var got string if strings.Contains(tc.name, "nil") && strings.Contains(tc.name, "context") { c = nil got = extractJWTFromWebSocketSubprotocol(c) } else { got = extractJWTFromWebSocketSubprotocol(c) } if got != tc.expectedToken { t.Errorf("extractJWTFromWebSocketSubprotocol(%q)\n got: %q\n want: %q\n (%s)", tc.protocolHeader, got, tc.expectedToken, tc.description) } }) } } func TestExtractJWTFromWebSocketSubprotocol_NilContext(t *testing.T) { got := extractJWTFromWebSocketSubprotocol(nil) if got != "" { t.Errorf("nil context should return empty, got %q", got) } } // ============================================================================= // Test: Authorization Header Parsing Pattern // 验证 Bearer token 解析逻辑(从 adminAuth 函数中提取的模式) // ============================================================================= func TestParseAuthorizationHeader_BearerToken(t *testing.T) { t.Parallel() tests := []struct { header string expectToken string expectValid bool }{ {"Bearer eyJhbGciOiJIUzI1NiJ9.valid", "eyJhbGciOiJIUzI1NiJ9.valid", true}, {"bearer lowercase-token", "lowercase-token", true}, // case-insensitive Bearer {"BEARER uppercase-token", "uppercase-token", true}, {"Bearer", "", false}, // no token after space {"Basic dXNlcjpwYXNz", "", false}, // non-Bearer scheme {"", "", false}, // empty header {"Bearer ", "", true}, // only whitespace → trimmed empty is valid parse {"Bearer spaced-token ", "spaced-token", true}, // trim whitespace {"MAC token=abc", "", false}, // unknown scheme } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("auth=%q", truncateStr(tc.header, 30)), func(t *testing.T) { parts := strings.SplitN(tc.header, " ", 2) if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { if tc.expectValid { t.Fatalf("expected valid but parsing failed for %q", tc.header) } return // expected invalid } token := strings.TrimSpace(parts[1]) if !tc.expectValid { t.Fatalf("expected invalid but got token %q for %q", token, tc.header) } if token == "" && tc.expectToken != "" { t.Errorf("token mismatch: got empty, want %q", tc.expectToken) } if token != tc.expectToken { t.Errorf("token mismatch: got %q, want %q", token, tc.expectToken) } }) } } // ============================================================================= // Test: API Key Header Detection // ============================================================================= func TestAPIKeyHeaderDetection(t *testing.T) { t.Parallel() t.Run("x-api-key header present and non-empty", func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodGet, "/", nil) c.Request.Header.Set("x-api-key", "my-api-key-value") key := c.GetHeader("x-api-key") if key != "my-api-key-value" { t.Errorf("expected api key value, got %q", key) } }) t.Run("x-api-key header absent", func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodGet, "/", nil) key := c.GetHeader("x-api-key") if key != "" { t.Errorf("expected empty, got %q", key) } }) t.Run("x-api-key header empty string", func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodGet, "/", nil) c.Request.Header.Set("x-api-key", "") key := c.GetHeader("x-api-key") if key != "" { t.Errorf("expected empty for empty header value, got %q", key) } }) } // ============================================================================= // Test: Error Response Format Consistency // 验证所有认证失败返回统一格式 // ============================================================================= func TestAbortWithError_FormatConsistency(t *testing.T) { gin.SetMode(gin.TestMode) tests := []struct { code int errCode string message string }{ {401, "UNAUTHORIZED", "Authorization required"}, {401, "TOKEN_EXPIRED", "Token has expired"}, {401, "INVALID_TOKEN", "Invalid token"}, {401, "INVALID_ADMIN_KEY", "Invalid admin api key"}, {401, "USER_NOT_FOUND", "User not found"}, {401, "USER_INACTIVE", "User account is not active"}, {401, "TOKEN_REVOKED", "Token has been revoked (password changed)"}, {403, "FORBIDDEN", "Admin access required"}, {500, "INTERNAL_ERROR", "Internal server error"}, } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%d_%s", tc.code, tc.errCode), func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest(http.MethodGet, "/", nil) AbortWithError(c, tc.code, tc.errCode, tc.message) if w.Code != tc.code { t.Errorf("HTTP status code = %d, want %d", w.Code, tc.code) } body := w.Body.String() if !strings.Contains(body, tc.errCode) { t.Errorf("response missing error code %q, body=%s", tc.errCode, body) } if !strings.Contains(body, tc.message) { t.Errorf("response missing message %q, body=%s", tc.message, body) } }) } } // Helper functions func truncateStr(s string, max int) string { if len(s) <= max { return s } return s[:max] + "..." } func assertNoPanic(t *testing.T, fn func()) { defer func() { if r := recover(); r != nil { t.Errorf("unexpected panic: %v", r) } }() fn() }