From 5dbb530b76d5fb8aa77bb06c81fbcf7530e94a61 Mon Sep 17 00:00:00 2001 From: long-agent Date: Tue, 7 Apr 2026 17:46:25 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E5=88=A0=E9=99=A4=E6=9C=AA=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E7=9A=84=E5=AD=A4=E7=AB=8B=E5=8C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 清理以下未导入的包: - internal/response (未使用的响应结构体) - pkg/response (未使用的响应封装) - internal/model (TLSFingerprintProfile, ErrorPassthroughRule) - internal/models (SocialAccount, domain已有) - internal/pkg/response (未使用的响应封装) - internal/security/ratelimit (已迁移到middleware) 验证: go build ./... && go test ./... 通过 --- internal/model/error_passthrough_rule.go | 75 -- internal/model/tls_fingerprint_profile.go | 54 -- internal/models/social_account.go | 70 -- internal/pkg/response/response.go | 203 ------ internal/pkg/response/response_test.go | 788 ---------------------- internal/response/response.go | 50 -- internal/response/response_test.go | 34 - internal/security/ratelimit.go | 184 ----- pkg/response/response.go | 50 -- 9 files changed, 1508 deletions(-) delete mode 100644 internal/model/error_passthrough_rule.go delete mode 100644 internal/model/tls_fingerprint_profile.go delete mode 100644 internal/models/social_account.go delete mode 100644 internal/pkg/response/response.go delete mode 100644 internal/pkg/response/response_test.go delete mode 100644 internal/response/response.go delete mode 100644 internal/response/response_test.go delete mode 100644 internal/security/ratelimit.go delete mode 100644 pkg/response/response.go diff --git a/internal/model/error_passthrough_rule.go b/internal/model/error_passthrough_rule.go deleted file mode 100644 index 620736c..0000000 --- a/internal/model/error_passthrough_rule.go +++ /dev/null @@ -1,75 +0,0 @@ -// Package model 定义服务层使用的数据模型。 -package model - -import "time" - -// ErrorPassthroughRule 全局错误透传规则 -// 用于控制上游错误如何返回给客户端 -type ErrorPassthroughRule struct { - ID int64 `json:"id"` - Name string `json:"name"` // 规则名称 - Enabled bool `json:"enabled"` // 是否启用 - Priority int `json:"priority"` // 优先级(数字越小优先级越高) - ErrorCodes []int `json:"error_codes"` // 匹配的错误码列表(OR关系) - Keywords []string `json:"keywords"` // 匹配的关键词列表(OR关系) - MatchMode string `json:"match_mode"` // "any"(任一条件) 或 "all"(所有条件) - Platforms []string `json:"platforms"` // 适用平台列表 - PassthroughCode bool `json:"passthrough_code"` // 是否透传原始状态码 - ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用) - PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息 - CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用) - SkipMonitoring bool `json:"skip_monitoring"` // 是否跳过运维监控记录 - Description *string `json:"description"` // 规则描述 - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - -// MatchModeAny 表示任一条件匹配即可 -const MatchModeAny = "any" - -// MatchModeAll 表示所有条件都必须匹配 -const MatchModeAll = "all" - -// 支持的平台常量 -const ( - PlatformAnthropic = "anthropic" - PlatformOpenAI = "openai" - PlatformGemini = "gemini" - PlatformAntigravity = "antigravity" -) - -// AllPlatforms 返回所有支持的平台列表 -func AllPlatforms() []string { - return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity} -} - -// Validate 验证规则配置的有效性 -func (r *ErrorPassthroughRule) Validate() error { - if r.Name == "" { - return &ValidationError{Field: "name", Message: "name is required"} - } - if r.MatchMode != MatchModeAny && r.MatchMode != MatchModeAll { - return &ValidationError{Field: "match_mode", Message: "match_mode must be 'any' or 'all'"} - } - // 至少需要配置一个匹配条件(错误码或关键词) - if len(r.ErrorCodes) == 0 && len(r.Keywords) == 0 { - return &ValidationError{Field: "conditions", Message: "at least one error_code or keyword is required"} - } - if !r.PassthroughCode && (r.ResponseCode == nil || *r.ResponseCode <= 0) { - return &ValidationError{Field: "response_code", Message: "response_code is required when passthrough_code is false"} - } - if !r.PassthroughBody && (r.CustomMessage == nil || *r.CustomMessage == "") { - return &ValidationError{Field: "custom_message", Message: "custom_message is required when passthrough_body is false"} - } - return nil -} - -// ValidationError 表示验证错误 -type ValidationError struct { - Field string - Message string -} - -func (e *ValidationError) Error() string { - return e.Field + ": " + e.Message -} diff --git a/internal/model/tls_fingerprint_profile.go b/internal/model/tls_fingerprint_profile.go deleted file mode 100644 index 3037c9e..0000000 --- a/internal/model/tls_fingerprint_profile.go +++ /dev/null @@ -1,54 +0,0 @@ -// Package model 定义服务层使用的数据模型。 -package model - -import ( - "time" - - "github.com/user-management-system/internal/pkg/tlsfingerprint" -) - -// TLSFingerprintProfile TLS 指纹配置模板 -// 包含完整的 ClientHello 参数,用于模拟特定客户端的 TLS 握手特征 -type TLSFingerprintProfile struct { - ID int64 `json:"id"` - Name string `json:"name"` - Description *string `json:"description"` - EnableGREASE bool `json:"enable_grease"` - CipherSuites []uint16 `json:"cipher_suites"` - Curves []uint16 `json:"curves"` - PointFormats []uint16 `json:"point_formats"` - SignatureAlgorithms []uint16 `json:"signature_algorithms"` - ALPNProtocols []string `json:"alpn_protocols"` - SupportedVersions []uint16 `json:"supported_versions"` - KeyShareGroups []uint16 `json:"key_share_groups"` - PSKModes []uint16 `json:"psk_modes"` - Extensions []uint16 `json:"extensions"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - -// Validate 验证模板配置的有效性 -func (p *TLSFingerprintProfile) Validate() error { - if p.Name == "" { - return &ValidationError{Field: "name", Message: "name is required"} - } - return nil -} - -// ToTLSProfile 将领域模型转换为运行时使用的 tlsfingerprint.Profile -// 空切片字段会在 dialer 中 fallback 到内置默认值 -func (p *TLSFingerprintProfile) ToTLSProfile() *tlsfingerprint.Profile { - return &tlsfingerprint.Profile{ - Name: p.Name, - EnableGREASE: p.EnableGREASE, - CipherSuites: p.CipherSuites, - Curves: p.Curves, - PointFormats: p.PointFormats, - SignatureAlgorithms: p.SignatureAlgorithms, - ALPNProtocols: p.ALPNProtocols, - SupportedVersions: p.SupportedVersions, - KeyShareGroups: p.KeyShareGroups, - PSKModes: p.PSKModes, - Extensions: p.Extensions, - } -} diff --git a/internal/models/social_account.go b/internal/models/social_account.go deleted file mode 100644 index 4319cdc..0000000 --- a/internal/models/social_account.go +++ /dev/null @@ -1,70 +0,0 @@ -package models - -import ( - "encoding/json" - "time" -) - -// SocialAccount 社交账号绑定模型 -type SocialAccount struct { - ID uint64 `json:"id" db:"id"` - UserID uint64 `json:"user_id" db:"user_id"` - Provider string `json:"provider" db:"provider"` // wechat, qq, weibo, google, facebook, twitter - ProviderUserID string `json:"provider_user_id" db:"provider_user_id"` - ProviderUsername string `json:"provider_username" db:"provider_username"` - AccessToken string `json:"-" db:"access_token"` // 不返回给前端 - RefreshToken string `json:"-" db:"refresh_token"` - ExpiresAt *time.Time `json:"expires_at" db:"expires_at"` - RawData JSON `json:"-" db:"raw_data"` - IsPrimary bool `json:"is_primary" db:"is_primary"` - CreatedAt time.Time `json:"created_at" db:"created_at"` - UpdatedAt time.Time `json:"updated_at" db:"updated_at"` -} - -// SocialAccountInfo 返回给前端的社交账号信息(不含敏感信息) -type SocialAccountInfo struct { - ID uint64 `json:"id"` - Provider string `json:"provider"` - ProviderUserID string `json:"provider_user_id"` - ProviderUsername string `json:"provider_username"` - IsPrimary bool `json:"is_primary"` - CreatedAt time.Time `json:"created_at"` -} - -// ToInfo 转换为安全信息 -func (sa *SocialAccount) ToInfo() *SocialAccountInfo { - return &SocialAccountInfo{ - ID: sa.ID, - Provider: sa.Provider, - ProviderUserID: sa.ProviderUserID, - ProviderUsername: sa.ProviderUsername, - IsPrimary: sa.IsPrimary, - CreatedAt: sa.CreatedAt, - } -} - -// JSON 自定义JSON类型,用于存储RawData -type JSON struct { - Data interface{} -} - -// Scan 实现 sql.Scanner 接口 -func (j *JSON) Scan(value interface{}) error { - if value == nil { - j.Data = nil - return nil - } - bytes, ok := value.([]byte) - if !ok { - return nil - } - return json.Unmarshal(bytes, &j.Data) -} - -// Value 实现 driver.Valuer 接口 -func (j JSON) Value() (interface{}, error) { - if j.Data == nil { - return nil, nil - } - return json.Marshal(j.Data) -} diff --git a/internal/pkg/response/response.go b/internal/pkg/response/response.go deleted file mode 100644 index d2f2f35..0000000 --- a/internal/pkg/response/response.go +++ /dev/null @@ -1,203 +0,0 @@ -// Package response provides standardized HTTP response helpers. -package response - -import ( - "log" - "math" - "net/http" - - infraerrors "github.com/user-management-system/internal/pkg/errors" - "github.com/user-management-system/internal/util/logredact" - "github.com/gin-gonic/gin" -) - -// Response 标准API响应格式 -type Response struct { - Code int `json:"code"` - Message string `json:"message"` - Reason string `json:"reason,omitempty"` - Metadata map[string]string `json:"metadata,omitempty"` - Data any `json:"data,omitempty"` -} - -// PaginatedData 分页数据格式(匹配前端期望) -type PaginatedData struct { - Items any `json:"items"` - Total int64 `json:"total"` - Page int `json:"page"` - PageSize int `json:"page_size"` - Pages int `json:"pages"` -} - -// Success 返回成功响应 -func Success(c *gin.Context, data any) { - c.JSON(http.StatusOK, Response{ - Code: 0, - Message: "success", - Data: data, - }) -} - -// Created 返回创建成功响应 -func Created(c *gin.Context, data any) { - c.JSON(http.StatusCreated, Response{ - Code: 0, - Message: "success", - Data: data, - }) -} - -// Accepted 返回异步接受响应 (HTTP 202) -func Accepted(c *gin.Context, data any) { - c.JSON(http.StatusAccepted, Response{ - Code: 0, - Message: "accepted", - Data: data, - }) -} - -// Error 返回错误响应 -func Error(c *gin.Context, statusCode int, message string) { - c.JSON(statusCode, Response{ - Code: statusCode, - Message: message, - Reason: "", - Metadata: nil, - }) -} - -// ErrorWithDetails returns an error response compatible with the existing envelope while -// optionally providing structured error fields (reason/metadata). -func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) { - c.JSON(statusCode, Response{ - Code: statusCode, - Message: message, - Reason: reason, - Metadata: metadata, - }) -} - -// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response. -// It returns true if an error was written. -func ErrorFrom(c *gin.Context, err error) bool { - if err == nil { - return false - } - - statusCode, status := infraerrors.ToHTTP(err) - - // Log internal errors with full details for debugging - if statusCode >= 500 && c.Request != nil { - log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error())) - } - - ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata) - return true -} - -// BadRequest 返回400错误 -func BadRequest(c *gin.Context, message string) { - Error(c, http.StatusBadRequest, message) -} - -// Unauthorized 返回401错误 -func Unauthorized(c *gin.Context, message string) { - Error(c, http.StatusUnauthorized, message) -} - -// Forbidden 返回403错误 -func Forbidden(c *gin.Context, message string) { - Error(c, http.StatusForbidden, message) -} - -// NotFound 返回404错误 -func NotFound(c *gin.Context, message string) { - Error(c, http.StatusNotFound, message) -} - -// InternalError 返回500错误 -func InternalError(c *gin.Context, message string) { - Error(c, http.StatusInternalServerError, message) -} - -// Paginated 返回分页数据 -func Paginated(c *gin.Context, items any, total int64, page, pageSize int) { - pages := int(math.Ceil(float64(total) / float64(pageSize))) - if pages < 1 { - pages = 1 - } - - Success(c, PaginatedData{ - Items: items, - Total: total, - Page: page, - PageSize: pageSize, - Pages: pages, - }) -} - -// PaginationResult 分页结果(与pagination.PaginationResult兼容) -type PaginationResult struct { - Total int64 - Page int - PageSize int - Pages int -} - -// PaginatedWithResult 使用PaginationResult返回分页数据 -func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) { - if pagination == nil { - Success(c, PaginatedData{ - Items: items, - Total: 0, - Page: 1, - PageSize: 20, - Pages: 1, - }) - return - } - - Success(c, PaginatedData{ - Items: items, - Total: pagination.Total, - Page: pagination.Page, - PageSize: pagination.PageSize, - Pages: pagination.Pages, - }) -} - -// ParsePagination 解析分页参数 -func ParsePagination(c *gin.Context) (page, pageSize int) { - page = 1 - pageSize = 20 - - if p := c.Query("page"); p != "" { - if val, err := parseInt(p); err == nil && val > 0 { - page = val - } - } - - // 支持 page_size 和 limit 两种参数名 - if ps := c.Query("page_size"); ps != "" { - if val, err := parseInt(ps); err == nil && val > 0 && val <= 1000 { - pageSize = val - } - } else if l := c.Query("limit"); l != "" { - if val, err := parseInt(l); err == nil && val > 0 && val <= 1000 { - pageSize = val - } - } - - return page, pageSize -} - -func parseInt(s string) (int, error) { - var result int - for _, c := range s { - if c < '0' || c > '9' { - return 0, nil - } - result = result*10 + int(c-'0') - } - return result, nil -} diff --git a/internal/pkg/response/response_test.go b/internal/pkg/response/response_test.go deleted file mode 100644 index ba31d7e..0000000 --- a/internal/pkg/response/response_test.go +++ /dev/null @@ -1,788 +0,0 @@ -//go:build unit - -package response - -import ( - "encoding/json" - "errors" - "net/http" - "net/http/httptest" - "testing" - - errors2 "github.com/user-management-system/internal/pkg/errors" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -// ---------- 辅助函数 ---------- - -// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体 -func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response { - t.Helper() - var got Response - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) - return got -} - -// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData) -func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) { - t.Helper() - // 先用 raw json 解析,因为 Data 是 any 类型 - var raw struct { - Code int `json:"code"` - Message string `json:"message"` - Reason string `json:"reason,omitempty"` - Data json.RawMessage `json:"data,omitempty"` - } - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw)) - - var pd PaginatedData - require.NoError(t, json.Unmarshal(raw.Data, &pd)) - - return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd -} - -// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination -func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) { - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil) - return w, c -} - -// ---------- 现有测试 ---------- - -func TestErrorWithDetails(t *testing.T) { - gin.SetMode(gin.TestMode) - - tests := []struct { - name string - statusCode int - message string - reason string - metadata map[string]string - want Response - }{ - { - name: "plain_error", - statusCode: http.StatusBadRequest, - message: "invalid request", - want: Response{ - Code: http.StatusBadRequest, - Message: "invalid request", - }, - }, - { - name: "structured_error", - statusCode: http.StatusForbidden, - message: "no access", - reason: "FORBIDDEN", - metadata: map[string]string{"k": "v"}, - want: Response{ - Code: http.StatusForbidden, - Message: "no access", - Reason: "FORBIDDEN", - Metadata: map[string]string{"k": "v"}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata) - - require.Equal(t, tt.statusCode, w.Code) - - var got Response - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) - require.Equal(t, tt.want, got) - }) - } -} - -func TestErrorFrom(t *testing.T) { - gin.SetMode(gin.TestMode) - - tests := []struct { - name string - err error - wantWritten bool - wantHTTPCode int - wantBody Response - }{ - { - name: "nil_error", - err: nil, - wantWritten: false, - }, - { - name: "application_error", - err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}), - wantWritten: true, - wantHTTPCode: http.StatusForbidden, - wantBody: Response{ - Code: http.StatusForbidden, - Message: "no access", - Reason: "FORBIDDEN", - Metadata: map[string]string{"scope": "admin"}, - }, - }, - { - name: "bad_request_error", - err: errors2.BadRequest("INVALID_REQUEST", "invalid request"), - wantWritten: true, - wantHTTPCode: http.StatusBadRequest, - wantBody: Response{ - Code: http.StatusBadRequest, - Message: "invalid request", - Reason: "INVALID_REQUEST", - }, - }, - { - name: "unauthorized_error", - err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"), - wantWritten: true, - wantHTTPCode: http.StatusUnauthorized, - wantBody: Response{ - Code: http.StatusUnauthorized, - Message: "unauthorized", - Reason: "UNAUTHORIZED", - }, - }, - { - name: "not_found_error", - err: errors2.NotFound("NOT_FOUND", "not found"), - wantWritten: true, - wantHTTPCode: http.StatusNotFound, - wantBody: Response{ - Code: http.StatusNotFound, - Message: "not found", - Reason: "NOT_FOUND", - }, - }, - { - name: "conflict_error", - err: errors2.Conflict("CONFLICT", "conflict"), - wantWritten: true, - wantHTTPCode: http.StatusConflict, - wantBody: Response{ - Code: http.StatusConflict, - Message: "conflict", - Reason: "CONFLICT", - }, - }, - { - name: "unknown_error_defaults_to_500", - err: errors.New("boom"), - wantWritten: true, - wantHTTPCode: http.StatusInternalServerError, - wantBody: Response{ - Code: http.StatusInternalServerError, - Message: errors2.UnknownMessage, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - written := ErrorFrom(c, tt.err) - require.Equal(t, tt.wantWritten, written) - - if !tt.wantWritten { - require.Equal(t, 200, w.Code) - require.Empty(t, w.Body.String()) - return - } - - require.Equal(t, tt.wantHTTPCode, w.Code) - var got Response - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) - require.Equal(t, tt.wantBody, got) - }) - } -} - -// ---------- 新增测试 ---------- - -func TestSuccess(t *testing.T) { - gin.SetMode(gin.TestMode) - - tests := []struct { - name string - data any - wantCode int - wantBody Response - }{ - { - name: "返回字符串数据", - data: "hello", - wantCode: http.StatusOK, - wantBody: Response{Code: 0, Message: "success", Data: "hello"}, - }, - { - name: "返回nil数据", - data: nil, - wantCode: http.StatusOK, - wantBody: Response{Code: 0, Message: "success"}, - }, - { - name: "返回map数据", - data: map[string]string{"key": "value"}, - wantCode: http.StatusOK, - wantBody: Response{Code: 0, Message: "success"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - Success(c, tt.data) - - require.Equal(t, tt.wantCode, w.Code) - - // 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice - got := parseResponseBody(t, w) - require.Equal(t, 0, got.Code) - require.Equal(t, "success", got.Message) - - if tt.data == nil { - require.Nil(t, got.Data) - } else { - require.NotNil(t, got.Data) - } - }) - } -} - -func TestCreated(t *testing.T) { - gin.SetMode(gin.TestMode) - - tests := []struct { - name string - data any - wantCode int - }{ - { - name: "创建成功_返回数据", - data: map[string]int{"id": 42}, - wantCode: http.StatusCreated, - }, - { - name: "创建成功_nil数据", - data: nil, - wantCode: http.StatusCreated, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - Created(c, tt.data) - - require.Equal(t, tt.wantCode, w.Code) - - got := parseResponseBody(t, w) - require.Equal(t, 0, got.Code) - require.Equal(t, "success", got.Message) - }) - } -} - -func TestError(t *testing.T) { - gin.SetMode(gin.TestMode) - - tests := []struct { - name string - statusCode int - message string - }{ - { - name: "400错误", - statusCode: http.StatusBadRequest, - message: "bad request", - }, - { - name: "500错误", - statusCode: http.StatusInternalServerError, - message: "internal error", - }, - { - name: "自定义状态码", - statusCode: 418, - message: "I'm a teapot", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - Error(c, tt.statusCode, tt.message) - - require.Equal(t, tt.statusCode, w.Code) - - got := parseResponseBody(t, w) - require.Equal(t, tt.statusCode, got.Code) - require.Equal(t, tt.message, got.Message) - require.Empty(t, got.Reason) - require.Nil(t, got.Metadata) - require.Nil(t, got.Data) - }) - } -} - -func TestBadRequest(t *testing.T) { - gin.SetMode(gin.TestMode) - - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - BadRequest(c, "参数无效") - - require.Equal(t, http.StatusBadRequest, w.Code) - got := parseResponseBody(t, w) - require.Equal(t, http.StatusBadRequest, got.Code) - require.Equal(t, "参数无效", got.Message) -} - -func TestUnauthorized(t *testing.T) { - gin.SetMode(gin.TestMode) - - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - Unauthorized(c, "未登录") - - require.Equal(t, http.StatusUnauthorized, w.Code) - got := parseResponseBody(t, w) - require.Equal(t, http.StatusUnauthorized, got.Code) - require.Equal(t, "未登录", got.Message) -} - -func TestForbidden(t *testing.T) { - gin.SetMode(gin.TestMode) - - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - Forbidden(c, "无权限") - - require.Equal(t, http.StatusForbidden, w.Code) - got := parseResponseBody(t, w) - require.Equal(t, http.StatusForbidden, got.Code) - require.Equal(t, "无权限", got.Message) -} - -func TestNotFound(t *testing.T) { - gin.SetMode(gin.TestMode) - - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - NotFound(c, "资源不存在") - - require.Equal(t, http.StatusNotFound, w.Code) - got := parseResponseBody(t, w) - require.Equal(t, http.StatusNotFound, got.Code) - require.Equal(t, "资源不存在", got.Message) -} - -func TestInternalError(t *testing.T) { - gin.SetMode(gin.TestMode) - - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - InternalError(c, "服务器内部错误") - - require.Equal(t, http.StatusInternalServerError, w.Code) - got := parseResponseBody(t, w) - require.Equal(t, http.StatusInternalServerError, got.Code) - require.Equal(t, "服务器内部错误", got.Message) -} - -func TestPaginated(t *testing.T) { - gin.SetMode(gin.TestMode) - - tests := []struct { - name string - items any - total int64 - page int - pageSize int - wantPages int - wantTotal int64 - wantPage int - wantPageSize int - }{ - { - name: "标准分页_多页", - items: []string{"a", "b"}, - total: 25, - page: 1, - pageSize: 10, - wantPages: 3, - wantTotal: 25, - wantPage: 1, - wantPageSize: 10, - }, - { - name: "总数刚好整除", - items: []string{"a"}, - total: 20, - page: 2, - pageSize: 10, - wantPages: 2, - wantTotal: 20, - wantPage: 2, - wantPageSize: 10, - }, - { - name: "总数为0_pages至少为1", - items: []string{}, - total: 0, - page: 1, - pageSize: 10, - wantPages: 1, - wantTotal: 0, - wantPage: 1, - wantPageSize: 10, - }, - { - name: "单页数据", - items: []int{1, 2, 3}, - total: 3, - page: 1, - pageSize: 20, - wantPages: 1, - wantTotal: 3, - wantPage: 1, - wantPageSize: 20, - }, - { - name: "总数为1", - items: []string{"only"}, - total: 1, - page: 1, - pageSize: 10, - wantPages: 1, - wantTotal: 1, - wantPage: 1, - wantPageSize: 10, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - Paginated(c, tt.items, tt.total, tt.page, tt.pageSize) - - require.Equal(t, http.StatusOK, w.Code) - - resp, pd := parsePaginatedBody(t, w) - require.Equal(t, 0, resp.Code) - require.Equal(t, "success", resp.Message) - require.Equal(t, tt.wantTotal, pd.Total) - require.Equal(t, tt.wantPage, pd.Page) - require.Equal(t, tt.wantPageSize, pd.PageSize) - require.Equal(t, tt.wantPages, pd.Pages) - }) - } -} - -func TestPaginatedWithResult(t *testing.T) { - gin.SetMode(gin.TestMode) - - tests := []struct { - name string - items any - pagination *PaginationResult - wantTotal int64 - wantPage int - wantPageSize int - wantPages int - }{ - { - name: "正常分页结果", - items: []string{"a", "b"}, - pagination: &PaginationResult{ - Total: 50, - Page: 3, - PageSize: 10, - Pages: 5, - }, - wantTotal: 50, - wantPage: 3, - wantPageSize: 10, - wantPages: 5, - }, - { - name: "pagination为nil_使用默认值", - items: []string{}, - pagination: nil, - wantTotal: 0, - wantPage: 1, - wantPageSize: 20, - wantPages: 1, - }, - { - name: "单页结果", - items: []int{1}, - pagination: &PaginationResult{ - Total: 1, - Page: 1, - PageSize: 20, - Pages: 1, - }, - wantTotal: 1, - wantPage: 1, - wantPageSize: 20, - wantPages: 1, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - PaginatedWithResult(c, tt.items, tt.pagination) - - require.Equal(t, http.StatusOK, w.Code) - - resp, pd := parsePaginatedBody(t, w) - require.Equal(t, 0, resp.Code) - require.Equal(t, "success", resp.Message) - require.Equal(t, tt.wantTotal, pd.Total) - require.Equal(t, tt.wantPage, pd.Page) - require.Equal(t, tt.wantPageSize, pd.PageSize) - require.Equal(t, tt.wantPages, pd.Pages) - }) - } -} - -func TestParsePagination(t *testing.T) { - gin.SetMode(gin.TestMode) - - tests := []struct { - name string - query string - wantPage int - wantPageSize int - }{ - { - name: "无参数_使用默认值", - query: "", - wantPage: 1, - wantPageSize: 20, - }, - { - name: "仅指定page", - query: "page=3", - wantPage: 3, - wantPageSize: 20, - }, - { - name: "仅指定page_size", - query: "page_size=50", - wantPage: 1, - wantPageSize: 50, - }, - { - name: "同时指定page和page_size", - query: "page=2&page_size=30", - wantPage: 2, - wantPageSize: 30, - }, - { - name: "使用limit代替page_size", - query: "limit=15", - wantPage: 1, - wantPageSize: 15, - }, - { - name: "page_size优先于limit", - query: "page_size=25&limit=50", - wantPage: 1, - wantPageSize: 25, - }, - { - name: "page为0_使用默认值", - query: "page=0", - wantPage: 1, - wantPageSize: 20, - }, - { - name: "page_size超过1000_使用默认值", - query: "page_size=1001", - wantPage: 1, - wantPageSize: 20, - }, - { - name: "page_size恰好1000_有效", - query: "page_size=1000", - wantPage: 1, - wantPageSize: 1000, - }, - { - name: "page为非数字_使用默认值", - query: "page=abc", - wantPage: 1, - wantPageSize: 20, - }, - { - name: "page_size为非数字_使用默认值", - query: "page_size=xyz", - wantPage: 1, - wantPageSize: 20, - }, - { - name: "limit为非数字_使用默认值", - query: "limit=abc", - wantPage: 1, - wantPageSize: 20, - }, - { - name: "page_size为0_使用默认值", - query: "page_size=0", - wantPage: 1, - wantPageSize: 20, - }, - { - name: "limit为0_使用默认值", - query: "limit=0", - wantPage: 1, - wantPageSize: 20, - }, - { - name: "大页码", - query: "page=999&page_size=100", - wantPage: 999, - wantPageSize: 100, - }, - { - name: "page_size为1_最小有效值", - query: "page_size=1", - wantPage: 1, - wantPageSize: 1, - }, - { - name: "混合数字和字母的page", - query: "page=12a", - wantPage: 1, - wantPageSize: 20, - }, - { - name: "limit超过1000_使用默认值", - query: "limit=2000", - wantPage: 1, - wantPageSize: 20, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, c := newContextWithQuery(tt.query) - - page, pageSize := ParsePagination(c) - - require.Equal(t, tt.wantPage, page, "page 不符合预期") - require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期") - }) - } -} - -func Test_parseInt(t *testing.T) { - tests := []struct { - name string - input string - wantVal int - wantErr bool - }{ - { - name: "正常数字", - input: "123", - wantVal: 123, - wantErr: false, - }, - { - name: "零", - input: "0", - wantVal: 0, - wantErr: false, - }, - { - name: "单个数字", - input: "5", - wantVal: 5, - wantErr: false, - }, - { - name: "大数字", - input: "99999", - wantVal: 99999, - wantErr: false, - }, - { - name: "包含字母_返回0", - input: "abc", - wantVal: 0, - wantErr: false, - }, - { - name: "数字开头接字母_返回0", - input: "12a", - wantVal: 0, - wantErr: false, - }, - { - name: "包含负号_返回0", - input: "-1", - wantVal: 0, - wantErr: false, - }, - { - name: "包含小数点_返回0", - input: "1.5", - wantVal: 0, - wantErr: false, - }, - { - name: "包含空格_返回0", - input: "1 2", - wantVal: 0, - wantErr: false, - }, - { - name: "空字符串", - input: "", - wantVal: 0, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - val, err := parseInt(tt.input) - if tt.wantErr { - require.Error(t, err) - } else { - require.NoError(t, err) - } - require.Equal(t, tt.wantVal, val) - }) - } -} diff --git a/internal/response/response.go b/internal/response/response.go deleted file mode 100644 index a7dbf82..0000000 --- a/internal/response/response.go +++ /dev/null @@ -1,50 +0,0 @@ -package response - -// Response 统一响应结构 -type Response struct { - Code int `json:"code"` - Message string `json:"message"` - Data interface{} `json:"data,omitempty"` -} - -// Success 成功响应 -func Success(data interface{}) *Response { - return &Response{ - Code: 0, - Message: "success", - Data: data, - } -} - -// Error 错误响应 -func Error(message string) *Response { - return &Response{ - Code: -1, - Message: message, - } -} - -// ErrorWithCode 带错误码的错误响应 -func ErrorWithCode(code int, message string) *Response { - return &Response{ - Code: code, - Message: message, - } -} - -// WithData 带扩展数据的成功响应 -func WithData(data interface{}, extra map[string]interface{}) *Response { - payload, ok := data.(map[string]interface{}) - if !ok { - payload = map[string]interface{}{ - "items": data, - } - } - - for k, v := range extra { - payload[k] = v - } - - resp := Success(payload) - return resp -} diff --git a/internal/response/response_test.go b/internal/response/response_test.go deleted file mode 100644 index 96d7f7f..0000000 --- a/internal/response/response_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package response - -import "testing" - -func TestWithDataWrapsSlicesAndMergesExtra(t *testing.T) { - resp := WithData([]string{"a", "b"}, map[string]interface{}{ - "total": 2, - "page": 1, - }) - - data, ok := resp.Data.(map[string]interface{}) - if !ok { - t.Fatalf("expected map payload, got %T", resp.Data) - } - if data["total"] != 2 { - t.Fatalf("expected total=2, got %v", data["total"]) - } - items, ok := data["items"].([]string) - if !ok || len(items) != 2 { - t.Fatalf("expected items slice to be preserved, got %#v", data["items"]) - } -} - -func TestWithDataPreservesMapPayload(t *testing.T) { - resp := WithData(map[string]interface{}{"user": "alice"}, map[string]interface{}{"page": 1}) - - data := resp.Data.(map[string]interface{}) - if data["user"] != "alice" { - t.Fatalf("expected user=alice, got %v", data["user"]) - } - if data["page"] != 1 { - t.Fatalf("expected page=1, got %v", data["page"]) - } -} diff --git a/internal/security/ratelimit.go b/internal/security/ratelimit.go deleted file mode 100644 index e236b22..0000000 --- a/internal/security/ratelimit.go +++ /dev/null @@ -1,184 +0,0 @@ -package security - -import ( - "sync" - "time" -) - -// RateLimitAlgorithm 限流算法类型 -type RateLimitAlgorithm string - -const ( - AlgorithmTokenBucket RateLimitAlgorithm = "token_bucket" - AlgorithmLeakyBucket RateLimitAlgorithm = "leaky_bucket" - AlgorithmSlidingWindow RateLimitAlgorithm = "sliding_window" - AlgorithmFixedWindow RateLimitAlgorithm = "fixed_window" -) - -// TokenBucket 令牌桶算法 -type TokenBucket struct { - capacity int64 - tokens int64 - rate int64 // 每秒产生的令牌数 - lastRefill time.Time - mu sync.Mutex -} - -// NewTokenBucket 创建令牌桶 -func NewTokenBucket(capacity, rate int64) *TokenBucket { - return &TokenBucket{ - capacity: capacity, - tokens: capacity, - rate: rate, - lastRefill: time.Now(), - } -} - -// Allow 检查是否允许访问 -func (tb *TokenBucket) Allow() bool { - tb.mu.Lock() - defer tb.mu.Unlock() - - now := time.Now() - elapsed := now.Sub(tb.lastRefill).Seconds() - - // 计算需要补充的令牌数 - refillTokens := int64(elapsed * float64(tb.rate)) - tb.tokens += refillTokens - if tb.tokens > tb.capacity { - tb.tokens = tb.capacity - } - tb.lastRefill = now - - // 检查是否有足够的令牌 - if tb.tokens > 0 { - tb.tokens-- - return true - } - - return false -} - -// LeakyBucket 漏桶算法 -type LeakyBucket struct { - capacity int64 - water int64 - rate int64 // 每秒漏出的水量 - lastLeak time.Time - mu sync.Mutex -} - -// NewLeakyBucket 创建漏桶 -func NewLeakyBucket(capacity, rate int64) *LeakyBucket { - return &LeakyBucket{ - capacity: capacity, - water: 0, - rate: rate, - lastLeak: time.Now(), - } -} - -// Allow 检查是否允许访问 -func (lb *LeakyBucket) Allow() bool { - lb.mu.Lock() - defer lb.mu.Unlock() - - now := time.Now() - elapsed := now.Sub(lb.lastLeak).Seconds() - - // 计算漏出的水量 - leakWater := int64(elapsed * float64(lb.rate)) - lb.water -= leakWater - if lb.water < 0 { - lb.water = 0 - } - lb.lastLeak = now - - // 检查桶是否已满 - if lb.water < lb.capacity { - lb.water++ - return true - } - - return false -} - -// SlidingWindow 滑动窗口算法 -type SlidingWindow struct { - window time.Duration - capacity int64 - requests []time.Time - mu sync.Mutex -} - -// NewSlidingWindow 创建滑动窗口 -func NewSlidingWindow(window time.Duration, capacity int64) *SlidingWindow { - return &SlidingWindow{ - window: window, - capacity: capacity, - requests: make([]time.Time, 0), - } -} - -// Allow 检查是否允许访问 -func (sw *SlidingWindow) Allow() bool { - sw.mu.Lock() - defer sw.mu.Unlock() - - now := time.Now() - - // 移除窗口外的请求 - validRequests := make([]time.Time, 0) - for _, req := range sw.requests { - if now.Sub(req) < sw.window { - validRequests = append(validRequests, req) - } - } - sw.requests = validRequests - - // 检查是否超过容量 - if int64(len(sw.requests)) < sw.capacity { - sw.requests = append(sw.requests, now) - return true - } - - return false -} - -// RateLimiter 限流器 -type RateLimiter struct { - algorithm RateLimitAlgorithm - limiter interface{} -} - -// NewRateLimiter 创建限流器 -func NewRateLimiter(algorithm RateLimitAlgorithm, capacity, rate int64, window time.Duration) *RateLimiter { - limiter := &RateLimiter{algorithm: algorithm} - - switch algorithm { - case AlgorithmTokenBucket: - limiter.limiter = NewTokenBucket(capacity, rate) - case AlgorithmLeakyBucket: - limiter.limiter = NewLeakyBucket(capacity, rate) - case AlgorithmSlidingWindow: - limiter.limiter = NewSlidingWindow(window, capacity) - default: - limiter.limiter = NewSlidingWindow(window, capacity) - } - - return limiter -} - -// Allow 检查是否允许访问 -func (rl *RateLimiter) Allow() bool { - switch rl.algorithm { - case AlgorithmTokenBucket: - return rl.limiter.(*TokenBucket).Allow() - case AlgorithmLeakyBucket: - return rl.limiter.(*LeakyBucket).Allow() - case AlgorithmSlidingWindow: - return rl.limiter.(*SlidingWindow).Allow() - default: - return rl.limiter.(*SlidingWindow).Allow() - } -} diff --git a/pkg/response/response.go b/pkg/response/response.go deleted file mode 100644 index af285fe..0000000 --- a/pkg/response/response.go +++ /dev/null @@ -1,50 +0,0 @@ -package response - -import ( - "net/http" - - "github.com/gin-gonic/gin" -) - -// Response 统一响应结构 -type Response struct { - Code int `json:"code"` - Message string `json:"message"` - Data interface{} `json:"data,omitempty"` -} - -// Success 成功响应 -func Success(c *gin.Context, data interface{}) { - c.JSON(http.StatusOK, Response{ - Code: 0, - Message: "success", - Data: data, - }) -} - -// Error 错误响应 -func Error(c *gin.Context, httpStatus int, message string, err error) { - if err != nil { - // 在开发环境下返回详细错误信息 - if gin.Mode() == gin.DebugMode { - c.JSON(httpStatus, Response{ - Code: httpStatus, - Message: message, - Data: err.Error(), - }) - return - } - } - c.JSON(httpStatus, Response{ - Code: httpStatus, - Message: message, - }) -} - -// ErrorWithCode 错误响应(带自定义错误码) -func ErrorWithCode(c *gin.Context, code int, message string) { - c.JSON(http.StatusOK, Response{ - Code: code, - Message: message, - }) -}