package service import ( "context" "encoding/json" "errors" "fmt" "strings" "testing" "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/domain" ) func mustHashRecoveryCode(t *testing.T, code string) string { t.Helper() hashed, err := auth.HashRecoveryCode(code) if err != nil { t.Fatalf("hash recovery code: %v", err) } return hashed } func mustMarshalJSON(t *testing.T, value any) string { t.Helper() payload, err := json.Marshal(value) if err != nil { t.Fatalf("marshal json: %v", err) } return string(payload) } type totpTestRepo struct { user *domain.User getErr error updateTOTPErr error consumeRecoveryCodeErr error consumeRecoveryCodeCalled bool verifyTOTPOrRecoveryCodeErr error verifyTOTPOrRecoveryCodeCalled bool } func (r *totpTestRepo) Create(ctx context.Context, user *domain.User) error { return nil } func (r *totpTestRepo) Update(ctx context.Context, user *domain.User) error { return nil } func (r *totpTestRepo) UpdateTOTP(ctx context.Context, user *domain.User) error { if r.updateTOTPErr != nil { return r.updateTOTPErr } copyUser := *user r.user = ©User return nil } func (r *totpTestRepo) Delete(ctx context.Context, id int64) error { return nil } func (r *totpTestRepo) GetByID(ctx context.Context, id int64) (*domain.User, error) { if r.getErr != nil { return nil, r.getErr } if r.user == nil || r.user.ID != id { return nil, errors.New("not found") } copyUser := *r.user return ©User, nil } func (r *totpTestRepo) GetByUsername(ctx context.Context, username string) (*domain.User, error) { return nil, errors.New("not implemented") } func (r *totpTestRepo) GetByEmail(ctx context.Context, email string) (*domain.User, error) { return nil, errors.New("not implemented") } func (r *totpTestRepo) GetByPhone(ctx context.Context, phone string) (*domain.User, error) { return nil, errors.New("not implemented") } func (r *totpTestRepo) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) { return nil, 0, errors.New("not implemented") } func (r *totpTestRepo) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) { return nil, 0, errors.New("not implemented") } func (r *totpTestRepo) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error { return errors.New("not implemented") } func (r *totpTestRepo) UpdateLastLogin(ctx context.Context, id int64, ip string) error { return errors.New("not implemented") } func (r *totpTestRepo) ExistsByUsername(ctx context.Context, username string) (bool, error) { return false, errors.New("not implemented") } func (r *totpTestRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { return false, errors.New("not implemented") } func (r *totpTestRepo) ExistsByPhone(ctx context.Context, phone string) (bool, error) { return false, errors.New("not implemented") } func (r *totpTestRepo) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) { return nil, 0, errors.New("not implemented") } func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) { r.consumeRecoveryCodeCalled = true if r.consumeRecoveryCodeErr != nil { return nil, false, r.consumeRecoveryCodeErr } if r.user == nil || r.user.ID != userID { return nil, false, errors.New("not found") } var hashedCodes []string if strings.TrimSpace(r.user.TOTPRecoveryCodes) != "" { if err := json.Unmarshal([]byte(r.user.TOTPRecoveryCodes), &hashedCodes); err != nil { return nil, false, fmt.Errorf("解析恢复码失败: %w", err) } } idx, matched := auth.VerifyRecoveryCode(code, hashedCodes) if !matched { return nil, false, nil } copyUser := *r.user hashedCodes = append(hashedCodes[:idx], hashedCodes[idx+1:]...) copyUser.TOTPRecoveryCodes = mustMarshalJSONFromHelper(hashedCodes) r.user = ©User return ©User, true, nil } func (r *totpTestRepo) VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) { r.verifyTOTPOrRecoveryCodeCalled = true if r.verifyTOTPOrRecoveryCodeErr != nil { return false, r.verifyTOTPOrRecoveryCodeErr } if r.user == nil || r.user.ID != userID { return false, errors.New("not found") } if !r.user.TOTPEnabled { return false, errors.New("TOTP not enabled") } // 尝试验证 TOTP 码(简化:只检查是否为特定测试码) if code == "123456" || code == "654321" { return true, nil } // 尝试验证恢复码 var hashedCodes []string if strings.TrimSpace(r.user.TOTPRecoveryCodes) != "" { if err := json.Unmarshal([]byte(r.user.TOTPRecoveryCodes), &hashedCodes); err != nil { return false, fmt.Errorf("解析恢复码失败: %w", err) } } _, matched := auth.VerifyRecoveryCode(code, hashedCodes) return matched, nil } func mustMarshalJSONFromHelper(value any) string { payload, err := json.Marshal(value) if err != nil { panic(err) } return string(payload) } func TestTOTPService_ReturnsDecodeErrorForCorruptedRecoveryCodes(t *testing.T) { repo := &totpTestRepo{user: &domain.User{ ID: 42, Username: "totp-user", TOTPEnabled: true, TOTPSecret: "invalid-secret", TOTPRecoveryCodes: "not-json", }} svc := NewTOTPService(repo) err := svc.VerifyTOTP(context.Background(), 42, "recovery-code") if err == nil { t.Fatal("expected corrupted recovery-code payload to fail") } if !strings.Contains(err.Error(), "解析恢复码失败") { t.Fatalf("expected decode error, got: %v", err) } } func TestTOTPService_ReturnsAtomicConsumptionErrorAfterRecoveryCodeConsumption(t *testing.T) { repo := &totpTestRepo{ user: &domain.User{ ID: 7, Username: "totp-user", TOTPEnabled: true, TOTPSecret: "invalid-secret", TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}), }, consumeRecoveryCodeErr: errors.New("write failed"), } svc := NewTOTPService(repo) err := svc.VerifyTOTP(context.Background(), 7, "RECOVERY-1") if err == nil { t.Fatal("expected update failure to be returned") } if !repo.consumeRecoveryCodeCalled { t.Fatal("expected atomic consumption path to be invoked") } if !strings.Contains(err.Error(), "消费恢复码失败") { t.Fatalf("expected atomic consume error, got: %v", err) } } func TestTOTPService_ConsumesHashedRecoveryCodeOnVerify(t *testing.T) { repo := &totpTestRepo{ user: &domain.User{ ID: 8, Username: "totp-user", TOTPEnabled: true, TOTPSecret: "invalid-secret", TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1"), mustHashRecoveryCode(t, "RECOVERY-2")}), }, } svc := NewTOTPService(repo) if err := svc.VerifyTOTP(context.Background(), 8, "RECOVERY-1"); err != nil { t.Fatalf("expected hashed recovery code to verify, got: %v", err) } if !repo.consumeRecoveryCodeCalled { t.Fatal("expected atomic recovery-code consumption path to be used") } if repo.user == nil { t.Fatal("expected updated user to be persisted") } var remaining []string if err := json.Unmarshal([]byte(repo.user.TOTPRecoveryCodes), &remaining); err != nil { t.Fatalf("unmarshal remaining codes: %v", err) } if len(remaining) != 1 { t.Fatalf("expected 1 remaining recovery code, got %d", len(remaining)) } if remaining[0] != mustHashRecoveryCode(t, "RECOVERY-2") { t.Fatalf("expected RECOVERY-2 hash to remain, got %q", remaining[0]) } } func TestTOTPService_DisableTOTP_UsesAtomicVerificationPath(t *testing.T) { repo := &totpTestRepo{ user: &domain.User{ ID: 10, Username: "totp-user", TOTPEnabled: true, TOTPSecret: "test-secret", TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}), }, } svc := NewTOTPService(repo) // 使用测试恢复码禁用 TOTP if err := svc.DisableTOTP(context.Background(), 10, "RECOVERY-1"); err != nil { t.Fatalf("expected disable to succeed with recovery code, got: %v", err) } if !repo.verifyTOTPOrRecoveryCodeCalled { t.Fatal("expected atomic verification path to be used") } if repo.user.TOTPEnabled { t.Fatal("expected TOTP to be disabled") } } func TestTOTPService_DisableTOTP_AtomicVerificationFailsOnWrongCode(t *testing.T) { repo := &totpTestRepo{ user: &domain.User{ ID: 11, Username: "totp-user", TOTPEnabled: true, TOTPSecret: "test-secret", TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}), }, } svc := NewTOTPService(repo) // 使用错误的恢复码 err := svc.DisableTOTP(context.Background(), 11, "WRONG-CODE") if err == nil { t.Fatal("expected disable to fail with wrong code") } if !repo.verifyTOTPOrRecoveryCodeCalled { t.Fatal("expected atomic verification path to be used") } if !repo.user.TOTPEnabled { t.Fatal("expected TOTP to remain enabled after failed verification") } } func TestTOTPService_DisableAcceptsHashedRecoveryCode(t *testing.T) { repo := &totpTestRepo{ user: &domain.User{ ID: 9, Username: "totp-user", TOTPEnabled: true, TOTPSecret: "invalid-secret", TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}), }, } svc := NewTOTPService(repo) if err := svc.DisableTOTP(context.Background(), 9, "RECOVERY-1"); err != nil { t.Fatalf("expected hashed recovery code to disable TOTP, got: %v", err) } if repo.user == nil { t.Fatal("expected updated user to be persisted") } if repo.user.TOTPEnabled { t.Fatal("expected TOTP to be disabled") } if repo.user.TOTPSecret != "" || repo.user.TOTPRecoveryCodes != "" { t.Fatalf("expected TOTP secret and recovery codes to be cleared, got secret=%q codes=%q", repo.user.TOTPSecret, repo.user.TOTPRecoveryCodes) } }