feat: atomic TOTP verification for DisableTOTP
- Add atomicTOTPVerifier interface for atomic TOTP/recovery code verification - Implement VerifyTOTPOrRecoveryCode in UserRepository with transaction - Update DisableTOTP to prefer atomic verification path - Add unit tests for atomic verification success/failure paths - Maintain backward compatibility with non-atomic fallback Refs: TOTP verification atomicity completion
This commit is contained in:
@@ -292,6 +292,58 @@ func (r *UserRepository) ConsumeTOTPRecoveryCode(ctx context.Context, userID int
|
||||
return &user, consumed, nil
|
||||
}
|
||||
|
||||
// VerifyTOTPOrRecoveryCode 原子性地验证 TOTP 码或恢复码(不消费恢复码)
|
||||
// 返回 (true, nil) 表示验证成功
|
||||
// 返回 (false, nil) 表示验证失败(码不匹配)
|
||||
// 返回 (false, error) 表示执行出错
|
||||
func (r *UserRepository) VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) {
|
||||
var user domain.User
|
||||
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.First(&user, userID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !user.TOTPEnabled {
|
||||
return errors.New("TOTP 未启用")
|
||||
}
|
||||
|
||||
// 先验证 TOTP 码
|
||||
manager := auth.NewTOTPManager()
|
||||
if manager.ValidateCode(user.TOTPSecret, code) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TOTP 码无效,尝试验证恢复码
|
||||
var hashedCodes []string
|
||||
if user.TOTPRecoveryCodes != "" {
|
||||
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
||||
return fmt.Errorf("解析恢复码失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||||
if !matched {
|
||||
// 恢复码也不匹配,标记验证失败
|
||||
return errVerificationFailed
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err == errVerificationFailed {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// errVerificationFailed 标记验证失败的内部错误
|
||||
var errVerificationFailed = errors.New("verification failed")
|
||||
|
||||
// UpdatePassword 更新用户密码
|
||||
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
|
||||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
|
||||
|
||||
@@ -10,11 +10,16 @@ import (
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// TOTPService manages 2FA setup, enable/disable, and verification.
|
||||
// atomicTOTPRecoveryCodeConsumer 原子性恢复码消费接口
|
||||
type atomicTOTPRecoveryCodeConsumer interface {
|
||||
ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error)
|
||||
}
|
||||
|
||||
// atomicTOTPVerifier 原子性 TOTP/恢复码验证接口(不消费恢复码)
|
||||
type atomicTOTPVerifier interface {
|
||||
VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error)
|
||||
}
|
||||
|
||||
// TOTPService manages 2FA setup, enable/disable, and verification.
|
||||
type TOTPService struct {
|
||||
userRepo userRepositoryInterface
|
||||
@@ -99,12 +104,24 @@ func (s *TOTPService) EnableTOTP(ctx context.Context, userID int64, code string)
|
||||
func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string) error {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
||||
return fmt.Errorf("用户不存在")
|
||||
}
|
||||
if !user.TOTPEnabled {
|
||||
return errors.New("2FA \u672a\u542f\u7528")
|
||||
return errors.New("2FA 未启用")
|
||||
}
|
||||
|
||||
// 尝试原子性验证(如果 repo 支持)
|
||||
if verifier, ok := s.userRepo.(atomicTOTPVerifier); ok {
|
||||
valid, err := verifier.VerifyTOTPOrRecoveryCode(ctx, userID, code)
|
||||
if err != nil {
|
||||
return fmt.Errorf("验证失败: %w", err)
|
||||
}
|
||||
if !valid {
|
||||
return errors.New("验证码或恢复码错误")
|
||||
}
|
||||
// 验证通过,继续禁用
|
||||
} else {
|
||||
// 降级到非原子性验证(兼容性模式)
|
||||
valid := s.totpManager.ValidateCode(user.TOTPSecret, code)
|
||||
if !valid {
|
||||
var hashedCodes []string
|
||||
@@ -118,6 +135,7 @@ func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string
|
||||
return errors.New("验证码或恢复码错误")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
user.TOTPEnabled = false
|
||||
user.TOTPSecret = ""
|
||||
|
||||
@@ -36,6 +36,8 @@ type totpTestRepo struct {
|
||||
updateTOTPErr error
|
||||
consumeRecoveryCodeErr error
|
||||
consumeRecoveryCodeCalled bool
|
||||
verifyTOTPOrRecoveryCodeErr error
|
||||
verifyTOTPOrRecoveryCodeCalled bool
|
||||
}
|
||||
|
||||
func (r *totpTestRepo) Create(ctx context.Context, user *domain.User) error { return nil }
|
||||
@@ -89,9 +91,11 @@ func (r *totpTestRepo) ExistsByEmail(ctx context.Context, email string) (bool, e
|
||||
func (r *totpTestRepo) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *totpTestRepo) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
|
||||
return nil, 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) {
|
||||
r.consumeRecoveryCodeCalled = true
|
||||
if r.consumeRecoveryCodeErr != nil {
|
||||
@@ -119,6 +123,34 @@ func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64
|
||||
return ©User, true, nil
|
||||
}
|
||||
|
||||
func (r *totpTestRepo) VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) {
|
||||
r.verifyTOTPOrRecoveryCodeCalled = true
|
||||
if r.verifyTOTPOrRecoveryCodeErr != nil {
|
||||
return false, r.verifyTOTPOrRecoveryCodeErr
|
||||
}
|
||||
if r.user == nil || r.user.ID != userID {
|
||||
return false, errors.New("not found")
|
||||
}
|
||||
if !r.user.TOTPEnabled {
|
||||
return false, errors.New("TOTP not enabled")
|
||||
}
|
||||
|
||||
// 尝试验证 TOTP 码(简化:只检查是否为特定测试码)
|
||||
if code == "123456" || code == "654321" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 尝试验证恢复码
|
||||
var hashedCodes []string
|
||||
if strings.TrimSpace(r.user.TOTPRecoveryCodes) != "" {
|
||||
if err := json.Unmarshal([]byte(r.user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
||||
return false, fmt.Errorf("解析恢复码失败: %w", err)
|
||||
}
|
||||
}
|
||||
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||||
return matched, nil
|
||||
}
|
||||
|
||||
func mustMarshalJSONFromHelper(value any) string {
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
@@ -205,6 +237,59 @@ func TestTOTPService_ConsumesHashedRecoveryCodeOnVerify(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPService_DisableTOTP_UsesAtomicVerificationPath(t *testing.T) {
|
||||
repo := &totpTestRepo{
|
||||
user: &domain.User{
|
||||
ID: 10,
|
||||
Username: "totp-user",
|
||||
TOTPEnabled: true,
|
||||
TOTPSecret: "test-secret",
|
||||
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
|
||||
},
|
||||
}
|
||||
svc := NewTOTPService(repo)
|
||||
|
||||
// 使用测试恢复码禁用 TOTP
|
||||
if err := svc.DisableTOTP(context.Background(), 10, "RECOVERY-1"); err != nil {
|
||||
t.Fatalf("expected disable to succeed with recovery code, got: %v", err)
|
||||
}
|
||||
|
||||
if !repo.verifyTOTPOrRecoveryCodeCalled {
|
||||
t.Fatal("expected atomic verification path to be used")
|
||||
}
|
||||
|
||||
if repo.user.TOTPEnabled {
|
||||
t.Fatal("expected TOTP to be disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPService_DisableTOTP_AtomicVerificationFailsOnWrongCode(t *testing.T) {
|
||||
repo := &totpTestRepo{
|
||||
user: &domain.User{
|
||||
ID: 11,
|
||||
Username: "totp-user",
|
||||
TOTPEnabled: true,
|
||||
TOTPSecret: "test-secret",
|
||||
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
|
||||
},
|
||||
}
|
||||
svc := NewTOTPService(repo)
|
||||
|
||||
// 使用错误的恢复码
|
||||
err := svc.DisableTOTP(context.Background(), 11, "WRONG-CODE")
|
||||
if err == nil {
|
||||
t.Fatal("expected disable to fail with wrong code")
|
||||
}
|
||||
|
||||
if !repo.verifyTOTPOrRecoveryCodeCalled {
|
||||
t.Fatal("expected atomic verification path to be used")
|
||||
}
|
||||
|
||||
if !repo.user.TOTPEnabled {
|
||||
t.Fatal("expected TOTP to remain enabled after failed verification")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPService_DisableAcceptsHashedRecoveryCode(t *testing.T) {
|
||||
repo := &totpTestRepo{
|
||||
user: &domain.User{
|
||||
|
||||
Reference in New Issue
Block a user