From f34333dc099654dbfc81283022bacc3988c4d042 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 3 Apr 2026 12:25:22 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E5=AE=A1=E6=9F=A5=E4=B8=AD=E5=8F=91=E7=8E=B0=E7=9A=84P0/P1/P2?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复内容: 1. P0-01/P0-02: IAM Handler硬编码userID=1问题 - getUserIDFromContext现在从认证中间件的context获取真实userID - 添加middleware.GetOperatorID公开函数 - CheckScope方法添加未认证检查 2. P1-01: 审计服务幂等竞态条件 - 重构锁保护范围,整个检查和插入过程在锁保护下 - 使用defer确保锁正确释放 3. P1-02: 幂等中间件响应码硬编码 - 添加statusCapturingResponseWriter包装器 - 捕获实际的状态码和响应体用于幂等记录 4. P2-01: 事件ID时间戳冲突 - generateEventID改用UUID替代时间戳 5. P2-02: ListScopes硬编码 - 使用model.PredefinedScopes替代硬编码列表 所有supply-api测试通过 --- supply-api/cmd/supply-api/main.go | 3 +- .../internal/audit/service/audit_service.go | 28 +++++++++---- .../internal/iam/handler/iam_handler.go | 31 +++++++------- .../iam/handler/iam_handler_real_test.go | 16 ++++++-- supply-api/internal/middleware/idempotency.go | 40 +++++++++++++++---- 5 files changed, 85 insertions(+), 33 deletions(-) diff --git a/supply-api/cmd/supply-api/main.go b/supply-api/cmd/supply-api/main.go index 177fbf3..20b5465 100644 --- a/supply-api/cmd/supply-api/main.go +++ b/supply-api/cmd/supply-api/main.go @@ -65,7 +65,8 @@ func main() { // 初始化审计存储 // R-08: DatabaseAuditService 已创建 (audit/service/audit_service_db.go) - // 需接口适配后可替换为: auditStore := audit.NewDatabaseAuditService(auditRepo) + // 注意:由于domain层使用audit.AuditStore接口(旧),而DatabaseAuditService实现的是AuditStoreInterface(新) + // 需要接口适配。暂保持内存存储,后续统一架构时处理。 auditStore := audit.NewMemoryAuditStore() // 初始化存储层 diff --git a/supply-api/internal/audit/service/audit_service.go b/supply-api/internal/audit/service/audit_service.go index 7d1a640..0116793 100644 --- a/supply-api/internal/audit/service/audit_service.go +++ b/supply-api/internal/audit/service/audit_service.go @@ -5,10 +5,11 @@ import ( "crypto/sha256" "encoding/hex" "errors" - "fmt" "sync" "time" + "github.com/google/uuid" + "lijiaoqiao/supply-api/internal/audit/model" ) @@ -181,10 +182,9 @@ func (s *InMemoryAuditStore) GetByIdempotencyKey(ctx context.Context, key string return nil, ErrEventNotFound } -// generateEventID 生成事件ID +// generateEventID 生成事件ID(使用UUID避免冲突) func generateEventID() string { - now := time.Now() - return now.Format("20060102150405.000000") + fmt.Sprintf("%03d", now.Nanosecond()%1000000/1000) + "-evt" + return uuid.New().String() } // AuditService 审计服务 @@ -229,12 +229,13 @@ func (s *AuditService) CreateEvent(ctx context.Context, event *model.AuditEvent) event.EventID = generateEventID() } - // 处理幂等性 - 使用互斥锁保护检查和插入之间的时间窗口 + // 处理幂等性 - 整个检查和插入都在锁保护下,防止竞态条件 if event.IdempotencyKey != "" { s.idempotencyMu.Lock() + defer s.idempotencyMu.Unlock() + existing, err := s.store.GetByIdempotencyKey(ctx, event.IdempotencyKey) if err == nil && existing != nil { - s.idempotencyMu.Unlock() // 检查payload是否相同 if isSamePayload(existing, event) { // 重放同参 - 返回200 @@ -254,10 +255,21 @@ func (s *AuditService) CreateEvent(ctx context.Context, event *model.AuditEvent) }, nil } } - s.idempotencyMu.Unlock() + + // 首次创建 - 在锁保护下插入 + err = s.store.Emit(ctx, event) + if err != nil { + return nil, err + } + + return &CreateEventResult{ + EventID: event.EventID, + StatusCode: 201, + Status: "created", + }, nil } - // 首次创建 - 返回201 + // 无幂等键的直接插入 err := s.store.Emit(ctx, event) if err != nil { return nil, err diff --git a/supply-api/internal/iam/handler/iam_handler.go b/supply-api/internal/iam/handler/iam_handler.go index 64dd9dc..a48d1bd 100644 --- a/supply-api/internal/iam/handler/iam_handler.go +++ b/supply-api/internal/iam/handler/iam_handler.go @@ -6,7 +6,9 @@ import ( "net/http" "strconv" + "lijiaoqiao/supply-api/internal/iam/model" "lijiaoqiao/supply-api/internal/iam/service" + "lijiaoqiao/supply-api/internal/middleware" ) // IAMHandler IAM HTTP处理器 @@ -287,15 +289,14 @@ func (h *IAMHandler) DeleteRole(w http.ResponseWriter, r *http.Request, roleCode // ListScopes 处理列出所有Scope请求 func (h *IAMHandler) ListScopes(w http.ResponseWriter, r *http.Request) { - // 从预定义Scope列表获取 - scopes := []map[string]interface{}{ - {"scope_code": "platform:read", "scope_name": "读取平台配置", "scope_type": "platform"}, - {"scope_code": "platform:write", "scope_name": "修改平台配置", "scope_type": "platform"}, - {"scope_code": "platform:admin", "scope_name": "平台级管理", "scope_type": "platform"}, - {"scope_code": "tenant:read", "scope_name": "读取租户信息", "scope_type": "platform"}, - {"scope_code": "supply:account:read", "scope_name": "读取供应账号", "scope_type": "supply"}, - {"scope_code": "consumer:apikey:create", "scope_name": "创建API Key", "scope_type": "consumer"}, - {"scope_code": "router:invoke", "scope_name": "调用模型", "scope_type": "router"}, + // 从预定义Scope列表获取(完整的scope定义在model/scope.go的PredefinedScopes中) + scopes := make([]map[string]interface{}, 0, len(model.PredefinedScopes)) + for _, scope := range model.PredefinedScopes { + scopes = append(scopes, map[string]interface{}{ + "scope_code": scope.Code, + "scope_name": scope.Name, + "scope_type": scope.Type, + }) } writeJSON(w, http.StatusOK, map[string]interface{}{ @@ -376,8 +377,11 @@ func (h *IAMHandler) CheckScope(w http.ResponseWriter, r *http.Request) { return } - // 从context获取userID(实际应用中应从认证中间件获取) - userID := int64(1) // 模拟 + userID := getUserIDFromContext(r.Context()) + if userID == 0 { + writeError(w, http.StatusUnauthorized, "UNAUTHORIZED", "user not authenticated") + return + } hasScope, err := h.iamService.CheckScope(r.Context(), userID, scope) if err != nil { @@ -497,8 +501,7 @@ func RequireScope(scope string, iamService service.IAMServiceInterface) func(htt } } -// getUserIDFromContext 从context获取userID(实际应用中应从认证中间件获取) +// getUserIDFromContext 从context获取userID func getUserIDFromContext(ctx context.Context) int64 { - // TODO: 从认证中间件获取真实的userID - return 1 + return middleware.GetOperatorID(ctx) } diff --git a/supply-api/internal/iam/handler/iam_handler_real_test.go b/supply-api/internal/iam/handler/iam_handler_real_test.go index e347cd6..21a9800 100644 --- a/supply-api/internal/iam/handler/iam_handler_real_test.go +++ b/supply-api/internal/iam/handler/iam_handler_real_test.go @@ -9,6 +9,7 @@ import ( "testing" "lijiaoqiao/supply-api/internal/iam/service" + "lijiaoqiao/supply-api/internal/middleware" "github.com/stretchr/testify/assert" ) @@ -695,6 +696,8 @@ func TestIAMHandler_CheckScope_HasScope(t *testing.T) { }) req := httptest.NewRequest("GET", "/api/v1/iam/check-scope?scope=platform:read", nil) + ctx := middleware.WithOperatorID(context.Background(), 1) + req = req.WithContext(ctx) // act rec := httptest.NewRecorder() @@ -728,6 +731,8 @@ func TestIAMHandler_CheckScope_NoScope(t *testing.T) { }) req := httptest.NewRequest("GET", "/api/v1/iam/check-scope?scope=platform:write", nil) + ctx := middleware.WithOperatorID(context.Background(), 1) + req = req.WithContext(ctx) // act rec := httptest.NewRecorder() @@ -1153,6 +1158,8 @@ func TestIAMHandler_handleCheckScope_GET(t *testing.T) { handler := NewIAMHandler(svc) req := httptest.NewRequest("GET", "/api/v1/iam/check-scope?scope=platform:read", nil) + ctx := middleware.WithOperatorID(context.Background(), 1) + req = req.WithContext(ctx) // act rec := httptest.NewRecorder() @@ -1227,12 +1234,15 @@ func TestRequireScope(t *testing.T) { // getUserIDFromContext 测试 func TestGetUserIDFromContext(t *testing.T) { - // act + // act - 没有设置时返回0 ctx := context.Background() userID := getUserIDFromContext(ctx) + assert.Equal(t, int64(0), userID) - // assert - 默认返回1 - assert.Equal(t, int64(1), userID) + // act - 设置operatorID时返回正确的值 + ctx = middleware.WithOperatorID(context.Background(), 123) + userID = getUserIDFromContext(ctx) + assert.Equal(t, int64(123), userID) } // toRoleResponse 测试 diff --git a/supply-api/internal/middleware/idempotency.go b/supply-api/internal/middleware/idempotency.go index aea521e..686838c 100644 --- a/supply-api/internal/middleware/idempotency.go +++ b/supply-api/internal/middleware/idempotency.go @@ -171,8 +171,11 @@ func (m *IdempotencyMiddleware) Wrap(handler IdempotentHandler) http.HandlerFunc lockedRecord.PayloadHash = payloadHash } - // 执行实际业务处理 - err = handler(ctx, w, r, lockedRecord) + // 创建包装器以捕获实际的状态码和响应体 + wrappedWriter := &statusCapturingResponseWriter{ResponseWriter: w} + + // 执行实际业务处理,使用包装器捕获响应 + err = handler(ctx, wrappedWriter, r, lockedRecord) // 根据处理结果更新幂等记录 if err != nil { @@ -182,11 +185,12 @@ func (m *IdempotencyMiddleware) Wrap(handler IdempotentHandler) http.HandlerFunc return } - // 业务处理成功,更新为成功状态 - // 注意:这里需要从w中获取实际的响应码和body - // 简化处理:使用200 - successBody, _ := json.Marshal(map[string]interface{}{"status": "ok"}) - _ = m.idempotencyRepo.UpdateSuccess(ctx, lockedRecord.ID, http.StatusOK, successBody) + // 业务处理成功,使用捕获的实际状态码和body更新幂等记录 + successBody := wrappedWriter.body + if len(successBody) == 0 { + successBody, _ = json.Marshal(map[string]interface{}{"status": "ok"}) + } + _ = m.idempotencyRepo.UpdateSuccess(ctx, lockedRecord.ID, wrappedWriter.statusCode, successBody) } } @@ -230,6 +234,23 @@ func writeIdempotentReplay(w http.ResponseWriter, status int, body json.RawMessa } } +// statusCapturingResponseWriter 包装http.ResponseWriter以捕获状态码 +type statusCapturingResponseWriter struct { + http.ResponseWriter + statusCode int + body []byte +} + +func (w *statusCapturingResponseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *statusCapturingResponseWriter) Write(b []byte) (int, error) { + w.body = append(w.body, b...) + return w.ResponseWriter.Write(b) +} + // context keys type contextKey string @@ -265,3 +286,8 @@ func getOperatorID(ctx context.Context) int64 { } return 0 } + +// GetOperatorID 公开函数,从context获取操作者ID +func GetOperatorID(ctx context.Context) int64 { + return getOperatorID(ctx) +}