feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers

This commit is contained in:
2026-04-02 11:19:50 +08:00
parent e59a77bc49
commit dcc1f186f8
298 changed files with 62603 additions and 0 deletions

View File

@@ -0,0 +1,260 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/service"
)
// AuthHandler handles authentication requests
type AuthHandler struct {
authService *service.AuthService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
return &AuthHandler{authService: authService}
}
func (h *AuthHandler) Register(c *gin.Context) {
var req struct {
Username string `json:"username" binding:"required"`
Email string `json:"email"`
Phone string `json:"phone"`
Password string `json:"password" binding:"required"`
Nickname string `json:"nickname"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
registerReq := &service.RegisterRequest{
Username: req.Username,
Email: req.Email,
Phone: req.Phone,
Password: req.Password,
Nickname: req.Nickname,
}
userInfo, err := h.authService.Register(c.Request.Context(), registerReq)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusCreated, userInfo)
}
func (h *AuthHandler) Login(c *gin.Context) {
var req struct {
Account string `json:"account"`
Username string `json:"username"`
Email string `json:"email"`
Phone string `json:"phone"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
loginReq := &service.LoginRequest{
Account: req.Account,
Username: req.Username,
Email: req.Email,
Phone: req.Phone,
Password: req.Password,
}
clientIP := c.ClientIP()
resp, err := h.authService.Login(c.Request.Context(), loginReq, clientIP)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, resp)
}
func (h *AuthHandler) Logout(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "logged out"})
}
func (h *AuthHandler) RefreshToken(c *gin.Context) {
var req struct {
RefreshToken string `json:"refresh_token" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
resp, err := h.authService.RefreshToken(c.Request.Context(), req.RefreshToken)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, resp)
}
func (h *AuthHandler) GetUserInfo(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
userInfo, err := h.authService.GetUserInfo(c.Request.Context(), userID)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, userInfo)
}
func (h *AuthHandler) GetCSRFToken(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"csrf_token": "not_implemented"})
}
func (h *AuthHandler) GetAuthCapabilities(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"register": true,
"login": true,
"oauth_login": false,
"totp": true,
})
}
func (h *AuthHandler) OAuthLogin(c *gin.Context) {
provider := c.Param("provider")
c.JSON(http.StatusOK, gin.H{"provider": provider, "message": "OAuth not configured"})
}
func (h *AuthHandler) OAuthCallback(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"error": "OAuth not configured"})
}
func (h *AuthHandler) OAuthExchange(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"error": "OAuth not configured"})
}
func (h *AuthHandler) GetEnabledOAuthProviders(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"providers": []string{}})
}
func (h *AuthHandler) ActivateEmail(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "email activation not configured"})
}
func (h *AuthHandler) ResendActivationEmail(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "email activation not configured"})
}
func (h *AuthHandler) SendEmailCode(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "email code login not configured"})
}
func (h *AuthHandler) LoginByEmailCode(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"error": "email code login not configured"})
}
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "password reset not configured"})
}
func (h *AuthHandler) ResetPassword(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "password reset not configured"})
}
func (h *AuthHandler) ValidateResetToken(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"valid": false})
}
func (h *AuthHandler) BootstrapAdmin(c *gin.Context) {
var req struct {
Username string `json:"username" binding:"required"`
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
bootstrapReq := &service.BootstrapAdminRequest{
Username: req.Username,
Email: req.Email,
Password: req.Password,
}
clientIP := c.ClientIP()
resp, err := h.authService.BootstrapAdmin(c.Request.Context(), bootstrapReq, clientIP)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusCreated, resp)
}
func (h *AuthHandler) SendEmailBindCode(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "email bind not configured"})
}
func (h *AuthHandler) BindEmail(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "email bind not configured"})
}
func (h *AuthHandler) UnbindEmail(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "email unbind not configured"})
}
func (h *AuthHandler) SendPhoneBindCode(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "phone bind not configured"})
}
func (h *AuthHandler) BindPhone(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "phone bind not configured"})
}
func (h *AuthHandler) UnbindPhone(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "phone unbind not configured"})
}
func (h *AuthHandler) GetSocialAccounts(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"accounts": []interface{}{}})
}
func (h *AuthHandler) BindSocialAccount(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "social binding not configured"})
}
func (h *AuthHandler) UnbindSocialAccount(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "social unbinding not configured"})
}
func (h *AuthHandler) SupportsEmailCodeLogin() bool {
return false
}
func getUserIDFromContext(c *gin.Context) (int64, bool) {
userID, exists := c.Get("user_id")
if !exists {
return 0, false
}
id, ok := userID.(int64)
return id, ok
}
func handleError(c *gin.Context, err error) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}

View File

@@ -0,0 +1,19 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
)
// AvatarHandler handles avatar upload requests
type AvatarHandler struct{}
// NewAvatarHandler creates a new AvatarHandler
func NewAvatarHandler() *AvatarHandler {
return &AvatarHandler{}
}
func (h *AvatarHandler) UploadAvatar(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "avatar upload not implemented"})
}

View File

@@ -0,0 +1,54 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/service"
)
// CaptchaHandler handles captcha requests
type CaptchaHandler struct {
captchaService *service.CaptchaService
}
// NewCaptchaHandler creates a new CaptchaHandler
func NewCaptchaHandler(captchaService *service.CaptchaService) *CaptchaHandler {
return &CaptchaHandler{captchaService: captchaService}
}
func (h *CaptchaHandler) GenerateCaptcha(c *gin.Context) {
result, err := h.captchaService.Generate(c.Request.Context())
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"captcha_id": result.CaptchaID,
"image": result.ImageData,
})
}
func (h *CaptchaHandler) GetCaptchaImage(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "captcha image endpoint"})
}
func (h *CaptchaHandler) VerifyCaptcha(c *gin.Context) {
var req struct {
CaptchaID string `json:"captcha_id" binding:"required"`
Answer string `json:"answer" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if h.captchaService.Verify(c.Request.Context(), req.CaptchaID, req.Answer) {
c.JSON(http.StatusOK, gin.H{"verified": true})
} else {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid captcha"})
}
}

View File

@@ -0,0 +1,146 @@
package handler
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/service"
)
// CustomFieldHandler 自定义字段处理器
type CustomFieldHandler struct {
customFieldService *service.CustomFieldService
}
// NewCustomFieldHandler 创建自定义字段处理器
func NewCustomFieldHandler(customFieldService *service.CustomFieldService) *CustomFieldHandler {
return &CustomFieldHandler{customFieldService: customFieldService}
}
// CreateField 创建自定义字段
func (h *CustomFieldHandler) CreateField(c *gin.Context) {
var req service.CreateFieldRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
field, err := h.customFieldService.CreateField(c.Request.Context(), &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusCreated, field)
}
// UpdateField 更新自定义字段
func (h *CustomFieldHandler) UpdateField(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid field id"})
return
}
var req service.UpdateFieldRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
field, err := h.customFieldService.UpdateField(c.Request.Context(), id, &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, field)
}
// DeleteField 删除自定义字段
func (h *CustomFieldHandler) DeleteField(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid field id"})
return
}
if err := h.customFieldService.DeleteField(c.Request.Context(), id); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "field deleted"})
}
// GetField 获取自定义字段
func (h *CustomFieldHandler) GetField(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid field id"})
return
}
field, err := h.customFieldService.GetField(c.Request.Context(), id)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, field)
}
// ListFields 获取所有自定义字段
func (h *CustomFieldHandler) ListFields(c *gin.Context) {
fields, err := h.customFieldService.ListFields(c.Request.Context())
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"fields": fields})
}
// SetUserFieldValues 设置用户自定义字段值
func (h *CustomFieldHandler) SetUserFieldValues(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
var req struct {
Values map[string]string `json:"values" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.customFieldService.BatchSetUserFieldValues(c.Request.Context(), userID, req.Values); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "field values set"})
}
// GetUserFieldValues 获取用户自定义字段值
func (h *CustomFieldHandler) GetUserFieldValues(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
values, err := h.customFieldService.GetUserFieldValues(c.Request.Context(), userID)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"fields": values})
}

View File

@@ -0,0 +1,343 @@
package handler
import (
"fmt"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/service"
)
// DeviceHandler handles device management requests
type DeviceHandler struct {
deviceService *service.DeviceService
}
// NewDeviceHandler creates a new DeviceHandler
func NewDeviceHandler(deviceService *service.DeviceService) *DeviceHandler {
return &DeviceHandler{deviceService: deviceService}
}
func (h *DeviceHandler) CreateDevice(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
var req service.CreateDeviceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
device, err := h.deviceService.CreateDevice(c.Request.Context(), userID, &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusCreated, device)
}
func (h *DeviceHandler) GetMyDevices(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
devices, total, err := h.deviceService.GetUserDevices(c.Request.Context(), userID, page, pageSize)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"devices": devices,
"total": total,
"page": page,
"page_size": pageSize,
})
}
func (h *DeviceHandler) GetDevice(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
device, err := h.deviceService.GetDevice(c.Request.Context(), id)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, device)
}
func (h *DeviceHandler) UpdateDevice(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
var req service.UpdateDeviceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
device, err := h.deviceService.UpdateDevice(c.Request.Context(), id, &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, device)
}
func (h *DeviceHandler) DeleteDevice(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
if err := h.deviceService.DeleteDevice(c.Request.Context(), id); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "device deleted"})
}
func (h *DeviceHandler) UpdateDeviceStatus(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
var req struct {
Status string `json:"status" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var status domain.DeviceStatus
switch req.Status {
case "active", "1":
status = domain.DeviceStatusActive
case "inactive", "0":
status = domain.DeviceStatusInactive
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid status"})
return
}
if err := h.deviceService.UpdateDeviceStatus(c.Request.Context(), id, status); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "status updated"})
}
func (h *DeviceHandler) GetUserDevices(c *gin.Context) {
userIDParam := c.Param("id")
userID, err := strconv.ParseInt(userIDParam, 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
devices, total, err := h.deviceService.GetUserDevices(c.Request.Context(), userID, page, pageSize)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"devices": devices,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// GetAllDevices 获取所有设备列表(管理员)
func (h *DeviceHandler) GetAllDevices(c *gin.Context) {
var req service.GetAllDevicesRequest
if err := c.ShouldBindQuery(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
devices, total, err := h.deviceService.GetAllDevices(c.Request.Context(), &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"devices": devices,
"total": total,
"page": req.Page,
"page_size": req.PageSize,
})
}
// TrustDeviceRequest 信任设备请求
type TrustDeviceRequest struct {
TrustDuration string `json:"trust_duration"` // 信任持续时间,如 "30d" 表示30天
}
// TrustDevice 设置设备为信任设备
func (h *DeviceHandler) TrustDevice(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
var req TrustDeviceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 解析信任持续时间
trustDuration := parseDuration(req.TrustDuration)
if err := h.deviceService.TrustDevice(c.Request.Context(), id, trustDuration); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "device trusted"})
}
// TrustDeviceByDeviceID 根据设备标识字符串设置设备为信任状态
func (h *DeviceHandler) TrustDeviceByDeviceID(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
deviceID := c.Param("deviceId")
if deviceID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
var req TrustDeviceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 解析信任持续时间
trustDuration := parseDuration(req.TrustDuration)
if err := h.deviceService.TrustDeviceByDeviceID(c.Request.Context(), userID, deviceID, trustDuration); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "device trusted"})
}
// UntrustDevice 取消设备信任状态
func (h *DeviceHandler) UntrustDevice(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
if err := h.deviceService.UntrustDevice(c.Request.Context(), id); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "device untrusted"})
}
// GetMyTrustedDevices 获取我的信任设备列表
func (h *DeviceHandler) GetMyTrustedDevices(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
devices, err := h.deviceService.GetTrustedDevices(c.Request.Context(), userID)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"devices": devices})
}
// LogoutAllOtherDevices 登出所有其他设备
func (h *DeviceHandler) LogoutAllOtherDevices(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
// 从请求中获取当前设备ID
currentDeviceIDStr := c.GetHeader("X-Device-ID")
currentDeviceID, err := strconv.ParseInt(currentDeviceIDStr, 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid current device id"})
return
}
if err := h.deviceService.LogoutAllOtherDevices(c.Request.Context(), userID, currentDeviceID); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "all other devices logged out"})
}
// parseDuration 解析duration字符串如 "30d" -> 30天的time.Duration
func parseDuration(s string) time.Duration {
if s == "" {
return 0
}
// 简单实现,支持 d(天)和h(小时)
var d int
var h int
_, _ = d, h
switch s[len(s)-1] {
case 'd':
d = 1
_, _ = fmt.Sscanf(s[:len(s)-1], "%d", &d)
return time.Duration(d) * 24 * time.Hour
case 'h':
_, _ = fmt.Sscanf(s[:len(s)-1], "%d", &h)
return time.Duration(h) * time.Hour
}
return 0
}

