feat: admin frontend - React + Vite, auth pages, user management, roles, permissions, webhooks, devices, logs

This commit is contained in:
2026-04-02 11:20:20 +08:00
parent dcc1f186f8
commit 4718980ab5
235 changed files with 35682 additions and 0 deletions

View File

@@ -0,0 +1,29 @@
import { describe, expect, it } from 'vitest'
import {
buildOAuthCallbackReturnTo,
parseOAuthCallbackHash,
sanitizeAuthRedirect,
} from './oauth'
describe('oauth auth helpers', () => {
it('sanitizes redirect paths to internal routes only', () => {
expect(sanitizeAuthRedirect('/users')).toBe('/users')
expect(sanitizeAuthRedirect('https://evil.example.com')).toBe('/dashboard')
expect(sanitizeAuthRedirect('//evil.example.com')).toBe('/dashboard')
expect(sanitizeAuthRedirect('users')).toBe('/dashboard')
})
it('builds oauth callback return url on current origin', () => {
expect(buildOAuthCallbackReturnTo('/users')).toBe('http://localhost:3000/login/oauth/callback?redirect=%2Fusers')
})
it('parses oauth callback hash payload', () => {
expect(parseOAuthCallbackHash('#status=success&code=abc&provider=github')).toEqual({
status: 'success',
code: 'abc',
provider: 'github',
message: '',
})
})
})

View File

@@ -0,0 +1,27 @@
export function sanitizeAuthRedirect(target: string | null | undefined, fallback: string = '/dashboard'): string {
const value = (target || '').trim()
if (!value.startsWith('/') || value.startsWith('//')) {
return fallback
}
return value
}
export function buildOAuthCallbackReturnTo(redirectPath: string): string {
const callbackUrl = new URL('/login/oauth/callback', window.location.origin)
if (redirectPath && redirectPath !== '/dashboard') {
callbackUrl.searchParams.set('redirect', redirectPath)
}
return callbackUrl.toString()
}
export function parseOAuthCallbackHash(hash: string): Record<string, string> {
const normalized = hash.startsWith('#') ? hash.slice(1) : hash
const values = new URLSearchParams(normalized)
return {
status: values.get('status') || '',
code: values.get('code') || '',
provider: values.get('provider') || '',
message: values.get('message') || '',
}
}

View File

@@ -0,0 +1,11 @@
/**
* 应用配置
* 从环境变量中读取配置项
*/
export const config = {
/**
* API 基础地址
*/
apiBaseUrl: import.meta.env.VITE_API_BASE_URL || '/api/v1',
} as const

View File

@@ -0,0 +1,126 @@
import { describe, expect, it } from 'vitest'
import { AppError, ErrorType, isAppError } from './AppError'
import { getErrorMessage, isFormValidationError } from './index'
describe('AppError', () => {
it('uses the default status and type when options are omitted', () => {
const error = new AppError(1001, 'business failed')
expect(error).toBeInstanceOf(AppError)
expect(error).toBeInstanceOf(Error)
expect(error.name).toBe('AppError')
expect(error.code).toBe(1001)
expect(error.status).toBe(500)
expect(error.type).toBe(ErrorType.BUSINESS)
expect(error.cause).toBeUndefined()
})
it('keeps explicit options including cause and exposes type guards', () => {
const cause = new Error('root cause')
const authByStatus = new AppError(2001, 'status-auth', {
status: 401,
type: ErrorType.BUSINESS,
cause,
})
const forbiddenByStatus = new AppError(2002, 'status-forbidden', {
status: 403,
type: ErrorType.BUSINESS,
})
const networkError = AppError.network('network failed', cause)
expect(authByStatus.cause).toBe(cause)
expect(authByStatus.isAuthError()).toBe(true)
expect(forbiddenByStatus.isForbidden()).toBe(true)
expect(networkError.isNetworkError()).toBe(true)
})
it('maps backend responses to the expected error type for each status family', () => {
const unauthorized = AppError.fromResponse({ code: 40101, message: 'unauthorized' }, 401)
const forbidden = AppError.fromResponse({ code: 40301, message: 'forbidden' }, 403)
const notFound = AppError.fromResponse({ code: 40401, message: 'missing' }, 404)
const network = AppError.fromResponse({ code: 50001, message: 'server error' }, 500)
const business = AppError.fromResponse({ code: 40001, message: 'business error' }, 400)
expect(unauthorized.type).toBe(ErrorType.AUTH)
expect(forbidden.type).toBe(ErrorType.FORBIDDEN)
expect(notFound.type).toBe(ErrorType.NOT_FOUND)
expect(network.type).toBe(ErrorType.NETWORK)
expect(business.type).toBe(ErrorType.BUSINESS)
})
it('creates auth, forbidden, and validation errors with the expected defaults', () => {
const auth = AppError.auth()
const forbidden = AppError.forbidden()
const validation = AppError.validation('validation failed')
expect(auth.code).toBe(401)
expect(auth.status).toBe(401)
expect(auth.type).toBe(ErrorType.AUTH)
expect(auth.message.length).toBeGreaterThan(0)
expect(forbidden.code).toBe(403)
expect(forbidden.status).toBe(403)
expect(forbidden.type).toBe(ErrorType.FORBIDDEN)
expect(forbidden.message.length).toBeGreaterThan(0)
expect(validation.code).toBe(400)
expect(validation.status).toBe(400)
expect(validation.type).toBe(ErrorType.VALIDATION)
expect(validation.message).toBe('validation failed')
})
it('returns user-facing messages for each supported error type', () => {
const networkMessage = AppError.network('network failed').getUserMessage()
const authMessage = AppError.auth('custom auth').getUserMessage()
const forbiddenMessage = AppError.forbidden('custom forbidden').getUserMessage()
const notFoundMessage = new AppError(40401, 'missing', {
status: 404,
type: ErrorType.NOT_FOUND,
}).getUserMessage()
const validationMessage = AppError.validation('validation failed').getUserMessage()
const customUnknownMessage = new AppError(9001, 'custom unknown', {
type: ErrorType.UNKNOWN,
}).getUserMessage()
const fallbackUnknownMessage = new AppError(9002, '', {
type: ErrorType.UNKNOWN,
}).getUserMessage()
expect(networkMessage.length).toBeGreaterThan(0)
expect(networkMessage).not.toBe('network failed')
expect(authMessage.length).toBeGreaterThan(0)
expect(authMessage).not.toBe('custom auth')
expect(forbiddenMessage.length).toBeGreaterThan(0)
expect(forbiddenMessage).not.toBe('custom forbidden')
expect(notFoundMessage.length).toBeGreaterThan(0)
expect(notFoundMessage).not.toBe('missing')
expect(validationMessage).toBe('validation failed')
expect(customUnknownMessage).toBe('custom unknown')
expect(fallbackUnknownMessage.length).toBeGreaterThan(0)
})
it('identifies AppError instances correctly', () => {
expect(isAppError(new AppError(1, 'boom'))).toBe(true)
expect(isAppError(new Error('boom'))).toBe(false)
expect(isAppError('boom')).toBe(false)
})
})
describe('error helpers', () => {
it('uses the AppError user message when available', () => {
const error = AppError.validation('invalid form')
expect(getErrorMessage(error, 'fallback')).toBe('invalid form')
})
it('falls back to generic Error messages and finally to the provided fallback', () => {
expect(getErrorMessage(new Error('plain error'), 'fallback')).toBe('plain error')
expect(getErrorMessage({ foo: 'bar' }, 'fallback')).toBe('fallback')
})
it('detects form validation errors only for objects with an errorFields array', () => {
expect(isFormValidationError({ errorFields: [] })).toBe(true)
expect(isFormValidationError({ errorFields: 'nope' })).toBe(false)
expect(isFormValidationError(null)).toBe(false)
})
})

