fix: P0/P1 security and quality fixes
P0-01: Add ESCAPE clause to LIKE queries in operation_log.go and device.go P0-02: Add atomic Increment to L1Cache and L2Cache interfaces P0-07: Add TOTP verification step after password login P1-01: Sanitize error messages in error.go middleware P1-03: Remove err.Error() from export error messages P1-04: Add error return to CountByResultSince in login_log.go P1-05: Add transactional DeleteCascade to RoleRepository P1-06: Add PasswordChangedAt tracking for JWT token invalidation P1-07: Wrap theme SetDefault in database transaction P1-08: Use config values for database pool parameters P1-09: Add rows.Err() checks in social_account_repo.go P1-10: Validate sortOrder with map in user.go ORDER BY P1-11: Add GORM tags to Announcement struct P1-15: Add pageSize upper limit (100) to device and log handlers
This commit is contained in:
@@ -79,6 +79,9 @@ func (h *DeviceHandler) GetMyDevices(c *gin.Context) {
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
devices, total, err := h.deviceService.GetUserDevices(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
@@ -293,6 +296,9 @@ func (h *DeviceHandler) GetUserDevices(c *gin.Context) {
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
devices, total, err := h.deviceService.GetUserDevices(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
|
||||
@@ -63,7 +63,8 @@ func (h *ExportHandler) ExportUsers(c *gin.Context) {
|
||||
|
||||
data, filename, contentType, err := h.exportService.ExportUsers(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "导出失败: " + err.Error()})
|
||||
// 安全修复:不泄露内部错误详情
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "导出失败"})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,9 @@ func (h *LogHandler) GetMyLoginLogs(c *gin.Context) {
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
logs, total, err := h.loginLogService.GetMyLoginLogs(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
@@ -83,6 +86,9 @@ func (h *LogHandler) GetMyOperationLogs(c *gin.Context) {
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
logs, total, err := h.operationLogService.GetMyOperationLogs(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
|
||||
@@ -74,6 +74,12 @@ func (m *AuthMiddleware) Required() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if m.isPasswordChangedSinceTokenIssued(c.Request.Context(), claims.UserID, claims.PCE) {
|
||||
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "密码已更新,请重新登录"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if !m.isUserActive(c.Request.Context(), claims.UserID) {
|
||||
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "账号不可用,请重新登录"))
|
||||
c.Abort()
|
||||
@@ -97,7 +103,7 @@ func (m *AuthMiddleware) Optional() gin.HandlerFunc {
|
||||
token := m.extractToken(c)
|
||||
if token != "" {
|
||||
claims, err := m.jwt.ValidateAccessToken(token)
|
||||
if err == nil && !m.isJTIBlacklisted(c.Request.Context(), claims.JTI) && m.isUserActive(c.Request.Context(), claims.UserID) {
|
||||
if err == nil && !m.isJTIBlacklisted(c.Request.Context(), claims.JTI) && !m.isPasswordChangedSinceTokenIssued(c.Request.Context(), claims.UserID, claims.PCE) && m.isUserActive(c.Request.Context(), claims.UserID) {
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("token_jti", claims.JTI)
|
||||
@@ -140,6 +146,27 @@ func (m *AuthMiddleware) isJTIBlacklisted(ctx context.Context, jti string) bool
|
||||
return false
|
||||
}
|
||||
|
||||
// isPasswordChangedSinceTokenIssued 检查用户密码是否在令牌发放后已更改
|
||||
// 如果 tokenPCE 为 0(旧令牌),则不检查(向后兼容)
|
||||
func (m *AuthMiddleware) isPasswordChangedSinceTokenIssued(ctx context.Context, userID int64, tokenPCE int64) bool {
|
||||
if tokenPCE == 0 {
|
||||
// 旧令牌没有密码变更时间戳,不拦截
|
||||
return false
|
||||
}
|
||||
|
||||
if m.userRepo == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
user, err := m.userRepo.GetByID(ctx, userID)
|
||||
if err != nil || user.PasswordChangedAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果令牌的 PCE < 用户密码变更时间,说明密码在令牌发放后已更改
|
||||
return tokenPCE < user.PasswordChangedAt.Unix()
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) {
|
||||
if m.userRoleRepo == nil {
|
||||
return nil, nil
|
||||
|
||||
@@ -22,7 +22,9 @@ func ErrorHandler() gin.HandlerFunc {
|
||||
if appErr, ok := err.Err.(*apierrors.ApplicationError); ok {
|
||||
c.JSON(int(appErr.Code), appErr)
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, apierrors.New(http.StatusInternalServerError, "", err.Err.Error()))
|
||||
// 安全修复:未知错误不泄露内部详情,只返回通用消息
|
||||
// 详细错误记录到日志,供调试使用
|
||||
c.JSON(http.StatusInternalServerError, apierrors.New(http.StatusInternalServerError, "", "服务器内部错误"))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -53,6 +53,7 @@ type Claims struct {
|
||||
Type string `json:"type"` // access, refresh
|
||||
Remember bool `json:"remember,omitempty"` // 记住登录标记
|
||||
JTI string `json:"jti"` // JWT ID,用于黑名单
|
||||
PCE int64 `json:"pce,omitempty"` // Password Changed Epoch,密码变更时间戳,用于 token 失效机制
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
@@ -318,8 +319,8 @@ func (j *JWT) GetAlgorithm() string {
|
||||
return j.algorithm
|
||||
}
|
||||
|
||||
// GenerateAccessToken 生成访问令牌(含JTI)
|
||||
func (j *JWT) GenerateAccessToken(userID int64, username string) (string, error) {
|
||||
// GenerateAccessToken 生成访问令牌(含JTI和密码变更时间戳)
|
||||
func (j *JWT) GenerateAccessToken(userID int64, username string, pce int64) (string, error) {
|
||||
if err := j.ensureReady(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -334,6 +335,7 @@ func (j *JWT) GenerateAccessToken(userID int64, username string) (string, error)
|
||||
Username: username,
|
||||
Type: "access",
|
||||
JTI: jti,
|
||||
PCE: pce,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(j.accessTokenExpire)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
@@ -345,8 +347,8 @@ func (j *JWT) GenerateAccessToken(userID int64, username string) (string, error)
|
||||
return token.SignedString(j.signingKey())
|
||||
}
|
||||
|
||||
// GenerateRefreshToken 生成刷新令牌(含JTI)
|
||||
func (j *JWT) GenerateRefreshToken(userID int64, username string) (string, error) {
|
||||
// GenerateRefreshToken 生成刷新令牌(含JTI和密码变更时间戳)
|
||||
func (j *JWT) GenerateRefreshToken(userID int64, username string, pce int64) (string, error) {
|
||||
if err := j.ensureReady(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -361,6 +363,7 @@ func (j *JWT) GenerateRefreshToken(userID int64, username string) (string, error
|
||||
Username: username,
|
||||
Type: "refresh",
|
||||
JTI: jti,
|
||||
PCE: pce,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(j.refreshTokenExpire)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
@@ -382,14 +385,14 @@ func (j *JWT) GetRefreshTokenExpire() time.Duration {
|
||||
return j.refreshTokenExpire
|
||||
}
|
||||
|
||||
// GenerateTokenPair 生成令牌对
|
||||
func (j *JWT) GenerateTokenPair(userID int64, username string) (accessToken, refreshToken string, err error) {
|
||||
accessToken, err = j.GenerateAccessToken(userID, username)
|
||||
// GenerateTokenPair 生成令牌对(含密码变更时间戳)
|
||||
func (j *JWT) GenerateTokenPair(userID int64, username string, pce int64) (accessToken, refreshToken string, err error) {
|
||||
accessToken, err = j.GenerateAccessToken(userID, username, pce)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
refreshToken, err = j.GenerateRefreshToken(userID, username)
|
||||
refreshToken, err = j.GenerateRefreshToken(userID, username, pce)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
@@ -397,17 +400,17 @@ func (j *JWT) GenerateTokenPair(userID int64, username string) (accessToken, ref
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
// GenerateTokenPairWithRemember 生成令牌对(支持记住登录)
|
||||
func (j *JWT) GenerateTokenPairWithRemember(userID int64, username string, remember bool) (accessToken, refreshToken string, err error) {
|
||||
accessToken, err = j.GenerateAccessToken(userID, username)
|
||||
// GenerateTokenPairWithRemember 生成令牌对(支持记住登录,含密码变更时间戳)
|
||||
func (j *JWT) GenerateTokenPairWithRemember(userID int64, username string, remember bool, pce int64) (accessToken, refreshToken string, err error) {
|
||||
accessToken, err = j.GenerateAccessToken(userID, username, pce)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if remember {
|
||||
refreshToken, err = j.GenerateLongLivedRefreshToken(userID, username)
|
||||
refreshToken, err = j.GenerateLongLivedRefreshToken(userID, username, pce)
|
||||
} else {
|
||||
refreshToken, err = j.GenerateRefreshToken(userID, username)
|
||||
refreshToken, err = j.GenerateRefreshToken(userID, username, pce)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
@@ -416,8 +419,8 @@ func (j *JWT) GenerateTokenPairWithRemember(userID int64, username string, remem
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
// GenerateLongLivedRefreshToken 生成长期刷新令牌(记住登录时使用)
|
||||
func (j *JWT) GenerateLongLivedRefreshToken(userID int64, username string) (string, error) {
|
||||
// GenerateLongLivedRefreshToken 生成长期刷新令牌(记住登录时使用,含密码变更时间戳)
|
||||
func (j *JWT) GenerateLongLivedRefreshToken(userID int64, username string, pce int64) (string, error) {
|
||||
if err := j.ensureReady(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -440,6 +443,7 @@ func (j *JWT) GenerateLongLivedRefreshToken(userID int64, username string) (stri
|
||||
Type: "refresh",
|
||||
Remember: true, // 长期会话标记
|
||||
JTI: jti,
|
||||
PCE: pce,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(expireDuration)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
@@ -506,5 +510,5 @@ func (j *JWT) RefreshAccessToken(refreshTokenString string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return j.GenerateAccessToken(claims.UserID, claims.Username)
|
||||
return j.GenerateAccessToken(claims.UserID, claims.Username, claims.PCE)
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ func TestNewJWT_DoesNotPanicOnInvalidLegacyConfig(t *testing.T) {
|
||||
t.Fatal("expected manager instance")
|
||||
}
|
||||
|
||||
if _, err := manager.GenerateAccessToken(1, "tester"); err == nil {
|
||||
if _, err := manager.GenerateAccessToken(1, "tester", 0); err == nil {
|
||||
t.Fatal("expected invalid legacy manager to return error")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestNewJWTWithOptions_RS256(t *testing.T) {
|
||||
t.Fatalf("create rs256 jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
accessToken, refreshToken, err := jwtManager.GenerateTokenPair(42, "rs256-user")
|
||||
accessToken, refreshToken, err := jwtManager.GenerateTokenPair(42, "rs256-user", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("generate token pair failed: %v", err)
|
||||
}
|
||||
@@ -136,7 +136,7 @@ func TestGenerateAccessToken_Success(t *testing.T) {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
token, err := jwtManager.GenerateAccessToken(123, "testuser")
|
||||
token, err := jwtManager.GenerateAccessToken(123, "testuser", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("generate access token failed: %v", err)
|
||||
}
|
||||
@@ -170,7 +170,7 @@ func TestGenerateRefreshToken_Success(t *testing.T) {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
token, err := jwtManager.GenerateRefreshToken(456, "refreshuser")
|
||||
token, err := jwtManager.GenerateRefreshToken(456, "refreshuser", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("generate refresh token failed: %v", err)
|
||||
}
|
||||
@@ -201,7 +201,7 @@ func TestGenerateTokenPair_Success(t *testing.T) {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
accessToken, refreshToken, err := jwtManager.GenerateTokenPair(789, "pairuser")
|
||||
accessToken, refreshToken, err := jwtManager.GenerateTokenPair(789, "pairuser", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("generate token pair failed: %v", err)
|
||||
}
|
||||
@@ -238,7 +238,7 @@ func TestGenerateTokenPairWithRemember_Success(t *testing.T) {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(999, "rememberuser", true)
|
||||
accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(999, "rememberuser", true, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("generate token pair with remember failed: %v", err)
|
||||
}
|
||||
@@ -275,7 +275,7 @@ func TestValidateAccessToken_WrongType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Use a refresh token as if it were an access token
|
||||
refreshToken, err := jwtManager.GenerateRefreshToken(123, "testuser")
|
||||
refreshToken, err := jwtManager.GenerateRefreshToken(123, "testuser", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("generate refresh token failed: %v", err)
|
||||
}
|
||||
@@ -298,7 +298,7 @@ func TestValidateRefreshToken_WrongType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Use an access token as if it were a refresh token
|
||||
accessToken, err := jwtManager.GenerateAccessToken(123, "testuser")
|
||||
accessToken, err := jwtManager.GenerateAccessToken(123, "testuser", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("generate access token failed: %v", err)
|
||||
}
|
||||
@@ -389,7 +389,7 @@ func TestGenerateLongLivedRefreshToken_Success(t *testing.T) {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
token, err := jwtManager.GenerateLongLivedRefreshToken(123, "longliveuser")
|
||||
token, err := jwtManager.GenerateLongLivedRefreshToken(123, "longliveuser", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("generate long lived refresh token failed: %v", err)
|
||||
}
|
||||
@@ -446,7 +446,7 @@ func TestRefreshAccessToken_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
// Generate a valid refresh token first
|
||||
refreshToken, err := jwtManager.GenerateRefreshToken(123, "testuser")
|
||||
refreshToken, err := jwtManager.GenerateRefreshToken(123, "testuser", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("generate refresh token failed: %v", err)
|
||||
}
|
||||
@@ -498,7 +498,7 @@ func TestRefreshAccessToken_AccessTokenProvided(t *testing.T) {
|
||||
}
|
||||
|
||||
// Generate an access token and try to use it as refresh
|
||||
accessToken, err := jwtManager.GenerateAccessToken(123, "testuser")
|
||||
accessToken, err := jwtManager.GenerateAccessToken(123, "testuser", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("generate access token failed: %v", err)
|
||||
}
|
||||
@@ -521,7 +521,7 @@ func TestGenerateTokenPairWithRemember_RememberFalse(t *testing.T) {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(123, "testuser", false)
|
||||
accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(123, "testuser", false, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateTokenPairWithRemember failed: %v", err)
|
||||
}
|
||||
@@ -553,7 +553,7 @@ func TestGenerateTokenPairWithRemember_NoRememberExpireConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should use RefreshTokenExpire when RememberLoginExpire is not set
|
||||
accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(123, "testuser", true)
|
||||
accessToken, refreshToken, err := jwtManager.GenerateTokenPairWithRemember(123, "testuser", true, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateTokenPairWithRemember failed: %v", err)
|
||||
}
|
||||
@@ -583,7 +583,7 @@ func TestGenerateLongLivedRefreshToken_NoRememberExpire(t *testing.T) {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
token, err := jwtManager.GenerateLongLivedRefreshToken(123, "testuser")
|
||||
token, err := jwtManager.GenerateLongLivedRefreshToken(123, "testuser", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateLongLivedRefreshToken failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -59,11 +59,29 @@ func NewDB(cfg *config.Config) (*DB, error) {
|
||||
log.Printf("warn: set busy_timeout failed: %v", err)
|
||||
}
|
||||
|
||||
// 连接池配置:SQLite 本身不支持真正的并发写,但需要控制连接数量
|
||||
sqlDB.SetMaxOpenConns(10)
|
||||
sqlDB.SetMaxIdleConns(5)
|
||||
sqlDB.SetConnMaxLifetime(30 * time.Minute)
|
||||
sqlDB.SetConnMaxIdleTime(10 * time.Minute)
|
||||
// 连接池配置:使用配置文件中的参数
|
||||
maxOpenConns := 10
|
||||
maxIdleConns := 5
|
||||
connMaxLifetime := 30 * time.Minute
|
||||
connMaxIdleTime := 10 * time.Minute
|
||||
if cfg != nil {
|
||||
if cfg.Database.MaxOpenConns > 0 {
|
||||
maxOpenConns = cfg.Database.MaxOpenConns
|
||||
}
|
||||
if cfg.Database.MaxIdleConns > 0 {
|
||||
maxIdleConns = cfg.Database.MaxIdleConns
|
||||
}
|
||||
if cfg.Database.ConnMaxLifetimeMinutes > 0 {
|
||||
connMaxLifetime = time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute
|
||||
}
|
||||
if cfg.Database.ConnMaxIdleTimeMinutes > 0 {
|
||||
connMaxIdleTime = time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute
|
||||
}
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(maxOpenConns)
|
||||
sqlDB.SetMaxIdleConns(maxIdleConns)
|
||||
sqlDB.SetConnMaxLifetime(connMaxLifetime)
|
||||
sqlDB.SetConnMaxIdleTime(connMaxIdleTime)
|
||||
|
||||
log.Println("database: SQLite WAL mode enabled, connection pool configured")
|
||||
|
||||
|
||||
@@ -200,18 +200,18 @@ func (c AnnouncementCondition) validate() error {
|
||||
}
|
||||
|
||||
type Announcement struct {
|
||||
ID int64
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
NotifyMode string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
CreatedBy *int64
|
||||
UpdatedBy *int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Title string `gorm:"type:varchar(255);not null" json:"title"`
|
||||
Content string `gorm:"type:text;not null" json:"content"`
|
||||
Status string `gorm:"type:varchar(20);default:draft;index" json:"status"`
|
||||
NotifyMode string `gorm:"type:varchar(20);default:silent" json:"notify_mode"`
|
||||
Targeting AnnouncementTargeting `gorm:"type:text" json:"targeting"`
|
||||
StartsAt *time.Time `gorm:"type:datetime" json:"starts_at,omitempty"`
|
||||
EndsAt *time.Time `gorm:"type:datetime" json:"ends_at,omitempty"`
|
||||
CreatedBy *int64 `json:"created_by,omitempty"`
|
||||
UpdatedBy *int64 `json:"updated_by,omitempty"`
|
||||
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (a *Announcement) IsActiveAt(now time.Time) bool {
|
||||
|
||||
@@ -62,6 +62,9 @@ type User struct {
|
||||
TOTPEnabled bool `gorm:"default:false" json:"totp_enabled"`
|
||||
TOTPSecret string `gorm:"type:varchar(64)" json:"-"` // Base32 密钥,不返回给前端
|
||||
TOTPRecoveryCodes string `gorm:"type:text" json:"-"` // JSON 编码的恢复码列表
|
||||
|
||||
// PasswordChangedAt 密码更新时间,用于 token 失效机制
|
||||
PasswordChangedAt time.Time `gorm:"type:timestamp;index" json:"password_changed_at,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
|
||||
@@ -104,16 +104,19 @@ func (r *LoginLogRepository) DeleteOlderThan(ctx context.Context, days int) erro
|
||||
|
||||
// CountByResultSince 统计指定时间之后特定结果的登录次数
|
||||
// success=true 统计成功次数,false 统计失败次数
|
||||
func (r *LoginLogRepository) CountByResultSince(ctx context.Context, success bool, since time.Time) int64 {
|
||||
func (r *LoginLogRepository) CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error) {
|
||||
status := 0 // 失败
|
||||
if success {
|
||||
status = 1 // 成功
|
||||
}
|
||||
var count int64
|
||||
r.db.WithContext(ctx).Model(&domain.LoginLog{}).
|
||||
err := r.db.WithContext(ctx).Model(&domain.LoginLog{}).
|
||||
Where("status = ? AND created_at >= ?", status, since).
|
||||
Count(&count)
|
||||
return count
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ListAllForExport 获取所有登录日志(用于导出,无分页)
|
||||
|
||||
@@ -263,10 +263,14 @@ func TestLoginLogRepositoryQueriesAndRetention(t *testing.T) {
|
||||
t.Fatalf("expected 2 recent logs, got total=%d len=%d", total, len(recentLogs))
|
||||
}
|
||||
|
||||
if count := repo.CountByResultSince(ctx, true, now.Add(-2*time.Hour)); count != 1 {
|
||||
if count, err := repo.CountByResultSince(ctx, true, now.Add(-2*time.Hour)); err != nil {
|
||||
t.Fatalf("CountByResultSince failed: %v", err)
|
||||
} else if count != 1 {
|
||||
t.Fatalf("expected 1 recent success login, got %d", count)
|
||||
}
|
||||
if count := repo.CountByResultSince(ctx, false, now.Add(-2*time.Hour)); count != 1 {
|
||||
if count, err := repo.CountByResultSince(ctx, false, now.Add(-2*time.Hour)); err != nil {
|
||||
t.Fatalf("CountByResultSince failed: %v", err)
|
||||
} else if count != 1 {
|
||||
t.Fatalf("expected 1 recent failed login, got %d", count)
|
||||
}
|
||||
|
||||
|
||||
@@ -48,6 +48,18 @@ func (r *RoleRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.Role{}, id).Error
|
||||
}
|
||||
|
||||
// DeleteCascade 级联删除角色(同时删除角色权限关联)
|
||||
func (r *RoleRepository) DeleteCascade(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 先删除角色权限关联
|
||||
if err := tx.Where("role_id = ?", id).Delete(&domain.RolePermission{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 再删除角色
|
||||
return tx.Delete(&domain.Role{}, id).Error
|
||||
})
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取角色
|
||||
func (r *RoleRepository) GetByID(ctx context.Context, id int64) (*domain.Role, error) {
|
||||
var role domain.Role
|
||||
|
||||
@@ -204,6 +204,9 @@ func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID in
|
||||
}
|
||||
accounts = append(accounts, &account)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return accounts, nil
|
||||
}
|
||||
@@ -290,6 +293,9 @@ func (r *SocialAccountRepositoryImpl) List(ctx context.Context, offset, limit in
|
||||
}
|
||||
accounts = append(accounts, &account)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return accounts, total, nil
|
||||
}
|
||||
|
||||
@@ -89,11 +89,13 @@ func (r *ThemeConfigRepository) ListAll(ctx context.Context) ([]*domain.ThemeCon
|
||||
|
||||
// SetDefault 设置默认主题
|
||||
func (r *ThemeConfigRepository) SetDefault(ctx context.Context, id int64) error {
|
||||
// 先清除所有默认标记
|
||||
if err := r.db.WithContext(ctx).Model(&domain.ThemeConfig{}).Where("is_default = ?", true).Update("is_default", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 设置新的默认主题
|
||||
return r.db.WithContext(ctx).Model(&domain.ThemeConfig{}).Where("id = ?", id).Update("is_default", true).Error
|
||||
// 使用事务确保原子性:先清除所有默认标记,再设置新默认
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 先清除所有默认标记
|
||||
if err := tx.Model(&domain.ThemeConfig{}).Where("is_default = ?", true).Update("is_default", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置新的默认主题
|
||||
return tx.Model(&domain.ThemeConfig{}).Where("id = ?", id).Update("is_default", true).Error
|
||||
})
|
||||
}
|
||||
|
||||
@@ -326,8 +326,9 @@ func (r *UserRepository) AdvancedSearch(ctx context.Context, filter *AdvancedFil
|
||||
sortBy = filter.SortBy
|
||||
}
|
||||
}
|
||||
if filter.SortOrder == "asc" {
|
||||
sortOrder = "ASC"
|
||||
allowedSortOrders := map[string]bool{"asc": true, "desc": true}
|
||||
if allowedSortOrders[strings.ToLower(filter.SortOrder)] {
|
||||
sortOrder = strings.ToUpper(filter.SortOrder)
|
||||
}
|
||||
query = query.Order(sortBy + " " + sortOrder)
|
||||
|
||||
@@ -404,8 +405,9 @@ func (r *UserRepository) ListCursor(ctx context.Context, filter *AdvancedFilter,
|
||||
}
|
||||
|
||||
sortOrder := "DESC"
|
||||
if filter.SortOrder == "asc" {
|
||||
sortOrder = "ASC"
|
||||
allowedSortOrders := map[string]bool{"asc": true, "desc": true}
|
||||
if allowedSortOrders[strings.ToLower(filter.SortOrder)] {
|
||||
sortOrder = strings.ToUpper(filter.SortOrder)
|
||||
}
|
||||
|
||||
orderClause := sortBy + " " + sortOrder + ", id " + sortOrder
|
||||
|
||||
@@ -1369,10 +1369,12 @@ func (s *AuthService) generateLoginResponse(ctx context.Context, user *domain.Us
|
||||
var accessToken, refreshToken string
|
||||
var err error
|
||||
|
||||
pce := user.PasswordChangedAt.Unix()
|
||||
|
||||
if remember {
|
||||
accessToken, refreshToken, err = s.jwtManager.GenerateTokenPairWithRemember(user.ID, user.Username, remember)
|
||||
accessToken, refreshToken, err = s.jwtManager.GenerateTokenPairWithRemember(user.ID, user.Username, remember, pce)
|
||||
} else {
|
||||
accessToken, refreshToken, err = s.jwtManager.GenerateTokenPair(user.ID, user.Username)
|
||||
accessToken, refreshToken, err = s.jwtManager.GenerateTokenPair(user.ID, user.Username, pce)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -181,13 +181,8 @@ func (s *RoleService) DeleteRole(ctx context.Context, roleID int64) error {
|
||||
return errors.New("存在子角色,无法删除")
|
||||
}
|
||||
|
||||
// 删除角色权限关联
|
||||
if err := s.rolePermissionRepo.DeleteByRoleID(ctx, roleID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除角色
|
||||
return s.roleRepo.Delete(ctx, roleID)
|
||||
// 级联删除角色及其权限关联(在事务中执行)
|
||||
return s.roleRepo.DeleteCascade(ctx, roleID)
|
||||
}
|
||||
|
||||
// GetRole 获取角色信息
|
||||
|
||||
@@ -15,7 +15,7 @@ type statsUserRepository interface {
|
||||
}
|
||||
|
||||
type statsLoginLogRepository interface {
|
||||
CountByResultSince(ctx context.Context, success bool, since time.Time) int64
|
||||
CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error)
|
||||
}
|
||||
|
||||
// StatsService 统计服务
|
||||
@@ -115,9 +115,15 @@ func (s *StatsService) GetDashboardStats(ctx context.Context) (*DashboardStats,
|
||||
// 今日登录成功/失败
|
||||
today := daysAgo(0)
|
||||
if s.loginLogRepo != nil {
|
||||
loginStats.LoginsTodaySuccess = s.loginLogRepo.CountByResultSince(ctx, true, today)
|
||||
loginStats.LoginsTodayFailed = s.loginLogRepo.CountByResultSince(ctx, false, today)
|
||||
loginStats.LoginsWeek = s.loginLogRepo.CountByResultSince(ctx, true, daysAgo(7))
|
||||
if successCount, err := s.loginLogRepo.CountByResultSince(ctx, true, today); err == nil {
|
||||
loginStats.LoginsTodaySuccess = successCount
|
||||
}
|
||||
if failedCount, err := s.loginLogRepo.CountByResultSince(ctx, false, today); err == nil {
|
||||
loginStats.LoginsTodayFailed = failedCount
|
||||
}
|
||||
if weekCount, err := s.loginLogRepo.CountByResultSince(ctx, true, daysAgo(7)); err == nil {
|
||||
loginStats.LoginsWeek = weekCount
|
||||
}
|
||||
}
|
||||
|
||||
return &DashboardStats{
|
||||
|
||||
@@ -51,11 +51,11 @@ type mockStatsLoginLogRepoInternal struct {
|
||||
weekCount int64
|
||||
}
|
||||
|
||||
func (m *mockStatsLoginLogRepoInternal) CountByResultSince(ctx context.Context, success bool, since time.Time) int64 {
|
||||
func (m *mockStatsLoginLogRepoInternal) CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error) {
|
||||
if success {
|
||||
return m.successCount
|
||||
return m.successCount, nil
|
||||
}
|
||||
return m.failedCount
|
||||
return m.failedCount, nil
|
||||
}
|
||||
|
||||
func TestStatsService_GetDashboardStats_Internal(t *testing.T) {
|
||||
|
||||
@@ -52,11 +52,11 @@ type mockStatsLoginLogRepo struct {
|
||||
weekCount int64
|
||||
}
|
||||
|
||||
func (m *mockStatsLoginLogRepo) CountByResultSince(ctx context.Context, success bool, since time.Time) int64 {
|
||||
func (m *mockStatsLoginLogRepo) CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error) {
|
||||
if success {
|
||||
return m.successCount
|
||||
return m.successCount, nil
|
||||
}
|
||||
return m.failedCount
|
||||
return m.failedCount, nil
|
||||
}
|
||||
|
||||
func TestStatsService_GetUserStats(t *testing.T) {
|
||||
|
||||
@@ -141,6 +141,7 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassw
|
||||
|
||||
// 更新密码(使用同一哈希值)
|
||||
user.Password = newHashedPassword
|
||||
user.PasswordChangedAt = time.Now()
|
||||
return s.userRepo.Update(ctx, user)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user