View File

@@ -0,0 +1,31 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/service"
)
// ExportHandler handles user export/import requests
type ExportHandler struct {
exportService *service.ExportService
}
// NewExportHandler creates a new ExportHandler
func NewExportHandler(exportService *service.ExportService) *ExportHandler {
return &ExportHandler{exportService: exportService}
}
func (h *ExportHandler) ExportUsers(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "user export not implemented"})
}
func (h *ExportHandler) ImportUsers(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "user import not implemented"})
}
func (h *ExportHandler) GetImportTemplate(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"template": "id,username,email,nickname"})
}

View File

@@ -0,0 +1,93 @@
package handler
import (
"fmt"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/service"
)
// LogHandler handles log requests
type LogHandler struct {
loginLogService *service.LoginLogService
operationLogService *service.OperationLogService
}
// NewLogHandler creates a new LogHandler
func NewLogHandler(loginLogService *service.LoginLogService, operationLogService *service.OperationLogService) *LogHandler {
return &LogHandler{
loginLogService: loginLogService,
operationLogService: operationLogService,
}
}
func (h *LogHandler) GetMyLoginLogs(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
logs, total, err := h.loginLogService.GetMyLoginLogs(c.Request.Context(), userID, page, pageSize)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"logs": logs,
"total": total,
"page": page,
"page_size": pageSize,
})
}
func (h *LogHandler) GetMyOperationLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"logs": []interface{}{}})
}
func (h *LogHandler) GetLoginLogs(c *gin.Context) {
var req service.ListLoginLogRequest
if err := c.ShouldBindQuery(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
logs, total, err := h.loginLogService.GetLoginLogs(c.Request.Context(), &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"logs": logs,
"total": total,
})
}
func (h *LogHandler) GetOperationLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"logs": []interface{}{}})
}
func (h *LogHandler) ExportLoginLogs(c *gin.Context) {
var req service.ExportLoginLogRequest
if err := c.ShouldBindQuery(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
data, filename, contentType, err := h.loginLogService.ExportLoginLogs(c.Request.Context(), &req)
if err != nil {
handleError(c, err)
return
}
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename))
c.Data(http.StatusOK, contentType, data)
}

View File

@@ -0,0 +1,153 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/service"
)
// PasswordResetHandler handles password reset requests
type PasswordResetHandler struct {
passwordResetService *service.PasswordResetService
smsService *service.SMSCodeService
}
// NewPasswordResetHandler creates a new PasswordResetHandler
func NewPasswordResetHandler(passwordResetService *service.PasswordResetService) *PasswordResetHandler {
return &PasswordResetHandler{passwordResetService: passwordResetService}
}
// NewPasswordResetHandlerWithSMS creates a new PasswordResetHandler with SMS support
func NewPasswordResetHandlerWithSMS(passwordResetService *service.PasswordResetService, smsService *service.SMSCodeService) *PasswordResetHandler {
return &PasswordResetHandler{
passwordResetService: passwordResetService,
smsService: smsService,
}
}
func (h *PasswordResetHandler) ForgotPassword(c *gin.Context) {
var req struct {
Email string `json:"email" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.passwordResetService.ForgotPassword(c.Request.Context(), req.Email); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "password reset email sent"})
}
func (h *PasswordResetHandler) ValidateResetToken(c *gin.Context) {
token := c.Query("token")
if token == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "token is required"})
return
}
valid, err := h.passwordResetService.ValidateResetToken(c.Request.Context(), token)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"valid": valid})
}
func (h *PasswordResetHandler) ResetPassword(c *gin.Context) {
var req struct {
Token string `json:"token" binding:"required"`
NewPassword string `json:"new_password" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.passwordResetService.ResetPassword(c.Request.Context(), req.Token, req.NewPassword); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "password reset successful"})
}
// ForgotPasswordByPhoneRequest 短信密码重置请求
type ForgotPasswordByPhoneRequest struct {
Phone string `json:"phone" binding:"required"`
}
// ForgotPasswordByPhone 发送短信验证码
func (h *PasswordResetHandler) ForgotPasswordByPhone(c *gin.Context) {
if h.smsService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "SMS service not configured"})
return
}
var req ForgotPasswordByPhoneRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 获取验证码(不发送,由调用方通过其他渠道发送)
code, err := h.passwordResetService.ForgotPasswordByPhone(c.Request.Context(), req.Phone)
if err != nil {
handleError(c, err)
return
}
if code == "" {
// 用户不存在,不提示
c.JSON(http.StatusOK, gin.H{"message": "verification code sent"})
return
}
// 通过SMS服务发送验证码
sendReq := &service.SendCodeRequest{
Phone: req.Phone,
Purpose: "password_reset",
}
_, err = h.smsService.SendCode(c.Request.Context(), sendReq)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "verification code sent"})
}
// ResetPasswordByPhoneRequest 短信验证码重置密码请求
type ResetPasswordByPhoneRequest struct {
Phone string `json:"phone" binding:"required"`
Code string `json:"code" binding:"required"`
NewPassword string `json:"new_password" binding:"required"`
}
// ResetPasswordByPhone 通过短信验证码重置密码
func (h *PasswordResetHandler) ResetPasswordByPhone(c *gin.Context) {
var req ResetPasswordByPhoneRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
err := h.passwordResetService.ResetPasswordByPhone(c.Request.Context(), &service.ResetPasswordByPhoneRequest{
Phone: req.Phone,
Code: req.Code,
NewPassword: req.NewPassword,
})
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "password reset successful"})
}

View File