View File

@@ -0,0 +1,172 @@
/**
* AppError - 应用统一错误模型
*
* 用于统一处理后端业务错误和前端运行时错误
*/
/**
* 错误类型常量
*/
export const ErrorType = {
/** 业务错误 - 后端返回的业务逻辑错误 */
BUSINESS: 'BUSINESS',
/** 网络错误 - 请求失败、超时等 */
NETWORK: 'NETWORK',
/** 认证错误 - 401 未登录或 token 过期 */
AUTH: 'AUTH',
/** 权限错误 - 403 无权限访问 */
FORBIDDEN: 'FORBIDDEN',
/** 资源不存在 - 404 */
NOT_FOUND: 'NOT_FOUND',
/** 验证错误 - 表单校验失败 */
VALIDATION: 'VALIDATION',
/** 未知错误 */
UNKNOWN: 'UNKNOWN',
} as const
export type ErrorTypeValue = typeof ErrorType[keyof typeof ErrorType]
/**
* 应用错误类
*/
export class AppError extends Error {
/** 错误码 */
readonly code: number
/** HTTP 状态码 */
readonly status: number
/** 错误类型 */
readonly type: ErrorTypeValue
/** 原始错误 */
readonly cause?: Error
constructor(
code: number,
message: string,
options?: {
status?: number
type?: ErrorTypeValue
cause?: Error
}
) {
super(message)
this.name = 'AppError'
this.code = code
this.status = options?.status ?? 500
this.type = options?.type ?? ErrorType.BUSINESS
this.cause = options?.cause
// 确保 instanceof 正常工作
Object.setPrototypeOf(this, AppError.prototype)
}
/**
* 从后端响应创建错误
*/
static fromResponse(response: { code: number; message: string }, status: number): AppError {
let type: ErrorTypeValue = ErrorType.BUSINESS
if (status === 401) {
type = ErrorType.AUTH
} else if (status === 403) {
type = ErrorType.FORBIDDEN
} else if (status === 404) {
type = ErrorType.NOT_FOUND
} else if (status >= 500) {
type = ErrorType.NETWORK
}
return new AppError(response.code, response.message, { status, type })
}
/**
* 创建网络错误
*/
static network(message: string, cause?: Error): AppError {
return new AppError(0, message, {
status: 0,
type: ErrorType.NETWORK,
cause,
})
}
/**
* 创建认证错误
*/
static auth(message: string = '请先登录'): AppError {
return new AppError(401, message, {
status: 401,
type: ErrorType.AUTH,
})
}
/**
* 创建权限错误
*/
static forbidden(message: string = '无权限访问'): AppError {
return new AppError(403, message, {
status: 403,
type: ErrorType.FORBIDDEN,
})
}
/**
* 创建验证错误
*/
static validation(message: string): AppError {
return new AppError(400, message, {
status: 400,
type: ErrorType.VALIDATION,
})
}
/**
* 判断是否为认证错误
*/
isAuthError(): boolean {
return this.type === ErrorType.AUTH || this.status === 401
}
/**
* 判断是否为权限错误
*/
isForbidden(): boolean {
return this.type === ErrorType.FORBIDDEN || this.status === 403
}
/**
* 判断是否为网络错误
*/
isNetworkError(): boolean {
return this.type === ErrorType.NETWORK
}
/**
* 获取用户友好的错误消息
*/
getUserMessage(): string {
switch (this.type) {
case ErrorType.NETWORK:
return '网络连接失败,请检查网络后重试'
case ErrorType.AUTH:
return '登录已过期,请重新登录'
case ErrorType.FORBIDDEN:
return '您没有权限执行此操作'
case ErrorType.NOT_FOUND:
return '请求的资源不存在'
case ErrorType.VALIDATION:
return this.message
default:
return this.message || '操作失败,请稍后重试'
}
}
}
/**
* 判断是否为 AppError
*/
export function isAppError(error: unknown): error is AppError {
return error instanceof AppError
}

View File

@@ -0,0 +1,26 @@
import { AppError, ErrorType, isAppError } from './AppError'
export { AppError, ErrorType, isAppError }
export function getErrorMessage(error: unknown, fallback: string): string {
if (isAppError(error)) {
return error.getUserMessage()
}
if (error instanceof Error && error.message) {
return error.message
}
return fallback
}
export function isFormValidationError(
error: unknown,
): error is { errorFields: unknown[] } {
return (
typeof error === 'object' &&
error !== null &&
'errorFields' in error &&
Array.isArray((error as { errorFields?: unknown[] }).errorFields)
)
}

View File

@@ -0,0 +1,82 @@
import type { ReactNode } from 'react'
import { renderHook } from '@testing-library/react'
import { MemoryRouter } from 'react-router-dom'
import { describe, expect, it } from 'vitest'
import { useBreadcrumbs } from './useBreadcrumbs'
function createWrapper(pathname: string) {
return function Wrapper({ children }: { children: ReactNode }) {
return <MemoryRouter initialEntries={[pathname]}>{children}</MemoryRouter>
}
}
describe('useBreadcrumbs', () => {
it('returns an empty breadcrumb list at the root path', () => {
const { result } = renderHook(() => useBreadcrumbs(), {
wrapper: createWrapper('/'),
})
expect(result.current).toEqual([])
})
it('maps known single-segment routes to a terminal breadcrumb item', () => {
const { result } = renderHook(() => useBreadcrumbs(), {
wrapper: createWrapper('/dashboard'),
})
expect(result.current).toEqual([
{
title: '概览',
path: undefined,
},
])
})
it('builds nested breadcrumbs for supported child routes', () => {
const { logsResult } = {
logsResult: renderHook(() => useBreadcrumbs(), {
wrapper: createWrapper('/logs/login'),
}),
}
expect(logsResult.result.current).toEqual([
{
title: '审计日志',
path: '/logs',
},
{
title: '登录日志',
path: undefined,
},
])
const profileResult = renderHook(() => useBreadcrumbs(), {
wrapper: createWrapper('/profile/security'),
})
expect(profileResult.result.current).toEqual([
{
title: '个人资料',
path: '/profile',
},
{
title: '安全设置',
path: undefined,
},
])
})
it('skips unknown route segments while keeping known ancestors', () => {
const { result } = renderHook(() => useBreadcrumbs(), {
wrapper: createWrapper('/logs/unknown'),
})
expect(result.current).toEqual([
{
title: '审计日志',
path: '/logs',
},
])
})
})

View File

@@ -0,0 +1,48 @@
import { useMemo } from 'react'
import { useLocation } from 'react-router-dom'
import type { BreadcrumbProps } from 'antd'
const breadcrumbNameMap: Record<string, string> = {
'/dashboard': '概览',
'/users': '用户管理',
'/roles': '角色管理',
'/permissions': '权限管理',
'/logs': '审计日志',
'/logs/login': '登录日志',
'/logs/operation': '操作日志',
'/webhooks': 'Webhooks',
'/import-export': '导入导出',
'/profile': '个人资料',
'/profile/security': '安全设置',
}
export function useBreadcrumbs(): BreadcrumbProps['items'] {
const location = useLocation()
return useMemo(() => {
const pathSnippets = location.pathname.split('/').filter(Boolean)
if (pathSnippets.length === 0) {
return []
}
const items: BreadcrumbProps['items'] = []
let currentPath = ''
pathSnippets.forEach((snippet, index) => {
currentPath += `/${snippet}`
const name = breadcrumbNameMap[currentPath]
if (!name) {
return
}
items.push({
title: name,
path: index === pathSnippets.length - 1 ? undefined : currentPath,
})
})
return items
}, [location.pathname])
}

