feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers

This commit is contained in:
2026-04-02 11:19:50 +08:00
parent e59a77bc49
commit dcc1f186f8
298 changed files with 62603 additions and 0 deletions

View File

@@ -0,0 +1,240 @@
package middleware
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain"
apierrors "github.com/user-management-system/internal/pkg/errors"
"github.com/user-management-system/internal/repository"
)
type AuthMiddleware struct {
jwt *auth.JWT
userRepo *repository.UserRepository
userRoleRepo *repository.UserRoleRepository
roleRepo *repository.RoleRepository
rolePermissionRepo *repository.RolePermissionRepository
permissionRepo *repository.PermissionRepository
l1Cache *cache.L1Cache
cacheManager *cache.CacheManager
}
func NewAuthMiddleware(
jwt *auth.JWT,
userRepo *repository.UserRepository,
userRoleRepo *repository.UserRoleRepository,
roleRepo *repository.RoleRepository,
rolePermissionRepo *repository.RolePermissionRepository,
permissionRepo *repository.PermissionRepository,
) *AuthMiddleware {
return &AuthMiddleware{
jwt: jwt,
userRepo: userRepo,
userRoleRepo: userRoleRepo,
roleRepo: roleRepo,
rolePermissionRepo: rolePermissionRepo,
permissionRepo: permissionRepo,
l1Cache: cache.NewL1Cache(),
}
}
func (m *AuthMiddleware) SetCacheManager(cm *cache.CacheManager) {
m.cacheManager = cm
}
func (m *AuthMiddleware) Required() gin.HandlerFunc {
return func(c *gin.Context) {
token := m.extractToken(c)
if token == "" {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "未提供认证令牌"))
c.Abort()
return
}
claims, err := m.jwt.ValidateAccessToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "无效的认证令牌"))
c.Abort()
return
}
if m.isJTIBlacklisted(claims.JTI) {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "令牌已失效,请重新登录"))
c.Abort()
return
}
if !m.isUserActive(c.Request.Context(), claims.UserID) {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "账号不可用,请重新登录"))
c.Abort()
return
}
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("token_jti", claims.JTI)
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
c.Set("role_codes", roleCodes)
c.Set("permission_codes", permCodes)
c.Next()
}
}
func (m *AuthMiddleware) Optional() gin.HandlerFunc {
return func(c *gin.Context) {
token := m.extractToken(c)
if token != "" {
claims, err := m.jwt.ValidateAccessToken(token)
if err == nil && !m.isJTIBlacklisted(claims.JTI) && m.isUserActive(c.Request.Context(), claims.UserID) {
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("token_jti", claims.JTI)
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
c.Set("role_codes", roleCodes)
c.Set("permission_codes", permCodes)
}
}
c.Next()
}
}
func (m *AuthMiddleware) isJTIBlacklisted(jti string) bool {
if jti == "" {
return false
}
key := "jwt_blacklist:" + jti
if _, ok := m.l1Cache.Get(key); ok {
return true
}
if m.cacheManager != nil {
if _, ok := m.cacheManager.Get(context.Background(), key); ok {
return true
}
}
return false
}
func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) {
if m.userRoleRepo == nil || m.roleRepo == nil || m.rolePermissionRepo == nil || m.permissionRepo == nil {
return nil, nil
}
cacheKey := fmt.Sprintf("user_perms:%d", userID)
if cached, ok := m.l1Cache.Get(cacheKey); ok {
if entry, ok := cached.(userPermEntry); ok {
return entry.roles, entry.perms
}
}
roleIDs, err := m.userRoleRepo.GetRoleIDsByUserID(ctx, userID)
if err != nil || len(roleIDs) == 0 {
return nil, nil
}
// 收集所有角色ID包括直接分配的角色和所有祖先角色
allRoleIDs := make([]int64, 0, len(roleIDs)*2)
allRoleIDs = append(allRoleIDs, roleIDs...)
for _, roleID := range roleIDs {
ancestorIDs, err := m.roleRepo.GetAncestorIDs(ctx, roleID)
if err == nil && len(ancestorIDs) > 0 {
allRoleIDs = append(allRoleIDs, ancestorIDs...)
}
}
// 去重
seen := make(map[int64]bool)
uniqueRoleIDs := make([]int64, 0, len(allRoleIDs))
for _, id := range allRoleIDs {
if !seen[id] {
seen[id] = true
uniqueRoleIDs = append(uniqueRoleIDs, id)
}
}
roles, err := m.roleRepo.GetByIDs(ctx, roleIDs)
if err != nil {
return nil, nil
}
roleCodes := make([]string, 0, len(roles))
for _, role := range roles {
roleCodes = append(roleCodes, role.Code)
}
permissionIDs, err := m.rolePermissionRepo.GetPermissionIDsByRoleIDs(ctx, uniqueRoleIDs)
if err != nil || len(permissionIDs) == 0 {
entry := userPermEntry{roles: roleCodes, perms: []string{}}
m.l1Cache.Set(cacheKey, entry, 30*time.Minute) // PERF-01 优化:增加缓存 TTL 减少 DB 查询
return entry.roles, entry.perms
}
permissions, err := m.permissionRepo.GetByIDs(ctx, permissionIDs)
if err != nil {
return roleCodes, nil
}
permCodes := make([]string, 0, len(permissions))
for _, permission := range permissions {
permCodes = append(permCodes, permission.Code)
}
m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 30*time.Minute) // PERF-01 优化:增加缓存 TTL 减少 DB 查询
return roleCodes, permCodes
}
func (m *AuthMiddleware) InvalidateUserPermCache(userID int64) {
m.l1Cache.Delete(fmt.Sprintf("user_perms:%d", userID))
}
func (m *AuthMiddleware) AddToBlacklist(jti string, ttl time.Duration) {
if jti != "" && ttl > 0 {
m.l1Cache.Set("jwt_blacklist:"+jti, true, ttl)
}
}
func (m *AuthMiddleware) isUserActive(ctx context.Context, userID int64) bool {
if m.userRepo == nil {
return true
}
user, err := m.userRepo.GetByID(ctx, userID)
if err != nil {
return false
}
return user.Status == domain.UserStatusActive
}
func (m *AuthMiddleware) extractToken(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
return ""
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
return ""
}
return parts[1]
}
type userPermEntry struct {
roles []string
perms []string
}

