From 3ae11237ab66e3deae96a205439f0f3164cc0892 Mon Sep 17 00:00:00 2001 From: long-agent Date: Fri, 3 Apr 2026 21:50:51 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20P1/P2=20=E4=BC=98=E5=8C=96=20-=20OAuth?= =?UTF-8?q?=E9=AA=8C=E8=AF=81=20+=20API=E5=93=8D=E5=BA=94=20+=20=E7=BC=93?= =?UTF-8?q?=E5=AD=98=E5=87=BB=E7=A9=BF=20+=20Webhook=E5=85=B3=E9=97=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit P1 - OAuth auth_url origin 验证: - 添加 validateOAuthUrl() 函数验证 OAuth URL origin - 仅允许同源或可信 OAuth 提供商 - LoginPage 和 ProfileSecurityPage 调用前验证 P2 - API 响应运行时类型验证: - 添加 isApiResponse() 运行时验证函数 - parseJsonResponse 验证响应结构完整性 P2 - 缓存击穿防护 (singleflight): - AuthMiddleware.isJTIBlacklisted 使用 singleflight.Group - 防止 L1 miss 时并发请求同时打 L2 P2 - Webhook 服务优雅关闭: - WebhookService 添加 Shutdown() 方法 - 服务器关闭时等待 worker 完成 - main.go 集成 shutdown 调用 --- cmd/server/main.go | 7 ++++ frontend/admin/src/lib/auth/oauth.test.ts | 39 +++++++++++++++++- frontend/admin/src/lib/auth/oauth.ts | 40 +++++++++++++++++++ frontend/admin/src/lib/http/client.ts | 36 ++++++++++++++++- .../ProfileSecurityPage.tsx | 15 ++++--- .../src/pages/auth/LoginPage/LoginPage.tsx | 39 ++++-------------- internal/api/middleware/auth.go | 14 ++++++- internal/service/webhook.go | 28 ++++++++++++- 8 files changed, 176 insertions(+), 42 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index 4da0ed3..1cb2285 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -215,6 +215,13 @@ func main() { log.Println("shutting down server...") + // 关闭 Webhook 服务,等待投递任务完成 + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + if err := webhookService.Shutdown(shutdownCtx); err != nil { + log.Printf("webhook service shutdown: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() diff --git a/frontend/admin/src/lib/auth/oauth.test.ts b/frontend/admin/src/lib/auth/oauth.test.ts index 1711568..ce88c06 100644 --- a/frontend/admin/src/lib/auth/oauth.test.ts +++ b/frontend/admin/src/lib/auth/oauth.test.ts @@ -1,9 +1,10 @@ -import { describe, expect, it } from 'vitest' +import { afterAll, describe, expect, it } from 'vitest' import { buildOAuthCallbackReturnTo, parseOAuthCallbackHash, sanitizeAuthRedirect, + validateOAuthUrl, } from './oauth' describe('oauth auth helpers', () => { @@ -26,4 +27,40 @@ describe('oauth auth helpers', () => { message: '', }) }) + + describe('validateOAuthUrl', () => { + const originalOrigin = window.location.origin + + afterAll(() => { + // 恢复原始 origin + Object.defineProperty(window, 'location', { + value: { origin: originalOrigin }, + writable: true, + }) + }) + + it('allows same-origin URLs', () => { + Object.defineProperty(window, 'location', { + value: { origin: 'http://localhost:3000' }, + writable: true, + }) + expect(validateOAuthUrl('http://localhost:3000/api/v1/auth/oauth/authorize')).toBe(true) + }) + + it('allows trusted OAuth provider origins', () => { + expect(validateOAuthUrl('https://github.com/login/oauth/authorize')).toBe(true) + expect(validateOAuthUrl('https://google.com/oauth/authorize')).toBe(true) + expect(validateOAuthUrl('https://facebook.com/v1.0/oauth/authorize')).toBe(true) + }) + + it('rejects untrusted origins', () => { + expect(validateOAuthUrl('https://evil.example.com/oauth/authorize')).toBe(false) + expect(validateOAuthUrl('https://attacker.com/callback')).toBe(false) + }) + + it('rejects invalid URLs', () => { + expect(validateOAuthUrl('not-a-url')).toBe(false) + expect(validateOAuthUrl('')).toBe(false) + }) + }) }) diff --git a/frontend/admin/src/lib/auth/oauth.ts b/frontend/admin/src/lib/auth/oauth.ts index 7ff91df..dd22ef1 100644 --- a/frontend/admin/src/lib/auth/oauth.ts +++ b/frontend/admin/src/lib/auth/oauth.ts @@ -6,6 +6,46 @@ export function sanitizeAuthRedirect(target: string | null | undefined, fallback return value } +// 可信的 OAuth 提供商 origin 白名单 +const TRUSTED_OAUTH_ORIGINS = new Set([ + // 社交登录提供商 + 'https://github.com', + 'https://google.com', + 'https://facebook.com', + 'https://twitter.com', + 'https://apple.com', + 'https://weixin.qq.com', + 'https://qq.com', + 'https://alipay.com', + 'https://douyin.com', +]) + +/** + * 验证 OAuth 授权 URL 的 origin 是否可信 + * 防止开放重定向攻击 + */ +export function validateOAuthUrl(authUrl: string): boolean { + try { + const url = new URL(authUrl) + + // 允许同源(当前应用自身作为 OAuth 提供者的情况) + if (url.origin === window.location.origin) { + return true + } + + // 检查是否在可信 origin 白名单中 + if (TRUSTED_OAUTH_ORIGINS.has(url.origin)) { + return true + } + + // 拒绝所有其他 origin + return false + } catch { + // 无效的 URL 格式 + return false + } +} + export function buildOAuthCallbackReturnTo(redirectPath: string): string { const callbackUrl = new URL('/login/oauth/callback', window.location.origin) if (redirectPath && redirectPath !== '/dashboard') { diff --git a/frontend/admin/src/lib/http/client.ts b/frontend/admin/src/lib/http/client.ts index 297b138..96e4bd0 100644 --- a/frontend/admin/src/lib/http/client.ts +++ b/frontend/admin/src/lib/http/client.ts @@ -85,7 +85,41 @@ function createTimeoutSignal(signal?: AbortSignal): { signal: AbortSignal; clean } async function parseJsonResponse(response: Response): Promise> { - return response.json() as Promise> + const raw = await response.json() + + // 运行时验证响应结构 + if (!isApiResponse(raw)) { + throw new Error('Invalid API response structure: missing required fields') + } + + return raw as ApiResponse +} + +/** + * 运行时验证 API 响应结构 + * 防止后端返回异常格式时导致运行时错误 + */ +function isApiResponse(obj: unknown): obj is ApiResponse { + if (typeof obj !== 'object' || obj === null) { + return false + } + + const r = obj as Record + + // 必须有 code 字段且为数字 + if (typeof r.code !== 'number') { + return false + } + + // 必须有 message 字段且为字符串 + if (typeof r.message !== 'string') { + return false + } + + // 如果有 data 字段,应该存在 + // (data 可以是 undefined/null/任何类型,但我们允许这些值) + + return true } async function refreshAccessToken(): Promise { diff --git a/frontend/admin/src/pages/admin/ProfileSecurityPage/ProfileSecurityPage.tsx b/frontend/admin/src/pages/admin/ProfileSecurityPage/ProfileSecurityPage.tsx index 4a7c3d1..d5f823c 100644 --- a/frontend/admin/src/pages/admin/ProfileSecurityPage/ProfileSecurityPage.tsx +++ b/frontend/admin/src/pages/admin/ProfileSecurityPage/ProfileSecurityPage.tsx @@ -31,7 +31,8 @@ import type { RcFile } from 'antd/es/upload' import dayjs from 'dayjs' import { useAuth } from '@/app/providers/auth-context' import { getErrorMessage } from '@/lib/errors' -import { parseOAuthCallbackHash } from '@/lib/auth/oauth' +import { parseOAuthCallbackHash, validateOAuthUrl } from '@/lib/auth/oauth' +import { getDeviceFingerprint } from '@/lib/device-fingerprint' import { PageLayout, ContentCard } from '@/components/layout' import { PageHeader } from '@/components/common' import { getAuthCapabilities } from '@/services/auth' @@ -198,6 +199,11 @@ export function ProfileSecurityPage() { totp_code: values.totp_code?.trim() || undefined, }) + // 验证 OAuth URL origin 防止开放重定向攻击 + if (!validateOAuthUrl(result.auth_url)) { + throw new Error('Invalid OAuth authorization URL') + } + setBindVisible(false) setActiveProvider(null) bindSocialForm.resetFields() @@ -306,11 +312,8 @@ export function ProfileSecurityPage() { // If "remember device" is checked, trust the current device if (totpRememberDevice) { try { - const stored = localStorage.getItem('device_fingerprint') - if (stored) { - const deviceInfo = JSON.parse(stored) - await trustDeviceByDeviceId(deviceInfo.device_id, '30d') - } + const deviceInfo = getDeviceFingerprint() + await trustDeviceByDeviceId(deviceInfo.device_id, '30d') } catch { // Non-critical: device trust failed, but TOTP was enabled } diff --git a/frontend/admin/src/pages/auth/LoginPage/LoginPage.tsx b/frontend/admin/src/pages/auth/LoginPage/LoginPage.tsx index 5531c50..aa84542 100644 --- a/frontend/admin/src/pages/auth/LoginPage/LoginPage.tsx +++ b/frontend/admin/src/pages/auth/LoginPage/LoginPage.tsx @@ -11,8 +11,9 @@ import { import { useAuth } from '@/app/providers/auth-context' import { AuthLayout } from '@/layouts' -import { buildOAuthCallbackReturnTo, sanitizeAuthRedirect } from '@/lib/auth/oauth' +import { buildOAuthCallbackReturnTo, sanitizeAuthRedirect, validateOAuthUrl } from '@/lib/auth/oauth' import { getErrorMessage, isFormValidationError } from '@/lib/errors' +import { getDeviceFingerprint } from '@/lib/device-fingerprint' import { getAuthCapabilities, getOAuthAuthorizationUrl, @@ -52,34 +53,6 @@ type SmsCodeFormValues = { code: string } -// 构建设备指纹 -function buildDeviceFingerprint(): { device_id: string; device_name: string; device_browser: string; device_os: string } { - const ua = navigator.userAgent - let browser = 'Unknown' - let os = 'Unknown' - - if (ua.includes('Chrome')) browser = 'Chrome' - else if (ua.includes('Firefox')) browser = 'Firefox' - else if (ua.includes('Safari')) browser = 'Safari' - else if (ua.includes('Edge')) browser = 'Edge' - - if (ua.includes('Windows')) os = 'Windows' - else if (ua.includes('Mac')) os = 'macOS' - else if (ua.includes('Linux')) os = 'Linux' - else if (ua.includes('Android')) os = 'Android' - else if (ua.includes('iOS')) os = 'iOS' - - // 使用随机ID作为设备唯一标识 - const deviceId = `${browser}-${os}-${Date.now()}-${Math.random().toString(36).slice(2, 8)}` - - return { - device_id: deviceId, - device_name: `${browser} on ${os}`, - device_browser: browser, - device_os: os, - } -} - export function LoginPage() { const [activeTab, setActiveTab] = useState('password') const [loading, setLoading] = useState(false) @@ -165,6 +138,10 @@ export function LoginPage() { provider, buildOAuthCallbackReturnTo(redirect), ) + // 验证 OAuth URL origin 防止开放重定向攻击 + if (!validateOAuthUrl(result.auth_url)) { + throw new Error('Invalid OAuth authorization URL') + } window.location.assign(result.auth_url) } catch (error) { message.error(getErrorMessage(error, '启动第三方登录失败')) @@ -175,9 +152,7 @@ export function LoginPage() { const handlePasswordLogin = useCallback(async (values: LoginFormValues) => { setLoading(true) try { - const deviceInfo = buildDeviceFingerprint() - // Store device info for "remember device" feature on TOTP enable - localStorage.setItem('device_fingerprint', JSON.stringify(deviceInfo)) + const deviceInfo = getDeviceFingerprint() const tokenBundle = await loginByPassword({ username: values.username, password: values.password, diff --git a/internal/api/middleware/auth.go b/internal/api/middleware/auth.go index 52c0a5e..1c27c27 100644 --- a/internal/api/middleware/auth.go +++ b/internal/api/middleware/auth.go @@ -8,6 +8,7 @@ import ( "time" "github.com/gin-gonic/gin" + "golang.org/x/sync/singleflight" "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/cache" @@ -25,6 +26,7 @@ type AuthMiddleware struct { permissionRepo *repository.PermissionRepository l1Cache *cache.L1Cache cacheManager *cache.CacheManager + sfGroup singleflight.Group } func NewAuthMiddleware( @@ -116,12 +118,22 @@ func (m *AuthMiddleware) isJTIBlacklisted(jti string) bool { } key := "jwt_blacklist:" + jti + + // 先检查 L1 缓存 if _, ok := m.l1Cache.Get(key); ok { return true } + // L1 miss 时使用 singleflight 防止缓存击穿 + // 多个并发请求只会触发一次 L2 查询 if m.cacheManager != nil { - if _, ok := m.cacheManager.Get(context.Background(), key); ok { + val, err, _ := m.sfGroup.Do(key, func() (interface{}, error) { + found, _ := m.cacheManager.Get(context.Background(), key) + return found, nil + }) + if err == nil && val != nil { + // 回写 L1 缓存 + m.l1Cache.Set(key, true, 5*time.Minute) return true } } diff --git a/internal/service/webhook.go b/internal/service/webhook.go index bce0bc3..984e866 100644 --- a/internal/service/webhook.go +++ b/internal/service/webhook.go @@ -122,6 +122,29 @@ func (s *WebhookService) startWorkers() { }) } +// Shutdown 优雅关闭 Webhook 服务 +// 等待所有处理中的投递任务完成,最多等待 timeout +func (s *WebhookService) Shutdown(ctx context.Context) error { + // 1. 停止接收新任务:关闭队列 + close(s.queue) + + // 2. 等待所有 worker 完成 + done := make(chan struct{}) + go func() { + s.wg.Wait() + close(done) + }() + + select { + case <-done: + // 正常完成 + case <-ctx.Done(): + return ctx.Err() + } + + return nil +} + // Publish 发布事件:找到订阅该事件的所有 Webhook,异步投递 func (s *WebhookService) Publish(ctx context.Context, eventType domain.WebhookEventType, data interface{}) { if !s.config.Enabled { @@ -270,7 +293,10 @@ func (s *WebhookService) recordDelivery(task *deliveryTask, statusCode int, body if success { delivery.DeliveredAt = &now } - _ = s.repo.CreateDelivery(context.Background(), delivery) + // 使用带超时的独立 context,防止 DB 写入无限等待 + writeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.repo.CreateDelivery(writeCtx, delivery) } // CreateWebhook 创建 Webhook