fix: P1-02 OAuth context propagation and P1-16 AuthProvider double-check
P1-02: OAuth ExchangeCode and GetUserInfo now accept context parameter
to properly propagate request context to HTTP calls
P1-16: AuthProvider isAuthenticated now uses single source of truth
(effectiveUser !== null) instead of double-checking both
React state and module-level function
This commit is contained in:
@@ -186,7 +186,7 @@ export function AuthProvider({ children }: AuthProviderProps) {
|
|||||||
user: effectiveUser,
|
user: effectiveUser,
|
||||||
roles: effectiveRoles,
|
roles: effectiveRoles,
|
||||||
isAdmin,
|
isAdmin,
|
||||||
isAuthenticated: effectiveUser !== null && isAuthenticated(),
|
isAuthenticated: effectiveUser !== null,
|
||||||
isLoading,
|
isLoading,
|
||||||
onLoginSuccess,
|
onLoginSuccess,
|
||||||
logout,
|
logout,
|
||||||
|
|||||||
@@ -63,10 +63,10 @@ type OAuthManager interface {
|
|||||||
GetAuthURL(provider OAuthProvider, state string) (string, error)
|
GetAuthURL(provider OAuthProvider, state string) (string, error)
|
||||||
|
|
||||||
// ExchangeCode 换取访问令牌
|
// ExchangeCode 换取访问令牌
|
||||||
ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error)
|
ExchangeCode(ctx context.Context, provider OAuthProvider, code string) (*OAuthToken, error)
|
||||||
|
|
||||||
// GetUserInfo 获取用户信息
|
// GetUserInfo 获取用户信息
|
||||||
GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error)
|
GetUserInfo(ctx context.Context, provider OAuthProvider, token *OAuthToken) (*OAuthUser, error)
|
||||||
|
|
||||||
// ValidateToken 验证令牌
|
// ValidateToken 验证令牌
|
||||||
ValidateToken(token string) (bool, error)
|
ValidateToken(token string) (bool, error)
|
||||||
@@ -203,14 +203,12 @@ func (m *DefaultOAuthManager) GetAuthURL(provider OAuthProvider, state string) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExchangeCode 换取访问令牌(使用真实 provider 实现)
|
// ExchangeCode 换取访问令牌(使用真实 provider 实现)
|
||||||
func (m *DefaultOAuthManager) ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error) {
|
func (m *DefaultOAuthManager) ExchangeCode(ctx context.Context, provider OAuthProvider, code string) (*OAuthToken, error) {
|
||||||
entry, ok := m.entries[provider]
|
entry, ok := m.entries[provider]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, ErrOAuthProviderNotSupported
|
return nil, ErrOAuthProviderNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
switch provider {
|
switch provider {
|
||||||
case OAuthProviderGoogle:
|
case OAuthProviderGoogle:
|
||||||
if entry.google != nil {
|
if entry.google != nil {
|
||||||
@@ -302,14 +300,12 @@ func (m *DefaultOAuthManager) ExchangeCode(provider OAuthProvider, code string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUserInfo 获取用户信息(使用真实 provider 实现)
|
// GetUserInfo 获取用户信息(使用真实 provider 实现)
|
||||||
func (m *DefaultOAuthManager) GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error) {
|
func (m *DefaultOAuthManager) GetUserInfo(ctx context.Context, provider OAuthProvider, token *OAuthToken) (*OAuthUser, error) {
|
||||||
entry, ok := m.entries[provider]
|
entry, ok := m.entries[provider]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, ErrOAuthProviderNotSupported
|
return nil, ErrOAuthProviderNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
switch provider {
|
switch provider {
|
||||||
case OAuthProviderGoogle:
|
case OAuthProviderGoogle:
|
||||||
if entry.google != nil {
|
if entry.google != nil {
|
||||||
@@ -448,8 +444,9 @@ func (m *DefaultOAuthManager) ValidateToken(token string) (bool, error) {
|
|||||||
}
|
}
|
||||||
// 尝试任一 provider 的 userinfo 端点验证
|
// 尝试任一 provider 的 userinfo 端点验证
|
||||||
tokenObj := &OAuthToken{AccessToken: token}
|
tokenObj := &OAuthToken{AccessToken: token}
|
||||||
|
ctx := context.Background()
|
||||||
for _, p := range providers {
|
for _, p := range providers {
|
||||||
if _, err := m.GetUserInfo(p.Provider, tokenObj); err == nil {
|
if _, err := m.GetUserInfo(ctx, p.Provider, tokenObj); err == nil {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -469,7 +466,8 @@ func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider,
|
|||||||
|
|
||||||
// 通过 provider 的 userinfo 端点验证 token
|
// 通过 provider 的 userinfo 端点验证 token
|
||||||
tokenObj := &OAuthToken{AccessToken: token}
|
tokenObj := &OAuthToken{AccessToken: token}
|
||||||
_, err := m.GetUserInfo(provider, tokenObj)
|
ctx := context.Background()
|
||||||
|
_, err := m.GetUserInfo(ctx, provider, tokenObj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ func TestDefaultOAuthManager_ExchangeCode(t *testing.T) {
|
|||||||
m := NewOAuthManager()
|
m := NewOAuthManager()
|
||||||
|
|
||||||
// Test non-existent provider
|
// Test non-existent provider
|
||||||
_, err := m.ExchangeCode(OAuthProviderGoogle, "test-code")
|
_, err := m.ExchangeCode(context.Background(), OAuthProviderGoogle, "test-code")
|
||||||
if err != ErrOAuthProviderNotSupported {
|
if err != ErrOAuthProviderNotSupported {
|
||||||
t.Errorf("Expected ErrOAuthProviderNotSupported, got %v", err)
|
t.Errorf("Expected ErrOAuthProviderNotSupported, got %v", err)
|
||||||
}
|
}
|
||||||
@@ -148,7 +148,7 @@ func TestDefaultOAuthManager_GetUserInfo(t *testing.T) {
|
|||||||
|
|
||||||
// Test non-existent provider
|
// Test non-existent provider
|
||||||
token := &OAuthToken{AccessToken: "test-token"}
|
token := &OAuthToken{AccessToken: "test-token"}
|
||||||
_, err := m.GetUserInfo(OAuthProviderGoogle, token)
|
_, err := m.GetUserInfo(context.Background(), OAuthProviderGoogle, token)
|
||||||
if err != ErrOAuthProviderNotSupported {
|
if err != ErrOAuthProviderNotSupported {
|
||||||
t.Errorf("Expected ErrOAuthProviderNotSupported, got %v", err)
|
t.Errorf("Expected ErrOAuthProviderNotSupported, got %v", err)
|
||||||
}
|
}
|
||||||
@@ -546,7 +546,7 @@ func TestOAuthManager_ExchangeCode_Errors(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// ExchangeCode should attempt HTTP call and fail
|
// ExchangeCode should attempt HTTP call and fail
|
||||||
_, err := m.ExchangeCode(OAuthProviderGoogle, "test-code")
|
_, err := m.ExchangeCode(context.Background(), OAuthProviderGoogle, "test-code")
|
||||||
// We expect an error because there's no mock server
|
// We expect an error because there's no mock server
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Log("ExchangeCode() unexpectedly succeeded - real network may be available")
|
t.Log("ExchangeCode() unexpectedly succeeded - real network may be available")
|
||||||
@@ -565,7 +565,7 @@ func TestOAuthManager_GetUserInfo_Errors(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
token := &OAuthToken{AccessToken: "test-token"}
|
token := &OAuthToken{AccessToken: "test-token"}
|
||||||
_, err := m.GetUserInfo(OAuthProviderGoogle, token)
|
_, err := m.GetUserInfo(context.Background(), OAuthProviderGoogle, token)
|
||||||
// We expect an error because there's no mock server
|
// We expect an error because there's no mock server
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Log("GetUserInfo() unexpectedly succeeded - real network may be available")
|
t.Log("GetUserInfo() unexpectedly succeeded - real network may be available")
|
||||||
|
|||||||
@@ -949,12 +949,12 @@ func (s *AuthService) OAuthCallback(ctx context.Context, provider, code string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
oauthProvider := auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider)))
|
oauthProvider := auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider)))
|
||||||
token, err := s.oauthManager.ExchangeCode(oauthProvider, strings.TrimSpace(code))
|
token, err := s.oauthManager.ExchangeCode(ctx, oauthProvider, strings.TrimSpace(code))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthUser, err := s.oauthManager.GetUserInfo(oauthProvider, token)
|
oauthUser, err := s.oauthManager.GetUserInfo(ctx, oauthProvider, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1127,12 +1127,12 @@ func (s *AuthService) OAuthBindCallback(ctx context.Context, userID int64, provi
|
|||||||
}
|
}
|
||||||
|
|
||||||
oauthProvider := auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider)))
|
oauthProvider := auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider)))
|
||||||
token, err := s.oauthManager.ExchangeCode(oauthProvider, strings.TrimSpace(code))
|
token, err := s.oauthManager.ExchangeCode(ctx, oauthProvider, strings.TrimSpace(code))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthUser, err := s.oauthManager.GetUserInfo(oauthProvider, token)
|
oauthUser, err := s.oauthManager.GetUserInfo(ctx, oauthProvider, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,14 +32,14 @@ func (m *mockOAuthManager) GetAuthURL(provider auth.OAuthProvider, state string)
|
|||||||
return m.authURL, nil
|
return m.authURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockOAuthManager) ExchangeCode(provider auth.OAuthProvider, code string) (*auth.OAuthToken, error) {
|
func (m *mockOAuthManager) ExchangeCode(ctx context.Context, provider auth.OAuthProvider, code string) (*auth.OAuthToken, error) {
|
||||||
if m.exchangeErr != nil {
|
if m.exchangeErr != nil {
|
||||||
return nil, m.exchangeErr
|
return nil, m.exchangeErr
|
||||||
}
|
}
|
||||||
return &auth.OAuthToken{AccessToken: "mock-token"}, nil
|
return &auth.OAuthToken{AccessToken: "mock-token"}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockOAuthManager) GetUserInfo(provider auth.OAuthProvider, token *auth.OAuthToken) (*auth.OAuthUser, error) {
|
func (m *mockOAuthManager) GetUserInfo(ctx context.Context, provider auth.OAuthProvider, token *auth.OAuthToken) (*auth.OAuthUser, error) {
|
||||||
if m.userInfoErr != nil {
|
if m.userInfoErr != nil {
|
||||||
return nil, m.userInfoErr
|
return nil, m.userInfoErr
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user