package handler import ( "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" apierrors "github.com/user-management-system/internal/pkg/errors" ) // TestHandleError_Nil 测试 nil error func TestHandleError_Nil(t *testing.T) { gin.SetMode(gin.TestMode) w := &mockResponseWriter{} c, _ := gin.CreateTestContext(w) handleError(c, nil) // nil error 不写入响应 assert.Equal(t, 0, w.code) } // TestHandleError_ApplicationError 测试 ApplicationError func TestHandleError_ApplicationError(t *testing.T) { gin.SetMode(gin.TestMode) tests := []struct { name string err error wantStatus int wantCode int }{ { name: "bad request error", err: apierrors.BadRequest("invalid", "invalid input"), wantStatus: http.StatusBadRequest, wantCode: http.StatusBadRequest, }, { name: "not found error", err: apierrors.NotFound("user", "user not found"), wantStatus: http.StatusNotFound, wantCode: http.StatusNotFound, }, { name: "unauthorized error", err: apierrors.Unauthorized("token", "invalid token"), wantStatus: http.StatusUnauthorized, wantCode: http.StatusUnauthorized, }, { name: "forbidden error", err: apierrors.Forbidden("permission", "permission denied"), wantStatus: http.StatusForbidden, wantCode: http.StatusForbidden, }, { name: "conflict error", err: apierrors.Conflict("user", "user already exists"), wantStatus: http.StatusConflict, wantCode: http.StatusConflict, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := &mockResponseWriter{} c, _ := gin.CreateTestContext(w) handleError(c, tt.err) assert.Equal(t, tt.wantStatus, w.code) }) } } // TestClassifyErrorMessage 测试错误消息分类 func TestClassifyErrorMessage(t *testing.T) { tests := []struct { name string msg string want int }{ // Not found {name: "not found EN", msg: "user not found", want: http.StatusNotFound}, {name: "not found CN", msg: "用户不存在", want: http.StatusNotFound}, {name: "not found CN2", msg: "找不到资源", want: http.StatusNotFound}, // Conflict {name: "already exists EN", msg: "user already exists", want: http.StatusConflict}, {name: "already exists CN", msg: "用户已存在", want: http.StatusConflict}, {name: "duplicate", msg: "duplicate entry", want: http.StatusConflict}, // Unauthorized {name: "unauthorized EN", msg: "unauthorized", want: http.StatusUnauthorized}, {name: "invalid token", msg: "invalid token", want: http.StatusUnauthorized}, {name: "token", msg: "token expired", want: http.StatusUnauthorized}, {name: "unauthorized CN", msg: "令牌无效", want: http.StatusUnauthorized}, // Forbidden {name: "forbidden EN", msg: "forbidden", want: http.StatusForbidden}, {name: "permission", msg: "no permission", want: http.StatusForbidden}, {name: "forbidden CN", msg: "权限不足", want: http.StatusForbidden}, // Bad request {name: "invalid", msg: "invalid input", want: http.StatusBadRequest}, {name: "required", msg: "field is required", want: http.StatusBadRequest}, {name: "cannot be empty", msg: "name cannot be empty", want: http.StatusBadRequest}, {name: "cannot be empty CN", msg: "名称不能为空", want: http.StatusBadRequest}, {name: "incorrect password", msg: "密码不正确", want: http.StatusBadRequest}, {name: "expired", msg: "token expired", want: http.StatusUnauthorized}, // "token" 匹配先于 "expired" // Rate limit {name: "locked", msg: "account locked", want: http.StatusTooManyRequests}, {name: "too many", msg: "too many attempts", want: http.StatusTooManyRequests}, {name: "rate limit", msg: "rate limit exceeded", want: http.StatusTooManyRequests}, // Internal server error (default) {name: "unknown error", msg: "unknown error occurred", want: http.StatusInternalServerError}, {name: "database error", msg: "database connection failed", want: http.StatusInternalServerError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := classifyErrorMessage(tt.msg) assert.Equal(t, tt.want, got, "classifyErrorMessage(%q)", tt.msg) }) } } // TestContains 测试 contains 辅助函数 func TestContains(t *testing.T) { tests := []struct { name string s string keywords []string want bool }{ { name: "match first", s: "hello world", keywords: []string{"hello", "foo"}, want: true, }, { name: "match second", s: "hello world", keywords: []string{"foo", "world"}, want: true, }, { name: "no match", s: "hello world", keywords: []string{"foo", "bar"}, want: false, }, { name: "empty keywords", s: "hello world", keywords: []string{}, want: false, }, { name: "empty string", s: "", keywords: []string{"hello"}, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := contains(tt.s, tt.keywords...) assert.Equal(t, tt.want, got) }) } } // TestGetUserIDFromContext_Success 测试从 context 获取 userID func TestGetUserIDFromContext_Success(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Set("user_id", int64(123)) userID, ok := getUserIDFromContext(c) assert.True(t, ok) assert.Equal(t, int64(123), userID) } // TestGetUserIDFromContext_NotExists 测试 context 中无 userID func TestGetUserIDFromContext_NotExists(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) userID, ok := getUserIDFromContext(c) assert.False(t, ok) assert.Equal(t, int64(0), userID) } // TestGetUserIDFromContext_WrongType 测试 userID 类型错误 func TestGetUserIDFromContext_WrongType(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Set("user_id", "not an int64") userID, ok := getUserIDFromContext(c) assert.False(t, ok) assert.Equal(t, int64(0), userID) } // TestGetUsernameFromContext_Success 测试从 context 获取 username func TestGetUsernameFromContext_Success(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Set("username", "testuser") username, ok := getUsernameFromContext(c) assert.True(t, ok) assert.Equal(t, "testuser", username) } // TestGetUsernameFromContext_NotExists 测试 context 中无 username func TestGetUsernameFromContext_NotExists(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) username, ok := getUsernameFromContext(c) assert.False(t, ok) assert.Equal(t, "", username) } // TestGetUsernameFromContext_WrongType 测试 username 类型错误 func TestGetUsernameFromContext_WrongType(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Set("username", 12345) username, ok := getUsernameFromContext(c) assert.False(t, ok) assert.Equal(t, "", username) } // mockResponseWriter 用于测试的 mock response writer type mockResponseWriter struct { code int data []byte } func (m *mockResponseWriter) Header() http.Header { return http.Header{} } func (m *mockResponseWriter) Write(data []byte) (int, error) { m.data = append(m.data, data...) return len(data), nil } func (m *mockResponseWriter) WriteHeader(code int) { m.code = code }