- Add quoteIdentifier for SQL injection defense in setup.go - Add setup_security_test.go for security tests - Add admin auth middleware improvements - Add admin auth test coverage
294 lines
9.4 KiB
Go
294 lines
9.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// =============================================================================
|
|
// Test: admin_auth.go — Pure Function Unit Tests
|
|
// 覆盖: isWebSocketUpgradeRequest, extractJWTFromWebSocketSubprotocol,
|
|
// Authorization header parsing pattern, API key detection
|
|
// =============================================================================
|
|
|
|
func TestIsWebSocketUpgradeRequest(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
upgradeHeader string
|
|
connectionHeader string
|
|
expected bool
|
|
}{
|
|
{"valid websocket upgrade", "websocket", "upgrade", true},
|
|
{"valid websocket with extra connection values", "websocket", "Upgrade, keep-alive", true},
|
|
{"case insensitive upgrade", "WebSocket", "Upgrade", true},
|
|
{"case insensitive connection", "websocket", "Upgrade", true},
|
|
{"wrong upgrade value", "http/2", "upgrade", false},
|
|
{"missing upgrade header", "", "upgrade", false},
|
|
{"missing connection header", "websocket", "", false},
|
|
{"both empty", "", "", false},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
if tc.upgradeHeader != "" { c.Request.Header.Set("Upgrade", tc.upgradeHeader) }
|
|
if tc.connectionHeader != "" { c.Request.Header.Set("Connection", tc.connectionHeader) }
|
|
|
|
got := isWebSocketUpgradeRequest(c)
|
|
if got != tc.expected {
|
|
t.Errorf("isWebSocketUpgradeRequest() = %v, want %v (upgrade=%q, connection=%q)",
|
|
got, tc.expected, tc.upgradeHeader, tc.connectionHeader)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsWebSocketUpgradeRequest_NilContext(t *testing.T) {
|
|
assertNoPanic(t, func() { isWebSocketUpgradeRequest(nil) })
|
|
if isWebSocketUpgradeRequest(nil) != false {
|
|
t.Error("nil context should return false")
|
|
}
|
|
}
|
|
|
|
func TestExtractJWTFromWebSocketSubprotocol(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
protocolHeader string
|
|
expectedToken string
|
|
description string
|
|
}{
|
|
{
|
|
name: "valid jwt.token format",
|
|
protocolHeader: "sub2api-admin, jwt.eyJhbGciOiJIUzI1NiJ9.test",
|
|
expectedToken: "eyJhbGciOiJIUzI1NiJ9.test",
|
|
description: "Should extract token after jwt. prefix",
|
|
},
|
|
{
|
|
name: "jwt.token at start",
|
|
protocolHeader: "jwt.my-secret-token-here",
|
|
expectedToken: "my-secret-token-here",
|
|
description: "First protocol item can be jwt. prefixed",
|
|
},
|
|
{
|
|
name: "multiple protocols, jwt in middle",
|
|
protocolHeader: "v1, jwt.token-123, v2",
|
|
expectedToken: "token-123",
|
|
description: "Finds jwt. prefix among comma-separated items",
|
|
},
|
|
{
|
|
name: "whitespace around token",
|
|
protocolHeader: " jwt.trimmed-token ",
|
|
expectedToken: "trimmed-token",
|
|
description: "Trims whitespace from extracted token",
|
|
},
|
|
{
|
|
name: "empty after prefix returns empty",
|
|
protocolHeader: "jwt.",
|
|
expectedToken: "",
|
|
description: "Empty after prefix → no match returned",
|
|
},
|
|
{
|
|
name: "no jwt prefix",
|
|
protocolHeader: "sub2api-admin, v1, chat",
|
|
expectedToken: "",
|
|
description: "Returns empty when no jwt. prefix found",
|
|
},
|
|
{
|
|
name: "empty header",
|
|
protocolHeader: "",
|
|
expectedToken: "",
|
|
description: "Empty header returns empty",
|
|
},
|
|
{
|
|
name: "similar but wrong prefix",
|
|
protocolHeader: "jwttoken, bearer-token",
|
|
expectedToken: "",
|
|
description: "Must be exactly 'jwt.' prefix, not 'jwttoken'",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
if tc.protocolHeader != "" { c.Request.Header.Set("Sec-WebSocket-Protocol", tc.protocolHeader) }
|
|
|
|
var got string
|
|
if strings.Contains(tc.name, "nil") && strings.Contains(tc.name, "context") {
|
|
c = nil
|
|
got = extractJWTFromWebSocketSubprotocol(c)
|
|
} else {
|
|
got = extractJWTFromWebSocketSubprotocol(c)
|
|
}
|
|
|
|
if got != tc.expectedToken {
|
|
t.Errorf("extractJWTFromWebSocketSubprotocol(%q)\n got: %q\n want: %q\n (%s)",
|
|
tc.protocolHeader, got, tc.expectedToken, tc.description)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtractJWTFromWebSocketSubprotocol_NilContext(t *testing.T) {
|
|
got := extractJWTFromWebSocketSubprotocol(nil)
|
|
if got != "" { t.Errorf("nil context should return empty, got %q", got) }
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test: Authorization Header Parsing Pattern
|
|
// 验证 Bearer token 解析逻辑(从 adminAuth 函数中提取的模式)
|
|
// =============================================================================
|
|
|
|
func TestParseAuthorizationHeader_BearerToken(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
header string
|
|
expectToken string
|
|
expectValid bool
|
|
}{
|
|
{"Bearer eyJhbGciOiJIUzI1NiJ9.valid", "eyJhbGciOiJIUzI1NiJ9.valid", true},
|
|
{"bearer lowercase-token", "lowercase-token", true}, // case-insensitive Bearer
|
|
{"BEARER uppercase-token", "uppercase-token", true},
|
|
{"Bearer", "", false}, // no token after space
|
|
{"Basic dXNlcjpwYXNz", "", false}, // non-Bearer scheme
|
|
{"", "", false}, // empty header
|
|
{"Bearer ", "", true}, // only whitespace → trimmed empty is valid parse
|
|
{"Bearer spaced-token ", "spaced-token", true}, // trim whitespace
|
|
{"MAC token=abc", "", false}, // unknown scheme
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(fmt.Sprintf("auth=%q", truncateStr(tc.header, 30)), func(t *testing.T) {
|
|
parts := strings.SplitN(tc.header, " ", 2)
|
|
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
|
if tc.expectValid {
|
|
t.Fatalf("expected valid but parsing failed for %q", tc.header)
|
|
}
|
|
return // expected invalid
|
|
}
|
|
token := strings.TrimSpace(parts[1])
|
|
if !tc.expectValid {
|
|
t.Fatalf("expected invalid but got token %q for %q", token, tc.header)
|
|
}
|
|
if token == "" && tc.expectToken != "" {
|
|
t.Errorf("token mismatch: got empty, want %q", tc.expectToken)
|
|
}
|
|
if token != tc.expectToken {
|
|
t.Errorf("token mismatch: got %q, want %q", token, tc.expectToken)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test: API Key Header Detection
|
|
// =============================================================================
|
|
|
|
func TestAPIKeyHeaderDetection(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("x-api-key header present and non-empty", func(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
c.Request.Header.Set("x-api-key", "my-api-key-value")
|
|
|
|
key := c.GetHeader("x-api-key")
|
|
if key != "my-api-key-value" { t.Errorf("expected api key value, got %q", key) }
|
|
})
|
|
|
|
t.Run("x-api-key header absent", func(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
|
|
key := c.GetHeader("x-api-key")
|
|
if key != "" { t.Errorf("expected empty, got %q", key) }
|
|
})
|
|
|
|
t.Run("x-api-key header empty string", func(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
c.Request.Header.Set("x-api-key", "")
|
|
|
|
key := c.GetHeader("x-api-key")
|
|
if key != "" { t.Errorf("expected empty for empty header value, got %q", key) }
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test: Error Response Format Consistency
|
|
// 验证所有认证失败返回统一格式
|
|
// =============================================================================
|
|
|
|
func TestAbortWithError_FormatConsistency(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
tests := []struct {
|
|
code int
|
|
errCode string
|
|
message string
|
|
}{
|
|
{401, "UNAUTHORIZED", "Authorization required"},
|
|
{401, "TOKEN_EXPIRED", "Token has expired"},
|
|
{401, "INVALID_TOKEN", "Invalid token"},
|
|
{401, "INVALID_ADMIN_KEY", "Invalid admin api key"},
|
|
{401, "USER_NOT_FOUND", "User not found"},
|
|
{401, "USER_INACTIVE", "User account is not active"},
|
|
{401, "TOKEN_REVOKED", "Token has been revoked (password changed)"},
|
|
{403, "FORBIDDEN", "Admin access required"},
|
|
{500, "INTERNAL_ERROR", "Internal server error"},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(fmt.Sprintf("%d_%s", tc.code, tc.errCode), func(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
|
|
AbortWithError(c, tc.code, tc.errCode, tc.message)
|
|
|
|
if w.Code != tc.code {
|
|
t.Errorf("HTTP status code = %d, want %d", w.Code, tc.code)
|
|
}
|
|
body := w.Body.String()
|
|
if !strings.Contains(body, tc.errCode) {
|
|
t.Errorf("response missing error code %q, body=%s", tc.errCode, body)
|
|
}
|
|
if !strings.Contains(body, tc.message) {
|
|
t.Errorf("response missing message %q, body=%s", tc.message, body)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
func truncateStr(s string, max int) string {
|
|
if len(s) <= max { return s }
|
|
return s[:max] + "..."
|
|
}
|
|
|
|
func assertNoPanic(t *testing.T, fn func()) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
t.Errorf("unexpected panic: %v", r)
|
|
}
|
|
}()
|
|
fn()
|
|
}
|