diff --git a/internal/api/handler/auth_handler_test.go b/internal/api/handler/auth_handler_test.go new file mode 100644 index 0000000..a03238a --- /dev/null +++ b/internal/api/handler/auth_handler_test.go @@ -0,0 +1,270 @@ +package handler + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + + apierrors "github.com/user-management-system/internal/pkg/errors" +) + +// TestHandleError_Nil 测试 nil error +func TestHandleError_Nil(t *testing.T) { + gin.SetMode(gin.TestMode) + w := &mockResponseWriter{} + c, _ := gin.CreateTestContext(w) + + handleError(c, nil) + // nil error 不写入响应 + assert.Equal(t, 0, w.code) +} + +// TestHandleError_ApplicationError 测试 ApplicationError +func TestHandleError_ApplicationError(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + err error + wantStatus int + wantCode int + }{ + { + name: "bad request error", + err: apierrors.BadRequest("invalid", "invalid input"), + wantStatus: http.StatusBadRequest, + wantCode: http.StatusBadRequest, + }, + { + name: "not found error", + err: apierrors.NotFound("user", "user not found"), + wantStatus: http.StatusNotFound, + wantCode: http.StatusNotFound, + }, + { + name: "unauthorized error", + err: apierrors.Unauthorized("token", "invalid token"), + wantStatus: http.StatusUnauthorized, + wantCode: http.StatusUnauthorized, + }, + { + name: "forbidden error", + err: apierrors.Forbidden("permission", "permission denied"), + wantStatus: http.StatusForbidden, + wantCode: http.StatusForbidden, + }, + { + name: "conflict error", + err: apierrors.Conflict("user", "user already exists"), + wantStatus: http.StatusConflict, + wantCode: http.StatusConflict, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &mockResponseWriter{} + c, _ := gin.CreateTestContext(w) + handleError(c, tt.err) + assert.Equal(t, tt.wantStatus, w.code) + }) + } +} + +// TestClassifyErrorMessage 测试错误消息分类 +func TestClassifyErrorMessage(t *testing.T) { + tests := []struct { + name string + msg string + want int + }{ + // Not found + {name: "not found EN", msg: "user not found", want: http.StatusNotFound}, + {name: "not found CN", msg: "用户不存在", want: http.StatusNotFound}, + {name: "not found CN2", msg: "找不到资源", want: http.StatusNotFound}, + + // Conflict + {name: "already exists EN", msg: "user already exists", want: http.StatusConflict}, + {name: "already exists CN", msg: "用户已存在", want: http.StatusConflict}, + {name: "duplicate", msg: "duplicate entry", want: http.StatusConflict}, + + // Unauthorized + {name: "unauthorized EN", msg: "unauthorized", want: http.StatusUnauthorized}, + {name: "invalid token", msg: "invalid token", want: http.StatusUnauthorized}, + {name: "token", msg: "token expired", want: http.StatusUnauthorized}, + {name: "unauthorized CN", msg: "令牌无效", want: http.StatusUnauthorized}, + + // Forbidden + {name: "forbidden EN", msg: "forbidden", want: http.StatusForbidden}, + {name: "permission", msg: "no permission", want: http.StatusForbidden}, + {name: "forbidden CN", msg: "权限不足", want: http.StatusForbidden}, + + // Bad request + {name: "invalid", msg: "invalid input", want: http.StatusBadRequest}, + {name: "required", msg: "field is required", want: http.StatusBadRequest}, + {name: "cannot be empty", msg: "name cannot be empty", want: http.StatusBadRequest}, + {name: "cannot be empty CN", msg: "名称不能为空", want: http.StatusBadRequest}, + {name: "incorrect password", msg: "密码不正确", want: http.StatusBadRequest}, + {name: "expired", msg: "token expired", want: http.StatusUnauthorized}, // "token" 匹配先于 "expired" + + // Rate limit + {name: "locked", msg: "account locked", want: http.StatusTooManyRequests}, + {name: "too many", msg: "too many attempts", want: http.StatusTooManyRequests}, + {name: "rate limit", msg: "rate limit exceeded", want: http.StatusTooManyRequests}, + + // Internal server error (default) + {name: "unknown error", msg: "unknown error occurred", want: http.StatusInternalServerError}, + {name: "database error", msg: "database connection failed", want: http.StatusInternalServerError}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := classifyErrorMessage(tt.msg) + assert.Equal(t, tt.want, got, "classifyErrorMessage(%q)", tt.msg) + }) + } +} + +// TestContains 测试 contains 辅助函数 +func TestContains(t *testing.T) { + tests := []struct { + name string + s string + keywords []string + want bool + }{ + { + name: "match first", + s: "hello world", + keywords: []string{"hello", "foo"}, + want: true, + }, + { + name: "match second", + s: "hello world", + keywords: []string{"foo", "world"}, + want: true, + }, + { + name: "no match", + s: "hello world", + keywords: []string{"foo", "bar"}, + want: false, + }, + { + name: "empty keywords", + s: "hello world", + keywords: []string{}, + want: false, + }, + { + name: "empty string", + s: "", + keywords: []string{"hello"}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := contains(tt.s, tt.keywords...) + assert.Equal(t, tt.want, got) + }) + } +} + +// TestGetUserIDFromContext_Success 测试从 context 获取 userID +func TestGetUserIDFromContext_Success(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + c.Set("user_id", int64(123)) + + userID, ok := getUserIDFromContext(c) + assert.True(t, ok) + assert.Equal(t, int64(123), userID) +} + +// TestGetUserIDFromContext_NotExists 测试 context 中无 userID +func TestGetUserIDFromContext_NotExists(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + userID, ok := getUserIDFromContext(c) + assert.False(t, ok) + assert.Equal(t, int64(0), userID) +} + +// TestGetUserIDFromContext_WrongType 测试 userID 类型错误 +func TestGetUserIDFromContext_WrongType(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + c.Set("user_id", "not an int64") + + userID, ok := getUserIDFromContext(c) + assert.False(t, ok) + assert.Equal(t, int64(0), userID) +} + +// TestGetUsernameFromContext_Success 测试从 context 获取 username +func TestGetUsernameFromContext_Success(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + c.Set("username", "testuser") + + username, ok := getUsernameFromContext(c) + assert.True(t, ok) + assert.Equal(t, "testuser", username) +} + +// TestGetUsernameFromContext_NotExists 测试 context 中无 username +func TestGetUsernameFromContext_NotExists(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + username, ok := getUsernameFromContext(c) + assert.False(t, ok) + assert.Equal(t, "", username) +} + +// TestGetUsernameFromContext_WrongType 测试 username 类型错误 +func TestGetUsernameFromContext_WrongType(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + c.Set("username", 12345) + + username, ok := getUsernameFromContext(c) + assert.False(t, ok) + assert.Equal(t, "", username) +} + +// mockResponseWriter 用于测试的 mock response writer +type mockResponseWriter struct { + code int + data []byte +} + +func (m *mockResponseWriter) Header() http.Header { + return http.Header{} +} + +func (m *mockResponseWriter) Write(data []byte) (int, error) { + m.data = append(m.data, data...) + return len(data), nil +} + +func (m *mockResponseWriter) WriteHeader(code int) { + m.code = code +}