@@ -0,0 +1,154 @@
package handler
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/service"
)
// PermissionHandler handles permission management requests
type PermissionHandler struct {
permissionService *service.PermissionService
}
// NewPermissionHandler creates a new PermissionHandler
func NewPermissionHandler(permissionService *service.PermissionService) *PermissionHandler {
return &PermissionHandler{permissionService: permissionService}
}
func (h *PermissionHandler) CreatePermission(c *gin.Context) {
var req service.CreatePermissionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
perm, err := h.permissionService.CreatePermission(c.Request.Context(), &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusCreated, perm)
}
func (h *PermissionHandler) ListPermissions(c *gin.Context) {
var req service.ListPermissionRequest
if err := c.ShouldBindQuery(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
perms, total, err := h.permissionService.ListPermissions(c.Request.Context(), &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"permissions": perms,
"total": total,
})
}
func (h *PermissionHandler) GetPermission(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid permission id"})
return
}
perm, err := h.permissionService.GetPermission(c.Request.Context(), id)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, perm)
}
func (h *PermissionHandler) UpdatePermission(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid permission id"})
return
}
var req service.UpdatePermissionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
perm, err := h.permissionService.UpdatePermission(c.Request.Context(), id, &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, perm)
}
func (h *PermissionHandler) DeletePermission(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid permission id"})
return
}
if err := h.permissionService.DeletePermission(c.Request.Context(), id); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "permission deleted"})
}
func (h *PermissionHandler) UpdatePermissionStatus(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid permission id"})
return
}
var req struct {
Status string `json:"status" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var status domain.PermissionStatus
switch req.Status {
case "enabled", "1":
status = domain.PermissionStatusEnabled
case "disabled", "0":
status = domain.PermissionStatusDisabled
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid status"})
return
}
if err := h.permissionService.UpdatePermissionStatus(c.Request.Context(), id, status); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "status updated"})
}
func (h *PermissionHandler) GetPermissionTree(c *gin.Context) {
tree, err := h.permissionService.GetPermissionTree(c.Request.Context())
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"permissions": tree})
}

View File

@@ -0,0 +1,186 @@
package handler
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/service"
)
// RoleHandler handles role management requests
type RoleHandler struct {
roleService *service.RoleService
}
// NewRoleHandler creates a new RoleHandler
func NewRoleHandler(roleService *service.RoleService) *RoleHandler {
return &RoleHandler{roleService: roleService}
}
func (h *RoleHandler) CreateRole(c *gin.Context) {
var req service.CreateRoleRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
role, err := h.roleService.CreateRole(c.Request.Context(), &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusCreated, role)
}
func (h *RoleHandler) ListRoles(c *gin.Context) {
var req service.ListRoleRequest
if err := c.ShouldBindQuery(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
roles, total, err := h.roleService.ListRoles(c.Request.Context(), &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"roles": roles,
"total": total,
})
}
func (h *RoleHandler) GetRole(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid role id"})
return
}
role, err := h.roleService.GetRole(c.Request.Context(), id)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, role)
}
func (h *RoleHandler) UpdateRole(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid role id"})
return
}
var req service.UpdateRoleRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
role, err := h.roleService.UpdateRole(c.Request.Context(), id, &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, role)
}
func (h *RoleHandler) DeleteRole(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid role id"})
return
}
if err := h.roleService.DeleteRole(c.Request.Context(), id); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "role deleted"})
}
func (h *RoleHandler) UpdateRoleStatus(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid role id"})
return
}
var req struct {
Status string `json:"status" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var status domain.RoleStatus
switch req.Status {
case "enabled", "1":
status = domain.RoleStatusEnabled
case "disabled", "0":
status = domain.RoleStatusDisabled
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid status"})
return
}
err = h.roleService.UpdateRoleStatus(c.Request.Context(), id, status)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "status updated"})
}
func (h *RoleHandler) GetRolePermissions(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid role id"})
return
}
perms, err := h.roleService.GetRolePermissions(c.Request.Context(), id)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"permissions": perms})
}
func (h *RoleHandler) AssignPermissions(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid role id"})
return
}
var req struct {
PermissionIDs []int64 `json:"permission_ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
err = h.roleService.AssignPermissions(c.Request.Context(), id, req.PermissionIDs)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "permissions assigned"})
}

View File

@@ -0,0 +1,23 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
)
// SMSHandler handles SMS requests
type SMSHandler struct{}
// NewSMSHandler creates a new SMSHandler
func NewSMSHandler() *SMSHandler {
return &SMSHandler{}
}
func (h *SMSHandler) SendCode(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "SMS not configured"})
}
func (h *SMSHandler) LoginByCode(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"error": "SMS login not configured"})
}

View File

@@ -0,0 +1,236 @@
package handler
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/auth"
)
// SSOHandler SSO 处理程序
type SSOHandler struct {
ssoManager *auth.SSOManager
}
// NewSSOHandler 创建 SSO 处理程序
func NewSSOHandler(ssoManager *auth.SSOManager) *SSOHandler {
return &SSOHandler{ssoManager: ssoManager}
}
// AuthorizeRequest 授权请求
type AuthorizeRequest struct {
ClientID string `form:"client_id" binding:"required"`
RedirectURI string `form:"redirect_uri" binding:"required"`
ResponseType string `form:"response_type" binding:"required"`
Scope string `form:"scope"`
State string `form:"state"`
}
// Authorize 处理 SSO 授权请求
// GET /api/v1/sso/authorize?client_id=xxx&redirect_uri=xxx&response_type=code&scope=openid&state=xxx
func (h *SSOHandler) Authorize(c *gin.Context) {
var req AuthorizeRequest
if err := c.ShouldBindQuery(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 验证 response_type
if req.ResponseType != "code" && req.ResponseType != "token" {
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported response_type"})
return
}
// 获取当前登录用户(从 auth middleware 设置的 context
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
username, _ := c.Get("username")
// 生成授权码或 access token
if req.ResponseType == "code" {
code, err := h.ssoManager.GenerateAuthorizationCode(
req.ClientID,
req.RedirectURI,
req.Scope,
userID.(int64),
username.(string),
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate code"})
return
}
// 重定向回客户端
redirectURL := req.RedirectURI + "?code=" + code
if req.State != "" {
redirectURL += "&state=" + req.State
}
c.Redirect(http.StatusFound, redirectURL)
} else {
// implicit 模式,直接返回 token
code, err := h.ssoManager.GenerateAuthorizationCode(
req.ClientID,
req.RedirectURI,
req.Scope,
userID.(int64),
username.(string),
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate code"})
return
}
// 验证授权码获取 session
session, err := h.ssoManager.ValidateAuthorizationCode(code)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to validate code"})
return
}
token, _ := h.ssoManager.GenerateAccessToken(req.ClientID, session)
// 重定向回客户端,带 token
redirectURL := req.RedirectURI + "#access_token=" + token + "&expires_in=7200"
if req.State != "" {
redirectURL += "&state=" + req.State
}
c.Redirect(http.StatusFound, redirectURL)
}
}
// TokenRequest Token 请求
type TokenRequest struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code"`
RedirectURI string `form:"redirect_uri"`
ClientID string `form:"client_id" binding:"required"`
ClientSecret string `form:"client_secret" binding:"required"`
}
// TokenResponse Token 响应
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
Scope string `json:"scope"`
}
// Token 处理 Token 请求(授权码模式第二步)
// POST /api/v1/sso/token
func (h *SSOHandler) Token(c *gin.Context) {
var req TokenRequest
if err := c.ShouldBind(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 验证 grant_type
if req.GrantType != "authorization_code" {
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported grant_type"})
return
}
// 验证授权码
session, err := h.ssoManager.ValidateAuthorizationCode(req.Code)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid code"})
return
}
// 生成 access token
token, expiresAt := h.ssoManager.GenerateAccessToken(req.ClientID, session)
c.JSON(http.StatusOK, TokenResponse{
AccessToken: token,
TokenType: "Bearer",
ExpiresIn: int64(time.Until(expiresAt).Seconds()),
Scope: session.Scope,
})
}
// IntrospectRequest Introspect 请求
type IntrospectRequest struct {
Token string `form:"token" binding:"required"`
ClientID string `form:"client_id"`
}
// IntrospectResponse Introspect 响应
type IntrospectResponse struct {
Active bool `json:"active"`
UserID int64 `json:"user_id,omitempty"`
Username string `json:"username,omitempty"`
ExpiresAt int64 `json:"exp,omitempty"`
Scope string `json:"scope,omitempty"`
}
// Introspect 验证 access token
// POST /api/v1/sso/introspect
func (h *SSOHandler) Introspect(c *gin.Context) {
var req IntrospectRequest
if err := c.ShouldBind(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
info, err := h.ssoManager.IntrospectToken(req.Token)
if err != nil {
c.JSON(http.StatusOK, IntrospectResponse{Active: false})
return
}
c.JSON(http.StatusOK, IntrospectResponse{
Active: info.Active,
UserID: info.UserID,
Username: info.Username,
ExpiresAt: info.ExpiresAt.Unix(),
Scope: info.Scope,
})
}
// RevokeRequest 撤销请求
type RevokeRequest struct {
Token string `form:"token" binding:"required"`
}
// Revoke 撤销 access token
// POST /api/v1/sso/revoke
func (h *SSOHandler) Revoke(c *gin.Context) {
var req RevokeRequest
if err := c.ShouldBind(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.ssoManager.RevokeToken(req.Token)
c.JSON(http.StatusOK, gin.H{"message": "token revoked"})
}
// UserInfoResponse 用户信息响应
type UserInfoResponse struct {
UserID int64 `json:"user_id"`
Username string `json:"username"`
}
// UserInfo 获取当前用户信息SSO 专用)
// GET /api/v1/sso/userinfo
func (h *SSOHandler) UserInfo(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
username, _ := c.Get("username")
c.JSON(http.StatusOK, UserInfoResponse{
UserID: userID.(int64),
Username: username.(string),
})
}

View File

@@ -0,0 +1,27 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/service"
)
// StatsHandler handles statistics requests
type StatsHandler struct {
statsService *service.StatsService
}
// NewStatsHandler creates a new StatsHandler
func NewStatsHandler(statsService *service.StatsService) *StatsHandler {
return &StatsHandler{statsService: statsService}
}
func (h *StatsHandler) GetDashboard(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "dashboard stats not implemented"})
}
func (h *StatsHandler) GetUserStats(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "user stats not implemented"})
}

View File

@@ -0,0 +1,153 @@
package handler
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/service"
)
// ThemeHandler 主题配置处理器
type ThemeHandler struct {
themeService *service.ThemeService
}
// NewThemeHandler 创建主题配置处理器
func NewThemeHandler(themeService *service.ThemeService) *ThemeHandler {
return &ThemeHandler{themeService: themeService}
}
// CreateTheme 创建主题
func (h *ThemeHandler) CreateTheme(c *gin.Context) {
var req service.CreateThemeRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
theme, err := h.themeService.CreateTheme(c.Request.Context(), &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusCreated, theme)
}
// UpdateTheme 更新主题
func (h *ThemeHandler) UpdateTheme(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid theme id"})
return
}
var req service.UpdateThemeRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
theme, err := h.themeService.UpdateTheme(c.Request.Context(), id, &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, theme)
}
// DeleteTheme 删除主题
func (h *ThemeHandler) DeleteTheme(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid theme id"})
return
}
if err := h.themeService.DeleteTheme(c.Request.Context(), id); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "theme deleted"})
}
// GetTheme 获取主题
func (h *ThemeHandler) GetTheme(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid theme id"})
return
}
theme, err := h.themeService.GetTheme(c.Request.Context(), id)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, theme)
}
// ListThemes 获取所有主题
func (h *ThemeHandler) ListThemes(c *gin.Context) {
themes, err := h.themeService.ListThemes(c.Request.Context())
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"themes": themes})
}
// ListAllThemes 获取所有主题(包括禁用的)
func (h *ThemeHandler) ListAllThemes(c *gin.Context) {
themes, err := h.themeService.ListAllThemes(c.Request.Context())
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"themes": themes})
}
// GetDefaultTheme 获取默认主题
func (h *ThemeHandler) GetDefaultTheme(c *gin.Context) {
theme, err := h.themeService.GetDefaultTheme(c.Request.Context())
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, theme)
}
// SetDefaultTheme 设置默认主题
func (h *ThemeHandler) SetDefaultTheme(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid theme id"})
return
}
if err := h.themeService.SetDefaultTheme(c.Request.Context(), id); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "default theme set"})
}
// GetActiveTheme 获取当前生效的主题(公开接口)
func (h *ThemeHandler) GetActiveTheme(c *gin.Context) {
theme, err := h.themeService.GetActiveTheme(c.Request.Context())
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, theme)
}

View File