View File

@@ -0,0 +1,32 @@
package middleware
import (
"strings"
"github.com/gin-gonic/gin"
)
const sensitiveNoStoreCacheControl = "no-store, no-cache, must-revalidate, max-age=0"
// NoStoreSensitiveResponses prevents browser or intermediary caching for auth routes.
func NoStoreSensitiveResponses() gin.HandlerFunc {
return func(c *gin.Context) {
if shouldDisableCaching(c.FullPath(), c.Request.URL.Path) {
headers := c.Writer.Header()
headers.Set("Cache-Control", sensitiveNoStoreCacheControl)
headers.Set("Pragma", "no-cache")
headers.Set("Expires", "0")
headers.Set("Surrogate-Control", "no-store")
}
c.Next()
}
}
func shouldDisableCaching(routePath, requestPath string) bool {
path := strings.TrimSpace(routePath)
if path == "" {
path = strings.TrimSpace(requestPath)
}
return strings.HasPrefix(path, "/api/v1/auth")
}

View File

@@ -0,0 +1,67 @@
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
)
var corsConfig = config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
}
func SetCORSConfig(cfg config.CORSConfig) {
corsConfig = cfg
}
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
cfg := corsConfig
origin := c.GetHeader("Origin")
if origin != "" {
allowOrigin, allowed := resolveAllowedOrigin(origin, cfg.AllowedOrigins, cfg.AllowCredentials)
if !allowed {
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusForbidden)
return
}
c.AbortWithStatus(http.StatusForbidden)
return
}
c.Writer.Header().Set("Access-Control-Allow-Origin", allowOrigin)
if cfg.AllowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
if c.Request.Method == http.MethodOptions {
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With, X-CSRF-Token")
c.Writer.Header().Set("Access-Control-Max-Age", "3600")
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
func resolveAllowedOrigin(origin string, allowedOrigins []string, allowCredentials bool) (string, bool) {
for _, allowed := range allowedOrigins {
if allowed == "*" {
if allowCredentials {
return origin, true
}
return "*", true
}
if strings.EqualFold(origin, allowed) {
return origin, true
}
}
return "", false
}

View File

@@ -0,0 +1,43 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
apierrors "github.com/user-management-system/internal/pkg/errors"
)
// ErrorHandler 错误处理中间件
func ErrorHandler() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
// 检查是否有错误
if len(c.Errors) > 0 {
// 获取最后一个错误
err := c.Errors.Last()
// 判断错误类型
if appErr, ok := err.Err.(*apierrors.ApplicationError); ok {
c.JSON(int(appErr.Code), appErr)
} else {
c.JSON(http.StatusInternalServerError, apierrors.New(http.StatusInternalServerError, "", err.Err.Error()))
}
return
}
}
}
// Recover 恢复中间件
func Recover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
c.JSON(http.StatusInternalServerError, apierrors.New(http.StatusInternalServerError, "", "服务器内部错误"))
c.Abort()
}
}()
c.Next()
}
}

