fix: atomic TOTP recovery code consumption with repository-level transaction

- Add ConsumeTOTPRecoveryCode to UserRepository for atomic read-verify-update
- Update TOTPService.VerifyTOTP to prefer atomic consumption when available
- Update AuthService.verifyTOTPCodeOrRecoveryCode with same pattern
- Fix critical bug: ConsumeTOTPRecoveryCode now correctly returns consumed=false on mismatch
- Maintain backward compatibility: falls back to non-atomic path if repo doesn't implement interface
- Add comprehensive unit tests for atomic consumption path

Refs: review-fix-closure-2026-05-28 TOTP recovery code atomicity
This commit is contained in:
Your Name
2026-05-29 12:31:36 +08:00
parent 80c59e2c2c
commit 878ca731f4
4 changed files with 229 additions and 12 deletions

View File

@@ -2,11 +2,15 @@ package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"gorm.io/gorm"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
)
@@ -231,6 +235,63 @@ func (r *UserRepository) UpdateTOTP(ctx context.Context, user *domain.User) erro
}).Error
}
// ConsumeTOTPRecoveryCode 原子性地消费一个恢复码
// 在事务中验证恢复码并更新,避免并发竞争窗口
func (r *UserRepository) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) {
var user domain.User
var consumed bool
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 在事务中重新获取用户
// 注意SQLite 不完全支持 FOR UPDATE依赖事务隔离
if err := tx.First(&user, userID).Error; err != nil {
return err
}
if !user.TOTPEnabled {
return errors.New("TOTP 未启用")
}
// 解析存储的哈希恢复码
var hashedCodes []string
if user.TOTPRecoveryCodes != "" {
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
return fmt.Errorf("解析恢复码失败: %w", err)
}
}
// 验证恢复码(输入会被哈希后与存储的哈希比较)
idx, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
// 不匹配,标记消费失败但不返回错误
consumed = false
return nil
}
// 从列表中移除已使用的恢复码
hashedCodes = append(hashedCodes[:idx], hashedCodes[idx+1:]...)
codesJSON, err := json.Marshal(hashedCodes)
if err != nil {
return fmt.Errorf("序列化恢复码失败: %w", err)
}
user.TOTPRecoveryCodes = string(codesJSON)
// 在同一事务中更新
if err := tx.Model(&user).Update("totp_recovery_codes", user.TOTPRecoveryCodes).Error; err != nil {
return err
}
consumed = true
return nil
})
if err != nil {
return nil, false, err
}
return &user, consumed, nil
}
// UpdatePassword 更新用户密码
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error