@@ -0,0 +1,132 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/service"
)
// TOTPHandler handles TOTP 2FA requests
type TOTPHandler struct {
authService *service.AuthService
totpService *service.TOTPService
}
// NewTOTPHandler creates a new TOTPHandler
func NewTOTPHandler(authService *service.AuthService, totpService *service.TOTPService) *TOTPHandler {
return &TOTPHandler{
authService: authService,
totpService: totpService,
}
}
func (h *TOTPHandler) GetTOTPStatus(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
enabled, err := h.totpService.GetTOTPStatus(c.Request.Context(), userID)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"enabled": enabled})
}
func (h *TOTPHandler) SetupTOTP(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
resp, err := h.totpService.SetupTOTP(c.Request.Context(), userID)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"secret": resp.Secret,
"qr_code_base64": resp.QRCodeBase64,
"recovery_codes": resp.RecoveryCodes,
})
}
func (h *TOTPHandler) EnableTOTP(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
var req struct {
Code string `json:"code" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.totpService.EnableTOTP(c.Request.Context(), userID, req.Code); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "TOTP enabled"})
}
func (h *TOTPHandler) DisableTOTP(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
var req struct {
Code string `json:"code" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.totpService.DisableTOTP(c.Request.Context(), userID, req.Code); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "TOTP disabled"})
}
func (h *TOTPHandler) VerifyTOTP(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
var req struct {
Code string `json:"code" binding:"required"`
DeviceID string `json:"device_id,omitempty"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.authService.VerifyTOTP(c.Request.Context(), userID, req.Code, req.DeviceID); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"verified": true})
}

View File

@@ -0,0 +1,261 @@
package handler
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/service"
)
// UserHandler handles user management requests
type UserHandler struct {
userService *service.UserService
}
// NewUserHandler creates a new UserHandler
func NewUserHandler(userService *service.UserService) *UserHandler {
return &UserHandler{userService: userService}
}
func (h *UserHandler) CreateUser(c *gin.Context) {
var req struct {
Username string `json:"username" binding:"required"`
Email string `json:"email"`
Password string `json:"password"`
Nickname string `json:"nickname"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
user := &domain.User{
Username: req.Username,
Email: domain.StrPtr(req.Email),
Nickname: req.Nickname,
Status: domain.UserStatusActive,
}
if req.Password != "" {
hashed, err := auth.HashPassword(req.Password)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to hash password"})
return
}
user.Password = hashed
}
if err := h.userService.Create(c.Request.Context(), user); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusCreated, toUserResponse(user))
}
func (h *UserHandler) ListUsers(c *gin.Context) {
offset, _ := strconv.ParseInt(c.DefaultQuery("offset", "0"), 10, 64)
limit, _ := strconv.ParseInt(c.DefaultQuery("limit", "20"), 10, 64)
users, total, err := h.userService.List(c.Request.Context(), int(offset), int(limit))
if err != nil {
handleError(c, err)
return
}
userResponses := make([]*UserResponse, len(users))
for i, u := range users {
userResponses[i] = toUserResponse(u)
}
c.JSON(http.StatusOK, gin.H{
"users": userResponses,
"total": total,
"offset": offset,
"limit": limit,
})
}
func (h *UserHandler) GetUser(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
return
}
user, err := h.userService.GetByID(c.Request.Context(), id)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, toUserResponse(user))
}
func (h *UserHandler) UpdateUser(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
return
}
var req struct {
Email *string `json:"email"`
Nickname *string `json:"nickname"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
user, err := h.userService.GetByID(c.Request.Context(), id)
if err != nil {
handleError(c, err)
return
}
if req.Email != nil {
user.Email = req.Email
}
if req.Nickname != nil {
user.Nickname = *req.Nickname
}
if err := h.userService.Update(c.Request.Context(), user); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, toUserResponse(user))
}
func (h *UserHandler) DeleteUser(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
return
}
if err := h.userService.Delete(c.Request.Context(), id); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "user deleted"})
}
func (h *UserHandler) UpdatePassword(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
return
}
var req struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"new_password" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.userService.ChangePassword(c.Request.Context(), id, req.OldPassword, req.NewPassword); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "密码修改成功"})
}
func (h *UserHandler) UpdateUserStatus(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
return
}
var req struct {
Status string `json:"status" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var status domain.UserStatus
switch req.Status {
case "active", "1":
status = domain.UserStatusActive
case "inactive", "0":
status = domain.UserStatusInactive
case "locked", "2":
status = domain.UserStatusLocked
case "disabled", "3":
status = domain.UserStatusDisabled
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid status"})
return
}
if err := h.userService.UpdateStatus(c.Request.Context(), id, status); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "status updated"})
}
func (h *UserHandler) GetUserRoles(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"roles": []interface{}{}})
}
func (h *UserHandler) AssignRoles(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "role assignment not implemented"})
}
func (h *UserHandler) UploadAvatar(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "avatar upload not implemented"})
}
func (h *UserHandler) ListAdmins(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"admins": []interface{}{}})
}
func (h *UserHandler) CreateAdmin(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "admin creation not implemented"})
}
func (h *UserHandler) DeleteAdmin(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "admin deletion not implemented"})
}
type UserResponse struct {
ID int64 `json:"id"`
Username string `json:"username"`
Email string `json:"email,omitempty"`
Nickname string `json:"nickname,omitempty"`
Status string `json:"status"`
}
func toUserResponse(u *domain.User) *UserResponse {
email := ""
if u.Email != nil {
email = *u.Email
}
return &UserResponse{
ID: u.ID,
Username: u.Username,
Email: email,
Nickname: u.Nickname,
Status: strconv.FormatInt(int64(u.Status), 10),
}
}

View File

@@ -0,0 +1,39 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/service"
)
// WebhookHandler handles webhook requests
type WebhookHandler struct {
webhookService *service.WebhookService
}
// NewWebhookHandler creates a new WebhookHandler
func NewWebhookHandler(webhookService *service.WebhookService) *WebhookHandler {
return &WebhookHandler{webhookService: webhookService}
}
func (h *WebhookHandler) CreateWebhook(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "webhook creation not implemented"})
}
func (h *WebhookHandler) ListWebhooks(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"webhooks": []interface{}{}})
}
func (h *WebhookHandler) UpdateWebhook(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "webhook update not implemented"})
}
func (h *WebhookHandler) DeleteWebhook(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "webhook deletion not implemented"})
}
func (h *WebhookHandler) GetWebhookDeliveries(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"deliveries": []interface{}{}})
}

View File

@@ -0,0 +1,240 @@
package middleware
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain"
apierrors "github.com/user-management-system/internal/pkg/errors"
"github.com/user-management-system/internal/repository"
)
type AuthMiddleware struct {
jwt *auth.JWT
userRepo *repository.UserRepository
userRoleRepo *repository.UserRoleRepository
roleRepo *repository.RoleRepository
rolePermissionRepo *repository.RolePermissionRepository
permissionRepo *repository.PermissionRepository
l1Cache *cache.L1Cache
cacheManager *cache.CacheManager
}
func NewAuthMiddleware(
jwt *auth.JWT,
userRepo *repository.UserRepository,
userRoleRepo *repository.UserRoleRepository,
roleRepo *repository.RoleRepository,
rolePermissionRepo *repository.RolePermissionRepository,
permissionRepo *repository.PermissionRepository,
) *AuthMiddleware {
return &AuthMiddleware{
jwt: jwt,
userRepo: userRepo,
userRoleRepo: userRoleRepo,
roleRepo: roleRepo,
rolePermissionRepo: rolePermissionRepo,
permissionRepo: permissionRepo,
l1Cache: cache.NewL1Cache(),
}
}
func (m *AuthMiddleware) SetCacheManager(cm *cache.CacheManager) {
m.cacheManager = cm
}
func (m *AuthMiddleware) Required() gin.HandlerFunc {
return func(c *gin.Context) {
token := m.extractToken(c)
if token == "" {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "未提供认证令牌"))
c.Abort()
return
}
claims, err := m.jwt.ValidateAccessToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "无效的认证令牌"))
c.Abort()
return
}
if m.isJTIBlacklisted(claims.JTI) {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "令牌已失效,请重新登录"))
c.Abort()
return
}
if !m.isUserActive(c.Request.Context(), claims.UserID) {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "账号不可用,请重新登录"))
c.Abort()
return
}
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("token_jti", claims.JTI)
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
c.Set("role_codes", roleCodes)
c.Set("permission_codes", permCodes)
c.Next()
}
}
func (m *AuthMiddleware) Optional() gin.HandlerFunc {
return func(c *gin.Context) {
token := m.extractToken(c)
if token != "" {
claims, err := m.jwt.ValidateAccessToken(token)
if err == nil && !m.isJTIBlacklisted(claims.JTI) && m.isUserActive(c.Request.Context(), claims.UserID) {
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("token_jti", claims.JTI)
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
c.Set("role_codes", roleCodes)
c.Set("permission_codes", permCodes)
}
}
c.Next()
}
}
func (m *AuthMiddleware) isJTIBlacklisted(jti string) bool {
if jti == "" {
return false
}
key := "jwt_blacklist:" + jti
if _, ok := m.l1Cache.Get(key); ok {
return true
}
if m.cacheManager != nil {
if _, ok := m.cacheManager.Get(context.Background(), key); ok {
return true
}
}
return false
}
func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) {
if m.userRoleRepo == nil || m.roleRepo == nil || m.rolePermissionRepo == nil || m.permissionRepo == nil {
return nil, nil
}
cacheKey := fmt.Sprintf("user_perms:%d", userID)
if cached, ok := m.l1Cache.Get(cacheKey); ok {
if entry, ok := cached.(userPermEntry); ok {
return entry.roles, entry.perms
}
}
roleIDs, err := m.userRoleRepo.GetRoleIDsByUserID(ctx, userID)
if err != nil || len(roleIDs) == 0 {
return nil, nil
}
// 收集所有角色ID包括直接分配的角色和所有祖先角色
allRoleIDs := make([]int64, 0, len(roleIDs)*2)
allRoleIDs = append(allRoleIDs, roleIDs...)
for _, roleID := range roleIDs {
ancestorIDs, err := m.roleRepo.GetAncestorIDs(ctx, roleID)
if err == nil && len(ancestorIDs) > 0 {
allRoleIDs = append(allRoleIDs, ancestorIDs...)
}
}
// 去重
seen := make(map[int64]bool)
uniqueRoleIDs := make([]int64, 0, len(allRoleIDs))
for _, id := range allRoleIDs {
if !seen[id] {
seen[id] = true
uniqueRoleIDs = append(uniqueRoleIDs, id)
}
}
roles, err := m.roleRepo.GetByIDs(ctx, roleIDs)
if err != nil {
return nil, nil
}
roleCodes := make([]string, 0, len(roles))
for _, role := range roles {
roleCodes = append(roleCodes, role.Code)
}
permissionIDs, err := m.rolePermissionRepo.GetPermissionIDsByRoleIDs(ctx, uniqueRoleIDs)
if err != nil || len(permissionIDs) == 0 {
entry := userPermEntry{roles: roleCodes, perms: []string{}}
m.l1Cache.Set(cacheKey, entry, 30*time.Minute) // PERF-01 优化:增加缓存 TTL 减少 DB 查询
return entry.roles, entry.perms
}
permissions, err := m.permissionRepo.GetByIDs(ctx, permissionIDs)
if err != nil {
return roleCodes, nil
}
permCodes := make([]string, 0, len(permissions))
for _, permission := range permissions {
permCodes = append(permCodes, permission.Code)
}
m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 30*time.Minute) // PERF-01 优化:增加缓存 TTL 减少 DB 查询
return roleCodes, permCodes
}
func (m *AuthMiddleware) InvalidateUserPermCache(userID int64) {
m.l1Cache.Delete(fmt.Sprintf("user_perms:%d", userID))
}
func (m *AuthMiddleware) AddToBlacklist(jti string, ttl time.Duration) {
if jti != "" && ttl > 0 {
m.l1Cache.Set("jwt_blacklist:"+jti, true, ttl)
}
}
func (m *AuthMiddleware) isUserActive(ctx context.Context, userID int64) bool {
if m.userRepo == nil {
return true
}
user, err := m.userRepo.GetByID(ctx, userID)
if err != nil {
return false
}
return user.Status == domain.UserStatusActive
}
func (m *AuthMiddleware) extractToken(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
return ""
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
return ""
}
return parts[1]
}
type userPermEntry struct {
roles []string
perms []string
}

View File

@@ -0,0 +1,32 @@
package middleware
import (
"strings"
"github.com/gin-gonic/gin"
)
const sensitiveNoStoreCacheControl = "no-store, no-cache, must-revalidate, max-age=0"
// NoStoreSensitiveResponses prevents browser or intermediary caching for auth routes.
func NoStoreSensitiveResponses() gin.HandlerFunc {
return func(c *gin.Context) {
if shouldDisableCaching(c.FullPath(), c.Request.URL.Path) {
headers := c.Writer.Header()
headers.Set("Cache-Control", sensitiveNoStoreCacheControl)
headers.Set("Pragma", "no-cache")
headers.Set("Expires", "0")
headers.Set("Surrogate-Control", "no-store")
}
c.Next()
}
}
func shouldDisableCaching(routePath, requestPath string) bool {
path := strings.TrimSpace(routePath)
if path == "" {
path = strings.TrimSpace(requestPath)
}
return strings.HasPrefix(path, "/api/v1/auth")
}

View File