View File

@@ -0,0 +1,134 @@
package middleware
import (
"net"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/security"
)
// IPFilterConfig IP过滤中间件配置
type IPFilterConfig struct {
TrustProxy bool // 是否信任 X-Forwarded-For
TrustedProxies []string // 可信代理 IP 列表
}
// IPFilterMiddleware IP 黑白名单过滤中间件
type IPFilterMiddleware struct {
filter *security.IPFilter
config IPFilterConfig
}
// NewIPFilterMiddleware 创建 IP 过滤中间件
func NewIPFilterMiddleware(filter *security.IPFilter, config IPFilterConfig) *IPFilterMiddleware {
return &IPFilterMiddleware{filter: filter, config: config}
}
// Filter 返回 Gin 中间件 HandlerFunc
// 逻辑:先取客户端真实 IP → 检查黑名单 → 被封则返回 403 并终止
func (m *IPFilterMiddleware) Filter() gin.HandlerFunc {
return func(c *gin.Context) {
ip := m.realIP(c)
blocked, reason := m.filter.IsBlocked(ip)
if blocked {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "访问被拒绝:" + reason,
})
return
}
// 将真实 IP 写入 context供后续中间件和 handler 直接取用
c.Set("client_ip", ip)
c.Next()
}
}
// GetFilter 返回底层 IPFilter供 handler 层做黑白名单管理
func (m *IPFilterMiddleware) GetFilter() *security.IPFilter {
return m.filter
}
// realIP 从请求中提取真实客户端 IP
// 优先级X-Forwarded-For > X-Real-IP > RemoteAddr
// SEC-05 修复:如果启用 TrustProxy只接受来自可信代理的 X-Forwarded-For
func (m *IPFilterMiddleware) realIP(c *gin.Context) string {
// 如果不信任代理,直接使用 TCP 连接 IP
if !m.config.TrustProxy {
return c.ClientIP()
}
// X-Forwarded-For 可能包含代理链
xff := c.GetHeader("X-Forwarded-For")
if xff != "" {
// 从右到左遍历(最右边是最后一次代理添加的)
for _, part := range strings.Split(xff, ",") {
ip := strings.TrimSpace(part)
if ip == "" {
continue
}
// 检查是否是可信代理
if !m.isTrustedProxy(ip) {
continue // 不是可信代理,跳过
}
// 是可信代理,检查是否为公网 IP
if !isPrivateIP(ip) {
return ip
}
}
}
// X-Real-IPNginx 反代常用)
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// 直接 TCP 连接的 RemoteAddr去掉端口号
ip, _, err := net.SplitHostPort(c.Request.RemoteAddr)
if err != nil {
return c.Request.RemoteAddr
}
return ip
}
// isTrustedProxy 检查 IP 是否在可信代理列表中
func (m *IPFilterMiddleware) isTrustedProxy(ip string) bool {
if len(m.config.TrustedProxies) == 0 {
return true // 如果没有配置可信代理列表,默认信任所有(兼容旧行为)
}
for _, trusted := range m.config.TrustedProxies {
if ip == trusted {
return true
}
}
return false
}
// isPrivateIP 判断是否为内网 IP
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
privateRanges := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"::1/128",
"fc00::/7",
}
for _, cidr := range privateRanges {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
if network.Contains(ip) {
return true
}
}
return false
}

View File

