- Add handleError tests for ApplicationError types - Add classifyErrorMessage tests for error message classification - Add contains helper function tests - Add getUserIDFromContext/getUsernameFromContext tests - Cover error classification for both EN and CN error messages
271 lines
7.4 KiB
Go
271 lines
7.4 KiB
Go
package handler
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
apierrors "github.com/user-management-system/internal/pkg/errors"
|
|
)
|
|
|
|
// TestHandleError_Nil 测试 nil error
|
|
func TestHandleError_Nil(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
w := &mockResponseWriter{}
|
|
c, _ := gin.CreateTestContext(w)
|
|
|
|
handleError(c, nil)
|
|
// nil error 不写入响应
|
|
assert.Equal(t, 0, w.code)
|
|
}
|
|
|
|
// TestHandleError_ApplicationError 测试 ApplicationError
|
|
func TestHandleError_ApplicationError(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
wantStatus int
|
|
wantCode int
|
|
}{
|
|
{
|
|
name: "bad request error",
|
|
err: apierrors.BadRequest("invalid", "invalid input"),
|
|
wantStatus: http.StatusBadRequest,
|
|
wantCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "not found error",
|
|
err: apierrors.NotFound("user", "user not found"),
|
|
wantStatus: http.StatusNotFound,
|
|
wantCode: http.StatusNotFound,
|
|
},
|
|
{
|
|
name: "unauthorized error",
|
|
err: apierrors.Unauthorized("token", "invalid token"),
|
|
wantStatus: http.StatusUnauthorized,
|
|
wantCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "forbidden error",
|
|
err: apierrors.Forbidden("permission", "permission denied"),
|
|
wantStatus: http.StatusForbidden,
|
|
wantCode: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "conflict error",
|
|
err: apierrors.Conflict("user", "user already exists"),
|
|
wantStatus: http.StatusConflict,
|
|
wantCode: http.StatusConflict,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
w := &mockResponseWriter{}
|
|
c, _ := gin.CreateTestContext(w)
|
|
handleError(c, tt.err)
|
|
assert.Equal(t, tt.wantStatus, w.code)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestClassifyErrorMessage 测试错误消息分类
|
|
func TestClassifyErrorMessage(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
msg string
|
|
want int
|
|
}{
|
|
// Not found
|
|
{name: "not found EN", msg: "user not found", want: http.StatusNotFound},
|
|
{name: "not found CN", msg: "用户不存在", want: http.StatusNotFound},
|
|
{name: "not found CN2", msg: "找不到资源", want: http.StatusNotFound},
|
|
|
|
// Conflict
|
|
{name: "already exists EN", msg: "user already exists", want: http.StatusConflict},
|
|
{name: "already exists CN", msg: "用户已存在", want: http.StatusConflict},
|
|
{name: "duplicate", msg: "duplicate entry", want: http.StatusConflict},
|
|
|
|
// Unauthorized
|
|
{name: "unauthorized EN", msg: "unauthorized", want: http.StatusUnauthorized},
|
|
{name: "invalid token", msg: "invalid token", want: http.StatusUnauthorized},
|
|
{name: "token", msg: "token expired", want: http.StatusUnauthorized},
|
|
{name: "unauthorized CN", msg: "令牌无效", want: http.StatusUnauthorized},
|
|
|
|
// Forbidden
|
|
{name: "forbidden EN", msg: "forbidden", want: http.StatusForbidden},
|
|
{name: "permission", msg: "no permission", want: http.StatusForbidden},
|
|
{name: "forbidden CN", msg: "权限不足", want: http.StatusForbidden},
|
|
|
|
// Bad request
|
|
{name: "invalid", msg: "invalid input", want: http.StatusBadRequest},
|
|
{name: "required", msg: "field is required", want: http.StatusBadRequest},
|
|
{name: "cannot be empty", msg: "name cannot be empty", want: http.StatusBadRequest},
|
|
{name: "cannot be empty CN", msg: "名称不能为空", want: http.StatusBadRequest},
|
|
{name: "incorrect password", msg: "密码不正确", want: http.StatusBadRequest},
|
|
{name: "expired", msg: "token expired", want: http.StatusUnauthorized}, // "token" 匹配先于 "expired"
|
|
|
|
// Rate limit
|
|
{name: "locked", msg: "account locked", want: http.StatusTooManyRequests},
|
|
{name: "too many", msg: "too many attempts", want: http.StatusTooManyRequests},
|
|
{name: "rate limit", msg: "rate limit exceeded", want: http.StatusTooManyRequests},
|
|
|
|
// Internal server error (default)
|
|
{name: "unknown error", msg: "unknown error occurred", want: http.StatusInternalServerError},
|
|
{name: "database error", msg: "database connection failed", want: http.StatusInternalServerError},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := classifyErrorMessage(tt.msg)
|
|
assert.Equal(t, tt.want, got, "classifyErrorMessage(%q)", tt.msg)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestContains 测试 contains 辅助函数
|
|
func TestContains(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
s string
|
|
keywords []string
|
|
want bool
|
|
}{
|
|
{
|
|
name: "match first",
|
|
s: "hello world",
|
|
keywords: []string{"hello", "foo"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "match second",
|
|
s: "hello world",
|
|
keywords: []string{"foo", "world"},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "no match",
|
|
s: "hello world",
|
|
keywords: []string{"foo", "bar"},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "empty keywords",
|
|
s: "hello world",
|
|
keywords: []string{},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "empty string",
|
|
s: "",
|
|
keywords: []string{"hello"},
|
|
want: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := contains(tt.s, tt.keywords...)
|
|
assert.Equal(t, tt.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestGetUserIDFromContext_Success 测试从 context 获取 userID
|
|
func TestGetUserIDFromContext_Success(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
|
|
c.Set("user_id", int64(123))
|
|
|
|
userID, ok := getUserIDFromContext(c)
|
|
assert.True(t, ok)
|
|
assert.Equal(t, int64(123), userID)
|
|
}
|
|
|
|
// TestGetUserIDFromContext_NotExists 测试 context 中无 userID
|
|
func TestGetUserIDFromContext_NotExists(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
|
|
userID, ok := getUserIDFromContext(c)
|
|
assert.False(t, ok)
|
|
assert.Equal(t, int64(0), userID)
|
|
}
|
|
|
|
// TestGetUserIDFromContext_WrongType 测试 userID 类型错误
|
|
func TestGetUserIDFromContext_WrongType(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
|
|
c.Set("user_id", "not an int64")
|
|
|
|
userID, ok := getUserIDFromContext(c)
|
|
assert.False(t, ok)
|
|
assert.Equal(t, int64(0), userID)
|
|
}
|
|
|
|
// TestGetUsernameFromContext_Success 测试从 context 获取 username
|
|
func TestGetUsernameFromContext_Success(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
|
|
c.Set("username", "testuser")
|
|
|
|
username, ok := getUsernameFromContext(c)
|
|
assert.True(t, ok)
|
|
assert.Equal(t, "testuser", username)
|
|
}
|
|
|
|
// TestGetUsernameFromContext_NotExists 测试 context 中无 username
|
|
func TestGetUsernameFromContext_NotExists(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
|
|
username, ok := getUsernameFromContext(c)
|
|
assert.False(t, ok)
|
|
assert.Equal(t, "", username)
|
|
}
|
|
|
|
// TestGetUsernameFromContext_WrongType 测试 username 类型错误
|
|
func TestGetUsernameFromContext_WrongType(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
|
|
c.Set("username", 12345)
|
|
|
|
username, ok := getUsernameFromContext(c)
|
|
assert.False(t, ok)
|
|
assert.Equal(t, "", username)
|
|
}
|
|
|
|
// mockResponseWriter 用于测试的 mock response writer
|
|
type mockResponseWriter struct {
|
|
code int
|
|
data []byte
|
|
}
|
|
|
|
func (m *mockResponseWriter) Header() http.Header {
|
|
return http.Header{}
|
|
}
|
|
|
|
func (m *mockResponseWriter) Write(data []byte) (int, error) {
|
|
m.data = append(m.data, data...)
|
|
return len(data), nil
|
|
}
|
|
|
|
func (m *mockResponseWriter) WriteHeader(code int) {
|
|
m.code = code
|
|
}
|