@@ -0,0 +1,67 @@
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
)
var corsConfig = config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
}
func SetCORSConfig(cfg config.CORSConfig) {
corsConfig = cfg
}
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
cfg := corsConfig
origin := c.GetHeader("Origin")
if origin != "" {
allowOrigin, allowed := resolveAllowedOrigin(origin, cfg.AllowedOrigins, cfg.AllowCredentials)
if !allowed {
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusForbidden)
return
}
c.AbortWithStatus(http.StatusForbidden)
return
}
c.Writer.Header().Set("Access-Control-Allow-Origin", allowOrigin)
if cfg.AllowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
if c.Request.Method == http.MethodOptions {
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With, X-CSRF-Token")
c.Writer.Header().Set("Access-Control-Max-Age", "3600")
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
func resolveAllowedOrigin(origin string, allowedOrigins []string, allowCredentials bool) (string, bool) {
for _, allowed := range allowedOrigins {
if allowed == "*" {
if allowCredentials {
return origin, true
}
return "*", true
}
if strings.EqualFold(origin, allowed) {
return origin, true
}
}
return "", false
}

View File

@@ -0,0 +1,43 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
apierrors "github.com/user-management-system/internal/pkg/errors"
)
// ErrorHandler 错误处理中间件
func ErrorHandler() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
// 检查是否有错误
if len(c.Errors) > 0 {
// 获取最后一个错误
err := c.Errors.Last()
// 判断错误类型
if appErr, ok := err.Err.(*apierrors.ApplicationError); ok {
c.JSON(int(appErr.Code), appErr)
} else {
c.JSON(http.StatusInternalServerError, apierrors.New(http.StatusInternalServerError, "", err.Err.Error()))
}
return
}
}
}
// Recover 恢复中间件
func Recover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
c.JSON(http.StatusInternalServerError, apierrors.New(http.StatusInternalServerError, "", "服务器内部错误"))
c.Abort()
}
}()
c.Next()
}
}

View File

@@ -0,0 +1,134 @@
package middleware
import (
"net"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/security"
)
// IPFilterConfig IP过滤中间件配置
type IPFilterConfig struct {
TrustProxy bool // 是否信任 X-Forwarded-For
TrustedProxies []string // 可信代理 IP 列表
}
// IPFilterMiddleware IP 黑白名单过滤中间件
type IPFilterMiddleware struct {
filter *security.IPFilter
config IPFilterConfig
}
// NewIPFilterMiddleware 创建 IP 过滤中间件
func NewIPFilterMiddleware(filter *security.IPFilter, config IPFilterConfig) *IPFilterMiddleware {
return &IPFilterMiddleware{filter: filter, config: config}
}
// Filter 返回 Gin 中间件 HandlerFunc
// 逻辑:先取客户端真实 IP → 检查黑名单 → 被封则返回 403 并终止
func (m *IPFilterMiddleware) Filter() gin.HandlerFunc {
return func(c *gin.Context) {
ip := m.realIP(c)
blocked, reason := m.filter.IsBlocked(ip)
if blocked {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "访问被拒绝:" + reason,
})
return
}
// 将真实 IP 写入 context供后续中间件和 handler 直接取用
c.Set("client_ip", ip)
c.Next()
}
}
// GetFilter 返回底层 IPFilter供 handler 层做黑白名单管理
func (m *IPFilterMiddleware) GetFilter() *security.IPFilter {
return m.filter
}
// realIP 从请求中提取真实客户端 IP
// 优先级X-Forwarded-For > X-Real-IP > RemoteAddr
// SEC-05 修复:如果启用 TrustProxy只接受来自可信代理的 X-Forwarded-For
func (m *IPFilterMiddleware) realIP(c *gin.Context) string {
// 如果不信任代理,直接使用 TCP 连接 IP
if !m.config.TrustProxy {
return c.ClientIP()
}
// X-Forwarded-For 可能包含代理链
xff := c.GetHeader("X-Forwarded-For")
if xff != "" {
// 从右到左遍历(最右边是最后一次代理添加的)
for _, part := range strings.Split(xff, ",") {
ip := strings.TrimSpace(part)
if ip == "" {
continue
}
// 检查是否是可信代理
if !m.isTrustedProxy(ip) {
continue // 不是可信代理,跳过
}
// 是可信代理,检查是否为公网 IP
if !isPrivateIP(ip) {
return ip
}
}
}
// X-Real-IPNginx 反代常用)
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// 直接 TCP 连接的 RemoteAddr去掉端口号
ip, _, err := net.SplitHostPort(c.Request.RemoteAddr)
if err != nil {
return c.Request.RemoteAddr
}
return ip
}
// isTrustedProxy 检查 IP 是否在可信代理列表中
func (m *IPFilterMiddleware) isTrustedProxy(ip string) bool {
if len(m.config.TrustedProxies) == 0 {
return true // 如果没有配置可信代理列表,默认信任所有(兼容旧行为)
}
for _, trusted := range m.config.TrustedProxies {
if ip == trusted {
return true
}
}
return false
}
// isPrivateIP 判断是否为内网 IP
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
privateRanges := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"::1/128",
"fc00::/7",
}
for _, cidr := range privateRanges {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
if network.Contains(ip) {
return true
}
}
return false
}

View File

@@ -0,0 +1,258 @@
package middleware
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/security"
)
func init() {
gin.SetMode(gin.TestMode)
}
// newTestEngine 用给定的 IPFilterMiddleware 构建一个最简 Gin 引擎,
// 注册一个 GET /ping 路由,返回 client_ip 值。
func newTestEngine(f *security.IPFilter) *gin.Engine {
engine := gin.New()
engine.Use(NewIPFilterMiddleware(f, IPFilterConfig{}).Filter())
engine.GET("/ping", func(c *gin.Context) {
ip, _ := c.Get("client_ip")
c.JSON(http.StatusOK, gin.H{"ip": ip})
})
return engine
}
// doRequest 发送 GET /ping返回响应码和响应 body map。
func doRequest(engine *gin.Engine, remoteAddr, xff, xri string) (int, map[string]interface{}) {
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
req.RemoteAddr = remoteAddr
if xff != "" {
req.Header.Set("X-Forwarded-For", xff)
}
if xri != "" {
req.Header.Set("X-Real-IP", xri)
}
w := httptest.NewRecorder()
engine.ServeHTTP(w, req)
var body map[string]interface{}
_ = json.Unmarshal(w.Body.Bytes(), &body)
return w.Code, body
}
// ---------- 黑名单拦截 ----------
func TestIPFilter_BlockedIP_Returns403(t *testing.T) {
f := security.NewIPFilter()
_ = f.AddToBlacklist("1.2.3.4", "测试封禁", 0)
engine := newTestEngine(f)
code, body := doRequest(engine, "1.2.3.4:9999", "", "")
if code != http.StatusForbidden {
t.Fatalf("期望 403实际 %d", code)
}
msg, _ := body["message"].(string)
if msg == "" {
t.Error("期望 body 中包含 message 字段")
}
}
func TestIPFilter_NonBlockedIP_Returns200(t *testing.T) {
f := security.NewIPFilter()
_ = f.AddToBlacklist("9.9.9.9", "其他 IP", 0)
engine := newTestEngine(f)
code, _ := doRequest(engine, "1.2.3.4:9999", "", "")
if code != http.StatusOK {
t.Fatalf("期望 200实际 %d", code)
}
}
func TestIPFilter_EmptyBlacklist_AllPass(t *testing.T) {
f := security.NewIPFilter()
engine := newTestEngine(f)
for _, ip := range []string{"1.1.1.1:80", "8.8.8.8:443", "203.0.113.5:1234"} {
code, _ := doRequest(engine, ip, "", "")
if code != http.StatusOK {
t.Errorf("IP %s 应通过,实际 %d", ip, code)
}
}
}
// ---------- 白名单豁免 ----------
func TestIPFilter_WhitelistOverridesBlacklist(t *testing.T) {
f := security.NewIPFilter()
_ = f.AddToBlacklist("5.5.5.5", "封禁测试", 0)
_ = f.AddToWhitelist("5.5.5.5", "白名单豁免")
engine := newTestEngine(f)
// 白名单优先,应通过
code, _ := doRequest(engine, "5.5.5.5:8080", "", "")
if code != http.StatusOK {
t.Fatalf("白名单 IP 应返回 200实际 %d", code)
}
}
// ---------- CIDR 黑名单 ----------
func TestIPFilter_CIDRBlacklist(t *testing.T) {
f := security.NewIPFilter()
_ = f.AddToBlacklist("10.10.10.0/24", "封禁整段 CIDR", 0)
engine := newTestEngine(f)
// 在 CIDR 范围内,应被封
code, _ := doRequest(engine, "10.10.10.55:1234", "", "")
if code != http.StatusForbidden {
t.Fatalf("CIDR 内 IP 应返回 403实际 %d", code)
}
// 不在 CIDR 范围内,应通过
code2, _ := doRequest(engine, "10.10.11.1:1234", "", "")
if code2 != http.StatusOK {
t.Fatalf("CIDR 外 IP 应返回 200实际 %d", code2)
}
}
// ---------- 过期规则 ----------
func TestIPFilter_ExpiredRule_Passes(t *testing.T) {
f := security.NewIPFilter()
// 封禁 1 纳秒,几乎立即过期
_ = f.AddToBlacklist("7.7.7.7", "即将过期", time.Nanosecond)
time.Sleep(2 * time.Millisecond)
engine := newTestEngine(f)
code, _ := doRequest(engine, "7.7.7.7:80", "", "")
if code != http.StatusOK {
t.Fatalf("过期规则不应拦截,期望 200实际 %d", code)
}
}
// ---------- client_ip 注入 ----------
func TestIPFilter_ClientIPSetInContext(t *testing.T) {
f := security.NewIPFilter()
engine := newTestEngine(f)
code, body := doRequest(engine, "203.0.113.1:9000", "", "")
if code != http.StatusOK {
t.Fatalf("期望 200实际 %d", code)
}
ip, _ := body["ip"].(string)
if ip != "203.0.113.1" {
t.Errorf("期望 client_ip=203.0.113.1,实际 %q", ip)
}
}
// ---------- realIP 提取逻辑 ----------
// TestRealIP_XForwardedFor_PublicIP 公网 X-Forwarded-For 取第一个非内网 IP
func TestRealIP_XForwardedFor_PublicIP(t *testing.T) {
f := security.NewIPFilter()
engine := newTestEngine(f)
// X-Forwarded-For: 203.0.113.10, 192.168.1.1(代理内网)
code, body := doRequest(engine, "192.168.1.1:80", "203.0.113.10, 192.168.1.1", "")
if code != http.StatusOK {
t.Fatalf("期望 200实际 %d", code)
}
ip, _ := body["ip"].(string)
if ip != "203.0.113.10" {
t.Errorf("期望从 X-Forwarded-For 取公网 IP实际 %q", ip)
}
}
// TestRealIP_XForwardedFor_AllPrivate 全内网则取第一个
func TestRealIP_XForwardedFor_AllPrivate(t *testing.T) {
f := security.NewIPFilter()
engine := newTestEngine(f)
code, body := doRequest(engine, "10.0.0.2:80", "192.168.0.5, 10.0.0.1", "")
if code != http.StatusOK {
t.Fatalf("期望 200实际 %d", code)
}
ip, _ := body["ip"].(string)
if ip != "192.168.0.5" {
t.Errorf("全内网时应取第一个,实际 %q", ip)
}
}
// TestRealIP_XRealIP_Fallback X-Forwarded-For 缺失时使用 X-Real-IP
func TestRealIP_XRealIP_Fallback(t *testing.T) {
f := security.NewIPFilter()
engine := newTestEngine(f)
code, body := doRequest(engine, "192.168.1.1:80", "", "203.0.113.20")
if code != http.StatusOK {
t.Fatalf("期望 200实际 %d", code)
}
ip, _ := body["ip"].(string)
if ip != "203.0.113.20" {
t.Errorf("期望 X-Real-IP 回退,实际 %q", ip)
}
}
// TestRealIP_RemoteAddr_Fallback 都无 header 时用 RemoteAddr
func TestRealIP_RemoteAddr_Fallback(t *testing.T) {
f := security.NewIPFilter()
engine := newTestEngine(f)
code, body := doRequest(engine, "203.0.113.99:12345", "", "")
if code != http.StatusOK {
t.Fatalf("期望 200实际 %d", code)
}
ip, _ := body["ip"].(string)
if ip != "203.0.113.99" {
t.Errorf("期望 RemoteAddr 回退,实际 %q", ip)
}
}
// ---------- GetFilter ----------
func TestIPFilterMiddleware_GetFilter(t *testing.T) {
f := security.NewIPFilter()
mw := NewIPFilterMiddleware(f, IPFilterConfig{})
if mw.GetFilter() != f {
t.Error("GetFilter 应返回同一个 IPFilter 实例")
}
}
// ---------- 并发安全 ----------
func TestIPFilter_ConcurrentRequests(t *testing.T) {
f := security.NewIPFilter()
_ = f.AddToBlacklist("66.66.66.66", "并发测试封禁", 0)
engine := newTestEngine(f)
done := make(chan struct{}, 20)
for i := 0; i < 20; i++ {
go func(i int) {
defer func() { done <- struct{}{} }()
var remoteAddr string
if i%2 == 0 {
remoteAddr = "66.66.66.66:9000"
} else {
remoteAddr = "1.2.3.4:9000"
}
code, _ := doRequest(engine, remoteAddr, "", "")
if i%2 == 0 && code != http.StatusForbidden {
t.Errorf("并发:封禁 IP 应返回 403实际 %d", code)
} else if i%2 != 0 && code != http.StatusOK {
t.Errorf("并发:正常 IP 应返回 200实际 %d", code)
}
}(i)
}
for i := 0; i < 20; i++ {
<-done
}
}