View File

@@ -0,0 +1,83 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
const user = {
id: 1,
username: 'admin',
email: 'admin@example.com',
phone: '13800138000',
nickname: 'Admin',
avatar: '',
status: 1 as const,
}
const roles = [
{
id: 1,
name: 'Administrator',
code: 'admin',
description: 'System administrator',
is_system: true,
is_default: false,
status: 1 as const,
},
]
describe('auth-session', () => {
beforeEach(() => {
vi.resetModules()
vi.clearAllMocks()
})
it('stores and clears the session state in memory', async () => {
const session = await import('@/lib/http/auth-session')
session.setAccessToken('access-token', 60)
session.setCurrentUser(user)
session.setCurrentRoles(roles)
expect(session.getAccessToken()).toBe('access-token')
expect(session.getCurrentUser()).toEqual(user)
expect(session.getCurrentRoles()).toEqual(roles)
expect(session.getRoleCodes()).toEqual(['admin'])
expect(session.isAdmin()).toBe(true)
expect(session.isAuthenticated()).toBe(true)
session.clearSession()
expect(session.getAccessToken()).toBeNull()
expect(session.getCurrentUser()).toBeNull()
expect(session.getCurrentRoles()).toEqual([])
expect(session.isAuthenticated()).toBe(false)
})
it('starts empty after a module reload because the session is memory-only', async () => {
let session = await import('@/lib/http/auth-session')
session.setAccessToken('access-token', 60)
session.setCurrentUser(user)
session.setCurrentRoles(roles)
vi.resetModules()
session = await import('@/lib/http/auth-session')
expect(session.getAccessToken()).toBeNull()
expect(session.getCurrentUser()).toBeNull()
expect(session.getCurrentRoles()).toEqual([])
expect(session.isAuthenticated()).toBe(false)
})
it('marks the token as expired before the hard expiry time', async () => {
vi.useFakeTimers()
vi.setSystemTime(new Date('2026-03-21T00:00:00Z'))
const session = await import('@/lib/http/auth-session')
session.setAccessToken('access-token', 60)
expect(session.isAccessTokenExpired()).toBe(false)
vi.advanceTimersByTime(31_000)
expect(session.isAccessTokenExpired()).toBe(true)
session.clearSession()
vi.useRealTimers()
})
})

View File

@@ -0,0 +1,101 @@
import type { SessionUser, Role } from '@/types'
interface SessionState {
accessToken: string | null
expiresAt: number | null
user: SessionUser | null
roles: Role[]
isRefreshing: boolean
refreshPromise: Promise<void> | null
}
const sessionState: SessionState = {
accessToken: null,
expiresAt: null,
user: null,
roles: [],
isRefreshing: false,
refreshPromise: null,
}
export function getAccessToken(): string | null {
return sessionState.accessToken
}
export function setAccessToken(token: string, expiresIn: number): void {
sessionState.accessToken = token
sessionState.expiresAt = Date.now() + expiresIn * 1000
}
export function clearAccessToken(): void {
sessionState.accessToken = null
sessionState.expiresAt = null
}
export function isAccessTokenExpired(): boolean {
if (!sessionState.expiresAt) {
return true
}
return Date.now() > sessionState.expiresAt - 30_000
}
export function getCurrentUser(): SessionUser | null {
return sessionState.user
}
export function setCurrentUser(user: SessionUser): void {
sessionState.user = user
}
export function getCurrentRoles(): Role[] {
return sessionState.roles
}
export function setCurrentRoles(roles: Role[]): void {
sessionState.roles = roles
}
export function isAdmin(): boolean {
return sessionState.roles.some((role) => role.code === 'admin')
}
export function getRoleCodes(): string[] {
return sessionState.roles.map((role) => role.code)
}
export function isAuthenticated(): boolean {
return sessionState.accessToken !== null && sessionState.user !== null
}
export function clearSession(): void {
sessionState.accessToken = null
sessionState.expiresAt = null
sessionState.user = null
sessionState.roles = []
sessionState.isRefreshing = false
sessionState.refreshPromise = null
}
export function isRefreshing(): boolean {
return sessionState.isRefreshing
}
export function startRefreshing(): void {
sessionState.isRefreshing = true
}
export function endRefreshing(): void {
sessionState.isRefreshing = false
}
export function getRefreshPromise(): Promise<void> | null {
return sessionState.refreshPromise
}
export function setRefreshPromise(promise: Promise<void>): void {
sessionState.refreshPromise = promise
}
export function clearRefreshPromise(): void {
sessionState.refreshPromise = null
}

View File