@@ -0,0 +1,258 @@
package middleware
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/security"
)
func init() {
gin.SetMode(gin.TestMode)
}
// newTestEngine 用给定的 IPFilterMiddleware 构建一个最简 Gin 引擎,
// 注册一个 GET /ping 路由,返回 client_ip 值。
func newTestEngine(f *security.IPFilter) *gin.Engine {
engine := gin.New()
engine.Use(NewIPFilterMiddleware(f, IPFilterConfig{}).Filter())
engine.GET("/ping", func(c *gin.Context) {
ip, _ := c.Get("client_ip")
c.JSON(http.StatusOK, gin.H{"ip": ip})
})
return engine
}
// doRequest 发送 GET /ping返回响应码和响应 body map。
func doRequest(engine *gin.Engine, remoteAddr, xff, xri string) (int, map[string]interface{}) {
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
req.RemoteAddr = remoteAddr
if xff != "" {
req.Header.Set("X-Forwarded-For", xff)
}
if xri != "" {
req.Header.Set("X-Real-IP", xri)
}
w := httptest.NewRecorder()
engine.ServeHTTP(w, req)
var body map[string]interface{}
_ = json.Unmarshal(w.Body.Bytes(), &body)
return w.Code, body
}
// ---------- 黑名单拦截 ----------
func TestIPFilter_BlockedIP_Returns403(t *testing.T) {
f := security.NewIPFilter()
_ = f.AddToBlacklist("1.2.3.4", "测试封禁", 0)
engine := newTestEngine(f)
code, body := doRequest(engine, "1.2.3.4:9999", "", "")
if code != http.StatusForbidden {
t.Fatalf("期望 403实际 %d", code)
}
msg, _ := body["message"].(string)
if msg == "" {
t.Error("期望 body 中包含 message 字段")
}
}
func TestIPFilter_NonBlockedIP_Returns200(t *testing.T) {
f := security.NewIPFilter()
_ = f.AddToBlacklist("9.9.9.9", "其他 IP", 0)
engine := newTestEngine(f)
code, _ := doRequest(engine, "1.2.3.4:9999", "", "")
if code != http.StatusOK {
t.Fatalf("期望 200实际 %d", code)
}
}
func TestIPFilter_EmptyBlacklist_AllPass(t *testing.T) {
f := security.NewIPFilter()
engine := newTestEngine(f)
for _, ip := range []string{"1.1.1.1:80", "8.8.8.8:443", "203.0.113.5:1234"} {
code, _ := doRequest(engine, ip, "", "")
if code != http.StatusOK {
t.Errorf("IP %s 应通过,实际 %d", ip, code)
}
}
}
// ---------- 白名单豁免 ----------
func TestIPFilter_WhitelistOverridesBlacklist(t *testing.T) {
f := security.NewIPFilter()
_ = f.AddToBlacklist("5.5.5.5", "封禁测试", 0)
_ = f.AddToWhitelist("5.5.5.5", "白名单豁免")
engine := newTestEngine(f)
// 白名单优先,应通过
code, _ := doRequest(engine, "5.5.5.5:8080", "", "")
if code != http.StatusOK {
t.Fatalf("白名单 IP 应返回 200实际 %d", code)
}
}
// ---------- CIDR 黑名单 ----------
func TestIPFilter_CIDRBlacklist(t *testing.T) {
f := security.NewIPFilter()
_ = f.AddToBlacklist("10.10.10.0/24", "封禁整段 CIDR", 0)
engine := newTestEngine(f)
// 在 CIDR 范围内,应被封
code, _ := doRequest(engine, "10.10.10.55:1234", "", "")
if code != http.StatusForbidden {
t.Fatalf("CIDR 内 IP 应返回 403实际 %d", code)
}
// 不在 CIDR 范围内,应通过
code2, _ := doRequest(engine, "10.10.11.1:1234", "", "")
if code2 != http.StatusOK {
t.Fatalf("CIDR 外 IP 应返回 200实际 %d", code2)
}
}
// ---------- 过期规则 ----------
func TestIPFilter_ExpiredRule_Passes(t *testing.T) {
f := security.NewIPFilter()
// 封禁 1 纳秒,几乎立即过期
_ = f.AddToBlacklist("7.7.7.7", "即将过期", time.Nanosecond)
time.Sleep(2 * time.Millisecond)
engine := newTestEngine(f)
code, _ := doRequest(engine, "7.7.7.7:80", "", "")
if code != http.StatusOK {
t.Fatalf("过期规则不应拦截,期望 200实际 %d", code)
}
}
// ---------- client_ip 注入 ----------
func TestIPFilter_ClientIPSetInContext(t *testing.T) {
f := security.NewIPFilter()
engine := newTestEngine(f)
code, body := doRequest(engine, "203.0.113.1:9000", "", "")
if code != http.StatusOK {
t.Fatalf("期望 200实际 %d", code)
}
ip, _ := body["ip"].(string)
if ip != "203.0.113.1" {
t.Errorf("期望 client_ip=203.0.113.1,实际 %q", ip)
}
}
// ---------- realIP 提取逻辑 ----------
// TestRealIP_XForwardedFor_PublicIP 公网 X-Forwarded-For 取第一个非内网 IP
func TestRealIP_XForwardedFor_PublicIP(t *testing.T) {
f := security.NewIPFilter()
engine := newTestEngine(f)
// X-Forwarded-For: 203.0.113.10, 192.168.1.1(代理内网)
code, body := doRequest(engine, "192.168.1.1:80", "203.0.113.10, 192.168.1.1", "")
if code != http.StatusOK {
t.Fatalf("期望 200实际 %d", code)
}
ip, _ := body["ip"].(string)
if ip != "203.0.113.10" {
t.Errorf("期望从 X-Forwarded-For 取公网 IP实际 %q", ip)
}
}
// TestRealIP_XForwardedFor_AllPrivate 全内网则取第一个
func TestRealIP_XForwardedFor_AllPrivate(t *testing.T) {
f := security.NewIPFilter()
engine := newTestEngine(f)
code, body := doRequest(engine, "10.0.0.2:80", "192.168.0.5, 10.0.0.1", "")
if code != http.StatusOK {
t.Fatalf("期望 200实际 %d", code)
}
ip, _ := body["ip"].(string)
if ip != "192.168.0.5" {
t.Errorf("全内网时应取第一个,实际 %q", ip)
}
}
// TestRealIP_XRealIP_Fallback X-Forwarded-For 缺失时使用 X-Real-IP
func TestRealIP_XRealIP_Fallback(t *testing.T) {
f := security.NewIPFilter()
engine := newTestEngine(f)
code, body := doRequest(engine, "192.168.1.1:80", "", "203.0.113.20")
if code != http.StatusOK {
t.Fatalf("期望 200实际 %d", code)
}
ip, _ := body["ip"].(string)
if ip != "203.0.113.20" {
t.Errorf("期望 X-Real-IP 回退,实际 %q", ip)
}
}
// TestRealIP_RemoteAddr_Fallback 都无 header 时用 RemoteAddr
func TestRealIP_RemoteAddr_Fallback(t *testing.T) {
f := security.NewIPFilter()
engine := newTestEngine(f)
code, body := doRequest(engine, "203.0.113.99:12345", "", "")
if code != http.StatusOK {
t.Fatalf("期望 200实际 %d", code)
}
ip, _ := body["ip"].(string)
if ip != "203.0.113.99" {
t.Errorf("期望 RemoteAddr 回退,实际 %q", ip)
}
}
// ---------- GetFilter ----------
func TestIPFilterMiddleware_GetFilter(t *testing.T) {
f := security.NewIPFilter()
mw := NewIPFilterMiddleware(f, IPFilterConfig{})
if mw.GetFilter() != f {
t.Error("GetFilter 应返回同一个 IPFilter 实例")
}
}
// ---------- 并发安全 ----------
func TestIPFilter_ConcurrentRequests(t *testing.T) {
f := security.NewIPFilter()
_ = f.AddToBlacklist("66.66.66.66", "并发测试封禁", 0)
engine := newTestEngine(f)
done := make(chan struct{}, 20)
for i := 0; i < 20; i++ {
go func(i int) {
defer func() { done <- struct{}{} }()
var remoteAddr string
if i%2 == 0 {
remoteAddr = "66.66.66.66:9000"
} else {
remoteAddr = "1.2.3.4:9000"
}
code, _ := doRequest(engine, remoteAddr, "", "")
if i%2 == 0 && code != http.StatusForbidden {
t.Errorf("并发:封禁 IP 应返回 403实际 %d", code)
} else if i%2 != 0 && code != http.StatusOK {
t.Errorf("并发:正常 IP 应返回 200实际 %d", code)
}
}(i)
}
for i := 0; i < 20; i++ {
<-done
}
}

