test: 补齐 handler/repository/domain 层单元测试

This commit is contained in:
2026-05-10 12:54:13 +08:00
parent b8e9af001f
commit 28012140cb
21 changed files with 5837 additions and 1 deletions

View File

@@ -0,0 +1,102 @@
package middleware
import (
"bytes"
"compress/gzip"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
func TestGzipMiddleware_CompressesLargeJSONResponses(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(GzipMiddleware())
router.GET("/data", func(c *gin.Context) {
c.Header("Content-Type", "application/json")
c.String(http.StatusOK, strings.Repeat("a", gzipMinLength+128))
})
req := httptest.NewRequest(http.MethodGet, "/data", nil)
req.Header.Set("Accept-Encoding", "gzip")
router.ServeHTTP(recorder, req)
if got := recorder.Header().Get("Content-Encoding"); got != "gzip" {
t.Fatalf("Content-Encoding = %q, want gzip", got)
}
reader, err := gzip.NewReader(bytes.NewReader(recorder.Body.Bytes()))
if err != nil {
t.Fatalf("gzip.NewReader() error = %v", err)
}
defer reader.Close()
payload, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("ReadAll() error = %v", err)
}
if got := string(payload); got != strings.Repeat("a", gzipMinLength+128) {
t.Fatalf("decompressed payload length = %d, want %d", len(got), gzipMinLength+128)
}
}
func TestGzipMiddleware_PassesThroughWhenCompressionNotUseful(t *testing.T) {
gin.SetMode(gin.TestMode)
testCases := []struct {
name string
acceptEncoding string
contentType string
body string
}{
{
name: "client does not accept gzip",
acceptEncoding: "",
contentType: "application/json",
body: strings.Repeat("b", gzipMinLength+64),
},
{
name: "body below threshold",
acceptEncoding: "gzip",
contentType: "application/json",
body: "small-body",
},
{
name: "unsupported content type",
acceptEncoding: "gzip",
contentType: "image/png",
body: strings.Repeat("c", gzipMinLength+64),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(GzipMiddleware())
router.GET("/data", func(c *gin.Context) {
c.Header("Content-Type", tc.contentType)
c.String(http.StatusOK, tc.body)
})
req := httptest.NewRequest(http.MethodGet, "/data", nil)
if tc.acceptEncoding != "" {
req.Header.Set("Accept-Encoding", tc.acceptEncoding)
}
router.ServeHTTP(recorder, req)
if got := recorder.Header().Get("Content-Encoding"); got != "" {
t.Fatalf("Content-Encoding = %q, want empty", got)
}
if got := recorder.Body.String(); got != tc.body {
t.Fatalf("body length = %d, want %d", len(got), len(tc.body))
}
})
}
}

View File