View File

@@ -0,0 +1,83 @@
package middleware
import (
"log"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
)
var sensitiveQueryKeys = map[string]struct{}{
"token": {},
"access_token": {},
"refresh_token": {},
"code": {},
"secret": {},
}
func Logger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
raw := sanitizeQuery(c.Request.URL.RawQuery)
c.Next()
latency := time.Since(start)
status := c.Writer.Status()
method := c.Request.Method
ip := c.ClientIP()
userAgent := c.Request.UserAgent()
userID, _ := c.Get("user_id")
log.Printf("[API] %s %s %s | status: %d | latency: %v | ip: %s | user_id: %v | ua: %s",
time.Now().Format("2006-01-02 15:04:05"),
method,
path,
status,
latency,
ip,
userID,
userAgent,
)
if len(c.Errors) > 0 {
for _, err := range c.Errors {
log.Printf("[Error] %v", err)
}
}
if raw != "" {
log.Printf("[Query] %s?%s", path, raw)
}
}
}
func sanitizeQuery(raw string) string {
if raw == "" {
return ""
}
values, err := url.ParseQuery(raw)
if err != nil {
return ""
}
for key := range values {
if isSensitiveQueryKey(key) {
values.Set(key, "***")
}
}
return values.Encode()
}
func isSensitiveQueryKey(key string) bool {
normalized := strings.ToLower(strings.TrimSpace(key))
if _, ok := sensitiveQueryKeys[normalized]; ok {
return true
}
return strings.Contains(normalized, "token") || strings.Contains(normalized, "secret")
}

View File

@@ -0,0 +1,125 @@
package middleware
import (
"bytes"
"context"
"encoding/json"
"io"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
type OperationLogMiddleware struct {
repo *repository.OperationLogRepository
}
func NewOperationLogMiddleware(repo *repository.OperationLogRepository) *OperationLogMiddleware {
return &OperationLogMiddleware{repo: repo}
}
type bodyWriter struct {
gin.ResponseWriter
statusCode int
}
func newBodyWriter(w gin.ResponseWriter) *bodyWriter {
return &bodyWriter{ResponseWriter: w, statusCode: 200}
}
func (bw *bodyWriter) WriteHeader(code int) {
bw.statusCode = code
bw.ResponseWriter.WriteHeader(code)
}
func (bw *bodyWriter) WriteHeaderNow() {
bw.ResponseWriter.WriteHeaderNow()
}
func (m *OperationLogMiddleware) Record() gin.HandlerFunc {
return func(c *gin.Context) {
method := c.Request.Method
if method == "GET" || method == "HEAD" || method == "OPTIONS" {
c.Next()
return
}
var reqParams string
if c.Request.Body != nil {
bodyBytes, err := io.ReadAll(io.LimitReader(c.Request.Body, 4096))
if err == nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
reqParams = sanitizeParams(bodyBytes)
}
}
bw := newBodyWriter(c.Writer)
c.Writer = bw
c.Next()
var userIDPtr *int64
if uid, exists := c.Get("user_id"); exists {
if id, ok := uid.(int64); ok {
userID := id
userIDPtr = &userID
}
}
logEntry := &domain.OperationLog{
UserID: userIDPtr,
OperationType: methodToType(method),
OperationName: c.FullPath(),
RequestMethod: method,
RequestPath: c.Request.URL.Path,
RequestParams: reqParams,
ResponseStatus: bw.statusCode,
IP: c.ClientIP(),
UserAgent: c.Request.UserAgent(),
}
go func(entry *domain.OperationLog) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = m.repo.Create(ctx, entry)
}(logEntry)
}
}
func methodToType(method string) string {
switch method {
case "POST":
return "CREATE"
case "PUT", "PATCH":
return "UPDATE"
case "DELETE":
return "DELETE"
default:
return "OTHER"
}
}
func sanitizeParams(data []byte) string {
var payload map[string]interface{}
if err := json.Unmarshal(data, &payload); err != nil {
if len(data) > 500 {
return string(data[:500]) + "..."
}
return string(data)
}
for _, field := range []string{"password", "old_password", "new_password", "confirm_password", "secret", "token"} {
if _, ok := payload[field]; ok {
payload[field] = "***"
}
}
result, err := json.Marshal(payload)
if err != nil {
return ""
}
return string(result)
}

View File

@@ -0,0 +1,127 @@
package middleware
import (
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
)
// RateLimitMiddleware 限流中间件
type RateLimitMiddleware struct {
cfg config.RateLimitConfig
limiters map[string]*SlidingWindowLimiter
mu sync.RWMutex
cleanupInt time.Duration
}
// SlidingWindowLimiter 滑动窗口限流器
type SlidingWindowLimiter struct {
mu sync.Mutex
window time.Duration
capacity int64
requests []int64
}
// NewSlidingWindowLimiter 创建滑动窗口限流器
func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter {
return &SlidingWindowLimiter{
window: window,
capacity: capacity,
requests: make([]int64, 0),
}
}
// Allow 检查是否允许请求
func (l *SlidingWindowLimiter) Allow() bool {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now().UnixMilli()
cutoff := now - l.window.Milliseconds()
// 清理过期请求
var validRequests []int64
for _, t := range l.requests {
if t > cutoff {
validRequests = append(validRequests, t)
}
}
l.requests = validRequests
// 检查容量
if int64(len(l.requests)) >= l.capacity {
return false
}
l.requests = append(l.requests, now)
return true
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
return &RateLimitMiddleware{
cfg: cfg,
limiters: make(map[string]*SlidingWindowLimiter),
cleanupInt: 5 * time.Minute,
}
}
// Register 返回注册接口的限流中间件
func (m *RateLimitMiddleware) Register() gin.HandlerFunc {
return m.limitForKey("register", 60, 10)
}
// Login 返回登录接口的限流中间件
func (m *RateLimitMiddleware) Login() gin.HandlerFunc {
return m.limitForKey("login", 60, 5)
}
// API 返回 API 接口的限流中间件
func (m *RateLimitMiddleware) API() gin.HandlerFunc {
return m.limitForKey("api", 60, 100)
}
// Refresh 返回刷新令牌的限流中间件
func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
return m.limitForKey("refresh", 60, 10)
}
func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc {
limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity)
return func(c *gin.Context) {
if !limiter.Allow() {
c.JSON(429, gin.H{
"code": 429,
"message": "请求过于频繁,请稍后再试",
})
c.Abort()
return
}
c.Next()
}
}
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
m.mu.RLock()
limiter, exists := m.limiters[key]
m.mu.RUnlock()
if exists {
return limiter
}
m.mu.Lock()
defer m.mu.Unlock()
// 双重检查
if limiter, exists = m.limiters[key]; exists {
return limiter
}
limiter = NewSlidingWindowLimiter(window, capacity)
m.limiters[key] = limiter
return limiter
}

View File

@@ -0,0 +1,156 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
)
// contextKey 上下文键常量
const (
ContextKeyRoleCodes = "role_codes"
ContextKeyPermissionCodes = "permission_codes"
)
// RequirePermission 要求用户拥有指定权限之一OR 逻辑)
// 适用于需要单个或多选权限校验的路由
func RequirePermission(codes ...string) gin.HandlerFunc {
return func(c *gin.Context) {
if !hasAnyPermission(c, codes) {
c.JSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "权限不足",
})
c.Abort()
return
}
c.Next()
}
}
// RequireAllPermissions 要求用户拥有所有指定权限AND 逻辑)
func RequireAllPermissions(codes ...string) gin.HandlerFunc {
return func(c *gin.Context) {
if !hasAllPermissions(c, codes) {
c.JSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "权限不足,需要所有指定权限",
})
c.Abort()
return
}
c.Next()
}
}
// RequireRole 要求用户拥有指定角色之一OR 逻辑)
func RequireRole(codes ...string) gin.HandlerFunc {
return func(c *gin.Context) {
if !hasAnyRole(c, codes) {
c.JSON(http.StatusForbidden, gin.H{
"code": 403,
"message": "权限不足,角色受限",
})
c.Abort()
return
}
c.Next()
}
}
// RequireAnyPermission RequirePermission 的别名,语义更清晰
func RequireAnyPermission(codes ...string) gin.HandlerFunc {
return RequirePermission(codes...)
}
// AdminOnly 仅限 admin 角色
func AdminOnly() gin.HandlerFunc {
return RequireRole("admin")
}
// GetRoleCodes 从 Context 获取当前用户角色代码列表
func GetRoleCodes(c *gin.Context) []string {
val, exists := c.Get(ContextKeyRoleCodes)
if !exists {
return nil
}
if codes, ok := val.([]string); ok {
return codes
}
return nil
}
// GetPermissionCodes 从 Context 获取当前用户权限代码列表
func GetPermissionCodes(c *gin.Context) []string {
val, exists := c.Get(ContextKeyPermissionCodes)
if !exists {
return nil
}
if codes, ok := val.([]string); ok {
return codes
}
return nil
}
// IsAdmin 判断当前用户是否为 admin
func IsAdmin(c *gin.Context) bool {
return hasAnyRole(c, []string{"admin"})
}
// hasAnyPermission 判断用户是否拥有任意一个权限
func hasAnyPermission(c *gin.Context, codes []string) bool {
// admin 角色拥有所有权限
if IsAdmin(c) {
return true
}
permCodes := GetPermissionCodes(c)
if len(permCodes) == 0 {
return false
}
permSet := toSet(permCodes)
for _, code := range codes {
if _, ok := permSet[code]; ok {
return true
}
}
return false
}
// hasAllPermissions 判断用户是否拥有所有权限
func hasAllPermissions(c *gin.Context, codes []string) bool {
if IsAdmin(c) {
return true
}
permCodes := GetPermissionCodes(c)
permSet := toSet(permCodes)
for _, code := range codes {
if _, ok := permSet[code]; !ok {
return false
}
}
return true
}
// hasAnyRole 判断用户是否拥有任意一个角色
func hasAnyRole(c *gin.Context, codes []string) bool {
roleCodes := GetRoleCodes(c)
if len(roleCodes) == 0 {
return false
}
roleSet := toSet(roleCodes)
for _, code := range codes {
if _, ok := roleSet[code]; ok {
return true
}
}
return false
}
// toSet 将字符串切片转换为 map 集合
func toSet(items []string) map[string]struct{} {
s := make(map[string]struct{}, len(items))
for _, item := range items {
s[item] = struct{}{}
}
return s
}