View File

@@ -0,0 +1,83 @@
package middleware
import (
"log"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
)
var sensitiveQueryKeys = map[string]struct{}{
"token": {},
"access_token": {},
"refresh_token": {},
"code": {},
"secret": {},
}
func Logger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
raw := sanitizeQuery(c.Request.URL.RawQuery)
c.Next()
latency := time.Since(start)
status := c.Writer.Status()
method := c.Request.Method
ip := c.ClientIP()
userAgent := c.Request.UserAgent()
userID, _ := c.Get("user_id")
log.Printf("[API] %s %s %s | status: %d | latency: %v | ip: %s | user_id: %v | ua: %s",
time.Now().Format("2006-01-02 15:04:05"),
method,
path,
status,
latency,
ip,
userID,
userAgent,
)
if len(c.Errors) > 0 {
for _, err := range c.Errors {
log.Printf("[Error] %v", err)
}
}
if raw != "" {
log.Printf("[Query] %s?%s", path, raw)
}
}
}
func sanitizeQuery(raw string) string {
if raw == "" {
return ""
}
values, err := url.ParseQuery(raw)
if err != nil {
return ""
}
for key := range values {
if isSensitiveQueryKey(key) {
values.Set(key, "***")
}
}
return values.Encode()
}
func isSensitiveQueryKey(key string) bool {
normalized := strings.ToLower(strings.TrimSpace(key))
if _, ok := sensitiveQueryKeys[normalized]; ok {
return true
}
return strings.Contains(normalized, "token") || strings.Contains(normalized, "secret")
}

View File

