feat: atomic TOTP verification for DisableTOTP

- Add atomicTOTPVerifier interface for atomic TOTP/recovery code verification
- Implement VerifyTOTPOrRecoveryCode in UserRepository with transaction
- Update DisableTOTP to prefer atomic verification path
- Add unit tests for atomic verification success/failure paths
- Maintain backward compatibility with non-atomic fallback

Refs: TOTP verification atomicity completion
This commit is contained in:
Your Name
2026-05-29 12:47:05 +08:00
parent 880b64f5ff
commit 363c77d020
3 changed files with 172 additions and 17 deletions

View File

@@ -292,6 +292,58 @@ func (r *UserRepository) ConsumeTOTPRecoveryCode(ctx context.Context, userID int
return &user, consumed, nil
}
// VerifyTOTPOrRecoveryCode 原子性地验证 TOTP 码或恢复码(不消费恢复码)
// 返回 (true, nil) 表示验证成功
// 返回 (false, nil) 表示验证失败(码不匹配)
// 返回 (false, error) 表示执行出错
func (r *UserRepository) VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) {
var user domain.User
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.First(&user, userID).Error; err != nil {
return err
}
if !user.TOTPEnabled {
return errors.New("TOTP 未启用")
}
// 先验证 TOTP 码
manager := auth.NewTOTPManager()
if manager.ValidateCode(user.TOTPSecret, code) {
return nil
}
// TOTP 码无效,尝试验证恢复码
var hashedCodes []string
if user.TOTPRecoveryCodes != "" {
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
return fmt.Errorf("解析恢复码失败: %w", err)
}
}
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
// 恢复码也不匹配,标记验证失败
return errVerificationFailed
}
return nil
})
if err == errVerificationFailed {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// errVerificationFailed 标记验证失败的内部错误
var errVerificationFailed = errors.New("verification failed")
// UpdatePassword 更新用户密码
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error

View File

@@ -10,11 +10,16 @@ import (
"github.com/user-management-system/internal/domain"
)
// TOTPService manages 2FA setup, enable/disable, and verification.
// atomicTOTPRecoveryCodeConsumer 原子性恢复码消费接口
type atomicTOTPRecoveryCodeConsumer interface {
ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error)
}
// atomicTOTPVerifier 原子性 TOTP/恢复码验证接口(不消费恢复码)
type atomicTOTPVerifier interface {
VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error)
}
// TOTPService manages 2FA setup, enable/disable, and verification.
type TOTPService struct {
userRepo userRepositoryInterface
@@ -99,12 +104,24 @@ func (s *TOTPService) EnableTOTP(ctx context.Context, userID int64, code string)
func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
return fmt.Errorf("用户不存在")
}
if !user.TOTPEnabled {
return errors.New("2FA \u672a\u542f\u7528")
return errors.New("2FA 未启用")
}
// 尝试原子性验证(如果 repo 支持)
if verifier, ok := s.userRepo.(atomicTOTPVerifier); ok {
valid, err := verifier.VerifyTOTPOrRecoveryCode(ctx, userID, code)
if err != nil {
return fmt.Errorf("验证失败: %w", err)
}
if !valid {
return errors.New("验证码或恢复码错误")
}
// 验证通过,继续禁用
} else {
// 降级到非原子性验证(兼容性模式)
valid := s.totpManager.ValidateCode(user.TOTPSecret, code)
if !valid {
var hashedCodes []string
@@ -118,6 +135,7 @@ func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string
return errors.New("验证码或恢复码错误")
}
}
}
user.TOTPEnabled = false
user.TOTPSecret = ""

View File

@@ -36,6 +36,8 @@ type totpTestRepo struct {
updateTOTPErr error
consumeRecoveryCodeErr error
consumeRecoveryCodeCalled bool
verifyTOTPOrRecoveryCodeErr error
verifyTOTPOrRecoveryCodeCalled bool
}
func (r *totpTestRepo) Create(ctx context.Context, user *domain.User) error { return nil }
@@ -89,9 +91,11 @@ func (r *totpTestRepo) ExistsByEmail(ctx context.Context, email string) (bool, e
func (r *totpTestRepo) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *totpTestRepo) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
return nil, 0, errors.New("not implemented")
}
func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) {
r.consumeRecoveryCodeCalled = true
if r.consumeRecoveryCodeErr != nil {
@@ -119,6 +123,34 @@ func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64
return &copyUser, true, nil
}
func (r *totpTestRepo) VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) {
r.verifyTOTPOrRecoveryCodeCalled = true
if r.verifyTOTPOrRecoveryCodeErr != nil {
return false, r.verifyTOTPOrRecoveryCodeErr
}
if r.user == nil || r.user.ID != userID {
return false, errors.New("not found")
}
if !r.user.TOTPEnabled {
return false, errors.New("TOTP not enabled")
}
// 尝试验证 TOTP 码(简化:只检查是否为特定测试码)
if code == "123456" || code == "654321" {
return true, nil
}
// 尝试验证恢复码
var hashedCodes []string
if strings.TrimSpace(r.user.TOTPRecoveryCodes) != "" {
if err := json.Unmarshal([]byte(r.user.TOTPRecoveryCodes), &hashedCodes); err != nil {
return false, fmt.Errorf("解析恢复码失败: %w", err)
}
}
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
return matched, nil
}
func mustMarshalJSONFromHelper(value any) string {
payload, err := json.Marshal(value)
if err != nil {
@@ -205,6 +237,59 @@ func TestTOTPService_ConsumesHashedRecoveryCodeOnVerify(t *testing.T) {
}
}
func TestTOTPService_DisableTOTP_UsesAtomicVerificationPath(t *testing.T) {
repo := &totpTestRepo{
user: &domain.User{
ID: 10,
Username: "totp-user",
TOTPEnabled: true,
TOTPSecret: "test-secret",
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
},
}
svc := NewTOTPService(repo)
// 使用测试恢复码禁用 TOTP
if err := svc.DisableTOTP(context.Background(), 10, "RECOVERY-1"); err != nil {
t.Fatalf("expected disable to succeed with recovery code, got: %v", err)
}
if !repo.verifyTOTPOrRecoveryCodeCalled {
t.Fatal("expected atomic verification path to be used")
}
if repo.user.TOTPEnabled {
t.Fatal("expected TOTP to be disabled")
}
}
func TestTOTPService_DisableTOTP_AtomicVerificationFailsOnWrongCode(t *testing.T) {
repo := &totpTestRepo{
user: &domain.User{
ID: 11,
Username: "totp-user",
TOTPEnabled: true,
TOTPSecret: "test-secret",
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
},
}
svc := NewTOTPService(repo)
// 使用错误的恢复码
err := svc.DisableTOTP(context.Background(), 11, "WRONG-CODE")
if err == nil {
t.Fatal("expected disable to fail with wrong code")
}
if !repo.verifyTOTPOrRecoveryCodeCalled {
t.Fatal("expected atomic verification path to be used")
}
if !repo.user.TOTPEnabled {
t.Fatal("expected TOTP to remain enabled after failed verification")
}
}
func TestTOTPService_DisableAcceptsHashedRecoveryCode(t *testing.T) {
repo := &totpTestRepo{
user: &domain.User{