test: add auth handler error classification tests

- Add handleError tests for ApplicationError types
- Add classifyErrorMessage tests for error message classification
- Add contains helper function tests
- Add getUserIDFromContext/getUsernameFromContext tests
- Cover error classification for both EN and CN error messages
This commit is contained in:
Your Name
2026-05-29 14:38:08 +08:00
parent 5d767abe72
commit f0930489f1

View File

@@ -0,0 +1,270 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
apierrors "github.com/user-management-system/internal/pkg/errors"
)
// TestHandleError_Nil 测试 nil error
func TestHandleError_Nil(t *testing.T) {
gin.SetMode(gin.TestMode)
w := &mockResponseWriter{}
c, _ := gin.CreateTestContext(w)
handleError(c, nil)
// nil error 不写入响应
assert.Equal(t, 0, w.code)
}
// TestHandleError_ApplicationError 测试 ApplicationError
func TestHandleError_ApplicationError(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
err error
wantStatus int
wantCode int
}{
{
name: "bad request error",
err: apierrors.BadRequest("invalid", "invalid input"),
wantStatus: http.StatusBadRequest,
wantCode: http.StatusBadRequest,
},
{
name: "not found error",
err: apierrors.NotFound("user", "user not found"),
wantStatus: http.StatusNotFound,
wantCode: http.StatusNotFound,
},
{
name: "unauthorized error",
err: apierrors.Unauthorized("token", "invalid token"),
wantStatus: http.StatusUnauthorized,
wantCode: http.StatusUnauthorized,
},
{
name: "forbidden error",
err: apierrors.Forbidden("permission", "permission denied"),
wantStatus: http.StatusForbidden,
wantCode: http.StatusForbidden,
},
{
name: "conflict error",
err: apierrors.Conflict("user", "user already exists"),
wantStatus: http.StatusConflict,
wantCode: http.StatusConflict,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &mockResponseWriter{}
c, _ := gin.CreateTestContext(w)
handleError(c, tt.err)
assert.Equal(t, tt.wantStatus, w.code)
})
}
}
// TestClassifyErrorMessage 测试错误消息分类
func TestClassifyErrorMessage(t *testing.T) {
tests := []struct {
name string
msg string
want int
}{
// Not found
{name: "not found EN", msg: "user not found", want: http.StatusNotFound},
{name: "not found CN", msg: "用户不存在", want: http.StatusNotFound},
{name: "not found CN2", msg: "找不到资源", want: http.StatusNotFound},
// Conflict
{name: "already exists EN", msg: "user already exists", want: http.StatusConflict},
{name: "already exists CN", msg: "用户已存在", want: http.StatusConflict},
{name: "duplicate", msg: "duplicate entry", want: http.StatusConflict},
// Unauthorized
{name: "unauthorized EN", msg: "unauthorized", want: http.StatusUnauthorized},
{name: "invalid token", msg: "invalid token", want: http.StatusUnauthorized},
{name: "token", msg: "token expired", want: http.StatusUnauthorized},
{name: "unauthorized CN", msg: "令牌无效", want: http.StatusUnauthorized},
// Forbidden
{name: "forbidden EN", msg: "forbidden", want: http.StatusForbidden},
{name: "permission", msg: "no permission", want: http.StatusForbidden},
{name: "forbidden CN", msg: "权限不足", want: http.StatusForbidden},
// Bad request
{name: "invalid", msg: "invalid input", want: http.StatusBadRequest},
{name: "required", msg: "field is required", want: http.StatusBadRequest},
{name: "cannot be empty", msg: "name cannot be empty", want: http.StatusBadRequest},
{name: "cannot be empty CN", msg: "名称不能为空", want: http.StatusBadRequest},
{name: "incorrect password", msg: "密码不正确", want: http.StatusBadRequest},
{name: "expired", msg: "token expired", want: http.StatusUnauthorized}, // "token" 匹配先于 "expired"
// Rate limit
{name: "locked", msg: "account locked", want: http.StatusTooManyRequests},
{name: "too many", msg: "too many attempts", want: http.StatusTooManyRequests},
{name: "rate limit", msg: "rate limit exceeded", want: http.StatusTooManyRequests},
// Internal server error (default)
{name: "unknown error", msg: "unknown error occurred", want: http.StatusInternalServerError},
{name: "database error", msg: "database connection failed", want: http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := classifyErrorMessage(tt.msg)
assert.Equal(t, tt.want, got, "classifyErrorMessage(%q)", tt.msg)
})
}
}
// TestContains 测试 contains 辅助函数
func TestContains(t *testing.T) {
tests := []struct {
name string
s string
keywords []string
want bool
}{
{
name: "match first",
s: "hello world",
keywords: []string{"hello", "foo"},
want: true,
},
{
name: "match second",
s: "hello world",
keywords: []string{"foo", "world"},
want: true,
},
{
name: "no match",
s: "hello world",
keywords: []string{"foo", "bar"},
want: false,
},
{
name: "empty keywords",
s: "hello world",
keywords: []string{},
want: false,
},
{
name: "empty string",
s: "",
keywords: []string{"hello"},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := contains(tt.s, tt.keywords...)
assert.Equal(t, tt.want, got)
})
}
}
// TestGetUserIDFromContext_Success 测试从 context 获取 userID
func TestGetUserIDFromContext_Success(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Set("user_id", int64(123))
userID, ok := getUserIDFromContext(c)
assert.True(t, ok)
assert.Equal(t, int64(123), userID)
}
// TestGetUserIDFromContext_NotExists 测试 context 中无 userID
func TestGetUserIDFromContext_NotExists(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
userID, ok := getUserIDFromContext(c)
assert.False(t, ok)
assert.Equal(t, int64(0), userID)
}
// TestGetUserIDFromContext_WrongType 测试 userID 类型错误
func TestGetUserIDFromContext_WrongType(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Set("user_id", "not an int64")
userID, ok := getUserIDFromContext(c)
assert.False(t, ok)
assert.Equal(t, int64(0), userID)
}
// TestGetUsernameFromContext_Success 测试从 context 获取 username
func TestGetUsernameFromContext_Success(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Set("username", "testuser")
username, ok := getUsernameFromContext(c)
assert.True(t, ok)
assert.Equal(t, "testuser", username)
}
// TestGetUsernameFromContext_NotExists 测试 context 中无 username
func TestGetUsernameFromContext_NotExists(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
username, ok := getUsernameFromContext(c)
assert.False(t, ok)
assert.Equal(t, "", username)
}
// TestGetUsernameFromContext_WrongType 测试 username 类型错误
func TestGetUsernameFromContext_WrongType(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Set("username", 12345)
username, ok := getUsernameFromContext(c)
assert.False(t, ok)
assert.Equal(t, "", username)
}
// mockResponseWriter 用于测试的 mock response writer
type mockResponseWriter struct {
code int
data []byte
}
func (m *mockResponseWriter) Header() http.Header {
return http.Header{}
}
func (m *mockResponseWriter) Write(data []byte) (int, error) {
m.data = append(m.data, data...)
return len(data), nil
}
func (m *mockResponseWriter) WriteHeader(code int) {
m.code = code
}