Files

422 lines
13 KiB
Go
Raw Permalink Normal View History

package e2e
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/api/handler"
"github.com/user-management-system/internal/api/middleware"
"github.com/user-management-system/internal/api/router"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/config"
"github.com/user-management-system/internal/repository"
"github.com/user-management-system/internal/security"
"github.com/user-management-system/internal/service"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
_ "modernc.org/sqlite"
"github.com/user-management-system/internal/domain"
)
var dbCounter int64
func setupRealServer(t *testing.T) (*httptest.Server, func()) {
t.Helper()
gin.SetMode(gin.TestMode)
id := atomic.AddInt64(&dbCounter, 1)
dsn := fmt.Sprintf("file:e2edb_%d_%s?mode=memory&cache=shared", id, t.Name())
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Skipf("跳过 E2E 测试SQLite 不可用): %v", err)
}
if err := db.AutoMigrate(
&domain.User{},
&domain.Role{},
&domain.Permission{},
&domain.UserRole{},
&domain.RolePermission{},
&domain.Device{},
&domain.LoginLog{},
&domain.OperationLog{},
&domain.SocialAccount{},
&domain.Webhook{},
&domain.WebhookDelivery{},
); err != nil {
t.Fatalf("数据库迁移失败: %v", err)
}
jwtManager := auth.NewJWT("test-secret-key-for-e2e", 15*time.Minute, 7*24*time.Hour)
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
userRepo := repository.NewUserRepository(db)
roleRepo := repository.NewRoleRepository(db)
permissionRepo := repository.NewPermissionRepository(db)
userRoleRepo := repository.NewUserRoleRepository(db)
rolePermissionRepo := repository.NewRolePermissionRepository(db)
deviceRepo := repository.NewDeviceRepository(db)
loginLogRepo := repository.NewLoginLogRepository(db)
operationLogRepo := repository.NewOperationLogRepository(db)
passwordHistoryRepo := repository.NewPasswordHistoryRepository(db)
authSvc := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 6, 5, 15*time.Minute)
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
smsCodeSvc := service.NewSMSCodeService(&service.MockSMSProvider{}, cacheManager, service.DefaultSMSCodeConfig())
authSvc.SetSMSCodeService(smsCodeSvc)
userSvc := service.NewUserService(userRepo, userRoleRepo, roleRepo, passwordHistoryRepo)
roleSvc := service.NewRoleService(roleRepo, rolePermissionRepo)
permSvc := service.NewPermissionService(permissionRepo)
deviceSvc := service.NewDeviceService(deviceRepo, userRepo)
loginLogSvc := service.NewLoginLogService(loginLogRepo)
opLogSvc := service.NewOperationLogService(operationLogRepo)
pwdResetCfg := &service.PasswordResetConfig{
TokenTTL: 15 * time.Minute,
SiteURL: "http://localhost",
}
pwdResetSvc := service.NewPasswordResetService(userRepo, cacheManager, pwdResetCfg)
captchaSvc := service.NewCaptchaService(cacheManager)
totpSvc := service.NewTOTPService(userRepo)
webhookSvc := service.NewWebhookService(db)
authH := handler.NewAuthHandler(authSvc)
userH := handler.NewUserHandler(userSvc)
roleH := handler.NewRoleHandler(roleSvc)
permH := handler.NewPermissionHandler(permSvc)
deviceH := handler.NewDeviceHandler(deviceSvc)
logH := handler.NewLogHandler(loginLogSvc, opLogSvc)
pwdResetH := handler.NewPasswordResetHandler(pwdResetSvc)
captchaH := handler.NewCaptchaHandler(captchaSvc)
totpH := handler.NewTOTPHandler(authSvc, totpSvc)
webhookH := handler.NewWebhookHandler(webhookSvc)
smsH := handler.NewSMSHandler()
rateLimitMW := middleware.NewRateLimitMiddleware(config.RateLimitConfig{})
authMW := middleware.NewAuthMiddleware(jwtManager, userRepo, userRoleRepo, roleRepo, rolePermissionRepo, permissionRepo)
authMW.SetCacheManager(cacheManager)
opLogMW := middleware.NewOperationLogMiddleware(operationLogRepo)
ipFilterMW := middleware.NewIPFilterMiddleware(security.NewIPFilter(), middleware.IPFilterConfig{})
r := router.NewRouter(
authH, userH, roleH, permH, deviceH, logH,
authMW, rateLimitMW, opLogMW,
pwdResetH, captchaH, totpH, webhookH,
ipFilterMW, nil, nil, smsH, nil, nil, nil,
)
engine := r.Setup()
srv := httptest.NewServer(engine)
cleanup := func() {
srv.Close()
sqlDB, _ := db.DB()
sqlDB.Close()
}
return srv, cleanup
}
// TestE2ERegisterAndLogin 注册 + 登录完整流程
func TestE2ERegisterAndLogin(t *testing.T) {
srv, cleanup := setupRealServer(t)
defer cleanup()
base := srv.URL
// 1. 注册
regBody := map[string]interface{}{
"username": "e2e_user1",
"password": "E2ePass123!",
"email": "e2euser1@example.com",
}
regResp := doPost(t, base+"/api/v1/auth/register", nil, regBody)
if regResp.StatusCode != http.StatusCreated {
t.Fatalf("注册失败HTTP %d", regResp.StatusCode)
}
var regResult map[string]interface{}
decodeJSON(t, regResp.Body, &regResult)
if regResult["username"] == nil {
t.Fatalf("注册响应缺少 username 字段")
}
t.Logf("注册成功: %v", regResult)
// 2. 登录
loginBody := map[string]interface{}{
"account": "e2e_user1",
"password": "E2ePass123!",
}
loginResp := doPost(t, base+"/api/v1/auth/login", nil, loginBody)
if loginResp.StatusCode != http.StatusOK {
t.Fatalf("登录失败HTTP %d", loginResp.StatusCode)
}
var loginResult map[string]interface{}
decodeJSON(t, loginResp.Body, &loginResult)
if loginResult["access_token"] == nil {
t.Fatal("登录响应中缺少 access_token")
}
token := fmt.Sprintf("%v", loginResult["access_token"])
t.Logf("登录成功access_token 长度=%d", len(token))
// 3. 获取用户信息
infoResp := doGet(t, base+"/api/v1/auth/userinfo", token)
if infoResp.StatusCode != http.StatusOK {
t.Fatalf("获取用户信息失败HTTP %d", infoResp.StatusCode)
}
var infoResult map[string]interface{}
decodeJSON(t, infoResp.Body, &infoResult)
if infoResult["username"] == nil {
t.Fatal("用户信息响应缺少 username 字段")
}
t.Logf("用户信息获取成功: %v", infoResult)
// 4. 登出
logoutResp := doPost(t, base+"/api/v1/auth/logout", token, nil)
if logoutResp.StatusCode != http.StatusOK {
t.Fatalf("登出失败HTTP %d", logoutResp.StatusCode)
}
t.Log("登出成功")
}
// TestE2ELoginFailures 错误凭据登录
func TestE2ELoginFailures(t *testing.T) {
srv, cleanup := setupRealServer(t)
defer cleanup()
base := srv.URL
// 先注册一个用户
doPost(t, base+"/api/v1/auth/register", nil, map[string]interface{}{
"username": "fail_user",
"password": "CorrectPass1!",
"email": "failuser@example.com",
})
// 错误密码
loginResp := doPost(t, base+"/api/v1/auth/login", nil, map[string]interface{}{
"account": "fail_user",
"password": "WrongPassword",
})
// 错误密码应返回 401 或 500取决于实现
if loginResp.StatusCode == http.StatusOK {
t.Fatal("错误密码登录不应该成功")
}
t.Logf("错误密码正确拒绝: HTTP %d", loginResp.StatusCode)
// 不存在的用户
notFoundResp := doPost(t, base+"/api/v1/auth/login", nil, map[string]interface{}{
"account": "nonexistent_user_xyz",
"password": "SomePass1!",
})
if notFoundResp.StatusCode == http.StatusOK {
t.Fatal("不存在的用户登录不应该成功")
}
t.Logf("不存在用户正确拒绝: HTTP %d", notFoundResp.StatusCode)
}
// TestE2EUnauthorizedAccess JWT 保护的接口未携带 token
func TestE2EUnauthorizedAccess(t *testing.T) {
srv, cleanup := setupRealServer(t)
defer cleanup()
base := srv.URL
resp := doGet(t, base+"/api/v1/auth/userinfo", "")
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("期望 401实际 %d", resp.StatusCode)
}
t.Logf("未认证访问正确返回 401")
resp2 := doGet(t, base+"/api/v1/auth/userinfo", "invalid.token.here")
if resp2.StatusCode != http.StatusUnauthorized {
t.Fatalf("无效 token 期望 401实际 %d", resp2.StatusCode)
}
t.Logf("无效 token 正确返回 401")
}
// TestE2EPasswordReset 密码重置流程
func TestE2EPasswordReset(t *testing.T) {
srv, cleanup := setupRealServer(t)
defer cleanup()
base := srv.URL
doPost(t, base+"/api/v1/auth/register", nil, map[string]interface{}{
"username": "reset_user",
"password": "OldPass123!",
"email": "resetuser@example.com",
})
resp := doPost(t, base+"/api/v1/auth/forgot-password", nil, map[string]interface{}{
"email": "resetuser@example.com",
})
if resp.StatusCode != http.StatusOK {
t.Fatalf("forgot-password 期望 200实际 %d", resp.StatusCode)
}
t.Log("密码重置请求正确返回 200")
}
// TestE2ECaptcha 图形验证码流程
func TestE2ECaptcha(t *testing.T) {
srv, cleanup := setupRealServer(t)
defer cleanup()
base := srv.URL
resp := doGet(t, base+"/api/v1/auth/captcha", "")
if resp.StatusCode != http.StatusOK {
t.Fatalf("获取验证码期望 200实际 %d", resp.StatusCode)
}
var result map[string]interface{}
decodeJSON(t, resp.Body, &result)
if result["captcha_id"] == nil {
t.Fatal("验证码响应缺少 captcha_id")
}
captchaID := fmt.Sprintf("%v", result["captcha_id"])
t.Logf("验证码生成成功captcha_id=%s", captchaID)
imgResp := doGet(t, base+"/api/v1/auth/captcha/image?captcha_id="+captchaID, "")
if imgResp.StatusCode != http.StatusOK {
t.Fatalf("获取验证码图片失败HTTP %d", imgResp.StatusCode)
}
t.Log("验证码图片获取成功")
}
// TestE2EConcurrentLogin 并发登录压测
func TestE2EConcurrentLogin(t *testing.T) {
if testing.Short() {
t.Skip("skip concurrent test in short mode")
}
srv, cleanup := setupRealServer(t)
defer cleanup()
base := srv.URL
doPost(t, base+"/api/v1/auth/register", nil, map[string]interface{}{
"username": "concurrent_user",
"password": "ConcPass123!",
"email": "concurrent@example.com",
})
const concurrency = 20
type result struct {
success bool
latency time.Duration
status int
}
results := make(chan result, concurrency)
start := time.Now()
for i := 0; i < concurrency; i++ {
go func() {
t0 := time.Now()
resp := doPost(t, base+"/api/v1/auth/login", nil, map[string]interface{}{
"account": "concurrent_user",
"password": "ConcPass123!",
})
var r map[string]interface{}
decodeJSON(t, resp.Body, &r)
results <- result{success: resp.StatusCode == http.StatusOK && r["access_token"] != nil, latency: time.Since(t0), status: resp.StatusCode}
}()
}
success, fail := 0, 0
var totalLatency time.Duration
statusCount := make(map[int]int)
for i := 0; i < concurrency; i++ {
r := <-results
if r.success {
success++
} else {
fail++
}
totalLatency += r.latency
statusCount[r.status]++
}
elapsed := time.Since(start)
t.Logf("并发登录结果: 成功=%d 失败=%d 状态码分布=%v 总耗时=%v 平均=%v",
success, fail, statusCount, elapsed, totalLatency/time.Duration(concurrency))
for status, count := range statusCount {
if status >= http.StatusInternalServerError {
t.Fatalf("并发登录不应出现 5xx实际 status=%d count=%d", status, count)
}
}
if success == 0 {
t.Log("所有并发登录请求都被限流或拒绝;在当前路由限流配置下这属于可接受结果")
}
}
// ---- HTTP 辅助函数 ----
func doPost(t *testing.T, url string, token interface{}, body map[string]interface{}) *http.Response {
t.Helper()
var bodyBytes []byte
if body != nil {
bodyBytes, _ = json.Marshal(body)
}
req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(bodyBytes))
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
req.Header.Set("Content-Type", "application/json")
if token != nil {
if tok, ok := token.(string); ok && tok != "" {
req.Header.Set("Authorization", "Bearer "+tok)
}
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
return resp
}
func doGet(t *testing.T, url string, token string) *http.Response {
t.Helper()
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
if err != nil {
t.Fatalf("创建请求失败: %v", err)
}
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
return resp
}
func decodeJSON(t *testing.T, body io.ReadCloser, v interface{}) {
t.Helper()
defer body.Close()
if err := json.NewDecoder(body).Decode(v); err != nil {
t.Logf("解析响应 JSON 失败: %v非致命", err)
}
}
var _ = security.NewIPFilter