@@ -0,0 +1,125 @@
package middleware
import (
"bytes"
"context"
"encoding/json"
"io"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
type OperationLogMiddleware struct {
repo *repository.OperationLogRepository
}
func NewOperationLogMiddleware(repo *repository.OperationLogRepository) *OperationLogMiddleware {
return &OperationLogMiddleware{repo: repo}
}
type bodyWriter struct {
gin.ResponseWriter
statusCode int
}
func newBodyWriter(w gin.ResponseWriter) *bodyWriter {
return &bodyWriter{ResponseWriter: w, statusCode: 200}
}
func (bw *bodyWriter) WriteHeader(code int) {
bw.statusCode = code
bw.ResponseWriter.WriteHeader(code)
}
func (bw *bodyWriter) WriteHeaderNow() {
bw.ResponseWriter.WriteHeaderNow()
}
func (m *OperationLogMiddleware) Record() gin.HandlerFunc {
return func(c *gin.Context) {
method := c.Request.Method
if method == "GET" || method == "HEAD" || method == "OPTIONS" {
c.Next()
return
}
var reqParams string
if c.Request.Body != nil {
bodyBytes, err := io.ReadAll(io.LimitReader(c.Request.Body, 4096))
if err == nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
reqParams = sanitizeParams(bodyBytes)
}
}
bw := newBodyWriter(c.Writer)
c.Writer = bw
c.Next()
var userIDPtr *int64
if uid, exists := c.Get("user_id"); exists {
if id, ok := uid.(int64); ok {
userID := id
userIDPtr = &userID
}
}
logEntry := &domain.OperationLog{
UserID: userIDPtr,
OperationType: methodToType(method),
OperationName: c.FullPath(),
RequestMethod: method,
RequestPath: c.Request.URL.Path,
RequestParams: reqParams,
ResponseStatus: bw.statusCode,
IP: c.ClientIP(),
UserAgent: c.Request.UserAgent(),
}
go func(entry *domain.OperationLog) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = m.repo.Create(ctx, entry)
}(logEntry)
}
}
func methodToType(method string) string {
switch method {
case "POST":
return "CREATE"
case "PUT", "PATCH":
return "UPDATE"
case "DELETE":
return "DELETE"
default:
return "OTHER"
}
}
func sanitizeParams(data []byte) string {
var payload map[string]interface{}
if err := json.Unmarshal(data, &payload); err != nil {
if len(data) > 500 {
return string(data[:500]) + "..."
}
return string(data)
}
for _, field := range []string{"password", "old_password", "new_password", "confirm_password", "secret", "token"} {
if _, ok := payload[field]; ok {
payload[field] = "***"
}
}
result, err := json.Marshal(payload)
if err != nil {
return ""
}
return string(result)
}

View File

@@ -0,0 +1,127 @@
package middleware
import (
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
)
// RateLimitMiddleware 限流中间件
type RateLimitMiddleware struct {
cfg config.RateLimitConfig
limiters map[string]*SlidingWindowLimiter
mu sync.RWMutex
cleanupInt time.Duration
}
// SlidingWindowLimiter 滑动窗口限流器
type SlidingWindowLimiter struct {
mu sync.Mutex
window time.Duration
capacity int64
requests []int64
}
// NewSlidingWindowLimiter 创建滑动窗口限流器
func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter {
return &SlidingWindowLimiter{
window: window,
capacity: capacity,
requests: make([]int64, 0),
}
}
// Allow 检查是否允许请求
func (l *SlidingWindowLimiter) Allow() bool {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now().UnixMilli()
cutoff := now - l.window.Milliseconds()
// 清理过期请求
var validRequests []int64
for _, t := range l.requests {
if t > cutoff {
validRequests = append(validRequests, t)
}
}
l.requests = validRequests
// 检查容量
if int64(len(l.requests)) >= l.capacity {
return false
}
l.requests = append(l.requests, now)
return true
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
return &RateLimitMiddleware{
cfg: cfg,
limiters: make(map[string]*SlidingWindowLimiter),
cleanupInt: 5 * time.Minute,
}
}
// Register 返回注册接口的限流中间件
func (m *RateLimitMiddleware) Register() gin.HandlerFunc {
return m.limitForKey("register", 60, 10)
}
// Login 返回登录接口的限流中间件
func (m *RateLimitMiddleware) Login() gin.HandlerFunc {
return m.limitForKey("login", 60, 5)
}
// API 返回 API 接口的限流中间件
func (m *RateLimitMiddleware) API() gin.HandlerFunc {
return m.limitForKey("api", 60, 100)
}
// Refresh 返回刷新令牌的限流中间件
func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
return m.limitForKey("refresh", 60, 10)
}
func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc {
limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity)
return func(c *gin.Context) {
if !limiter.Allow() {
c.JSON(429, gin.H{
"code": 429,
"message": "请求过于频繁,请稍后再试",
})
c.Abort()
return
}
c.Next()
}
}
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
m.mu.RLock()
limiter, exists := m.limiters[key]
m.mu.RUnlock()
if exists {
return limiter
}
m.mu.Lock()
defer m.mu.Unlock()
// 双重检查
if limiter, exists = m.limiters[key]; exists {
return limiter
}
limiter = NewSlidingWindowLimiter(window, capacity)
m.limiters[key] = limiter
return limiter
}