@@ -0,0 +1,785 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
type JsonResponseInit = ResponseInit & {
status?: number
}
function jsonResponse(data: unknown, init: JsonResponseInit = {}) {
return new Response(JSON.stringify(data), {
status: 200,
headers: {
'Content-Type': 'application/json',
},
...init,
})
}
async function loadModules() {
vi.resetModules()
const session = await import('@/lib/http/auth-session')
const storage = await import('@/lib/storage')
const csrf = await import('@/lib/http/csrf')
const errors = await import('@/lib/errors')
const client = await import('@/lib/http/client')
return {
...session,
...storage,
...csrf,
...errors,
...client,
}
}
describe('http client', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.useRealTimers()
vi.unstubAllEnvs()
vi.unstubAllGlobals()
vi.stubGlobal('fetch', vi.fn())
})
afterEach(() => {
vi.useRealTimers()
vi.unstubAllEnvs()
vi.unstubAllGlobals()
})
it('builds query-string urls and skips undefined params without auth headers', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: { ok: true },
}),
)
const { get } = await loadModules()
const result = await get(
'/users',
{ page: 2, active: true, keyword: undefined },
{ auth: false },
)
expect(result).toEqual({ ok: true })
expect(fetchMock).toHaveBeenCalledTimes(1)
const [requestUrl, requestInit] = fetchMock.mock.calls[0]
expect(String(requestUrl)).toBe(`${window.location.origin}/api/v1/users?page=2&active=true`)
expect(requestInit?.headers).not.toMatchObject({
Authorization: expect.any(String),
})
})
it('supports relative api base urls without a leading slash', async () => {
vi.stubEnv('VITE_API_BASE_URL', 'api/custom')
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: { ok: true },
}),
)
const { get } = await loadModules()
await get('/status', undefined, { auth: false })
expect(fetchMock).toHaveBeenCalledWith(
`${window.location.origin}/api/custom/status`,
expect.any(Object),
)
})
it('supports absolute api base urls', async () => {
vi.stubEnv('VITE_API_BASE_URL', 'https://api.example.com/base')
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: { ok: true },
}),
)
const { get } = await loadModules()
await get('/status', undefined, { auth: false })
expect(fetchMock).toHaveBeenCalledWith(
'https://api.example.com/base/status',
expect.any(Object),
)
})
it('sends FormData without forcing a JSON content type', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: { uploaded: true },
}),
)
const { post } = await loadModules()
const formData = new FormData()
formData.append('file', new Blob(['demo'], { type: 'text/plain' }), 'demo.txt')
const result = await post('/upload', formData, { auth: false })
expect(result).toEqual({ uploaded: true })
expect(fetchMock).toHaveBeenCalledTimes(1)
const [requestUrl, requestInit] = fetchMock.mock.calls[0]
const headers = requestInit?.headers as Record<string, string> | undefined
expect(String(requestUrl)).toContain('/api/v1/upload')
expect(requestInit?.body).toBe(formData)
expect(requestInit?.credentials).toBe('include')
expect(headers?.['Content-Type']).toBeUndefined()
})
it('adds csrf and json headers for protected write requests', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: { saved: true },
}),
)
const { CSRF_HEADER_NAME, put, setCSRFToken } = await loadModules()
setCSRFToken('csrf-token')
const result = await put('/users/1', { nickname: 'Demo' }, { auth: false })
expect(result).toEqual({ saved: true })
expect(fetchMock).toHaveBeenCalledTimes(1)
expect(fetchMock.mock.calls[0][1]).toMatchObject({
method: 'PUT',
body: JSON.stringify({ nickname: 'Demo' }),
headers: {
'Content-Type': 'application/json',
[CSRF_HEADER_NAME]: 'csrf-token',
},
})
})
it('adds csrf headers to delete requests', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: { deleted: true },
}),
)
const { CSRF_HEADER_NAME, del, setCSRFToken } = await loadModules()
setCSRFToken('csrf-token')
const result = await del('/users/1', { auth: false })
expect(result).toEqual({ deleted: true })
expect(fetchMock.mock.calls[0][1]).toMatchObject({
method: 'DELETE',
headers: {
[CSRF_HEADER_NAME]: 'csrf-token',
},
})
})
it('refreshes an expired access token before sending the business request', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock
.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: {
access_token: 'access-token-new',
refresh_token: 'refresh-token-new',
expires_in: 3600,
user: {
id: 1,
username: 'admin',
email: 'admin@example.com',
phone: '13800138000',
nickname: 'Admin',
avatar: '',
status: 1,
},
},
}),
)
.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: { ok: true },
}),
)
const { get, setAccessToken, setRefreshToken } = await loadModules()
setAccessToken('access-token-old', -1)
setRefreshToken('refresh-token-old')
const data = await get('/protected')
expect(data).toEqual({ ok: true })
expect(fetchMock).toHaveBeenCalledTimes(2)
expect(String(fetchMock.mock.calls[0][0])).toContain('/api/v1/auth/refresh')
expect(fetchMock.mock.calls[0][1]).toMatchObject({
credentials: 'include',
method: 'POST',
body: JSON.stringify({ refresh_token: 'refresh-token-old' }),
})
expect(fetchMock.mock.calls[1][1]?.headers).toMatchObject({
Authorization: 'Bearer access-token-new',
})
})
it('waits for an in-flight refresh promise before sending the request', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: { ok: true },
}),
)
const { get, setAccessToken, setRefreshPromise, startRefreshing } = await loadModules()
setAccessToken('queued-access-token', 3600)
startRefreshing()
setRefreshPromise(Promise.resolve())
const result = await get('/protected')
expect(result).toEqual({ ok: true })
expect(fetchMock).toHaveBeenCalledTimes(1)
expect(fetchMock.mock.calls[0][1]?.headers).toMatchObject({
Authorization: 'Bearer queued-access-token',
})
})
it('clears the local session when refresh fails before the business request is sent', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(new Response(null, { status: 401 }))
const {
ErrorType,
get,
getAccessToken,
getRefreshToken,
setAccessToken,
setRefreshToken,
} = await loadModules()
setAccessToken('expired-access-token', -1)
setRefreshToken('refresh-token-old')
await expect(get('/protected')).rejects.toMatchObject({
status: 401,
type: ErrorType.AUTH,
})
expect(fetchMock).toHaveBeenCalledTimes(1)
expect(getAccessToken()).toBeNull()
expect(getRefreshToken()).toBeNull()
})
it('clears the local session when refresh returns a business error payload', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 10001,
message: 'refresh failed',
data: null,
}),
)
const {
ErrorType,
get,
getAccessToken,
getRefreshToken,
setAccessToken,
setRefreshToken,
} = await loadModules()
setAccessToken('expired-access-token', -1)
setRefreshToken('refresh-token-old')
await expect(get('/protected')).rejects.toMatchObject({
status: 401,
type: ErrorType.AUTH,
})
expect(fetchMock).toHaveBeenCalledTimes(1)
expect(getAccessToken()).toBeNull()
expect(getRefreshToken()).toBeNull()
})
it('retries once after a 401 response and rotates the in-memory refresh token', async () => {
const fetchMock = vi.mocked(fetch)
const capturedHeaders: Array<Record<string, string> | undefined> = []
fetchMock
.mockImplementationOnce(async (_url, requestInit) => {
capturedHeaders.push(
requestInit?.headers
? { ...(requestInit.headers as Record<string, string>) }
: undefined,
)
return new Response(null, { status: 401 })
})
.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: {
access_token: 'access-token-retried',
refresh_token: 'refresh-token-retried',
expires_in: 3600,
user: {
id: 1,
username: 'admin',
email: 'admin@example.com',
phone: '13800138000',
nickname: 'Admin',
avatar: '',
status: 1,
},
},
}),
)
.mockImplementationOnce(async (_url, requestInit) => {
capturedHeaders.push(
requestInit?.headers
? { ...(requestInit.headers as Record<string, string>) }
: undefined,
)
return jsonResponse({
code: 0,
message: 'ok',
data: { retried: true },
})
})
const { get, getRefreshToken, setAccessToken, setRefreshToken } = await loadModules()
setAccessToken('access-token-old', 3600)
setRefreshToken('refresh-token-old')
const data = await get('/protected')
expect(data).toEqual({ retried: true })
expect(fetchMock).toHaveBeenCalledTimes(3)
expect(capturedHeaders[0]).toMatchObject({
Authorization: 'Bearer access-token-old',
})
expect(String(fetchMock.mock.calls[1][0])).toContain('/api/v1/auth/refresh')
expect(fetchMock.mock.calls[1][1]).toMatchObject({
credentials: 'include',
method: 'POST',
body: JSON.stringify({ refresh_token: 'refresh-token-old' }),
})
expect(capturedHeaders[1]).toMatchObject({
Authorization: 'Bearer access-token-retried',
})
expect(getRefreshToken()).toBe('refresh-token-retried')
})
it('reuses an in-flight refresh token when a 401 retry happens during another refresh', async () => {
const fetchMock = vi.mocked(fetch)
const {
get,
setAccessToken,
setRefreshPromise,
startRefreshing,
} = await loadModules()
fetchMock
.mockImplementationOnce(async () => {
startRefreshing()
setAccessToken('shared-refresh-token', 3600)
setRefreshPromise(Promise.resolve())
return new Response(null, { status: 401 })
})
.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: { retried: true },
}),
)
setAccessToken('access-token-old', 3600)
const data = await get('/protected')
expect(data).toEqual({ retried: true })
expect(fetchMock).toHaveBeenCalledTimes(2)
expect(fetchMock.mock.calls[1][1]?.headers).toMatchObject({
Authorization: 'Bearer shared-refresh-token',
})
})
it('fails the 401 retry when the shared refresh finishes without an access token', async () => {
const fetchMock = vi.mocked(fetch)
const {
clearAccessToken,
ErrorType,
get,
getAccessToken,
getRefreshToken,
setAccessToken,
setRefreshPromise,
setRefreshToken,
startRefreshing,
} = await loadModules()
fetchMock.mockImplementationOnce(async () => {
startRefreshing()
clearAccessToken()
setRefreshPromise(Promise.resolve())
return new Response(null, { status: 401 })
})
setAccessToken('access-token-old', 3600)
setRefreshToken('refresh-token-old')
await expect(get('/protected')).rejects.toMatchObject({
status: 401,
type: ErrorType.AUTH,
})
expect(fetchMock).toHaveBeenCalledTimes(1)
expect(getAccessToken()).toBeNull()
expect(getRefreshToken()).toBeNull()
})
it('clears the local session when the retried request still returns 401', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock
.mockResolvedValueOnce(new Response(null, { status: 401 }))
.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: {
access_token: 'access-token-retried',
refresh_token: 'refresh-token-retried',
expires_in: 3600,
user: {
id: 1,
username: 'admin',
email: 'admin@example.com',
phone: '13800138000',
nickname: 'Admin',
avatar: '',
status: 1,
},
},
}),
)
.mockResolvedValueOnce(new Response(null, { status: 401 }))
const {
ErrorType,
get,
getAccessToken,
getRefreshToken,
setAccessToken,
setRefreshToken,
} = await loadModules()
setAccessToken('access-token-old', 3600)
setRefreshToken('refresh-token-old')
await expect(get('/protected')).rejects.toMatchObject({
status: 401,
type: ErrorType.AUTH,
})
expect(fetchMock).toHaveBeenCalledTimes(3)
expect(getAccessToken()).toBeNull()
expect(getRefreshToken()).toBeNull()
})
it('maps 403 responses to forbidden errors', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(new Response(null, { status: 403 }))
const { ErrorType, get } = await loadModules()
await expect(get('/forbidden', undefined, { auth: false })).rejects.toMatchObject({
status: 403,
type: ErrorType.FORBIDDEN,
})
})
it('maps 404 responses to not-found errors', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(new Response(null, { status: 404 }))
const { ErrorType, get } = await loadModules()
await expect(get('/missing', undefined, { auth: false })).rejects.toMatchObject({
status: 404,
type: ErrorType.NOT_FOUND,
})
})
it('maps other non-ok responses to network errors', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(new Response(null, { status: 500 }))
const { ErrorType, get } = await loadModules()
await expect(get('/broken', undefined, { auth: false })).rejects.toMatchObject({
status: 0,
type: ErrorType.NETWORK,
})
})
it('maps non-zero business responses to AppError.fromResponse', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 10001,
message: 'business failure',
data: null,
}),
)
const { ErrorType, get } = await loadModules()
await expect(get('/business', undefined, { auth: false })).rejects.toMatchObject({
code: 10001,
status: 200,
type: ErrorType.BUSINESS,
})
})
it('converts aborted requests into timeout AppErrors', async () => {
vi.useFakeTimers()
const fetchMock = vi.mocked(fetch)
fetchMock.mockImplementation(
(_url, requestInit) =>
new Promise((_, reject) => {
;(requestInit?.signal as AbortSignal).addEventListener(
'abort',
() => reject(new DOMException('Aborted', 'AbortError')),
{ once: true },
)
}),
)
const { ErrorType, request } = await loadModules()
const requestPromise = expect(request('/slow', { auth: false })).rejects.toMatchObject({
status: 0,
type: ErrorType.NETWORK,
})
await vi.advanceTimersByTimeAsync(30_000)
await requestPromise
})
it('propagates a caller abort signal into the request timeout controller', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockImplementation(
(_url, requestInit) =>
new Promise((_, reject) => {
;(requestInit?.signal as AbortSignal).addEventListener(
'abort',
() => reject(new DOMException('Aborted', 'AbortError')),
{ once: true },
)
}),
)
const controller = new AbortController()
const { ErrorType, request } = await loadModules()
const requestPromise = expect(
request('/slow', { auth: false, signal: controller.signal }),
).rejects.toMatchObject({
status: 0,
type: ErrorType.NETWORK,
})
await Promise.resolve()
controller.abort()
await requestPromise
})
it('retries downloads after a 401 and returns the blob payload', async () => {
const fetchMock = vi.mocked(fetch)
const downloadedBlob = { kind: 'downloaded-blob' } as unknown as Blob
const successResponse = {
ok: true,
status: 200,
blob: vi.fn().mockResolvedValue(downloadedBlob),
} as unknown as Response
fetchMock
.mockResolvedValueOnce(new Response(null, { status: 401 }))
.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: {
access_token: 'download-access-token',
refresh_token: 'download-refresh-token',
expires_in: 3600,
user: {
id: 1,
username: 'admin',
email: 'admin@example.com',
phone: '13800138000',
nickname: 'Admin',
avatar: '',
status: 1,
},
},
}),
)
.mockResolvedValueOnce(successResponse)
const { download, getRefreshToken, setAccessToken, setRefreshToken } = await loadModules()
setAccessToken('access-token-old', 3600)
setRefreshToken('refresh-token-old')
const blob = await download('/export')
expect(blob).toBe(downloadedBlob)
expect(fetchMock).toHaveBeenCalledTimes(3)
expect(String(fetchMock.mock.calls[1][0])).toContain('/api/v1/auth/refresh')
expect(fetchMock.mock.calls[2][1]?.headers).toMatchObject({
Authorization: 'Bearer download-access-token',
})
expect(getRefreshToken()).toBe('download-refresh-token')
})
it('maps failed downloads to network AppErrors', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(new Response(null, { status: 500 }))
const { ErrorType, download } = await loadModules()
await expect(download('/export', undefined, { auth: false })).rejects.toMatchObject({
status: 0,
type: ErrorType.NETWORK,
})
})
it('clears the local session when a download retry still returns 401', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock
.mockResolvedValueOnce(new Response(null, { status: 401 }))
.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: {
access_token: 'download-access-token',
refresh_token: 'download-refresh-token',
expires_in: 3600,
user: {
id: 1,
username: 'admin',
email: 'admin@example.com',
phone: '13800138000',
nickname: 'Admin',
avatar: '',
status: 1,
},
},
}),
)
.mockResolvedValueOnce(new Response(null, { status: 401 }))
const {
ErrorType,
download,
getAccessToken,
getRefreshToken,
setAccessToken,
setRefreshToken,
} = await loadModules()
setAccessToken('access-token-old', 3600)
setRefreshToken('refresh-token-old')
await expect(download('/export')).rejects.toMatchObject({
status: 401,
type: ErrorType.AUTH,
})
expect(fetchMock).toHaveBeenCalledTimes(3)
expect(getAccessToken()).toBeNull()
expect(getRefreshToken()).toBeNull()
})
it('converts aborted downloads into timeout AppErrors', async () => {
vi.useFakeTimers()
const fetchMock = vi.mocked(fetch)
fetchMock.mockImplementation(
(_url, requestInit) =>
new Promise((_, reject) => {
;(requestInit?.signal as AbortSignal).addEventListener(
'abort',
() => reject(new DOMException('Aborted', 'AbortError')),
{ once: true },
)
}),
)
const { ErrorType, download } = await loadModules()
const downloadPromise = expect(
download('/export', undefined, { auth: false }),
).rejects.toMatchObject({
status: 0,
type: ErrorType.NETWORK,
})
await vi.advanceTimersByTimeAsync(30_000)
await downloadPromise
})
it('builds upload form data with additional fields', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 0,
message: 'ok',
data: { uploaded: true },
}),
)
const { upload } = await loadModules()
const file = new File(['demo'], 'avatar.png', { type: 'image/png' })
const result = await upload(
'/upload',
file,
'asset',
{ folder: 'avatars' },
{ auth: false },
)
expect(result).toEqual({ uploaded: true })
expect(fetchMock).toHaveBeenCalledTimes(1)
const requestInit = fetchMock.mock.calls[0][1]
const body = requestInit?.body as FormData
expect(requestInit?.method).toBe('POST')
expect(body.get('folder')).toBe('avatars')
expect(body.get('asset')).toBeInstanceOf(File)
expect((body.get('asset') as File).name).toBe('avatar.png')
})
})

