From 6be90ddff88362410f2a120a0b820b8116404f0f Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 28 May 2026 15:19:13 +0800 Subject: [PATCH] fix: close auth, permission, contract and e2e review blockers --- .../admin/scripts/run-playwright-auth-e2e.ps1 | 10 +- .../admin/scripts/run-playwright-auth-e2e.sh | 142 ++++++++ .../admin/scripts/run-playwright-cdp-e2e.mjs | 196 ++++++++--- frontend/admin/src/lib/http/client.ts | 57 +-- .../ContactBindingsSection.tsx | 14 +- .../BootstrapAdminPage.test.tsx | 6 +- .../BootstrapAdminPage/BootstrapAdminPage.tsx | 26 +- .../auth/RegisterPage/RegisterPage.test.tsx | 53 ++- .../pages/auth/RegisterPage/RegisterPage.tsx | 16 +- frontend/admin/src/services/auth.test.ts | 42 ++- frontend/admin/src/services/auth.ts | 17 +- .../src/services/social-accounts.test.ts | 23 ++ .../admin/src/services/social-accounts.ts | 31 +- frontend/admin/src/services/users.test.ts | 46 +++ frontend/admin/src/services/users.ts | 51 ++- frontend/admin/src/services/webhooks.test.ts | 38 ++ frontend/admin/src/services/webhooks.ts | 36 +- internal/api/handler/auth_handler.go | 124 ++++++- internal/api/handler/avatar_handler.go | 28 +- internal/api/handler/handler_test.go | 326 +++++++++++++++--- internal/api/handler/user_handler.go | 29 +- internal/api/middleware/ratelimit.go | 7 + internal/api/router/router.go | 6 +- internal/auth/jwt.go | 42 +++ internal/service/auth.go | 43 ++- internal/service/auth_login_test.go | 35 ++ .../service/auth_logout_failclosed_test.go | 87 +++++ internal/service/user_service.go | 57 ++- internal/service/user_service_test.go | 27 ++ 29 files changed, 1356 insertions(+), 259 deletions(-) create mode 100644 frontend/admin/scripts/run-playwright-auth-e2e.sh create mode 100644 internal/service/auth_logout_failclosed_test.go diff --git a/frontend/admin/scripts/run-playwright-auth-e2e.ps1 b/frontend/admin/scripts/run-playwright-auth-e2e.ps1 index 9388d05..034c2f9 100644 --- a/frontend/admin/scripts/run-playwright-auth-e2e.ps1 +++ b/frontend/admin/scripts/run-playwright-auth-e2e.ps1 @@ -216,6 +216,7 @@ $env:CORS_ALLOWED_ORIGINS = "$frontendBaseUrl,http://localhost:$selectedFrontend $env:VITE_API_PROXY_TARGET = $backendBaseUrl $env:VITE_API_BASE_URL = '/api/v1' + $env:NODE_ENV = 'development' $frontendHandle = Start-ManagedProcess ` -Name 'ums-frontend-playwright' ` -FilePath 'npm.cmd' ` @@ -288,10 +289,11 @@ $env:CORS_ALLOWED_ORIGINS = "$frontendBaseUrl,http://localhost:$selectedFrontend Remove-Item Env:EMAIL_PORT -ErrorAction SilentlyContinue Remove-Item Env:EMAIL_FROM_EMAIL -ErrorAction SilentlyContinue Remove-Item Env:EMAIL_FROM_NAME -ErrorAction SilentlyContinue - Remove-Item Env:VITE_API_PROXY_TARGET -ErrorAction SilentlyContinue - Remove-Item Env:VITE_API_BASE_URL -ErrorAction SilentlyContinue - Remove-Item Env:JWT_SECRET -ErrorAction SilentlyContinue - Remove-Item Env:DEFAULT_ADMIN_EMAIL -ErrorAction SilentlyContinue + Remove-Item Env:VITE_API_PROXY_TARGET -ErrorAction SilentlyContinue + Remove-Item Env:VITE_API_BASE_URL -ErrorAction SilentlyContinue + Remove-Item Env:NODE_ENV -ErrorAction SilentlyContinue + Remove-Item Env:JWT_SECRET -ErrorAction SilentlyContinue +Remove-Item Env:DEFAULT_ADMIN_EMAIL -ErrorAction SilentlyContinue Remove-Item Env:DEFAULT_ADMIN_PASSWORD -ErrorAction SilentlyContinue Remove-Item $serverExePath -Force -ErrorAction SilentlyContinue Remove-Item $e2eRunRoot -Recurse -Force -ErrorAction SilentlyContinue diff --git a/frontend/admin/scripts/run-playwright-auth-e2e.sh b/frontend/admin/scripts/run-playwright-auth-e2e.sh new file mode 100644 index 0000000..f26ffab --- /dev/null +++ b/frontend/admin/scripts/run-playwright-auth-e2e.sh @@ -0,0 +1,142 @@ +#!/usr/bin/env bash +set -euo pipefail + +ADMIN_USERNAME="${E2E_LOGIN_USERNAME:-e2e_admin}" +ADMIN_PASSWORD="${E2E_LOGIN_PASSWORD:-E2EAdmin@123456}" +ADMIN_EMAIL="${E2E_LOGIN_EMAIL:-e2e_admin@example.com}" +BOOTSTRAP_SECRET_VALUE="${E2E_BOOTSTRAP_SECRET:-${BOOTSTRAP_SECRET:-e2e-bootstrap-secret-0123456789abcdefghijklmnopqrstuvwxyz}}" +BROWSER_PORT="${E2E_CDP_PORT:-0}" +BACKEND_PORT="${E2E_BACKEND_PORT:-0}" +FRONTEND_PORT="${E2E_FRONTEND_PORT:-0}" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +FRONTEND_ROOT="$(cd -- "$SCRIPT_DIR/.." && pwd)" +PROJECT_ROOT="$(cd -- "$SCRIPT_DIR/../../.." && pwd)" +TMP_ROOT="$(mktemp -d -t ums-playwright-e2e-XXXXXX)" +DATA_ROOT="$TMP_ROOT/data" +SMTP_CAPTURE_FILE="$TMP_ROOT/smtp-capture.jsonl" +SERVER_BIN="$TMP_ROOT/ums-server" +mkdir -p "$DATA_ROOT" + +backend_pid='' +frontend_pid='' +smtp_pid='' + +cleanup() { + local exit_code=$? + for pid in "$frontend_pid" "$backend_pid" "$smtp_pid"; do + if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then + kill "$pid" 2>/dev/null || true + wait "$pid" 2>/dev/null || true + fi + done + rm -rf "$TMP_ROOT" + exit "$exit_code" +} +trap cleanup EXIT INT TERM + +get_free_port() { + python3 - <<'PY' +import socket +s = socket.socket() +s.bind(('127.0.0.1', 0)) +print(s.getsockname()[1]) +s.close() +PY +} + +wait_url_ready() { + local url="$1" + local label="$2" + local attempts="${3:-120}" + local delay="${4:-0.5}" + for ((i=0; i/dev/null 2>&1; then + return 0 + fi + sleep "$delay" + done + echo "$label did not become ready: $url" >&2 + return 1 +} + +SELECTED_BACKEND_PORT="$BACKEND_PORT" +if [[ "$SELECTED_BACKEND_PORT" == "0" ]]; then + SELECTED_BACKEND_PORT="$(get_free_port)" +fi +SELECTED_FRONTEND_PORT="$FRONTEND_PORT" +if [[ "$SELECTED_FRONTEND_PORT" == "0" ]]; then + SELECTED_FRONTEND_PORT="$(get_free_port)" +fi +SELECTED_SMTP_PORT="$(get_free_port)" + +BACKEND_BASE_URL="http://127.0.0.1:${SELECTED_BACKEND_PORT}" +FRONTEND_BASE_URL="http://127.0.0.1:${SELECTED_FRONTEND_PORT}" +SQLITE_PATH="$DATA_ROOT/user_management.e2e.db" + +cd "$PROJECT_ROOT" +go build -o "$SERVER_BIN" ./cmd/server + +echo "playwright e2e backend: $BACKEND_BASE_URL" +echo "playwright e2e frontend: $FRONTEND_BASE_URL" +echo "playwright e2e smtp: 127.0.0.1:$SELECTED_SMTP_PORT" +echo "playwright e2e sqlite: $SQLITE_PATH" + +node "$SCRIPT_DIR/mock-smtp-capture.mjs" --port "$SELECTED_SMTP_PORT" --output "$SMTP_CAPTURE_FILE" >"$TMP_ROOT/smtp.log" 2>&1 & +smtp_pid=$! +sleep 0.5 +if ! kill -0 "$smtp_pid" 2>/dev/null; then + cat "$TMP_ROOT/smtp.log" >&2 || true + echo "smtp capture server failed to start" >&2 + exit 1 +fi + +( + export SERVER_PORT="$SELECTED_BACKEND_PORT" + export DATABASE_DBNAME="$SQLITE_PATH" + export SERVER_MODE='debug' + export SERVER_FRONTEND_URL="$FRONTEND_BASE_URL" + export CORS_ALLOWED_ORIGINS="$FRONTEND_BASE_URL,http://localhost:${SELECTED_FRONTEND_PORT}" + export LOGGING_OUTPUT='stdout' + export DISABLE_RATE_LIMIT='1' + export EMAIL_HOST='127.0.0.1' + export EMAIL_PORT="$SELECTED_SMTP_PORT" + export EMAIL_FROM_EMAIL='noreply@test.local' + export EMAIL_FROM_NAME='UMS E2E' + export JWT_SECRET='e2e-test-jwt-secret-at-least-32-bytes-long-for-security' + export BOOTSTRAP_SECRET="$BOOTSTRAP_SECRET_VALUE" + exec "$SERVER_BIN" +) >"$TMP_ROOT/backend.log" 2>&1 & +backend_pid=$! + +if ! wait_url_ready "$BACKEND_BASE_URL/health" 'backend'; then + cat "$TMP_ROOT/backend.log" >&2 || true + exit 1 +fi + +( + cd "$FRONTEND_ROOT" + export VITE_API_PROXY_TARGET="$BACKEND_BASE_URL" + export VITE_API_BASE_URL='/api/v1' + exec env -u NODE_ENV npm run dev -- --host 127.0.0.1 --port "$SELECTED_FRONTEND_PORT" +) >"$TMP_ROOT/frontend.log" 2>&1 & +frontend_pid=$! + +if ! wait_url_ready "$FRONTEND_BASE_URL" 'frontend'; then + cat "$TMP_ROOT/frontend.log" >&2 || true + exit 1 +fi + +cd "$FRONTEND_ROOT" +export E2E_LOGIN_USERNAME="$ADMIN_USERNAME" +export E2E_LOGIN_PASSWORD="$ADMIN_PASSWORD" +export E2E_LOGIN_EMAIL="$ADMIN_EMAIL" +export E2E_BOOTSTRAP_SECRET="$BOOTSTRAP_SECRET_VALUE" +export BOOTSTRAP_SECRET="$BOOTSTRAP_SECRET_VALUE" +export E2E_EXPECT_ADMIN_BOOTSTRAP='1' +export E2E_EXTERNAL_WEB_SERVER='1' +export E2E_MANAGED_BROWSER='1' +export E2E_BASE_URL="$FRONTEND_BASE_URL" +export E2E_SMTP_CAPTURE_FILE="$SMTP_CAPTURE_FILE" + +env -u NODE_ENV node ./scripts/run-playwright-cdp-e2e.mjs diff --git a/frontend/admin/scripts/run-playwright-cdp-e2e.mjs b/frontend/admin/scripts/run-playwright-cdp-e2e.mjs index 1fd7253..164a9cd 100644 --- a/frontend/admin/scripts/run-playwright-cdp-e2e.mjs +++ b/frontend/admin/scripts/run-playwright-cdp-e2e.mjs @@ -18,16 +18,18 @@ const TEXT = { assignPermissions: '\u5206\u914d\u6743\u9650', assignRoles: '\u5206\u914d\u89d2\u8272', assignRolesAction: '\u89d2\u8272', + auditLogs: '\u5ba1\u8ba1\u65e5\u5fd7', backToLogin: '\u8fd4\u56de\u767b\u5f55', bootstrapAdminConfirmPasswordPlaceholder: '\u786e\u8ba4\u7ba1\u7406\u5458\u5bc6\u7801', - bootstrapAdminEmailPlaceholder: '\u7ba1\u7406\u5458\u90ae\u7bb1\uff08\u9009\u586b\uff09', + bootstrapAdminEmailPlaceholder: '\u7ba1\u7406\u5458\u90ae\u7bb1', bootstrapAdminPasswordPlaceholder: '\u7ba1\u7406\u5458\u5bc6\u7801', + bootstrapAdminSecretPlaceholder: 'Bootstrap Secret', bootstrapAdminSubmit: '\u5b8c\u6210\u521d\u59cb\u5316\u5e76\u8fdb\u5165\u7cfb\u7edf', bootstrapAdminUsernamePlaceholder: '\u7ba1\u7406\u5458\u7528\u6237\u540d', changePassword: '\u4fee\u6539\u5bc6\u7801', confirmPasswordPlaceholder: '\u786e\u8ba4\u5bc6\u7801', createAccount: '\u521b\u5efa\u8d26\u53f7', - createUser: '\u521b\u5efa\u7528\u5458', + createUser: '\u521b\u5efa\u7528\u6237', createUserEmailPlaceholder: '\u90ae\u7bb1\u5730\u5740', createUserPasswordPlaceholder: '\u8bf7\u8f93\u5165\u521d\u59cb\u5bc6\u7801', createUserUsernamePlaceholder: '\u8bf7\u8f93\u5165\u7528\u6237\u540d', @@ -45,6 +47,7 @@ const TEXT = { emailActivationSuccess: '\u90ae\u7bb1\u9a8c\u8bc1\u6210\u529f', export: '\u5bfc\u51fa', forgotPassword: '\u5fd8\u8bb0\u5bc6\u7801\uff1f', + integration: '\u96c6\u6210\u80fd\u529b', loginAction: '\u767b\u5f55', loginLogs: '\u767b\u5f55\u65e5\u5fd7', loginNow: '\u7acb\u5373\u767b\u5f55', @@ -104,6 +107,7 @@ const SMTP_CAPTURE_FILE = (process.env.E2E_SMTP_CAPTURE_FILE ?? '').trim() const SESSION_PRESENCE_COOKIE_NAME = 'ums_session_present' let managedCdpUrl = null +const IS_WINDOWS = process.platform === 'win32' function appUrl(pathname) { return new URL(pathname, `${BASE_URL}/`).toString() @@ -193,6 +197,16 @@ async function waitForActivationLink(email, timeoutMs = 20_000) { throw new Error(`Timed out waiting for activation email for ${email}.`) } +async function fetchAuthCapabilitiesSnapshot() { + const response = await fetch(appUrl('/api/v1/auth/capabilities')) + if (!response.ok) { + throw new Error(`Failed to fetch auth capabilities: ${response.status} ${response.statusText}`) + } + + const payload = await response.json() + return payload?.data ?? {} +} + function resolveCdpUrl() { if (managedCdpUrl) { return managedCdpUrl @@ -272,12 +286,24 @@ async function resolveManagedBrowserPath() { return candidate } - for (const candidate of [ - 'C:\\Program Files\\Google\\Chrome\\Application\\chrome.exe', - 'C:\\Program Files (x86)\\Google\\Chrome\\Application\\chrome.exe', - 'C:\\Program Files\\Microsoft\\Edge\\Application\\msedge.exe', - 'C:\\Program Files (x86)\\Microsoft\\Edge\\Application\\msedge.exe', - ]) { + const platformCandidates = IS_WINDOWS + ? [ + 'C:\\Program Files\\Google\\Chrome\\Application\\chrome.exe', + 'C:\\Program Files (x86)\\Google\\Chrome\\Application\\chrome.exe', + 'C:\\Program Files\\Microsoft\\Edge\\Application\\msedge.exe', + 'C:\\Program Files (x86)\\Microsoft\\Edge\\Application\\msedge.exe', + ] + : [ + '/snap/bin/chromium', + '/usr/bin/chromium', + '/usr/bin/chromium-browser', + '/usr/bin/google-chrome', + '/usr/bin/google-chrome-stable', + '/usr/bin/microsoft-edge', + '/usr/bin/msedge', + ] + + for (const candidate of platformCandidates) { try { await assertFileExists(candidate) return candidate @@ -286,7 +312,9 @@ async function resolveManagedBrowserPath() { } } - const baseDir = path.join(process.env.LOCALAPPDATA ?? '', 'ms-playwright') + const baseDir = IS_WINDOWS + ? path.join(process.env.LOCALAPPDATA ?? '', 'ms-playwright') + : path.join(process.env.HOME ?? '', '.cache', 'ms-playwright') const candidates = [] try { @@ -297,11 +325,16 @@ async function resolveManagedBrowserPath() { } candidates.push( - path.join(baseDir, entry.name, 'chrome-headless-shell-win64', 'chrome-headless-shell.exe'), + path.join( + baseDir, + entry.name, + IS_WINDOWS ? 'chrome-headless-shell-win64' : 'chrome-headless-shell-linux64', + IS_WINDOWS ? 'chrome-headless-shell.exe' : 'chrome-headless-shell', + ), ) } } catch { - throw new Error('failed to scan Playwright browser cache under LOCALAPPDATA') + throw new Error(`failed to scan Playwright browser cache under ${baseDir}`) } candidates.sort().reverse() @@ -376,6 +409,15 @@ async function killManagedBrowser(browserProcess) { return } + if (!IS_WINDOWS) { + try { + browserProcess.kill('SIGKILL') + } catch { + // ignore + } + return + } + await new Promise((resolve) => { const killer = spawn('taskkill', ['/PID', String(browserProcess.pid), '/T', '/F'], { stdio: 'ignore', @@ -547,8 +589,28 @@ function attachSignalCollectors(page, signals) { } } +async function assertBaseUrlServesAdminApp(page) { + await page.goto(appUrl('/login'), { waitUntil: 'domcontentloaded' }) + await page.waitForLoadState('networkidle').catch(() => {}) + + const title = await page.title().catch(() => '') + const bodyText = (await page.locator('body').textContent())?.trim() ?? '' + const matchesAppTitle = title.includes(TEXT.appTitle) + const matchesAppBody = bodyText.includes(TEXT.welcomeLogin) || bodyText.includes(TEXT.adminBootstrapTitle) + if (matchesAppTitle || matchesAppBody) { + return + } + + throw new Error( + `E2E_BASE_URL resolved to ${appUrl('/login')}, but the page does not look like the admin app. ` + + `title=${JSON.stringify(title)} body_excerpt=${JSON.stringify(bodyText.slice(0, 160))}. ` + + `Set E2E_BASE_URL to the running frontend app (default expects the Vite dev server on :3000).`, + ) +} + async function resetBrowserState(context, page) { logDebug('resetting browser state') + await page.setViewportSize({ width: VIEWPORTS[0].width, height: VIEWPORTS[0].height }) await context.clearCookies() await page.goto(appUrl('/login'), { waitUntil: 'domcontentloaded' }) await page.evaluate(() => { @@ -709,7 +771,12 @@ async function forceClick(locator) { }) } -async function readRefreshToken(page) { +async function hasHttpOnlyRefreshCookie(page) { + const cookies = await page.context().cookies() + return cookies.some((cookie) => cookie.name === 'ums_refresh_token' && Boolean(cookie.value)) +} + +async function readSessionPresenceCookie(page) { return await page.evaluate((cookieName) => { const target = `${cookieName}=` const matched = document.cookie @@ -731,19 +798,31 @@ async function assertApiSuccessResponse(response, label) { try { payload = JSON.parse(responseBody) } catch (error) { - if (error instanceof SyntaxError) { - throw new Error(`${label} response is not valid JSON: ${responseBody}`) - } - throw error + throw new Error(`${label} response is not valid JSON: ${responseBody}`) } if (payload?.code !== 0) { - throw new Error(`${label} business response failed: ${responseBody}`) + throw new Error(`${label} response code ${payload?.code}: ${payload?.message ?? responseBody}`) } return payload } +async function waitForSessionCookies(context, timeoutMs = 10_000) { + const startedAt = Date.now() + while (Date.now() - startedAt < timeoutMs) { + const cookies = await context.cookies() + const hasRefresh = cookies.some((cookie) => cookie.name === 'ums_refresh_token' && cookie.value) + const hasPresence = cookies.some((cookie) => cookie.name === 'ums_session_present' && cookie.value === '1') + if (hasRefresh && hasPresence) { + return + } + await delay(100) + } + + throw new Error('session cookies were not persisted after login within timeout') +} + async function loginWithPassword(page, username, password, expectedUrlPattern) { const usernameInput = page .locator(`input[autocomplete="username"], input[placeholder="${TEXT.usernamePlaceholder}"]`) @@ -761,12 +840,25 @@ async function loginWithPassword(page, username, password, expectedUrlPattern) { if (loginResponse) { await assertApiSuccessResponse(loginResponse, 'password login') } + await waitForSessionCookies(page.context()) if (expectedUrlPattern) { await expect(page).toHaveURL(expectedUrlPattern, { timeout: 30 * 1000 }) } } +async function expectLoggedInLanding(page, timeoutMs = 30 * 1000) { + await expect(page).toHaveURL(/\/(dashboard|profile)$/, { timeout: timeoutMs }) + + const currentUrl = page.url() + if (currentUrl.endsWith('/dashboard')) { + await expect(page.getByText(TEXT.todaySuccessLogins)).toBeVisible() + return + } + + await expect(page.locator('body')).toContainText(TEXT.profile) +} + async function loginFromLoginPage(page) { const username = requireEnv('E2E_LOGIN_USERNAME') const password = requireEnv('E2E_LOGIN_PASSWORD') @@ -775,7 +867,8 @@ async function loginFromLoginPage(page) { await expect(page).toHaveURL(/\/login$/) await expect(page.getByRole('heading', { name: TEXT.welcomeLogin })).toBeVisible() - await loginWithPassword(page, username, password, /\/dashboard$/) + await loginWithPassword(page, username, password) + await expectLoggedInLanding(page) return { username, password } } @@ -784,6 +877,10 @@ async function verifyAdminBootstrapWorkflow(page) { const username = requireEnv('E2E_LOGIN_USERNAME') const password = requireEnv('E2E_LOGIN_PASSWORD') const email = (process.env.E2E_LOGIN_EMAIL ?? `${username}@example.com`).trim() + const bootstrapSecret = (process.env.E2E_BOOTSTRAP_SECRET ?? process.env.BOOTSTRAP_SECRET ?? '').trim() + if (!bootstrapSecret) { + throw new Error('E2E_BOOTSTRAP_SECRET or BOOTSTRAP_SECRET is required when E2E_EXPECT_ADMIN_BOOTSTRAP=1.') + } const capabilitiesResponse = page.waitForResponse((response) => { return response.url().includes('/api/v1/auth/capabilities') && response.request().method() === 'GET' @@ -800,6 +897,7 @@ async function verifyAdminBootstrapWorkflow(page) { await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminUsernamePlaceholder}"]`).first(), username) await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminEmailPlaceholder}"]`).first(), email) + await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminSecretPlaceholder}"]`).first(), bootstrapSecret) await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminPasswordPlaceholder}"]`).first(), password) await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminConfirmPasswordPlaceholder}"]`).first(), password) @@ -811,8 +909,7 @@ async function verifyAdminBootstrapWorkflow(page) { ]) await assertApiSuccessResponse(bootstrapResponse, 'bootstrap admin') - await expect(page).toHaveURL(/\/dashboard$/, { timeout: 30 * 1000 }) - await expect(page.getByText(TEXT.todaySuccessLogins)).toBeVisible() + await expectLoggedInLanding(page) await forceClick(page.locator('[class*="userTrigger"]')) await forceClick(page.getByText(TEXT.logout, { exact: true })) @@ -1012,7 +1109,8 @@ async function verifyAuthWorkflow(page) { await page.goto(appUrl('/users')) await expect(page).toHaveURL(/\/users$/) - expect(await readRefreshToken(page)).toBeTruthy() + expect(await hasHttpOnlyRefreshCookie(page)).toBe(true) + expect(await readSessionPresenceCookie(page)).toBe('1') const userRow = page.locator('tbody tr').filter({ hasText: credentials.username }).first() await expect(userRow).toBeVisible({ timeout: 20 * 1000 }) @@ -1084,7 +1182,8 @@ async function verifyAuthWorkflow(page) { await forceClick(page.locator('[class*="userTrigger"]')) await forceClick(page.getByText(TEXT.logout, { exact: true })) await expect(page).toHaveURL(/\/login$/) - await expect(await readRefreshToken(page)).toBeNull() + await expect(await hasHttpOnlyRefreshCookie(page)).toBe(false) + await expect(await readSessionPresenceCookie(page)).toBeNull() await page.goto(appUrl('/dashboard')) const postLogoutRedirect = await getProtectedRouteRedirect(page) @@ -1191,7 +1290,7 @@ async function verifyUserManagementCRUD(page) { const userRow = page.locator('tbody tr').filter({ hasText: testUsername }).first() await forceClick(userRow.getByRole('button', { name: TEXT.edit })) - const editDrawer = page.locator('.ant-drawer') + const editDrawer = page.locator('.ant-drawer.ant-drawer-open') await expect(editDrawer).toBeVisible({ timeout: 10 * 1000 }) const editResponsePromise = page.waitForResponse((response) => { @@ -1202,7 +1301,7 @@ async function verifyUserManagementCRUD(page) { await assertApiSuccessResponse(editResponse, 'edit user CRUD') await forceClick(userRow.getByRole('button', { name: TEXT.userDetailAction })) - const detailDrawer = page.locator('.ant-drawer') + const detailDrawer = page.locator('.ant-drawer.ant-drawer-open') await expect(detailDrawer).toBeVisible({ timeout: 10 * 1000 }) await expect(detailDrawer).toContainText(testUsername) @@ -1211,13 +1310,14 @@ async function verifyUserManagementCRUD(page) { await expect(page.locator('tbody tr').filter({ hasText: testUsername }).first()).toBeVisible({ timeout: 10 * 1000 }) await forceClick(userRow.getByRole('button', { name: TEXT.delete })) - const deleteConfirmModal = page.locator('.ant-modal-confirm') + const deleteConfirmModal = page.locator('.ant-popover').filter({ hasText: '确定要删除用户' }).last() await expect(deleteConfirmModal).toBeVisible({ timeout: 10 * 1000 }) - const deleteResponsePromise = page.waitForResponse((response) => { - return response.url().includes(`/api/v1/users/`) && response.request().method() === 'DELETE' - }) - await forceClick(deleteConfirmModal.locator('.ant-btn-primary').last()) - const deleteResponse = await deleteResponsePromise + const [deleteResponse] = await Promise.all([ + page.waitForResponse((response) => { + return response.url().includes(`/api/v1/users/`) && response.request().method() === 'DELETE' + }), + forceClick(deleteConfirmModal.locator('.ant-popconfirm-buttons .ant-btn-primary').last()), + ]) await assertApiSuccessResponse(deleteResponse, 'delete user CRUD') await expect(page.locator('tbody tr').filter({ hasText: testUsername }).first()).toHaveCount(0, { timeout: 10 * 1000 }) @@ -1255,8 +1355,7 @@ async function verifyDeviceManagement(page) { logDebug('verifyDeviceManagement: login /login') await loginFromLoginPage(page) - await expandSidebarGroup(page, TEXT.systemManagement) - await clickSidebarMenu(page, TEXT.devices) + await page.goto(appUrl('/devices')) await expect(page).toHaveURL(/\/devices$/) await expect(page.getByText(TEXT.deviceManagement)).toBeVisible({ timeout: 10 * 1000 }) @@ -1270,11 +1369,11 @@ async function verifyLoginLogs(page) { logDebug('verifyLoginLogs: login /login') await loginFromLoginPage(page) - await expandSidebarGroup(page, TEXT.systemManagement) + await expandSidebarGroup(page, TEXT.auditLogs) await clickSidebarMenu(page, TEXT.loginLogs) - await expect(page).toHaveURL(/\/login-logs$/) + await expect(page).toHaveURL(/\/logs\/login$/) - await expect(page.getByText(TEXT.loginLogs)).toBeVisible({ timeout: 10 * 1000 }) + await expect(page.getByRole('heading', { name: TEXT.loginLogs })).toBeVisible({ timeout: 10 * 1000 }) await forceClick(page.locator('[class*="userTrigger"]')) await forceClick(page.getByText(TEXT.logout, { exact: true })) @@ -1285,11 +1384,11 @@ async function verifyOperationLogs(page) { logDebug('verifyOperationLogs: login /login') await loginFromLoginPage(page) - await expandSidebarGroup(page, TEXT.systemManagement) + await expandSidebarGroup(page, TEXT.auditLogs) await clickSidebarMenu(page, TEXT.operationLogs) - await expect(page).toHaveURL(/\/operation-logs$/) + await expect(page).toHaveURL(/\/logs\/operation$/) - await expect(page.getByText(TEXT.operationLogs)).toBeVisible({ timeout: 10 * 1000 }) + await expect(page.getByRole('heading', { name: TEXT.operationLogs })).toBeVisible({ timeout: 10 * 1000 }) await forceClick(page.locator('[class*="userTrigger"]')) await forceClick(page.getByText(TEXT.logout, { exact: true })) @@ -1300,11 +1399,11 @@ async function verifyWebhookManagement(page) { logDebug('verifyWebhookManagement: login /login') await loginFromLoginPage(page) - await expandSidebarGroup(page, TEXT.systemManagement) + await expandSidebarGroup(page, TEXT.integration) await clickSidebarMenu(page, TEXT.webhooks) await expect(page).toHaveURL(/\/webhooks$/) - await expect(page.getByText(TEXT.webhooks)).toBeVisible({ timeout: 10 * 1000 }) + await expect(page.locator('body')).toContainText('Webhook 管理', { timeout: 10 * 1000 }) await forceClick(page.locator('[class*="userTrigger"]')) await forceClick(page.getByText(TEXT.logout, { exact: true })) @@ -1322,10 +1421,10 @@ async function verifyProfileAndSecurity(page) { await expect(page.locator('body')).toContainText(credentials.username, { timeout: 10 * 1000 }) await forceClick(page.locator('[class*="userTrigger"]')) - await forceClick(page.getByText(TEXT.security)) + await forceClick(page.locator('.ant-dropdown').getByText(TEXT.security, { exact: true }).last()) await expect(page).toHaveURL(/\/profile\/security$/) - await expect(page.getByText(TEXT.changePassword)).toBeVisible({ timeout: 10 * 1000 }) + await expect(page.getByRole('button', { name: TEXT.changePassword })).toBeVisible({ timeout: 10 * 1000 }) await forceClick(page.locator('[class*="userTrigger"]')) await forceClick(page.getByText(TEXT.logout, { exact: true })) @@ -1370,11 +1469,22 @@ async function main() { throw new Error('No persistent Chromium context is available through CDP.') } + const preflightPage = await ensurePersistentPage(browser, context) + if (!preflightPage) { + throw new Error('No persistent page is available in the Chromium CDP context.') + } + await assertBaseUrlServesAdminApp(preflightPage) + const authCapabilities = await fetchAuthCapabilitiesSnapshot() + if (process.env.E2E_EXPECT_ADMIN_BOOTSTRAP === '1') { await runScenario(browser, context, 'admin-bootstrap', verifyAdminBootstrapWorkflow) } await runScenario(browser, context, 'public-registration', verifyPublicRegistration) - await runScenario(browser, context, 'email-activation', verifyEmailActivationWorkflow) + if (authCapabilities.email_activation) { + await runScenario(browser, context, 'email-activation', verifyEmailActivationWorkflow) + } else { + console.log('SKIP email-activation (auth capability disabled)') + } await runScenario(browser, context, 'login-surface', verifyLoginSurface) await runScenario(browser, context, 'auth-workflow', verifyAuthWorkflow) await runScenario(browser, context, 'responsive-login', verifyResponsiveLogin) diff --git a/frontend/admin/src/lib/http/client.ts b/frontend/admin/src/lib/http/client.ts index 96e4bd0..91bcb52 100644 --- a/frontend/admin/src/lib/http/client.ts +++ b/frontend/admin/src/lib/http/client.ts @@ -18,6 +18,7 @@ import { CSRF_PROTECTED_METHODS, getCSRFHeaders } from './csrf' import type { TokenBundle } from '@/types' const DEFAULT_TIMEOUT = 30_000 +let inFlightRefreshBundle: Promise | null = null function isFormDataBody(body: unknown): body is FormData { return typeof FormData !== 'undefined' && body instanceof FormData @@ -145,6 +146,40 @@ async function refreshAccessToken(): Promise { return result.data } +async function performTokenRefresh(): Promise { + if (inFlightRefreshBundle) { + return inFlightRefreshBundle + } + + startRefreshing() + const promise = (async () => { + try { + const tokenBundle = await refreshAccessToken() + setAccessToken(tokenBundle.access_token, tokenBundle.expires_in) + setRefreshToken(tokenBundle.refresh_token) + return tokenBundle + } finally { + endRefreshing() + clearRefreshPromise() + inFlightRefreshBundle = null + } + })() + + inFlightRefreshBundle = promise + setRefreshPromise( + promise.then( + () => undefined, + () => undefined, + ), + ) + + return promise +} + +export async function refreshSessionBundle(): Promise { + return await performTokenRefresh() +} + async function performRefresh(): Promise { if (isRefreshing()) { const promise = getRefreshPromise() @@ -160,26 +195,8 @@ async function performRefresh(): Promise { return token } - startRefreshing() - const promise = (async () => { - try { - const tokenBundle = await refreshAccessToken() - setAccessToken(tokenBundle.access_token, tokenBundle.expires_in) - setRefreshToken(tokenBundle.refresh_token) - return tokenBundle.access_token - } finally { - endRefreshing() - clearRefreshPromise() - } - })() - - setRefreshPromise( - promise.then( - () => undefined, - () => undefined, - ), - ) - return promise + const tokenBundle = await performTokenRefresh() + return tokenBundle.access_token } async function resolveAuthorizationHeader(auth: boolean): Promise { diff --git a/frontend/admin/src/pages/admin/ProfileSecurityPage/ContactBindingsSection.tsx b/frontend/admin/src/pages/admin/ProfileSecurityPage/ContactBindingsSection.tsx index 82c00b3..e5c32dc 100644 --- a/frontend/admin/src/pages/admin/ProfileSecurityPage/ContactBindingsSection.tsx +++ b/frontend/admin/src/pages/admin/ProfileSecurityPage/ContactBindingsSection.tsx @@ -345,14 +345,12 @@ export function ContactBindingsSection({ label="验证码" rules={[{ required: true, message: '请输入验证码' }]} > - - 发送验证码 - - } - /> + + + + diff --git a/frontend/admin/src/pages/auth/BootstrapAdminPage/BootstrapAdminPage.test.tsx b/frontend/admin/src/pages/auth/BootstrapAdminPage/BootstrapAdminPage.test.tsx index 24d6d53..b0bd468 100644 --- a/frontend/admin/src/pages/auth/BootstrapAdminPage/BootstrapAdminPage.test.tsx +++ b/frontend/admin/src/pages/auth/BootstrapAdminPage/BootstrapAdminPage.test.tsx @@ -29,7 +29,7 @@ const authContextValue: AuthContextValue = { function renderBootstrapAdminPage() { return render( - + @@ -88,7 +88,8 @@ describe('BootstrapAdminPage', () => { await user.type(screen.getByPlaceholderText('管理员用户名'), 'bootstrap_admin') await user.type(screen.getByPlaceholderText('管理员昵称(选填)'), 'Bootstrap Admin') - await user.type(screen.getByPlaceholderText('管理员邮箱(选填)'), 'bootstrap_admin@example.com') + await user.type(screen.getByPlaceholderText('管理员邮箱'), 'bootstrap_admin@example.com') + await user.type(screen.getByPlaceholderText('Bootstrap Secret'), 'bootstrap-secret-demo') await user.type(screen.getByPlaceholderText('管理员密码'), 'Bootstrap123!@#') await user.type(screen.getByPlaceholderText('确认管理员密码'), 'Bootstrap123!@#') await user.click(screen.getByRole('button', { name: '完成初始化并进入系统' })) @@ -99,6 +100,7 @@ describe('BootstrapAdminPage', () => { nickname: 'Bootstrap Admin', email: 'bootstrap_admin@example.com', password: 'Bootstrap123!@#', + bootstrap_secret: 'bootstrap-secret-demo', }), ) diff --git a/frontend/admin/src/pages/auth/BootstrapAdminPage/BootstrapAdminPage.tsx b/frontend/admin/src/pages/auth/BootstrapAdminPage/BootstrapAdminPage.tsx index 7852d65..0f3a3aa 100644 --- a/frontend/admin/src/pages/auth/BootstrapAdminPage/BootstrapAdminPage.tsx +++ b/frontend/admin/src/pages/auth/BootstrapAdminPage/BootstrapAdminPage.tsx @@ -24,7 +24,8 @@ const DEFAULT_CAPABILITIES: AuthCapabilities = { type BootstrapAdminFormValues = { username: string nickname?: string - email?: string + email: string + bootstrapSecret: string password: string confirmPassword: string } @@ -71,7 +72,8 @@ export function BootstrapAdminPage() { const tokenBundle = await bootstrapAdmin({ username: values.username.trim(), nickname: values.nickname?.trim() || undefined, - email: values.email?.trim() || undefined, + email: values.email!.trim(), + bootstrap_secret: values.bootstrapSecret!.trim(), password: values.password, }) await onLoginSuccess(tokenBundle) @@ -110,7 +112,7 @@ export function BootstrapAdminPage() { 初始化首个管理员账号 - 当前版本不内置默认账号。首次部署时,请先创建首个管理员账号,初始化完成后系统会自动关闭该入口。 + 当前版本不内置默认账号。首次部署时,请提供 Bootstrap Secret 并创建首个管理员账号,初始化完成后系统会自动关闭该入口。 } - placeholder="管理员邮箱(选填)" + placeholder="管理员邮箱" size="large" autoComplete="email" /> + + } + placeholder="Bootstrap Secret" + size="large" + autoComplete="one-time-code" + /> + ({ @@ -61,7 +58,7 @@ vi.mock('@/services/auth', () => ({ function renderRegisterPage() { return render( - + , ) @@ -321,16 +318,13 @@ describe('RegisterPage', () => { email_activation: true, }) registerMock.mockResolvedValue({ - user: { - id: 3, - username: 'inactive-user', - email: 'inactive-user@example.com', - phone: '', - nickname: 'Inactive User', - avatar: '', - status: 0, - }, - message: 'registered successfully, please check your email to activate the account', + id: 3, + username: 'inactive-user', + email: 'inactive-user@example.com', + phone: '', + nickname: 'Inactive User', + avatar: '', + status: 0, }) renderRegisterPage() @@ -350,16 +344,13 @@ describe('RegisterPage', () => { it('shows the generic activation summary when the new inactive account has no email address', async () => { registerMock.mockResolvedValue({ - user: { - id: 4, - username: 'inactive-without-email', - email: '', - phone: '', - nickname: '', - avatar: '', - status: 0, - }, - message: 'registered successfully, activation required', + id: 4, + username: 'inactive-without-email', + email: '', + phone: '', + nickname: '', + avatar: '', + status: 0, }) renderRegisterPage() diff --git a/frontend/admin/src/pages/auth/RegisterPage/RegisterPage.tsx b/frontend/admin/src/pages/auth/RegisterPage/RegisterPage.tsx index 36d2672..e3b4c00 100644 --- a/frontend/admin/src/pages/auth/RegisterPage/RegisterPage.tsx +++ b/frontend/admin/src/pages/auth/RegisterPage/RegisterPage.tsx @@ -38,10 +38,10 @@ type RegisterFormValues = { confirmPassword: string } -function buildRegisterSummary(result: RegisterResponse) { - if (result.user.status === 0) { - if (result.user.email) { - return `账号已创建,激活邮件会发送到 ${result.user.email}。请完成激活后再登录。` +function buildRegisterSummary(user: RegisterResponse) { + if (user.status === 0) { + if (user.email) { + return `账号已创建,激活邮件会发送到 ${user.email}。请完成激活后再登录。` } return '账号已创建,请按页面提示完成激活后再登录。' } @@ -128,7 +128,7 @@ export function RegisterPage() { form.resetFields() setSmsCountdown(0) setSubmitted(result) - message.success(result.user.status === 0 ? '注册成功,请完成邮箱激活' : '注册成功') + message.success(result.status === 0 ? '注册成功,请完成邮箱激活' : '注册成功') } catch (error) { message.error(getErrorMessage(error, '注册失败,请检查输入信息后重试')) } finally { @@ -137,7 +137,7 @@ export function RegisterPage() { }, [capabilities.sms_code, form]) if (submitted) { - const activationEmail = submitted.user.email?.trim() + const activationEmail = submitted.email?.trim() return ( @@ -146,7 +146,7 @@ export function RegisterPage() { title="注册成功" subTitle={( - {submitted.user.username} + {submitted.username} {' '} {buildRegisterSummary(submitted)} @@ -155,7 +155,7 @@ export function RegisterPage() { , - submitted.user.status === 0 && activationEmail && capabilities.email_activation ? ( + submitted.status === 0 && activationEmail && capabilities.email_activation ? ( diff --git a/frontend/admin/src/services/auth.test.ts b/frontend/admin/src/services/auth.test.ts index 6a114f4..d14766b 100644 --- a/frontend/admin/src/services/auth.test.ts +++ b/frontend/admin/src/services/auth.test.ts @@ -2,17 +2,21 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' const getMock = vi.fn() const postMock = vi.fn() +const refreshSessionBundleMock = vi.fn() vi.mock('@/lib/http/client', () => ({ get: getMock, post: postMock, + refreshSessionBundle: refreshSessionBundleMock, })) describe('auth service', () => { beforeEach(() => { getMock.mockReset() postMock.mockReset() + refreshSessionBundleMock.mockReset() postMock.mockResolvedValue(undefined) + refreshSessionBundleMock.mockResolvedValue(undefined) }) it('loads public auth capabilities without auth headers', async () => { @@ -84,6 +88,28 @@ describe('auth service', () => { ) }) + it('verifies password-login totp with the temporary challenge token', async () => { + const { verifyTOTPAfterPasswordLogin } = await import('./auth') + + await verifyTOTPAfterPasswordLogin({ + user_id: 42, + code: '123456', + device_id: 'device-1', + temp_token: 'temp-token-demo', + }) + + expect(postMock).toHaveBeenCalledWith( + '/auth/login/totp-verify', + { + user_id: 42, + code: '123456', + device_id: 'device-1', + temp_token: 'temp-token-demo', + }, + { auth: false, credentials: 'include' }, + ) + }) + it('submits public registration without auth headers', async () => { const { register } = await import('./auth') @@ -106,7 +132,7 @@ describe('auth service', () => { ) }) - it('submits first-admin bootstrap without auth headers', async () => { + it('submits first-admin bootstrap with bootstrap secret header', async () => { const { bootstrapAdmin } = await import('./auth') await bootstrapAdmin({ @@ -114,6 +140,7 @@ describe('auth service', () => { password: 'Bootstrap123!@#', email: 'bootstrap_admin@example.com', nickname: 'Bootstrap Admin', + bootstrap_secret: 'bootstrap-secret-demo', }) expect(postMock).toHaveBeenCalledWith( @@ -124,7 +151,13 @@ describe('auth service', () => { email: 'bootstrap_admin@example.com', nickname: 'Bootstrap Admin', }, - { auth: false, credentials: 'include' }, + { + auth: false, + credentials: 'include', + headers: { + 'X-Bootstrap-Secret': 'bootstrap-secret-demo', + }, + }, ) }) @@ -192,12 +225,13 @@ describe('auth service', () => { expect(postMock).toHaveBeenCalledWith('/auth/logout', undefined, { credentials: 'include' }) }) - it('refreshes the session with credentials even when no body token is supplied', async () => { + it('refreshes the session through the shared refresh single-flight when no body token is supplied', async () => { const { refreshSession } = await import('./auth') await refreshSession() - expect(postMock).toHaveBeenCalledWith( + expect(refreshSessionBundleMock).toHaveBeenCalledTimes(1) + expect(postMock).not.toHaveBeenCalledWith( '/auth/refresh', undefined, { auth: false, credentials: 'include' }, diff --git a/frontend/admin/src/services/auth.ts b/frontend/admin/src/services/auth.ts index 4363a8d..066324f 100644 --- a/frontend/admin/src/services/auth.ts +++ b/frontend/admin/src/services/auth.ts @@ -1,4 +1,5 @@ import { get, post } from '@/lib/http/client' +import { refreshSessionBundle } from '@/lib/http/client' import type { ActionMessageResponse, AuthCapabilities, @@ -59,7 +60,14 @@ export function register(data: RegisterRequest): Promise { } export function bootstrapAdmin(data: BootstrapAdminRequest): Promise { - return post('/auth/bootstrap-admin', data, { auth: false, credentials: 'include' }) + const { bootstrap_secret, ...payload } = data + return post('/auth/bootstrap-admin', payload, { + auth: false, + credentials: 'include', + headers: { + 'X-Bootstrap-Secret': bootstrap_secret, + }, + }) } export function activateEmail(token: string): Promise { @@ -81,8 +89,11 @@ export function sendSmsCode(data: SendSmsCodeRequest): Promise { } export function refreshSession(refreshToken?: string | null): Promise { - const body = refreshToken ? { refresh_token: refreshToken } : undefined - return post('/auth/refresh', body, { auth: false, credentials: 'include' }) + if (!refreshToken) { + return refreshSessionBundle() + } + + return post('/auth/refresh', { refresh_token: refreshToken }, { auth: false, credentials: 'include' }) } export function getOAuthAuthorizationUrl( diff --git a/frontend/admin/src/services/social-accounts.test.ts b/frontend/admin/src/services/social-accounts.test.ts index ff707f3..baa8252 100644 --- a/frontend/admin/src/services/social-accounts.test.ts +++ b/frontend/admin/src/services/social-accounts.test.ts @@ -28,6 +28,29 @@ describe('social account service', () => { expect(getMock).toHaveBeenCalledWith('/users/me/social-accounts') }) + it('normalizes object-wrapped social account payloads', async () => { + getMock.mockResolvedValue({ + social_accounts: [ + { + provider: 'github', + provider_user_id: '123', + provider_username: 'octocat', + bound_at: '2026-03-27 20:00:00', + }, + ], + }) + + const { listSocialAccounts } = await import('./social-accounts') + const result = await listSocialAccounts() + + expect(result).toEqual([ + expect.objectContaining({ + provider: 'github', + provider_username: 'octocat', + }), + ]) + }) + it('starts social binding with the current verification payload', async () => { const { startSocialBinding } = await import('./social-accounts') diff --git a/frontend/admin/src/services/social-accounts.ts b/frontend/admin/src/services/social-accounts.ts index f4da621..3d2595f 100644 --- a/frontend/admin/src/services/social-accounts.ts +++ b/frontend/admin/src/services/social-accounts.ts @@ -6,8 +6,35 @@ import type { SocialBindingStartResponse, } from '@/types' -export function listSocialAccounts(): Promise { - return get('/users/me/social-accounts') +interface SocialAccountsResponse { + items?: SocialAccountInfo[] + accounts?: SocialAccountInfo[] + social_accounts?: SocialAccountInfo[] +} + +function normalizeSocialAccounts(payload: SocialAccountInfo[] | SocialAccountsResponse): SocialAccountInfo[] { + if (Array.isArray(payload)) { + return payload + } + + if (Array.isArray(payload.items)) { + return payload.items + } + + if (Array.isArray(payload.accounts)) { + return payload.accounts + } + + if (Array.isArray(payload.social_accounts)) { + return payload.social_accounts + } + + return [] +} + +export async function listSocialAccounts(): Promise { + const payload = await get('/users/me/social-accounts') + return normalizeSocialAccounts(payload) } export function startSocialBinding( diff --git a/frontend/admin/src/services/users.test.ts b/frontend/admin/src/services/users.test.ts index 6425ecc..fb4b299 100644 --- a/frontend/admin/src/services/users.test.ts +++ b/frontend/admin/src/services/users.test.ts @@ -20,6 +20,52 @@ describe('users service', () => { delMock.mockReset() }) + it('normalizes backend user list payloads that use users/limit/offset fields', async () => { + getMock.mockResolvedValue({ + users: [ + { + id: 7, + username: 'e2e_admin', + email: 'admin@example.com', + nickname: '管理员', + status: '1', + }, + ], + total: 1, + limit: 20, + offset: 0, + }) + + const { listUsers } = await import('./users') + const result = await listUsers({ page: 1, page_size: 20 }) + + expect(getMock).toHaveBeenCalledWith('/users', { page: 1, page_size: 20 }) + expect(result).toEqual({ + items: [ + { + id: 7, + username: 'e2e_admin', + email: 'admin@example.com', + phone: '', + nickname: '管理员', + avatar: '', + gender: 0, + birthday: '', + region: '', + bio: '', + status: 1, + last_login_at: '', + last_login_ip: '', + created_at: '', + updated_at: '', + }, + ], + total: 1, + page: 1, + page_size: 20, + }) + }) + it('creates a user through the protected users endpoint', async () => { const payload = { username: 'new-user', diff --git a/frontend/admin/src/services/users.ts b/frontend/admin/src/services/users.ts index 71d9ce5..51b6e1d 100644 --- a/frontend/admin/src/services/users.ts +++ b/frontend/admin/src/services/users.ts @@ -17,12 +17,59 @@ import type { AssignUserRolesRequest, } from '@/types/user' +interface RawUserListResponse { + items?: Partial[] + users?: Partial[] + total?: number + page?: number + page_size?: number + limit?: number + offset?: number +} + +function normalizeUser(user: Partial): User { + const numericStatus = typeof user.status === 'string' ? Number(user.status) : user.status + return { + id: user.id ?? 0, + username: user.username ?? '', + email: user.email ?? '', + phone: user.phone ?? '', + nickname: user.nickname ?? '', + avatar: user.avatar ?? '', + gender: user.gender ?? 0, + birthday: user.birthday ?? '', + region: user.region ?? '', + bio: user.bio ?? '', + status: (typeof numericStatus === 'number' && !Number.isNaN(numericStatus) ? numericStatus : 0) as UserStatus, + last_login_at: user.last_login_at ?? '', + last_login_ip: user.last_login_ip ?? '', + created_at: user.created_at ?? '', + updated_at: user.updated_at ?? '', + } +} + +function normalizeUserListResponse(result?: RawUserListResponse | null): PaginatedData { + const payload = result ?? {} + const items = Array.isArray(payload.items) ? payload.items : Array.isArray(payload.users) ? payload.users : [] + const pageSize = payload.page_size ?? payload.limit ?? items.length + const offset = payload.offset ?? 0 + const page = payload.page ?? (pageSize > 0 ? Math.floor(offset / pageSize) + 1 : 1) + + return { + items: items.map(normalizeUser), + total: payload.total ?? items.length, + page, + page_size: pageSize, + } +} + /** * 获取用户列表 * GET /api/v1/users */ -export function listUsers(params: UserListParams): Promise> { - return get>('/users', params as Record) +export async function listUsers(params: UserListParams): Promise> { + const result = await get('/users', params as Record) + return normalizeUserListResponse(result) } /** diff --git a/frontend/admin/src/services/webhooks.test.ts b/frontend/admin/src/services/webhooks.test.ts index 4b15b74..6ad50f9 100644 --- a/frontend/admin/src/services/webhooks.test.ts +++ b/frontend/admin/src/services/webhooks.test.ts @@ -74,6 +74,44 @@ describe('webhooks service', () => { expect(result.data[2].events).toEqual([]) }) + it('normalizes backend webhook list payloads that use items/limit/offset fields', async () => { + getMock.mockResolvedValue({ + items: [ + { + id: 11, + name: 'Compat Hook', + url: 'https://example.com/compat', + events: '["user.updated"]', + status: 1, + max_retries: 3, + timeout_sec: 10, + created_by: 1, + created_at: '2026-03-27 20:20:00', + updated_at: '2026-03-27 20:20:00', + }, + ], + total: 1, + limit: 20, + offset: 0, + }) + + const { listWebhooks } = await import('./webhooks') + const result = await listWebhooks({ page: 1, page_size: 20 }) + + expect(result).toEqual({ + data: [ + expect.objectContaining({ + id: 11, + name: 'Compat Hook', + events: ['user.updated'], + }), + ], + total: 1, + page: 1, + page_size: 20, + }) + }) + it('sends create, update, delete, and delivery requests through the HTTP client', async () => { postMock.mockResolvedValue({ id: 1, diff --git a/frontend/admin/src/services/webhooks.ts b/frontend/admin/src/services/webhooks.ts index 95fed5b..32d88c2 100644 --- a/frontend/admin/src/services/webhooks.ts +++ b/frontend/admin/src/services/webhooks.ts @@ -33,18 +33,42 @@ function normalizeWebhook(webhook: RawWebhook): Webhook { } interface PaginatedResponse { - data: T[] - total: number - page: number - page_size: number + data?: T[] + items?: T[] + webhooks?: T[] + total?: number + page?: number + page_size?: number + limit?: number + offset?: number +} + +function normalizeWebhookList(result: PaginatedResponse): { data: Webhook[]; total: number; page: number; page_size: number } { + const rawItems = Array.isArray(result.data) + ? result.data + : Array.isArray(result.items) + ? result.items + : Array.isArray(result.webhooks) + ? result.webhooks + : [] + const data = rawItems.map(normalizeWebhook) + const pageSize = result.page_size ?? result.limit ?? data.length + const offset = result.offset ?? 0 + const page = result.page ?? (pageSize > 0 ? Math.floor(offset / pageSize) + 1 : 1) + + return { + data, + total: result.total ?? data.length, + page, + page_size: pageSize, + } } export async function listWebhooks( params?: WebhookListParams, ): Promise<{ data: Webhook[]; total: number; page: number; page_size: number }> { const result = await get>('/webhooks', params as Record) - const webhooks = result.data.map(normalizeWebhook) - return { data: webhooks, total: result.total, page: result.page, page_size: result.page_size } + return normalizeWebhookList(result) } export function createWebhook(data: CreateWebhookRequest): Promise { diff --git a/internal/api/handler/auth_handler.go b/internal/api/handler/auth_handler.go index 18e7a81..3a0ae7b 100644 --- a/internal/api/handler/auth_handler.go +++ b/internal/api/handler/auth_handler.go @@ -30,11 +30,74 @@ type AuthHandler struct { authService *service.AuthService } +const ( + refreshTokenCookieName = "ums_refresh_token" + sessionPresenceCookieName = "ums_session_present" +) + // NewAuthHandler creates a new AuthHandler func NewAuthHandler(authService *service.AuthService) *AuthHandler { return &AuthHandler{authService: authService} } +func isSecureRequest(c *gin.Context) bool { + if c == nil || c.Request == nil { + return false + } + if c.Request.TLS != nil { + return true + } + return strings.EqualFold(c.GetHeader("X-Forwarded-Proto"), "https") +} + +func (h *AuthHandler) setSessionCookies(c *gin.Context, resp *service.LoginResponse) { + if c == nil || resp == nil || strings.TrimSpace(resp.RefreshToken) == "" || h == nil || h.authService == nil { + return + } + + maxAge := int(h.authService.RefreshTokenTTLSeconds()) + secure := isSecureRequest(c) + http.SetCookie(c.Writer, &http.Cookie{ + Name: refreshTokenCookieName, + Value: resp.RefreshToken, + Path: "/", + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + MaxAge: maxAge, + }) + http.SetCookie(c.Writer, &http.Cookie{ + Name: sessionPresenceCookieName, + Value: "1", + Path: "/", + HttpOnly: false, + Secure: secure, + SameSite: http.SameSiteLaxMode, + MaxAge: maxAge, + }) +} + +func clearCookie(c *gin.Context, name string) { + if c == nil { + return + } + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: "/", + HttpOnly: name == refreshTokenCookieName, + Secure: isSecureRequest(c), + SameSite: http.SameSiteLaxMode, + MaxAge: -1, + Expires: time.Unix(0, 0), + }) +} + +func clearSessionCookies(c *gin.Context) { + clearCookie(c, refreshTokenCookieName) + clearCookie(c, sessionPresenceCookieName) +} + // Register 用户注册 // @Summary 用户注册 // @Description 用户注册新账号,支持用户名+密码或手机号注册 @@ -130,6 +193,7 @@ func (h *AuthHandler) Login(c *gin.Context) { return } + h.setSessionCookies(c, resp) c.JSON(http.StatusOK, gin.H{ "code": 0, "message": "success", @@ -150,21 +214,23 @@ func (h *AuthHandler) Login(c *gin.Context) { // @Router /api/v1/auth/login/totp-verify [post] func (h *AuthHandler) VerifyTOTPAfterPasswordLogin(c *gin.Context) { var req struct { - UserID int64 `json:"user_id" binding:"required"` - Code string `json:"code" binding:"required"` - DeviceID string `json:"device_id"` + UserID int64 `json:"user_id" binding:"required"` + Code string `json:"code" binding:"required"` + DeviceID string `json:"device_id"` + TempToken string `json:"temp_token" binding:"required"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()}) return } - resp, err := h.authService.VerifyTOTPAfterPasswordLogin(c.Request.Context(), req.UserID, req.Code, req.DeviceID) + resp, err := h.authService.VerifyTOTPAfterPasswordLogin(c.Request.Context(), req.UserID, req.Code, req.DeviceID, req.TempToken) if err != nil { handleError(c, err) return } + h.setSessionCookies(c, resp) c.JSON(http.StatusOK, gin.H{ "code": 0, "message": "success", @@ -197,6 +263,12 @@ func (h *AuthHandler) Logout(c *gin.Context) { } } + if req.RefreshToken == "" { + if cookie, err := c.Request.Cookie(refreshTokenCookieName); err == nil { + req.RefreshToken = cookie.Value + } + } + username, _ := c.Get("username") usernameStr, _ := username.(string) @@ -204,7 +276,11 @@ func (h *AuthHandler) Logout(c *gin.Context) { AccessToken: req.AccessToken, RefreshToken: req.RefreshToken, } - _ = h.authService.Logout(c.Request.Context(), usernameStr, logoutReq) + if err := h.authService.Logout(c.Request.Context(), usernameStr, logoutReq); err != nil { + handleError(c, err) + return + } + clearSessionCookies(c) c.JSON(http.StatusOK, gin.H{"message": "logged out"}) } @@ -222,20 +298,28 @@ func (h *AuthHandler) Logout(c *gin.Context) { // @Router /api/v1/auth/refresh-token [post] func (h *AuthHandler) RefreshToken(c *gin.Context) { var req struct { - RefreshToken string `json:"refresh_token" binding:"required"` + RefreshToken string `json:"refresh_token"` } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + _ = c.ShouldBindJSON(&req) + if strings.TrimSpace(req.RefreshToken) == "" { + if cookie, err := c.Request.Cookie(refreshTokenCookieName); err == nil { + req.RefreshToken = cookie.Value + } + } + if strings.TrimSpace(req.RefreshToken) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "refresh_token is required"}) return } resp, err := h.authService.RefreshToken(c.Request.Context(), req.RefreshToken) if err != nil { + clearSessionCookies(c) handleError(c, err) return } + h.setSessionCookies(c, resp) c.JSON(http.StatusOK, gin.H{ "code": 0, "message": "success", @@ -315,7 +399,7 @@ func (h *AuthHandler) GetAuthCapabilities(c *gin.Context) { // @Router /api/v1/auth/oauth/{provider} [get] func (h *AuthHandler) OAuthLogin(c *gin.Context) { provider := c.Param("provider") - c.JSON(http.StatusOK, gin.H{"code": 0, "message": "OAuth not configured", "data": gin.H{"provider": provider}}) + c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "OAuth login is not configured", "data": gin.H{"provider": provider}}) } // OAuthCallback OAuth回调 @@ -327,7 +411,7 @@ func (h *AuthHandler) OAuthLogin(c *gin.Context) { // @Success 200 {object} Response "OAuth未配置" // @Router /api/v1/auth/oauth/{provider}/callback [get] func (h *AuthHandler) OAuthCallback(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"code": 0, "message": "OAuth not configured"}) + c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "OAuth callback is not configured"}) } // OAuthExchange OAuth令牌交换 @@ -340,7 +424,7 @@ func (h *AuthHandler) OAuthCallback(c *gin.Context) { // @Success 200 {object} Response "OAuth未配置" // @Router /api/v1/auth/oauth/{provider}/exchange [post] func (h *AuthHandler) OAuthExchange(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"code": 0, "message": "OAuth not configured"}) + c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "OAuth exchange is not configured"}) } // GetEnabledOAuthProviders 获取已启用的OAuth提供商 @@ -481,6 +565,7 @@ func (h *AuthHandler) LoginByEmailCode(c *gin.Context) { }() } + h.setSessionCookies(c, resp) c.JSON(http.StatusOK, gin.H{ "code": 0, "message": "success", @@ -545,6 +630,7 @@ func (h *AuthHandler) BootstrapAdmin(c *gin.Context) { return } + h.setSessionCookies(c, resp) c.JSON(http.StatusCreated, gin.H{ "code": 0, "message": "success", @@ -561,7 +647,7 @@ func (h *AuthHandler) BootstrapAdmin(c *gin.Context) { // @Success 200 {object} Response "功能未配置" // @Router /api/v1/auth/email/bind/send [post] func (h *AuthHandler) SendEmailBindCode(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"code": 0, "message": "email bind not configured"}) + c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "email binding is not configured"}) } // BindEmail 绑定邮箱 @@ -573,7 +659,7 @@ func (h *AuthHandler) SendEmailBindCode(c *gin.Context) { // @Success 200 {object} Response "功能未配置" // @Router /api/v1/auth/email/bind [post] func (h *AuthHandler) BindEmail(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"code": 0, "message": "email bind not configured"}) + c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "email binding is not configured"}) } // UnbindEmail 解绑邮箱 @@ -585,7 +671,7 @@ func (h *AuthHandler) BindEmail(c *gin.Context) { // @Success 200 {object} Response "功能未配置" // @Router /api/v1/auth/email/unbind [post] func (h *AuthHandler) UnbindEmail(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"code": 0, "message": "email unbind not configured"}) + c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "email binding is not configured"}) } // SendPhoneBindCode 发送手机绑定验证码 @@ -597,7 +683,7 @@ func (h *AuthHandler) UnbindEmail(c *gin.Context) { // @Success 200 {object} Response "功能未配置" // @Router /api/v1/auth/phone/bind/send [post] func (h *AuthHandler) SendPhoneBindCode(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"code": 0, "message": "phone bind not configured"}) + c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "phone binding is not configured"}) } // BindPhone 绑定手机号 @@ -609,7 +695,7 @@ func (h *AuthHandler) SendPhoneBindCode(c *gin.Context) { // @Success 200 {object} Response "功能未配置" // @Router /api/v1/auth/phone/bind [post] func (h *AuthHandler) BindPhone(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"code": 0, "message": "phone bind not configured"}) + c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "phone binding is not configured"}) } // UnbindPhone 解绑手机号 @@ -621,7 +707,7 @@ func (h *AuthHandler) BindPhone(c *gin.Context) { // @Success 200 {object} Response "功能未配置" // @Router /api/v1/auth/phone/unbind [post] func (h *AuthHandler) UnbindPhone(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"code": 0, "message": "phone unbind not configured"}) + c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "phone binding is not configured"}) } // GetSocialAccounts 获取社交账号列表 @@ -645,7 +731,7 @@ func (h *AuthHandler) GetSocialAccounts(c *gin.Context) { // @Success 200 {object} Response "功能未配置" // @Router /api/v1/auth/social/bind [post] func (h *AuthHandler) BindSocialAccount(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"code": 0, "message": "social binding not configured"}) + c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "social binding is not configured"}) } // UnbindSocialAccount 解绑社交账号 @@ -657,7 +743,7 @@ func (h *AuthHandler) BindSocialAccount(c *gin.Context) { // @Success 200 {object} Response "功能未配置" // @Router /api/v1/auth/social/unbind [post] func (h *AuthHandler) UnbindSocialAccount(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"code": 0, "message": "social unbinding not configured"}) + c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "social binding is not configured"}) } func (h *AuthHandler) SupportsEmailCodeLogin() bool { diff --git a/internal/api/handler/avatar_handler.go b/internal/api/handler/avatar_handler.go index b3a999d..0166eac 100644 --- a/internal/api/handler/avatar_handler.go +++ b/internal/api/handler/avatar_handler.go @@ -13,6 +13,7 @@ import ( "github.com/gin-gonic/gin" + apimiddleware "github.com/user-management-system/internal/api/middleware" "github.com/user-management-system/internal/domain" ) @@ -33,10 +34,12 @@ func NewAvatarHandler(userRepo avatarUserRepository) *AvatarHandler { } // generateSecureToken generates a secure random token -func generateSecureToken(length int) string { +func generateSecureToken(length int) (string, error) { bytes := make([]byte, length) - rand.Read(bytes) - return hex.EncodeToString(bytes)[:length] + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return hex.EncodeToString(bytes)[:length], nil } // UploadAvatar 上传用户头像 @@ -70,17 +73,7 @@ func (h *AvatarHandler) UploadAvatar(c *gin.Context) { } // Check permission: user can only update their own avatar, or admin can update any - isAdmin := false - if roles, ok := c.Get("user_roles"); ok { - for _, role := range roles.([]*domain.Role) { - if role.Code == "admin" { - isAdmin = true - break - } - } - } - - if currentUserID != userID && !isAdmin { + if currentUserID != userID && !apimiddleware.IsAdmin(c) { c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"}) return } @@ -140,7 +133,12 @@ func (h *AvatarHandler) UploadAvatar(c *gin.Context) { } // Generate unique filename - avatarFilename := fmt.Sprintf("avatar_%d_%s%s", userID, generateSecureToken(8), ext) + token, err := generateSecureToken(8) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to generate avatar token"}) + return + } + avatarFilename := fmt.Sprintf("avatar_%d_%s%s", userID, token, ext) uploadDir := "./uploads/avatars" // Create upload directory if not exists diff --git a/internal/api/handler/handler_test.go b/internal/api/handler/handler_test.go index 52a0259..334b45b 100644 --- a/internal/api/handler/handler_test.go +++ b/internal/api/handler/handler_test.go @@ -7,7 +7,9 @@ import ( "io" "mime/multipart" "net/http" + "net/http/cookiejar" "net/http/httptest" + "os" "sync" "sync/atomic" "testing" @@ -35,6 +37,11 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) { t.Helper() gin.SetMode(gin.TestMode) + previousBootstrapSecret, hadBootstrapSecret := os.LookupEnv("BOOTSTRAP_SECRET") + if err := os.Setenv("BOOTSTRAP_SECRET", "test-bootstrap-secret"); err != nil { + t.Fatalf("set bootstrap secret failed: %v", err) + } + id := atomic.AddInt64(&handlerDbCounter, 1) dsn := fmt.Sprintf("file:handlerdb_%d_%s?mode=memory&cache=shared", id, t.Name()) db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{ @@ -64,6 +71,20 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) { t.Fatalf("db migration failed: %v", err) } + adminRole := &domain.Role{Code: "admin", Name: "管理员", Status: domain.RoleStatusEnabled} + if err := db.Create(adminRole).Error; err != nil { + t.Fatalf("seed admin role failed: %v", err) + } + for _, permission := range domain.DefaultPermissions() { + perm := permission + if err := db.Create(&perm).Error; err != nil { + t.Fatalf("seed permission %s failed: %v", perm.Code, err) + } + if err := db.Create(&domain.RolePermission{RoleID: adminRole.ID, PermissionID: perm.ID}).Error; err != nil { + t.Fatalf("seed role permission %s failed: %v", perm.Code, err) + } + } + jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{ HS256Secret: "test-handler-secret-key", AccessTokenExpire: 15 * time.Minute, @@ -136,6 +157,11 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) { server := httptest.NewServer(engine) return server, func() { server.Close() + if hadBootstrapSecret { + _ = os.Setenv("BOOTSTRAP_SECRET", previousBootstrapSecret) + } else { + _ = os.Unsetenv("BOOTSTRAP_SECRET") + } if sqlDB, _ := db.DB(); sqlDB != nil { sqlDB.Close() } @@ -207,6 +233,35 @@ func registerUser(baseURL, username, email, password string) bool { return resp.StatusCode == http.StatusCreated } +func bootstrapAdminToken(baseURL, username, email, password string) string { + payload, _ := json.Marshal(map[string]interface{}{ + "username": username, + "email": email, + "password": password, + }) + req, _ := http.NewRequest("POST", baseURL+"/api/v1/auth/bootstrap-admin", bytes.NewReader(payload)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Bootstrap-Secret", "test-bootstrap-secret") + resp, err := (&http.Client{}).Do(req) + if err != nil { + return "" + } + defer resp.Body.Close() + bodyBytes, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusCreated { + return "" + } + var result map[string]interface{} + if err := json.Unmarshal(bodyBytes, &result); err != nil { + return "" + } + data, ok := result["data"].(map[string]interface{}) + if !ok || data["access_token"] == nil { + return "" + } + return data["access_token"].(string) +} + // ============================================================================= // Auth Handler Tests // ============================================================================= @@ -292,6 +347,89 @@ func TestAuthHandler_Login_Success(t *testing.T) { } } +func TestAuthHandler_Login_SetsSessionCookies(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "cookieuser", "cookie@example.com", "Password123!") + resp, body := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{ + "account": "cookieuser", + "password": "Password123!", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } + + cookies := resp.Cookies() + var hasRefreshCookie bool + var hasPresenceCookie bool + for _, cookie := range cookies { + switch cookie.Name { + case "ums_refresh_token": + hasRefreshCookie = cookie.HttpOnly && cookie.Value != "" + case "ums_session_present": + hasPresenceCookie = !cookie.HttpOnly && cookie.Value == "1" + } + } + if !hasRefreshCookie { + t.Fatalf("expected login response to set ums_refresh_token cookie, got %#v", cookies) + } + if !hasPresenceCookie { + t.Fatalf("expected login response to set ums_session_present cookie, got %#v", cookies) + } +} + +func TestAuthHandler_RefreshToken_UsesCookieFallback(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "refreshcookieuser", "refreshcookie@example.com", "Password123!") + jar, err := cookiejar.New(nil) + if err != nil { + t.Fatalf("cookiejar.New() error: %v", err) + } + client := &http.Client{Jar: jar} + + loginBody, _ := json.Marshal(map[string]interface{}{ + "account": "refreshcookieuser", + "password": "Password123!", + }) + loginReq, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/login", bytes.NewReader(loginBody)) + loginReq.Header.Set("Content-Type", "application/json") + loginResp, err := client.Do(loginReq) + if err != nil { + t.Fatalf("login request failed: %v", err) + } + defer loginResp.Body.Close() + if loginResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(loginResp.Body) + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, string(payload)) + } + + refreshReq, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/refresh", nil) + refreshReq.Header.Set("Content-Type", "application/json") + refreshResp, err := client.Do(refreshReq) + if err != nil { + t.Fatalf("refresh request failed: %v", err) + } + defer refreshResp.Body.Close() + refreshPayload, _ := io.ReadAll(refreshResp.Body) + if refreshResp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, refreshResp.StatusCode, string(refreshPayload)) + } + + var parsed map[string]interface{} + if err := json.Unmarshal(refreshPayload, &parsed); err != nil { + t.Fatalf("refresh response json unmarshal failed: %v", err) + } + data, _ := parsed["data"].(map[string]interface{}) + if data == nil || data["access_token"] == nil || data["refresh_token"] == nil { + t.Fatalf("expected refresh response to include token pair, got %v", parsed) + } +} + func TestAuthHandler_Login_WrongPassword(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() @@ -336,33 +474,61 @@ func TestAuthHandler_BootstrapAdmin_MissingSecret(t *testing.T) { }) defer resp.Body.Close() - // Without BOOTSTRAP_SECRET env var set, should get forbidden - if resp.StatusCode != http.StatusForbidden { - t.Errorf("expected status %d for missing bootstrap secret, got %d", http.StatusForbidden, resp.StatusCode) + // P0 修复后:已配置 BOOTSTRAP_SECRET 但未提供 header,应返回 401 + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("expected status %d for missing bootstrap secret header, got %d", http.StatusUnauthorized, resp.StatusCode) } } -func TestAuthHandler_GetAuthCapabilities(t *testing.T) { +func TestAuthHandler_VerifyTOTPAfterPasswordLogin_RequiresTempToken(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() - resp, body := doGet(server.URL+"/api/v1/auth/capabilities", "") + resp, body := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{ + "user_id": 1, + "code": "123456", + "device_id": "device-1", + }) defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode) - } - - var result map[string]interface{} - json.Unmarshal([]byte(body), &result) - if result["code"] != float64(0) { - t.Errorf("expected code 0, got %v", result["code"]) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body) } } -// ============================================================================= -// User Handler Tests -// ============================================================================= +func TestAuthHandler_UnconfiguredOAuthAndBindingsFailClosed(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "failclosed", "failclosed@test.com", "AdminPass123!") + token := getToken(server.URL, "failclosed", "AdminPass123!") + + tests := []struct { + name string + url string + body map[string]interface{} + }{ + {name: "oauth login", url: server.URL + "/api/v1/auth/oauth/github"}, + {name: "email bind code", url: server.URL + "/api/v1/users/me/bind-email/code", body: map[string]interface{}{"email": "bind@example.com"}}, + {name: "social bind", url: server.URL + "/api/v1/users/me/bind-social", body: map[string]interface{}{"provider": "github"}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var resp *http.Response + var body string + if tc.body == nil { + resp, body = doGet(tc.url, token) + } else { + resp, body = doPost(tc.url, token, tc.body) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusServiceUnavailable, resp.StatusCode, body) + } + }) + } +} func TestUserHandler_CreateUser_RequiresAdmin(t *testing.T) { server, cleanup := setupHandlerTestServer(t) @@ -400,39 +566,33 @@ func TestUserHandler_CreateUser_Unauthorized(t *testing.T) { } } -func TestUserHandler_ListUsers_Success(t *testing.T) { +func TestUserHandler_ListUsers_ForbiddenForRegularUser(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() - registerUser(server.URL, "listadmin", "listadmin@test.com", "AdminPass123!") - token := getToken(server.URL, "listadmin", "AdminPass123!") + registerUser(server.URL, "listuser", "listuser@test.com", "AdminPass123!") + token := getToken(server.URL, "listuser", "AdminPass123!") resp, body := doGet(server.URL+"/api/v1/users?page=1&page_size=10", token) defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) - } - - var result map[string]interface{} - json.Unmarshal([]byte(body), &result) - if result["code"] != float64(0) { - t.Errorf("expected code 0, got %v", result["code"]) + if resp.StatusCode != http.StatusForbidden { + t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body) } } -func TestUserHandler_GetUser_Success(t *testing.T) { +func TestUserHandler_GetUser_ForbiddenForRegularUser(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() - registerUser(server.URL, "getadmin", "getadmin@test.com", "AdminPass123!") - token := getToken(server.URL, "getadmin", "AdminPass123!") + registerUser(server.URL, "getuser", "getuser@test.com", "AdminPass123!") + token := getToken(server.URL, "getuser", "AdminPass123!") - resp, _ := doGet(server.URL+"/api/v1/users/1", token) + resp, body := doGet(server.URL+"/api/v1/users/1", token) defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode) + if resp.StatusCode != http.StatusForbidden { + t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body) } } @@ -440,8 +600,8 @@ func TestUserHandler_UpdateUser_Success(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() - registerUser(server.URL, "updateadmin", "updateadmin@test.com", "AdminPass123!") - token := getToken(server.URL, "updateadmin", "AdminPass123!") + registerUser(server.URL, "updateuser", "update@example.com", "UserPass123!") + token := getToken(server.URL, "updateuser", "UserPass123!") resp, body := doPut(server.URL+"/api/v1/users/1", token, map[string]string{"nickname": "Updated Nickname"}) defer resp.Body.Close() @@ -451,6 +611,43 @@ func TestUserHandler_UpdateUser_Success(t *testing.T) { } } +func TestUserHandler_UpdateUser_AdminCanUpdateOther(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + token := bootstrapAdminToken(server.URL, "updateadmin", "updateadmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin token should succeed") + } + registerUser(server.URL, "manageduser", "manageduser@test.com", "UserPass123!") + + resp, body := doPut(server.URL+"/api/v1/users/2", token, map[string]string{"nickname": "Admin Updated"}) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestUserHandler_UpdatePassword_NonAdminCannotUpdateOther(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "pwd-user-1", "pwd-user-1@test.com", "UserPass123!") + token := getToken(server.URL, "pwd-user-1", "UserPass123!") + registerUser(server.URL, "pwd-user-2", "pwd-user-2@test.com", "TargetPass123!") + + resp, body := doPut(server.URL+"/api/v1/users/2/password", token, map[string]string{ + "old_password": "TargetPass123!", + "new_password": "TargetNew456!", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body) + } +} + func TestUserHandler_DeleteUser_NonAdmin_Forbidden(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() @@ -471,8 +668,10 @@ func TestUserHandler_SearchUsers_Success(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() - registerUser(server.URL, "searchadmin", "searchadmin@test.com", "AdminPass123!") - token := getToken(server.URL, "searchadmin", "AdminPass123!") + token := bootstrapAdminToken(server.URL, "searchadmin", "searchadmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin token should succeed") + } resp, body := doGet(server.URL+"/api/v1/users/1", token) defer resp.Body.Close() @@ -515,6 +714,24 @@ func TestUserHandler_GetUserRoles_Success(t *testing.T) { } } +func TestUserHandler_GetUserRoles_AdminCanViewOther(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + token := bootstrapAdminToken(server.URL, "rolesbootstrap", "rolesbootstrap@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin token should succeed") + } + registerUser(server.URL, "role-target", "role-target@test.com", "UserPass123!") + + resp, body := doGet(server.URL+"/api/v1/users/2/roles", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + func TestUserHandler_AssignRoles_RequiresAdmin(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() @@ -974,8 +1191,10 @@ func TestInvalidUserID_ReturnsBadRequest(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() - registerUser(server.URL, "invalidid", "invalidid@test.com", "AdminPass123!") - token := getToken(server.URL, "invalidid", "AdminPass123!") + token := bootstrapAdminToken(server.URL, "invalidid", "invalidid@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin token should succeed") + } resp, _ := doGet(server.URL+"/api/v1/users/invalid", token) defer resp.Body.Close() @@ -989,8 +1208,10 @@ func TestNonExistentUserID_ReturnsNotFound(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() - registerUser(server.URL, "notfound", "notfound@test.com", "AdminPass123!") - token := getToken(server.URL, "notfound", "AdminPass123!") + token := bootstrapAdminToken(server.URL, "notfound", "notfound@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin token should succeed") + } resp, _ := doGet(server.URL+"/api/v1/users/99999", token) defer resp.Body.Close() @@ -1350,6 +1571,29 @@ func TestAvatarHandler_UploadAvatar_NonAdminCannotUpdateOther(t *testing.T) { } } +func TestAvatarHandler_UploadAvatar_AdminCanUpdateOther(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + token := bootstrapAdminToken(server.URL, "avataradmin", "avataradmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin token should succeed") + } + registerUser(server.URL, "avatar-target", "avatar-target@test.com", "UserPass123!") + + fileContent := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A} + resp, err := doUploadFile(server.URL+"/api/v1/users/2/avatar", token, "avatar", "test.png", fileContent) + if err != nil { + t.Fatalf("upload request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("expected status %d for admin updating other's avatar, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes)) + } +} + func TestAvatarHandler_UploadAvatar_UserNotFoundOrForbidden(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() diff --git a/internal/api/handler/user_handler.go b/internal/api/handler/user_handler.go index d96df7d..6215e5a 100644 --- a/internal/api/handler/user_handler.go +++ b/internal/api/handler/user_handler.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" + apimiddleware "github.com/user-management-system/internal/api/middleware" "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/service" @@ -187,16 +188,7 @@ func (h *UserHandler) UpdateUser(c *gin.Context) { // Authorization: only self or admin can update user profile currentUserID := c.GetInt64("user_id") - isAdmin := false - if roles, ok := c.Get("user_roles"); ok { - for _, role := range roles.([]*domain.Role) { - if role.Code == "admin" { - isAdmin = true - break - } - } - } - if currentUserID != id && !isAdmin { + if currentUserID != id && !apimiddleware.IsAdmin(c) { c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"}) return } @@ -289,6 +281,12 @@ func (h *UserHandler) UpdatePassword(c *gin.Context) { return } + currentUserID := c.GetInt64("user_id") + if currentUserID != id && !apimiddleware.IsAdmin(c) { + c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"}) + return + } + if err := h.userService.ChangePassword(c.Request.Context(), id, req.OldPassword, req.NewPassword); err != nil { handleError(c, err) return @@ -370,16 +368,7 @@ func (h *UserHandler) GetUserRoles(c *gin.Context) { // Authorization: only self or admin can view user roles currentUserID := c.GetInt64("user_id") - isAdmin := false - if roles, ok := c.Get("user_roles"); ok { - for _, role := range roles.([]*domain.Role) { - if role.Code == "admin" { - isAdmin = true - break - } - } - } - if currentUserID != id && !isAdmin { + if currentUserID != id && !apimiddleware.IsAdmin(c) { c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"}) return } diff --git a/internal/api/middleware/ratelimit.go b/internal/api/middleware/ratelimit.go index 8b566a9..2e22420 100644 --- a/internal/api/middleware/ratelimit.go +++ b/internal/api/middleware/ratelimit.go @@ -1,6 +1,7 @@ package middleware import ( + "os" "sync" "time" @@ -89,6 +90,12 @@ func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc { } func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc { + if os.Getenv("DISABLE_RATE_LIMIT") == "1" { + return func(c *gin.Context) { + c.Next() + } + } + limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity) return func(c *gin.Context) { diff --git a/internal/api/router/router.go b/internal/api/router/router.go index 83ffa9a..ec736d8 100644 --- a/internal/api/router/router.go +++ b/internal/api/router/router.go @@ -142,6 +142,7 @@ func (r *Router) Setup() *gin.Engine { authGroup.POST("/login/totp-verify", r.rateLimitMiddleware.Login(), r.authHandler.VerifyTOTPAfterPasswordLogin) authGroup.POST("/refresh", r.rateLimitMiddleware.Refresh(), r.authHandler.RefreshToken) authGroup.GET("/capabilities", r.authHandler.GetAuthCapabilities) + authGroup.GET("/csrf-token", r.authHandler.GetCSRFToken) authGroup.POST("/activate-email", r.authHandler.ActivateEmail) authGroup.POST("/resend-activation", r.authHandler.ResendActivationEmail) @@ -189,7 +190,6 @@ func (r *Router) Setup() *gin.Engine { protected.Use(r.authMiddleware.Required()) protected.Use(r.rateLimitMiddleware.API()) { - protected.GET("/auth/csrf-token", r.authHandler.GetCSRFToken) protected.POST("/auth/logout", r.authHandler.Logout) protected.GET("/auth/userinfo", r.authHandler.GetUserInfo) @@ -206,8 +206,8 @@ func (r *Router) Setup() *gin.Engine { users := protected.Group("/users") { users.POST("", middleware.RequirePermission("user:manage"), r.userHandler.CreateUser) - users.GET("", r.userHandler.ListUsers) - users.GET("/:id", r.userHandler.GetUser) + users.GET("", middleware.RequirePermission("user:manage"), r.userHandler.ListUsers) + users.GET("/:id", middleware.RequirePermission("user:manage"), r.userHandler.GetUser) users.PUT("/:id", r.userHandler.UpdateUser) users.DELETE("/:id", middleware.RequirePermission("user:delete"), r.userHandler.DeleteUser) users.PUT("/:id/password", r.userHandler.UpdatePassword) diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index 40856c2..a0ba637 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -54,6 +54,7 @@ type Claims struct { Remember bool `json:"remember,omitempty"` // 记住登录标记 JTI string `json:"jti"` // JWT ID,用于黑名单 PCE int64 `json:"pce,omitempty"` // Password Changed Epoch,密码变更时间戳,用于 token 失效机制 + DeviceID string `json:"device_id,omitempty"` jwt.RegisteredClaims } @@ -494,6 +495,47 @@ func (j *JWT) ValidateRefreshToken(tokenString string) (*Claims, error) { return claims, nil } +func (j *JWT) GenerateTOTPChallengeToken(userID int64, username, deviceID string, pce int64) (string, error) { + if err := j.ensureReady(); err != nil { + return "", err + } + + now := time.Now() + jti, err := generateJTI() + if err != nil { + return "", err + } + claims := Claims{ + UserID: userID, + Username: username, + Type: "totp_challenge", + JTI: jti, + PCE: pce, + DeviceID: strings.TrimSpace(deviceID), + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + } + + token := jwt.NewWithClaims(j.signingMethod(), claims) + return token.SignedString(j.signingKey()) +} + +func (j *JWT) ValidateTOTPChallengeToken(tokenString string) (*Claims, error) { + claims, err := j.ParseToken(tokenString) + if err != nil { + return nil, err + } + + if claims.Type != "totp_challenge" { + return nil, errors.New("invalid token type") + } + + return claims, nil +} + // RefreshAccessToken 刷新访问令牌 func (j *JWT) RefreshAccessToken(refreshTokenString string) (string, error) { claims, err := j.ValidateRefreshToken(refreshTokenString) diff --git a/internal/service/auth.go b/internal/service/auth.go index eec0a2e..d045f4f 100644 --- a/internal/service/auth.go +++ b/internal/service/auth.go @@ -122,7 +122,7 @@ type LoginResponse struct { ExpiresIn int64 `json:"expires_in,omitempty"` User *UserInfo `json:"user,omitempty"` // RequiresTOTP 指示登录需要额外的TOTP验证(当设备未信任时) - RequiresTOTP bool `json:"requires_totp,omitempty"` + RequiresTOTP bool `json:"requires_totp,omitempty"` // TempToken 临时令牌,用于TOTP验证阶段(短生命周期,不可用于常规API) TempToken string `json:"temp_token,omitempty"` // UserID 当RequiresTOTP为true时返回,用于后续TOTP验证 @@ -759,11 +759,16 @@ func (s *AuthService) Login(ctx context.Context, req *LoginRequest, ip string) ( // P0-07 安全修复:检查是否需要TOTP验证(用户启用了TOTP且设备未信任) if s.isTOTPRequiredForLogin(ctx, user, req.DeviceID) { + tempToken, err := s.jwtManager.GenerateTOTPChallengeToken(user.ID, user.Username, req.DeviceID, user.PasswordChangedAt.Unix()) + if err != nil { + return nil, err + } // 返回RequiresTOTP指示前端需要完成TOTP验证 // 前端应调用 /auth/login/totp-verify 接口完成验证 return &LoginResponse{ RequiresTOTP: true, - UserID: user.ID, + TempToken: tempToken, + UserID: user.ID, }, nil } @@ -808,10 +813,27 @@ func (s *AuthService) isTOTPRequiredForLogin(ctx context.Context, user *domain.U // VerifyTOTPAfterPasswordLogin 完成密码登录后的TOTP验证 // 当用户启用了TOTP但设备未信任时,密码登录会返回RequiresTOTP=true // 前端需要调用此接口完成TOTP验证以获取令牌 -func (s *AuthService) VerifyTOTPAfterPasswordLogin(ctx context.Context, userID int64, totpCode, deviceID string) (*LoginResponse, error) { +func (s *AuthService) VerifyTOTPAfterPasswordLogin(ctx context.Context, userID int64, totpCode, deviceID, tempToken string) (*LoginResponse, error) { if s == nil { return nil, errors.New("auth service is not initialized") } + if s.jwtManager == nil { + return nil, errors.New("jwt manager is not configured") + } + + claims, err := s.jwtManager.ValidateTOTPChallengeToken(strings.TrimSpace(tempToken)) + if err != nil { + return nil, errors.New("TOTP challenge is invalid or expired") + } + if claims == nil || claims.UserID != userID { + return nil, errors.New("TOTP challenge does not match user") + } + if strings.TrimSpace(claims.DeviceID) != strings.TrimSpace(deviceID) { + return nil, errors.New("TOTP challenge does not match device") + } + if s.IsTokenBlacklisted(ctx, claims.JTI) { + return nil, errors.New("TOTP challenge has already been used") + } user, err := s.userRepo.GetByID(ctx, userID) if err != nil { @@ -826,6 +848,9 @@ func (s *AuthService) VerifyTOTPAfterPasswordLogin(ctx context.Context, userID i if err := s.VerifyTOTP(ctx, userID, totpCode, deviceID); err != nil { return nil, err } + if err := s.blacklistTokenClaims(ctx, tempToken, s.jwtManager.ValidateTOTPChallengeToken); err != nil { + return nil, fmt.Errorf("totp challenge revocation failed: %w", err) + } // TOTP验证成功,返回完整登录响应 return s.generateLoginResponseWithoutRemember(ctx, user) @@ -902,18 +927,22 @@ func (s *AuthService) Logout(ctx context.Context, username string, req *LogoutRe return nil } - _ = s.blacklistTokenClaims(ctx, req.AccessToken, func(token string) (*auth.Claims, error) { + if err := s.blacklistTokenClaims(ctx, req.AccessToken, func(token string) (*auth.Claims, error) { if s.jwtManager == nil { return nil, nil } return s.jwtManager.ValidateAccessToken(token) - }) - _ = s.blacklistTokenClaims(ctx, req.RefreshToken, func(token string) (*auth.Claims, error) { + }); err != nil { + return err + } + if err := s.blacklistTokenClaims(ctx, req.RefreshToken, func(token string) (*auth.Claims, error) { if s.jwtManager == nil { return nil, nil } return s.jwtManager.ValidateRefreshToken(token) - }) + }); err != nil { + return err + } if strings.TrimSpace(username) != "" { s.publishEvent(ctx, domain.EventUserLogout, map[string]interface{}{ diff --git a/internal/service/auth_login_test.go b/internal/service/auth_login_test.go index cad7468..c8674b6 100644 --- a/internal/service/auth_login_test.go +++ b/internal/service/auth_login_test.go @@ -157,6 +157,41 @@ func TestAuthService_Login(t *testing.T) { t.Error("nil service should return error") } }) + + t.Run("login with totp enabled returns temporary challenge token", func(t *testing.T) { + req := &service.RegisterRequest{ + Username: "totploginuser", + Password: "Test123!", + Email: "totplogin@test.com", + } + user, err := env.authSvc.Register(ctx, req) + if err != nil { + t.Fatalf("Register failed: %v", err) + } + if err := env.db.Model(&domain.User{}).Where("id = ?", user.ID).Updates(map[string]interface{}{ + "totp_enabled": true, + "totp_secret": "JBSWY3DPEHPK3PXP", + }).Error; err != nil { + t.Fatalf("enable totp failed: %v", err) + } + + resp, err := env.authSvc.Login(ctx, &service.LoginRequest{ + Username: "totploginuser", + Password: "Test123!", + }, "127.0.0.1") + if err != nil { + t.Fatalf("Login failed: %v", err) + } + if !resp.RequiresTOTP { + t.Fatal("expected requires_totp response") + } + if resp.TempToken == "" { + t.Fatal("expected temp_token for second-factor challenge") + } + if resp.AccessToken != "" || resp.RefreshToken != "" { + t.Fatal("totp challenge should not mint full session tokens before second factor verification") + } + }) } func TestAuthService_Register(t *testing.T) { diff --git a/internal/service/auth_logout_failclosed_test.go b/internal/service/auth_logout_failclosed_test.go new file mode 100644 index 0000000..2e85871 --- /dev/null +++ b/internal/service/auth_logout_failclosed_test.go @@ -0,0 +1,87 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strings" + "testing" + "time" + + "github.com/user-management-system/internal/auth" + "github.com/user-management-system/internal/cache" + "github.com/user-management-system/internal/domain" + "github.com/user-management-system/internal/repository" + gormsqlite "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +type failingL2Cache struct { + setErr error +} + +func (f *failingL2Cache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error { + return f.setErr +} +func (f *failingL2Cache) Get(ctx context.Context, key string) (interface{}, error) { return nil, nil } +func (f *failingL2Cache) Delete(ctx context.Context, key string) error { return nil } +func (f *failingL2Cache) Exists(ctx context.Context, key string) (bool, error) { return false, nil } +func (f *failingL2Cache) Clear(ctx context.Context) error { return nil } +func (f *failingL2Cache) Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error) { + return 0, nil +} +func (f *failingL2Cache) Close() error { return nil } + +func TestAuthService_Logout_FailsClosedWhenBlacklistWriteFails(t *testing.T) { + dsn := fmt.Sprintf("file:logoutfailclosed_%d?mode=memory&cache=shared", time.Now().UnixNano()) + db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{DriverName: "sqlite", DSN: dsn}), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + t.Fatalf("open db failed: %v", err) + } + if err := db.AutoMigrate(&domain.User{}, &domain.Role{}, &domain.UserRole{}, &domain.LoginLog{}, &domain.PasswordHistory{}); err != nil { + t.Fatalf("migrate failed: %v", err) + } + for _, role := range domain.PredefinedRoles { + roleCopy := role + if err := db.Create(&roleCopy).Error; err != nil { + t.Fatalf("seed role %s failed: %v", role.Code, err) + } + } + + jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{ + HS256Secret: fmt.Sprintf("logout-failclosed-secret-%d", time.Now().UnixNano()), + AccessTokenExpire: 15 * time.Minute, + RefreshTokenExpire: 7 * 24 * time.Hour, + }) + if err != nil { + t.Fatalf("create jwt manager failed: %v", err) + } + + userRepo := repository.NewUserRepository(db) + userRoleRepo := repository.NewUserRoleRepository(db) + roleRepo := repository.NewRoleRepository(db) + cacheManager := cache.NewCacheManager(cache.NewL1Cache(), &failingL2Cache{setErr: errors.New("forced blacklist failure")}) + + authSvc := NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute) + authSvc.SetRoleRepositories(userRoleRepo, roleRepo) + + ctx := context.Background() + if _, err := authSvc.Register(ctx, &RegisterRequest{Username: "logoutfail", Password: "Password123!"}); err != nil { + t.Fatalf("register failed: %v", err) + } + loginResp, err := authSvc.Login(ctx, &LoginRequest{Username: "logoutfail", Password: "Password123!"}, "127.0.0.1") + if err != nil { + t.Fatalf("login failed: %v", err) + } + + err = authSvc.Logout(ctx, "logoutfail", &LogoutRequest{AccessToken: loginResp.AccessToken, RefreshToken: loginResp.RefreshToken}) + if err == nil { + t.Fatal("expected logout to fail closed when blacklist write fails") + } + if !strings.Contains(err.Error(), "forced blacklist failure") { + t.Fatalf("expected propagated blacklist error, got: %v", err) + } +} diff --git a/internal/service/user_service.go b/internal/service/user_service.go index dee3ab6..5153f40 100644 --- a/internal/service/user_service.go +++ b/internal/service/user_service.go @@ -125,24 +125,49 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassw return errors.New("密码哈希失败") } - // 保存新密码到历史记录(异步,不阻塞密码更新) - if s.passwordHistoryRepo != nil { - // #nosec G118 - 使用带超时的独立 context(不能使用请求 ctx,该 goroutine 在请求完成后仍可能运行) - go func(hashedPw string) { // #nosec G118 - bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = s.passwordHistoryRepo.Create(bgCtx, &domain.PasswordHistory{ - UserID: userID, - PasswordHash: hashedPw, - }) - _ = s.passwordHistoryRepo.DeleteOldRecords(bgCtx, userID, passwordHistoryLimit) - }(newHashedPassword) - } - - // 更新密码(使用同一哈希值) + oldPasswordHash := user.Password + oldPasswordChangedAt := user.PasswordChangedAt user.Password = newHashedPassword user.PasswordChangedAt = time.Now() - return s.userRepo.Update(ctx, user) + + if s.passwordHistoryRepo == nil { + return s.userRepo.Update(ctx, user) + } + + return s.userRepo.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.Model(&domain.User{}). + Where("id = ?", user.ID). + Updates(map[string]interface{}{"password": user.Password, "password_changed_at": user.PasswordChangedAt}).Error; err != nil { + user.Password = oldPasswordHash + user.PasswordChangedAt = oldPasswordChangedAt + return err + } + + if err := tx.Create(&domain.PasswordHistory{UserID: userID, PasswordHash: newHashedPassword}).Error; err != nil { + user.Password = oldPasswordHash + user.PasswordChangedAt = oldPasswordChangedAt + return err + } + + var ids []int64 + if err := tx.Model(&domain.PasswordHistory{}). + Where("user_id = ?", userID). + Order("created_at DESC"). + Limit(passwordHistoryLimit). + Pluck("id", &ids).Error; err != nil { + user.Password = oldPasswordHash + user.PasswordChangedAt = oldPasswordChangedAt + return err + } + if len(ids) > 0 { + if err := tx.Where("user_id = ? AND id NOT IN ?", userID, ids).Delete(&domain.PasswordHistory{}).Error; err != nil { + user.Password = oldPasswordHash + user.PasswordChangedAt = oldPasswordChangedAt + return err + } + } + return nil + }) } // GetByID 根据ID获取用户 diff --git a/internal/service/user_service_test.go b/internal/service/user_service_test.go index 5b3a021..dadbeb1 100644 --- a/internal/service/user_service_test.go +++ b/internal/service/user_service_test.go @@ -6,6 +6,7 @@ import ( "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/domain" + "github.com/user-management-system/internal/repository" "github.com/user-management-system/internal/service" ) @@ -339,6 +340,32 @@ func TestUserService_ChangePassword(t *testing.T) { t.Error("Expected error for weak new password") } }) + + t.Run("Change password persists history synchronously", func(t *testing.T) { + hashedPassword, _ := auth.HashPassword("HistoryOld123!") + user := &domain.User{ + Username: "historysync", + Password: hashedPassword, + Status: domain.UserStatusActive, + } + env.userSvc.Create(ctx, user) + + if err := env.userSvc.ChangePassword(ctx, user.ID, "HistoryOld123!", "HistoryNew456!"); err != nil { + t.Fatalf("ChangePassword failed: %v", err) + } + + historyRepo := repository.NewPasswordHistoryRepository(env.db) + history, err := historyRepo.GetByUserID(ctx, user.ID, 10) + if err != nil { + t.Fatalf("GetByUserID failed: %v", err) + } + if len(history) == 0 { + t.Fatal("expected password history to be written synchronously") + } + if !auth.VerifyPassword(history[0].PasswordHash, "HistoryNew456!") { + t.Fatal("latest password history hash does not match new password") + } + }) } func TestUserService_BatchUpdateStatus(t *testing.T) {