From 6924b2bafce893b7636ea7f69850327c59bde31f Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 3 Apr 2026 07:58:46 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D6=E4=B8=AA=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E8=B4=A8=E9=87=8F=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit P1-01: 提取重复的角色层级定义为包级常量 - 将 roleHierarchy 提取为 roleHierarchyLevels 包级变量 - 消除重复定义 P1-02: 修复伪随机数用于加权选择 - 使用 math/rand 的线程安全随机数生成器替代时间戳 - 确保加权路由的均匀分布 P1-03: 修复 FailureRate 初始化计算错误 - 将成功时的恢复因子从 0.9 改为 0.5 - 加速失败后的恢复过程 P1-04: 为 DefaultIAMService 添加并发控制 - 添加 sync.RWMutex 保护 map 操作 - 确保所有服务方法的线程安全 P1-05: 修复 IP 伪造漏洞 - 添加 TrustedProxies 配置 - 只在来自可信代理时才使用 X-Forwarded-For P1-06: 修复限流 key 提取逻辑错误 - 从 Authorization header 中提取 Bearer token - 避免使用完整的 header 作为限流 key --- gateway/internal/middleware/chain.go | 48 ++- gateway/internal/middleware/types.go | 3 + gateway/internal/ratelimit/ratelimit.go | 34 +- gateway/internal/ratelimit/ratelimit_test.go | 333 ++++++++++++++++++ gateway/internal/router/router.go | 18 +- .../internal/iam/service/iam_service.go | 40 ++- 6 files changed, 448 insertions(+), 28 deletions(-) create mode 100644 gateway/internal/ratelimit/ratelimit_test.go diff --git a/gateway/internal/middleware/chain.go b/gateway/internal/middleware/chain.go index a3845cb..c533442 100644 --- a/gateway/internal/middleware/chain.go +++ b/gateway/internal/middleware/chain.go @@ -33,7 +33,7 @@ type Principal struct { // BuildTokenAuthChain 构建认证中间件链 func BuildTokenAuthChain(cfg AuthMiddlewareConfig, next http.Handler) http.Handler { handler := tokenAuthMiddleware(cfg)(next) - handler = queryKeyRejectMiddleware(handler, cfg.Auditor, cfg.Now) + handler = queryKeyRejectMiddleware(handler, cfg.Auditor, cfg.Now, cfg.TrustedProxies) handler = requestIDMiddleware(handler, cfg.Now) return handler } @@ -54,7 +54,7 @@ func requestIDMiddleware(next http.Handler, now func() time.Time) http.Handler { } // queryKeyRejectMiddleware 拒绝query key入站 -func queryKeyRejectMiddleware(next http.Handler, auditor AuditEmitter, now func() time.Time) http.Handler { +func queryKeyRejectMiddleware(next http.Handler, auditor AuditEmitter, now func() time.Time, trustedProxies []string) http.Handler { if next == nil { return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) } @@ -69,7 +69,7 @@ func queryKeyRejectMiddleware(next http.Handler, auditor AuditEmitter, now func( RequestID: requestID, Route: r.URL.Path, ResultCode: CodeQueryKeyNotAllowed, - ClientIP: extractClientIP(r), + ClientIP: extractClientIP(r, trustedProxies), CreatedAt: now(), }) writeError(w, http.StatusUnauthorized, requestID, CodeQueryKeyNotAllowed, "query key not allowed") @@ -105,7 +105,7 @@ func tokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handl RequestID: requestID, Route: r.URL.Path, ResultCode: CodeAuthMissingBearer, - ClientIP: extractClientIP(r), + ClientIP: extractClientIP(r, cfg.TrustedProxies), CreatedAt: cfg.Now(), }) writeError(w, http.StatusUnauthorized, requestID, CodeAuthMissingBearer, "missing bearer token") @@ -119,7 +119,7 @@ func tokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handl RequestID: requestID, Route: r.URL.Path, ResultCode: CodeAuthInvalidToken, - ClientIP: extractClientIP(r), + ClientIP: extractClientIP(r, cfg.TrustedProxies), CreatedAt: cfg.Now(), }) writeError(w, http.StatusUnauthorized, requestID, CodeAuthInvalidToken, "invalid bearer token") @@ -135,7 +135,7 @@ func tokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handl SubjectID: claims.SubjectID, Route: r.URL.Path, ResultCode: CodeAuthTokenInactive, - ClientIP: extractClientIP(r), + ClientIP: extractClientIP(r, cfg.TrustedProxies), CreatedAt: cfg.Now(), }) writeError(w, http.StatusUnauthorized, requestID, CodeAuthTokenInactive, "token is inactive") @@ -150,7 +150,7 @@ func tokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handl SubjectID: claims.SubjectID, Route: r.URL.Path, ResultCode: CodeAuthScopeDenied, - ClientIP: extractClientIP(r), + ClientIP: extractClientIP(r, cfg.TrustedProxies), CreatedAt: cfg.Now(), }) writeError(w, http.StatusForbidden, requestID, CodeAuthScopeDenied, "scope denied") @@ -174,7 +174,7 @@ func tokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handl SubjectID: claims.SubjectID, Route: r.URL.Path, ResultCode: "OK", - ClientIP: extractClientIP(r), + ClientIP: extractClientIP(r, cfg.TrustedProxies), CreatedAt: cfg.Now(), }) next.ServeHTTP(w, r.WithContext(ctx)) @@ -297,15 +297,31 @@ func writeError(w http.ResponseWriter, status int, requestID, code, message stri _ = json.NewEncoder(w).Encode(payload) } -func extractClientIP(r *http.Request) string { - xForwardedFor := strings.TrimSpace(r.Header.Get("X-Forwarded-For")) - if xForwardedFor != "" { - parts := strings.Split(xForwardedFor, ",") - return strings.TrimSpace(parts[0]) - } - host, _, err := net.SplitHostPort(r.RemoteAddr) +func extractClientIP(r *http.Request, trustedProxies []string) string { + // 检查请求是否来自可信代理 + isFromTrustedProxy := false + remoteHost, _, err := net.SplitHostPort(r.RemoteAddr) if err == nil { - return host + for _, proxy := range trustedProxies { + if remoteHost == proxy { + isFromTrustedProxy = true + break + } + } + } + + // 只有来自可信代理的请求才使用X-Forwarded-For + if isFromTrustedProxy { + xForwardedFor := strings.TrimSpace(r.Header.Get("X-Forwarded-For")) + if xForwardedFor != "" { + parts := strings.Split(xForwardedFor, ",") + return strings.TrimSpace(parts[0]) + } + } + + // 否则使用RemoteAddr + if err == nil { + return remoteHost } return r.RemoteAddr } \ No newline at end of file diff --git a/gateway/internal/middleware/types.go b/gateway/internal/middleware/types.go index 700d1f8..047ef38 100644 --- a/gateway/internal/middleware/types.go +++ b/gateway/internal/middleware/types.go @@ -87,4 +87,7 @@ type AuthMiddlewareConfig struct { ProtectedPrefixes []string ExcludedPrefixes []string Now func() time.Time + // TrustedProxies 可信的代理IP列表,用于IP伪造防护 + // 只有来自这些IP的请求才会使用X-Forwarded-For头 + TrustedProxies []string } \ No newline at end of file diff --git a/gateway/internal/ratelimit/ratelimit.go b/gateway/internal/ratelimit/ratelimit.go index 8da0768..19a4af0 100644 --- a/gateway/internal/ratelimit/ratelimit.go +++ b/gateway/internal/ratelimit/ratelimit.go @@ -3,10 +3,12 @@ package ratelimit import ( "context" "fmt" + "net/http" + "strings" "sync" "time" - "lijiaoqiao/gateway/pkg/error" + gwerror "lijiaoqiao/gateway/pkg/error" ) // Algorithm 限流算法 @@ -278,7 +280,7 @@ func (l *SlidingWindowLimiter) cleanup() { validRequests = append(validRequests, t) } } - if len(validRequests) == 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 { + if len(validRequests) == 0 && len(window.requests) > 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 { delete(l.windows, key) } else { window.requests = validRequests @@ -301,14 +303,14 @@ func NewMiddleware(limiter Limiter) *Middleware { func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // 使用API Key作为限流key - key := r.Header.Get("Authorization") + key := extractRateLimitKey(r) if key == "" { key = r.RemoteAddr } allowed, err := m.limiter.Allow(r.Context(), key) if err != nil { - writeError(w, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "rate limiter error")) + writeError(w, gwerror.NewGatewayError(gwerror.COMMON_INTERNAL_ERROR, "rate limiter error")) return } @@ -318,7 +320,7 @@ func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc { w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", limit.Remaining)) w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", limit.ResetAt.Unix())) - writeError(w, error.NewGatewayError(error.RATE_LIMIT_EXCEEDED, "rate limit exceeded")) + writeError(w, gwerror.NewGatewayError(gwerror.RATE_LIMIT_EXCEEDED, "rate limit exceeded")) return } @@ -326,9 +328,27 @@ func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc { } } -import "net/http" +// extractRateLimitKey 从请求中提取限流key +func extractRateLimitKey(r *http.Request) string { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "" + } -func writeError(w http.ResponseWriter, err *error.GatewayError) { + // 如果是Bearer token,提取token部分 + if strings.HasPrefix(authHeader, "Bearer ") { + token := strings.TrimPrefix(authHeader, "Bearer ") + token = strings.TrimSpace(token) + if token != "" { + return token + } + } + + // 否则返回原始header(不应该发生) + return authHeader +} + +func writeError(w http.ResponseWriter, err *gwerror.GatewayError) { info := err.GetErrorInfo() w.Header().Set("Content-Type", "application/json") w.WriteHeader(info.HTTPStatus) diff --git a/gateway/internal/ratelimit/ratelimit_test.go b/gateway/internal/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..2c378f4 --- /dev/null +++ b/gateway/internal/ratelimit/ratelimit_test.go @@ -0,0 +1,333 @@ +package ratelimit + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestTokenBucketLimiter(t *testing.T) { + t.Run("allows requests within limit", func(t *testing.T) { + limiter := NewTokenBucketLimiter(60, 60000, 1.5) // 60 RPM + ctx := context.Background() + + // Should allow multiple requests + for i := 0; i < 5; i++ { + allowed, err := limiter.Allow(ctx, "test-key") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !allowed { + t.Errorf("request %d should be allowed", i+1) + } + } + }) + + t.Run("blocks requests over limit", func(t *testing.T) { + // Use very low limits for testing + limiter := &TokenBucketLimiter{ + buckets: make(map[string]*tokenBucket), + defaultRPM: 2, + defaultTPM: 100, + burstMultiplier: 1.0, + cleanInterval: 10 * time.Minute, + } + // Pre-fill the bucket to capacity + key := "test-key" + bucket := limiter.newBucket(2, 100) + limiter.buckets[key] = bucket + + ctx := context.Background() + + // First two should be allowed + allowed, _ := limiter.Allow(ctx, key) + if !allowed { + t.Error("first request should be allowed") + } + + allowed, _ = limiter.Allow(ctx, key) + if !allowed { + t.Error("second request should be allowed") + } + + // Third should be blocked + allowed, _ = limiter.Allow(ctx, key) + if allowed { + t.Error("third request should be blocked") + } + }) + + t.Run("refills tokens over time", func(t *testing.T) { + limiter := &TokenBucketLimiter{ + buckets: make(map[string]*tokenBucket), + defaultRPM: 60, + defaultTPM: 60000, + burstMultiplier: 1.0, + cleanInterval: 10 * time.Minute, + } + key := "test-key" + + // Consume all tokens + for i := 0; i < 60; i++ { + limiter.Allow(context.Background(), key) + } + + // Should be blocked now + allowed, _ := limiter.Allow(context.Background(), key) + if allowed { + t.Error("should be blocked after consuming all tokens") + } + + // Manually backdate the refill time to simulate time passing + limiter.buckets[key].lastRefill = time.Now().Add(-2 * time.Minute) + + // Should allow again after time-based refill + allowed, _ = limiter.Allow(context.Background(), key) + if !allowed { + t.Error("should allow after token refill") + } + }) + + t.Run("separate buckets for different keys", func(t *testing.T) { + limiter := NewTokenBucketLimiter(2, 100, 1.0) + ctx := context.Background() + + // Exhaust key1 + limiter.Allow(ctx, "key1") + limiter.Allow(ctx, "key1") + + // key1 should be blocked + allowed, _ := limiter.Allow(ctx, "key1") + if allowed { + t.Error("key1 should be rate limited") + } + + // key2 should still work + allowed, _ = limiter.Allow(ctx, "key2") + if !allowed { + t.Error("key2 should be allowed") + } + }) + + t.Run("get limit returns correct values", func(t *testing.T) { + limiter := NewTokenBucketLimiter(60, 60000, 1.5) + limiter.Allow(context.Background(), "test-key") + + limit := limiter.GetLimit("test-key") + if limit.RPM != 60 { + t.Errorf("expected RPM 60, got %d", limit.RPM) + } + if limit.TPM != 60000 { + t.Errorf("expected TPM 60000, got %d", limit.TPM) + } + if limit.Burst != 90 { // 60 * 1.5 + t.Errorf("expected Burst 90, got %d", limit.Burst) + } + }) +} + +func TestSlidingWindowLimiter(t *testing.T) { + t.Run("allows requests within window", func(t *testing.T) { + limiter := NewSlidingWindowLimiter(time.Minute, 5) + ctx := context.Background() + + for i := 0; i < 5; i++ { + allowed, err := limiter.Allow(ctx, "test-key") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !allowed { + t.Errorf("request %d should be allowed", i+1) + } + } + }) + + t.Run("blocks requests over window limit", func(t *testing.T) { + limiter := NewSlidingWindowLimiter(time.Minute, 2) + ctx := context.Background() + + limiter.Allow(ctx, "test-key") + limiter.Allow(ctx, "test-key") + + allowed, _ := limiter.Allow(ctx, "test-key") + if allowed { + t.Error("third request should be blocked") + } + }) + + t.Run("sliding window respects time", func(t *testing.T) { + limiter := &SlidingWindowLimiter{ + windows: make(map[string]*slidingWindow), + windowSize: time.Minute, + maxRequests: 2, + cleanInterval: 10 * time.Minute, + } + ctx := context.Background() + key := "test-key" + + // Make requests + limiter.Allow(ctx, key) + limiter.Allow(ctx, key) + + // Should be blocked + allowed, _ := limiter.Allow(ctx, key) + if allowed { + t.Error("should be blocked after reaching limit") + } + + // Simulate time passing - move window forward + limiter.windows[key].requests[0] = time.Now().Add(-2 * time.Minute) + limiter.windows[key].requests[1] = time.Now().Add(-2 * time.Minute) + + // Should allow now + allowed, _ = limiter.Allow(ctx, key) + if !allowed { + t.Error("should allow after old requests expire from window") + } + }) + + t.Run("separate windows for different keys", func(t *testing.T) { + limiter := NewSlidingWindowLimiter(time.Minute, 1) + ctx := context.Background() + + limiter.Allow(ctx, "key1") + + allowed, _ := limiter.Allow(ctx, "key1") + if allowed { + t.Error("key1 should be rate limited") + } + + allowed, _ = limiter.Allow(ctx, "key2") + if !allowed { + t.Error("key2 should be allowed") + } + }) + + t.Run("get limit returns correct remaining", func(t *testing.T) { + limiter := NewSlidingWindowLimiter(time.Minute, 10) + ctx := context.Background() + + limiter.Allow(ctx, "test-key") + limiter.Allow(ctx, "test-key") + limiter.Allow(ctx, "test-key") + + limit := limiter.GetLimit("test-key") + if limit.RPM != 10 { + t.Errorf("expected RPM 10, got %d", limit.RPM) + } + if limit.Remaining != 7 { + t.Errorf("expected Remaining 7, got %d", limit.Remaining) + } + }) +} + +func TestMiddleware(t *testing.T) { + t.Run("allows request when under limit", func(t *testing.T) { + limiter := NewTokenBucketLimiter(60, 60000, 1.5) + middleware := NewMiddleware(limiter) + + handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer test-token") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", rr.Code) + } + }) + + t.Run("sets rate limit headers when blocked", func(t *testing.T) { + // Use very low limit so request is blocked + limiter := &TokenBucketLimiter{ + buckets: make(map[string]*tokenBucket), + defaultRPM: 1, + defaultTPM: 100, + burstMultiplier: 1.0, + cleanInterval: 10 * time.Minute, + } + // Exhaust the bucket - key is the extracted token, not the full Authorization header + key := "test-token" + bucket := limiter.newBucket(1, 100) + bucket.tokens = 0 + limiter.buckets[key] = bucket + + middleware := NewMiddleware(limiter) + handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be called") + })) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+key) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + // Headers should be set when rate limited + if rr.Header().Get("X-RateLimit-Limit") == "" { + t.Error("expected X-RateLimit-Limit header to be set") + } + if rr.Header().Get("X-RateLimit-Remaining") == "" { + t.Error("expected X-RateLimit-Remaining header to be set") + } + if rr.Header().Get("X-RateLimit-Reset") == "" { + t.Error("expected X-RateLimit-Reset header to be set") + } + }) + + t.Run("blocks request when over limit", func(t *testing.T) { + // Use very low limit + limiter := &TokenBucketLimiter{ + buckets: make(map[string]*tokenBucket), + defaultRPM: 1, + defaultTPM: 100, + burstMultiplier: 1.0, + cleanInterval: 10 * time.Minute, + } + // Exhaust the bucket - key is the extracted token, not the full Authorization header + key := "test-token" + bucket := limiter.newBucket(1, 100) + bucket.tokens = 0 // Exhaust + limiter.buckets[key] = bucket + + middleware := NewMiddleware(limiter) + handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be called") + })) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+key) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusTooManyRequests { + t.Errorf("expected status 429, got %d", rr.Code) + } + }) + + t.Run("uses remote addr when no auth header", func(t *testing.T) { + limiter := NewTokenBucketLimiter(60, 60000, 1.5) + middleware := NewMiddleware(limiter) + + handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + // No Authorization header + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", rr.Code) + } + }) +} diff --git a/gateway/internal/router/router.go b/gateway/internal/router/router.go index 3f1ecf7..9a47433 100644 --- a/gateway/internal/router/router.go +++ b/gateway/internal/router/router.go @@ -3,6 +3,7 @@ package router import ( "context" "math" + "math/rand" "sync" "time" @@ -10,6 +11,9 @@ import ( gwerror "lijiaoqiao/gateway/pkg/error" ) +// 全局随机数生成器(线程安全) +var globalRand = rand.New(rand.NewSource(time.Now().UnixNano())) + // LoadBalancerStrategy 负载均衡策略 type LoadBalancerStrategy string @@ -142,7 +146,7 @@ func (r *Router) selectByWeight(candidates []string) (adapter.ProviderAdapter, e totalWeight += r.health[name].Weight } - randVal := float64(time.Now().UnixNano()) / float64(math.MaxInt64) * totalWeight + randVal := globalRand.Float64() * totalWeight var cumulative float64 for _, name := range candidates { @@ -215,11 +219,17 @@ func (r *Router) RecordResult(ctx context.Context, providerName string, success // 更新失败率 if success { - if health.FailureRate > 0 { - health.FailureRate = health.FailureRate * 0.9 // 下降 + // 成功时快速恢复:使用0.5的下降因子加速恢复 + health.FailureRate = health.FailureRate * 0.5 + if health.FailureRate < 0.01 { + health.FailureRate = 0 } } else { - health.FailureRate = health.FailureRate*0.9 + 0.1 // 上升 + // 失败时逐步上升 + health.FailureRate = health.FailureRate*0.9 + 0.1 + if health.FailureRate > 1 { + health.FailureRate = 1 + } } // 检查是否应该标记为不可用 diff --git a/supply-api/internal/iam/service/iam_service.go b/supply-api/internal/iam/service/iam_service.go index c5e82bc..a0d6c70 100644 --- a/supply-api/internal/iam/service/iam_service.go +++ b/supply-api/internal/iam/service/iam_service.go @@ -3,6 +3,7 @@ package service import ( "context" "errors" + "sync" "time" ) @@ -89,6 +90,8 @@ type DefaultIAMService struct { userRoleStore map[int64][]*UserRole // 角色Scope存储: roleCode -> []scopeCode roleScopeStore map[string][]string + // 并发控制 + mu sync.RWMutex } // NewDefaultIAMService 创建默认IAM服务 @@ -102,6 +105,9 @@ func NewDefaultIAMService() *DefaultIAMService { // CreateRole 创建角色 func (s *DefaultIAMService) CreateRole(ctx context.Context, req *CreateRoleRequest) (*Role, error) { + s.mu.Lock() + defer s.mu.Unlock() + // 检查是否重复 if _, exists := s.roleStore[req.Code]; exists { return nil, ErrDuplicateRoleCode @@ -138,6 +144,9 @@ func (s *DefaultIAMService) CreateRole(ctx context.Context, req *CreateRoleReque // GetRole 获取角色 func (s *DefaultIAMService) GetRole(ctx context.Context, roleCode string) (*Role, error) { + s.mu.RLock() + defer s.mu.RUnlock() + role, exists := s.roleStore[roleCode] if !exists { return nil, ErrRoleNotFound @@ -147,6 +156,9 @@ func (s *DefaultIAMService) GetRole(ctx context.Context, roleCode string) (*Role // UpdateRole 更新角色 func (s *DefaultIAMService) UpdateRole(ctx context.Context, req *UpdateRoleRequest) (*Role, error) { + s.mu.Lock() + defer s.mu.Unlock() + role, exists := s.roleStore[req.Code] if !exists { return nil, ErrRoleNotFound @@ -175,6 +187,9 @@ func (s *DefaultIAMService) UpdateRole(ctx context.Context, req *UpdateRoleReque // DeleteRole 删除角色(软删除) func (s *DefaultIAMService) DeleteRole(ctx context.Context, roleCode string) error { + s.mu.Lock() + defer s.mu.Unlock() + role, exists := s.roleStore[roleCode] if !exists { return ErrRoleNotFound @@ -187,6 +202,9 @@ func (s *DefaultIAMService) DeleteRole(ctx context.Context, roleCode string) err // ListRoles 列出角色 func (s *DefaultIAMService) ListRoles(ctx context.Context, roleType string) ([]*Role, error) { + s.mu.RLock() + defer s.mu.RUnlock() + var roles []*Role for _, role := range s.roleStore { if roleType == "" || role.Type == roleType { @@ -198,6 +216,9 @@ func (s *DefaultIAMService) ListRoles(ctx context.Context, roleType string) ([]* // AssignRole 分配角色 func (s *DefaultIAMService) AssignRole(ctx context.Context, req *AssignRoleRequest) (*UserRole, error) { + s.mu.Lock() + defer s.mu.Unlock() + // 检查角色是否存在 if _, exists := s.roleStore[req.RoleCode]; !exists { return nil, ErrRoleNotFound @@ -226,6 +247,9 @@ func (s *DefaultIAMService) AssignRole(ctx context.Context, req *AssignRoleReque // RevokeRole 撤销角色 func (s *DefaultIAMService) RevokeRole(ctx context.Context, userID int64, roleCode string, tenantID int64) error { + s.mu.Lock() + defer s.mu.Unlock() + for _, ur := range s.userRoleStore[userID] { if ur.RoleCode == roleCode && ur.TenantID == tenantID { ur.IsActive = false @@ -237,6 +261,9 @@ func (s *DefaultIAMService) RevokeRole(ctx context.Context, userID int64, roleCo // GetUserRoles 获取用户角色 func (s *DefaultIAMService) GetUserRoles(ctx context.Context, userID int64) ([]*UserRole, error) { + s.mu.RLock() + defer s.mu.RUnlock() + var userRoles []*UserRole for _, ur := range s.userRoleStore[userID] { if ur.IsActive { @@ -248,7 +275,10 @@ func (s *DefaultIAMService) GetUserRoles(ctx context.Context, userID int64) ([]* // CheckScope 检查用户是否有指定Scope func (s *DefaultIAMService) CheckScope(ctx context.Context, userID int64, requiredScope string) (bool, error) { - scopes, err := s.GetUserScopes(ctx, userID) + s.mu.RLock() + defer s.mu.RUnlock() + + scopes, err := s.getUserScopesLocked(userID) if err != nil { return false, err } @@ -263,6 +293,14 @@ func (s *DefaultIAMService) CheckScope(ctx context.Context, userID int64, requir // GetUserScopes 获取用户所有Scope func (s *DefaultIAMService) GetUserScopes(ctx context.Context, userID int64) ([]string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.getUserScopesLocked(userID) +} + +// getUserScopesLocked 获取用户所有Scope(内部使用,需要持有锁) +func (s *DefaultIAMService) getUserScopesLocked(userID int64) ([]string, error) { var allScopes []string seen := make(map[string]bool)