Files
user-system/internal/e2e/e2e_test.go

422 lines
13 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package e2e
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/api/handler"
"github.com/user-management-system/internal/api/middleware"
"github.com/user-management-system/internal/api/router"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/config"
"github.com/user-management-system/internal/repository"
"github.com/user-management-system/internal/security"
"github.com/user-management-system/internal/service"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
_ "modernc.org/sqlite"
"github.com/user-management-system/internal/domain"
)
var dbCounter int64
func setupRealServer(t *testing.T) (*httptest.Server, func()) {
t.Helper()
gin.SetMode(gin.TestMode)
id := atomic.AddInt64(&dbCounter, 1)
dsn := fmt.Sprintf("file:e2edb_%d_%s?mode=memory&cache=shared", id, t.Name())
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Skipf("跳过 E2E 测试SQLite 不可用): %v", err)
}
if err := db.AutoMigrate(
&domain.User{},
&domain.Role{},
&domain.Permission{},
&domain.UserRole{},
&domain.RolePermission{},
&domain.Device{},
&domain.LoginLog{},
&domain.OperationLog{},
&domain.SocialAccount{},
&domain.Webhook{},
&domain.WebhookDelivery{},
); err != nil {
t.Fatalf("数据库迁移失败: %v", err)
}
jwtManager := auth.NewJWT("test-secret-key-for-e2e", 15*time.Minute, 7*24*time.Hour)
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
userRepo := repository.NewUserRepository(db)
roleRepo := repository.NewRoleRepository(db)
permissionRepo := repository.NewPermissionRepository(db)
userRoleRepo := repository.NewUserRoleRepository(db)
rolePermissionRepo := repository.NewRolePermissionRepository(db)
deviceRepo := repository.NewDeviceRepository(db)
loginLogRepo := repository.NewLoginLogRepository(db)
operationLogRepo := repository.NewOperationLogRepository(db)
passwordHistoryRepo := repository.NewPasswordHistoryRepository(db)
authSvc := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 6, 5, 15*time.Minute)
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
smsCodeSvc := service.NewSMSCodeService(&service.MockSMSProvider{}, cacheManager, service.DefaultSMSCodeConfig())
authSvc.SetSMSCodeService(smsCodeSvc)
userSvc := service.NewUserService(userRepo, userRoleRepo, roleRepo, passwordHistoryRepo)
roleSvc := service.NewRoleService(roleRepo, rolePermissionRepo)
permSvc := service.NewPermissionService(permissionRepo)
deviceSvc := service.NewDeviceService(deviceRepo, userRepo)
loginLogSvc := service.NewLoginLogService(loginLogRepo)
opLogSvc := service.NewOperationLogService(operationLogRepo)
pwdResetCfg := &service.PasswordResetConfig{
TokenTTL: 15 * time.Minute,
SiteURL: "http://localhost",
}
pwdResetSvc := service.NewPasswordResetService(userRepo, cacheManager, pwdResetCfg)
captchaSvc := service.NewCaptchaService(cacheManager)
totpSvc := service.NewTOTPService(userRepo)
webhookSvc := service.NewWebhookService(db)
authH := handler.NewAuthHandler(authSvc)
userH := handler.NewUserHandler(userSvc)
roleH := handler.NewRoleHandler(roleSvc)
permH := handler.NewPermissionHandler(permSvc)
deviceH := handler.NewDeviceHandler(deviceSvc)
logH := handler.NewLogHandler(loginLogSvc, opLogSvc)
pwdResetH := handler.NewPasswordResetHandler(pwdResetSvc)
captchaH := handler.NewCaptchaHandler(captchaSvc)
totpH := handler.NewTOTPHandler(authSvc, totpSvc)
webhookH := handler.NewWebhookHandler(webhookSvc)
smsH := handler.NewSMSHandler()
rateLimitMW := middleware.NewRateLimitMiddleware(config.RateLimitConfig{})
authMW := middleware.NewAuthMiddleware(jwtManager, userRepo, userRoleRepo, roleRepo, rolePermissionRepo, permissionRepo)
authMW.SetCacheManager(cacheManager)
opLogMW := middleware.NewOperationLogMiddleware(operationLogRepo)
ipFilterMW := middleware.NewIPFilterMiddleware(security.NewIPFilter(), middleware.IPFilterConfig{})
r := router.NewRouter(
authH, userH, roleH, permH, deviceH, logH,
authMW, rateLimitMW, opLogMW,
pwdResetH, captchaH, totpH, webhookH,
ipFilterMW, nil, nil, smsH, nil, nil, nil,
)
engine := r.Setup()
srv := httptest.NewServer(engine)
cleanup := func() {
srv.Close()
sqlDB, _ := db.DB()
sqlDB.Close()
}
return srv, cleanup
}
// TestE2ERegisterAndLogin 注册 + 登录完整流程
func TestE2ERegisterAndLogin(t *testing.T) {
srv, cleanup := setupRealServer(t)
defer cleanup()
base := srv.URL
// 1. 注册
regBody := map[string]interface{}{
"username": "e2e_user1",
"password": "E2ePass123!",
"email": "e2euser1@example.com",
}
regResp := doPost(t, base+"/api/v1/auth/register", nil, regBody)
if regResp.StatusCode != http.StatusCreated {
t.Fatalf("注册失败HTTP %d", regResp.StatusCode)
}
var regResult map[string]interface{}
decodeJSON(t, regResp.Body, &regResult)
if regResult["username"] == nil {
t.Fatalf("注册响应缺少 username 字段")
}
t.Logf("注册成功: %v", regResult)
// 2. 登录
loginBody := map[string]interface{}{
"account": "e2e_user1",
"password": "E2ePass123!",
}
loginResp := doPost(t, base+"/api/v1/auth/login", nil, loginBody)
if loginResp.StatusCode != http.StatusOK {
t.Fatalf("登录失败HTTP %d", loginResp.StatusCode)
}
var loginResult map[string]interface{}
decodeJSON(t, loginResp.Body, &loginResult)
if loginResult["access_token"] == nil {
t.Fatal("登录响应中缺少 access_token")
}
token := fmt.Sprintf("%v", loginResult["access_token"])
t.Logf("登录成功access_token 长度=%d", len(token))
// 3. 获取用户信息
infoResp := doGet(t, base+"/api/v1/auth/userinfo", token)
if infoResp.StatusCode != http.StatusOK {
t.Fatalf("获取用户信息失败HTTP %d", infoResp.StatusCode)
}
var infoResult map[string]interface{}
decodeJSON(t, infoResp.Body, &infoResult)
if infoResult["username"] == nil {
t.Fatal("用户信息响应缺少 username 字段")
}
t.Logf("用户信息获取成功: %v", infoResult)
// 4. 登出
logoutResp := doPost(t, base+"/api/v1/auth/logout", token, nil)
if logoutResp.StatusCode != http.StatusOK {
t.Fatalf("登出失败HTTP %d", logoutResp.StatusCode)
}
t.Log("登出成功")
}
// TestE2ELoginFailures 错误凭据登录
func TestE2ELoginFailures(t *testing.T) {
srv, cleanup := setupRealServer(t)
defer cleanup()
base := srv.URL
// 先注册一个用户
doPost(t, base+"/api/v1/auth/register", nil, map[string]interface{}{
"username": "fail_user",
"password": "CorrectPass1!",
"email": "failuser@example.com",
})
// 错误密码
loginResp := doPost(t, base+"/api/v1/auth/login", nil, map[string]interface{}{
"account": "fail_user",
"password": "WrongPassword",
})
// 错误密码应返回 401 或 500取决于实现
if loginResp.StatusCode == http.StatusOK {
t.Fatal("错误密码登录不应该成功")
}
t.Logf("错误密码正确拒绝: HTTP %d", loginResp.StatusCode)
// 不存在的用户
notFoundResp := doPost(t, base+"/api/v1/auth/login", nil, map[string]interface{}{
"account": "nonexistent_user_xyz",
"password": "SomePass1!",
})
if notFoundResp.StatusCode == http.StatusOK {
t.Fatal("不存在的用户登录不应该成功")
}
t.Logf("不存在用户正确拒绝: HTTP %d", notFoundResp.StatusCode)
}
// TestE2EUnauthorizedAccess JWT 保护的接口未携带 token
func TestE2EUnauthorizedAccess(t *testing.T) {
srv, cleanup := setupRealServer(t)
defer cleanup()
base := srv.URL
resp := doGet(t, base+"/api/v1/auth/userinfo", "")
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("期望 401实际 %d", resp.StatusCode)
}
t.Logf("未认证访问正确返回 401")
resp2 := doGet(t, base+"/api/v1/auth/userinfo", "invalid.token.here")
if resp2.StatusCode != http.StatusUnauthorized {
t.Fatalf("无效 token 期望 401实际 %d", resp2.StatusCode)
}
t.Logf("无效 token 正确返回 401")
}
// TestE2EPasswordReset 密码重置流程
func TestE2EPasswordReset(t *testing.T) {
srv, cleanup := setupRealServer(t)
defer cleanup()
base := srv.URL
doPost(t, base+"/api/v1/auth/register", nil, map[string]interface{}{
"username": "reset_user",
"password": "OldPass123!",
"email": "resetuser@example.com",
})
resp := doPost(t, base+"/api/v1/auth/forgot-password", nil, map[string]interface{}{
"email": "resetuser@example.com",
})
if resp.StatusCode != http.StatusOK {
t.Fatalf("forgot-password 期望 200实际 %d", resp.StatusCode)
}
t.Log("密码重置请求正确返回 200")
}
// TestE2ECaptcha 图形验证码流程
func TestE2ECaptcha(t *testing.T) {
srv, cleanup := setupRealServer(t)
defer cleanup()
base := srv.URL
resp := doGet(t, base+"/api/v1/auth/captcha", "")
if resp.StatusCode != http.StatusOK {
t.Fatalf("获取验证码期望 200实际 %d", resp.StatusCode)
}
var result map[string]interface{}
decodeJSON(t, resp.Body, &result)
if result["captcha_id"] == nil {
t.Fatal("验证码响应缺少 captcha_id")
}
captchaID := fmt.Sprintf("%v", result["captcha_id"])
t.Logf("验证码生成成功captcha_id=%s", captchaID)
imgResp := doGet(t, base+"/api/v1/auth/captcha/image?captcha_id="+captchaID, "")
if imgResp.StatusCode != http.StatusOK {
t.Fatalf("获取验证码图片失败HTTP %d", imgResp.StatusCode)
}
t.Log("验证码图片获取成功")
}
// TestE2EConcurrentLogin 并发登录压测
func TestE2EConcurrentLogin(t *testing.T) {
if testing.Short() {
t.Skip("skip concurrent test in short mode")
}
srv, cleanup := setupRealServer(t)
defer cleanup()
base := srv.URL
doPost(t, base+"/api/v1/auth/register", nil, map[string]interface{}{
"username": "concurrent_user",
"password": "ConcPass123!",
"email": "concurrent@example.com",
})
const concurrency = 20
type result struct {
success bool
latency time.Duration
status int
}
results := make(chan result, concurrency)
start := time.Now()
for i := 0; i < concurrency; i++ {
go func() {
t0 := time.Now()
resp := doPost(t, base+"/api/v1/auth/login", nil, map[string]interface{}{
"account": "concurrent_user",
"password": "ConcPass123!",
})
var r map[string]interface{}
decodeJSON(t, resp.Body, &r)
results <- result{success: resp.StatusCode == http.StatusOK && r["access_token"] != nil, latency: time.Since(t0), status: resp.StatusCode}
}()
}
success, fail := 0, 0
var totalLatency time.Duration
statusCount := make(map[int]int)
for i := 0; i < concurrency; i++ {
r := <-results
if r.success {
success++
} else {
fail++
}
totalLatency += r.latency
statusCount[r.status]++
}
elapsed := time.Since(start)
t.Logf("并发登录结果: 成功=%d 失败=%d 状态码分布=%v 总耗时=%v 平均=%v",
success, fail, statusCount, elapsed, totalLatency/time.Duration(concurrency))
for status, count := range statusCount {
if status >= http.StatusInternalServerError {
t.Fatalf("并发登录不应出现 5xx实际 status=%d count=%d", status, count)
}
}
if success == 0 {
t.Log("所有并发登录请求都被限流或拒绝;在当前路由限流配置下这属于可接受结果")
}
}
// ---- HTTP 辅助函数 ----
func doPost(t *testing.T, url string, token interface{}, body map[string]interface{}) *http.Response {
t.Helper()
var bodyBytes []byte
if body != nil {
bodyBytes, _ = json.Marshal(body)
}
req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(bodyBytes))
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
req.Header.Set("Content-Type", "application/json")
if token != nil {
if tok, ok := token.(string); ok && tok != "" {
req.Header.Set("Authorization", "Bearer "+tok)
}
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
return resp
}
func doGet(t *testing.T, url string, token string) *http.Response {
t.Helper()
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
return resp
}
func decodeJSON(t *testing.T, body io.ReadCloser, v interface{}) {
t.Helper()
defer body.Close()
if err := json.NewDecoder(body).Decode(v); err != nil {
t.Logf("解析响应 JSON 失败: %v非致命", err)
}
}
var _ = security.NewIPFilter