View File

@@ -0,0 +1,156 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
)
// contextKey 上下文键常量
const (
ContextKeyRoleCodes = "role_codes"
ContextKeyPermissionCodes = "permission_codes"
)
// RequirePermission 要求用户拥有指定权限之一OR 逻辑)
// 适用于需要单个或多选权限校验的路由
func RequirePermission(codes ...string) gin.HandlerFunc {
return func(c *gin.Context) {
if !hasAnyPermission(c, codes) {
c.JSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "权限不足",
})
c.Abort()
return
}
c.Next()
}
}
// RequireAllPermissions 要求用户拥有所有指定权限AND 逻辑)
func RequireAllPermissions(codes ...string) gin.HandlerFunc {
return func(c *gin.Context) {
if !hasAllPermissions(c, codes) {
c.JSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "权限不足,需要所有指定权限",
})
c.Abort()
return
}
c.Next()
}
}
// RequireRole 要求用户拥有指定角色之一OR 逻辑)
func RequireRole(codes ...string) gin.HandlerFunc {
return func(c *gin.Context) {
if !hasAnyRole(c, codes) {
c.JSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "权限不足,角色受限",
})
c.Abort()
return
}
c.Next()
}
}
// RequireAnyPermission RequirePermission 的别名,语义更清晰
func RequireAnyPermission(codes ...string) gin.HandlerFunc {
return RequirePermission(codes...)
}
// AdminOnly 仅限 admin 角色
func AdminOnly() gin.HandlerFunc {
return RequireRole("admin")
}
// GetRoleCodes 从 Context 获取当前用户角色代码列表
func GetRoleCodes(c *gin.Context) []string {
val, exists := c.Get(ContextKeyRoleCodes)
if !exists {
return nil
}
if codes, ok := val.([]string); ok {
return codes
}
return nil
}
// GetPermissionCodes 从 Context 获取当前用户权限代码列表
func GetPermissionCodes(c *gin.Context) []string {
val, exists := c.Get(ContextKeyPermissionCodes)
if !exists {
return nil
}
if codes, ok := val.([]string); ok {
return codes
}
return nil
}
// IsAdmin 判断当前用户是否为 admin
func IsAdmin(c *gin.Context) bool {
return hasAnyRole(c, []string{"admin"})
}
// hasAnyPermission 判断用户是否拥有任意一个权限
func hasAnyPermission(c *gin.Context, codes []string) bool {
// admin 角色拥有所有权限
if IsAdmin(c) {
return true
}
permCodes := GetPermissionCodes(c)
if len(permCodes) == 0 {
return false
}
permSet := toSet(permCodes)
for _, code := range codes {
if _, ok := permSet[code]; ok {
return true
}
}
return false
}
// hasAllPermissions 判断用户是否拥有所有权限
func hasAllPermissions(c *gin.Context, codes []string) bool {
if IsAdmin(c) {
return true
}
permCodes := GetPermissionCodes(c)
permSet := toSet(permCodes)
for _, code := range codes {
if _, ok := permSet[code]; !ok {
return false
}
}
return true
}
// hasAnyRole 判断用户是否拥有任意一个角色
func hasAnyRole(c *gin.Context, codes []string) bool {
roleCodes := GetRoleCodes(c)
if len(roleCodes) == 0 {
return false
}
roleSet := toSet(roleCodes)
for _, code := range codes {
if _, ok := roleSet[code]; ok {
return true
}
}
return false
}
// toSet 将字符串切片转换为 map 集合
func toSet(items []string) map[string]struct{} {
s := make(map[string]struct{}, len(items))
for _, item := range items {
s[item] = struct{}{}
}
return s
}

View File

