fix: 修复代码审查中发现的P0/P1/P2问题
修复内容: 1. P0-01/P0-02: IAM Handler硬编码userID=1问题 - getUserIDFromContext现在从认证中间件的context获取真实userID - 添加middleware.GetOperatorID公开函数 - CheckScope方法添加未认证检查 2. P1-01: 审计服务幂等竞态条件 - 重构锁保护范围,整个检查和插入过程在锁保护下 - 使用defer确保锁正确释放 3. P1-02: 幂等中间件响应码硬编码 - 添加statusCapturingResponseWriter包装器 - 捕获实际的状态码和响应体用于幂等记录 4. P2-01: 事件ID时间戳冲突 - generateEventID改用UUID替代时间戳 5. P2-02: ListScopes硬编码 - 使用model.PredefinedScopes替代硬编码列表 所有supply-api测试通过
This commit is contained in:
@@ -65,7 +65,8 @@ func main() {
|
||||
|
||||
// 初始化审计存储
|
||||
// R-08: DatabaseAuditService 已创建 (audit/service/audit_service_db.go)
|
||||
// 需接口适配后可替换为: auditStore := audit.NewDatabaseAuditService(auditRepo)
|
||||
// 注意:由于domain层使用audit.AuditStore接口(旧),而DatabaseAuditService实现的是AuditStoreInterface(新)
|
||||
// 需要接口适配。暂保持内存存储,后续统一架构时处理。
|
||||
auditStore := audit.NewMemoryAuditStore()
|
||||
|
||||
// 初始化存储层
|
||||
|
||||
@@ -5,10 +5,11 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
)
|
||||
|
||||
@@ -181,10 +182,9 @@ func (s *InMemoryAuditStore) GetByIdempotencyKey(ctx context.Context, key string
|
||||
return nil, ErrEventNotFound
|
||||
}
|
||||
|
||||
// generateEventID 生成事件ID
|
||||
// generateEventID 生成事件ID(使用UUID避免冲突)
|
||||
func generateEventID() string {
|
||||
now := time.Now()
|
||||
return now.Format("20060102150405.000000") + fmt.Sprintf("%03d", now.Nanosecond()%1000000/1000) + "-evt"
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// AuditService 审计服务
|
||||
@@ -229,12 +229,13 @@ func (s *AuditService) CreateEvent(ctx context.Context, event *model.AuditEvent)
|
||||
event.EventID = generateEventID()
|
||||
}
|
||||
|
||||
// 处理幂等性 - 使用互斥锁保护检查和插入之间的时间窗口
|
||||
// 处理幂等性 - 整个检查和插入都在锁保护下,防止竞态条件
|
||||
if event.IdempotencyKey != "" {
|
||||
s.idempotencyMu.Lock()
|
||||
defer s.idempotencyMu.Unlock()
|
||||
|
||||
existing, err := s.store.GetByIdempotencyKey(ctx, event.IdempotencyKey)
|
||||
if err == nil && existing != nil {
|
||||
s.idempotencyMu.Unlock()
|
||||
// 检查payload是否相同
|
||||
if isSamePayload(existing, event) {
|
||||
// 重放同参 - 返回200
|
||||
@@ -254,10 +255,21 @@ func (s *AuditService) CreateEvent(ctx context.Context, event *model.AuditEvent)
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
s.idempotencyMu.Unlock()
|
||||
|
||||
// 首次创建 - 在锁保护下插入
|
||||
err = s.store.Emit(ctx, event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CreateEventResult{
|
||||
EventID: event.EventID,
|
||||
StatusCode: 201,
|
||||
Status: "created",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 首次创建 - 返回201
|
||||
// 无幂等键的直接插入
|
||||
err := s.store.Emit(ctx, event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/iam/model"
|
||||
"lijiaoqiao/supply-api/internal/iam/service"
|
||||
"lijiaoqiao/supply-api/internal/middleware"
|
||||
)
|
||||
|
||||
// IAMHandler IAM HTTP处理器
|
||||
@@ -287,15 +289,14 @@ func (h *IAMHandler) DeleteRole(w http.ResponseWriter, r *http.Request, roleCode
|
||||
|
||||
// ListScopes 处理列出所有Scope请求
|
||||
func (h *IAMHandler) ListScopes(w http.ResponseWriter, r *http.Request) {
|
||||
// 从预定义Scope列表获取
|
||||
scopes := []map[string]interface{}{
|
||||
{"scope_code": "platform:read", "scope_name": "读取平台配置", "scope_type": "platform"},
|
||||
{"scope_code": "platform:write", "scope_name": "修改平台配置", "scope_type": "platform"},
|
||||
{"scope_code": "platform:admin", "scope_name": "平台级管理", "scope_type": "platform"},
|
||||
{"scope_code": "tenant:read", "scope_name": "读取租户信息", "scope_type": "platform"},
|
||||
{"scope_code": "supply:account:read", "scope_name": "读取供应账号", "scope_type": "supply"},
|
||||
{"scope_code": "consumer:apikey:create", "scope_name": "创建API Key", "scope_type": "consumer"},
|
||||
{"scope_code": "router:invoke", "scope_name": "调用模型", "scope_type": "router"},
|
||||
// 从预定义Scope列表获取(完整的scope定义在model/scope.go的PredefinedScopes中)
|
||||
scopes := make([]map[string]interface{}, 0, len(model.PredefinedScopes))
|
||||
for _, scope := range model.PredefinedScopes {
|
||||
scopes = append(scopes, map[string]interface{}{
|
||||
"scope_code": scope.Code,
|
||||
"scope_name": scope.Name,
|
||||
"scope_type": scope.Type,
|
||||
})
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
@@ -376,8 +377,11 @@ func (h *IAMHandler) CheckScope(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// 从context获取userID(实际应用中应从认证中间件获取)
|
||||
userID := int64(1) // 模拟
|
||||
userID := getUserIDFromContext(r.Context())
|
||||
if userID == 0 {
|
||||
writeError(w, http.StatusUnauthorized, "UNAUTHORIZED", "user not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
hasScope, err := h.iamService.CheckScope(r.Context(), userID, scope)
|
||||
if err != nil {
|
||||
@@ -497,8 +501,7 @@ func RequireScope(scope string, iamService service.IAMServiceInterface) func(htt
|
||||
}
|
||||
}
|
||||
|
||||
// getUserIDFromContext 从context获取userID(实际应用中应从认证中间件获取)
|
||||
// getUserIDFromContext 从context获取userID
|
||||
func getUserIDFromContext(ctx context.Context) int64 {
|
||||
// TODO: 从认证中间件获取真实的userID
|
||||
return 1
|
||||
return middleware.GetOperatorID(ctx)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/iam/service"
|
||||
"lijiaoqiao/supply-api/internal/middleware"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -695,6 +696,8 @@ func TestIAMHandler_CheckScope_HasScope(t *testing.T) {
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/iam/check-scope?scope=platform:read", nil)
|
||||
ctx := middleware.WithOperatorID(context.Background(), 1)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -728,6 +731,8 @@ func TestIAMHandler_CheckScope_NoScope(t *testing.T) {
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/iam/check-scope?scope=platform:write", nil)
|
||||
ctx := middleware.WithOperatorID(context.Background(), 1)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -1153,6 +1158,8 @@ func TestIAMHandler_handleCheckScope_GET(t *testing.T) {
|
||||
handler := NewIAMHandler(svc)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/iam/check-scope?scope=platform:read", nil)
|
||||
ctx := middleware.WithOperatorID(context.Background(), 1)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -1227,12 +1234,15 @@ func TestRequireScope(t *testing.T) {
|
||||
// getUserIDFromContext 测试
|
||||
|
||||
func TestGetUserIDFromContext(t *testing.T) {
|
||||
// act
|
||||
// act - 没有设置时返回0
|
||||
ctx := context.Background()
|
||||
userID := getUserIDFromContext(ctx)
|
||||
assert.Equal(t, int64(0), userID)
|
||||
|
||||
// assert - 默认返回1
|
||||
assert.Equal(t, int64(1), userID)
|
||||
// act - 设置operatorID时返回正确的值
|
||||
ctx = middleware.WithOperatorID(context.Background(), 123)
|
||||
userID = getUserIDFromContext(ctx)
|
||||
assert.Equal(t, int64(123), userID)
|
||||
}
|
||||
|
||||
// toRoleResponse 测试
|
||||
|
||||
@@ -171,8 +171,11 @@ func (m *IdempotencyMiddleware) Wrap(handler IdempotentHandler) http.HandlerFunc
|
||||
lockedRecord.PayloadHash = payloadHash
|
||||
}
|
||||
|
||||
// 执行实际业务处理
|
||||
err = handler(ctx, w, r, lockedRecord)
|
||||
// 创建包装器以捕获实际的状态码和响应体
|
||||
wrappedWriter := &statusCapturingResponseWriter{ResponseWriter: w}
|
||||
|
||||
// 执行实际业务处理,使用包装器捕获响应
|
||||
err = handler(ctx, wrappedWriter, r, lockedRecord)
|
||||
|
||||
// 根据处理结果更新幂等记录
|
||||
if err != nil {
|
||||
@@ -182,11 +185,12 @@ func (m *IdempotencyMiddleware) Wrap(handler IdempotentHandler) http.HandlerFunc
|
||||
return
|
||||
}
|
||||
|
||||
// 业务处理成功,更新为成功状态
|
||||
// 注意:这里需要从w中获取实际的响应码和body
|
||||
// 简化处理:使用200
|
||||
successBody, _ := json.Marshal(map[string]interface{}{"status": "ok"})
|
||||
_ = m.idempotencyRepo.UpdateSuccess(ctx, lockedRecord.ID, http.StatusOK, successBody)
|
||||
// 业务处理成功,使用捕获的实际状态码和body更新幂等记录
|
||||
successBody := wrappedWriter.body
|
||||
if len(successBody) == 0 {
|
||||
successBody, _ = json.Marshal(map[string]interface{}{"status": "ok"})
|
||||
}
|
||||
_ = m.idempotencyRepo.UpdateSuccess(ctx, lockedRecord.ID, wrappedWriter.statusCode, successBody)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,6 +234,23 @@ func writeIdempotentReplay(w http.ResponseWriter, status int, body json.RawMessa
|
||||
}
|
||||
}
|
||||
|
||||
// statusCapturingResponseWriter 包装http.ResponseWriter以捕获状态码
|
||||
type statusCapturingResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
body []byte
|
||||
}
|
||||
|
||||
func (w *statusCapturingResponseWriter) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *statusCapturingResponseWriter) Write(b []byte) (int, error) {
|
||||
w.body = append(w.body, b...)
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// context keys
|
||||
type contextKey string
|
||||
|
||||
@@ -265,3 +286,8 @@ func getOperatorID(ctx context.Context) int64 {
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetOperatorID 公开函数,从context获取操作者ID
|
||||
func GetOperatorID(ctx context.Context) int64 {
|
||||
return getOperatorID(ctx)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user