Files
user-system/internal/service/auth_runtime_test.go
long-agent 582ad7a069 test: add comprehensive test coverage and improve code quality
- Add new test files for auth, service, and handler modules
- Improve test organization and coverage
- Refactor code for better maintainability
- Add captcha, settings, stats, and theme handler tests
- Add auth module tests (CAS, OAuth, password, SSO, state)
- Add service layer tests for auth, export, permissions, roles
- All Go tests pass (exit code 0)
- All frontend tests pass (325 tests in 59 files)
2026-04-17 20:43:50 +08:00

1092 lines
33 KiB
Go

package service
import (
"context"
"errors"
"testing"
"time"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/security"
"gorm.io/gorm"
)
// =============================================================================
// Auth Runtime Helper Functions Tests
// =============================================================================
func TestIsUserNotFoundError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "gorm record not found",
err: gorm.ErrRecordNotFound,
expected: true,
},
{
name: "wrapped gorm record not found",
err: errors.Join(gorm.ErrRecordNotFound, errors.New("additional context")),
expected: true,
},
{
name: "other error",
err: errors.New("some other error"),
expected: false,
},
{
name: "generic error",
err: errors.New("something went wrong"),
expected: false,
},
{
name: "error containing user not found",
err: errors.New("user not found"),
expected: true, // contains "user not found" in lowercase
},
{
name: "error containing record not found",
err: errors.New("record not found"),
expected: true, // contains "record not found"
},
{
name: "error containing not found",
err: errors.New("entity not found"),
expected: true, // contains "not found"
},
{
name: "error containing 用户不存在",
err: errors.New("用户不存在"),
expected: true, // contains Chinese "用户不存在"
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isUserNotFoundError(tt.err)
if result != tt.expected {
t.Errorf("isUserNotFoundError(%v) = %v, want %v", tt.err, result, tt.expected)
}
})
}
}
// =============================================================================
// OAuth State Tests
// =============================================================================
func TestAuthService_CreateOAuthState(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
svc := &AuthService{cache: cacheManager}
ctx := context.Background()
t.Run("CreateOAuthState success", func(t *testing.T) {
state, err := svc.CreateOAuthState(ctx, "http://localhost/callback")
if err != nil {
t.Fatalf("CreateOAuthState failed: %v", err)
}
if state == "" {
t.Error("Expected non-empty state")
}
})
t.Run("CreateOAuthState with empty return URL", func(t *testing.T) {
state, err := svc.CreateOAuthState(ctx, "")
if err != nil {
t.Fatalf("CreateOAuthState failed: %v", err)
}
if state == "" {
t.Error("Expected non-empty state")
}
})
}
func TestAuthService_CreateOAuthBindState(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
svc := &AuthService{cache: cacheManager}
ctx := context.Background()
t.Run("CreateOAuthBindState success", func(t *testing.T) {
state, err := svc.CreateOAuthBindState(ctx, 1, "http://localhost/callback")
if err != nil {
t.Fatalf("CreateOAuthBindState failed: %v", err)
}
if state == "" {
t.Error("Expected non-empty state")
}
})
t.Run("CreateOAuthBindState with invalid user ID", func(t *testing.T) {
_, err := svc.CreateOAuthBindState(ctx, 0, "http://localhost/callback")
if err == nil {
t.Error("Expected error for invalid user ID")
}
})
}
func TestAuthService_ConsumeOAuthState(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
svc := &AuthService{cache: cacheManager}
ctx := context.Background()
t.Run("ConsumeOAuthState invalid state", func(t *testing.T) {
_, err := svc.ConsumeOAuthState(ctx, "invalid_state")
if err == nil {
t.Error("Expected error for invalid state")
}
})
t.Run("ConsumeOAuthState valid state", func(t *testing.T) {
state, _ := svc.CreateOAuthState(ctx, "http://localhost/callback")
returnTo, err := svc.ConsumeOAuthState(ctx, state)
if err != nil {
t.Fatalf("ConsumeOAuthState failed: %v", err)
}
if returnTo != "http://localhost/callback" {
t.Errorf("Expected return URL 'http://localhost/callback', got %s", returnTo)
}
})
}
func TestAuthService_ConsumeOAuthStatePayload(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
svc := &AuthService{cache: cacheManager}
ctx := context.Background()
t.Run("ConsumeOAuthStatePayload with bind purpose", func(t *testing.T) {
state, _ := svc.CreateOAuthBindState(ctx, 123, "http://localhost/callback")
payload, err := svc.ConsumeOAuthStatePayload(ctx, state)
if err != nil {
t.Fatalf("ConsumeOAuthStatePayload failed: %v", err)
}
if payload.Purpose != OAuthStatePurposeBind {
t.Errorf("Expected purpose 'bind', got %s", payload.Purpose)
}
if payload.UserID != 123 {
t.Errorf("Expected user ID 123, got %d", payload.UserID)
}
})
}
func TestAuthService_CreateOAuthHandoff(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
svc := &AuthService{cache: cacheManager}
ctx := context.Background()
t.Run("CreateOAuthHandoff success", func(t *testing.T) {
loginResp := &LoginResponse{
AccessToken: "test_token",
RefreshToken: "test_refresh",
}
code, err := svc.CreateOAuthHandoff(ctx, loginResp)
if err != nil {
t.Fatalf("CreateOAuthHandoff failed: %v", err)
}
if code == "" {
t.Error("Expected non-empty code")
}
})
t.Run("CreateOAuthHandoff with nil response", func(t *testing.T) {
_, err := svc.CreateOAuthHandoff(ctx, nil)
if err == nil {
t.Error("Expected error for nil response")
}
})
}
func TestAuthService_ConsumeOAuthHandoff(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
svc := &AuthService{cache: cacheManager}
ctx := context.Background()
t.Run("ConsumeOAuthHandoff invalid code", func(t *testing.T) {
_, err := svc.ConsumeOAuthHandoff(ctx, "invalid_code")
if err == nil {
t.Error("Expected error for invalid code")
}
})
t.Run("ConsumeOAuthHandoff valid code", func(t *testing.T) {
loginResp := &LoginResponse{
AccessToken: "test_token",
RefreshToken: "test_refresh",
}
code, _ := svc.CreateOAuthHandoff(ctx, loginResp)
resp, err := svc.ConsumeOAuthHandoff(ctx, code)
if err != nil {
t.Fatalf("ConsumeOAuthHandoff failed: %v", err)
}
if resp.AccessToken != "test_token" {
t.Errorf("Expected access token 'test_token', got %s", resp.AccessToken)
}
})
}
func TestAuthService_OAuthStateNilService(t *testing.T) {
var nilSvc *AuthService
ctx := context.Background()
t.Run("CreateOAuthState nil service", func(t *testing.T) {
_, err := nilSvc.CreateOAuthState(ctx, "http://localhost")
if err == nil {
t.Error("Expected error for nil service")
}
})
t.Run("ConsumeOAuthState nil service", func(t *testing.T) {
_, err := nilSvc.ConsumeOAuthState(ctx, "state")
if err == nil {
t.Error("Expected error for nil service")
}
})
t.Run("CreateOAuthHandoff nil service", func(t *testing.T) {
_, err := nilSvc.CreateOAuthHandoff(ctx, &LoginResponse{})
if err == nil {
t.Error("Expected error for nil service")
}
})
t.Run("ConsumeOAuthHandoff nil service", func(t *testing.T) {
_, err := nilSvc.ConsumeOAuthHandoff(ctx, "code")
if err == nil {
t.Error("Expected error for nil service")
}
})
}
func TestGenerateOAuthEphemeralCode(t *testing.T) {
code, err := generateOAuthEphemeralCode()
if err != nil {
t.Fatalf("generateOAuthEphemeralCode failed: %v", err)
}
if code == "" {
t.Error("Expected non-empty code")
}
// Should generate different codes
code2, _ := generateOAuthEphemeralCode()
if code == code2 {
t.Error("Expected different codes on each call")
}
}
// =============================================================================
// Password Policy Tests
// =============================================================================
func TestAuthService_SetPasswordPolicy(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
svc := &AuthService{cache: cacheManager}
t.Run("SetPasswordPolicy success", func(t *testing.T) {
policy := security.PasswordPolicy{
MinLength: 12,
RequireSpecial: true,
RequireNumber: true,
}
svc.SetPasswordPolicy(policy)
// Verify policy is set
if !svc.passwordPolicySet {
t.Error("Expected passwordPolicySet to be true")
}
if svc.passwordPolicy.MinLength != 12 {
t.Errorf("Expected MinLength 12, got %d", svc.passwordPolicy.MinLength)
}
})
t.Run("SetPasswordPolicy with defaults", func(t *testing.T) {
svc2 := &AuthService{cache: cacheManager}
policy := security.PasswordPolicy{} // Empty policy
svc2.SetPasswordPolicy(policy)
// Should normalize to default min length 8
if svc2.passwordPolicy.MinLength != 8 {
t.Errorf("Expected normalized MinLength 8, got %d", svc2.passwordPolicy.MinLength)
}
})
}
// =============================================================================
// Social Account Helper Tests
// =============================================================================
func TestFindSocialAccountByProvider(t *testing.T) {
tests := []struct {
name string
accounts []*domain.SocialAccount
provider string
expectNil bool
}{
{
name: "nil accounts",
accounts: nil,
provider: "github",
expectNil: true,
},
{
name: "empty accounts",
accounts: []*domain.SocialAccount{},
provider: "github",
expectNil: true,
},
{
name: "found matching provider",
accounts: []*domain.SocialAccount{
{Provider: "github", OpenID: "123"},
{Provider: "google", OpenID: "456"},
},
provider: "github",
expectNil: false,
},
{
name: "case insensitive match",
accounts: []*domain.SocialAccount{
{Provider: "GitHub", OpenID: "123"},
},
provider: "github",
expectNil: false,
},
{
name: "provider not found",
accounts: []*domain.SocialAccount{
{Provider: "google", OpenID: "456"},
},
provider: "github",
expectNil: true,
},
{
name: "nil account in list",
accounts: []*domain.SocialAccount{
nil,
{Provider: "github", OpenID: "123"},
},
provider: "github",
expectNil: false,
},
{
name: "empty provider",
accounts: []*domain.SocialAccount{
{Provider: "github", OpenID: "123"},
},
provider: "",
expectNil: true,
},
{
name: "provider with spaces",
accounts: []*domain.SocialAccount{
{Provider: " github ", OpenID: "123"},
},
provider: "github",
expectNil: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := findSocialAccountByProvider(tt.accounts, tt.provider)
if (result == nil) != tt.expectNil {
t.Errorf("findSocialAccountByProvider() nil = %v, expectNil = %v", result == nil, tt.expectNil)
}
})
}
}
// =============================================================================
// Available Login Method Count Tests
// =============================================================================
func TestAuthService_AvailableLoginMethodCount(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
t.Run("nil user", func(t *testing.T) {
svc := &AuthService{cache: cacheManager}
count := svc.availableLoginMethodCount(nil, nil, "")
if count != 0 {
t.Errorf("Expected 0 for nil user, got %d", count)
}
})
t.Run("password only", func(t *testing.T) {
svc := &AuthService{cache: cacheManager}
user := &domain.User{Password: "hashed_password"}
count := svc.availableLoginMethodCount(user, nil, "")
if count != 1 {
t.Errorf("Expected 1 for password only, got %d", count)
}
})
t.Run("password and email with email service", func(t *testing.T) {
email := "test@example.com"
svc := &AuthService{
cache: cacheManager,
emailCodeSvc: &EmailCodeService{},
}
user := &domain.User{Password: "hashed_password", Email: &email}
count := svc.availableLoginMethodCount(user, nil, "")
if count != 2 {
t.Errorf("Expected 2 for password and email, got %d", count)
}
})
t.Run("password and phone with sms service", func(t *testing.T) {
phone := "13800138000"
svc := &AuthService{
cache: cacheManager,
smsCodeSvc: &SMSCodeService{},
}
user := &domain.User{Password: "hashed_password", Phone: &phone}
count := svc.availableLoginMethodCount(user, nil, "")
if count != 2 {
t.Errorf("Expected 2 for password and phone, got %d", count)
}
})
t.Run("all methods", func(t *testing.T) {
email := "test@example.com"
phone := "13800138000"
svc := &AuthService{
cache: cacheManager,
emailCodeSvc: &EmailCodeService{},
smsCodeSvc: &SMSCodeService{},
}
user := &domain.User{Password: "hashed_password", Email: &email, Phone: &phone}
accounts := []*domain.SocialAccount{
{Provider: "github", Status: domain.SocialAccountStatusActive},
}
count := svc.availableLoginMethodCount(user, accounts, "")
if count != 4 {
t.Errorf("Expected 4 for all methods, got %d", count)
}
})
t.Run("exclude social provider", func(t *testing.T) {
email := "test@example.com"
svc := &AuthService{
cache: cacheManager,
emailCodeSvc: &EmailCodeService{},
}
user := &domain.User{Password: "hashed_password", Email: &email}
accounts := []*domain.SocialAccount{
{Provider: "github", Status: domain.SocialAccountStatusActive},
{Provider: "google", Status: domain.SocialAccountStatusActive},
}
count := svc.availableLoginMethodCount(user, accounts, "github")
// password + email + google (github excluded)
if count != 3 {
t.Errorf("Expected 3 with github excluded, got %d", count)
}
})
t.Run("inactive social accounts not counted", func(t *testing.T) {
svc := &AuthService{cache: cacheManager}
user := &domain.User{Password: "hashed_password"}
accounts := []*domain.SocialAccount{
{Provider: "github", Status: domain.SocialAccountStatusActive},
{Provider: "google", Status: 0}, // inactive
nil, // nil account
}
count := svc.availableLoginMethodCount(user, accounts, "")
// password + github only
if count != 2 {
t.Errorf("Expected 2 with inactive filtered, got %d", count)
}
})
t.Run("empty password not counted", func(t *testing.T) {
svc := &AuthService{cache: cacheManager}
user := &domain.User{Password: " "}
count := svc.availableLoginMethodCount(user, nil, "")
if count != 0 {
t.Errorf("Expected 0 for empty password, got %d", count)
}
})
}
// =============================================================================
// Generate Unique Username Tests
// =============================================================================
func TestGenerateUniqueUsername(t *testing.T) {
t.Run("nil service returns sanitized username", func(t *testing.T) {
var nilSvc *AuthService
username, err := nilSvc.generateUniqueUsername(context.Background(), "Test User")
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if username != "test_user" {
t.Errorf("Expected 'test_user', got %q", username)
}
})
t.Run("service with nil userRepo returns sanitized username", func(t *testing.T) {
svc := &AuthService{}
username, err := svc.generateUniqueUsername(context.Background(), "John Doe")
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if username != "john_doe" {
t.Errorf("Expected 'john_doe', got %q", username)
}
})
t.Run("long username is truncated", func(t *testing.T) {
svc := &AuthService{}
longName := "this_is_a_very_long_username_that_should_be_truncated_to_forty_characters"
username, err := svc.generateUniqueUsername(context.Background(), longName)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if len(username) > 50 {
t.Errorf("Username should be max 50 chars, got %d", len(username))
}
})
t.Run("empty base returns user", func(t *testing.T) {
svc := &AuthService{}
username, err := svc.generateUniqueUsername(context.Background(), "")
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if username != "user" {
t.Errorf("Expected 'user', got %q", username)
}
})
}
// =============================================================================
// Set Login Log Repository Tests
// =============================================================================
func TestAuthService_SetLoginLogRepository(t *testing.T) {
svc := &AuthService{}
// Should not panic with nil
svc.SetLoginLogRepository(nil)
}
// =============================================================================
// Set Anomaly Detector Tests
// =============================================================================
func TestAuthService_SetAnomalyDetector(t *testing.T) {
svc := &AuthService{}
// Should not panic with nil
svc.SetAnomalyDetector(nil)
}
// =============================================================================
// Set Device Service Tests
// =============================================================================
func TestAuthService_SetDeviceService(t *testing.T) {
svc := &AuthService{}
// Should not panic with nil
svc.SetDeviceService(nil)
}
// =============================================================================
// Set SMS Code Service Tests
// =============================================================================
func TestAuthService_SetSMSCodeService(t *testing.T) {
svc := &AuthService{}
// Should not panic with nil
svc.SetSMSCodeService(nil)
}
// =============================================================================
// Available Login Method Count After Contact Removal Tests
// =============================================================================
func TestAuthService_AvailableLoginMethodCountAfterContactRemoval(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
t.Run("nil user", func(t *testing.T) {
svc := &AuthService{cache: cacheManager}
count := svc.availableLoginMethodCountAfterContactRemoval(nil, nil, false, false)
if count != 0 {
t.Errorf("Expected 0 for nil user, got %d", count)
}
})
t.Run("password only no removal", func(t *testing.T) {
svc := &AuthService{cache: cacheManager}
user := &domain.User{Password: "hashed_password"}
count := svc.availableLoginMethodCountAfterContactRemoval(user, nil, false, false)
if count != 1 {
t.Errorf("Expected 1 for password only, got %d", count)
}
})
t.Run("password and email with email service", func(t *testing.T) {
email := "test@example.com"
svc := &AuthService{
cache: cacheManager,
emailCodeSvc: &EmailCodeService{},
}
user := &domain.User{Password: "hashed_password", Email: &email}
count := svc.availableLoginMethodCountAfterContactRemoval(user, nil, false, false)
if count != 2 {
t.Errorf("Expected 2 for password and email, got %d", count)
}
})
t.Run("remove email", func(t *testing.T) {
email := "test@example.com"
svc := &AuthService{
cache: cacheManager,
emailCodeSvc: &EmailCodeService{},
}
user := &domain.User{Password: "hashed_password", Email: &email}
count := svc.availableLoginMethodCountAfterContactRemoval(user, nil, true, false)
if count != 1 {
t.Errorf("Expected 1 after email removal, got %d", count)
}
})
t.Run("remove phone", func(t *testing.T) {
phone := "13800138000"
svc := &AuthService{
cache: cacheManager,
smsCodeSvc: &SMSCodeService{},
}
user := &domain.User{Password: "hashed_password", Phone: &phone}
count := svc.availableLoginMethodCountAfterContactRemoval(user, nil, false, true)
if count != 1 {
t.Errorf("Expected 1 after phone removal, got %d", count)
}
})
t.Run("social accounts counted", func(t *testing.T) {
svc := &AuthService{cache: cacheManager}
user := &domain.User{Password: "hashed_password"}
accounts := []*domain.SocialAccount{
{Provider: "github", Status: domain.SocialAccountStatusActive},
{Provider: "google", Status: domain.SocialAccountStatusActive},
}
count := svc.availableLoginMethodCountAfterContactRemoval(user, accounts, false, false)
if count != 3 {
t.Errorf("Expected 3 with social accounts, got %d", count)
}
})
t.Run("inactive social accounts not counted", func(t *testing.T) {
svc := &AuthService{cache: cacheManager}
user := &domain.User{Password: "hashed_password"}
accounts := []*domain.SocialAccount{
{Provider: "github", Status: domain.SocialAccountStatusActive},
{Provider: "google", Status: 0}, // inactive
nil,
}
count := svc.availableLoginMethodCountAfterContactRemoval(user, accounts, false, false)
if count != 2 {
t.Errorf("Expected 2 with inactive filtered, got %d", count)
}
})
}
// =============================================================================
// Register OAuth Provider Tests
// =============================================================================
func TestAuthService_RegisterOAuthProvider(t *testing.T) {
t.Run("nil config does nothing", func(t *testing.T) {
svc := &AuthService{}
// Should not panic with nil config
svc.RegisterOAuthProvider("github", nil)
})
t.Run("nil oauth manager", func(t *testing.T) {
svc := &AuthService{}
cfg := &auth.OAuthConfig{ClientID: "test"}
// Should not panic with nil oauthManager
svc.RegisterOAuthProvider("github", cfg)
})
}
// =============================================================================
// Best Effort Register Device Public Tests
// =============================================================================
func TestAuthService_BestEffortRegisterDevicePublic(t *testing.T) {
t.Run("nil service does not panic", func(t *testing.T) {
var nilSvc *AuthService
// Should not panic
nilSvc.BestEffortRegisterDevicePublic(context.Background(), 1, nil)
})
t.Run("nil device service does not panic", func(t *testing.T) {
svc := &AuthService{}
svc.BestEffortRegisterDevicePublic(context.Background(), 1, &LoginRequest{})
// Should not panic
})
}
// =============================================================================
// Int Value and Int64 Value Tests
// =============================================================================
func TestIntValue(t *testing.T) {
tests := []struct {
name string
input interface{}
expected int
wantOk bool
}{
{"int value", 42, 42, true},
{"int64 value", int64(100), 100, true},
{"float64 value", float64(99.0), 99, true},
{"float64 with decimal", float64(99.5), 99, true},
{"string value", "42", 0, false},
{"nil value", nil, 0, false},
{"negative int", -5, -5, true},
{"zero value", 0, 0, true},
{"large int64", int64(9999999999), 9999999999, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, ok := intValue(tt.input)
if result != tt.expected || ok != tt.wantOk {
t.Errorf("intValue(%v) = (%d, %v), want (%d, %v)", tt.input, result, ok, tt.expected, tt.wantOk)
}
})
}
}
func TestInt64Value(t *testing.T) {
tests := []struct {
name string
input interface{}
expected int64
wantOk bool
}{
{"int value", 42, 42, true},
{"int64 value", int64(100), 100, true},
{"float64 value", float64(99.0), 99, true},
{"float64 with decimal", float64(99.5), 99, true},
{"string value", "42", 0, false},
{"nil value", nil, 0, false},
{"negative int64", int64(-5), -5, true},
{"zero value", 0, 0, true},
{"large int64", int64(9999999999), 9999999999, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, ok := int64Value(tt.input)
if result != tt.expected || ok != tt.wantOk {
t.Errorf("int64Value(%v) = (%d, %v), want (%d, %v)", tt.input, result, ok, tt.expected, tt.wantOk)
}
})
}
}
// =============================================================================
// Best Effort Update Last Login Tests
// =============================================================================
func TestBestEffortUpdateLastLogin(t *testing.T) {
t.Run("nil service does not panic", func(t *testing.T) {
var nilSvc *AuthService
// Should not panic
nilSvc.bestEffortUpdateLastLogin(context.Background(), 1, "127.0.0.1", "password")
})
}
// =============================================================================
// Best Effort Assign Default Roles Tests
// =============================================================================
func TestBestEffortAssignDefaultRoles(t *testing.T) {
t.Run("nil service does not panic", func(t *testing.T) {
var nilSvc *AuthService
nilSvc.bestEffortAssignDefaultRoles(context.Background(), 1, "register")
})
t.Run("service without repos does not panic", func(t *testing.T) {
svc := &AuthService{}
svc.bestEffortAssignDefaultRoles(context.Background(), 1, "register")
})
}
// =============================================================================
// Create OAuth State Payload Tests
// =============================================================================
func TestCreateOAuthStatePayload(t *testing.T) {
t.Run("nil service returns error", func(t *testing.T) {
var nilSvc *AuthService
_, err := nilSvc.createOAuthStatePayload(context.Background(), &OAuthStatePayload{Purpose: OAuthStatePurposeLogin})
if err == nil {
t.Error("Expected error for nil service")
}
})
t.Run("service without cache returns error", func(t *testing.T) {
svc := &AuthService{}
_, err := svc.createOAuthStatePayload(context.Background(), &OAuthStatePayload{Purpose: OAuthStatePurposeLogin})
if err == nil {
t.Error("Expected error when cache not configured")
}
})
t.Run("nil payload returns error", func(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
svc := &AuthService{cache: cacheManager}
_, err := svc.createOAuthStatePayload(context.Background(), nil)
if err == nil {
t.Error("Expected error for nil payload")
}
})
t.Run("create state payload with cache", func(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
svc := &AuthService{cache: cacheManager}
state, err := svc.createOAuthStatePayload(context.Background(), &OAuthStatePayload{
Purpose: OAuthStatePurposeLogin,
ReturnTo: "http://localhost/callback",
})
if err != nil {
t.Fatalf("createOAuthStatePayload failed: %v", err)
}
if state == "" {
t.Error("Expected non-empty state")
}
})
t.Run("create state payload with default purpose", func(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
svc := &AuthService{cache: cacheManager}
state, err := svc.createOAuthStatePayload(context.Background(), &OAuthStatePayload{
ReturnTo: "http://localhost/callback",
})
if err != nil {
t.Fatalf("createOAuthStatePayload failed: %v", err)
}
if state == "" {
t.Error("Expected non-empty state")
}
})
}
// =============================================================================
// Verify Phone Registration Tests
// =============================================================================
func TestVerifyPhoneRegistration(t *testing.T) {
t.Run("nil service returns nil for empty phone", func(t *testing.T) {
var nilSvc *AuthService
err := nilSvc.verifyPhoneRegistration(context.Background(), &RegisterRequest{Phone: ""})
if err != nil {
t.Errorf("Expected nil error for empty phone, got: %v", err)
}
})
t.Run("nil request returns nil", func(t *testing.T) {
svc := &AuthService{}
err := svc.verifyPhoneRegistration(context.Background(), nil)
if err != nil {
t.Errorf("Expected nil error for nil request, got: %v", err)
}
})
t.Run("service without SMS returns error", func(t *testing.T) {
svc := &AuthService{}
err := svc.verifyPhoneRegistration(context.Background(), &RegisterRequest{Phone: "13800138000", PhoneCode: "123456"})
if err == nil {
t.Error("Expected error when SMS service not configured")
}
})
t.Run("empty phone code returns error", func(t *testing.T) {
svc := &AuthService{smsCodeSvc: &SMSCodeService{}}
err := svc.verifyPhoneRegistration(context.Background(), &RegisterRequest{Phone: "13800138000", PhoneCode: ""})
if err == nil {
t.Error("Expected error when phone code is empty")
}
})
}
// =============================================================================
// Consume OAuth State Payload Tests
// =============================================================================
func TestConsumeOAuthStatePayload(t *testing.T) {
t.Run("nil service returns error", func(t *testing.T) {
var nilSvc *AuthService
_, err := nilSvc.ConsumeOAuthStatePayload(context.Background(), "state123")
if err == nil {
t.Error("Expected error for nil service")
}
})
t.Run("service without cache returns error", func(t *testing.T) {
svc := &AuthService{}
_, err := svc.ConsumeOAuthStatePayload(context.Background(), "state123")
if err == nil {
t.Error("Expected error when cache not configured")
}
})
}
// =============================================================================
// Consume OAuth Handoff Tests
// =============================================================================
func TestConsumeOAuthHandoff(t *testing.T) {
t.Run("nil service returns error", func(t *testing.T) {
var nilSvc *AuthService
_, err := nilSvc.ConsumeOAuthHandoff(context.Background(), "code123")
if err == nil {
t.Error("Expected error for nil service")
}
})
t.Run("service without cache returns error", func(t *testing.T) {
svc := &AuthService{}
_, err := svc.ConsumeOAuthHandoff(context.Background(), "code123")
if err == nil {
t.Error("Expected error when cache not configured")
}
})
}
// =============================================================================
// Consume OAuth Handoff With Cache Tests
// =============================================================================
func TestConsumeOAuthHandoff_WithCache(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
svc := &AuthService{cache: cacheManager}
ctx := context.Background()
t.Run("consume non-existent handoff", func(t *testing.T) {
_, err := svc.ConsumeOAuthHandoff(ctx, "nonexistent_code")
if err == nil {
t.Error("Expected error for non-existent handoff")
}
})
t.Run("consume handoff with pointer response", func(t *testing.T) {
resp := &LoginResponse{
AccessToken: "test_access_token",
RefreshToken: "test_refresh_token",
}
cacheManager.Set(ctx, "oauth_handoff:test_code_1", resp, time.Minute, time.Minute)
result, err := svc.ConsumeOAuthHandoff(ctx, "test_code_1")
if err != nil {
t.Fatalf("ConsumeOAuthHandoff failed: %v", err)
}
if result.AccessToken != "test_access_token" {
t.Errorf("Expected access token, got %s", result.AccessToken)
}
})
t.Run("consume handoff with value response", func(t *testing.T) {
resp := LoginResponse{
AccessToken: "value_access_token",
RefreshToken: "value_refresh_token",
}
cacheManager.Set(ctx, "oauth_handoff:test_code_2", resp, time.Minute, time.Minute)
result, err := svc.ConsumeOAuthHandoff(ctx, "test_code_2")
if err != nil {
t.Fatalf("ConsumeOAuthHandoff failed: %v", err)
}
if result.AccessToken != "value_access_token" {
t.Errorf("Expected access token, got %s", result.AccessToken)
}
})
}
// =============================================================================
// Consume OAuth State Payload With Cache Tests
// =============================================================================
func TestConsumeOAuthStatePayload_WithCache(t *testing.T) {
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
svc := &AuthService{cache: cacheManager}
ctx := context.Background()
t.Run("consume non-existent state", func(t *testing.T) {
_, err := svc.ConsumeOAuthStatePayload(ctx, "nonexistent_state")
if err == nil {
t.Error("Expected error for non-existent state")
}
})
t.Run("consume state with pointer payload", func(t *testing.T) {
payload := &OAuthStatePayload{
Purpose: OAuthStatePurposeLogin,
ReturnTo: "http://localhost/callback",
}
cacheManager.Set(ctx, "oauth_state:test_state_1", payload, time.Minute*10, time.Minute*10)
result, err := svc.ConsumeOAuthStatePayload(ctx, "test_state_1")
if err != nil {
t.Fatalf("ConsumeOAuthStatePayload failed: %v", err)
}
if result.Purpose != OAuthStatePurposeLogin {
t.Errorf("Expected purpose %s, got %s", OAuthStatePurposeLogin, result.Purpose)
}
})
t.Run("consume state with value payload", func(t *testing.T) {
payload := OAuthStatePayload{
Purpose: OAuthStatePurposeBind,
ReturnTo: "http://localhost/bind",
UserID: 123,
}
cacheManager.Set(ctx, "oauth_state:test_state_2", payload, time.Minute*10, time.Minute*10)
result, err := svc.ConsumeOAuthStatePayload(ctx, "test_state_2")
if err != nil {
t.Fatalf("ConsumeOAuthStatePayload failed: %v", err)
}
if result.UserID != 123 {
t.Errorf("Expected UserID 123, got %d", result.UserID)
}
})
}