Files
user-system/internal/service/totp_internal_test.go
Your Name 363c77d020 feat: atomic TOTP verification for DisableTOTP
- Add atomicTOTPVerifier interface for atomic TOTP/recovery code verification
- Implement VerifyTOTPOrRecoveryCode in UserRepository with transaction
- Update DisableTOTP to prefer atomic verification path
- Add unit tests for atomic verification success/failure paths
- Maintain backward compatibility with non-atomic fallback

Refs: TOTP verification atomicity completion
2026-05-29 12:47:05 +08:00

318 lines
9.9 KiB
Go

package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
)
func mustHashRecoveryCode(t *testing.T, code string) string {
t.Helper()
hashed, err := auth.HashRecoveryCode(code)
if err != nil {
t.Fatalf("hash recovery code: %v", err)
}
return hashed
}
func mustMarshalJSON(t *testing.T, value any) string {
t.Helper()
payload, err := json.Marshal(value)
if err != nil {
t.Fatalf("marshal json: %v", err)
}
return string(payload)
}
type totpTestRepo struct {
user *domain.User
getErr error
updateTOTPErr error
consumeRecoveryCodeErr error
consumeRecoveryCodeCalled bool
verifyTOTPOrRecoveryCodeErr error
verifyTOTPOrRecoveryCodeCalled bool
}
func (r *totpTestRepo) Create(ctx context.Context, user *domain.User) error { return nil }
func (r *totpTestRepo) Update(ctx context.Context, user *domain.User) error { return nil }
func (r *totpTestRepo) UpdateTOTP(ctx context.Context, user *domain.User) error {
if r.updateTOTPErr != nil {
return r.updateTOTPErr
}
copyUser := *user
r.user = &copyUser
return nil
}
func (r *totpTestRepo) Delete(ctx context.Context, id int64) error { return nil }
func (r *totpTestRepo) GetByID(ctx context.Context, id int64) (*domain.User, error) {
if r.getErr != nil {
return nil, r.getErr
}
if r.user == nil || r.user.ID != id {
return nil, errors.New("not found")
}
copyUser := *r.user
return &copyUser, nil
}
func (r *totpTestRepo) GetByUsername(ctx context.Context, username string) (*domain.User, error) {
return nil, errors.New("not implemented")
}
func (r *totpTestRepo) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
return nil, errors.New("not implemented")
}
func (r *totpTestRepo) GetByPhone(ctx context.Context, phone string) (*domain.User, error) {
return nil, errors.New("not implemented")
}
func (r *totpTestRepo) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) {
return nil, 0, errors.New("not implemented")
}
func (r *totpTestRepo) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) {
return nil, 0, errors.New("not implemented")
}
func (r *totpTestRepo) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error {
return errors.New("not implemented")
}
func (r *totpTestRepo) UpdateLastLogin(ctx context.Context, id int64, ip string) error {
return errors.New("not implemented")
}
func (r *totpTestRepo) ExistsByUsername(ctx context.Context, username string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *totpTestRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *totpTestRepo) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *totpTestRepo) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
return nil, 0, errors.New("not implemented")
}
func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) {
r.consumeRecoveryCodeCalled = true
if r.consumeRecoveryCodeErr != nil {
return nil, false, r.consumeRecoveryCodeErr
}
if r.user == nil || r.user.ID != userID {
return nil, false, errors.New("not found")
}
var hashedCodes []string
if strings.TrimSpace(r.user.TOTPRecoveryCodes) != "" {
if err := json.Unmarshal([]byte(r.user.TOTPRecoveryCodes), &hashedCodes); err != nil {
return nil, false, fmt.Errorf("解析恢复码失败: %w", err)
}
}
idx, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
return nil, false, nil
}
copyUser := *r.user
hashedCodes = append(hashedCodes[:idx], hashedCodes[idx+1:]...)
copyUser.TOTPRecoveryCodes = mustMarshalJSONFromHelper(hashedCodes)
r.user = &copyUser
return &copyUser, true, nil
}
func (r *totpTestRepo) VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) {
r.verifyTOTPOrRecoveryCodeCalled = true
if r.verifyTOTPOrRecoveryCodeErr != nil {
return false, r.verifyTOTPOrRecoveryCodeErr
}
if r.user == nil || r.user.ID != userID {
return false, errors.New("not found")
}
if !r.user.TOTPEnabled {
return false, errors.New("TOTP not enabled")
}
// 尝试验证 TOTP 码(简化:只检查是否为特定测试码)
if code == "123456" || code == "654321" {
return true, nil
}
// 尝试验证恢复码
var hashedCodes []string
if strings.TrimSpace(r.user.TOTPRecoveryCodes) != "" {
if err := json.Unmarshal([]byte(r.user.TOTPRecoveryCodes), &hashedCodes); err != nil {
return false, fmt.Errorf("解析恢复码失败: %w", err)
}
}
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
return matched, nil
}
func mustMarshalJSONFromHelper(value any) string {
payload, err := json.Marshal(value)
if err != nil {
panic(err)
}
return string(payload)
}
func TestTOTPService_ReturnsDecodeErrorForCorruptedRecoveryCodes(t *testing.T) {
repo := &totpTestRepo{user: &domain.User{
ID: 42,
Username: "totp-user",
TOTPEnabled: true,
TOTPSecret: "invalid-secret",
TOTPRecoveryCodes: "not-json",
}}
svc := NewTOTPService(repo)
err := svc.VerifyTOTP(context.Background(), 42, "recovery-code")
if err == nil {
t.Fatal("expected corrupted recovery-code payload to fail")
}
if !strings.Contains(err.Error(), "解析恢复码失败") {
t.Fatalf("expected decode error, got: %v", err)
}
}
func TestTOTPService_ReturnsAtomicConsumptionErrorAfterRecoveryCodeConsumption(t *testing.T) {
repo := &totpTestRepo{
user: &domain.User{
ID: 7,
Username: "totp-user",
TOTPEnabled: true,
TOTPSecret: "invalid-secret",
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
},
consumeRecoveryCodeErr: errors.New("write failed"),
}
svc := NewTOTPService(repo)
err := svc.VerifyTOTP(context.Background(), 7, "RECOVERY-1")
if err == nil {
t.Fatal("expected update failure to be returned")
}
if !repo.consumeRecoveryCodeCalled {
t.Fatal("expected atomic consumption path to be invoked")
}
if !strings.Contains(err.Error(), "消费恢复码失败") {
t.Fatalf("expected atomic consume error, got: %v", err)
}
}
func TestTOTPService_ConsumesHashedRecoveryCodeOnVerify(t *testing.T) {
repo := &totpTestRepo{
user: &domain.User{
ID: 8,
Username: "totp-user",
TOTPEnabled: true,
TOTPSecret: "invalid-secret",
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1"), mustHashRecoveryCode(t, "RECOVERY-2")}),
},
}
svc := NewTOTPService(repo)
if err := svc.VerifyTOTP(context.Background(), 8, "RECOVERY-1"); err != nil {
t.Fatalf("expected hashed recovery code to verify, got: %v", err)
}
if !repo.consumeRecoveryCodeCalled {
t.Fatal("expected atomic recovery-code consumption path to be used")
}
if repo.user == nil {
t.Fatal("expected updated user to be persisted")
}
var remaining []string
if err := json.Unmarshal([]byte(repo.user.TOTPRecoveryCodes), &remaining); err != nil {
t.Fatalf("unmarshal remaining codes: %v", err)
}
if len(remaining) != 1 {
t.Fatalf("expected 1 remaining recovery code, got %d", len(remaining))
}
if remaining[0] != mustHashRecoveryCode(t, "RECOVERY-2") {
t.Fatalf("expected RECOVERY-2 hash to remain, got %q", remaining[0])
}
}
func TestTOTPService_DisableTOTP_UsesAtomicVerificationPath(t *testing.T) {
repo := &totpTestRepo{
user: &domain.User{
ID: 10,
Username: "totp-user",
TOTPEnabled: true,
TOTPSecret: "test-secret",
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
},
}
svc := NewTOTPService(repo)
// 使用测试恢复码禁用 TOTP
if err := svc.DisableTOTP(context.Background(), 10, "RECOVERY-1"); err != nil {
t.Fatalf("expected disable to succeed with recovery code, got: %v", err)
}
if !repo.verifyTOTPOrRecoveryCodeCalled {
t.Fatal("expected atomic verification path to be used")
}
if repo.user.TOTPEnabled {
t.Fatal("expected TOTP to be disabled")
}
}
func TestTOTPService_DisableTOTP_AtomicVerificationFailsOnWrongCode(t *testing.T) {
repo := &totpTestRepo{
user: &domain.User{
ID: 11,
Username: "totp-user",
TOTPEnabled: true,
TOTPSecret: "test-secret",
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
},
}
svc := NewTOTPService(repo)
// 使用错误的恢复码
err := svc.DisableTOTP(context.Background(), 11, "WRONG-CODE")
if err == nil {
t.Fatal("expected disable to fail with wrong code")
}
if !repo.verifyTOTPOrRecoveryCodeCalled {
t.Fatal("expected atomic verification path to be used")
}
if !repo.user.TOTPEnabled {
t.Fatal("expected TOTP to remain enabled after failed verification")
}
}
func TestTOTPService_DisableAcceptsHashedRecoveryCode(t *testing.T) {
repo := &totpTestRepo{
user: &domain.User{
ID: 9,
Username: "totp-user",
TOTPEnabled: true,
TOTPSecret: "invalid-secret",
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
},
}
svc := NewTOTPService(repo)
if err := svc.DisableTOTP(context.Background(), 9, "RECOVERY-1"); err != nil {
t.Fatalf("expected hashed recovery code to disable TOTP, got: %v", err)
}
if repo.user == nil {
t.Fatal("expected updated user to be persisted")
}
if repo.user.TOTPEnabled {
t.Fatal("expected TOTP to be disabled")
}
if repo.user.TOTPSecret != "" || repo.user.TOTPRecoveryCodes != "" {
t.Fatalf("expected TOTP secret and recovery codes to be cleared, got secret=%q codes=%q", repo.user.TOTPSecret, repo.user.TOTPRecoveryCodes)
}
}