package middleware import ( "context" "fmt" "net/http" "strings" "time" "github.com/gin-gonic/gin" "golang.org/x/sync/singleflight" "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/cache" "github.com/user-management-system/internal/domain" apierrors "github.com/user-management-system/internal/pkg/errors" ) // Interfaces for dependency inversion (DIP) — middleware depends on these abstractions, not concrete types. type authUserRepository interface { GetByID(ctx context.Context, id int64) (*domain.User, error) } type authUserRoleRepository interface { GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error) } type AuthMiddleware struct { jwt *auth.JWT userRepo authUserRepository userRoleRepo authUserRoleRepository l1Cache *cache.L1Cache cacheManager *cache.CacheManager sfGroup singleflight.Group } func NewAuthMiddleware( jwt *auth.JWT, userRepo authUserRepository, userRoleRepo authUserRoleRepository, l1Cache *cache.L1Cache, ) *AuthMiddleware { return &AuthMiddleware{ jwt: jwt, userRepo: userRepo, userRoleRepo: userRoleRepo, l1Cache: l1Cache, } } func (m *AuthMiddleware) SetCacheManager(cm *cache.CacheManager) { m.cacheManager = cm } func (m *AuthMiddleware) Required() gin.HandlerFunc { return func(c *gin.Context) { token := m.extractToken(c) if token == "" { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "未提供认证令牌")) c.Abort() return } claims, err := m.jwt.ValidateAccessToken(token) if err != nil { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "无效的认证令牌")) c.Abort() return } if m.isJTIBlacklisted(c.Request.Context(), claims.JTI) { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "令牌已失效,请重新登录")) c.Abort() return } if m.isPasswordChangedSinceTokenIssued(c.Request.Context(), claims.UserID, claims.PCE) { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "密码已更新,请重新登录")) c.Abort() return } if !m.isUserActive(c.Request.Context(), claims.UserID) { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "账号不可用,请重新登录")) c.Abort() return } c.Set("user_id", claims.UserID) c.Set("username", claims.Username) c.Set("token_jti", claims.JTI) roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID) c.Set("role_codes", roleCodes) c.Set("permission_codes", permCodes) c.Next() } } func (m *AuthMiddleware) Optional() gin.HandlerFunc { return func(c *gin.Context) { token := m.extractToken(c) if token != "" { claims, err := m.jwt.ValidateAccessToken(token) if err == nil && !m.isJTIBlacklisted(c.Request.Context(), claims.JTI) && !m.isPasswordChangedSinceTokenIssued(c.Request.Context(), claims.UserID, claims.PCE) && m.isUserActive(c.Request.Context(), claims.UserID) { c.Set("user_id", claims.UserID) c.Set("username", claims.Username) c.Set("token_jti", claims.JTI) roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID) c.Set("role_codes", roleCodes) c.Set("permission_codes", permCodes) } } c.Next() } } func (m *AuthMiddleware) isJTIBlacklisted(ctx context.Context, jti string) bool { if jti == "" { return false } key := "jwt_blacklist:" + jti // 先检查 L1 缓存 if _, ok := m.l1Cache.Get(key); ok { return true } // L1 miss 时使用 singleflight 防止缓存击穿 // 多个并发请求只会触发一次 L2 查询 if m.cacheManager != nil { val, err, _ := m.sfGroup.Do(key, func() (interface{}, error) { found, _ := m.cacheManager.Get(ctx, key) return found, nil }) if err == nil && val != nil { // 回写 L1 缓存 m.l1Cache.Set(key, true, 5*time.Minute) return true } } return false } // isPasswordChangedSinceTokenIssued 检查用户密码是否在令牌发放后已更改 // 如果 tokenPCE 为 0(旧令牌),则不检查(向后兼容) func (m *AuthMiddleware) isPasswordChangedSinceTokenIssued(ctx context.Context, userID int64, tokenPCE int64) bool { if tokenPCE == 0 { // 旧令牌没有密码变更时间戳,不拦截 return false } if m.userRepo == nil { return false } user, err := m.userRepo.GetByID(ctx, userID) if err != nil || user.PasswordChangedAt.IsZero() { return false } // 如果令牌的 PCE < 用户密码变更时间,说明密码在令牌发放后已更改 return tokenPCE < user.PasswordChangedAt.Unix() } func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) { if m.userRoleRepo == nil { return nil, nil } cacheKey := fmt.Sprintf("user_perms:%d", userID) if cached, ok := m.l1Cache.Get(cacheKey); ok { if entry, ok := cached.(userPermEntry); ok { return entry.roles, entry.perms } } // 使用已优化的单次 JOIN 查询获取用户角色和权限 roles, permissions, err := m.userRoleRepo.GetUserRolesAndPermissions(ctx, userID) if err != nil || len(roles) == 0 { return nil, nil } roleCodes := make([]string, 0, len(roles)) for _, role := range roles { roleCodes = append(roleCodes, role.Code) } permCodes := make([]string, 0, len(permissions)) for _, perm := range permissions { permCodes = append(permCodes, perm.Code) } m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 30*time.Minute) return roleCodes, permCodes } func (m *AuthMiddleware) InvalidateUserPermCache(userID int64) { m.l1Cache.Delete(fmt.Sprintf("user_perms:%d", userID)) } func (m *AuthMiddleware) AddToBlacklist(jti string, ttl time.Duration) { if jti != "" && ttl > 0 { m.l1Cache.Set("jwt_blacklist:"+jti, true, ttl) } } func (m *AuthMiddleware) isUserActive(ctx context.Context, userID int64) bool { if m.userRepo == nil { return true } user, err := m.userRepo.GetByID(ctx, userID) if err != nil { return false } return user.Status == domain.UserStatusActive } func (m *AuthMiddleware) extractToken(c *gin.Context) string { authHeader := c.GetHeader("Authorization") if authHeader == "" { return "" } parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || parts[0] != "Bearer" { return "" } return parts[1] } type userPermEntry struct { roles []string perms []string }