View File

@@ -0,0 +1,139 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
)
func TestCORS_UsesConfiguredOrigins(t *testing.T) {
gin.SetMode(gin.TestMode)
SetCORSConfig(config.CORSConfig{
AllowedOrigins: []string{"https://app.example.com"},
AllowCredentials: true,
})
t.Cleanup(func() {
SetCORSConfig(config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
})
})
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodOptions, "/api/v1/users", nil)
c.Request.Header.Set("Origin", "https://app.example.com")
c.Request.Header.Set("Access-Control-Request-Headers", "Authorization")
CORS()(c)
if recorder.Code != http.StatusNoContent {
t.Fatalf("expected 204, got %d", recorder.Code)
}
if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != "https://app.example.com" {
t.Fatalf("unexpected allow origin: %s", got)
}
if got := recorder.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
t.Fatalf("expected credentials header to be 'true', got %q", got)
}
}
func TestSanitizeQuery_MasksSensitiveValues(t *testing.T) {
raw := "token=abc123&foo=bar&access_token=xyz&secret=s1"
sanitized := sanitizeQuery(raw)
if sanitized == "" {
t.Fatal("expected sanitized query")
}
if sanitized == raw {
t.Fatal("expected query to be sanitized")
}
for _, value := range []string{"abc123", "xyz", "s1"} {
if strings.Contains(sanitized, value) {
t.Fatalf("expected sensitive value %q to be masked in %q", value, sanitized)
}
}
if sanitizeQuery("") != "" {
t.Fatal("expected empty query to stay empty")
}
}
func TestSecurityHeaders_AttachesExpectedHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
SecurityHeaders()(c)
if got := recorder.Header().Get("X-Content-Type-Options"); got != "nosniff" {
t.Fatalf("unexpected nosniff header: %q", got)
}
if got := recorder.Header().Get("X-Frame-Options"); got != "DENY" {
t.Fatalf("unexpected frame options: %q", got)
}
if got := recorder.Header().Get("Content-Security-Policy"); got == "" {
t.Fatal("expected content security policy header")
}
if got := recorder.Header().Get("Strict-Transport-Security"); got != "" {
t.Fatalf("did not expect hsts header for http request, got %q", got)
}
}
func TestSecurityHeaders_AttachesHSTSForForwardedHTTPS(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
c.Request.Header.Set("X-Forwarded-Proto", "https")
SecurityHeaders()(c)
if got := recorder.Header().Get("Strict-Transport-Security"); !strings.Contains(got, "max-age=31536000") {
t.Fatalf("expected hsts header, got %q", got)
}
}
func TestNoStoreSensitiveResponses_AttachesExpectedHeadersToAuthRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/capabilities", nil)
NoStoreSensitiveResponses()(c)
if got := recorder.Header().Get("Cache-Control"); got != sensitiveNoStoreCacheControl {
t.Fatalf("unexpected cache-control header: %q", got)
}
if got := recorder.Header().Get("Pragma"); got != "no-cache" {
t.Fatalf("unexpected pragma header: %q", got)
}
if got := recorder.Header().Get("Expires"); got != "0" {
t.Fatalf("unexpected expires header: %q", got)
}
if got := recorder.Header().Get("Surrogate-Control"); got != "no-store" {
t.Fatalf("unexpected surrogate-control header: %q", got)
}
}
func TestNoStoreSensitiveResponses_DoesNotAttachHeadersToNonAuthRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
NoStoreSensitiveResponses()(c)
if got := recorder.Header().Get("Cache-Control"); got != "" {
t.Fatalf("did not expect cache-control header, got %q", got)
}
}

View File