View File

@@ -0,0 +1,367 @@
import { config } from '@/lib/config'
import { AppError, ErrorType } from '@/lib/errors'
import type { ApiResponse, RequestOptions } from '@/types'
import {
clearRefreshPromise,
clearSession,
endRefreshing,
getAccessToken,
getRefreshPromise,
isAccessTokenExpired,
isRefreshing,
setAccessToken,
setRefreshPromise,
startRefreshing,
} from './auth-session'
import { clearRefreshToken, getRefreshToken, setRefreshToken } from '../storage'
import { CSRF_PROTECTED_METHODS, getCSRFHeaders } from './csrf'
import type { TokenBundle } from '@/types'
const DEFAULT_TIMEOUT = 30_000
function isFormDataBody(body: unknown): body is FormData {
return typeof FormData !== 'undefined' && body instanceof FormData
}
function serializeBody(body: unknown): BodyInit | undefined {
if (body === undefined || body === null) {
return undefined
}
if (isFormDataBody(body)) {
return body
}
return JSON.stringify(body)
}
function resolveApiBaseUrl(): URL {
const origin = typeof window !== 'undefined' ? window.location.origin : 'http://localhost'
const rawBaseUrl = /^https?:\/\//i.test(config.apiBaseUrl)
? config.apiBaseUrl
: config.apiBaseUrl.startsWith('/')
? config.apiBaseUrl
: `/${config.apiBaseUrl}`
const baseUrl = new URL(rawBaseUrl, origin)
if (!baseUrl.pathname.endsWith('/')) {
baseUrl.pathname = `${baseUrl.pathname}/`
}
return baseUrl
}
function buildUrl(path: string, params?: Record<string, string | number | boolean | undefined>): string {
const url = new URL(path.replace(/^\/+/, ''), resolveApiBaseUrl())
if (params) {
for (const [key, value] of Object.entries(params)) {
if (value !== undefined) {
url.searchParams.append(key, String(value))
}
}
}
return url.toString()
}
function cleanupSessionOnAuthFailure(): never {
clearRefreshToken()
clearSession()
throw AppError.auth('会话已过期,请重新登录')
}
function createTimeoutSignal(signal?: AbortSignal): { signal: AbortSignal; cleanup: () => void } {
const controller = new AbortController()
const timeoutId = window.setTimeout(() => controller.abort(), DEFAULT_TIMEOUT)
if (signal) {
signal.addEventListener('abort', () => controller.abort(), { once: true })
}
return {
signal: controller.signal,
cleanup: () => window.clearTimeout(timeoutId),
}
}
async function parseJsonResponse<T>(response: Response): Promise<ApiResponse<T>> {
return response.json() as Promise<ApiResponse<T>>
}
async function refreshAccessToken(): Promise<TokenBundle> {
const refreshToken = getRefreshToken()
const body = refreshToken ? JSON.stringify({ refresh_token: refreshToken }) : undefined
const response = await fetch(buildUrl('/auth/refresh'), {
method: 'POST',
credentials: 'include',
headers: body ? { 'Content-Type': 'application/json' } : undefined,
body,
})
if (!response.ok) {
return cleanupSessionOnAuthFailure()
}
const result = await parseJsonResponse<TokenBundle>(response)
if (result.code !== 0) {
return cleanupSessionOnAuthFailure()
}
return result.data
}
async function performRefresh(): Promise<string> {
if (isRefreshing()) {
const promise = getRefreshPromise()
if (promise) {
await promise
}
const token = getAccessToken()
if (!token) {
return cleanupSessionOnAuthFailure()
}
return token
}
startRefreshing()
const promise = (async () => {
try {
const tokenBundle = await refreshAccessToken()
setAccessToken(tokenBundle.access_token, tokenBundle.expires_in)
setRefreshToken(tokenBundle.refresh_token)
return tokenBundle.access_token
} finally {
endRefreshing()
clearRefreshPromise()
}
})()
setRefreshPromise(
promise.then(
() => undefined,
() => undefined,
),
)
return promise
}
async function resolveAuthorizationHeader(auth: boolean): Promise<string | null> {
if (!auth) {
return null
}
let token = getAccessToken()
if (isRefreshing()) {
const promise = getRefreshPromise()
if (promise) {
await promise
token = getAccessToken()
}
}
if (token && isAccessTokenExpired()) {
token = await performRefresh()
}
return token
}
async function request<T>(path: string, options: RequestOptions = {}): Promise<T> {
const {
method = 'GET',
headers = {},
body,
params,
auth = true,
credentials = 'include',
signal,
} = options
const url = buildUrl(path, params)
const requestHeaders: Record<string, string> = { ...headers }
if (body !== undefined && body !== null && !isFormDataBody(body) && !requestHeaders['Content-Type']) {
requestHeaders['Content-Type'] = 'application/json'
}
if (CSRF_PROTECTED_METHODS.includes(method)) {
Object.assign(requestHeaders, getCSRFHeaders())
}
const authToken = await resolveAuthorizationHeader(auth)
if (authToken) {
requestHeaders.Authorization = `Bearer ${authToken}`
}
const timeout = createTimeoutSignal(signal)
try {
let response = await fetch(url, {
method,
headers: requestHeaders,
body: serializeBody(body),
credentials,
signal: timeout.signal,
})
if (response.status === 401 && auth) {
const refreshedToken = await performRefresh()
requestHeaders.Authorization = `Bearer ${refreshedToken}`
response = await fetch(url, {
method,
headers: requestHeaders,
body: serializeBody(body),
credentials,
signal: timeout.signal,
})
}
if (response.status === 401) {
return cleanupSessionOnAuthFailure()
}
if (!response.ok) {
if (response.status === 403) {
throw AppError.forbidden()
}
if (response.status === 404) {
throw new AppError(404, '请求的资源不存在', {
status: 404,
type: ErrorType.NOT_FOUND,
})
}
throw AppError.network(`请求失败: ${response.status}`)
}
const result = await parseJsonResponse<T>(response)
if (result.code !== 0) {
throw AppError.fromResponse(result, response.status)
}
return result.data
} catch (error) {
if (error instanceof DOMException && error.name === 'AbortError') {
throw AppError.network('请求超时,请稍后重试')
}
throw error
} finally {
timeout.cleanup()
}
}
export function get<T>(
path: string,
params?: Record<string, string | number | boolean | undefined>,
options?: Omit<RequestOptions, 'method' | 'params' | 'body'>,
): Promise<T> {
return request<T>(path, { ...options, method: 'GET', params })
}
export function post<T>(
path: string,
body?: unknown,
options?: Omit<RequestOptions, 'method' | 'body'>,
): Promise<T> {
return request<T>(path, { ...options, method: 'POST', body })
}
export function put<T>(
path: string,
body?: unknown,
options?: Omit<RequestOptions, 'method' | 'body'>,
): Promise<T> {
return request<T>(path, { ...options, method: 'PUT', body })
}
export function del<T>(
path: string,
options?: Omit<RequestOptions, 'method'>,
): Promise<T> {
return request<T>(path, { ...options, method: 'DELETE' })
}
async function resolveAuthorizedHeaders(options?: Omit<RequestOptions, 'method' | 'params' | 'body'>): Promise<Record<string, string>> {
const headers: Record<string, string> = { ...(options?.headers ?? {}) }
if (options?.auth !== false) {
const token = await resolveAuthorizationHeader(true)
if (token) {
headers.Authorization = `Bearer ${token}`
}
}
return headers
}
export async function download(
path: string,
params?: Record<string, string | number | boolean | undefined>,
options?: Omit<RequestOptions, 'method' | 'params'>,
): Promise<Blob> {
const url = buildUrl(path, params)
const headers = await resolveAuthorizedHeaders(options)
const timeout = createTimeoutSignal(options?.signal)
try {
let response = await fetch(url, {
headers,
credentials: options?.credentials ?? 'include',
signal: timeout.signal,
})
if (response.status === 401 && options?.auth !== false) {
const refreshedToken = await performRefresh()
headers.Authorization = `Bearer ${refreshedToken}`
response = await fetch(url, {
headers,
credentials: options?.credentials ?? 'include',
signal: timeout.signal,
})
}
if (response.status === 401) {
return cleanupSessionOnAuthFailure()
}
if (!response.ok) {
throw AppError.network(`下载失败: ${response.status}`)
}
return response.blob()
} catch (error) {
if (error instanceof DOMException && error.name === 'AbortError') {
throw AppError.network('下载超时,请稍后重试')
}
throw error
} finally {
timeout.cleanup()
}
}
export async function upload<T>(
path: string,
file: File,
fieldName: string = 'file',
additionalData?: Record<string, string>,
options?: Omit<RequestOptions, 'method' | 'body'>,
): Promise<T> {
const formData = new FormData()
formData.append(fieldName, file)
if (additionalData) {
for (const [key, value] of Object.entries(additionalData)) {
formData.append(key, value)
}
}
return request<T>(path, {
...options,
method: 'POST',
body: formData,
})
}
export { request }

