feat: atomic TOTP verification for DisableTOTP

- Add atomicTOTPVerifier interface for atomic TOTP/recovery code verification
- Implement VerifyTOTPOrRecoveryCode in UserRepository with transaction
- Update DisableTOTP to prefer atomic verification path
- Add unit tests for atomic verification success/failure paths
- Maintain backward compatibility with non-atomic fallback

Refs: TOTP verification atomicity completion
This commit is contained in:
Your Name
2026-05-29 12:47:05 +08:00
parent 880b64f5ff
commit 363c77d020
3 changed files with 172 additions and 17 deletions

View File

@@ -292,6 +292,58 @@ func (r *UserRepository) ConsumeTOTPRecoveryCode(ctx context.Context, userID int
return &user, consumed, nil
}
// VerifyTOTPOrRecoveryCode 原子性地验证 TOTP 码或恢复码(不消费恢复码)
// 返回 (true, nil) 表示验证成功
// 返回 (false, nil) 表示验证失败(码不匹配)
// 返回 (false, error) 表示执行出错
func (r *UserRepository) VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) {
var user domain.User
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.First(&user, userID).Error; err != nil {
return err
}
if !user.TOTPEnabled {
return errors.New("TOTP 未启用")
}
// 先验证 TOTP 码
manager := auth.NewTOTPManager()
if manager.ValidateCode(user.TOTPSecret, code) {
return nil
}
// TOTP 码无效,尝试验证恢复码
var hashedCodes []string
if user.TOTPRecoveryCodes != "" {
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
return fmt.Errorf("解析恢复码失败: %w", err)
}
}
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
// 恢复码也不匹配,标记验证失败
return errVerificationFailed
}
return nil
})
if err == errVerificationFailed {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// errVerificationFailed 标记验证失败的内部错误
var errVerificationFailed = errors.New("verification failed")
// UpdatePassword 更新用户密码
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error