Files
tokens-reef/backend/internal/server/middleware/admin_auth_test.go
User 1a483baa90
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled
Security Scan / backend-security (push) Has been cancelled
Security Scan / frontend-security (push) Has been cancelled
feat(security): add security enhancements and tests
- Add quoteIdentifier for SQL injection defense in setup.go
- Add setup_security_test.go for security tests
- Add admin auth middleware improvements
- Add admin auth test coverage
2026-04-17 07:24:23 +08:00

294 lines
9.4 KiB
Go

package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
// =============================================================================
// Test: admin_auth.go — Pure Function Unit Tests
// 覆盖: isWebSocketUpgradeRequest, extractJWTFromWebSocketSubprotocol,
// Authorization header parsing pattern, API key detection
// =============================================================================
func TestIsWebSocketUpgradeRequest(t *testing.T) {
tests := []struct {
name string
upgradeHeader string
connectionHeader string
expected bool
}{
{"valid websocket upgrade", "websocket", "upgrade", true},
{"valid websocket with extra connection values", "websocket", "Upgrade, keep-alive", true},
{"case insensitive upgrade", "WebSocket", "Upgrade", true},
{"case insensitive connection", "websocket", "Upgrade", true},
{"wrong upgrade value", "http/2", "upgrade", false},
{"missing upgrade header", "", "upgrade", false},
{"missing connection header", "websocket", "", false},
{"both empty", "", "", false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
if tc.upgradeHeader != "" { c.Request.Header.Set("Upgrade", tc.upgradeHeader) }
if tc.connectionHeader != "" { c.Request.Header.Set("Connection", tc.connectionHeader) }
got := isWebSocketUpgradeRequest(c)
if got != tc.expected {
t.Errorf("isWebSocketUpgradeRequest() = %v, want %v (upgrade=%q, connection=%q)",
got, tc.expected, tc.upgradeHeader, tc.connectionHeader)
}
})
}
}
func TestIsWebSocketUpgradeRequest_NilContext(t *testing.T) {
assertNoPanic(t, func() { isWebSocketUpgradeRequest(nil) })
if isWebSocketUpgradeRequest(nil) != false {
t.Error("nil context should return false")
}
}
func TestExtractJWTFromWebSocketSubprotocol(t *testing.T) {
tests := []struct {
name string
protocolHeader string
expectedToken string
description string
}{
{
name: "valid jwt.token format",
protocolHeader: "sub2api-admin, jwt.eyJhbGciOiJIUzI1NiJ9.test",
expectedToken: "eyJhbGciOiJIUzI1NiJ9.test",
description: "Should extract token after jwt. prefix",
},
{
name: "jwt.token at start",
protocolHeader: "jwt.my-secret-token-here",
expectedToken: "my-secret-token-here",
description: "First protocol item can be jwt. prefixed",
},
{
name: "multiple protocols, jwt in middle",
protocolHeader: "v1, jwt.token-123, v2",
expectedToken: "token-123",
description: "Finds jwt. prefix among comma-separated items",
},
{
name: "whitespace around token",
protocolHeader: " jwt.trimmed-token ",
expectedToken: "trimmed-token",
description: "Trims whitespace from extracted token",
},
{
name: "empty after prefix returns empty",
protocolHeader: "jwt.",
expectedToken: "",
description: "Empty after prefix → no match returned",
},
{
name: "no jwt prefix",
protocolHeader: "sub2api-admin, v1, chat",
expectedToken: "",
description: "Returns empty when no jwt. prefix found",
},
{
name: "empty header",
protocolHeader: "",
expectedToken: "",
description: "Empty header returns empty",
},
{
name: "similar but wrong prefix",
protocolHeader: "jwttoken, bearer-token",
expectedToken: "",
description: "Must be exactly 'jwt.' prefix, not 'jwttoken'",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
if tc.protocolHeader != "" { c.Request.Header.Set("Sec-WebSocket-Protocol", tc.protocolHeader) }
var got string
if strings.Contains(tc.name, "nil") && strings.Contains(tc.name, "context") {
c = nil
got = extractJWTFromWebSocketSubprotocol(c)
} else {
got = extractJWTFromWebSocketSubprotocol(c)
}
if got != tc.expectedToken {
t.Errorf("extractJWTFromWebSocketSubprotocol(%q)\n got: %q\n want: %q\n (%s)",
tc.protocolHeader, got, tc.expectedToken, tc.description)
}
})
}
}
func TestExtractJWTFromWebSocketSubprotocol_NilContext(t *testing.T) {
got := extractJWTFromWebSocketSubprotocol(nil)
if got != "" { t.Errorf("nil context should return empty, got %q", got) }
}
// =============================================================================
// Test: Authorization Header Parsing Pattern
// 验证 Bearer token 解析逻辑(从 adminAuth 函数中提取的模式)
// =============================================================================
func TestParseAuthorizationHeader_BearerToken(t *testing.T) {
t.Parallel()
tests := []struct {
header string
expectToken string
expectValid bool
}{
{"Bearer eyJhbGciOiJIUzI1NiJ9.valid", "eyJhbGciOiJIUzI1NiJ9.valid", true},
{"bearer lowercase-token", "lowercase-token", true}, // case-insensitive Bearer
{"BEARER uppercase-token", "uppercase-token", true},
{"Bearer", "", false}, // no token after space
{"Basic dXNlcjpwYXNz", "", false}, // non-Bearer scheme
{"", "", false}, // empty header
{"Bearer ", "", true}, // only whitespace → trimmed empty is valid parse
{"Bearer spaced-token ", "spaced-token", true}, // trim whitespace
{"MAC token=abc", "", false}, // unknown scheme
}
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf("auth=%q", truncateStr(tc.header, 30)), func(t *testing.T) {
parts := strings.SplitN(tc.header, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
if tc.expectValid {
t.Fatalf("expected valid but parsing failed for %q", tc.header)
}
return // expected invalid
}
token := strings.TrimSpace(parts[1])
if !tc.expectValid {
t.Fatalf("expected invalid but got token %q for %q", token, tc.header)
}
if token == "" && tc.expectToken != "" {
t.Errorf("token mismatch: got empty, want %q", tc.expectToken)
}
if token != tc.expectToken {
t.Errorf("token mismatch: got %q, want %q", token, tc.expectToken)
}
})
}
}
// =============================================================================
// Test: API Key Header Detection
// =============================================================================
func TestAPIKeyHeaderDetection(t *testing.T) {
t.Parallel()
t.Run("x-api-key header present and non-empty", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("x-api-key", "my-api-key-value")
key := c.GetHeader("x-api-key")
if key != "my-api-key-value" { t.Errorf("expected api key value, got %q", key) }
})
t.Run("x-api-key header absent", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
key := c.GetHeader("x-api-key")
if key != "" { t.Errorf("expected empty, got %q", key) }
})
t.Run("x-api-key header empty string", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("x-api-key", "")
key := c.GetHeader("x-api-key")
if key != "" { t.Errorf("expected empty for empty header value, got %q", key) }
})
}
// =============================================================================
// Test: Error Response Format Consistency
// 验证所有认证失败返回统一格式
// =============================================================================
func TestAbortWithError_FormatConsistency(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
code int
errCode string
message string
}{
{401, "UNAUTHORIZED", "Authorization required"},
{401, "TOKEN_EXPIRED", "Token has expired"},
{401, "INVALID_TOKEN", "Invalid token"},
{401, "INVALID_ADMIN_KEY", "Invalid admin api key"},
{401, "USER_NOT_FOUND", "User not found"},
{401, "USER_INACTIVE", "User account is not active"},
{401, "TOKEN_REVOKED", "Token has been revoked (password changed)"},
{403, "FORBIDDEN", "Admin access required"},
{500, "INTERNAL_ERROR", "Internal server error"},
}
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf("%d_%s", tc.code, tc.errCode), func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
AbortWithError(c, tc.code, tc.errCode, tc.message)
if w.Code != tc.code {
t.Errorf("HTTP status code = %d, want %d", w.Code, tc.code)
}
body := w.Body.String()
if !strings.Contains(body, tc.errCode) {
t.Errorf("response missing error code %q, body=%s", tc.errCode, body)
}
if !strings.Contains(body, tc.message) {
t.Errorf("response missing message %q, body=%s", tc.message, body)
}
})
}
}
// Helper functions
func truncateStr(s string, max int) string {
if len(s) <= max { return s }
return s[:max] + "..."
}
func assertNoPanic(t *testing.T, fn func()) {
defer func() {
if r := recover(); r != nil {
t.Errorf("unexpected panic: %v", r)
}
}()
fn()
}