View File

@@ -0,0 +1,192 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
function jsonResponse(data: unknown, init: ResponseInit = {}) {
return new Response(JSON.stringify(data), {
status: 200,
headers: {
'Content-Type': 'application/json',
},
...init,
})
}
async function loadCsrfModule() {
vi.resetModules()
return import('./csrf')
}
function clearCsrfCookie() {
if (typeof document === 'undefined') {
return
}
document.cookie = 'csrftoken=; expires=Thu, 01 Jan 1970 00:00:00 GMT; path=/'
}
describe('csrf helpers', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.unstubAllGlobals()
vi.unstubAllEnvs()
clearCsrfCookie()
vi.stubGlobal('fetch', vi.fn())
})
afterEach(() => {
clearCsrfCookie()
vi.restoreAllMocks()
vi.unstubAllGlobals()
vi.unstubAllEnvs()
})
it('returns null when cookie lookup runs without a document', async () => {
vi.stubGlobal('document', undefined)
const { getCSRFTokenFromCookie, getCSRFHeaders } = await loadCsrfModule()
expect(getCSRFTokenFromCookie()).toBeNull()
expect(getCSRFHeaders()).toEqual({})
})
it('stores csrf tokens in memory and falls back to the cookie for headers', async () => {
const {
CSRF_HEADER_NAME,
clearCSRFToken,
getCSRFHeaders,
getCSRFToken,
setCSRFToken,
} = await loadCsrfModule()
setCSRFToken('memory-token')
expect(getCSRFToken()).toBe('memory-token')
expect(getCSRFHeaders()).toEqual({
[CSRF_HEADER_NAME]: 'memory-token',
})
clearCSRFToken()
document.cookie = 'csrftoken=cookie-token; path=/'
expect(getCSRFToken()).toBeNull()
expect(getCSRFHeaders()).toEqual({
[CSRF_HEADER_NAME]: 'cookie-token',
})
})
it('prefers an existing csrf cookie and skips the network bootstrap', async () => {
const fetchMock = vi.mocked(fetch)
document.cookie = 'csrftoken=cookie-token; path=/'
const { getCSRFToken, initCSRFToken } = await loadCsrfModule()
const token = await initCSRFToken()
expect(token).toBe('cookie-token')
expect(getCSRFToken()).toBe('cookie-token')
expect(fetchMock).not.toHaveBeenCalled()
})
it('fetches and stores a csrf token from the default relative api base', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 0,
data: {
csrf_token: 'api-token',
},
}),
)
const { getCSRFToken, initCSRFToken } = await loadCsrfModule()
const token = await initCSRFToken()
expect(token).toBe('api-token')
expect(getCSRFToken()).toBe('api-token')
expect(fetchMock).toHaveBeenCalledWith(
`${window.location.origin}/api/v1/auth/csrf-token`,
{
method: 'GET',
credentials: 'include',
headers: {
'Content-Type': 'application/json',
},
},
)
})
it('supports api base urls without a leading slash', async () => {
vi.stubEnv('VITE_API_BASE_URL', 'api/custom')
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 0,
data: {
csrf_token: 'custom-token',
},
}),
)
const { initCSRFToken } = await loadCsrfModule()
await initCSRFToken()
expect(fetchMock).toHaveBeenCalledWith(
`${window.location.origin}/api/custom/auth/csrf-token`,
expect.any(Object),
)
})
it('supports absolute api base urls', async () => {
vi.stubEnv('VITE_API_BASE_URL', 'https://api.example.com/base')
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 0,
data: {
csrf_token: 'absolute-token',
},
}),
)
const { initCSRFToken } = await loadCsrfModule()
await initCSRFToken()
expect(fetchMock).toHaveBeenCalledWith(
'https://api.example.com/base/auth/csrf-token',
expect.any(Object),
)
})
it('falls back to a cookie exposed after the csrf bootstrap request fails', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockRejectedValueOnce(new Error('network failed'))
const cookieSpy = vi
.spyOn(document, 'cookie', 'get')
.mockReturnValueOnce('')
.mockReturnValueOnce('csrftoken=fallback-token')
const { getCSRFToken, initCSRFToken } = await loadCsrfModule()
const token = await initCSRFToken()
expect(token).toBe('fallback-token')
expect(getCSRFToken()).toBe('fallback-token')
expect(fetchMock).toHaveBeenCalledTimes(1)
cookieSpy.mockRestore()
})
it('returns null when the bootstrap response does not contain a csrf token', async () => {
const fetchMock = vi.mocked(fetch)
fetchMock.mockResolvedValueOnce(
jsonResponse({
code: 1,
data: {},
}),
)
const { getCSRFHeaders, getCSRFToken, initCSRFToken } = await loadCsrfModule()
const token = await initCSRFToken()
expect(token).toBeNull()
expect(getCSRFToken()).toBeNull()
expect(getCSRFHeaders()).toEqual({})
})
})