@@ -0,0 +1,45 @@
package middleware
import (
"strings"
"github.com/gin-gonic/gin"
)
const contentSecurityPolicy = "default-src 'none'; frame-ancestors 'none'; base-uri 'none'; form-action 'self'"
func SecurityHeaders() gin.HandlerFunc {
return func(c *gin.Context) {
headers := c.Writer.Header()
headers.Set("X-Content-Type-Options", "nosniff")
headers.Set("X-Frame-Options", "DENY")
headers.Set("Referrer-Policy", "strict-origin-when-cross-origin")
headers.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
headers.Set("Cross-Origin-Opener-Policy", "same-origin")
headers.Set("X-Permitted-Cross-Domain-Policies", "none")
if shouldAttachCSP(c.FullPath(), c.Request.URL.Path) {
headers.Set("Content-Security-Policy", contentSecurityPolicy)
}
if isHTTPSRequest(c) {
headers.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
c.Next()
}
}
func shouldAttachCSP(routePath, requestPath string) bool {
path := strings.TrimSpace(routePath)
if path == "" {
path = strings.TrimSpace(requestPath)
}
return !strings.HasPrefix(path, "/swagger/")
}
func isHTTPSRequest(c *gin.Context) bool {
if c.Request.TLS != nil {
return true
}
return strings.EqualFold(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")), "https")
}

View File

@@ -0,0 +1,367 @@
package router
import (
"github.com/gin-gonic/gin"
swaggerFiles "github.com/swaggo/files"
"github.com/swaggo/gin-swagger"
"github.com/user-management-system/internal/api/handler"
"github.com/user-management-system/internal/api/middleware"
)
type Router struct {
engine *gin.Engine
authHandler *handler.AuthHandler
userHandler *handler.UserHandler
roleHandler *handler.RoleHandler
permissionHandler *handler.PermissionHandler
deviceHandler *handler.DeviceHandler
logHandler *handler.LogHandler
passwordResetHandler *handler.PasswordResetHandler
captchaHandler *handler.CaptchaHandler
totpHandler *handler.TOTPHandler
webhookHandler *handler.WebhookHandler
exportHandler *handler.ExportHandler
statsHandler *handler.StatsHandler
smsHandler *handler.SMSHandler
avatarHandler *handler.AvatarHandler
customFieldHandler *handler.CustomFieldHandler
themeHandler *handler.ThemeHandler
authMiddleware *middleware.AuthMiddleware
rateLimitMiddleware *middleware.RateLimitMiddleware
opLogMiddleware *middleware.OperationLogMiddleware
ipFilterMiddleware *middleware.IPFilterMiddleware
ssoHandler *handler.SSOHandler
}
func NewRouter(
authHandler *handler.AuthHandler,
userHandler *handler.UserHandler,
roleHandler *handler.RoleHandler,
permissionHandler *handler.PermissionHandler,
deviceHandler *handler.DeviceHandler,
logHandler *handler.LogHandler,
authMiddleware *middleware.AuthMiddleware,
rateLimitMiddleware *middleware.RateLimitMiddleware,
opLogMiddleware *middleware.OperationLogMiddleware,
passwordResetHandler *handler.PasswordResetHandler,
captchaHandler *handler.CaptchaHandler,
totpHandler *handler.TOTPHandler,
webhookHandler *handler.WebhookHandler,
ipFilterMiddleware *middleware.IPFilterMiddleware,
exportHandler *handler.ExportHandler,
statsHandler *handler.StatsHandler,
smsHandler *handler.SMSHandler,
customFieldHandler *handler.CustomFieldHandler,
themeHandler *handler.ThemeHandler,
ssoHandler *handler.SSOHandler,
avatarHandler ...*handler.AvatarHandler,
) *Router {
engine := gin.New()
var avatar *handler.AvatarHandler
if len(avatarHandler) > 0 {
avatar = avatarHandler[0]
}
return &Router{
engine: engine,
authHandler: authHandler,
userHandler: userHandler,
roleHandler: roleHandler,
permissionHandler: permissionHandler,
deviceHandler: deviceHandler,
logHandler: logHandler,
passwordResetHandler: passwordResetHandler,
captchaHandler: captchaHandler,
totpHandler: totpHandler,
webhookHandler: webhookHandler,
exportHandler: exportHandler,
statsHandler: statsHandler,
smsHandler: smsHandler,
customFieldHandler: customFieldHandler,
themeHandler: themeHandler,
ssoHandler: ssoHandler,
avatarHandler: avatar,
authMiddleware: authMiddleware,
rateLimitMiddleware: rateLimitMiddleware,
opLogMiddleware: opLogMiddleware,
ipFilterMiddleware: ipFilterMiddleware,
}
}
func (r *Router) Setup() *gin.Engine {
r.engine.Use(middleware.Recover())
r.engine.Use(middleware.ErrorHandler())
r.engine.Use(middleware.Logger())
r.engine.Use(middleware.SecurityHeaders())
r.engine.Use(middleware.NoStoreSensitiveResponses())
r.engine.Use(middleware.CORS())
r.engine.Static("/uploads", "./uploads")
r.engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
if r.ipFilterMiddleware != nil {
r.engine.Use(r.ipFilterMiddleware.Filter())
}
if r.opLogMiddleware != nil {
r.engine.Use(r.opLogMiddleware.Record())
}
v1 := r.engine.Group("/api/v1")
{
authGroup := v1.Group("/auth")
{
authGroup.POST("/register", r.rateLimitMiddleware.Register(), r.authHandler.Register)
authGroup.POST("/bootstrap-admin", r.rateLimitMiddleware.Register(), r.authHandler.BootstrapAdmin)
authGroup.POST("/login", r.rateLimitMiddleware.Login(), r.authHandler.Login)
authGroup.POST("/refresh", r.rateLimitMiddleware.Refresh(), r.authHandler.RefreshToken)
authGroup.GET("/capabilities", r.authHandler.GetAuthCapabilities)
authGroup.GET("/activate", r.authHandler.ActivateEmail)
authGroup.POST("/resend-activation", r.authHandler.ResendActivationEmail)
if r.authHandler.SupportsEmailCodeLogin() {
authGroup.POST("/send-email-code", r.rateLimitMiddleware.Register(), r.authHandler.SendEmailCode)
authGroup.POST("/login/email-code", r.rateLimitMiddleware.Login(), r.authHandler.LoginByEmailCode)
}
if r.smsHandler != nil {
authGroup.POST("/send-code", r.rateLimitMiddleware.Register(), r.smsHandler.SendCode)
authGroup.POST("/login/code", r.rateLimitMiddleware.Login(), r.smsHandler.LoginByCode)
}
if r.passwordResetHandler != nil {
authGroup.POST("/forgot-password", r.passwordResetHandler.ForgotPassword)
authGroup.GET("/reset-password", r.passwordResetHandler.ValidateResetToken)
authGroup.POST("/reset-password", r.passwordResetHandler.ResetPassword)
// 短信密码重置
authGroup.POST("/forgot-password/phone", r.passwordResetHandler.ForgotPasswordByPhone)
authGroup.POST("/reset-password/phone", r.passwordResetHandler.ResetPasswordByPhone)
}
if r.captchaHandler != nil {
authGroup.GET("/captcha", r.captchaHandler.GenerateCaptcha)
authGroup.GET("/captcha/image", r.captchaHandler.GetCaptchaImage)
authGroup.POST("/captcha/verify", r.captchaHandler.VerifyCaptcha)
}
authGroup.GET("/oauth/providers", r.authHandler.GetEnabledOAuthProviders)
authGroup.GET("/oauth/:provider", r.authHandler.OAuthLogin)
authGroup.GET("/oauth/:provider/callback", r.authHandler.OAuthCallback)
authGroup.POST("/oauth/exchange", r.authHandler.OAuthExchange)
}
// 公开主题接口(无需认证)
if r.themeHandler != nil {
themePublic := v1.Group("")
{
themePublic.GET("/theme/active", r.themeHandler.GetActiveTheme)
}
}
protected := v1.Group("")
protected.Use(r.authMiddleware.Required())
protected.Use(r.rateLimitMiddleware.API())
{
protected.GET("/auth/csrf-token", r.authHandler.GetCSRFToken)
protected.POST("/auth/logout", r.authHandler.Logout)
protected.GET("/auth/userinfo", r.authHandler.GetUserInfo)
protected.POST("/users/me/bind-email/code", r.authHandler.SendEmailBindCode)
protected.POST("/users/me/bind-email", r.authHandler.BindEmail)
protected.DELETE("/users/me/bind-email", r.authHandler.UnbindEmail)
protected.POST("/users/me/bind-phone/code", r.authHandler.SendPhoneBindCode)
protected.POST("/users/me/bind-phone", r.authHandler.BindPhone)
protected.DELETE("/users/me/bind-phone", r.authHandler.UnbindPhone)
protected.GET("/users/me/social-accounts", r.authHandler.GetSocialAccounts)
protected.POST("/users/me/bind-social", r.authHandler.BindSocialAccount)
protected.DELETE("/users/me/bind-social/:provider", r.authHandler.UnbindSocialAccount)
users := protected.Group("/users")
{
users.POST("", middleware.RequirePermission("user:manage"), r.userHandler.CreateUser)
users.GET("", r.userHandler.ListUsers)
users.GET("/:id", r.userHandler.GetUser)
users.PUT("/:id", r.userHandler.UpdateUser)
users.DELETE("/:id", middleware.RequirePermission("user:delete"), r.userHandler.DeleteUser)
users.PUT("/:id/password", r.userHandler.UpdatePassword)
users.PUT("/:id/status", middleware.RequirePermission("user:manage"), r.userHandler.UpdateUserStatus)
users.GET("/:id/roles", r.userHandler.GetUserRoles)
users.PUT("/:id/roles", middleware.RequirePermission("user:manage"), r.userHandler.AssignRoles)
if r.avatarHandler != nil {
users.POST("/:id/avatar", r.avatarHandler.UploadAvatar)
}
}
roles := protected.Group("/roles")
roles.Use(middleware.AdminOnly())
{
roles.POST("", r.roleHandler.CreateRole)
roles.GET("", r.roleHandler.ListRoles)
roles.GET("/:id", r.roleHandler.GetRole)
roles.PUT("/:id", r.roleHandler.UpdateRole)
roles.DELETE("/:id", r.roleHandler.DeleteRole)
roles.PUT("/:id/status", r.roleHandler.UpdateRoleStatus)
roles.GET("/:id/permissions", r.roleHandler.GetRolePermissions)
roles.PUT("/:id/permissions", r.roleHandler.AssignPermissions)
}
permissions := protected.Group("/permissions")
permissions.Use(middleware.AdminOnly())
{
permissions.POST("", r.permissionHandler.CreatePermission)
permissions.GET("", r.permissionHandler.ListPermissions)
permissions.GET("/tree", r.permissionHandler.GetPermissionTree)
permissions.GET("/:id", r.permissionHandler.GetPermission)
permissions.PUT("/:id", r.permissionHandler.UpdatePermission)
permissions.DELETE("/:id", r.permissionHandler.DeletePermission)
permissions.PUT("/:id/status", r.permissionHandler.UpdatePermissionStatus)
}
devices := protected.Group("/devices")
{
devices.GET("", r.deviceHandler.GetMyDevices)
devices.POST("", r.deviceHandler.CreateDevice)
devices.GET("/:id", r.deviceHandler.GetDevice)
devices.PUT("/:id", r.deviceHandler.UpdateDevice)
devices.DELETE("/:id", r.deviceHandler.DeleteDevice)
devices.PUT("/:id/status", r.deviceHandler.UpdateDeviceStatus)
devices.POST("/:id/trust", r.deviceHandler.TrustDevice)
devices.POST("/by-device-id/:deviceId/trust", r.deviceHandler.TrustDeviceByDeviceID)
devices.DELETE("/:id/trust", r.deviceHandler.UntrustDevice)
devices.GET("/me/trusted", r.deviceHandler.GetMyTrustedDevices)
devices.POST("/me/logout-others", r.deviceHandler.LogoutAllOtherDevices)
devices.GET("/users/:id", r.deviceHandler.GetUserDevices)
}
adminDevices := protected.Group("/admin/devices")
adminDevices.Use(middleware.AdminOnly())
{
adminDevices.GET("", r.deviceHandler.GetAllDevices)
adminDevices.DELETE("/:id", r.deviceHandler.DeleteDevice)
adminDevices.PUT("/:id/status", r.deviceHandler.UpdateDeviceStatus)
adminDevices.POST("/:id/trust", r.deviceHandler.TrustDevice)
adminDevices.DELETE("/:id/trust", r.deviceHandler.UntrustDevice)
}
if r.logHandler != nil {
logs := protected.Group("/logs")
{
logs.GET("/login/me", r.logHandler.GetMyLoginLogs)
logs.GET("/operation/me", r.logHandler.GetMyOperationLogs)
adminLogs := logs.Group("")
adminLogs.Use(middleware.AdminOnly())
{
adminLogs.GET("/login", r.logHandler.GetLoginLogs)
adminLogs.GET("/login/export", r.logHandler.ExportLoginLogs)
adminLogs.GET("/operation", r.logHandler.GetOperationLogs)
}
}
}
if r.totpHandler != nil {
twoFA := protected.Group("/auth/2fa")
{
twoFA.GET("/status", r.totpHandler.GetTOTPStatus)
twoFA.GET("/setup", r.totpHandler.SetupTOTP)
twoFA.POST("/enable", r.totpHandler.EnableTOTP)
twoFA.POST("/disable", r.totpHandler.DisableTOTP)
twoFA.POST("/verify", r.totpHandler.VerifyTOTP)
}
}
if r.webhookHandler != nil {
webhooks := protected.Group("/webhooks")
{
webhooks.POST("", r.webhookHandler.CreateWebhook)
webhooks.GET("", r.webhookHandler.ListWebhooks)
webhooks.PUT("/:id", r.webhookHandler.UpdateWebhook)
webhooks.DELETE("/:id", r.webhookHandler.DeleteWebhook)
webhooks.GET("/:id/deliveries", r.webhookHandler.GetWebhookDeliveries)
}
}
if r.exportHandler != nil {
adminUsers := protected.Group("/admin/users")
adminUsers.Use(middleware.AdminOnly())
{
adminUsers.GET("/export", r.exportHandler.ExportUsers)
adminUsers.POST("/import", r.exportHandler.ImportUsers)
adminUsers.GET("/import/template", r.exportHandler.GetImportTemplate)
}
}
adminMgmt := protected.Group("/admin/admins")
adminMgmt.Use(middleware.AdminOnly())
{
adminMgmt.GET("", r.userHandler.ListAdmins)
adminMgmt.POST("", r.userHandler.CreateAdmin)
adminMgmt.DELETE("/:id", r.userHandler.DeleteAdmin)
}
if r.statsHandler != nil {
adminStats := protected.Group("/admin/stats")
adminStats.Use(middleware.AdminOnly())
{
adminStats.GET("/dashboard", r.statsHandler.GetDashboard)
adminStats.GET("/users", r.statsHandler.GetUserStats)
}
}
if r.customFieldHandler != nil {
// 自定义字段管理(管理员)
customFields := protected.Group("/custom-fields")
customFields.Use(middleware.AdminOnly())
{
customFields.POST("", r.customFieldHandler.CreateField)
customFields.GET("", r.customFieldHandler.ListFields)
customFields.GET("/:id", r.customFieldHandler.GetField)
customFields.PUT("/:id", r.customFieldHandler.UpdateField)
customFields.DELETE("/:id", r.customFieldHandler.DeleteField)
}
// 用户自定义字段值(用户自己的)
userFields := protected.Group("/users/me/custom-fields")
{
userFields.GET("", r.customFieldHandler.GetUserFieldValues)
userFields.PUT("", r.customFieldHandler.SetUserFieldValues)
}
}
if r.themeHandler != nil {
// 主题管理(管理员)
themes := protected.Group("/themes")
themes.Use(middleware.AdminOnly())
{
themes.POST("", r.themeHandler.CreateTheme)
themes.GET("", r.themeHandler.ListAllThemes)
themes.GET("/default", r.themeHandler.GetDefaultTheme)
themes.PUT("/default/:id", r.themeHandler.SetDefaultTheme)
themes.GET("/:id", r.themeHandler.GetTheme)
themes.PUT("/:id", r.themeHandler.UpdateTheme)
themes.DELETE("/:id", r.themeHandler.DeleteTheme)
}
}
// SSO 单点登录接口(需要认证)
if r.ssoHandler != nil {
sso := protected.Group("/sso")
{
sso.GET("/authorize", r.ssoHandler.Authorize)
sso.POST("/token", r.ssoHandler.Token)
sso.POST("/introspect", r.ssoHandler.Introspect)
sso.POST("/revoke", r.ssoHandler.Revoke)
sso.GET("/userinfo", r.ssoHandler.UserInfo)
}
}
}
}
return r.engine
}
func (r *Router) GetEngine() *gin.Engine {
return r.engine
}