package service import ( "context" "errors" "testing" "time" "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/cache" "github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/security" "gorm.io/gorm" ) // ============================================================================= // Auth Runtime Helper Functions Tests // ============================================================================= func TestIsUserNotFoundError(t *testing.T) { tests := []struct { name string err error expected bool }{ { name: "nil error", err: nil, expected: false, }, { name: "gorm record not found", err: gorm.ErrRecordNotFound, expected: true, }, { name: "wrapped gorm record not found", err: errors.Join(gorm.ErrRecordNotFound, errors.New("additional context")), expected: true, }, { name: "other error", err: errors.New("some other error"), expected: false, }, { name: "generic error", err: errors.New("something went wrong"), expected: false, }, { name: "error containing user not found", err: errors.New("user not found"), expected: true, // contains "user not found" in lowercase }, { name: "error containing record not found", err: errors.New("record not found"), expected: true, // contains "record not found" }, { name: "error containing not found", err: errors.New("entity not found"), expected: true, // contains "not found" }, { name: "error containing 用户不存在", err: errors.New("用户不存在"), expected: true, // contains Chinese "用户不存在" }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := isUserNotFoundError(tt.err) if result != tt.expected { t.Errorf("isUserNotFoundError(%v) = %v, want %v", tt.err, result, tt.expected) } }) } } // ============================================================================= // OAuth State Tests // ============================================================================= func TestAuthService_CreateOAuthState(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) svc := &AuthService{cache: cacheManager} ctx := context.Background() t.Run("CreateOAuthState success", func(t *testing.T) { state, err := svc.CreateOAuthState(ctx, "http://localhost/callback") if err != nil { t.Fatalf("CreateOAuthState failed: %v", err) } if state == "" { t.Error("Expected non-empty state") } }) t.Run("CreateOAuthState with empty return URL", func(t *testing.T) { state, err := svc.CreateOAuthState(ctx, "") if err != nil { t.Fatalf("CreateOAuthState failed: %v", err) } if state == "" { t.Error("Expected non-empty state") } }) } func TestAuthService_CreateOAuthBindState(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) svc := &AuthService{cache: cacheManager} ctx := context.Background() t.Run("CreateOAuthBindState success", func(t *testing.T) { state, err := svc.CreateOAuthBindState(ctx, 1, "http://localhost/callback") if err != nil { t.Fatalf("CreateOAuthBindState failed: %v", err) } if state == "" { t.Error("Expected non-empty state") } }) t.Run("CreateOAuthBindState with invalid user ID", func(t *testing.T) { _, err := svc.CreateOAuthBindState(ctx, 0, "http://localhost/callback") if err == nil { t.Error("Expected error for invalid user ID") } }) } func TestAuthService_ConsumeOAuthState(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) svc := &AuthService{cache: cacheManager} ctx := context.Background() t.Run("ConsumeOAuthState invalid state", func(t *testing.T) { _, err := svc.ConsumeOAuthState(ctx, "invalid_state") if err == nil { t.Error("Expected error for invalid state") } }) t.Run("ConsumeOAuthState valid state", func(t *testing.T) { state, _ := svc.CreateOAuthState(ctx, "http://localhost/callback") returnTo, err := svc.ConsumeOAuthState(ctx, state) if err != nil { t.Fatalf("ConsumeOAuthState failed: %v", err) } if returnTo != "http://localhost/callback" { t.Errorf("Expected return URL 'http://localhost/callback', got %s", returnTo) } }) } func TestAuthService_ConsumeOAuthStatePayload(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) svc := &AuthService{cache: cacheManager} ctx := context.Background() t.Run("ConsumeOAuthStatePayload with bind purpose", func(t *testing.T) { state, _ := svc.CreateOAuthBindState(ctx, 123, "http://localhost/callback") payload, err := svc.ConsumeOAuthStatePayload(ctx, state) if err != nil { t.Fatalf("ConsumeOAuthStatePayload failed: %v", err) } if payload.Purpose != OAuthStatePurposeBind { t.Errorf("Expected purpose 'bind', got %s", payload.Purpose) } if payload.UserID != 123 { t.Errorf("Expected user ID 123, got %d", payload.UserID) } }) } func TestAuthService_CreateOAuthHandoff(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) svc := &AuthService{cache: cacheManager} ctx := context.Background() t.Run("CreateOAuthHandoff success", func(t *testing.T) { loginResp := &LoginResponse{ AccessToken: "test_token", RefreshToken: "test_refresh", } code, err := svc.CreateOAuthHandoff(ctx, loginResp) if err != nil { t.Fatalf("CreateOAuthHandoff failed: %v", err) } if code == "" { t.Error("Expected non-empty code") } }) t.Run("CreateOAuthHandoff with nil response", func(t *testing.T) { _, err := svc.CreateOAuthHandoff(ctx, nil) if err == nil { t.Error("Expected error for nil response") } }) } func TestAuthService_ConsumeOAuthHandoff(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) svc := &AuthService{cache: cacheManager} ctx := context.Background() t.Run("ConsumeOAuthHandoff invalid code", func(t *testing.T) { _, err := svc.ConsumeOAuthHandoff(ctx, "invalid_code") if err == nil { t.Error("Expected error for invalid code") } }) t.Run("ConsumeOAuthHandoff valid code", func(t *testing.T) { loginResp := &LoginResponse{ AccessToken: "test_token", RefreshToken: "test_refresh", } code, _ := svc.CreateOAuthHandoff(ctx, loginResp) resp, err := svc.ConsumeOAuthHandoff(ctx, code) if err != nil { t.Fatalf("ConsumeOAuthHandoff failed: %v", err) } if resp.AccessToken != "test_token" { t.Errorf("Expected access token 'test_token', got %s", resp.AccessToken) } }) } func TestAuthService_OAuthStateNilService(t *testing.T) { var nilSvc *AuthService ctx := context.Background() t.Run("CreateOAuthState nil service", func(t *testing.T) { _, err := nilSvc.CreateOAuthState(ctx, "http://localhost") if err == nil { t.Error("Expected error for nil service") } }) t.Run("ConsumeOAuthState nil service", func(t *testing.T) { _, err := nilSvc.ConsumeOAuthState(ctx, "state") if err == nil { t.Error("Expected error for nil service") } }) t.Run("CreateOAuthHandoff nil service", func(t *testing.T) { _, err := nilSvc.CreateOAuthHandoff(ctx, &LoginResponse{}) if err == nil { t.Error("Expected error for nil service") } }) t.Run("ConsumeOAuthHandoff nil service", func(t *testing.T) { _, err := nilSvc.ConsumeOAuthHandoff(ctx, "code") if err == nil { t.Error("Expected error for nil service") } }) } func TestGenerateOAuthEphemeralCode(t *testing.T) { code, err := generateOAuthEphemeralCode() if err != nil { t.Fatalf("generateOAuthEphemeralCode failed: %v", err) } if code == "" { t.Error("Expected non-empty code") } // Should generate different codes code2, _ := generateOAuthEphemeralCode() if code == code2 { t.Error("Expected different codes on each call") } } // ============================================================================= // Password Policy Tests // ============================================================================= func TestAuthService_SetPasswordPolicy(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) svc := &AuthService{cache: cacheManager} t.Run("SetPasswordPolicy success", func(t *testing.T) { policy := security.PasswordPolicy{ MinLength: 12, RequireSpecial: true, RequireNumber: true, } svc.SetPasswordPolicy(policy) // Verify policy is set if !svc.passwordPolicySet { t.Error("Expected passwordPolicySet to be true") } if svc.passwordPolicy.MinLength != 12 { t.Errorf("Expected MinLength 12, got %d", svc.passwordPolicy.MinLength) } }) t.Run("SetPasswordPolicy with defaults", func(t *testing.T) { svc2 := &AuthService{cache: cacheManager} policy := security.PasswordPolicy{} // Empty policy svc2.SetPasswordPolicy(policy) // Should normalize to default min length 8 if svc2.passwordPolicy.MinLength != 8 { t.Errorf("Expected normalized MinLength 8, got %d", svc2.passwordPolicy.MinLength) } }) } // ============================================================================= // Social Account Helper Tests // ============================================================================= func TestFindSocialAccountByProvider(t *testing.T) { tests := []struct { name string accounts []*domain.SocialAccount provider string expectNil bool }{ { name: "nil accounts", accounts: nil, provider: "github", expectNil: true, }, { name: "empty accounts", accounts: []*domain.SocialAccount{}, provider: "github", expectNil: true, }, { name: "found matching provider", accounts: []*domain.SocialAccount{ {Provider: "github", OpenID: "123"}, {Provider: "google", OpenID: "456"}, }, provider: "github", expectNil: false, }, { name: "case insensitive match", accounts: []*domain.SocialAccount{ {Provider: "GitHub", OpenID: "123"}, }, provider: "github", expectNil: false, }, { name: "provider not found", accounts: []*domain.SocialAccount{ {Provider: "google", OpenID: "456"}, }, provider: "github", expectNil: true, }, { name: "nil account in list", accounts: []*domain.SocialAccount{ nil, {Provider: "github", OpenID: "123"}, }, provider: "github", expectNil: false, }, { name: "empty provider", accounts: []*domain.SocialAccount{ {Provider: "github", OpenID: "123"}, }, provider: "", expectNil: true, }, { name: "provider with spaces", accounts: []*domain.SocialAccount{ {Provider: " github ", OpenID: "123"}, }, provider: "github", expectNil: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := findSocialAccountByProvider(tt.accounts, tt.provider) if (result == nil) != tt.expectNil { t.Errorf("findSocialAccountByProvider() nil = %v, expectNil = %v", result == nil, tt.expectNil) } }) } } // ============================================================================= // Available Login Method Count Tests // ============================================================================= func TestAuthService_AvailableLoginMethodCount(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) t.Run("nil user", func(t *testing.T) { svc := &AuthService{cache: cacheManager} count := svc.availableLoginMethodCount(nil, nil, "") if count != 0 { t.Errorf("Expected 0 for nil user, got %d", count) } }) t.Run("password only", func(t *testing.T) { svc := &AuthService{cache: cacheManager} user := &domain.User{Password: "hashed_password"} count := svc.availableLoginMethodCount(user, nil, "") if count != 1 { t.Errorf("Expected 1 for password only, got %d", count) } }) t.Run("password and email with email service", func(t *testing.T) { email := "test@example.com" svc := &AuthService{ cache: cacheManager, emailCodeSvc: &EmailCodeService{}, } user := &domain.User{Password: "hashed_password", Email: &email} count := svc.availableLoginMethodCount(user, nil, "") if count != 2 { t.Errorf("Expected 2 for password and email, got %d", count) } }) t.Run("password and phone with sms service", func(t *testing.T) { phone := "13800138000" svc := &AuthService{ cache: cacheManager, smsCodeSvc: &SMSCodeService{}, } user := &domain.User{Password: "hashed_password", Phone: &phone} count := svc.availableLoginMethodCount(user, nil, "") if count != 2 { t.Errorf("Expected 2 for password and phone, got %d", count) } }) t.Run("all methods", func(t *testing.T) { email := "test@example.com" phone := "13800138000" svc := &AuthService{ cache: cacheManager, emailCodeSvc: &EmailCodeService{}, smsCodeSvc: &SMSCodeService{}, } user := &domain.User{Password: "hashed_password", Email: &email, Phone: &phone} accounts := []*domain.SocialAccount{ {Provider: "github", Status: domain.SocialAccountStatusActive}, } count := svc.availableLoginMethodCount(user, accounts, "") if count != 4 { t.Errorf("Expected 4 for all methods, got %d", count) } }) t.Run("exclude social provider", func(t *testing.T) { email := "test@example.com" svc := &AuthService{ cache: cacheManager, emailCodeSvc: &EmailCodeService{}, } user := &domain.User{Password: "hashed_password", Email: &email} accounts := []*domain.SocialAccount{ {Provider: "github", Status: domain.SocialAccountStatusActive}, {Provider: "google", Status: domain.SocialAccountStatusActive}, } count := svc.availableLoginMethodCount(user, accounts, "github") // password + email + google (github excluded) if count != 3 { t.Errorf("Expected 3 with github excluded, got %d", count) } }) t.Run("inactive social accounts not counted", func(t *testing.T) { svc := &AuthService{cache: cacheManager} user := &domain.User{Password: "hashed_password"} accounts := []*domain.SocialAccount{ {Provider: "github", Status: domain.SocialAccountStatusActive}, {Provider: "google", Status: 0}, // inactive nil, // nil account } count := svc.availableLoginMethodCount(user, accounts, "") // password + github only if count != 2 { t.Errorf("Expected 2 with inactive filtered, got %d", count) } }) t.Run("empty password not counted", func(t *testing.T) { svc := &AuthService{cache: cacheManager} user := &domain.User{Password: " "} count := svc.availableLoginMethodCount(user, nil, "") if count != 0 { t.Errorf("Expected 0 for empty password, got %d", count) } }) } // ============================================================================= // Generate Unique Username Tests // ============================================================================= func TestGenerateUniqueUsername(t *testing.T) { t.Run("nil service returns sanitized username", func(t *testing.T) { var nilSvc *AuthService username, err := nilSvc.generateUniqueUsername(context.Background(), "Test User") if err != nil { t.Fatalf("Expected no error, got: %v", err) } if username != "test_user" { t.Errorf("Expected 'test_user', got %q", username) } }) t.Run("service with nil userRepo returns sanitized username", func(t *testing.T) { svc := &AuthService{} username, err := svc.generateUniqueUsername(context.Background(), "John Doe") if err != nil { t.Fatalf("Expected no error, got: %v", err) } if username != "john_doe" { t.Errorf("Expected 'john_doe', got %q", username) } }) t.Run("long username is truncated", func(t *testing.T) { svc := &AuthService{} longName := "this_is_a_very_long_username_that_should_be_truncated_to_forty_characters" username, err := svc.generateUniqueUsername(context.Background(), longName) if err != nil { t.Fatalf("Expected no error, got: %v", err) } if len(username) > 50 { t.Errorf("Username should be max 50 chars, got %d", len(username)) } }) t.Run("empty base returns user", func(t *testing.T) { svc := &AuthService{} username, err := svc.generateUniqueUsername(context.Background(), "") if err != nil { t.Fatalf("Expected no error, got: %v", err) } if username != "user" { t.Errorf("Expected 'user', got %q", username) } }) } // ============================================================================= // Set Login Log Repository Tests // ============================================================================= func TestAuthService_SetLoginLogRepository(t *testing.T) { svc := &AuthService{} // Should not panic with nil svc.SetLoginLogRepository(nil) } // ============================================================================= // Set Anomaly Detector Tests // ============================================================================= func TestAuthService_SetAnomalyDetector(t *testing.T) { svc := &AuthService{} // Should not panic with nil svc.SetAnomalyDetector(nil) } // ============================================================================= // Set Device Service Tests // ============================================================================= func TestAuthService_SetDeviceService(t *testing.T) { svc := &AuthService{} // Should not panic with nil svc.SetDeviceService(nil) } // ============================================================================= // Set SMS Code Service Tests // ============================================================================= func TestAuthService_SetSMSCodeService(t *testing.T) { svc := &AuthService{} // Should not panic with nil svc.SetSMSCodeService(nil) } // ============================================================================= // Available Login Method Count After Contact Removal Tests // ============================================================================= func TestAuthService_AvailableLoginMethodCountAfterContactRemoval(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) t.Run("nil user", func(t *testing.T) { svc := &AuthService{cache: cacheManager} count := svc.availableLoginMethodCountAfterContactRemoval(nil, nil, false, false) if count != 0 { t.Errorf("Expected 0 for nil user, got %d", count) } }) t.Run("password only no removal", func(t *testing.T) { svc := &AuthService{cache: cacheManager} user := &domain.User{Password: "hashed_password"} count := svc.availableLoginMethodCountAfterContactRemoval(user, nil, false, false) if count != 1 { t.Errorf("Expected 1 for password only, got %d", count) } }) t.Run("password and email with email service", func(t *testing.T) { email := "test@example.com" svc := &AuthService{ cache: cacheManager, emailCodeSvc: &EmailCodeService{}, } user := &domain.User{Password: "hashed_password", Email: &email} count := svc.availableLoginMethodCountAfterContactRemoval(user, nil, false, false) if count != 2 { t.Errorf("Expected 2 for password and email, got %d", count) } }) t.Run("remove email", func(t *testing.T) { email := "test@example.com" svc := &AuthService{ cache: cacheManager, emailCodeSvc: &EmailCodeService{}, } user := &domain.User{Password: "hashed_password", Email: &email} count := svc.availableLoginMethodCountAfterContactRemoval(user, nil, true, false) if count != 1 { t.Errorf("Expected 1 after email removal, got %d", count) } }) t.Run("remove phone", func(t *testing.T) { phone := "13800138000" svc := &AuthService{ cache: cacheManager, smsCodeSvc: &SMSCodeService{}, } user := &domain.User{Password: "hashed_password", Phone: &phone} count := svc.availableLoginMethodCountAfterContactRemoval(user, nil, false, true) if count != 1 { t.Errorf("Expected 1 after phone removal, got %d", count) } }) t.Run("social accounts counted", func(t *testing.T) { svc := &AuthService{cache: cacheManager} user := &domain.User{Password: "hashed_password"} accounts := []*domain.SocialAccount{ {Provider: "github", Status: domain.SocialAccountStatusActive}, {Provider: "google", Status: domain.SocialAccountStatusActive}, } count := svc.availableLoginMethodCountAfterContactRemoval(user, accounts, false, false) if count != 3 { t.Errorf("Expected 3 with social accounts, got %d", count) } }) t.Run("inactive social accounts not counted", func(t *testing.T) { svc := &AuthService{cache: cacheManager} user := &domain.User{Password: "hashed_password"} accounts := []*domain.SocialAccount{ {Provider: "github", Status: domain.SocialAccountStatusActive}, {Provider: "google", Status: 0}, // inactive nil, } count := svc.availableLoginMethodCountAfterContactRemoval(user, accounts, false, false) if count != 2 { t.Errorf("Expected 2 with inactive filtered, got %d", count) } }) } // ============================================================================= // Register OAuth Provider Tests // ============================================================================= func TestAuthService_RegisterOAuthProvider(t *testing.T) { t.Run("nil config does nothing", func(t *testing.T) { svc := &AuthService{} // Should not panic with nil config svc.RegisterOAuthProvider("github", nil) }) t.Run("nil oauth manager", func(t *testing.T) { svc := &AuthService{} cfg := &auth.OAuthConfig{ClientID: "test"} // Should not panic with nil oauthManager svc.RegisterOAuthProvider("github", cfg) }) } // ============================================================================= // Best Effort Register Device Public Tests // ============================================================================= func TestAuthService_BestEffortRegisterDevicePublic(t *testing.T) { t.Run("nil service does not panic", func(t *testing.T) { var nilSvc *AuthService // Should not panic nilSvc.BestEffortRegisterDevicePublic(context.Background(), 1, nil) }) t.Run("nil device service does not panic", func(t *testing.T) { svc := &AuthService{} svc.BestEffortRegisterDevicePublic(context.Background(), 1, &LoginRequest{}) // Should not panic }) } // ============================================================================= // Int Value and Int64 Value Tests // ============================================================================= func TestIntValue(t *testing.T) { tests := []struct { name string input interface{} expected int wantOk bool }{ {"int value", 42, 42, true}, {"int64 value", int64(100), 100, true}, {"float64 value", float64(99.0), 99, true}, {"float64 with decimal", float64(99.5), 99, true}, {"string value", "42", 0, false}, {"nil value", nil, 0, false}, {"negative int", -5, -5, true}, {"zero value", 0, 0, true}, {"large int64", int64(9999999999), 9999999999, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, ok := intValue(tt.input) if result != tt.expected || ok != tt.wantOk { t.Errorf("intValue(%v) = (%d, %v), want (%d, %v)", tt.input, result, ok, tt.expected, tt.wantOk) } }) } } func TestInt64Value(t *testing.T) { tests := []struct { name string input interface{} expected int64 wantOk bool }{ {"int value", 42, 42, true}, {"int64 value", int64(100), 100, true}, {"float64 value", float64(99.0), 99, true}, {"float64 with decimal", float64(99.5), 99, true}, {"string value", "42", 0, false}, {"nil value", nil, 0, false}, {"negative int64", int64(-5), -5, true}, {"zero value", 0, 0, true}, {"large int64", int64(9999999999), 9999999999, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, ok := int64Value(tt.input) if result != tt.expected || ok != tt.wantOk { t.Errorf("int64Value(%v) = (%d, %v), want (%d, %v)", tt.input, result, ok, tt.expected, tt.wantOk) } }) } } // ============================================================================= // Best Effort Update Last Login Tests // ============================================================================= func TestBestEffortUpdateLastLogin(t *testing.T) { t.Run("nil service does not panic", func(t *testing.T) { var nilSvc *AuthService // Should not panic nilSvc.bestEffortUpdateLastLogin(context.Background(), 1, "127.0.0.1", "password") }) } // ============================================================================= // Best Effort Assign Default Roles Tests // ============================================================================= func TestBestEffortAssignDefaultRoles(t *testing.T) { t.Run("nil service does not panic", func(t *testing.T) { var nilSvc *AuthService nilSvc.bestEffortAssignDefaultRoles(context.Background(), 1, "register") }) t.Run("service without repos does not panic", func(t *testing.T) { svc := &AuthService{} svc.bestEffortAssignDefaultRoles(context.Background(), 1, "register") }) } // ============================================================================= // Create OAuth State Payload Tests // ============================================================================= func TestCreateOAuthStatePayload(t *testing.T) { t.Run("nil service returns error", func(t *testing.T) { var nilSvc *AuthService _, err := nilSvc.createOAuthStatePayload(context.Background(), &OAuthStatePayload{Purpose: OAuthStatePurposeLogin}) if err == nil { t.Error("Expected error for nil service") } }) t.Run("service without cache returns error", func(t *testing.T) { svc := &AuthService{} _, err := svc.createOAuthStatePayload(context.Background(), &OAuthStatePayload{Purpose: OAuthStatePurposeLogin}) if err == nil { t.Error("Expected error when cache not configured") } }) t.Run("nil payload returns error", func(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) svc := &AuthService{cache: cacheManager} _, err := svc.createOAuthStatePayload(context.Background(), nil) if err == nil { t.Error("Expected error for nil payload") } }) t.Run("create state payload with cache", func(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) svc := &AuthService{cache: cacheManager} state, err := svc.createOAuthStatePayload(context.Background(), &OAuthStatePayload{ Purpose: OAuthStatePurposeLogin, ReturnTo: "http://localhost/callback", }) if err != nil { t.Fatalf("createOAuthStatePayload failed: %v", err) } if state == "" { t.Error("Expected non-empty state") } }) t.Run("create state payload with default purpose", func(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) svc := &AuthService{cache: cacheManager} state, err := svc.createOAuthStatePayload(context.Background(), &OAuthStatePayload{ ReturnTo: "http://localhost/callback", }) if err != nil { t.Fatalf("createOAuthStatePayload failed: %v", err) } if state == "" { t.Error("Expected non-empty state") } }) } // ============================================================================= // Verify Phone Registration Tests // ============================================================================= func TestVerifyPhoneRegistration(t *testing.T) { t.Run("nil service returns nil for empty phone", func(t *testing.T) { var nilSvc *AuthService err := nilSvc.verifyPhoneRegistration(context.Background(), &RegisterRequest{Phone: ""}) if err != nil { t.Errorf("Expected nil error for empty phone, got: %v", err) } }) t.Run("nil request returns nil", func(t *testing.T) { svc := &AuthService{} err := svc.verifyPhoneRegistration(context.Background(), nil) if err != nil { t.Errorf("Expected nil error for nil request, got: %v", err) } }) t.Run("service without SMS returns error", func(t *testing.T) { svc := &AuthService{} err := svc.verifyPhoneRegistration(context.Background(), &RegisterRequest{Phone: "13800138000", PhoneCode: "123456"}) if err == nil { t.Error("Expected error when SMS service not configured") } }) t.Run("empty phone code returns error", func(t *testing.T) { svc := &AuthService{smsCodeSvc: &SMSCodeService{}} err := svc.verifyPhoneRegistration(context.Background(), &RegisterRequest{Phone: "13800138000", PhoneCode: ""}) if err == nil { t.Error("Expected error when phone code is empty") } }) } // ============================================================================= // Consume OAuth State Payload Tests // ============================================================================= func TestConsumeOAuthStatePayload(t *testing.T) { t.Run("nil service returns error", func(t *testing.T) { var nilSvc *AuthService _, err := nilSvc.ConsumeOAuthStatePayload(context.Background(), "state123") if err == nil { t.Error("Expected error for nil service") } }) t.Run("service without cache returns error", func(t *testing.T) { svc := &AuthService{} _, err := svc.ConsumeOAuthStatePayload(context.Background(), "state123") if err == nil { t.Error("Expected error when cache not configured") } }) } // ============================================================================= // Consume OAuth Handoff Tests // ============================================================================= func TestConsumeOAuthHandoff(t *testing.T) { t.Run("nil service returns error", func(t *testing.T) { var nilSvc *AuthService _, err := nilSvc.ConsumeOAuthHandoff(context.Background(), "code123") if err == nil { t.Error("Expected error for nil service") } }) t.Run("service without cache returns error", func(t *testing.T) { svc := &AuthService{} _, err := svc.ConsumeOAuthHandoff(context.Background(), "code123") if err == nil { t.Error("Expected error when cache not configured") } }) } // ============================================================================= // Consume OAuth Handoff With Cache Tests // ============================================================================= func TestConsumeOAuthHandoff_WithCache(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) svc := &AuthService{cache: cacheManager} ctx := context.Background() t.Run("consume non-existent handoff", func(t *testing.T) { _, err := svc.ConsumeOAuthHandoff(ctx, "nonexistent_code") if err == nil { t.Error("Expected error for non-existent handoff") } }) t.Run("consume handoff with pointer response", func(t *testing.T) { resp := &LoginResponse{ AccessToken: "test_access_token", RefreshToken: "test_refresh_token", } cacheManager.Set(ctx, "oauth_handoff:test_code_1", resp, time.Minute, time.Minute) result, err := svc.ConsumeOAuthHandoff(ctx, "test_code_1") if err != nil { t.Fatalf("ConsumeOAuthHandoff failed: %v", err) } if result.AccessToken != "test_access_token" { t.Errorf("Expected access token, got %s", result.AccessToken) } }) t.Run("consume handoff with value response", func(t *testing.T) { resp := LoginResponse{ AccessToken: "value_access_token", RefreshToken: "value_refresh_token", } cacheManager.Set(ctx, "oauth_handoff:test_code_2", resp, time.Minute, time.Minute) result, err := svc.ConsumeOAuthHandoff(ctx, "test_code_2") if err != nil { t.Fatalf("ConsumeOAuthHandoff failed: %v", err) } if result.AccessToken != "value_access_token" { t.Errorf("Expected access token, got %s", result.AccessToken) } }) } // ============================================================================= // Consume OAuth State Payload With Cache Tests // ============================================================================= func TestConsumeOAuthStatePayload_WithCache(t *testing.T) { l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) svc := &AuthService{cache: cacheManager} ctx := context.Background() t.Run("consume non-existent state", func(t *testing.T) { _, err := svc.ConsumeOAuthStatePayload(ctx, "nonexistent_state") if err == nil { t.Error("Expected error for non-existent state") } }) t.Run("consume state with pointer payload", func(t *testing.T) { payload := &OAuthStatePayload{ Purpose: OAuthStatePurposeLogin, ReturnTo: "http://localhost/callback", } cacheManager.Set(ctx, "oauth_state:test_state_1", payload, time.Minute*10, time.Minute*10) result, err := svc.ConsumeOAuthStatePayload(ctx, "test_state_1") if err != nil { t.Fatalf("ConsumeOAuthStatePayload failed: %v", err) } if result.Purpose != OAuthStatePurposeLogin { t.Errorf("Expected purpose %s, got %s", OAuthStatePurposeLogin, result.Purpose) } }) t.Run("consume state with value payload", func(t *testing.T) { payload := OAuthStatePayload{ Purpose: OAuthStatePurposeBind, ReturnTo: "http://localhost/bind", UserID: 123, } cacheManager.Set(ctx, "oauth_state:test_state_2", payload, time.Minute*10, time.Minute*10) result, err := svc.ConsumeOAuthStatePayload(ctx, "test_state_2") if err != nil { t.Fatalf("ConsumeOAuthStatePayload failed: %v", err) } if result.UserID != 123 { t.Errorf("Expected UserID 123, got %d", result.UserID) } }) }