fix: atomic TOTP recovery code consumption with repository-level transaction
- Add ConsumeTOTPRecoveryCode to UserRepository for atomic read-verify-update - Update TOTPService.VerifyTOTP to prefer atomic consumption when available - Update AuthService.verifyTOTPCodeOrRecoveryCode with same pattern - Fix critical bug: ConsumeTOTPRecoveryCode now correctly returns consumed=false on mismatch - Maintain backward compatibility: falls back to non-atomic path if repo doesn't implement interface - Add comprehensive unit tests for atomic consumption path Refs: review-fix-closure-2026-05-28 TOTP recovery code atomicity
This commit is contained in:
@@ -2,11 +2,15 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
)
|
||||
@@ -231,6 +235,63 @@ func (r *UserRepository) UpdateTOTP(ctx context.Context, user *domain.User) erro
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ConsumeTOTPRecoveryCode 原子性地消费一个恢复码
|
||||
// 在事务中验证恢复码并更新,避免并发竞争窗口
|
||||
func (r *UserRepository) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) {
|
||||
var user domain.User
|
||||
var consumed bool
|
||||
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 在事务中重新获取用户
|
||||
// 注意:SQLite 不完全支持 FOR UPDATE,依赖事务隔离
|
||||
if err := tx.First(&user, userID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !user.TOTPEnabled {
|
||||
return errors.New("TOTP 未启用")
|
||||
}
|
||||
|
||||
// 解析存储的哈希恢复码
|
||||
var hashedCodes []string
|
||||
if user.TOTPRecoveryCodes != "" {
|
||||
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
||||
return fmt.Errorf("解析恢复码失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证恢复码(输入会被哈希后与存储的哈希比较)
|
||||
idx, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||||
if !matched {
|
||||
// 不匹配,标记消费失败但不返回错误
|
||||
consumed = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// 从列表中移除已使用的恢复码
|
||||
hashedCodes = append(hashedCodes[:idx], hashedCodes[idx+1:]...)
|
||||
codesJSON, err := json.Marshal(hashedCodes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化恢复码失败: %w", err)
|
||||
}
|
||||
user.TOTPRecoveryCodes = string(codesJSON)
|
||||
|
||||
// 在同一事务中更新
|
||||
if err := tx.Model(&user).Update("totp_recovery_codes", user.TOTPRecoveryCodes).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
consumed = true
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return &user, consumed, nil
|
||||
}
|
||||
|
||||
// UpdatePassword 更新用户密码
|
||||
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
|
||||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
|
||||
|
||||
Reference in New Issue
Block a user