View File

@@ -0,0 +1,145 @@
/**
* CSRF Token 管理
*
* CSRF 保护机制:
* 1. GET 请求获取 CSRF Token从 cookie 或 API
* 2. POST/PUT/DELETE 请求将 Token 添加到 X-CSRF-Token 头
*
* 注意:由于使用 Bearer Token 认证(存储在内存中),
* CSRF 风险相对较低,但为增强安全性仍建议对关键操作启用。
*/
// 注意:避免从 './client' 导入,防止循环依赖
// 使用原生 fetch 获取 CSRF Token
import { config } from '@/lib/config'
// CSRF Token 存储
let csrfToken: string | null = null
/**
* 获取 CSRF Token
*/
export function getCSRFToken(): string | null {
return csrfToken
}
/**
* 设置 CSRF Token
*/
export function setCSRFToken(token: string): void {
csrfToken = token
}
/**
* 从 cookie 中读取 CSRF Token
* Django/Laravel 等框架通常在 cookie 中设置 csrftoken
*/
export function getCSRFTokenFromCookie(): string | null {
if (typeof document === 'undefined') {
return null
}
const match = document.cookie.match(/csrftoken=([^;]+)/)
return match ? match[1] : null
}
/**
* 解析 API 基础 URL
* 注意:此函数复制自 client.ts 以避免循环依赖
*/
function resolveApiBaseUrl(): URL {
const origin = typeof window !== 'undefined' ? window.location.origin : 'http://localhost'
const rawBaseUrl = /^https?:\/\//i.test(config.apiBaseUrl)
? config.apiBaseUrl
: config.apiBaseUrl.startsWith('/')
? config.apiBaseUrl
: `/${config.apiBaseUrl}`
const baseUrl = new URL(rawBaseUrl, origin)
if (!baseUrl.pathname.endsWith('/')) {
baseUrl.pathname = `${baseUrl.pathname}/`
}
return baseUrl
}
/**
* 构建完整 URL
*/
function buildUrl(path: string): string {
const normalizedPath = path.replace(/^\/+/, '')
const url = new URL(normalizedPath, resolveApiBaseUrl())
return url.toString()
}
/**
* 初始化 CSRF Token
* 从 cookie 或 API 获取 Token 并存储
*/
export async function initCSRFToken(): Promise<string | null> {
// 优先从 cookie 获取
let token = getCSRFTokenFromCookie()
if (!token) {
try {
// 使用原生 fetch 避免循环依赖
const response = await fetch(buildUrl('/auth/csrf-token'), {
method: 'GET',
credentials: 'include',
headers: {
'Content-Type': 'application/json',
},
})
if (response.ok) {
const result = await response.json()
// 后端返回字段名为 csrf_token
if (result.code === 0 && result.data?.csrf_token) {
token = result.data.csrf_token
}
}
} catch {
// API 不支持,使用 cookie 中的 token如果有
token = getCSRFTokenFromCookie()
}
}
if (token) {
setCSRFToken(token)
}
return token
}
/**
* 清除 CSRF Token登出时调用
*/
export function clearCSRFToken(): void {
csrfToken = null
}
/**
* CSRF Token 头名称
*/
export const CSRF_HEADER_NAME = 'X-CSRF-Token'
/**
* 获取带 CSRF Token 的请求头
* 用于 POST/PUT/DELETE 请求
*/
export function getCSRFHeaders(): Record<string, string> {
const token = csrfToken || getCSRFTokenFromCookie()
if (!token) {
return {}
}
return {
[CSRF_HEADER_NAME]: token
}
}
/**
* 需要 CSRF 保护的方法列表
*/
export const CSRF_PROTECTED_METHODS = ['POST', 'PUT', 'DELETE', 'PATCH']