@@ -0,0 +1,165 @@
package middleware
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
_ "modernc.org/sqlite"
)
func newOperationLogRepositoryForTest(t *testing.T) *repository.OperationLogRepository {
t.Helper()
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: "file:operation_log_test?mode=memory&cache=shared",
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite failed: %v", err)
}
if err := db.AutoMigrate(&domain.OperationLog{}); err != nil {
t.Fatalf("migrate failed: %v", err)
}
if err := db.Exec("DELETE FROM operation_logs").Error; err != nil {
t.Fatalf("cleanup operation_logs failed: %v", err)
}
return repository.NewOperationLogRepository(db)
}
func waitForOperationLogs(t *testing.T, repo *repository.OperationLogRepository, want int) []*domain.OperationLog {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
logs, _, err := repo.List(context.Background(), 0, 20)
if err != nil {
t.Fatalf("list operation logs failed: %v", err)
}
if len(logs) >= want {
return logs
}
time.Sleep(25 * time.Millisecond)
}
logs, _, err := repo.List(context.Background(), 0, 20)
if err != nil {
t.Fatalf("list operation logs failed: %v", err)
}
t.Fatalf("timed out waiting for %d operation logs, got %d", want, len(logs))
return nil
}
func TestOperationLogMiddleware_SkipsReadOnlyMethods(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newOperationLogRepositoryForTest(t)
router := gin.New()
router.Use(NewOperationLogMiddleware(repo).Record())
router.GET("/logs", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
time.Sleep(100 * time.Millisecond)
logs, _, err := repo.List(context.Background(), 0, 20)
if err != nil {
t.Fatalf("list operation logs failed: %v", err)
}
if len(logs) != 0 {
t.Fatalf("expected no logs for GET request, got %d", len(logs))
}
}
func TestOperationLogMiddleware_RecordsAdminMutationAndSanitizesParams(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newOperationLogRepositoryForTest(t)
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set("user_id", int64(42))
c.Set(ContextKeyRoleCodes, []string{"admin"})
c.Next()
})
router.Use(NewOperationLogMiddleware(repo).Record())
router.POST("/users", func(c *gin.Context) {
c.Status(http.StatusCreated)
})
body := `{"username":"alice","password":"super-secret","token":"abc"}`
req := httptest.NewRequest(http.MethodPost, "/users", strings.NewReader(body))
req.RemoteAddr = "203.0.113.10:8080"
req.Header.Set("User-Agent", "middleware-test")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d", recorder.Code)
}
logs := waitForOperationLogs(t, repo, 1)
entry := logs[0]
if entry.UserID == nil || *entry.UserID != 42 {
t.Fatalf("user_id = %#v, want 42", entry.UserID)
}
if entry.OperationType != "admin:CREATE" {
t.Fatalf("operation_type = %q, want admin:CREATE", entry.OperationType)
}
if entry.ResponseStatus != http.StatusCreated {
t.Fatalf("response_status = %d, want %d", entry.ResponseStatus, http.StatusCreated)
}
if strings.Contains(entry.RequestParams, "super-secret") || strings.Contains(entry.RequestParams, "abc") {
t.Fatalf("expected sanitized params, got %s", entry.RequestParams)
}
}
func TestOperationLogMiddleware_MethodToTypeAndSanitizeFallbacks(t *testing.T) {
if got := methodToType(http.MethodPatch); got != "UPDATE" {
t.Fatalf("methodToType(PATCH) = %q, want UPDATE", got)
}
if got := methodToType(http.MethodDelete); got != "DELETE" {
t.Fatalf("methodToType(DELETE) = %q, want DELETE", got)
}
if got := methodToType(http.MethodGet); got != "OTHER" {
t.Fatalf("methodToType(GET) = %q, want OTHER", got)
}
raw := []byte(`{"password":"secret","name":"alice"}`)
sanitized := sanitizeParams(raw)
if strings.Contains(sanitized, "secret") {
t.Fatalf("expected password to be masked, got %s", sanitized)
}
plain := sanitizeParams([]byte("not-json"))
if plain != "not-json" {
t.Fatalf("sanitizeParams(non-json) = %q, want not-json", plain)
}
var payload map[string]string
if err := json.Unmarshal([]byte(sanitized), &payload); err != nil {
t.Fatalf("unmarshal sanitized params failed: %v", err)
}
if payload["password"] != "***" {
t.Fatalf("password = %q, want ***", payload["password"])
}
}

View File

