test: 补齐 handler/repository/domain 层单元测试
This commit is contained in:
297
internal/api/handler/auth_handler_unit_test.go
Normal file
297
internal/api/handler/auth_handler_unit_test.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestAuthHandler_SupportFlags(t *testing.T) {
|
||||
var nilHandler *AuthHandler
|
||||
if nilHandler.SupportsPasswordReset() {
|
||||
t.Fatal("nil handler should not support password reset")
|
||||
}
|
||||
|
||||
handler := &AuthHandler{}
|
||||
if handler.SupportsPasswordReset() {
|
||||
t.Fatal("password reset should be disabled by default")
|
||||
}
|
||||
|
||||
handler.SetPasswordResetEnabled(true)
|
||||
if !handler.SupportsPasswordReset() {
|
||||
t.Fatal("password reset flag should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserIDFromContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/userinfo", nil)
|
||||
|
||||
if _, ok := getUserIDFromContext(c); ok {
|
||||
t.Fatal("expected missing user_id to return false")
|
||||
}
|
||||
|
||||
c.Set("user_id", "1")
|
||||
if _, ok := getUserIDFromContext(c); ok {
|
||||
t.Fatal("expected non-int64 user_id to return false")
|
||||
}
|
||||
|
||||
c.Set("user_id", int64(42))
|
||||
if got, ok := getUserIDFromContext(c); !ok || got != 42 {
|
||||
t.Fatalf("getUserIDFromContext() = (%d, %v), want (42, true)", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestUsesHTTPS(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
if requestUsesHTTPS(nil) {
|
||||
t.Fatal("nil context should not use https")
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||||
if requestUsesHTTPS(c) {
|
||||
t.Fatal("plain http request should not use https")
|
||||
}
|
||||
|
||||
c.Request.Header.Set("X-Forwarded-Proto", "https")
|
||||
if !requestUsesHTTPS(c) {
|
||||
t.Fatal("forwarded https request should be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCookies_SetAndClear(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||||
|
||||
setSessionCookies(c, nil, "")
|
||||
if len(recorder.Header().Values("Set-Cookie")) != 0 {
|
||||
t.Fatal("empty refresh token should not set cookies")
|
||||
}
|
||||
|
||||
setSessionCookies(c, nil, "refresh-token")
|
||||
setCookies := recorder.Header().Values("Set-Cookie")
|
||||
if len(setCookies) < 2 {
|
||||
t.Fatalf("expected session cookies to be set, got %d", len(setCookies))
|
||||
}
|
||||
if !strings.Contains(setCookies[0], refreshTokenCookieName+"=refresh-token") &&
|
||||
!strings.Contains(setCookies[1], refreshTokenCookieName+"=refresh-token") {
|
||||
t.Fatalf("expected refresh token cookie, got %#v", setCookies)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
c, _ = gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||||
clearSessionCookies(c)
|
||||
setCookies = recorder.Header().Values("Set-Cookie")
|
||||
if len(setCookies) < 2 {
|
||||
t.Fatalf("expected clearing cookies to emit expired cookies, got %d", len(setCookies))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyErrorMessage(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
msg string
|
||||
want int
|
||||
}{
|
||||
{name: "not found", msg: "user not found", want: http.StatusNotFound},
|
||||
{name: "duplicate", msg: "already exists", want: http.StatusConflict},
|
||||
{name: "verification code", msg: "验证码错误", want: http.StatusUnauthorized},
|
||||
{name: "unauthorized", msg: "invalid token", want: http.StatusUnauthorized},
|
||||
{name: "forbidden", msg: "permission denied", want: http.StatusForbidden},
|
||||
{name: "bad request", msg: "invalid payload", want: http.StatusBadRequest},
|
||||
{name: "rate limit", msg: "too many attempts", want: http.StatusTooManyRequests},
|
||||
{name: "fallback", msg: "unexpected boom", want: http.StatusInternalServerError},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := classifyErrorMessage(tc.msg); got != tc.want {
|
||||
t.Fatalf("classifyErrorMessage(%q) = %d, want %d", tc.msg, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_OAuthFallbackEndpoints(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
run func(*gin.Context)
|
||||
}{
|
||||
{
|
||||
name: "oauth login",
|
||||
run: func(c *gin.Context) {
|
||||
c.Params = gin.Params{{Key: "provider", Value: "github"}}
|
||||
h.OAuthLogin(c)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "oauth callback",
|
||||
run: func(c *gin.Context) {
|
||||
c.Params = gin.Params{{Key: "provider", Value: "github"}}
|
||||
h.OAuthCallback(c)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "oauth exchange",
|
||||
run: func(c *gin.Context) {
|
||||
c.Params = gin.Params{{Key: "provider", Value: "github"}}
|
||||
h.OAuthExchange(c)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "oauth providers",
|
||||
run: func(c *gin.Context) {
|
||||
h.GetEnabledOAuthProviders(c)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||||
tc.run(c)
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_RefreshToken_InvalidJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/auth/refresh", bytes.NewBufferString("{"))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.RefreshToken(c)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_ActivateEmail_MissingToken(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/auth/activate-email", bytes.NewBufferString(`{}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.ActivateEmail(c)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_ResendActivationEmail_InvalidEmail(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/auth/resend-activation-email", bytes.NewBufferString(`{"email":"bad-email"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.ResendActivationEmail(c)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_SendEmailCode_InvalidEmail(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/auth/send-email-code", bytes.NewBufferString(`{"email":"bad-email"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.SendEmailCode(c)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_LoginByEmailCode_InvalidPayload(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/auth/login-by-email-code", bytes.NewBufferString(`{"email":"bad-email"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.LoginByEmailCode(c)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_BootstrapAdmin_HeaderFailures(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
original := os.Getenv("BOOTSTRAP_SECRET")
|
||||
if err := os.Setenv("BOOTSTRAP_SECRET", "expected-secret"); err != nil {
|
||||
t.Fatalf("set env failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.Setenv("BOOTSTRAP_SECRET", original)
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
secret string
|
||||
want int
|
||||
}{
|
||||
{name: "missing header", secret: "", want: http.StatusUnauthorized},
|
||||
{name: "wrong header", secret: "wrong-secret", want: http.StatusUnauthorized},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/auth/bootstrap-admin", bytes.NewBufferString(`{"username":"admin","email":"admin@example.com","password":"AdminPass123!"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
if tc.secret != "" {
|
||||
c.Request.Header.Set("X-Bootstrap-Secret", tc.secret)
|
||||
}
|
||||
|
||||
h.BootstrapAdmin(c)
|
||||
|
||||
if recorder.Code != tc.want {
|
||||
t.Fatalf("expected %d, got %d", tc.want, recorder.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user