package middleware import ( "context" "fmt" "net/http" "strings" "time" "github.com/gin-gonic/gin" "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/cache" "github.com/user-management-system/internal/domain" apierrors "github.com/user-management-system/internal/pkg/errors" "github.com/user-management-system/internal/repository" ) type AuthMiddleware struct { jwt *auth.JWT userRepo *repository.UserRepository userRoleRepo *repository.UserRoleRepository roleRepo *repository.RoleRepository rolePermissionRepo *repository.RolePermissionRepository permissionRepo *repository.PermissionRepository l1Cache *cache.L1Cache cacheManager *cache.CacheManager } func NewAuthMiddleware( jwt *auth.JWT, userRepo *repository.UserRepository, userRoleRepo *repository.UserRoleRepository, roleRepo *repository.RoleRepository, rolePermissionRepo *repository.RolePermissionRepository, permissionRepo *repository.PermissionRepository, ) *AuthMiddleware { return &AuthMiddleware{ jwt: jwt, userRepo: userRepo, userRoleRepo: userRoleRepo, roleRepo: roleRepo, rolePermissionRepo: rolePermissionRepo, permissionRepo: permissionRepo, l1Cache: cache.NewL1Cache(), } } func (m *AuthMiddleware) SetCacheManager(cm *cache.CacheManager) { m.cacheManager = cm } func (m *AuthMiddleware) Required() gin.HandlerFunc { return func(c *gin.Context) { token := m.extractToken(c) if token == "" { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "未提供认证令牌")) c.Abort() return } claims, err := m.jwt.ValidateAccessToken(token) if err != nil { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "无效的认证令牌")) c.Abort() return } if m.isJTIBlacklisted(claims.JTI) { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "令牌已失效,请重新登录")) c.Abort() return } if !m.isUserActive(c.Request.Context(), claims.UserID) { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "账号不可用,请重新登录")) c.Abort() return } c.Set("user_id", claims.UserID) c.Set("username", claims.Username) c.Set("token_jti", claims.JTI) roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID) c.Set("role_codes", roleCodes) c.Set("permission_codes", permCodes) c.Next() } } func (m *AuthMiddleware) Optional() gin.HandlerFunc { return func(c *gin.Context) { token := m.extractToken(c) if token != "" { claims, err := m.jwt.ValidateAccessToken(token) if err == nil && !m.isJTIBlacklisted(claims.JTI) && m.isUserActive(c.Request.Context(), claims.UserID) { c.Set("user_id", claims.UserID) c.Set("username", claims.Username) c.Set("token_jti", claims.JTI) roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID) c.Set("role_codes", roleCodes) c.Set("permission_codes", permCodes) } } c.Next() } } func (m *AuthMiddleware) isJTIBlacklisted(jti string) bool { if jti == "" { return false } key := "jwt_blacklist:" + jti if _, ok := m.l1Cache.Get(key); ok { return true } if m.cacheManager != nil { if _, ok := m.cacheManager.Get(context.Background(), key); ok { return true } } return false } func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) { if m.userRoleRepo == nil || m.roleRepo == nil || m.rolePermissionRepo == nil || m.permissionRepo == nil { return nil, nil } cacheKey := fmt.Sprintf("user_perms:%d", userID) if cached, ok := m.l1Cache.Get(cacheKey); ok { if entry, ok := cached.(userPermEntry); ok { return entry.roles, entry.perms } } roleIDs, err := m.userRoleRepo.GetRoleIDsByUserID(ctx, userID) if err != nil || len(roleIDs) == 0 { return nil, nil } // 收集所有角色ID(包括直接分配的角色和所有祖先角色) allRoleIDs := make([]int64, 0, len(roleIDs)*2) allRoleIDs = append(allRoleIDs, roleIDs...) for _, roleID := range roleIDs { ancestorIDs, err := m.roleRepo.GetAncestorIDs(ctx, roleID) if err == nil && len(ancestorIDs) > 0 { allRoleIDs = append(allRoleIDs, ancestorIDs...) } } // 去重 seen := make(map[int64]bool) uniqueRoleIDs := make([]int64, 0, len(allRoleIDs)) for _, id := range allRoleIDs { if !seen[id] { seen[id] = true uniqueRoleIDs = append(uniqueRoleIDs, id) } } roles, err := m.roleRepo.GetByIDs(ctx, roleIDs) if err != nil { return nil, nil } roleCodes := make([]string, 0, len(roles)) for _, role := range roles { roleCodes = append(roleCodes, role.Code) } permissionIDs, err := m.rolePermissionRepo.GetPermissionIDsByRoleIDs(ctx, uniqueRoleIDs) if err != nil || len(permissionIDs) == 0 { entry := userPermEntry{roles: roleCodes, perms: []string{}} m.l1Cache.Set(cacheKey, entry, 30*time.Minute) // PERF-01 优化:增加缓存 TTL 减少 DB 查询 return entry.roles, entry.perms } permissions, err := m.permissionRepo.GetByIDs(ctx, permissionIDs) if err != nil { return roleCodes, nil } permCodes := make([]string, 0, len(permissions)) for _, permission := range permissions { permCodes = append(permCodes, permission.Code) } m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 30*time.Minute) // PERF-01 优化:增加缓存 TTL 减少 DB 查询 return roleCodes, permCodes } func (m *AuthMiddleware) InvalidateUserPermCache(userID int64) { m.l1Cache.Delete(fmt.Sprintf("user_perms:%d", userID)) } func (m *AuthMiddleware) AddToBlacklist(jti string, ttl time.Duration) { if jti != "" && ttl > 0 { m.l1Cache.Set("jwt_blacklist:"+jti, true, ttl) } } func (m *AuthMiddleware) isUserActive(ctx context.Context, userID int64) bool { if m.userRepo == nil { return true } user, err := m.userRepo.GetByID(ctx, userID) if err != nil { return false } return user.Status == domain.UserStatusActive } func (m *AuthMiddleware) extractToken(c *gin.Context) string { authHeader := c.GetHeader("Authorization") if authHeader == "" { return "" } parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || parts[0] != "Bearer" { return "" } return parts[1] } type userPermEntry struct { roles []string perms []string }