@@ -0,0 +1,114 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func performRBACRequest(t *testing.T, setup func(*gin.Context), middleware gin.HandlerFunc) *httptest.ResponseRecorder {
t.Helper()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
if setup != nil {
router.Use(setup)
}
router.Use(middleware)
router.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"code": 0})
})
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
router.ServeHTTP(recorder, req)
return recorder
}
func TestRequirePermissionRejectsMissingPermission(t *testing.T) {
recorder := performRBACRequest(t, func(c *gin.Context) {
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
c.Next()
}, RequirePermission("users:write"))
if recorder.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", recorder.Code)
}
}
func TestRequirePermissionAllowsMatchingPermission(t *testing.T) {
recorder := performRBACRequest(t, func(c *gin.Context) {
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
c.Next()
}, RequirePermission("users:read"))
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
}
func TestRequireAllPermissionsRequiresEveryCode(t *testing.T) {
recorder := performRBACRequest(t, func(c *gin.Context) {
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
c.Next()
}, RequireAllPermissions("users:read", "users:write"))
if recorder.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", recorder.Code)
}
}
func TestRequireAnyPermissionIsAliasOfRequirePermission(t *testing.T) {
recorder := performRBACRequest(t, func(c *gin.Context) {
c.Set(ContextKeyPermissionCodes, []string{"users:write"})
c.Next()
}, RequireAnyPermission("users:read", "users:write"))
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
}
func TestRequireRoleAndAdminOnly(t *testing.T) {
roleRecorder := performRBACRequest(t, func(c *gin.Context) {
c.Set(ContextKeyRoleCodes, []string{"auditor"})
c.Next()
}, RequireRole("admin"))
if roleRecorder.Code != http.StatusForbidden {
t.Fatalf("expected role check to return 403, got %d", roleRecorder.Code)
}
adminRecorder := performRBACRequest(t, func(c *gin.Context) {
c.Set(ContextKeyRoleCodes, []string{"admin"})
c.Next()
}, AdminOnly())
if adminRecorder.Code != http.StatusOK {
t.Fatalf("expected admin check to return 200, got %d", adminRecorder.Code)
}
}
func TestRBACHelpersHandleMissingContextValues(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
if got := GetRoleCodes(c); got != nil {
t.Fatalf("GetRoleCodes() = %#v, want nil", got)
}
if got := GetPermissionCodes(c); got != nil {
t.Fatalf("GetPermissionCodes() = %#v, want nil", got)
}
if IsAdmin(c) {
t.Fatal("IsAdmin() = true, want false")
}
c.Set(ContextKeyRoleCodes, []string{"admin"})
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
if !IsAdmin(c) {
t.Fatal("IsAdmin() = false, want true")
}
}

View File

@@ -0,0 +1,119 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestResponseWrapper_WrapsSuccessfulJSONPayload(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(ResponseWrapper())
router.GET("/users", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"id": 1, "name": "alice"})
})
req := httptest.NewRequest(http.MethodGet, "/users", nil)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
want := `{"code":0,"data":{"id":1,"name":"alice"},"message":"success"}`
if got := recorder.Body.String(); got != want {
t.Fatalf("body = %s, want %s", got, want)
}
}
func TestResponseWrapper_PassesThroughMarkedResponses(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(ResponseWrapper())
router.GET("/users", func(c *gin.Context) {
WrapResponse(c)
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "already wrapped"})
})
req := httptest.NewRequest(http.MethodGet, "/users", nil)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
want := `{"code":0,"message":"already wrapped"}`
if got := recorder.Body.String(); got != want {
t.Fatalf("body = %s, want %s", got, want)
}
}
func TestResponseWrapper_PassesThroughNonSuccessStatus(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(ResponseWrapper())
router.GET("/users", func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
})
req := httptest.NewRequest(http.MethodGet, "/users", nil)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
want := `{"message":"bad request"}`
if got := recorder.Body.String(); got != want {
t.Fatalf("body = %s, want %s", got, want)
}
}
func TestResponseWrapper_PassesThroughInvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(ResponseWrapper())
router.GET("/users", func(c *gin.Context) {
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.WriteString("plain text")
})
req := httptest.NewRequest(http.MethodGet, "/users", nil)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
if got := recorder.Body.String(); got != "plain text" {
t.Fatalf("body = %q, want plain text", got)
}
}
func TestResponseWrapper_NoWrapperMarksContext(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(NoWrapper())
router.GET("/users", func(c *gin.Context) {
if _, exists := c.Get("response_wrapped"); !exists {
t.Fatal("expected response_wrapped marker in context")
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/users", nil)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
}