chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,204 @@
|
||||
// Package middleware provides HTTP middleware for authentication, authorization, and request processing.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// NewAdminAuthMiddleware 创建管理员认证中间件
|
||||
func NewAdminAuthMiddleware(
|
||||
authService *service.AuthService,
|
||||
userService *service.UserService,
|
||||
settingService *service.SettingService,
|
||||
) AdminAuthMiddleware {
|
||||
return AdminAuthMiddleware(adminAuth(authService, userService, settingService))
|
||||
}
|
||||
|
||||
// adminAuth 管理员认证中间件实现
|
||||
// 支持两种认证方式(通过不同的 header 区分):
|
||||
// 1. Admin API Key: x-api-key: <admin-api-key>
|
||||
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
|
||||
func adminAuth(
|
||||
authService *service.AuthService,
|
||||
userService *service.UserService,
|
||||
settingService *service.SettingService,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// WebSocket upgrade requests cannot set Authorization headers in browsers.
|
||||
// For admin WebSocket endpoints (e.g. Ops realtime), allow passing the JWT via
|
||||
// Sec-WebSocket-Protocol (subprotocol list) using a prefixed token item:
|
||||
// Sec-WebSocket-Protocol: sub2api-admin, jwt.<token>
|
||||
if isWebSocketUpgradeRequest(c) {
|
||||
if token := extractJWTFromWebSocketSubprotocol(c); token != "" {
|
||||
if !validateJWTForAdmin(c, token, authService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 x-api-key header(Admin API Key 认证)
|
||||
apiKey := c.GetHeader("x-api-key")
|
||||
if apiKey != "" {
|
||||
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 Authorization header(JWT 认证)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
|
||||
token := strings.TrimSpace(parts[1])
|
||||
if token == "" {
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
|
||||
return
|
||||
}
|
||||
if !validateJWTForAdmin(c, token, authService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 无有效认证信息
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
|
||||
}
|
||||
}
|
||||
|
||||
func isWebSocketUpgradeRequest(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil {
|
||||
return false
|
||||
}
|
||||
// RFC6455 handshake uses:
|
||||
// Connection: Upgrade
|
||||
// Upgrade: websocket
|
||||
upgrade := strings.ToLower(strings.TrimSpace(c.GetHeader("Upgrade")))
|
||||
if upgrade != "websocket" {
|
||||
return false
|
||||
}
|
||||
connection := strings.ToLower(c.GetHeader("Connection"))
|
||||
return strings.Contains(connection, "upgrade")
|
||||
}
|
||||
|
||||
func extractJWTFromWebSocketSubprotocol(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
raw := strings.TrimSpace(c.GetHeader("Sec-WebSocket-Protocol"))
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// The header is a comma-separated list of tokens. We reserve the prefix "jwt."
|
||||
// for carrying the admin JWT.
|
||||
for _, part := range strings.Split(raw, ",") {
|
||||
p := strings.TrimSpace(part)
|
||||
if strings.HasPrefix(p, "jwt.") {
|
||||
token := strings.TrimSpace(strings.TrimPrefix(p, "jwt."))
|
||||
if token != "" {
|
||||
return token
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// validateAdminAPIKey 验证管理员 API Key
|
||||
func validateAdminAPIKey(
|
||||
c *gin.Context,
|
||||
key string,
|
||||
settingService *service.SettingService,
|
||||
userService *service.UserService,
|
||||
) bool {
|
||||
storedKey, err := settingService.GetAdminAPIKey(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
|
||||
return false
|
||||
}
|
||||
|
||||
// 未配置或不匹配,统一返回相同错误(避免信息泄露)
|
||||
if storedKey == "" || subtle.ConstantTimeCompare([]byte(key), []byte(storedKey)) != 1 {
|
||||
AbortWithError(c, 401, "INVALID_ADMIN_KEY", "Invalid admin API key")
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取真实的管理员用户
|
||||
admin, err := userService.GetFirstAdmin(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found")
|
||||
return false
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: admin.ID,
|
||||
Concurrency: admin.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), admin.Role)
|
||||
c.Set("auth_method", "admin_api_key")
|
||||
return true
|
||||
}
|
||||
|
||||
// validateJWTForAdmin 验证 JWT 并检查管理员权限
|
||||
func validateJWTForAdmin(
|
||||
c *gin.Context,
|
||||
token string,
|
||||
authService *service.AuthService,
|
||||
userService *service.UserService,
|
||||
) bool {
|
||||
// 验证 JWT token
|
||||
claims, err := authService.ValidateToken(token)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrTokenExpired) {
|
||||
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
|
||||
return false
|
||||
}
|
||||
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
|
||||
return false
|
||||
}
|
||||
|
||||
// 从数据库获取用户
|
||||
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
|
||||
if err != nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return false
|
||||
}
|
||||
|
||||
// 校验 TokenVersion,确保管理员改密后旧 token 失效
|
||||
if claims.TokenVersion != user.TokenVersion {
|
||||
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查管理员权限
|
||||
if !user.IsAdmin() {
|
||||
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
|
||||
return false
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: user.ID,
|
||||
Concurrency: user.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), user.Role)
|
||||
c.Set("auth_method", "jwt")
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
admin := &service.User{
|
||||
ID: 1,
|
||||
Email: "admin@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
TokenVersion: 2,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
userRepo := &stubUserRepo{
|
||||
getByID: func(ctx context.Context, id int64) (*service.User, error) {
|
||||
if id != admin.ID {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
clone := *admin
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
userService := service.NewUserService(userRepo, nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
t.Run("token_version_mismatch_rejected", func(t *testing.T) {
|
||||
token, err := authService.GenerateToken(&service.User{
|
||||
ID: admin.ID,
|
||||
Email: admin.Email,
|
||||
Role: admin.Role,
|
||||
TokenVersion: admin.TokenVersion - 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
|
||||
})
|
||||
|
||||
t.Run("token_version_match_allows", func(t *testing.T) {
|
||||
token, err := authService.GenerateToken(&service.User{
|
||||
ID: admin.ID,
|
||||
Email: admin.Email,
|
||||
Role: admin.Role,
|
||||
TokenVersion: admin.TokenVersion,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("websocket_token_version_mismatch_rejected", func(t *testing.T) {
|
||||
token, err := authService.GenerateToken(&service.User{
|
||||
ID: admin.ID,
|
||||
Email: admin.Email,
|
||||
Role: admin.Role,
|
||||
TokenVersion: admin.TokenVersion - 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
|
||||
})
|
||||
|
||||
t.Run("websocket_token_version_match_allows", func(t *testing.T) {
|
||||
token, err := authService.GenerateToken(&service.User{
|
||||
ID: admin.ID,
|
||||
Email: admin.Email,
|
||||
Role: admin.Role,
|
||||
TokenVersion: admin.TokenVersion,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
type stubUserRepo struct {
|
||||
getByID func(ctx context.Context, id int64) (*service.User, error)
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) Create(ctx context.Context, user *service.User) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
|
||||
if s.getByID == nil {
|
||||
panic("GetByID not stubbed")
|
||||
}
|
||||
return s.getByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
||||
panic("unexpected GetByEmail call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) {
|
||||
panic("unexpected GetFirstAdmin call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) Update(ctx context.Context, user *service.User) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
panic("unexpected UpdateBalance call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
panic("unexpected DeductBalance call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||||
panic("unexpected UpdateConcurrency call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
panic("unexpected ExistsByEmail call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||||
panic("unexpected AddGroupToAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
panic("unexpected UpdateTotpSecret call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected EnableTotp call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected DisableTotp call")
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AdminOnly 管理员权限中间件
|
||||
// 必须在JWTAuth中间件之后使用
|
||||
func AdminOnly() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
role, ok := GetUserRoleFromContext(c)
|
||||
if !ok {
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为管理员
|
||||
if role != service.RoleAdmin {
|
||||
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,252 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// NewAPIKeyAuthMiddleware 创建 API Key 认证中间件
|
||||
func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) APIKeyAuthMiddleware {
|
||||
return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
|
||||
}
|
||||
|
||||
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
//
|
||||
// 中间件职责分为两层:
|
||||
// - 鉴权(Authentication):验证 Key 有效性、用户状态、IP 限制 —— 始终执行
|
||||
// - 计费执行(Billing Enforcement):过期/配额/订阅/余额检查 —— skipBilling 时整块跳过
|
||||
//
|
||||
// /v1/usage 端点只需鉴权,不需要计费执行(允许过期/配额耗尽的 Key 查询自身用量)。
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// ── 1. 提取 API Key ──────────────────────────────────────────
|
||||
|
||||
queryKey := strings.TrimSpace(c.Query("key"))
|
||||
queryApiKey := strings.TrimSpace(c.Query("api_key"))
|
||||
if queryKey != "" || queryApiKey != "" {
|
||||
AbortWithError(c, 400, "api_key_in_query_deprecated", "API key in query parameter is deprecated. Please use Authorization header instead.")
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试从Authorization header中提取API key (Bearer scheme)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
var apiKeyString string
|
||||
|
||||
if authHeader != "" {
|
||||
// 验证Bearer scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
|
||||
apiKeyString = strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
// 如果Authorization header中没有,尝试从x-api-key header中提取
|
||||
if apiKeyString == "" {
|
||||
apiKeyString = c.GetHeader("x-api-key")
|
||||
}
|
||||
|
||||
// 如果x-api-key header中没有,尝试从x-goog-api-key header中提取(Gemini CLI兼容)
|
||||
if apiKeyString == "" {
|
||||
apiKeyString = c.GetHeader("x-goog-api-key")
|
||||
}
|
||||
|
||||
// 如果所有header都没有API key
|
||||
if apiKeyString == "" {
|
||||
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header")
|
||||
return
|
||||
}
|
||||
|
||||
// ── 2. 验证 Key 存在 ─────────────────────────────────────────
|
||||
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrAPIKeyNotFound) {
|
||||
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
|
||||
return
|
||||
}
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
|
||||
return
|
||||
}
|
||||
|
||||
// ── 3. 基础鉴权(始终执行) ─────────────────────────────────
|
||||
|
||||
// disabled / 未知状态 → 无条件拦截(expired 和 quota_exhausted 留给计费阶段)
|
||||
if !apiKey.IsActive() &&
|
||||
apiKey.Status != service.StatusAPIKeyExpired &&
|
||||
apiKey.Status != service.StatusAPIKeyQuotaExhausted {
|
||||
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 IP 限制(白名单/黑名单)
|
||||
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
||||
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
||||
clientIP := ip.GetTrustedClientIP(c)
|
||||
allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist)
|
||||
if !allowed {
|
||||
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 检查关联的用户
|
||||
if apiKey.User == nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !apiKey.User.IsActive() {
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return
|
||||
}
|
||||
|
||||
// ── 4. SimpleMode → early return ─────────────────────────────
|
||||
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// ── 5. 加载订阅(订阅模式时始终加载) ───────────────────────
|
||||
|
||||
// skipBilling: /v1/usage 只需鉴权,跳过所有计费执行
|
||||
skipBilling := c.Request.URL.Path == "/v1/usage"
|
||||
|
||||
var subscription *service.UserSubscription
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
|
||||
if isSubscriptionType && subscriptionService != nil {
|
||||
sub, subErr := subscriptionService.GetActiveSubscription(
|
||||
c.Request.Context(),
|
||||
apiKey.User.ID,
|
||||
apiKey.Group.ID,
|
||||
)
|
||||
if subErr != nil {
|
||||
if !skipBilling {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
// skipBilling: 订阅不存在也放行,handler 会返回可用的数据
|
||||
} else {
|
||||
subscription = sub
|
||||
}
|
||||
}
|
||||
|
||||
// ── 6. 计费执行(skipBilling 时整块跳过) ────────────────────
|
||||
|
||||
if !skipBilling {
|
||||
// Key 状态检查
|
||||
switch apiKey.Status {
|
||||
case service.StatusAPIKeyQuotaExhausted:
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
return
|
||||
case service.StatusAPIKeyExpired:
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
return
|
||||
}
|
||||
|
||||
// 运行时过期/配额检查(即使状态是 active,也要检查时间和用量)
|
||||
if apiKey.IsExpired() {
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
return
|
||||
}
|
||||
if apiKey.IsQuotaExhausted() {
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
return
|
||||
}
|
||||
|
||||
// 订阅模式:验证订阅限额
|
||||
if subscription != nil {
|
||||
needsMaintenance, validateErr := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
|
||||
if validateErr != nil {
|
||||
code := "SUBSCRIPTION_INVALID"
|
||||
status := 403
|
||||
if errors.Is(validateErr, service.ErrDailyLimitExceeded) ||
|
||||
errors.Is(validateErr, service.ErrWeeklyLimitExceeded) ||
|
||||
errors.Is(validateErr, service.ErrMonthlyLimitExceeded) {
|
||||
code = "USAGE_LIMIT_EXCEEDED"
|
||||
status = 429
|
||||
}
|
||||
AbortWithError(c, status, code, validateErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 窗口维护异步化(不阻塞请求)
|
||||
if needsMaintenance {
|
||||
maintenanceCopy := *subscription
|
||||
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
}
|
||||
} else {
|
||||
// 非订阅模式 或 订阅模式但 subscriptionService 未注入:回退到余额检查
|
||||
if apiKey.User.Balance <= 0 {
|
||||
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── 7. 设置上下文 → Next ─────────────────────────────────────
|
||||
|
||||
if subscription != nil {
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
}
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetAPIKeyFromContext 从上下文中获取API key
|
||||
func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) {
|
||||
value, exists := c.Get(string(ContextKeyAPIKey))
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
apiKey, ok := value.(*service.APIKey)
|
||||
return apiKey, ok
|
||||
}
|
||||
|
||||
// GetSubscriptionFromContext 从上下文中获取订阅信息
|
||||
func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool) {
|
||||
value, exists := c.Get(string(ContextKeySubscription))
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
subscription, ok := value.(*service.UserSubscription)
|
||||
return subscription, ok
|
||||
}
|
||||
|
||||
func setGroupContext(c *gin.Context, group *service.Group) {
|
||||
if !service.IsGroupContextValid(group) {
|
||||
return
|
||||
}
|
||||
if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID && service.IsGroupContextValid(existing) {
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// APIKeyAuthGoogle is a Google-style error wrapper for API key auth.
|
||||
func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) gin.HandlerFunc {
|
||||
return APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
|
||||
}
|
||||
|
||||
// APIKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
|
||||
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
|
||||
//
|
||||
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
|
||||
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
|
||||
abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.")
|
||||
return
|
||||
}
|
||||
apiKeyString := extractAPIKeyForGoogle(c)
|
||||
if apiKeyString == "" {
|
||||
abortWithGoogleError(c, 401, "API key is required")
|
||||
return
|
||||
}
|
||||
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrAPIKeyNotFound) {
|
||||
abortWithGoogleError(c, 401, "Invalid API key")
|
||||
return
|
||||
}
|
||||
abortWithGoogleError(c, 500, "Failed to validate API key")
|
||||
return
|
||||
}
|
||||
|
||||
if !apiKey.IsActive() {
|
||||
abortWithGoogleError(c, 401, "API key is disabled")
|
||||
return
|
||||
}
|
||||
if apiKey.User == nil {
|
||||
abortWithGoogleError(c, 401, "User associated with API key not found")
|
||||
return
|
||||
}
|
||||
if !apiKey.User.IsActive() {
|
||||
abortWithGoogleError(c, 401, "User account is not active")
|
||||
return
|
||||
}
|
||||
|
||||
// 简易模式:跳过余额和订阅检查
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
if isSubscriptionType && subscriptionService != nil {
|
||||
subscription, err := subscriptionService.GetActiveSubscription(
|
||||
c.Request.Context(),
|
||||
apiKey.User.ID,
|
||||
apiKey.Group.ID,
|
||||
)
|
||||
if err != nil {
|
||||
abortWithGoogleError(c, 403, "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
|
||||
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
|
||||
if err != nil {
|
||||
status := 403
|
||||
if errors.Is(err, service.ErrDailyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
|
||||
status = 429
|
||||
}
|
||||
abortWithGoogleError(c, status, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
|
||||
if needsMaintenance {
|
||||
maintenanceCopy := *subscription
|
||||
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
}
|
||||
} else {
|
||||
if apiKey.User.Balance <= 0 {
|
||||
abortWithGoogleError(c, 403, "Insufficient account balance")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// extractAPIKeyForGoogle extracts API key for Google/Gemini endpoints.
|
||||
// Priority: x-goog-api-key > Authorization: Bearer > x-api-key > query key
|
||||
// This allows OpenClaw and other clients using Bearer auth to work with Gemini endpoints.
|
||||
func extractAPIKeyForGoogle(c *gin.Context) string {
|
||||
// 1) preferred: Gemini native header
|
||||
if k := strings.TrimSpace(c.GetHeader("x-goog-api-key")); k != "" {
|
||||
return k
|
||||
}
|
||||
|
||||
// 2) fallback: Authorization: Bearer <key>
|
||||
auth := strings.TrimSpace(c.GetHeader("Authorization"))
|
||||
if auth != "" {
|
||||
parts := strings.SplitN(auth, " ", 2)
|
||||
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
|
||||
if k := strings.TrimSpace(parts[1]); k != "" {
|
||||
return k
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3) x-api-key header (backward compatibility)
|
||||
if k := strings.TrimSpace(c.GetHeader("x-api-key")); k != "" {
|
||||
return k
|
||||
}
|
||||
|
||||
// 4) query parameter key (for specific paths)
|
||||
if allowGoogleQueryKey(c.Request.URL.Path) {
|
||||
if v := strings.TrimSpace(c.Query("key")); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func allowGoogleQueryKey(path string) bool {
|
||||
return strings.HasPrefix(path, "/v1beta") || strings.HasPrefix(path, "/antigravity/v1beta")
|
||||
}
|
||||
|
||||
func abortWithGoogleError(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": googleapi.HTTPStatusToGoogleStatus(status),
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
@@ -0,0 +1,686 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeAPIKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
|
||||
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
|
||||
}
|
||||
|
||||
type fakeGoogleSubscriptionRepo struct {
|
||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||
activateWindow func(ctx context.Context, id int64, start time.Time) error
|
||||
resetDaily func(ctx context.Context, id int64, start time.Time) error
|
||||
resetWeekly func(ctx context.Context, id int64, start time.Time) error
|
||||
resetMonthly func(ctx context.Context, id int64, start time.Time) error
|
||||
}
|
||||
|
||||
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
return "", 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if f.getByKey == nil {
|
||||
return nil, errors.New("unexpected call")
|
||||
}
|
||||
return f.getByKey(ctx, key)
|
||||
}
|
||||
func (f fakeAPIKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return f.GetByKey(ctx, key)
|
||||
}
|
||||
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
if f.updateLastUsed != nil {
|
||||
return f.updateLastUsed(ctx, id, usedAt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (f fakeAPIKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (f fakeAPIKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
|
||||
return &service.APIKeyRateLimitData{}, nil
|
||||
}
|
||||
|
||||
func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if f.getActive != nil {
|
||||
return f.getActive(ctx, userID, groupID)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
|
||||
if f.updateStatus != nil {
|
||||
return f.updateStatus(ctx, subscriptionID, status)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
|
||||
if f.activateWindow != nil {
|
||||
return f.activateWindow(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, start time.Time) error {
|
||||
if f.resetDaily != nil {
|
||||
return f.resetDaily(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, start time.Time) error {
|
||||
if f.resetWeekly != nil {
|
||||
return f.resetWeekly(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, start time.Time) error {
|
||||
if f.resetMonthly != nil {
|
||||
return f.resetMonthly(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type googleErrorResponse struct {
|
||||
Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService {
|
||||
return service.NewAPIKeyService(
|
||||
repo,
|
||||
nil, // userRepo (unused in GetByKey)
|
||||
nil, // groupRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
nil, // cache
|
||||
&config.Config{},
|
||||
)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return nil, errors.New("should not be called")
|
||||
},
|
||||
})
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
|
||||
require.Equal(t, "API key is required", resp.Error.Message)
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return nil, errors.New("should not be called")
|
||||
},
|
||||
})
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?api_key=legacy", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusBadRequest, resp.Error.Code)
|
||||
require.Equal(t, "Query parameter api_key is deprecated. Use Authorization header or key instead.", resp.Error.Message)
|
||||
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 99,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformGemini,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyService := service.NewAPIKeyService(
|
||||
fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&config.Config{RunMode: config.RunModeSimple},
|
||||
)
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
r := gin.New()
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return &service.APIKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 123,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?key=valid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
},
|
||||
})
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
|
||||
require.Equal(t, "Invalid API key", resp.Error.Message)
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return nil, errors.New("db down")
|
||||
},
|
||||
})
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer any")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusInternalServerError, resp.Error.Code)
|
||||
require.Equal(t, "Failed to validate API key", resp.Error.Message)
|
||||
require.Equal(t, "INTERNAL", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return &service.APIKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusDisabled,
|
||||
User: &service.User{
|
||||
ID: 123,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer disabled")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
|
||||
require.Equal(t, "API key is disabled", resp.Error.Message)
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return &service.APIKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 123,
|
||||
Status: service.StatusActive,
|
||||
Balance: 0,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer ok")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusForbidden, resp.Error.Code)
|
||||
require.Equal(t, "Insufficient account balance", resp.Error.Message)
|
||||
require.Equal(t, "PERMISSION_DENIED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedOnSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 11,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 201,
|
||||
UserID: user.ID,
|
||||
Key: "google-touch-ok",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
}
|
||||
|
||||
var touchedID int64
|
||||
var touchedAt time.Time
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
touchedID = id
|
||||
touchedAt = usedAt
|
||||
return nil
|
||||
},
|
||||
})
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("x-goog-api-key", apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, apiKey.ID, touchedID)
|
||||
require.False(t, touchedAt.IsZero())
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_TouchFailureDoesNotBlock(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 12,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 202,
|
||||
UserID: user.ID,
|
||||
Key: "google-touch-fail",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
}
|
||||
|
||||
touchCalls := 0
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
touchCalls++
|
||||
return errors.New("write failed")
|
||||
},
|
||||
})
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("x-goog-api-key", apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 1, touchCalls)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 13,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 203,
|
||||
UserID: user.ID,
|
||||
Key: "google-touch-standard",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
}
|
||||
|
||||
touchCalls := 0
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
touchCalls++
|
||||
return nil
|
||||
},
|
||||
})
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 1, touchCalls)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_SubscriptionLimitExceededReturns429(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
limit := 1.0
|
||||
group := &service.Group{
|
||||
ID: 77,
|
||||
Name: "gemini-sub",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformGemini,
|
||||
Hydrated: true,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
DailyLimitUSD: &limit,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 999,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 501,
|
||||
UserID: user.ID,
|
||||
Key: "google-sub-limit",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
})
|
||||
|
||||
now := time.Now()
|
||||
sub := &service.UserSubscription{
|
||||
ID: 601,
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
DailyWindowStart: &now,
|
||||
DailyUsageUSD: 10,
|
||||
}
|
||||
subscriptionService := service.NewSubscriptionService(nil, fakeGoogleSubscriptionRepo{
|
||||
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if userID != user.ID || groupID != group.ID {
|
||||
return nil, service.ErrSubscriptionNotFound
|
||||
}
|
||||
clone := *sub
|
||||
return &clone, nil
|
||||
},
|
||||
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
|
||||
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
}, nil, nil, &config.Config{RunMode: config.RunModeStandard})
|
||||
|
||||
r := gin.New()
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, &config.Config{RunMode: config.RunModeStandard}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("x-goog-api-key", apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusTooManyRequests, resp.Error.Code)
|
||||
require.Equal(t, "RESOURCE_EXHAUSTED", resp.Error.Status)
|
||||
require.Contains(t, resp.Error.Message, "daily usage limit exceeded")
|
||||
}
|
||||
@@ -0,0 +1,706 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
limit := 1.0
|
||||
group := &service.Group{
|
||||
ID: 42,
|
||||
Name: "sub",
|
||||
Status: service.StatusActive,
|
||||
Hydrated: true,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
DailyLimitUSD: &limit,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("standard_mode_needs_maintenance_does_not_block_request", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
cfg.SubscriptionMaintenance.WorkerCount = 1
|
||||
cfg.SubscriptionMaintenance.QueueSize = 1
|
||||
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
|
||||
past := time.Now().Add(-48 * time.Hour)
|
||||
sub := &service.UserSubscription{
|
||||
ID: 55,
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
DailyWindowStart: &past,
|
||||
DailyUsageUSD: 0,
|
||||
}
|
||||
maintenanceCalled := make(chan struct{}, 1)
|
||||
subscriptionRepo := &stubUserSubscriptionRepo{
|
||||
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
clone := *sub
|
||||
return &clone, nil
|
||||
},
|
||||
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
|
||||
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetDaily: func(ctx context.Context, id int64, start time.Time) error {
|
||||
maintenanceCalled <- struct{}{}
|
||||
return nil
|
||||
},
|
||||
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
}
|
||||
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg)
|
||||
t.Cleanup(subscriptionService.Stop)
|
||||
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
select {
|
||||
case <-maintenanceCalled:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("expected maintenance to be scheduled")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("simple_mode_accepts_lowercase_bearer", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Authorization", "bearer "+apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
|
||||
now := time.Now()
|
||||
sub := &service.UserSubscription{
|
||||
ID: 55,
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
DailyWindowStart: &now,
|
||||
DailyUsageUSD: 10,
|
||||
}
|
||||
subscriptionRepo := &stubUserSubscriptionRepo{
|
||||
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if userID != sub.UserID || groupID != sub.GroupID {
|
||||
return nil, service.ErrSubscriptionNotFound
|
||||
}
|
||||
clone := *sub
|
||||
return &clone, nil
|
||||
},
|
||||
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
|
||||
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
}
|
||||
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
require.Contains(t, w.Body.String(), "USAGE_LIMIT_EXCEEDED")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 101,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 101,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
|
||||
invalidGroup := &service.Group{
|
||||
ID: group.ID,
|
||||
Platform: group.Platform,
|
||||
Status: group.Status,
|
||||
}
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID || !groupFromCtx.Hydrated || groupFromCtx == invalidGroup {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, invalidGroup))
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
IPWhitelist: []string{"1.2.3.4"},
|
||||
}
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
require.NoError(t, router.SetTrustedProxies(nil))
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.RemoteAddr = "9.9.9.9:12345"
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, w.Code)
|
||||
require.Contains(t, w.Body.String(), "ACCESS_DENIED")
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthTouchesLastUsedOnSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "touch-ok",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
}
|
||||
|
||||
var touchedID int64
|
||||
var touchedAt time.Time
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
touchedID = id
|
||||
touchedAt = usedAt
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, nil, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, apiKey.ID, touchedID)
|
||||
require.False(t, touchedAt.IsZero(), "expected touch timestamp")
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthTouchLastUsedFailureDoesNotBlock(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 8,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 101,
|
||||
UserID: user.ID,
|
||||
Key: "touch-fail",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
}
|
||||
|
||||
touchCalls := 0
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
touchCalls++
|
||||
return errors.New("db unavailable")
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, nil, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code, "touch failure should not block request")
|
||||
require.Equal(t, 1, touchCalls)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthTouchesLastUsedInStandardMode(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 9,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 102,
|
||||
UserID: user.ID,
|
||||
Key: "touch-standard",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
}
|
||||
|
||||
touchCalls := 0
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
touchCalls++
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, nil, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, 1, touchCalls)
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
return router
|
||||
}
|
||||
|
||||
type stubApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
|
||||
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
return "", 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if r.getByKey != nil {
|
||||
return r.getByKey(ctx, key)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return r.GetByKey(ctx, key)
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
if r.updateLastUsed != nil {
|
||||
return r.updateLastUsed(ctx, id, usedAt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type stubUserSubscriptionRepo struct {
|
||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||
activateWindow func(ctx context.Context, id int64, start time.Time) error
|
||||
resetDaily func(ctx context.Context, id int64, start time.Time) error
|
||||
resetWeekly func(ctx context.Context, id int64, start time.Time) error
|
||||
resetMonthly func(ctx context.Context, id int64, start time.Time) error
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if r.getActive != nil {
|
||||
return r.getActive(ctx, userID, groupID)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
|
||||
if r.updateStatus != nil {
|
||||
return r.updateStatus(ctx, subscriptionID, status)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
|
||||
if r.activateWindow != nil {
|
||||
return r.activateWindow(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
if r.resetDaily != nil {
|
||||
return r.resetDaily(ctx, id, newWindowStart)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
if r.resetWeekly != nil {
|
||||
return r.resetWeekly(ctx, id, newWindowStart)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
if r.resetMonthly != nil {
|
||||
return r.resetMonthly(ctx, id, newWindowStart)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package middleware
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
// AuthSubject is the minimal authenticated identity stored in gin context.
|
||||
// Decision: {UserID int64, Concurrency int}
|
||||
type AuthSubject struct {
|
||||
UserID int64
|
||||
Concurrency int
|
||||
}
|
||||
|
||||
func GetAuthSubjectFromContext(c *gin.Context) (AuthSubject, bool) {
|
||||
value, exists := c.Get(string(ContextKeyUser))
|
||||
if !exists {
|
||||
return AuthSubject{}, false
|
||||
}
|
||||
subject, ok := value.(AuthSubject)
|
||||
return subject, ok
|
||||
}
|
||||
|
||||
func GetUserRoleFromContext(c *gin.Context) (string, bool) {
|
||||
value, exists := c.Get(string(ContextKeyUserRole))
|
||||
if !exists {
|
||||
return "", false
|
||||
}
|
||||
role, ok := value.(string)
|
||||
return role, ok
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled.
|
||||
// Must be placed AFTER JWT auth middleware so that the user role is available in context.
|
||||
func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
role, _ := GetUserRoleFromContext(c)
|
||||
if role == "admin" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. User self-service is disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
|
||||
// Allows: login, login/2fa, logout, refresh (admin needs these).
|
||||
// Blocks: register, forgot-password, reset-password, OAuth, etc.
|
||||
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
// Allow login, 2FA, logout, refresh, public settings
|
||||
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
|
||||
for _, suffix := range allowedSuffixes {
|
||||
if strings.HasSuffix(path, suffix) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,239 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type bmSettingRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
v, ok := r.values[key]
|
||||
if !ok {
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
if r.values == nil {
|
||||
r.values = make(map[string]string, len(settings))
|
||||
}
|
||||
for key, value := range settings {
|
||||
r.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Delete(_ context.Context, _ string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService {
|
||||
t.Helper()
|
||||
|
||||
repo := &bmSettingRepo{
|
||||
values: map[string]string{
|
||||
service.SettingKeyBackendModeEnabled: enabled,
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{})
|
||||
require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{
|
||||
BackendModeEnabled: enabled == "true",
|
||||
}))
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
func stringPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestBackendModeUserGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
role *string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_admin_allowed",
|
||||
enabled: "true",
|
||||
role: stringPtr("admin"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_user_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_no_role_blocked",
|
||||
enabled: "true",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_empty_role_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr(""),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
if tc.role != nil {
|
||||
role := *tc.role
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(string(ContextKeyUserRole), role)
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeUserGuard(svc))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendModeAuthGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
path string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login_2fa",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login/2fa",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_logout",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/logout",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_refresh",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/refresh",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_register",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_forgot_password",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/forgot-password",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeAuthGuard(svc))
|
||||
r.Any("/*path", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
|
||||
//
|
||||
// This is used by the Ops monitoring module for end-to-end request correlation.
|
||||
func ClientRequestID() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if v := c.Request.Context().Value(ctxkey.ClientRequestID); v != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)
|
||||
requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id)))
|
||||
ctx = logger.IntoContext(ctx, requestLogger)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var corsWarningOnce sync.Once
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS(cfg config.CORSConfig) gin.HandlerFunc {
|
||||
allowedOrigins := normalizeOrigins(cfg.AllowedOrigins)
|
||||
allowAll := false
|
||||
for _, origin := range allowedOrigins {
|
||||
if origin == "*" {
|
||||
allowAll = true
|
||||
break
|
||||
}
|
||||
}
|
||||
wildcardWithSpecific := allowAll && len(allowedOrigins) > 1
|
||||
if wildcardWithSpecific {
|
||||
allowedOrigins = []string{"*"}
|
||||
}
|
||||
allowCredentials := cfg.AllowCredentials
|
||||
|
||||
corsWarningOnce.Do(func() {
|
||||
if len(allowedOrigins) == 0 {
|
||||
log.Println("Warning: CORS allowed_origins not configured; cross-origin requests will be rejected.")
|
||||
}
|
||||
if wildcardWithSpecific {
|
||||
log.Println("Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins.")
|
||||
}
|
||||
if allowAll && allowCredentials {
|
||||
log.Println("Warning: CORS allowed_origins set to '*', disabling allow_credentials.")
|
||||
}
|
||||
})
|
||||
if allowAll && allowCredentials {
|
||||
allowCredentials = false
|
||||
}
|
||||
|
||||
allowedSet := make(map[string]struct{}, len(allowedOrigins))
|
||||
for _, origin := range allowedOrigins {
|
||||
if origin == "" || origin == "*" {
|
||||
continue
|
||||
}
|
||||
allowedSet[origin] = struct{}{}
|
||||
}
|
||||
allowHeaders := []string{
|
||||
"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization",
|
||||
"accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key",
|
||||
}
|
||||
// OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。
|
||||
openAIProperties := []string{
|
||||
"lang", "package-version", "os", "arch", "retry-count", "runtime",
|
||||
"runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout",
|
||||
}
|
||||
for _, prop := range openAIProperties {
|
||||
allowHeaders = append(allowHeaders, "x-stainless-"+prop)
|
||||
}
|
||||
allowHeadersValue := strings.Join(allowHeaders, ", ")
|
||||
|
||||
return func(c *gin.Context) {
|
||||
origin := strings.TrimSpace(c.GetHeader("Origin"))
|
||||
originAllowed := allowAll
|
||||
if origin != "" && !allowAll {
|
||||
_, originAllowed = allowedSet[origin]
|
||||
}
|
||||
|
||||
if originAllowed {
|
||||
if allowAll {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
} else if origin != "" {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
c.Writer.Header().Add("Vary", "Origin")
|
||||
}
|
||||
if allowCredentials {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeadersValue)
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||
c.Writer.Header().Set("Access-Control-Expose-Headers", "ETag")
|
||||
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
|
||||
}
|
||||
// 处理预检请求
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
if originAllowed {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
} else {
|
||||
c.AbortWithStatus(http.StatusForbidden)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOrigins(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
@@ -0,0 +1,308 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// cors_test 与 security_headers_test 在同一个包,但 init 是幂等的
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// --- Task 8.2: 验证 CORS 条件化头部 ---
|
||||
|
||||
func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
origin string
|
||||
}{
|
||||
{
|
||||
name: "preflight_disallowed_origin",
|
||||
method: http.MethodOptions,
|
||||
origin: "https://evil.example.com",
|
||||
},
|
||||
{
|
||||
name: "get_disallowed_origin",
|
||||
method: http.MethodGet,
|
||||
origin: "https://evil.example.com",
|
||||
},
|
||||
{
|
||||
name: "post_disallowed_origin",
|
||||
method: http.MethodPost,
|
||||
origin: "https://attacker.example.com",
|
||||
},
|
||||
{
|
||||
name: "preflight_no_origin",
|
||||
method: http.MethodOptions,
|
||||
origin: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(tt.method, "/", nil)
|
||||
if tt.origin != "" {
|
||||
c.Request.Header.Set("Origin", tt.origin)
|
||||
}
|
||||
|
||||
middleware(c)
|
||||
|
||||
// 不应设置 Allow-Headers、Allow-Methods 和 Max-Age
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"),
|
||||
"不允许的 origin 不应收到 Allow-Headers")
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"),
|
||||
"不允许的 origin 不应收到 Allow-Methods")
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Max-Age"),
|
||||
"不允许的 origin 不应收到 Max-Age")
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"),
|
||||
"不允许的 origin 不应收到 Allow-Origin")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
}{
|
||||
{name: "preflight_OPTIONS", method: http.MethodOptions},
|
||||
{name: "normal_GET", method: http.MethodGet},
|
||||
{name: "normal_POST", method: http.MethodPost},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(tt.method, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
// 应设置 Allow-Headers、Allow-Methods 和 Max-Age
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
|
||||
"允许的 origin 应收到 Allow-Headers")
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
|
||||
"允许的 origin 应收到 Allow-Methods")
|
||||
assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"),
|
||||
"允许的 origin 应收到 Max-Age=86400")
|
||||
assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"),
|
||||
"允许的 origin 应收到 Allow-Origin")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://evil.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, w.Code,
|
||||
"不允许的 origin 的 preflight 请求应返回 403")
|
||||
}
|
||||
|
||||
func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, w.Code,
|
||||
"允许的 origin 的 preflight 请求应返回 204")
|
||||
}
|
||||
|
||||
func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://any-origin.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"),
|
||||
"通配符配置应返回 *")
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
|
||||
"通配符 origin 应设置 Allow-Headers")
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
|
||||
"通配符 origin 应设置 Allow-Methods")
|
||||
}
|
||||
|
||||
func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: true,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
t.Run("allowed_origin_gets_credentials", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"),
|
||||
"允许的 origin 且开启 credentials 应设置 Allow-Credentials")
|
||||
})
|
||||
|
||||
t.Run("disallowed_origin_no_credentials", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://evil.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
|
||||
"不允许的 origin 不应收到 Allow-Credentials")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowCredentials: true,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://any.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
// 通配符 + credentials 不兼容,credentials 应被禁用
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
|
||||
"通配符 origin 应禁用 Allow-Credentials")
|
||||
}
|
||||
|
||||
func TestCORS_MultipleAllowedOrigins(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{
|
||||
"https://app1.example.com",
|
||||
"https://app2.example.com",
|
||||
},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
t.Run("first_origin_allowed", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://app1.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
|
||||
})
|
||||
|
||||
t.Run("second_origin_allowed", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://app2.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
|
||||
})
|
||||
|
||||
t.Run("unlisted_origin_rejected", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://app3.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Contains(t, w.Header().Values("Vary"), "Origin",
|
||||
"非通配符允许的 origin 应设置 Vary: Origin")
|
||||
}
|
||||
|
||||
func TestNormalizeOrigins(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
expect []string
|
||||
}{
|
||||
{name: "nil_input", input: nil, expect: nil},
|
||||
{name: "empty_input", input: []string{}, expect: nil},
|
||||
{name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}},
|
||||
{name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := normalizeOrigins(tt.input)
|
||||
assert.Equal(t, tt.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// NewJWTAuthMiddleware 创建 JWT 认证中间件
|
||||
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
|
||||
return JWTAuthMiddleware(jwtAuth(authService, userService))
|
||||
}
|
||||
|
||||
// jwtAuth JWT认证中间件实现
|
||||
func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从Authorization header中提取token
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization header is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证Bearer scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
||||
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := strings.TrimSpace(parts[1])
|
||||
if tokenString == "" {
|
||||
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := authService.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrTokenExpired) {
|
||||
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
|
||||
return
|
||||
}
|
||||
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库获取最新的用户信息
|
||||
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
|
||||
if err != nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return
|
||||
}
|
||||
|
||||
// Security: Validate TokenVersion to ensure token hasn't been invalidated
|
||||
// This check ensures tokens issued before a password change are rejected
|
||||
if claims.TokenVersion != user.TokenVersion {
|
||||
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: user.ID,
|
||||
Concurrency: user.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), user.Role)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated: prefer GetAuthSubjectFromContext in auth_subject.go.
|
||||
@@ -0,0 +1,256 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// stubJWTUserRepo 实现 UserRepository 的最小子集,仅支持 GetByID。
|
||||
type stubJWTUserRepo struct {
|
||||
service.UserRepository
|
||||
users map[int64]*service.User
|
||||
}
|
||||
|
||||
func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, error) {
|
||||
u, ok := r.users[id]
|
||||
if !ok {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// newJWTTestEnv 创建 JWT 认证中间件测试环境。
|
||||
// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。
|
||||
func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!"
|
||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||
|
||||
userRepo := &stubJWTUserRepo{users: users}
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil)
|
||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(gin.HandlerFunc(mw))
|
||||
r.GET("/protected", func(c *gin.Context) {
|
||||
subject, _ := GetAuthSubjectFromContext(c)
|
||||
role, _ := GetUserRoleFromContext(c)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"user_id": subject.UserID,
|
||||
"role": role,
|
||||
})
|
||||
})
|
||||
return r, authSvc
|
||||
}
|
||||
|
||||
func TestJWTAuth_ValidToken(t *testing.T) {
|
||||
user := &service.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
|
||||
|
||||
token, err := authSvc.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var body map[string]any
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, float64(1), body["user_id"])
|
||||
require.Equal(t, "user", body["role"])
|
||||
}
|
||||
|
||||
func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
|
||||
user := &service.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
|
||||
|
||||
token, err := authSvc.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
|
||||
router, _ := newJWTTestEnv(nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "UNAUTHORIZED", body.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth_InvalidHeaderFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
}{
|
||||
{"无Bearer前缀", "Token abc123"},
|
||||
{"缺少空格分隔", "Bearerabc123"},
|
||||
{"仅有单词", "abc123"},
|
||||
}
|
||||
router, _ := newJWTTestEnv(nil)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", tt.header)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "INVALID_AUTH_HEADER", body.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuth_EmptyToken(t *testing.T) {
|
||||
router, _ := newJWTTestEnv(nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer ")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "EMPTY_TOKEN", body.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth_TamperedToken(t *testing.T) {
|
||||
router, _ := newJWTTestEnv(nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoxfQ.invalid_signature")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "INVALID_TOKEN", body.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth_UserNotFound(t *testing.T) {
|
||||
// 使用 user ID=1 的 token,但 repo 中没有该用户
|
||||
fakeUser := &service.User{
|
||||
ID: 999,
|
||||
Email: "ghost@example.com",
|
||||
Role: "user",
|
||||
Status: service.StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
// 创建环境时不注入此用户,这样 GetByID 会失败
|
||||
router, authSvc := newJWTTestEnv(map[int64]*service.User{})
|
||||
|
||||
token, err := authSvc.GenerateToken(fakeUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "USER_NOT_FOUND", body.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth_UserInactive(t *testing.T) {
|
||||
user := &service.User{
|
||||
ID: 1,
|
||||
Email: "disabled@example.com",
|
||||
Role: "user",
|
||||
Status: service.StatusDisabled,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
|
||||
|
||||
token, err := authSvc.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "USER_INACTIVE", body.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth_TokenVersionMismatch(t *testing.T) {
|
||||
// Token 生成时 TokenVersion=1,但数据库中用户已更新为 TokenVersion=2(密码修改)
|
||||
userForToken := &service.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
Status: service.StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
userInDB := &service.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
Status: service.StatusActive,
|
||||
TokenVersion: 2, // 密码修改后版本递增
|
||||
}
|
||||
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: userInDB})
|
||||
|
||||
token, err := authSvc.GenerateToken(userForToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "TOKEN_REVOKED", body.Code)
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Logger 请求日志中间件
|
||||
func Logger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 请求路径
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 跳过健康检查等高频探针路径的日志
|
||||
if path == "/health" || path == "/setup/status" {
|
||||
return
|
||||
}
|
||||
|
||||
endTime := time.Now()
|
||||
latency := endTime.Sub(startTime)
|
||||
|
||||
method := c.Request.Method
|
||||
statusCode := c.Writer.Status()
|
||||
clientIP := c.ClientIP()
|
||||
protocol := c.Request.Proto
|
||||
accountID, hasAccountID := c.Request.Context().Value(ctxkey.AccountID).(int64)
|
||||
platform, _ := c.Request.Context().Value(ctxkey.Platform).(string)
|
||||
model, _ := c.Request.Context().Value(ctxkey.Model).(string)
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "http.access"),
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.Int64("latency_ms", latency.Milliseconds()),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("protocol", protocol),
|
||||
zap.String("method", method),
|
||||
zap.String("path", path),
|
||||
}
|
||||
if hasAccountID && accountID > 0 {
|
||||
fields = append(fields, zap.Int64("account_id", accountID))
|
||||
}
|
||||
if platform != "" {
|
||||
fields = append(fields, zap.String("platform", platform))
|
||||
}
|
||||
if model != "" {
|
||||
fields = append(fields, zap.String("model", model))
|
||||
}
|
||||
|
||||
l := logger.FromContext(c.Request.Context()).With(fields...)
|
||||
l.Info("http request completed", zap.Time("completed_at", endTime))
|
||||
|
||||
if len(c.Errors) > 0 {
|
||||
l.Warn("http request contains gin errors", zap.String("errors", c.Errors.String()))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ContextKey 定义上下文键类型
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
// ContextKeyUser 用户上下文键
|
||||
ContextKeyUser ContextKey = "user"
|
||||
// ContextKeyUserRole 当前用户角色(string)
|
||||
ContextKeyUserRole ContextKey = "user_role"
|
||||
// ContextKeyAPIKey API密钥上下文键
|
||||
ContextKeyAPIKey ContextKey = "api_key"
|
||||
// ContextKeySubscription 订阅上下文键
|
||||
ContextKeySubscription ContextKey = "subscription"
|
||||
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
|
||||
ContextKeyForcePlatform ContextKey = "force_platform"
|
||||
)
|
||||
|
||||
// ForcePlatform 返回设置强制平台的中间件
|
||||
// 同时设置 request.Context(供 Service 使用)和 gin.Context(供 Handler 快速检查)
|
||||
func ForcePlatform(platform string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 设置到 request.Context,使用 ctxkey.ForcePlatform 供 Service 层读取
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, platform)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
// 同时设置到 gin.Context,供 Handler 快速检查
|
||||
c.Set(string(ContextKeyForcePlatform), platform)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// HasForcePlatform 检查是否有强制平台(用于 Handler 跳过分组检查)
|
||||
func HasForcePlatform(c *gin.Context) bool {
|
||||
_, exists := c.Get(string(ContextKeyForcePlatform))
|
||||
return exists
|
||||
}
|
||||
|
||||
// GetForcePlatformFromContext 从 gin.Context 获取强制平台
|
||||
func GetForcePlatformFromContext(c *gin.Context) (string, bool) {
|
||||
value, exists := c.Get(string(ContextKeyForcePlatform))
|
||||
if !exists {
|
||||
return "", false
|
||||
}
|
||||
platform, ok := value.(string)
|
||||
return platform, ok
|
||||
}
|
||||
|
||||
// ErrorResponse 标准错误响应结构
|
||||
type ErrorResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NewErrorResponse 创建错误响应
|
||||
func NewErrorResponse(code, message string) ErrorResponse {
|
||||
return ErrorResponse{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// AbortWithError 中断请求并返回JSON错误
|
||||
func AbortWithError(c *gin.Context, statusCode int, code, message string) {
|
||||
c.JSON(statusCode, NewErrorResponse(code, message))
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// RequireGroupAssignment — 未分组 Key 拦截中间件
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// GatewayErrorWriter 定义网关错误响应格式(不同协议使用不同格式)
|
||||
type GatewayErrorWriter func(c *gin.Context, status int, message string)
|
||||
|
||||
// AnthropicErrorWriter 按 Anthropic API 规范输出错误
|
||||
func AnthropicErrorWriter(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{"type": "permission_error", "message": message},
|
||||
})
|
||||
}
|
||||
|
||||
// GoogleErrorWriter 按 Google API 规范输出错误
|
||||
func GoogleErrorWriter(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": googleapi.HTTPStatusToGoogleStatus(status),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// RequireGroupAssignment 检查 API Key 是否已分配到分组,
|
||||
// 如果未分组且系统设置不允许未分组 Key 调度则返回 403。
|
||||
func RequireGroupAssignment(settingService *service.SettingService, writeError GatewayErrorWriter) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
apiKey, ok := GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey.GroupID != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
// 未分组 Key — 检查系统设置
|
||||
if settingService.IsUngroupedKeySchedulingAllowed(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
writeError(c, http.StatusForbidden, "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClientRequestID_GeneratesWhenMissing(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(ClientRequestID())
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
v := c.Request.Context().Value(ctxkey.ClientRequestID)
|
||||
require.NotNil(t, v)
|
||||
id, ok := v.(string)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, id)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestClientRequestID_PreservesExisting(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(ClientRequestID())
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
id, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "keep", id)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "keep"))
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestRequestBodyLimit_LimitsBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(RequestBodyLimit(4))
|
||||
r.POST("/t", func(c *gin.Context) {
|
||||
_, err := io.ReadAll(c.Request.Body)
|
||||
require.Error(t, err)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/t", bytes.NewBufferString("12345"))
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestForcePlatform_SetsContextAndGinValue(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(ForcePlatform("anthropic"))
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
require.True(t, HasForcePlatform(c))
|
||||
v, ok := GetForcePlatformFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "anthropic", v)
|
||||
|
||||
ctxV := c.Request.Context().Value(ctxkey.ForcePlatform)
|
||||
require.Equal(t, "anthropic", ctxV)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthSubjectHelpers_RoundTrip(t *testing.T) {
|
||||
c := &gin.Context{}
|
||||
c.Set(string(ContextKeyUser), AuthSubject{UserID: 1, Concurrency: 2})
|
||||
c.Set(string(ContextKeyUserRole), "admin")
|
||||
|
||||
sub, ok := GetAuthSubjectFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(1), sub.UserID)
|
||||
require.Equal(t, 2, sub.Concurrency)
|
||||
|
||||
role, ok := GetUserRoleFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "admin", role)
|
||||
}
|
||||
|
||||
func TestAPIKeyAndSubscriptionFromContext(t *testing.T) {
|
||||
c := &gin.Context{}
|
||||
|
||||
key := &service.APIKey{ID: 1}
|
||||
c.Set(string(ContextKeyAPIKey), key)
|
||||
gotKey, ok := GetAPIKeyFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(1), gotKey.ID)
|
||||
|
||||
sub := &service.UserSubscription{ID: 2}
|
||||
c.Set(string(ContextKeySubscription), sub)
|
||||
gotSub, ok := GetSubscriptionFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(2), gotSub.ID)
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Recovery converts panics into the project's standard JSON error envelope.
|
||||
//
|
||||
// It preserves Gin's broken-pipe handling by not attempting to write a response
|
||||
// when the client connection is already gone.
|
||||
func Recovery() gin.HandlerFunc {
|
||||
return gin.CustomRecoveryWithWriter(gin.DefaultErrorWriter, func(c *gin.Context, recovered any) {
|
||||
recoveredErr, _ := recovered.(error)
|
||||
|
||||
if isBrokenPipe(recoveredErr) {
|
||||
if recoveredErr != nil {
|
||||
_ = c.Error(recoveredErr)
|
||||
}
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if c.Writer.Written() {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
response.ErrorWithDetails(
|
||||
c,
|
||||
http.StatusInternalServerError,
|
||||
infraerrors.UnknownMessage,
|
||||
infraerrors.UnknownReason,
|
||||
nil,
|
||||
)
|
||||
c.Abort()
|
||||
})
|
||||
}
|
||||
|
||||
func isBrokenPipe(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var opErr *net.OpError
|
||||
if !errors.As(err, &opErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
var syscallErr *os.SyscallError
|
||||
if !errors.As(opErr.Err, &syscallErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
msg := strings.ToLower(syscallErr.Error())
|
||||
return strings.Contains(msg, "broken pipe") || strings.Contains(msg, "connection reset by peer")
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRecovery_PanicLogContainsInfo(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 临时替换 DefaultErrorWriter 以捕获日志输出
|
||||
var buf bytes.Buffer
|
||||
originalWriter := gin.DefaultErrorWriter
|
||||
gin.DefaultErrorWriter = &buf
|
||||
t.Cleanup(func() {
|
||||
gin.DefaultErrorWriter = originalWriter
|
||||
})
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Recovery())
|
||||
r.GET("/panic", func(c *gin.Context) {
|
||||
panic("custom panic message for test")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
|
||||
logOutput := buf.String()
|
||||
require.Contains(t, logOutput, "custom panic message for test", "日志应包含 panic 信息")
|
||||
require.Contains(t, logOutput, "recovery_test.go", "日志应包含堆栈跟踪文件名")
|
||||
}
|
||||
|
||||
func TestRecovery(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
handler gin.HandlerFunc
|
||||
wantHTTPCode int
|
||||
wantBody response.Response
|
||||
}{
|
||||
{
|
||||
name: "panic_returns_standard_json_500",
|
||||
handler: func(c *gin.Context) {
|
||||
panic("boom")
|
||||
},
|
||||
wantHTTPCode: http.StatusInternalServerError,
|
||||
wantBody: response.Response{
|
||||
Code: http.StatusInternalServerError,
|
||||
Message: infraerrors.UnknownMessage,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no_panic_passthrough",
|
||||
handler: func(c *gin.Context) {
|
||||
response.Success(c, gin.H{"ok": true})
|
||||
},
|
||||
wantHTTPCode: http.StatusOK,
|
||||
wantBody: response.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: map[string]any{"ok": true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "panic_after_write_does_not_override_body",
|
||||
handler: func(c *gin.Context) {
|
||||
response.Success(c, gin.H{"ok": true})
|
||||
panic("boom")
|
||||
},
|
||||
wantHTTPCode: http.StatusOK,
|
||||
wantBody: response.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: map[string]any{"ok": true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(Recovery())
|
||||
r.GET("/t", tt.handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tt.wantHTTPCode, w.Code)
|
||||
|
||||
var got response.Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.wantBody, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type testLogSink struct {
|
||||
mu sync.Mutex
|
||||
events []*logger.LogEvent
|
||||
}
|
||||
|
||||
func (s *testLogSink) WriteLogEvent(event *logger.LogEvent) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.events = append(s.events, event)
|
||||
}
|
||||
|
||||
func (s *testLogSink) list() []*logger.LogEvent {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]*logger.LogEvent, len(s.events))
|
||||
copy(out, s.events)
|
||||
return out
|
||||
}
|
||||
|
||||
func initMiddlewareTestLogger(t *testing.T) *testLogSink {
|
||||
return initMiddlewareTestLoggerWithLevel(t, "debug")
|
||||
}
|
||||
|
||||
func initMiddlewareTestLoggerWithLevel(t *testing.T, level string) *testLogSink {
|
||||
t.Helper()
|
||||
level = strings.TrimSpace(level)
|
||||
if level == "" {
|
||||
level = "debug"
|
||||
}
|
||||
if err := logger.Init(logger.InitOptions{
|
||||
Level: level,
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: false,
|
||||
ToFile: false,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("init logger: %v", err)
|
||||
}
|
||||
sink := &testLogSink{}
|
||||
logger.SetSink(sink)
|
||||
t.Cleanup(func() {
|
||||
logger.SetSink(nil)
|
||||
})
|
||||
return sink
|
||||
}
|
||||
|
||||
func TestRequestLogger_GenerateAndPropagateRequestID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(RequestLogger())
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
reqID, ok := c.Request.Context().Value(ctxkey.RequestID).(string)
|
||||
if !ok || reqID == "" {
|
||||
t.Fatalf("request_id missing in context")
|
||||
}
|
||||
if got := c.Writer.Header().Get(requestIDHeader); got != reqID {
|
||||
t.Fatalf("response header request_id mismatch, header=%q ctx=%q", got, reqID)
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
if w.Header().Get(requestIDHeader) == "" {
|
||||
t.Fatalf("X-Request-ID should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestLogger_KeepIncomingRequestID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(RequestLogger())
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
reqID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
|
||||
if reqID != "rid-fixed" {
|
||||
t.Fatalf("request_id=%q, want rid-fixed", reqID)
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set(requestIDHeader, "rid-fixed")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
if got := w.Header().Get(requestIDHeader); got != "rid-fixed" {
|
||||
t.Fatalf("header=%q, want rid-fixed", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_AccessLogIncludesCoreFields(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
sink := initMiddlewareTestLogger(t)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Logger())
|
||||
r.Use(func(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
ctx = context.WithValue(ctx, ctxkey.AccountID, int64(101))
|
||||
ctx = context.WithValue(ctx, ctxkey.Platform, "openai")
|
||||
ctx = context.WithValue(ctx, ctxkey.Model, "gpt-5")
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
})
|
||||
r.GET("/api/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusCreated)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
|
||||
events := sink.list()
|
||||
if len(events) == 0 {
|
||||
t.Fatalf("expected at least one log event")
|
||||
}
|
||||
found := false
|
||||
for _, event := range events {
|
||||
if event == nil || event.Message != "http request completed" {
|
||||
continue
|
||||
}
|
||||
found = true
|
||||
switch v := event.Fields["status_code"].(type) {
|
||||
case int:
|
||||
if v != http.StatusCreated {
|
||||
t.Fatalf("status_code field mismatch: %v", v)
|
||||
}
|
||||
case int64:
|
||||
if v != int64(http.StatusCreated) {
|
||||
t.Fatalf("status_code field mismatch: %v", v)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("status_code type mismatch: %T", v)
|
||||
}
|
||||
switch v := event.Fields["account_id"].(type) {
|
||||
case int64:
|
||||
if v != 101 {
|
||||
t.Fatalf("account_id field mismatch: %v", v)
|
||||
}
|
||||
case int:
|
||||
if v != 101 {
|
||||
t.Fatalf("account_id field mismatch: %v", v)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("account_id type mismatch: %T", v)
|
||||
}
|
||||
if event.Fields["platform"] != "openai" || event.Fields["model"] != "gpt-5" {
|
||||
t.Fatalf("platform/model mismatch: %+v", event.Fields)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("access log event not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_HealthPathSkipped(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
sink := initMiddlewareTestLogger(t)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Logger())
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
if len(sink.list()) != 0 {
|
||||
t.Fatalf("health endpoint should not write access log")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_AccessLogDroppedWhenLevelWarn(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
sink := initMiddlewareTestLoggerWithLevel(t, "warn")
|
||||
|
||||
r := gin.New()
|
||||
r.Use(RequestLogger())
|
||||
r.Use(Logger())
|
||||
r.GET("/api/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusCreated)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
|
||||
events := sink.list()
|
||||
for _, event := range events {
|
||||
if event != nil && event.Message == "http request completed" {
|
||||
t.Fatalf("access log should not be indexed when level=warn: %+v", event)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RequestBodyLimit 使用 MaxBytesReader 限制请求体大小。
|
||||
func RequestBodyLimit(maxBytes int64) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const requestIDHeader = "X-Request-ID"
|
||||
|
||||
// RequestLogger 在请求入口注入 request-scoped logger。
|
||||
func RequestLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
requestID := strings.TrimSpace(c.GetHeader(requestIDHeader))
|
||||
if requestID == "" {
|
||||
requestID = uuid.NewString()
|
||||
}
|
||||
c.Header(requestIDHeader, requestID)
|
||||
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.RequestID, requestID)
|
||||
clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string)
|
||||
|
||||
requestLogger := logger.With(
|
||||
zap.String("component", "http"),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("client_request_id", strings.TrimSpace(clientRequestID)),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
)
|
||||
|
||||
ctx = logger.IntoContext(ctx, requestLogger)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// CSPNonceKey is the context key for storing the CSP nonce
|
||||
CSPNonceKey = "csp_nonce"
|
||||
// NonceTemplate is the placeholder in CSP policy for nonce
|
||||
NonceTemplate = "__CSP_NONCE__"
|
||||
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
|
||||
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
|
||||
)
|
||||
|
||||
// GenerateNonce generates a cryptographically secure random nonce.
|
||||
// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。
|
||||
func GenerateNonce() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate CSP nonce: %w", err)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// GetNonceFromContext retrieves the CSP nonce from gin context
|
||||
func GetNonceFromContext(c *gin.Context) string {
|
||||
if nonce, exists := c.Get(CSPNonceKey); exists {
|
||||
if s, ok := nonce.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SecurityHeaders sets baseline security headers for all responses.
|
||||
// getFrameSrcOrigins is an optional function that returns extra origins to inject into frame-src;
|
||||
// pass nil to disable dynamic frame-src injection.
|
||||
func SecurityHeaders(cfg config.CSPConfig, getFrameSrcOrigins func() []string) gin.HandlerFunc {
|
||||
policy := strings.TrimSpace(cfg.Policy)
|
||||
if policy == "" {
|
||||
policy = config.DefaultCSPPolicy
|
||||
}
|
||||
|
||||
// Enhance policy with required directives (nonce placeholder and Cloudflare Insights)
|
||||
policy = enhanceCSPPolicy(policy)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
finalPolicy := policy
|
||||
if getFrameSrcOrigins != nil {
|
||||
for _, origin := range getFrameSrcOrigins() {
|
||||
if origin != "" {
|
||||
finalPolicy = addToDirective(finalPolicy, "frame-src", origin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
if isAPIRoutePath(c) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.Enabled {
|
||||
// Generate nonce for this request
|
||||
nonce, err := GenerateNonce()
|
||||
if err != nil {
|
||||
// crypto/rand 失败时降级为无 nonce 的 CSP 策略
|
||||
log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err)
|
||||
c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'unsafe-inline'"))
|
||||
} else {
|
||||
c.Set(CSPNonceKey, nonce)
|
||||
c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'nonce-"+nonce+"'"))
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func isAPIRoutePath(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||
return false
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
return strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/v1beta/") ||
|
||||
strings.HasPrefix(path, "/antigravity/") ||
|
||||
strings.HasPrefix(path, "/sora/") ||
|
||||
strings.HasPrefix(path, "/responses")
|
||||
}
|
||||
|
||||
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
|
||||
// This allows the application to work correctly even if the config file has an older CSP policy.
|
||||
func enhanceCSPPolicy(policy string) string {
|
||||
// Add nonce placeholder to script-src if not present
|
||||
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
|
||||
policy = addToDirective(policy, "script-src", NonceTemplate)
|
||||
}
|
||||
|
||||
// Add Cloudflare Insights domain to script-src if not present
|
||||
if !strings.Contains(policy, CloudflareInsightsDomain) {
|
||||
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
|
||||
}
|
||||
|
||||
return policy
|
||||
}
|
||||
|
||||
// addToDirective adds a value to a specific CSP directive.
|
||||
// If the directive doesn't exist, it will be added after default-src.
|
||||
func addToDirective(policy, directive, value string) string {
|
||||
// Find the directive in the policy
|
||||
directivePrefix := directive + " "
|
||||
idx := strings.Index(policy, directivePrefix)
|
||||
|
||||
if idx == -1 {
|
||||
// Directive not found, add it after default-src or at the beginning
|
||||
defaultSrcIdx := strings.Index(policy, "default-src ")
|
||||
if defaultSrcIdx != -1 {
|
||||
// Find the end of default-src directive (next semicolon)
|
||||
endIdx := strings.Index(policy[defaultSrcIdx:], ";")
|
||||
if endIdx != -1 {
|
||||
insertPos := defaultSrcIdx + endIdx + 1
|
||||
// Insert new directive after default-src
|
||||
return policy[:insertPos] + " " + directive + " 'self' " + value + ";" + policy[insertPos:]
|
||||
}
|
||||
}
|
||||
// Fallback: prepend the directive
|
||||
return directive + " 'self' " + value + "; " + policy
|
||||
}
|
||||
|
||||
// Find the end of this directive (next semicolon or end of string)
|
||||
endIdx := strings.Index(policy[idx:], ";")
|
||||
|
||||
if endIdx == -1 {
|
||||
// No semicolon found, directive goes to end of string
|
||||
return policy + " " + value
|
||||
}
|
||||
|
||||
// Insert value before the semicolon
|
||||
insertPos := idx + endIdx
|
||||
return policy[:insertPos] + " " + value + policy[insertPos:]
|
||||
}
|
||||
@@ -0,0 +1,388 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func TestGenerateNonce(t *testing.T) {
|
||||
t.Run("generates_valid_base64_string", func(t *testing.T) {
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be valid base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should decode to 16 bytes
|
||||
assert.Len(t, decoded, 16)
|
||||
})
|
||||
|
||||
t.Run("generates_unique_nonces", func(t *testing.T) {
|
||||
nonces := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
assert.False(t, nonces[nonce], "nonce should be unique")
|
||||
nonces[nonce] = true
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nonce_has_expected_length", func(t *testing.T) {
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
// 16 bytes -> 24 chars in base64 (with padding)
|
||||
assert.Len(t, nonce, 24)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetNonceFromContext(t *testing.T) {
|
||||
t.Run("returns_nonce_when_present", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
expectedNonce := "test-nonce-123"
|
||||
c.Set(CSPNonceKey, expectedNonce)
|
||||
|
||||
nonce := GetNonceFromContext(c)
|
||||
assert.Equal(t, expectedNonce, nonce)
|
||||
})
|
||||
|
||||
t.Run("returns_empty_string_when_not_present", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
nonce := GetNonceFromContext(c)
|
||||
assert.Empty(t, nonce)
|
||||
})
|
||||
|
||||
t.Run("returns_empty_for_wrong_type", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Set a non-string value
|
||||
c.Set(CSPNonceKey, 12345)
|
||||
|
||||
// Should return empty string for wrong type (safe type assertion)
|
||||
nonce := GetNonceFromContext(c)
|
||||
assert.Empty(t, nonce)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecurityHeaders(t *testing.T) {
|
||||
t.Run("sets_basic_security_headers", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{Enabled: false}
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
|
||||
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
|
||||
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
|
||||
})
|
||||
|
||||
t.Run("csp_disabled_no_csp_header", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{Enabled: false}
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Empty(t, w.Header().Get("Content-Security-Policy"))
|
||||
})
|
||||
|
||||
t.Run("csp_enabled_sets_csp_header", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "default-src 'self'",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
csp := w.Header().Get("Content-Security-Policy")
|
||||
assert.NotEmpty(t, csp)
|
||||
// Policy is auto-enhanced with nonce and Cloudflare Insights domain
|
||||
assert.Contains(t, csp, "default-src 'self'")
|
||||
assert.Contains(t, csp, "'nonce-")
|
||||
assert.Contains(t, csp, CloudflareInsightsDomain)
|
||||
})
|
||||
|
||||
t.Run("api_route_skips_csp_nonce_generation", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
|
||||
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
|
||||
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
|
||||
assert.Empty(t, w.Header().Get("Content-Security-Policy"))
|
||||
assert.Empty(t, GetNonceFromContext(c))
|
||||
})
|
||||
|
||||
t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "script-src 'self' __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
csp := w.Header().Get("Content-Security-Policy")
|
||||
assert.NotEmpty(t, csp)
|
||||
assert.NotContains(t, csp, "__CSP_NONCE__", "placeholder should be replaced")
|
||||
assert.Contains(t, csp, "'nonce-", "should contain nonce directive")
|
||||
|
||||
// Verify nonce is stored in context
|
||||
nonce := GetNonceFromContext(c)
|
||||
assert.NotEmpty(t, nonce)
|
||||
assert.Contains(t, csp, "'nonce-"+nonce+"'")
|
||||
})
|
||||
|
||||
t.Run("uses_default_policy_when_empty", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
csp := w.Header().Get("Content-Security-Policy")
|
||||
assert.NotEmpty(t, csp)
|
||||
// Default policy should contain these elements
|
||||
assert.Contains(t, csp, "default-src 'self'")
|
||||
})
|
||||
|
||||
t.Run("uses_default_policy_when_whitespace_only", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: " \t\n ",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
csp := w.Header().Get("Content-Security-Policy")
|
||||
assert.NotEmpty(t, csp)
|
||||
assert.Contains(t, csp, "default-src 'self'")
|
||||
})
|
||||
|
||||
t.Run("multiple_nonce_placeholders_replaced", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
csp := w.Header().Get("Content-Security-Policy")
|
||||
nonce := GetNonceFromContext(c)
|
||||
|
||||
// Count occurrences of the nonce
|
||||
count := strings.Count(csp, "'nonce-"+nonce+"'")
|
||||
assert.Equal(t, 2, count, "both placeholders should be replaced with same nonce")
|
||||
})
|
||||
|
||||
t.Run("calls_next_handler", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"}
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
nextCalled := false
|
||||
router := gin.New()
|
||||
router.Use(middleware)
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
nextCalled = true
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.True(t, nextCalled, "next handler should be called")
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("nonce_unique_per_request", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "script-src __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
nonces := make(map[string]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
nonce := GetNonceFromContext(c)
|
||||
assert.False(t, nonces[nonce], "nonce should be unique per request")
|
||||
nonces[nonce] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCSPNonceKey(t *testing.T) {
|
||||
t.Run("constant_value", func(t *testing.T) {
|
||||
assert.Equal(t, "csp_nonce", CSPNonceKey)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNonceTemplate(t *testing.T) {
|
||||
t.Run("constant_value", func(t *testing.T) {
|
||||
assert.Equal(t, "__CSP_NONCE__", NonceTemplate)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnhanceCSPPolicy(t *testing.T) {
|
||||
t.Run("adds_nonce_placeholder_if_missing", func(t *testing.T) {
|
||||
policy := "default-src 'self'; script-src 'self'"
|
||||
enhanced := enhanceCSPPolicy(policy)
|
||||
|
||||
assert.Contains(t, enhanced, NonceTemplate)
|
||||
assert.Contains(t, enhanced, CloudflareInsightsDomain)
|
||||
})
|
||||
|
||||
t.Run("does_not_duplicate_nonce_placeholder", func(t *testing.T) {
|
||||
policy := "default-src 'self'; script-src 'self' __CSP_NONCE__"
|
||||
enhanced := enhanceCSPPolicy(policy)
|
||||
|
||||
// Should not duplicate
|
||||
count := strings.Count(enhanced, NonceTemplate)
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
|
||||
t.Run("does_not_duplicate_cloudflare_domain", func(t *testing.T) {
|
||||
policy := "default-src 'self'; script-src 'self' https://static.cloudflareinsights.com"
|
||||
enhanced := enhanceCSPPolicy(policy)
|
||||
|
||||
count := strings.Count(enhanced, CloudflareInsightsDomain)
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
|
||||
t.Run("handles_policy_without_script_src", func(t *testing.T) {
|
||||
policy := "default-src 'self'"
|
||||
enhanced := enhanceCSPPolicy(policy)
|
||||
|
||||
assert.Contains(t, enhanced, "script-src")
|
||||
assert.Contains(t, enhanced, NonceTemplate)
|
||||
assert.Contains(t, enhanced, CloudflareInsightsDomain)
|
||||
})
|
||||
|
||||
t.Run("preserves_existing_nonce", func(t *testing.T) {
|
||||
policy := "script-src 'self' 'nonce-existing'"
|
||||
enhanced := enhanceCSPPolicy(policy)
|
||||
|
||||
// Should not add placeholder if nonce already exists
|
||||
assert.NotContains(t, enhanced, NonceTemplate)
|
||||
assert.Contains(t, enhanced, "'nonce-existing'")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddToDirective(t *testing.T) {
|
||||
t.Run("adds_to_existing_directive", func(t *testing.T) {
|
||||
policy := "script-src 'self'; style-src 'self'"
|
||||
result := addToDirective(policy, "script-src", "https://example.com")
|
||||
|
||||
assert.Contains(t, result, "script-src 'self' https://example.com")
|
||||
})
|
||||
|
||||
t.Run("creates_directive_if_not_exists", func(t *testing.T) {
|
||||
policy := "default-src 'self'"
|
||||
result := addToDirective(policy, "script-src", "https://example.com")
|
||||
|
||||
assert.Contains(t, result, "script-src")
|
||||
assert.Contains(t, result, "https://example.com")
|
||||
})
|
||||
|
||||
t.Run("handles_directive_at_end_without_semicolon", func(t *testing.T) {
|
||||
policy := "default-src 'self'; script-src 'self'"
|
||||
result := addToDirective(policy, "script-src", "https://example.com")
|
||||
|
||||
assert.Contains(t, result, "https://example.com")
|
||||
})
|
||||
|
||||
t.Run("handles_empty_policy", func(t *testing.T) {
|
||||
policy := ""
|
||||
result := addToDirective(policy, "script-src", "https://example.com")
|
||||
|
||||
assert.Contains(t, result, "script-src")
|
||||
assert.Contains(t, result, "https://example.com")
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkGenerateNonce(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = GenerateNonce()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSecurityHeadersMiddleware(b *testing.B) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "script-src 'self' __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
middleware(c)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// JWTAuthMiddleware JWT 认证中间件类型
|
||||
type JWTAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// AdminAuthMiddleware 管理员认证中间件类型
|
||||
type AdminAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// APIKeyAuthMiddleware API Key 认证中间件类型
|
||||
type APIKeyAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// ProviderSet 中间件层的依赖注入
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewJWTAuthMiddleware,
|
||||
NewAdminAuthMiddleware,
|
||||
NewAPIKeyAuthMiddleware,
|
||||
)
|
||||
Reference in New Issue
Block a user