test: 补齐 handler/repository/domain 层单元测试
This commit is contained in:
102
internal/api/middleware/gzip_test.go
Normal file
102
internal/api/middleware/gzip_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestGzipMiddleware_CompressesLargeJSONResponses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(GzipMiddleware())
|
||||
router.GET("/data", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.String(http.StatusOK, strings.Repeat("a", gzipMinLength+128))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/data", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if got := recorder.Header().Get("Content-Encoding"); got != "gzip" {
|
||||
t.Fatalf("Content-Encoding = %q, want gzip", got)
|
||||
}
|
||||
|
||||
reader, err := gzip.NewReader(bytes.NewReader(recorder.Body.Bytes()))
|
||||
if err != nil {
|
||||
t.Fatalf("gzip.NewReader() error = %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
payload, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll() error = %v", err)
|
||||
}
|
||||
if got := string(payload); got != strings.Repeat("a", gzipMinLength+128) {
|
||||
t.Fatalf("decompressed payload length = %d, want %d", len(got), gzipMinLength+128)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGzipMiddleware_PassesThroughWhenCompressionNotUseful(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
acceptEncoding string
|
||||
contentType string
|
||||
body string
|
||||
}{
|
||||
{
|
||||
name: "client does not accept gzip",
|
||||
acceptEncoding: "",
|
||||
contentType: "application/json",
|
||||
body: strings.Repeat("b", gzipMinLength+64),
|
||||
},
|
||||
{
|
||||
name: "body below threshold",
|
||||
acceptEncoding: "gzip",
|
||||
contentType: "application/json",
|
||||
body: "small-body",
|
||||
},
|
||||
{
|
||||
name: "unsupported content type",
|
||||
acceptEncoding: "gzip",
|
||||
contentType: "image/png",
|
||||
body: strings.Repeat("c", gzipMinLength+64),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(GzipMiddleware())
|
||||
router.GET("/data", func(c *gin.Context) {
|
||||
c.Header("Content-Type", tc.contentType)
|
||||
c.String(http.StatusOK, tc.body)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/data", nil)
|
||||
if tc.acceptEncoding != "" {
|
||||
req.Header.Set("Accept-Encoding", tc.acceptEncoding)
|
||||
}
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if got := recorder.Header().Get("Content-Encoding"); got != "" {
|
||||
t.Fatalf("Content-Encoding = %q, want empty", got)
|
||||
}
|
||||
if got := recorder.Body.String(); got != tc.body {
|
||||
t.Fatalf("body length = %d, want %d", len(got), len(tc.body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
165
internal/api/middleware/operation_log_test.go
Normal file
165
internal/api/middleware/operation_log_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func newOperationLogRepositoryForTest(t *testing.T) *repository.OperationLogRepository {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: "file:operation_log_test?mode=memory&cache=shared",
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite failed: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&domain.OperationLog{}); err != nil {
|
||||
t.Fatalf("migrate failed: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Exec("DELETE FROM operation_logs").Error; err != nil {
|
||||
t.Fatalf("cleanup operation_logs failed: %v", err)
|
||||
}
|
||||
|
||||
return repository.NewOperationLogRepository(db)
|
||||
}
|
||||
|
||||
func waitForOperationLogs(t *testing.T, repo *repository.OperationLogRepository, want int) []*domain.OperationLog {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
logs, _, err := repo.List(context.Background(), 0, 20)
|
||||
if err != nil {
|
||||
t.Fatalf("list operation logs failed: %v", err)
|
||||
}
|
||||
if len(logs) >= want {
|
||||
return logs
|
||||
}
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
}
|
||||
|
||||
logs, _, err := repo.List(context.Background(), 0, 20)
|
||||
if err != nil {
|
||||
t.Fatalf("list operation logs failed: %v", err)
|
||||
}
|
||||
t.Fatalf("timed out waiting for %d operation logs, got %d", want, len(logs))
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOperationLogMiddleware_SkipsReadOnlyMethods(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := newOperationLogRepositoryForTest(t)
|
||||
router := gin.New()
|
||||
router.Use(NewOperationLogMiddleware(repo).Record())
|
||||
router.GET("/logs", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
logs, _, err := repo.List(context.Background(), 0, 20)
|
||||
if err != nil {
|
||||
t.Fatalf("list operation logs failed: %v", err)
|
||||
}
|
||||
if len(logs) != 0 {
|
||||
t.Fatalf("expected no logs for GET request, got %d", len(logs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOperationLogMiddleware_RecordsAdminMutationAndSanitizesParams(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := newOperationLogRepositoryForTest(t)
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set("user_id", int64(42))
|
||||
c.Set(ContextKeyRoleCodes, []string{"admin"})
|
||||
c.Next()
|
||||
})
|
||||
router.Use(NewOperationLogMiddleware(repo).Record())
|
||||
router.POST("/users", func(c *gin.Context) {
|
||||
c.Status(http.StatusCreated)
|
||||
})
|
||||
|
||||
body := `{"username":"alice","password":"super-secret","token":"abc"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/users", strings.NewReader(body))
|
||||
req.RemoteAddr = "203.0.113.10:8080"
|
||||
req.Header.Set("User-Agent", "middleware-test")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusCreated {
|
||||
t.Fatalf("expected 201, got %d", recorder.Code)
|
||||
}
|
||||
|
||||
logs := waitForOperationLogs(t, repo, 1)
|
||||
entry := logs[0]
|
||||
if entry.UserID == nil || *entry.UserID != 42 {
|
||||
t.Fatalf("user_id = %#v, want 42", entry.UserID)
|
||||
}
|
||||
if entry.OperationType != "admin:CREATE" {
|
||||
t.Fatalf("operation_type = %q, want admin:CREATE", entry.OperationType)
|
||||
}
|
||||
if entry.ResponseStatus != http.StatusCreated {
|
||||
t.Fatalf("response_status = %d, want %d", entry.ResponseStatus, http.StatusCreated)
|
||||
}
|
||||
if strings.Contains(entry.RequestParams, "super-secret") || strings.Contains(entry.RequestParams, "abc") {
|
||||
t.Fatalf("expected sanitized params, got %s", entry.RequestParams)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOperationLogMiddleware_MethodToTypeAndSanitizeFallbacks(t *testing.T) {
|
||||
if got := methodToType(http.MethodPatch); got != "UPDATE" {
|
||||
t.Fatalf("methodToType(PATCH) = %q, want UPDATE", got)
|
||||
}
|
||||
if got := methodToType(http.MethodDelete); got != "DELETE" {
|
||||
t.Fatalf("methodToType(DELETE) = %q, want DELETE", got)
|
||||
}
|
||||
if got := methodToType(http.MethodGet); got != "OTHER" {
|
||||
t.Fatalf("methodToType(GET) = %q, want OTHER", got)
|
||||
}
|
||||
|
||||
raw := []byte(`{"password":"secret","name":"alice"}`)
|
||||
sanitized := sanitizeParams(raw)
|
||||
if strings.Contains(sanitized, "secret") {
|
||||
t.Fatalf("expected password to be masked, got %s", sanitized)
|
||||
}
|
||||
|
||||
plain := sanitizeParams([]byte("not-json"))
|
||||
if plain != "not-json" {
|
||||
t.Fatalf("sanitizeParams(non-json) = %q, want not-json", plain)
|
||||
}
|
||||
|
||||
var payload map[string]string
|
||||
if err := json.Unmarshal([]byte(sanitized), &payload); err != nil {
|
||||
t.Fatalf("unmarshal sanitized params failed: %v", err)
|
||||
}
|
||||
if payload["password"] != "***" {
|
||||
t.Fatalf("password = %q, want ***", payload["password"])
|
||||
}
|
||||
}
|
||||
114
internal/api/middleware/rbac_test.go
Normal file
114
internal/api/middleware/rbac_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func performRBACRequest(t *testing.T, setup func(*gin.Context), middleware gin.HandlerFunc) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
if setup != nil {
|
||||
router.Use(setup)
|
||||
}
|
||||
router.Use(middleware)
|
||||
router.GET("/protected", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
return recorder
|
||||
}
|
||||
|
||||
func TestRequirePermissionRejectsMissingPermission(t *testing.T) {
|
||||
recorder := performRBACRequest(t, func(c *gin.Context) {
|
||||
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
|
||||
c.Next()
|
||||
}, RequirePermission("users:write"))
|
||||
|
||||
if recorder.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequirePermissionAllowsMatchingPermission(t *testing.T) {
|
||||
recorder := performRBACRequest(t, func(c *gin.Context) {
|
||||
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
|
||||
c.Next()
|
||||
}, RequirePermission("users:read"))
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAllPermissionsRequiresEveryCode(t *testing.T) {
|
||||
recorder := performRBACRequest(t, func(c *gin.Context) {
|
||||
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
|
||||
c.Next()
|
||||
}, RequireAllPermissions("users:read", "users:write"))
|
||||
|
||||
if recorder.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAnyPermissionIsAliasOfRequirePermission(t *testing.T) {
|
||||
recorder := performRBACRequest(t, func(c *gin.Context) {
|
||||
c.Set(ContextKeyPermissionCodes, []string{"users:write"})
|
||||
c.Next()
|
||||
}, RequireAnyPermission("users:read", "users:write"))
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRoleAndAdminOnly(t *testing.T) {
|
||||
roleRecorder := performRBACRequest(t, func(c *gin.Context) {
|
||||
c.Set(ContextKeyRoleCodes, []string{"auditor"})
|
||||
c.Next()
|
||||
}, RequireRole("admin"))
|
||||
if roleRecorder.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected role check to return 403, got %d", roleRecorder.Code)
|
||||
}
|
||||
|
||||
adminRecorder := performRBACRequest(t, func(c *gin.Context) {
|
||||
c.Set(ContextKeyRoleCodes, []string{"admin"})
|
||||
c.Next()
|
||||
}, AdminOnly())
|
||||
if adminRecorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected admin check to return 200, got %d", adminRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRBACHelpersHandleMissingContextValues(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
|
||||
if got := GetRoleCodes(c); got != nil {
|
||||
t.Fatalf("GetRoleCodes() = %#v, want nil", got)
|
||||
}
|
||||
if got := GetPermissionCodes(c); got != nil {
|
||||
t.Fatalf("GetPermissionCodes() = %#v, want nil", got)
|
||||
}
|
||||
if IsAdmin(c) {
|
||||
t.Fatal("IsAdmin() = true, want false")
|
||||
}
|
||||
|
||||
c.Set(ContextKeyRoleCodes, []string{"admin"})
|
||||
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
|
||||
|
||||
if !IsAdmin(c) {
|
||||
t.Fatal("IsAdmin() = false, want true")
|
||||
}
|
||||
}
|
||||
119
internal/api/middleware/response_wrapper_test.go
Normal file
119
internal/api/middleware/response_wrapper_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestResponseWrapper_WrapsSuccessfulJSONPayload(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(ResponseWrapper())
|
||||
router.GET("/users", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"id": 1, "name": "alice"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
want := `{"code":0,"data":{"id":1,"name":"alice"},"message":"success"}`
|
||||
if got := recorder.Body.String(); got != want {
|
||||
t.Fatalf("body = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseWrapper_PassesThroughMarkedResponses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(ResponseWrapper())
|
||||
router.GET("/users", func(c *gin.Context) {
|
||||
WrapResponse(c)
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "already wrapped"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
want := `{"code":0,"message":"already wrapped"}`
|
||||
if got := recorder.Body.String(); got != want {
|
||||
t.Fatalf("body = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseWrapper_PassesThroughNonSuccessStatus(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(ResponseWrapper())
|
||||
router.GET("/users", func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||
}
|
||||
want := `{"message":"bad request"}`
|
||||
if got := recorder.Body.String(); got != want {
|
||||
t.Fatalf("body = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseWrapper_PassesThroughInvalidJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(ResponseWrapper())
|
||||
router.GET("/users", func(c *gin.Context) {
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, _ = c.Writer.WriteString("plain text")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
if got := recorder.Body.String(); got != "plain text" {
|
||||
t.Fatalf("body = %q, want plain text", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseWrapper_NoWrapperMarksContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(NoWrapper())
|
||||
router.GET("/users", func(c *gin.Context) {
|
||||
if _, exists := c.Get("response_wrapped"); !exists {
|
||||
t.Fatal("expected response_wrapped marker in context")
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user