@@ -0,0 +1,139 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
)
func TestCORS_UsesConfiguredOrigins(t *testing.T) {
gin.SetMode(gin.TestMode)
SetCORSConfig(config.CORSConfig{
AllowedOrigins: []string{"https://app.example.com"},
AllowCredentials: true,
})
t.Cleanup(func() {
SetCORSConfig(config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
})
})
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodOptions, "/api/v1/users", nil)
c.Request.Header.Set("Origin", "https://app.example.com")
c.Request.Header.Set("Access-Control-Request-Headers", "Authorization")
CORS()(c)
if recorder.Code != http.StatusNoContent {
t.Fatalf("expected 204, got %d", recorder.Code)
}
if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != "https://app.example.com" {
t.Fatalf("unexpected allow origin: %s", got)
}
if got := recorder.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
t.Fatalf("expected credentials header to be 'true', got %q", got)
}
}
func TestSanitizeQuery_MasksSensitiveValues(t *testing.T) {
raw := "token=abc123&foo=bar&access_token=xyz&secret=s1"
sanitized := sanitizeQuery(raw)
if sanitized == "" {
t.Fatal("expected sanitized query")
}
if sanitized == raw {
t.Fatal("expected query to be sanitized")
}
for _, value := range []string{"abc123", "xyz", "s1"} {
if strings.Contains(sanitized, value) {
t.Fatalf("expected sensitive value %q to be masked in %q", value, sanitized)
}
}
if sanitizeQuery("") != "" {
t.Fatal("expected empty query to stay empty")
}
}
func TestSecurityHeaders_AttachesExpectedHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
SecurityHeaders()(c)
if got := recorder.Header().Get("X-Content-Type-Options"); got != "nosniff" {
t.Fatalf("unexpected nosniff header: %q", got)
}
if got := recorder.Header().Get("X-Frame-Options"); got != "DENY" {
t.Fatalf("unexpected frame options: %q", got)
}
if got := recorder.Header().Get("Content-Security-Policy"); got == "" {
t.Fatal("expected content security policy header")
}
if got := recorder.Header().Get("Strict-Transport-Security"); got != "" {
t.Fatalf("did not expect hsts header for http request, got %q", got)
}
}
func TestSecurityHeaders_AttachesHSTSForForwardedHTTPS(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
c.Request.Header.Set("X-Forwarded-Proto", "https")
SecurityHeaders()(c)
if got := recorder.Header().Get("Strict-Transport-Security"); !strings.Contains(got, "max-age=31536000") {
t.Fatalf("expected hsts header, got %q", got)
}
}
func TestNoStoreSensitiveResponses_AttachesExpectedHeadersToAuthRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/capabilities", nil)
NoStoreSensitiveResponses()(c)
if got := recorder.Header().Get("Cache-Control"); got != sensitiveNoStoreCacheControl {
t.Fatalf("unexpected cache-control header: %q", got)
}
if got := recorder.Header().Get("Pragma"); got != "no-cache" {
t.Fatalf("unexpected pragma header: %q", got)
}
if got := recorder.Header().Get("Expires"); got != "0" {
t.Fatalf("unexpected expires header: %q", got)
}
if got := recorder.Header().Get("Surrogate-Control"); got != "no-store" {
t.Fatalf("unexpected surrogate-control header: %q", got)
}
}
func TestNoStoreSensitiveResponses_DoesNotAttachHeadersToNonAuthRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
NoStoreSensitiveResponses()(c)
if got := recorder.Header().Get("Cache-Control"); got != "" {
t.Fatalf("did not expect cache-control header, got %q", got)
}
}

View File

@@ -0,0 +1,45 @@
package middleware
import (
"strings"
"github.com/gin-gonic/gin"
)
const contentSecurityPolicy = "default-src 'none'; frame-ancestors 'none'; base-uri 'none'; form-action 'self'"
func SecurityHeaders() gin.HandlerFunc {
return func(c *gin.Context) {
headers := c.Writer.Header()
headers.Set("X-Content-Type-Options", "nosniff")
headers.Set("X-Frame-Options", "DENY")
headers.Set("Referrer-Policy", "strict-origin-when-cross-origin")
headers.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
headers.Set("Cross-Origin-Opener-Policy", "same-origin")
headers.Set("X-Permitted-Cross-Domain-Policies", "none")
if shouldAttachCSP(c.FullPath(), c.Request.URL.Path) {
headers.Set("Content-Security-Policy", contentSecurityPolicy)
}
if isHTTPSRequest(c) {
headers.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
c.Next()
}
}
func shouldAttachCSP(routePath, requestPath string) bool {
path := strings.TrimSpace(routePath)
if path == "" {
path = strings.TrimSpace(requestPath)
}
return !strings.HasPrefix(path, "/swagger/")
}
func isHTTPSRequest(c *gin.Context) bool {
if c.Request.TLS != nil {
return true
}
return strings.EqualFold(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")), "https")
}