View File

@@ -0,0 +1,32 @@
export {
get,
post,
put,
del,
download,
upload,
request,
} from './client'
export {
getAccessToken,
setAccessToken,
clearAccessToken,
isAccessTokenExpired,
getCurrentUser,
setCurrentUser,
getCurrentRoles,
setCurrentRoles,
isAdmin,
getRoleCodes,
isAuthenticated,
clearSession,
isRefreshing,
startRefreshing,
endRefreshing,
getRefreshPromise,
setRefreshPromise,
clearRefreshPromise,
} from './auth-session'
export { AppError, ErrorType, isAppError } from '@/lib/errors'

View File

@@ -0,0 +1,4 @@
export * from './config'
export * from './errors'
export * from './http'
export * from './storage'

View File

@@ -0,0 +1,7 @@
export {
getRefreshToken,
setRefreshToken,
clearRefreshToken,
hasRefreshToken,
hasSessionPresenceCookie,
} from './token-storage'

View File

@@ -0,0 +1,68 @@
import { afterEach, describe, expect, it, vi } from 'vitest'
import {
clearRefreshToken,
getRefreshToken,
hasRefreshToken,
hasSessionPresenceCookie,
setRefreshToken,
} from './token-storage'
const originalDocument = globalThis.document
describe('token-storage', () => {
afterEach(() => {
clearRefreshToken()
vi.restoreAllMocks()
Object.defineProperty(globalThis, 'document', {
configurable: true,
value: originalDocument,
})
})
it('stores refresh tokens in memory and normalizes empty values to null', () => {
setRefreshToken(' refresh-token ')
expect(getRefreshToken()).toBe('refresh-token')
expect(hasRefreshToken()).toBe(true)
setRefreshToken(' ')
expect(getRefreshToken()).toBeNull()
expect(hasRefreshToken()).toBe(false)
setRefreshToken(undefined)
expect(getRefreshToken()).toBeNull()
})
it('clears the in-memory refresh token explicitly', () => {
setRefreshToken('token-to-clear')
expect(hasRefreshToken()).toBe(true)
clearRefreshToken()
expect(getRefreshToken()).toBeNull()
expect(hasRefreshToken()).toBe(false)
})
it('detects the session presence cookie when it is present among other cookies', () => {
vi.spyOn(document, 'cookie', 'get').mockReturnValue('foo=bar; ums_session_present=1; theme=dark')
expect(hasSessionPresenceCookie()).toBe(true)
})
it('returns false when the session presence cookie is absent', () => {
vi.spyOn(document, 'cookie', 'get').mockReturnValue('foo=bar; theme=dark')
expect(hasSessionPresenceCookie()).toBe(false)
})
it('returns false when document is unavailable', () => {
Object.defineProperty(globalThis, 'document', {
configurable: true,
value: undefined,
})
expect(hasSessionPresenceCookie()).toBe(false)
})
})

View File

@@ -0,0 +1,38 @@
/**
* In-memory refresh token storage.
*
* The authoritative session continuity mechanism is now the backend-managed
* HttpOnly refresh cookie. This module only keeps a process-local copy so the
* current tab can still send an explicit logout payload when available.
*/
let refreshToken: string | null = null
const SESSION_PRESENCE_COOKIE_NAME = 'ums_session_present'
export function getRefreshToken(): string | null {
return refreshToken
}
export function setRefreshToken(token: string | null | undefined): void {
const value = (token || '').trim()
refreshToken = value || null
}
export function clearRefreshToken(): void {
refreshToken = null
}
export function hasRefreshToken(): boolean {
return refreshToken !== null
}
export function hasSessionPresenceCookie(): boolean {
if (typeof document === 'undefined') {
return false
}
return document.cookie
.split(';')
.map((cookie) => cookie.trim())
.some((cookie) => cookie.startsWith(`${SESSION_PRESENCE_COOKIE_NAME}=`))
}