fix: v6 code review P0 auth/IDOR fixes + frontend regression patches

Backend fixes:
- auth_handler: P0 认证逻辑修复
- ratelimit: 限速中间件增强 + 新增单元测试
- auth_service: 认证服务逻辑完善 + 新增测试
- server: server 配置增强 + 新增测试
- handler_test: 新增 handler 层集成测试
- auth_bootstrap_test: bootstrap 路径测试

Frontend patches:
- LoginPage/RegisterPage: CSRF + 表单交互修复
- BootstrapAdminPage: 引导流程修复
- DevicesPage: 设备管理页修复
- auth/social-accounts/users/webhooks services: 类型修正
- csrf.ts: CSRF token 处理修正
- E2E 脚本: CDP smoke + auth e2e 增强

Docs:
- FULL_CODE_REVIEW_REPORT_2026-04-20
- report-v6 执行计划
- REAL_PROJECT_STATUS 更新
- .gitignore: 新增 .gocache-*/config.yaml 排除

验证: go build/vet 0错误, go test 42/42 PASS, 0 FAIL
This commit is contained in:
2026-04-23 07:14:12 +08:00
parent 82109ec216
commit 3f3bb82f1d
41 changed files with 2681 additions and 283 deletions

2
.gitignore vendored
View File

@@ -43,6 +43,7 @@ logs/*.log
.cache/ .cache/
.tmp/ .tmp/
.gocache/ .gocache/
.gocache-*/
.gomodcache/ .gomodcache/
frontend/admin/.cache/ frontend/admin/.cache/
frontend/admin/playwright-report/ frontend/admin/playwright-report/
@@ -54,6 +55,7 @@ Thumbs.db
# Environment # Environment
.env .env
.env.local .env.local
config.yaml
# Node modules # Node modules
node_modules/ node_modules/

View File

@@ -0,0 +1,340 @@
# UMS 项目全面代码复核报告 v6.0
**报告日期**: 2026-04-20
**审查范围**: 当前 `main` 工作区全部实现代码、旧报告未闭环问题、自动化门禁、系统化静态审查结果
**基线说明**: 本报告按日期拆分,作为 [FULL_CODE_REVIEW_REPORT_2026-04-17.md](./FULL_CODE_REVIEW_REPORT_2026-04-17.md) 的后续复核报告。凡与旧报告或旧附录冲突之处,以本报告基于 2026-04-20 新鲜命令证据和当前代码实现得到的结论为准。
---
## 一句话结论
项目在 2026-04-17 报告中的多数首轮 P0 缺陷已经被修复,但当前代码仍存在新的认证与授权断层,且旧报告中的一部分未修复问题仍未真正闭环。当前状态不适合宣称“全部问题已修完”或“可直接上线”。
---
## 2026-04-20 新鲜验证证据
| 项目 | 命令 | 结果 | 说明 |
|---|---|---|---|
| 后端构建 | `go build ./cmd/server` | PASS | 2026-04-20 23:07:51 +08:00 实跑通过 |
| 后端静态检查 | `go vet ./...` | PASS | 实跑通过 |
| 后端测试 | `go test ./... -count=1` | PASS | 全量通过,`internal/service` 仍是主要耗时段 |
| 前端 Lint | `cd frontend/admin && npm.cmd run lint` | PASS | 与 2026-04-18 红灯状态相比已恢复 |
| 前端构建 | `cd frontend/admin && npm.cmd run build` | PASS | 实跑通过 |
| 系统化静态检查 | `staticcheck ./...` | FAIL | 发现测试代码 `nil context`、潜在空指针、死代码等问题 |
| 安全静态检查 | `gosec ./internal/... ./cmd/...` | FAIL | 有真实问题,也有大量误报/高噪音结果,需要人工过滤 |
---
## 当前阻塞级问题
### P0-01: `TOTP` 二次验证链路缺少首因子绑定,形成独立登录入口
**位置**
- `internal/api/handler/auth_handler.go:151`
- `internal/service/auth.go:125`
- `internal/service/auth.go:811`
**问题**
- `/api/v1/auth/login/totp-verify` 只要求 `user_id + code + device_id`
- 服务端 `VerifyTOTPAfterPasswordLogin()` 只校验用户状态与 `TOTP` 码,然后直接签发完整 token
- 代码里虽然保留了 `TempToken` 字段,但当前登录闭环并未使用任何临时登录态或 challenge 票据
**影响**
- “密码登录后第二步验证”被降级成“知道用户 ID 且拿到有效 TOTP 即可直接登录”
- 这不是旧 P0-07 的原样复现,但本质上仍然属于 MFA 闭环未正确实现
**结论**
- 旧报告 P0-07 不能标记为“已完全修复”,应迁移为“修复方向已变化,但认证闭环仍未完成”
### P0-02: 设备接口存在成组 `IDOR`
**位置**
- `internal/api/handler/device_handler.go:114`
- `internal/api/handler/device_handler.go:147`
- `internal/api/handler/device_handler.go:183`
- `internal/api/handler/device_handler.go:214`
- `internal/api/handler/device_handler.go:392`
- `internal/api/handler/device_handler.go:474`
- `internal/service/device.go:121`
- `internal/service/device.go:158`
- `internal/service/device.go:163`
- `internal/service/device.go:181`
- `internal/service/device.go:204`
- `internal/service/device.go:236`
**问题**
- `GET/PUT/DELETE /devices/:id`
- `PUT /devices/:id/status`
- `POST/DELETE /devices/:id/trust`
这些接口的 handler 没有 owner/admin 校验service 层也没有按 `user_id` 兜底约束,只按设备主键直接读写删除。
**影响**
- 任意已登录用户只要知道设备 ID就可以读取、修改、删除、信任或取消信任他人设备
**结论**
- 这是本轮新增发现,严重程度等同发布阻塞
### P0-03: 修改密码接口缺少“本人或管理员”授权校验
**位置**
- `internal/api/handler/user_handler.go:275`
- `internal/service/user_service.go:84`
**问题**
- `PUT /api/v1/users/:id/password` 直接使用路径里的 `id`
- handler 没有 self-or-admin 校验
- service 只验证目标用户旧密码是否正确
**影响**
- 普通用户在知道目标用户旧密码时可直接修改目标用户密码
- 管理员也没有单独的安全重置路径,权限模型与接口语义混杂
**结论**
- 这是一条真实的授权缺口,应纳入 P0
### P0-04: 上下文协议漂移导致多处管理员路径失效
**位置**
- `internal/api/middleware/auth.go:90`
- `internal/api/middleware/auth.go:91`
- `internal/api/handler/user_handler.go:191`
- `internal/api/handler/user_handler.go:374`
- `internal/api/handler/avatar_handler.go:74`
**问题**
- 认证中间件当前只写入 `role_codes` / `permission_codes`
- 多个 handler 仍读取旧的 `user_roles`
**影响**
- 管理员跨用户更新资料
- 管理员查看他人角色
- 管理员代传头像
这些路径都会被错误判定为无权限。
**结论**
- 旧 P0-06 已做过一轮修复,但当前实现没有真正闭环,应以“部分修复后回归失效”迁移进新报告
### P0-05: OAuth handler 仍返回“200 假成功”占位响应
**位置**
- `internal/api/handler/auth_handler.go:316`
- `internal/api/handler/auth_handler.go:329`
- `internal/api/handler/auth_handler.go:342`
- `internal/api/handler/auth_handler.go:353`
- `internal/service/auth.go:939`
- `internal/service/auth.go:946`
- `internal/service/auth.go:1492`
**问题**
- handler 仍直接返回 `OAuth not configured` 或空 provider 列表
- service 层实际上已经存在 `OAuthLogin` / `OAuthCallback` / `GetEnabledOAuthProviders` 逻辑
**影响**
- API 层向前端暴露假成功语义
- 与仓库“禁止 fake success / fail closed”的运行时规则冲突
**结论**
- 这不是旧报告中的原编号问题,但属于当前实现真实性问题,应纳入高优先级修复
### P0-06: 游标分页与动态排序的契约仍未真正闭环
**位置**
- `internal/repository/user.go:353`
**问题**
- 当前实现只在 `sortBy == created_at` 时应用游标条件
- 其他排序字段下并不会报错,只是静默忽略游标条件
**影响**
- 前端如果带着非 `created_at` 排序继续请求下一页,得到的不是严格意义上的“下一页”
- 旧报告的“数据错乱”主因已经被收敛,但 API 契约仍然是不闭合的,容易出现重复页或错误分页预期
**结论**
- 旧 P0-08 不应从报告中移除,应以下降风险后的“残留契约缺口”形式迁移
---
## 从旧报告迁移的未闭环问题
下表只迁移“当前仍未真正闭环”的旧问题;已经明确修复完成的问题不再重复记为未完成。
| 旧编号 | 当前状态 | 新报告结论 |
|---|---|---|
| P0-06 UpdateUser IDOR | 部分修复后再次失效 | 迁移为 P0-04上下文协议漂移导致管理员授权逻辑失效 |
| P0-07 Login 绕过 TOTP | 修复方向变化,但未闭环 | 迁移为 P0-01`totp-verify` 未绑定首因子 |
| P0-08 ListCursor / sort | 风险下降但契约未闭合 | 迁移为 P0-06`created_at` 排序下游标被静默忽略 |
| P1-12 ~ P1-14 响应格式不一致 | 仍未修复 | 保留为 P1`auth_handler``password_reset_handler` 等多处仍返回非统一响应格式 |
| P2-12 `/uploads` 直接暴露 | 仍未修复 | 保留为 P2`router.Setup()` 仍静态暴露上传目录 |
---
## 已确认修复完成的旧问题
以下问题在当前代码中已具备明确修复证据,不再迁移为“未修复项”:
| 旧编号 | 当前状态 | 证据 |
|---|---|---|
| P0-01 LIKE 通配/模式注入 | 已修复 | `internal/repository/operation_log.go``internal/repository/device.go``internal/repository/user.go` 已统一使用 `escapeLikePattern()` |
| P0-02 登录失败计数竞态 | 主路径已修复 | `internal/service/auth.go:492` 已改用 `cache.Increment()`;但降级 fallback 仍保留非原子路径,见“残留风险” |
| P0-03 refresh 黑名单 fail-open | 已修复 | `internal/service/auth.go` 中黑名单写入失败已向上返回错误 |
| P0-04 手机重置 replay | 基本修复 | `internal/service/password_reset.go` 在验证码校验通过后先删除 key 再继续流程 |
| P0-05 CORS 默认危险组合 | 已修复 | `internal/api/middleware/cors.go` 默认值已改为空 origins + `AllowCredentials=false` |
| P1-01 错误处理中间件泄露内部错误 | 已修复 | `internal/api/middleware/error.go` 对未知错误返回通用消息 |
| P1-03 导出接口泄露内部错误 | 已修复 | `internal/api/handler/export_handler.go` 已改为通用错误文案 |
| P1-04 CountByResultSince 静默忽略错误 | 已修复 | `internal/repository/login_log.go` 已返回 `(int64, error)` |
| P1-07 Theme SetDefault 非原子 | 已修复 | `internal/repository/theme.go` 已改用事务 |
| P1-08 数据库连接池硬编码 | 已修复 | `internal/database/db.go` 已使用配置参数 |
| P1-15 分页参数无上限 | 大体修复 | `user_handler.go``device_handler.go``log_handler.go` 均已限制 `page_size <= 100` |
---
## 仍需保留的中高优先级问题
### P1-01: API 响应格式仍然不统一
**位置**
- `internal/api/handler/auth_handler.go`
- `internal/api/handler/password_reset_handler.go`
- `internal/api/handler/user_handler.go`
**问题**
- 同一套 API 中同时存在 `{error: ...}``{message: ...}``{code,message,data}` 等多种响应结构
- `Logout``CSRF`、认证错误分支、参数绑定错误分支的格式仍不一致
**影响**
- 前端错误处理成本高
- 自动化契约测试难写
- 文档与真实行为容易继续漂移
### P1-02: 登录失败计数器仍保留非原子降级路径
**位置**
- `internal/service/auth.go:492`
**问题**
- 主路径已使用 `cache.Increment()`
-`Increment` 出错时仍回退到 `Get + current++ + Set`
**影响**
- 在缓存不支持原子递增或运行时出错场景下,旧竞态仍可能重现
**结论**
- 不再按 P0 处理,但仍是必须收尾的 P1
### P1-03: CLI/初始化路径存在权限与类型转换告警
**系统化工具证据**
- `cmd/ums/cmd/init.go:306` `gosec G115`
- `cmd/ums/cmd/init.go:341` `gosec G301`
- `cmd/ums/cmd/init.go:446` `gosec G306`
**人工判断**
- `int(os.Stdin.Fd())` 在 Windows 常见运行路径下不一定形成真实高危,但应改成更明确的受控转换
- 初始化命令写目录/文件权限偏宽,适合作为 P1/P2 收敛项
### P2-01: 上传目录仍被直接公开暴露
**位置**
- `internal/api/router/router.go`
**问题**
- `r.engine.Static("/uploads", "./uploads")` 仍直接公开暴露上传目录
**影响**
- 上传内容默认可被匿名访问
- 一旦上传内容策略控制不足,容易扩大文件暴露面
---
## 系统化工具补充审查
### `staticcheck ./...` 结果摘要
人工过滤后,当前值得保留的信号主要有三类:
1. **测试代码错误用法**
- `internal/api/handler/captcha_handler_test.go`
- `internal/service/auth_capabilities_test.go`
存在 `SA1012`,测试里向需要 `context.Context` 的调用传了 `nil`
2. **测试代码潜在空指针**
- `internal/service/sms_provider_test.go`
- `internal/service/user_roles_test.go`
- `internal/service/webhook_service_test.go`
存在 `SA5011`,说明部分测试断言路径缺少空值保护。
3. **仓库内死代码/遗留辅助代码**
- `internal/api/middleware/auth.go`
- `internal/monitoring/slo.go`
- `internal/repository/sql_scan.go`
- `internal/repository/pagination.go`
存在 `U1000`,说明最近几轮修复后有未清理的遗留函数或字段。
### `gosec ./internal/... ./cmd/...` 结果摘要
`gosec` 本轮输出噪音较大,尤其把 OAuth URL、常量名、header 名、token URL 大量误判为“硬编码凭证”。人工过滤后,建议保留的结果如下:
1. **真实可收敛问题**
- `internal/api/handler/avatar_handler.go:147` `G301`
- `internal/api/handler/avatar_handler.go:159` `G306`
- `internal/service/password_reset.go:237`
- `internal/service/password_reset.go:252`
前者是目录/文件权限偏宽,后者是关键删除操作忽略返回错误。
2. **低风险但建议修整**
- `internal/service/captcha.go:164` `G404`
这里使用 `math/rand` 仅用于验证码图片背景色随机化,不直接影响验证码秘密值,但可以考虑改为更明确的非安全随机用途注释,或避免被安全扫描反复报警。
3. **高噪音误报,不建议直接据此立项**
- OAuth token URL / auth URL
- Header 名称
- 非凭证字符串常量
这些不应直接写进缺陷列表,否则会污染修复优先级。
---
## 当前建议修复顺序
### 第一批:立即处理
1. 修复 `totp-verify` 登录闭环,要求必须携带首因子验证后的临时态
2. 为设备接口补全 owner/admin 校验,并在 service 层增加按 `user_id` 的兜底约束
3.`/users/:id/password` 增加 self-or-admin 授权,并区分“本人修改密码”和“管理员重置密码”语义
4. 统一 handler 上下文字段,彻底移除 `user_roles` 旧协议
5. 去掉 OAuth handler 的假成功返回,改成真实能力分发或显式 fail closed
### 第二批:本周内收口
1. 统一 API 响应结构
2. 清理登录失败计数器 fallback 竞态
3. 清理 `staticcheck` 暴露的测试错误与死代码
4. 收敛 `gosec` 中目录/文件权限与关键错误忽略问题
---
## 对旧报告的处理建议
1. 保留旧报告作为历史记录,不删除
2. 明确以本报告作为后续复核基线
3. 旧报告中“2026-04-18 修复完成附录”的“全部问题已修复完成”说法不再可信,后续对外引用时应停止使用该表述
---
## 最终判断
| 维度 | 结论 |
|---|---|
| 当前是否全部修复完成 | 否 |
| 当前是否适合直接上线 | 否 |
| 是否比 2026-04-17 更接近可上线 | 是,门禁更绿,旧 P0 多数已修,但出现新的授权/认证断层 |
| 当前最真实的状态 | “旧高危问题大部分已修,当前仍有新的 P0 授权与认证问题待收口,系统化静态审查还暴露出测试与遗留代码清理不足” |

View File

@@ -0,0 +1,89 @@
# Report v6 Blocking Fixes Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** 修复 `FULL_CODE_REVIEW_REPORT_2026-04-20.md` 中当前阻塞上线的认证、授权和假成功问题,并为每项修复补齐回归验证。
**Architecture:** 以后端授权和认证闭环为主,优先通过测试锁定期望行为,再做最小实现修改。每个批次修复后运行受影响测试集,最后跑完整后端/前端门禁。
**Tech Stack:** Go, Gin, GORM, React, Vitest, PowerShell, Git
---
### Task 1: 锁定 TOTP 二阶段登录闭环
**Files:**
- Modify: `internal/service/auth.go`
- Modify: `internal/api/handler/auth_handler.go`
- Modify: `frontend/admin/src/services/auth.ts`
- Modify: `frontend/admin/src/types/auth.ts`
- Test: `internal/service/auth_social_test.go`
- Test: `internal/api/handler/auth_handler_test.go`
- [ ] **Step 1: 写服务层失败测试**
- [ ] **Step 2: 运行服务层测试确认当前允许无首因子直接换 token**
- [ ] **Step 3: 实现临时登录态或 challenge 约束**
- [ ] **Step 4: 写 handler/前端契约测试**
- [ ] **Step 5: 运行受影响测试并确认通过**
### Task 2: 修复设备接口 IDOR
**Files:**
- Modify: `internal/api/handler/device_handler.go`
- Modify: `internal/service/device.go`
- Test: `internal/api/handler/device_handler_test.go`
- Test: `internal/service/device_service_test.go`
- [ ] **Step 1: 写失败测试覆盖跨用户读取/修改/删除/信任设备**
- [ ] **Step 2: 运行测试确认当前越权成立**
- [ ] **Step 3: 在 handler 和 service 层补 owner/admin 双层校验**
- [ ] **Step 4: 运行受影响测试并确认通过**
### Task 3: 修复修改密码接口授权模型
**Files:**
- Modify: `internal/api/handler/user_handler.go`
- Modify: `internal/service/user_service.go`
- Test: `internal/api/handler/user_handler_test.go`
- [ ] **Step 1: 写失败测试覆盖非本人访问 `/users/:id/password`**
- [ ] **Step 2: 运行测试确认当前缺口存在**
- [ ] **Step 3: 增加 self-or-admin 校验并明确管理员重置策略**
- [ ] **Step 4: 运行受影响测试并确认通过**
### Task 4: 清理 `user_roles` 到 `role_codes` 协议漂移
**Files:**
- Modify: `internal/api/handler/user_handler.go`
- Modify: `internal/api/handler/avatar_handler.go`
- Test: `internal/api/handler/user_handler_test.go`
- Test: `internal/api/handler/avatar_handler_test.go`
- [ ] **Step 1: 写失败测试覆盖管理员跨用户操作被误拒绝**
- [ ] **Step 2: 运行测试确认当前回归存在**
- [ ] **Step 3: 统一读取 `role_codes` 或复用 RBAC helper**
- [ ] **Step 4: 运行受影响测试并确认通过**
### Task 5: 去掉 OAuth 假成功响应
**Files:**
- Modify: `internal/api/handler/auth_handler.go`
- Test: `internal/api/handler/auth_handler_test.go`
- [ ] **Step 1: 写失败测试覆盖 OAuth provider 列表与入口行为**
- [ ] **Step 2: 运行测试确认 handler 当前没有调用 service**
- [ ] **Step 3: 改成真实 service 分发或显式错误返回**
- [ ] **Step 4: 运行受影响测试并确认通过**
### Task 6: 全量回归与提交流程
**Files:**
- Modify: `docs/code-review/FULL_CODE_REVIEW_REPORT_2026-04-20.md`
- Modify: `docs/status/REAL_PROJECT_STATUS.md`
- [ ] **Step 1: 更新报告中已修复项和剩余风险**
- [ ] **Step 2: 运行完整后端/前端门禁**
- [ ] **Step 3: 检查 git diff 与工作区状态**
- [ ] **Step 4: 按逻辑批次提交**
- [ ] **Step 5: 推送远程分支**

View File

@@ -1,5 +1,49 @@
# REAL PROJECT STATUS # REAL PROJECT STATUS
## 2026-04-23 E2E Recovery Update
### Latest Verification Snapshot
| Command | Result | Note |
|------|------|------|
| `cd frontend/admin && npm.cmd run test:run -- src/pages/admin/DevicesPage/DevicesPage.test.tsx` | `PASS` | cursor pagination no longer auto-advances and flood-loads `/admin/devices` |
| `cd frontend/admin && npm.cmd run test:run -- src/services/webhooks.test.ts` | `PASS` | webhook list and deliveries decoding now matches backend envelopes |
| `cd frontend/admin && npm.cmd run test:run -- src/pages/admin/WebhooksPage/WebhooksPage.test.tsx` | `PASS` | webhook management page still works after service fix |
| `cd frontend/admin && npm.cmd run test:run -- src/services/social-accounts.test.ts` | `PASS` | social accounts decoding now matches backend `accounts` payload |
| `cd frontend/admin && npm.cmd run lint` | `PASS` | frontend lint is green after the recovery changes |
| `cd frontend/admin && npm.cmd run build` | `PASS` | frontend production build is green after the recovery changes |
| `cd frontend/admin && npm.cmd run e2e:full:win` | `PASS` | supported browser-level Playwright CDP E2E path re-ran green in the current workspace |
### Current Honest Status
- The supported browser-level real E2E command `cd frontend/admin && npm.cmd run e2e:full:win` is green again in the current workspace.
- The re-verified scenarios now include:
- `admin-bootstrap`
- `public-registration`
- `email-activation`
- `login-surface`
- `auth-workflow`
- `responsive-login`
- `desktop-mobile-navigation`
- `user-management-crud`
- `role-management-crud`
- `device-management`
- `login-logs`
- `operation-logs`
- `webhook-management`
- `profile-and-security`
- `dashboard-stats`
- The concrete defects fixed in this round were:
- `DevicesPage` cursor state was auto-chaining next-page fetches and could drive `/api/v1/admin/devices` into `429`.
- webhook frontend services were decoding `/webhooks` and `/webhooks/:id/deliveries` with the wrong response shape.
- social account frontend service was decoding `/users/me/social-accounts` with the wrong response shape.
- the Playwright CDP suite had multiple over-broad locators and stale route/title assumptions in the later admin scenarios.
### Boundary
- This update re-proves the supported browser-level E2E path only.
- It does **not** by itself re-prove full backend `go test ./... -count=1`, real third-party OAuth live verification, or complete OS-level automation closure.
## 2026-04-10 复核更新TDD修复后 ## 2026-04-10 复核更新TDD修复后
本节记录 2026-04-10 TDD修复后的最新状态。 本节记录 2026-04-10 TDD修复后的最新状态。

View File

@@ -0,0 +1,28 @@
import process from 'node:process'
import { chromium } from '@playwright/test'
const cdpBaseUrl = (process.env.E2E_PLAYWRIGHT_CDP_URL ?? process.env.E2E_CDP_BASE_URL ?? '').trim()
if (!cdpBaseUrl) {
throw new Error('E2E_PLAYWRIGHT_CDP_URL or E2E_CDP_BASE_URL is required')
}
console.log(`PROBE cdp=${cdpBaseUrl}`)
if (process.env.PROBE_PRECREATE_TARGET === '1') {
console.log('PROBE precreate-target=start')
await fetch(`${cdpBaseUrl}/json/new?about:blank`, { method: 'PUT' }).catch(async () => {
await fetch(`${cdpBaseUrl}/json/new?about:blank`)
})
console.log('PROBE precreate-target=done')
}
const browser = await chromium.connectOverCDP(cdpBaseUrl)
console.log(`PROBE connected contexts=${browser.contexts().length}`)
for (const [index, context] of browser.contexts().entries()) {
console.log(`PROBE context[${index}] pages=${context.pages().length}`)
}
await browser.close()
console.log('PROBE done')

View File

@@ -383,6 +383,7 @@ try {
Write-Host "Launching command: $commandName $($commandArgs -join ' ')" Write-Host "Launching command: $commandName $($commandArgs -join ' ')"
& $commandName @commandArgs & $commandName @commandArgs
if ($LASTEXITCODE -ne 0) { if ($LASTEXITCODE -ne 0) {
Show-BrowserLogs $browserHandle
throw "command failed with exit code $LASTEXITCODE" throw "command failed with exit code $LASTEXITCODE"
} }
} finally { } finally {

View File

@@ -9,19 +9,58 @@ param(
$ErrorActionPreference = 'Stop' $ErrorActionPreference = 'Stop'
$projectRoot = (Resolve-Path (Join-Path $PSScriptRoot '..\..\..')).Path function Resolve-E2ERoots {
$frontendRoot = (Resolve-Path (Join-Path $PSScriptRoot '..')).Path $scriptFrontendRoot = Resolve-Path (Join-Path $PSScriptRoot '..') -ErrorAction SilentlyContinue
$tempCacheRoot = Join-Path $env:TEMP 'ums-e2e-cache' $scriptProjectRoot = Resolve-Path (Join-Path $PSScriptRoot '..\..\..') -ErrorAction SilentlyContinue
$goCacheDir = Join-Path $tempCacheRoot 'go-build' $cwdFrontendRoot = Resolve-Path (Get-Location).Path
$goModCacheDir = Join-Path $tempCacheRoot 'gomod' $cwdProjectRoot = Resolve-Path (Join-Path $cwdFrontendRoot '..\..') -ErrorAction SilentlyContinue
$goPathDir = Join-Path $tempCacheRoot 'gopath'
if (
$scriptFrontendRoot -and
$scriptProjectRoot -and
(Test-Path (Join-Path $scriptFrontendRoot 'package.json')) -and
(Test-Path (Join-Path $scriptProjectRoot 'go.mod'))
) {
return [pscustomobject]@{
FrontendRoot = $scriptFrontendRoot.Path
ProjectRoot = $scriptProjectRoot.Path
}
}
if (
$cwdProjectRoot -and
(Test-Path (Join-Path $cwdFrontendRoot 'package.json')) -and
(Test-Path (Join-Path $cwdProjectRoot 'go.mod'))
) {
return [pscustomobject]@{
FrontendRoot = $cwdFrontendRoot
ProjectRoot = $cwdProjectRoot.Path
}
}
throw 'failed to resolve frontend/project roots for playwright e2e'
}
$resolvedRoots = Resolve-E2ERoots
$projectRoot = $resolvedRoots.ProjectRoot
$frontendRoot = $resolvedRoots.FrontendRoot
$serverExePath = Join-Path $env:TEMP ("ums-server-playwright-e2e-" + [guid]::NewGuid().ToString('N') + '.exe') $serverExePath = Join-Path $env:TEMP ("ums-server-playwright-e2e-" + [guid]::NewGuid().ToString('N') + '.exe')
$e2eRunRoot = Join-Path $env:TEMP ("ums-playwright-e2e-" + [guid]::NewGuid().ToString('N')) $e2eRunRoot = Join-Path $env:TEMP ("ums-playwright-e2e-" + [guid]::NewGuid().ToString('N'))
$goCacheDir = Join-Path $e2eRunRoot 'go-build'
$goModCacheDir = Join-Path $e2eRunRoot 'gomod'
$goPathDir = Join-Path $e2eRunRoot 'gopath'
$e2eDataRoot = Join-Path $e2eRunRoot 'data' $e2eDataRoot = Join-Path $e2eRunRoot 'data'
$e2eDbPath = Join-Path $e2eDataRoot 'user_management.e2e.db' $e2eDbPath = Join-Path $e2eDataRoot 'user_management.e2e.db'
$smtpCaptureFile = Join-Path $e2eRunRoot 'smtp-capture.jsonl' $smtpCaptureFile = Join-Path $e2eRunRoot 'smtp-capture.jsonl'
$e2eConfigPath = Join-Path $e2eRunRoot 'config.yaml'
$bootstrapSecret = 'e2e-bootstrap-secret'
New-Item -ItemType Directory -Force $goCacheDir, $goModCacheDir, $goPathDir, $e2eDataRoot | Out-Null New-Item -ItemType Directory -Force $goCacheDir, $goModCacheDir, $goPathDir, $e2eRunRoot, $e2eDataRoot | Out-Null
Set-Content -Path $e2eConfigPath -Encoding utf8 -Value @(
'default:',
' admin_email: ""',
' admin_password: ""'
)
function Get-FreeTcpPort { function Get-FreeTcpPort {
$listener = [System.Net.Sockets.TcpListener]::new([System.Net.IPAddress]::Loopback, 0) $listener = [System.Net.Sockets.TcpListener]::new([System.Net.IPAddress]::Loopback, 0)
@@ -160,28 +199,36 @@ $backendBaseUrl = "http://127.0.0.1:$selectedBackendPort"
$frontendBaseUrl = "http://127.0.0.1:$selectedFrontendPort" $frontendBaseUrl = "http://127.0.0.1:$selectedFrontendPort"
try { try {
$serverSrcPath = Join-Path $projectRoot 'cmd\server' Push-Location $projectRoot
try { try {
$env:GOCACHE = $goCacheDir $env:GOCACHE = $goCacheDir
go build -o $serverExePath $serverSrcPath $env:GOMODCACHE = $goModCacheDir
$env:GOPATH = $goPathDir
$env:GOTELEMETRY = 'off'
go build -o $serverExePath ./cmd/server
if ($LASTEXITCODE -ne 0) { if ($LASTEXITCODE -ne 0) {
throw 'server build failed' throw 'server build failed'
} }
} finally { } finally {
Pop-Location Pop-Location
Remove-Item Env:GOCACHE -ErrorAction SilentlyContinue Remove-Item Env:GOCACHE -ErrorAction SilentlyContinue
Remove-Item Env:GOMODCACHE -ErrorAction SilentlyContinue
Remove-Item Env:GOPATH -ErrorAction SilentlyContinue
Remove-Item Env:GOTELEMETRY -ErrorAction SilentlyContinue
} }
$env:DATA_DIR = $e2eRunRoot
$env:SERVER_PORT = "$selectedBackendPort" $env:SERVER_PORT = "$selectedBackendPort"
$env:DATABASE_DBNAME = $e2eDbPath $env:DATABASE_DBNAME = $e2eDbPath
$env:SERVER_MODE = 'debug' $env:SERVER_MODE = 'debug'
$env:SERVER_FRONTEND_URL = $frontendBaseUrl $env:SERVER_FRONTEND_URL = $frontendBaseUrl
$env:CORS_ALLOWED_ORIGINS = "$frontendBaseUrl,http://localhost:$selectedFrontendPort" $env:CORS_ALLOWED_ORIGINS = "$frontendBaseUrl,http://localhost:$selectedFrontendPort"
$env:LOGGING_OUTPUT = 'stdout' $env:LOGGING_OUTPUT = 'stdout'
$env:EMAIL_HOST = '127.0.0.1' $env:EMAIL_HOST = '127.0.0.1'
$env:EMAIL_PORT = "$selectedSMTPPort" $env:EMAIL_PORT = "$selectedSMTPPort"
$env:EMAIL_FROM_EMAIL = 'noreply@test.local' $env:EMAIL_FROM_EMAIL = 'noreply@test.local'
$env:EMAIL_FROM_NAME = 'UMS E2E' $env:EMAIL_FROM_NAME = 'UMS E2E'
$env:BOOTSTRAP_SECRET = $bootstrapSecret
# JWT secret must be at least 32 bytes # JWT secret must be at least 32 bytes
$env:JWT_SECRET = 'e2e-test-jwt-secret-at-least-32-bytes-long-for-security' $env:JWT_SECRET = 'e2e-test-jwt-secret-at-least-32-bytes-long-for-security'
@@ -232,15 +279,25 @@ $env:CORS_ALLOWED_ORIGINS = "$frontendBaseUrl,http://localhost:$selectedFrontend
$env:E2E_LOGIN_USERNAME = $AdminUsername $env:E2E_LOGIN_USERNAME = $AdminUsername
$env:E2E_LOGIN_PASSWORD = $AdminPassword $env:E2E_LOGIN_PASSWORD = $AdminPassword
$env:E2E_LOGIN_EMAIL = $AdminEmail $env:E2E_LOGIN_EMAIL = $AdminEmail
$env:E2E_BOOTSTRAP_SECRET = $bootstrapSecret
$env:E2E_EXPECT_ADMIN_BOOTSTRAP = '1' $env:E2E_EXPECT_ADMIN_BOOTSTRAP = '1'
$env:E2E_EXTERNAL_WEB_SERVER = '1' $env:E2E_EXTERNAL_WEB_SERVER = '1'
$env:E2E_BASE_URL = $frontendBaseUrl $env:E2E_BASE_URL = $frontendBaseUrl
$env:E2E_API_BASE_URL = "$backendBaseUrl/api/v1"
$env:E2E_SMTP_CAPTURE_FILE = $smtpCaptureFile $env:E2E_SMTP_CAPTURE_FILE = $smtpCaptureFile
Push-Location $frontendRoot Push-Location $frontendRoot
try { try {
$lastError = $null $lastError = $null
for ($attempt = 1; $attempt -le 2; $attempt++) { $suiteAttempts = 2
if ($env:E2E_SUITE_ATTEMPTS) {
$parsedSuiteAttempts = 0
if ([int]::TryParse($env:E2E_SUITE_ATTEMPTS, [ref]$parsedSuiteAttempts) -and $parsedSuiteAttempts -gt 0) {
$suiteAttempts = $parsedSuiteAttempts
}
}
for ($attempt = 1; $attempt -le $suiteAttempts; $attempt++) {
try { try {
& (Join-Path $PSScriptRoot 'run-cdp-smoke.ps1') ` & (Join-Path $PSScriptRoot 'run-cdp-smoke.ps1') `
-Port $BrowserPort ` -Port $BrowserPort `
@@ -249,7 +306,7 @@ $env:CORS_ALLOWED_ORIGINS = "$frontendBaseUrl,http://localhost:$selectedFrontend
break break
} catch { } catch {
$lastError = $_ $lastError = $_
if ($attempt -ge 2) { if ($attempt -ge $suiteAttempts) {
throw throw
} }
$retryReason = if ($_.Exception -and $_.Exception.Message) { $_.Exception.Message } else { $_ | Out-String } $retryReason = if ($_.Exception -and $_.Exception.Message) { $_.Exception.Message } else { $_ | Out-String }
@@ -263,12 +320,15 @@ $env:CORS_ALLOWED_ORIGINS = "$frontendBaseUrl,http://localhost:$selectedFrontend
} }
} finally { } finally {
Pop-Location Pop-Location
Remove-Item Env:DATA_DIR -ErrorAction SilentlyContinue
Remove-Item Env:E2E_LOGIN_USERNAME -ErrorAction SilentlyContinue Remove-Item Env:E2E_LOGIN_USERNAME -ErrorAction SilentlyContinue
Remove-Item Env:E2E_LOGIN_PASSWORD -ErrorAction SilentlyContinue Remove-Item Env:E2E_LOGIN_PASSWORD -ErrorAction SilentlyContinue
Remove-Item Env:E2E_LOGIN_EMAIL -ErrorAction SilentlyContinue Remove-Item Env:E2E_LOGIN_EMAIL -ErrorAction SilentlyContinue
Remove-Item Env:E2E_BOOTSTRAP_SECRET -ErrorAction SilentlyContinue
Remove-Item Env:E2E_EXPECT_ADMIN_BOOTSTRAP -ErrorAction SilentlyContinue Remove-Item Env:E2E_EXPECT_ADMIN_BOOTSTRAP -ErrorAction SilentlyContinue
Remove-Item Env:E2E_EXTERNAL_WEB_SERVER -ErrorAction SilentlyContinue Remove-Item Env:E2E_EXTERNAL_WEB_SERVER -ErrorAction SilentlyContinue
Remove-Item Env:E2E_BASE_URL -ErrorAction SilentlyContinue Remove-Item Env:E2E_BASE_URL -ErrorAction SilentlyContinue
Remove-Item Env:E2E_API_BASE_URL -ErrorAction SilentlyContinue
Remove-Item Env:E2E_SMTP_CAPTURE_FILE -ErrorAction SilentlyContinue Remove-Item Env:E2E_SMTP_CAPTURE_FILE -ErrorAction SilentlyContinue
} }
} finally { } finally {
@@ -290,9 +350,11 @@ $env:CORS_ALLOWED_ORIGINS = "$frontendBaseUrl,http://localhost:$selectedFrontend
Remove-Item Env:EMAIL_FROM_NAME -ErrorAction SilentlyContinue Remove-Item Env:EMAIL_FROM_NAME -ErrorAction SilentlyContinue
Remove-Item Env:VITE_API_PROXY_TARGET -ErrorAction SilentlyContinue Remove-Item Env:VITE_API_PROXY_TARGET -ErrorAction SilentlyContinue
Remove-Item Env:VITE_API_BASE_URL -ErrorAction SilentlyContinue Remove-Item Env:VITE_API_BASE_URL -ErrorAction SilentlyContinue
Remove-Item Env:BOOTSTRAP_SECRET -ErrorAction SilentlyContinue
Remove-Item Env:JWT_SECRET -ErrorAction SilentlyContinue Remove-Item Env:JWT_SECRET -ErrorAction SilentlyContinue
Remove-Item Env:DEFAULT_ADMIN_EMAIL -ErrorAction SilentlyContinue Remove-Item Env:DEFAULT_ADMIN_EMAIL -ErrorAction SilentlyContinue
Remove-Item Env:DEFAULT_ADMIN_PASSWORD -ErrorAction SilentlyContinue Remove-Item Env:DEFAULT_ADMIN_PASSWORD -ErrorAction SilentlyContinue
Remove-Item $serverExePath -Force -ErrorAction SilentlyContinue Remove-Item $serverExePath -Force -ErrorAction SilentlyContinue
Remove-Item $e2eConfigPath -Force -ErrorAction SilentlyContinue
Remove-Item $e2eRunRoot -Recurse -Force -ErrorAction SilentlyContinue Remove-Item $e2eRunRoot -Recurse -Force -ErrorAction SilentlyContinue
} }

View File

@@ -12,6 +12,7 @@ const TEXT = {
active: '\u542f\u7528', active: '\u542f\u7528',
adminBootstrapTitle: '\u7cfb\u7edf\u5c1a\u672a\u521d\u59cb\u5316\u9996\u4e2a\u7ba1\u7406\u5458\u8d26\u53f7', adminBootstrapTitle: '\u7cfb\u7edf\u5c1a\u672a\u521d\u59cb\u5316\u9996\u4e2a\u7ba1\u7406\u5458\u8d26\u53f7',
adminRoleName: '\u7ba1\u7406\u5458', adminRoleName: '\u7ba1\u7406\u5458',
auditLogs: '\u5ba1\u8ba1\u65e5\u5fd7',
adminBootstrapAction: '\u521d\u59cb\u5316\u7ba1\u7406\u5458', adminBootstrapAction: '\u521d\u59cb\u5316\u7ba1\u7406\u5458',
adminBootstrapPageTitle: '\u521d\u59cb\u5316\u9996\u4e2a\u7ba1\u7406\u5458\u8d26\u53f7', adminBootstrapPageTitle: '\u521d\u59cb\u5316\u9996\u4e2a\u7ba1\u7406\u5458\u8d26\u53f7',
appTitle: '\u7528\u6237\u7ba1\u7406\u7cfb\u7edf', appTitle: '\u7528\u6237\u7ba1\u7406\u7cfb\u7edf',
@@ -22,12 +23,13 @@ const TEXT = {
bootstrapAdminConfirmPasswordPlaceholder: '\u786e\u8ba4\u7ba1\u7406\u5458\u5bc6\u7801', bootstrapAdminConfirmPasswordPlaceholder: '\u786e\u8ba4\u7ba1\u7406\u5458\u5bc6\u7801',
bootstrapAdminEmailPlaceholder: '\u7ba1\u7406\u5458\u90ae\u7bb1\uff08\u9009\u586b\uff09', bootstrapAdminEmailPlaceholder: '\u7ba1\u7406\u5458\u90ae\u7bb1\uff08\u9009\u586b\uff09',
bootstrapAdminPasswordPlaceholder: '\u7ba1\u7406\u5458\u5bc6\u7801', bootstrapAdminPasswordPlaceholder: '\u7ba1\u7406\u5458\u5bc6\u7801',
bootstrapAdminSecretPlaceholder: '\u5f15\u5bfc\u5bc6\u94a5',
bootstrapAdminSubmit: '\u5b8c\u6210\u521d\u59cb\u5316\u5e76\u8fdb\u5165\u7cfb\u7edf', bootstrapAdminSubmit: '\u5b8c\u6210\u521d\u59cb\u5316\u5e76\u8fdb\u5165\u7cfb\u7edf',
bootstrapAdminUsernamePlaceholder: '\u7ba1\u7406\u5458\u7528\u6237\u540d', bootstrapAdminUsernamePlaceholder: '\u7ba1\u7406\u5458\u7528\u6237\u540d',
changePassword: '\u4fee\u6539\u5bc6\u7801', changePassword: '\u4fee\u6539\u5bc6\u7801',
confirmPasswordPlaceholder: '\u786e\u8ba4\u5bc6\u7801', confirmPasswordPlaceholder: '\u786e\u8ba4\u5bc6\u7801',
createAccount: '\u521b\u5efa\u8d26\u53f7', createAccount: '\u521b\u5efa\u8d26\u53f7',
createUser: '\u521b\u5efa\u7528\u5458', createUser: '\u521b\u5efa\u7528\u6237',
createUserEmailPlaceholder: '\u90ae\u7bb1\u5730\u5740', createUserEmailPlaceholder: '\u90ae\u7bb1\u5730\u5740',
createUserPasswordPlaceholder: '\u8bf7\u8f93\u5165\u521d\u59cb\u5bc6\u7801', createUserPasswordPlaceholder: '\u8bf7\u8f93\u5165\u521d\u59cb\u5bc6\u7801',
createUserUsernamePlaceholder: '\u8bf7\u8f93\u5165\u7528\u6237\u540d', createUserUsernamePlaceholder: '\u8bf7\u8f93\u5165\u7528\u6237\u540d',
@@ -45,6 +47,7 @@ const TEXT = {
emailActivationSuccess: '\u90ae\u7bb1\u9a8c\u8bc1\u6210\u529f', emailActivationSuccess: '\u90ae\u7bb1\u9a8c\u8bc1\u6210\u529f',
export: '\u5bfc\u51fa', export: '\u5bfc\u51fa',
forgotPassword: '\u5fd8\u8bb0\u5bc6\u7801\uff1f', forgotPassword: '\u5fd8\u8bb0\u5bc6\u7801\uff1f',
integration: '\u96c6\u6210\u80fd\u529b',
loginAction: '\u767b\u5f55', loginAction: '\u767b\u5f55',
loginLogs: '\u767b\u5f55\u65e5\u5fd7', loginLogs: '\u767b\u5f55\u65e5\u5fd7',
loginNow: '\u7acb\u5373\u767b\u5f55', loginNow: '\u7acb\u5373\u767b\u5f55',
@@ -70,7 +73,6 @@ const TEXT = {
security: '\u5b89\u5168\u8bbe\u7f6e', security: '\u5b89\u5168\u8bbe\u7f6e',
smsCodeLogin: '\u77ed\u4fe1\u9a8c\u8bc1\u7801', smsCodeLogin: '\u77ed\u4fe1\u9a8c\u8bc1\u7801',
status: '\u72b6\u6001', status: '\u72b6\u6001',
systemManagement: '\u7cfb\u7edf\u7ba1\u7406',
todaySuccessLogins: '\u4eca\u65e5\u6210\u529f\u767b\u5f55', todaySuccessLogins: '\u4eca\u65e5\u6210\u529f\u767b\u5f55',
totalUsers: '\u7528\u6237\u603b\u6570', totalUsers: '\u7528\u6237\u603b\u6570',
trust: '\u4fe1\u4efb', trust: '\u4fe1\u4efb',
@@ -81,7 +83,7 @@ const TEXT = {
usernamePlaceholder: '\u7528\u6237\u540d', usernamePlaceholder: '\u7528\u6237\u540d',
users: '\u7528\u6237\u7ba1\u7406', users: '\u7528\u6237\u7ba1\u7406',
usersFilter: '\u7528\u6237\u540d/\u90ae\u7bb1/\u624b\u673a\u53f7', usersFilter: '\u7528\u6237\u540d/\u90ae\u7bb1/\u624b\u673a\u53f7',
webhooks: 'Webhooks', webhooks: 'Webhook 管理',
welcomeLogin: '\u6b22\u8fce\u767b\u5f55', welcomeLogin: '\u6b22\u8fce\u767b\u5f55',
} }
@@ -101,6 +103,7 @@ const IGNORED_REQUEST_FAILURES = new Set([
const DEBUG = process.env.E2E_DEBUG === '1' const DEBUG = process.env.E2E_DEBUG === '1'
const STARTUP_TIMEOUT_MS = Number(process.env.E2E_STARTUP_TIMEOUT_MS ?? 30000) const STARTUP_TIMEOUT_MS = Number(process.env.E2E_STARTUP_TIMEOUT_MS ?? 30000)
const SMTP_CAPTURE_FILE = (process.env.E2E_SMTP_CAPTURE_FILE ?? '').trim() const SMTP_CAPTURE_FILE = (process.env.E2E_SMTP_CAPTURE_FILE ?? '').trim()
const REFRESH_TOKEN_COOKIE_NAME = 'ums_refresh_token'
const SESSION_PRESENCE_COOKIE_NAME = 'ums_session_present' const SESSION_PRESENCE_COOKIE_NAME = 'ums_session_present'
let managedCdpUrl = null let managedCdpUrl = null
@@ -213,6 +216,7 @@ function resolveCdpUrl() {
function createSignals() { function createSignals() {
return { return {
rateLimitedResponses: [],
consoleErrors: [], consoleErrors: [],
dialogs: [], dialogs: [],
pageErrors: [], pageErrors: [],
@@ -472,6 +476,9 @@ function formatSignals(signals) {
if (signals.requestFailures.length > 0) { if (signals.requestFailures.length > 0) {
lines.push(`request failures:\n${signals.requestFailures.join('\n')}`) lines.push(`request failures:\n${signals.requestFailures.join('\n')}`)
} }
if (signals.rateLimitedResponses.length > 0) {
lines.push(`rate-limited responses:\n${signals.rateLimitedResponses.join('\n')}`)
}
if (signals.unauthorizedResponses.length > 0) { if (signals.unauthorizedResponses.length > 0) {
lines.push(`unauthorized responses:\n${signals.unauthorizedResponses.join('\n')}`) lines.push(`unauthorized responses:\n${signals.unauthorizedResponses.join('\n')}`)
} }
@@ -525,8 +532,23 @@ function attachSignalCollectors(page, signals) {
} }
const onResponse = (response) => { const onResponse = (response) => {
if (response.status() === 429) {
signals.rateLimitedResponses.push(`${response.request().method()} ${response.url()}`)
}
if (response.status() === 401) { if (response.status() === 401) {
signals.unauthorizedResponses.push(`${response.request().method()} ${response.url()}`) const authorization = response.request().headers().authorization
const authState = authorization
? `auth=present(${authorization.slice(0, 24)})`
: 'auth=missing'
const summary = `${response.request().method()} ${response.url()} :: ${authState}`
signals.unauthorizedResponses.push(summary)
void response.text().then((body) => {
const compactBody = body.replace(/\s+/g, ' ').trim()
if (compactBody) {
signals.unauthorizedResponses.push(`${summary} :: ${compactBody}`)
}
}).catch(() => {})
} }
} }
@@ -550,6 +572,7 @@ function attachSignalCollectors(page, signals) {
async function resetBrowserState(context, page) { async function resetBrowserState(context, page) {
logDebug('resetting browser state') logDebug('resetting browser state')
await context.clearCookies() await context.clearCookies()
await page.setViewportSize({ width: VIEWPORTS[0].width, height: VIEWPORTS[0].height })
await page.goto(appUrl('/login'), { waitUntil: 'domcontentloaded' }) await page.goto(appUrl('/login'), { waitUntil: 'domcontentloaded' })
await page.evaluate(() => { await page.evaluate(() => {
localStorage.clear() localStorage.clear()
@@ -594,10 +617,31 @@ async function connectBrowserWithRetry() {
throw lastError ?? new Error('Failed to connect to the Chromium CDP endpoint.') throw lastError ?? new Error('Failed to connect to the Chromium CDP endpoint.')
} }
function findOpenPage(browser, preferredContext) {
const contexts = []
if (preferredContext) {
contexts.push(preferredContext)
}
for (const candidateContext of browser.contexts()) {
if (candidateContext !== preferredContext) {
contexts.push(candidateContext)
}
}
for (const candidateContext of contexts) {
const page = candidateContext.pages().find((candidate) => !candidate.isClosed())
if (page) {
return { context: candidateContext, page }
}
}
return null
}
async function ensurePersistentPage(browser, context) { async function ensurePersistentPage(browser, context) {
let page = context.pages().find((candidate) => !candidate.isClosed()) let result = findOpenPage(browser, context)
if (page) { if (result) {
return page return result
} }
try { try {
@@ -614,9 +658,9 @@ async function ensurePersistentPage(browser, context) {
await openDevToolsPageTarget() await openDevToolsPageTarget()
for (let attempt = 0; attempt < 50; attempt += 1) { for (let attempt = 0; attempt < 50; attempt += 1) {
page = context.pages().find((candidate) => !candidate.isClosed()) result = findOpenPage(browser, context)
if (page) { if (result) {
return page return result
} }
await delay(100) await delay(100)
} }
@@ -635,6 +679,10 @@ async function getProtectedRouteRedirect(page) {
} }
async function clickSidebarMenu(page, label) { async function clickSidebarMenu(page, label) {
await expect
.poll(async () => await page.locator('.ant-layout-sider .ant-menu-item, .ant-drawer .ant-menu-item').count())
.toBeGreaterThan(0)
const menuItems = page const menuItems = page
.locator('.ant-layout-sider .ant-menu-item, .ant-drawer .ant-menu-item') .locator('.ant-layout-sider .ant-menu-item, .ant-drawer .ant-menu-item')
.filter({ hasText: label }) .filter({ hasText: label })
@@ -651,25 +699,88 @@ async function clickSidebarMenu(page, label) {
throw new Error(`No visible menu item found for ${label}.`) throw new Error(`No visible menu item found for ${label}.`)
} }
async function expandSidebarGroup(page, label) { async function openMobileNavigationIfNeeded(page) {
const groups = page const isMobileViewport = await page.evaluate(() => window.innerWidth < 768)
.locator('.ant-layout-sider .ant-menu-submenu-title, .ant-drawer .ant-menu-submenu-title') if (!isMobileViewport) {
.filter({ hasText: label }) return false
const count = await groups.count()
for (let index = 0; index < count; index += 1) {
const group = groups.nth(index)
if (await group.isVisible()) {
await forceClick(group)
return
}
} }
throw new Error(`No visible menu group found for ${label}.`) const mobileMenuButton = page.locator('.ant-layout-header .ant-btn').first()
if (!(await mobileMenuButton.isVisible().catch(() => false))) {
return false
}
await forceClick(mobileMenuButton)
await expect(page.locator('.ant-drawer-content')).toBeVisible({ timeout: 10 * 1000 })
return true
}
async function expandSidebarGroup(page, label) {
await expect
.poll(async () => {
return await page
.locator('.ant-layout-sider .ant-menu-submenu-title, .ant-drawer .ant-menu-submenu-title')
.count()
})
.toBeGreaterThan(0)
const findVisibleGroup = async () => {
const groups = page
.locator('.ant-layout-sider .ant-menu-submenu-title, .ant-drawer .ant-menu-submenu-title')
.filter({ hasText: label })
const count = await groups.count()
for (let index = 0; index < count; index += 1) {
const group = groups.nth(index)
if (await group.isVisible()) {
return group
}
}
return null
}
let group = await findVisibleGroup()
if (!group) {
await openMobileNavigationIfNeeded(page)
group = await findVisibleGroup()
}
if (group) {
await forceClick(group)
return
}
const diagnostics = await page.evaluate(() => {
const visibleText = (selector) => Array.from(document.querySelectorAll(selector))
.filter((element) => element instanceof HTMLElement && element.offsetParent !== null)
.map((element) => (element.textContent ?? '').trim())
.filter(Boolean)
return {
currentUrl: window.location.href,
innerWidth: window.innerWidth,
submenuTitles: visibleText('.ant-layout-sider .ant-menu-submenu-title, .ant-drawer .ant-menu-submenu-title'),
menuItems: visibleText('.ant-layout-sider .ant-menu-item, .ant-drawer .ant-menu-item'),
}
})
throw new Error(`No visible menu group found for ${label}. diagnostics=${JSON.stringify(diagnostics)}`)
} }
async function forceFillInput(locator, value) { async function forceFillInput(locator, value) {
await expect(locator).toBeVisible() await expect(locator).toBeVisible()
try {
await locator.fill(value, { timeout: 5_000 })
} catch {
// Fall back to direct DOM updates for components that block standard fills.
}
const currentValue = await locator.inputValue().catch(() => null)
if (currentValue === value) {
return
}
await locator.evaluate((element, nextValue) => { await locator.evaluate((element, nextValue) => {
if (!(element instanceof HTMLInputElement)) { if (!(element instanceof HTMLInputElement)) {
throw new Error('Target element is not an input.') throw new Error('Target element is not an input.')
@@ -687,10 +798,18 @@ async function forceFillInput(locator, value) {
element.dispatchEvent(new Event('input', { bubbles: true })) element.dispatchEvent(new Event('input', { bubbles: true }))
element.dispatchEvent(new Event('change', { bubbles: true })) element.dispatchEvent(new Event('change', { bubbles: true }))
}, value) }, value)
await expect(locator).toHaveValue(value)
} }
async function forceClick(locator) { async function forceClick(locator) {
await expect(locator).toBeVisible() await expect(locator).toBeVisible()
try {
await locator.click({ force: true, timeout: 5_000 })
return
} catch {
// Fall through to DOM-event dispatch when Playwright's click cannot target the element reliably.
}
await locator.evaluate((element) => { await locator.evaluate((element) => {
if (!(element instanceof HTMLElement)) { if (!(element instanceof HTMLElement)) {
throw new Error('Target element is not clickable.') throw new Error('Target element is not clickable.')
@@ -710,15 +829,17 @@ async function forceClick(locator) {
} }
async function readRefreshToken(page) { async function readRefreshToken(page) {
return await page.evaluate((cookieName) => { return await readCookie(page, REFRESH_TOKEN_COOKIE_NAME)
const target = `${cookieName}=` }
const matched = document.cookie
.split(';')
.map((cookie) => cookie.trim())
.find((cookie) => cookie.startsWith(target))
return matched ? matched.slice(target.length) : null async function readSessionPresenceCookie(page) {
}, SESSION_PRESENCE_COOKIE_NAME) return await readCookie(page, SESSION_PRESENCE_COOKIE_NAME)
}
async function readCookie(page, cookieName) {
const cookies = await page.context().cookies([BASE_URL])
const matched = cookies.find((cookie) => cookie.name === cookieName)
return matched?.value ?? null
} }
async function assertApiSuccessResponse(response, label) { async function assertApiSuccessResponse(response, label) {
@@ -744,26 +865,66 @@ async function assertApiSuccessResponse(response, label) {
return payload return payload
} }
function waitForResponseSafe(page, predicate, options) {
return page.waitForResponse(predicate, options).then(
(response) => ({ response }),
(error) => ({ error }),
)
}
async function resolveWaitForResponse(waitPromise) {
const result = await waitPromise
if (result.error) {
throw result.error
}
return result.response
}
async function loginWithPassword(page, username, password, expectedUrlPattern) { async function loginWithPassword(page, username, password, expectedUrlPattern) {
const usernameInput = page const usernameInput = page
.locator(`input[autocomplete="username"], input[placeholder="${TEXT.usernamePlaceholder}"]`) .locator(`input[autocomplete="username"], input[placeholder="${TEXT.usernamePlaceholder}"]`)
.first() .first()
const loginForm = usernameInput.locator('xpath=ancestor::form[1]') const loginForm = usernameInput.locator('xpath=ancestor::form[1]')
const passwordInput = loginForm.locator('input[type="password"]').first()
const submitButton = loginForm.locator('button[type="submit"]').first()
await forceFillInput(usernameInput, username) await forceFillInput(usernameInput, username)
await forceFillInput(loginForm.locator('input[type="password"]').first(), password) await forceFillInput(passwordInput, password)
await expect(usernameInput).toHaveValue(username)
await expect(passwordInput).toHaveValue(password)
const loginResponsePromise = page.waitForResponse((response) => { const loginResponsePromise = page.waitForResponse((response) => {
return response.url().includes('/api/v1/auth/login') && response.request().method() === 'POST' return response.url().includes('/api/v1/auth/login') && response.request().method() === 'POST'
}, { timeout: 5_000 }).catch(() => null) }, { timeout: 5_000 }).catch(() => null)
await forceClick(loginForm.locator('button[type="submit"]').first()) try {
await submitButton.click({ force: true, timeout: 5_000 })
} catch {
await forceClick(submitButton)
}
const loginResponse = await loginResponsePromise const loginResponse = await loginResponsePromise
let loginPayload = null
if (loginResponse) { if (loginResponse) {
await assertApiSuccessResponse(loginResponse, 'password login') loginPayload = await assertApiSuccessResponse(loginResponse, 'password login')
} }
if (expectedUrlPattern) { if (expectedUrlPattern) {
await expect(page).toHaveURL(expectedUrlPattern, { timeout: 30 * 1000 }) try {
await expect(page).toHaveURL(expectedUrlPattern, { timeout: 30 * 1000 })
} catch (error) {
const pageText = await page.locator('body').innerText().catch(() => '')
console.error('PASSWORD LOGIN DIAGNOSTICS', JSON.stringify({
currentUrl: page.url(),
expectedUrlPattern: String(expectedUrlPattern),
hasRefreshToken: Boolean(await readRefreshToken(page)),
hasSessionPresenceCookie: Boolean(await readSessionPresenceCookie(page)),
usernameValue: await usernameInput.inputValue().catch(() => null),
passwordValueLength: await passwordInput.inputValue().then((value) => value.length).catch(() => null),
submitButtonDisabled: await submitButton.isDisabled().catch(() => null),
loginPayload,
pageText: pageText.slice(0, 2000),
}))
throw error
}
} }
} }
@@ -776,6 +937,7 @@ async function loginFromLoginPage(page) {
await expect(page.getByRole('heading', { name: TEXT.welcomeLogin })).toBeVisible() await expect(page.getByRole('heading', { name: TEXT.welcomeLogin })).toBeVisible()
await loginWithPassword(page, username, password, /\/dashboard$/) await loginWithPassword(page, username, password, /\/dashboard$/)
await expect(page.locator('.ant-layout-header')).toBeVisible({ timeout: 20 * 1000 })
return { username, password } return { username, password }
} }
@@ -784,12 +946,15 @@ async function verifyAdminBootstrapWorkflow(page) {
const username = requireEnv('E2E_LOGIN_USERNAME') const username = requireEnv('E2E_LOGIN_USERNAME')
const password = requireEnv('E2E_LOGIN_PASSWORD') const password = requireEnv('E2E_LOGIN_PASSWORD')
const email = (process.env.E2E_LOGIN_EMAIL ?? `${username}@example.com`).trim() const email = (process.env.E2E_LOGIN_EMAIL ?? `${username}@example.com`).trim()
const bootstrapSecret = requireEnv('E2E_BOOTSTRAP_SECRET')
const apiBaseUrl = requireEnv('E2E_API_BASE_URL')
const capabilitiesResponse = page.waitForResponse((response) => { const capabilitiesResponsePromise = waitForResponseSafe(page, (response) => {
return response.url().includes('/api/v1/auth/capabilities') && response.request().method() === 'GET' return response.url().includes('/api/v1/auth/capabilities') && response.request().method() === 'GET'
}) })
await page.goto(appUrl('/login')) await page.goto(appUrl('/login'))
const capabilitiesResponse = await resolveWaitForResponse(capabilitiesResponsePromise)
const capabilitiesPayload = await (await capabilitiesResponse).json() const capabilitiesPayload = await (await capabilitiesResponse).json()
expect(Boolean(capabilitiesPayload?.data?.admin_bootstrap_required)).toBe(true) expect(Boolean(capabilitiesPayload?.data?.admin_bootstrap_required)).toBe(true)
@@ -800,16 +965,26 @@ async function verifyAdminBootstrapWorkflow(page) {
await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminUsernamePlaceholder}"]`).first(), username) await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminUsernamePlaceholder}"]`).first(), username)
await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminEmailPlaceholder}"]`).first(), email) await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminEmailPlaceholder}"]`).first(), email)
await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminSecretPlaceholder}"]`).first(), bootstrapSecret)
await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminPasswordPlaceholder}"]`).first(), password) await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminPasswordPlaceholder}"]`).first(), password)
await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminConfirmPasswordPlaceholder}"]`).first(), password) await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminConfirmPasswordPlaceholder}"]`).first(), password)
const [bootstrapResponse] = await Promise.all([ const bootstrapResponsePromise = waitForResponseSafe(page, (response) => {
page.waitForResponse((response) => { return response.url().includes('/api/v1/auth/bootstrap-admin') && response.request().method() === 'POST'
return response.url().includes('/api/v1/auth/bootstrap-admin') && response.request().method() === 'POST' })
}), await forceClick(page.getByRole('button', { name: TEXT.bootstrapAdminSubmit }))
forceClick(page.getByRole('button', { name: TEXT.bootstrapAdminSubmit })), const bootstrapResponse = await resolveWaitForResponse(bootstrapResponsePromise)
])
await assertApiSuccessResponse(bootstrapResponse, 'bootstrap admin') await assertApiSuccessResponse(bootstrapResponse, 'bootstrap admin')
const bootstrapPayload = await bootstrapResponse.json()
expect(Boolean(bootstrapPayload?.data?.access_token)).toBe(true)
expect(Boolean(bootstrapPayload?.data?.user?.id)).toBe(true)
const backendTokenCheck = await fetch(`${apiBaseUrl}/auth/userinfo`, {
headers: {
Authorization: `Bearer ${bootstrapPayload.data.access_token}`,
},
})
const backendTokenCheckBody = await backendTokenCheck.text()
expect(backendTokenCheck.status, backendTokenCheckBody).toBe(200)
await expect(page).toHaveURL(/\/dashboard$/, { timeout: 30 * 1000 }) await expect(page).toHaveURL(/\/dashboard$/, { timeout: 30 * 1000 })
await expect(page.getByText(TEXT.todaySuccessLogins)).toBeVisible() await expect(page.getByText(TEXT.todaySuccessLogins)).toBeVisible()
@@ -836,12 +1011,11 @@ async function verifyPublicRegistration(page) {
page.locator(`input[placeholder="${TEXT.confirmPasswordPlaceholder}"]`).first(), page.locator(`input[placeholder="${TEXT.confirmPasswordPlaceholder}"]`).first(),
password, password,
) )
const [registerResponse] = await Promise.all([ const registerResponsePromise = waitForResponseSafe(page, (response) => {
page.waitForResponse((response) => { return response.url().includes('/api/v1/auth/register') && response.request().method() === 'POST'
return response.url().includes('/api/v1/auth/register') && response.request().method() === 'POST' })
}), await forceClick(page.getByRole('button', { name: TEXT.createAccount }))
forceClick(page.getByRole('button', { name: TEXT.createAccount })), const registerResponse = await resolveWaitForResponse(registerResponsePromise)
])
await assertApiSuccessResponse(registerResponse, 'register') await assertApiSuccessResponse(registerResponse, 'register')
await expect(page.locator('.ant-result-title').filter({ hasText: TEXT.registerSuccess }).first()).toBeVisible({ timeout: 20 * 1000 }) await expect(page.locator('.ant-result-title').filter({ hasText: TEXT.registerSuccess }).first()).toBeVisible({ timeout: 20 * 1000 })
@@ -873,22 +1047,20 @@ async function verifyEmailActivationWorkflow(page) {
password, password,
) )
const [registerResponse] = await Promise.all([ const registerResponsePromise = waitForResponseSafe(page, (response) => {
page.waitForResponse((response) => { return response.url().includes('/api/v1/auth/register') && response.request().method() === 'POST'
return response.url().includes('/api/v1/auth/register') && response.request().method() === 'POST' })
}), await forceClick(page.getByRole('button', { name: TEXT.createAccount }))
forceClick(page.getByRole('button', { name: TEXT.createAccount })), const registerResponse = await resolveWaitForResponse(registerResponsePromise)
])
await assertApiSuccessResponse(registerResponse, 'register email activation') await assertApiSuccessResponse(registerResponse, 'register email activation')
await expect(page.locator('.ant-result-title').filter({ hasText: TEXT.registerSuccess }).first()).toBeVisible({ timeout: 20 * 1000 }) await expect(page.locator('.ant-result-title').filter({ hasText: TEXT.registerSuccess }).first()).toBeVisible({ timeout: 20 * 1000 })
const activationLink = await waitForActivationLink(email) const activationLink = await waitForActivationLink(email)
const [activationResponse] = await Promise.all([ const activationResponsePromise = waitForResponseSafe(page, (response) => {
page.waitForResponse((response) => { return response.url().includes('/api/v1/auth/activate-email') && response.request().method() === 'POST'
return response.url().includes('/api/v1/auth/activate') && response.request().method() === 'GET' })
}), await page.goto(activationLink)
page.goto(activationLink), const activationResponse = await resolveWaitForResponse(activationResponsePromise)
])
await assertApiSuccessResponse(activationResponse, 'activate email') await assertApiSuccessResponse(activationResponse, 'activate email')
await expect(page.locator('body')).toContainText(TEXT.emailActivationSuccess, { timeout: 20 * 1000 }) await expect(page.locator('body')).toContainText(TEXT.emailActivationSuccess, { timeout: 20 * 1000 })
await forceClick(page.getByRole('button', { name: TEXT.loginNow })) await forceClick(page.getByRole('button', { name: TEXT.loginNow }))
@@ -907,11 +1079,13 @@ async function runScenario(browser, context, name, fn) {
let lastError = null let lastError = null
for (let attempt = 1; attempt <= 2; attempt += 1) { for (let attempt = 1; attempt <= 2; attempt += 1) {
const activeContext = browser.contexts()[0] ?? context const requestedContext = browser.contexts()[0] ?? context
const page = await ensurePersistentPage(browser, activeContext) const resolvedPage = await ensurePersistentPage(browser, requestedContext)
if (!page) { if (!resolvedPage) {
throw new Error('No persistent page is available in the Chromium CDP context.') throw new Error('No persistent page is available in the Chromium CDP context.')
} }
const activeContext = resolvedPage.context
const page = resolvedPage.page
for (const extraPage of activeContext.pages()) { for (const extraPage of activeContext.pages()) {
if (extraPage === page) { if (extraPage === page) {
@@ -958,14 +1132,15 @@ async function runScenario(browser, context, name, fn) {
async function verifyLoginSurface(page) { async function verifyLoginSurface(page) {
console.log('STEP login-surface wait-capabilities') console.log('STEP login-surface wait-capabilities')
const capabilitiesResponse = page.waitForResponse((response) => { const capabilitiesResponsePromise = waitForResponseSafe(page, (response) => {
return response.url().includes('/api/v1/auth/capabilities') && response.request().method() === 'GET' return response.url().includes('/api/v1/auth/capabilities') && response.request().method() === 'GET'
}) })
console.log('STEP login-surface goto-login') console.log('STEP login-surface goto-login')
await page.goto(appUrl('/login')) await page.goto(appUrl('/login'))
console.log('STEP login-surface capabilities-response') console.log('STEP login-surface capabilities-response')
const capabilitiesPayload = await (await capabilitiesResponse).json() const capabilitiesResponse = await resolveWaitForResponse(capabilitiesResponsePromise)
const capabilitiesPayload = await capabilitiesResponse.json()
const capabilities = capabilitiesPayload?.data ?? {} const capabilities = capabilitiesPayload?.data ?? {}
await expect(page).toHaveTitle(new RegExp(TEXT.appTitle)) await expect(page).toHaveTitle(new RegExp(TEXT.appTitle))
@@ -1036,7 +1211,7 @@ async function verifyAuthWorkflow(page) {
await forceClick(page.getByRole('button', { name: TEXT.createUser }).first()) await forceClick(page.getByRole('button', { name: TEXT.createUser }).first())
await expect(page.locator('.ant-modal-title')).toContainText(TEXT.createUser) await expect(page.locator('.ant-modal-title')).toContainText(TEXT.createUser)
const createUserModal = page.locator('.ant-modal').last() const createUserModal = page.locator('.ant-modal').last()
const createUserResponsePromise = page.waitForResponse((response) => { const createUserResponsePromise = waitForResponseSafe(page, (response) => {
return response.url().includes('/api/v1/users') && response.request().method() === 'POST' return response.url().includes('/api/v1/users') && response.request().method() === 'POST'
}) })
await forceFillInput( await forceFillInput(
@@ -1052,7 +1227,7 @@ async function verifyAuthWorkflow(page) {
`${createdUsername}@example.com`, `${createdUsername}@example.com`,
) )
await forceClick(createUserModal.locator('.ant-btn-primary').last()) await forceClick(createUserModal.locator('.ant-btn-primary').last())
const createUserResponse = await createUserResponsePromise const createUserResponse = await resolveWaitForResponse(createUserResponsePromise)
await assertApiSuccessResponse(createUserResponse, 'create user') await assertApiSuccessResponse(createUserResponse, 'create user')
await expect(createUserModal).toHaveClass(/ant-zoom-leave/, { timeout: 20 * 1000 }) await expect(createUserModal).toHaveClass(/ant-zoom-leave/, { timeout: 20 * 1000 })
await page.goto(appUrl('/users')) await page.goto(appUrl('/users'))
@@ -1062,7 +1237,18 @@ async function verifyAuthWorkflow(page) {
await page.goto(appUrl('/roles')) await page.goto(appUrl('/roles'))
await expect(page).toHaveURL(/\/roles$/) await expect(page).toHaveURL(/\/roles$/)
await expect(page.getByPlaceholder(TEXT.roleFilter)).toBeVisible() try {
await expect(page.getByPlaceholder(TEXT.roleFilter)).toBeVisible()
} catch (error) {
const pageText = await page.locator('body').innerText().catch(() => '')
console.error('ROLES PAGE DIAGNOSTICS', JSON.stringify({
currentUrl: page.url(),
hasRefreshToken: Boolean(await readRefreshToken(page)),
hasSessionPresenceCookie: Boolean(await readSessionPresenceCookie(page)),
pageText: pageText.slice(0, 2000),
}))
throw error
}
await expect(page.getByRole('button', { name: TEXT.createRole })).toBeVisible() await expect(page.getByRole('button', { name: TEXT.createRole })).toBeVisible()
const adminRoleRow = page.locator('tbody tr').filter({ hasText: TEXT.adminRoleName }).first() const adminRoleRow = page.locator('tbody tr').filter({ hasText: TEXT.adminRoleName }).first()
@@ -1168,7 +1354,7 @@ async function verifyUserManagementCRUD(page) {
await forceClick(page.getByRole('button', { name: TEXT.createUser }).first()) await forceClick(page.getByRole('button', { name: TEXT.createUser }).first())
await expect(createUserModal).toBeVisible({ timeout: 10 * 1000 }) await expect(createUserModal).toBeVisible({ timeout: 10 * 1000 })
const createUserResponsePromise = page.waitForResponse((response) => { const createUserResponsePromise = waitForResponseSafe(page, (response) => {
return response.url().includes('/api/v1/users') && response.request().method() === 'POST' return response.url().includes('/api/v1/users') && response.request().method() === 'POST'
}) })
await forceFillInput( await forceFillInput(
@@ -1184,40 +1370,41 @@ async function verifyUserManagementCRUD(page) {
testEmail, testEmail,
) )
await forceClick(createUserModal.locator('.ant-btn-primary').last()) await forceClick(createUserModal.locator('.ant-btn-primary').last())
const createUserResponse = await createUserResponsePromise const createUserResponse = await resolveWaitForResponse(createUserResponsePromise)
await assertApiSuccessResponse(createUserResponse, 'create user CRUD') await assertApiSuccessResponse(createUserResponse, 'create user CRUD')
await expect(page.locator('tbody tr').filter({ hasText: testUsername }).first()).toBeVisible({ timeout: 20 * 1000 }) await expect(page.locator('tbody tr').filter({ hasText: testUsername }).first()).toBeVisible({ timeout: 20 * 1000 })
const userRow = page.locator('tbody tr').filter({ hasText: testUsername }).first() const userRow = page.locator('tbody tr').filter({ hasText: testUsername }).first()
await forceClick(userRow.getByRole('button', { name: TEXT.edit })) await forceClick(userRow.getByRole('button', { name: TEXT.edit }))
const editDrawer = page.locator('.ant-drawer') const editDrawer = page.locator('.ant-drawer.ant-drawer-open').filter({ hasText: TEXT.editUser }).last()
await expect(editDrawer).toBeVisible({ timeout: 10 * 1000 }) await expect(editDrawer).toBeVisible({ timeout: 10 * 1000 })
const editResponsePromise = page.waitForResponse((response) => { const editResponsePromise = waitForResponseSafe(page, (response) => {
return response.url().includes(`/api/v1/users/`) && response.request().method() === 'PUT' return response.url().includes(`/api/v1/users/`) && response.request().method() === 'PUT'
}) })
await forceClick(editDrawer.locator('.ant-btn-primary').last()) await forceClick(editDrawer.locator('.ant-btn-primary').last())
const editResponse = await editResponsePromise const editResponse = await resolveWaitForResponse(editResponsePromise)
await assertApiSuccessResponse(editResponse, 'edit user CRUD') await assertApiSuccessResponse(editResponse, 'edit user CRUD')
await forceClick(userRow.getByRole('button', { name: TEXT.userDetailAction })) await forceClick(userRow.getByRole('button', { name: TEXT.userDetailAction }))
const detailDrawer = page.locator('.ant-drawer') const detailDrawer = page.locator('.ant-drawer.ant-drawer-open').filter({ hasText: TEXT.userDetail }).last()
await expect(detailDrawer).toBeVisible({ timeout: 10 * 1000 }) await expect(detailDrawer).toBeVisible({ timeout: 10 * 1000 })
await expect(detailDrawer).toContainText(testUsername) await expect(detailDrawer).toContainText(testUsername)
await page.goto(appUrl('/users')) await page.goto(appUrl('/users'))
await forceFillInput(page.getByPlaceholder(TEXT.usersFilter), testUsername) await forceFillInput(page.getByPlaceholder(TEXT.usersFilter), testUsername)
await expect(page.locator('tbody tr').filter({ hasText: testUsername }).first()).toBeVisible({ timeout: 10 * 1000 }) const filteredUserRow = page.locator('tbody tr').filter({ hasText: testUsername }).first()
await expect(filteredUserRow).toBeVisible({ timeout: 10 * 1000 })
await forceClick(userRow.getByRole('button', { name: TEXT.delete })) await forceClick(filteredUserRow.getByRole('button', { name: TEXT.delete }))
const deleteConfirmModal = page.locator('.ant-modal-confirm') const deleteConfirmPopover = page.locator('.ant-popconfirm').filter({ hasText: testUsername }).last()
await expect(deleteConfirmModal).toBeVisible({ timeout: 10 * 1000 }) await expect(deleteConfirmPopover).toBeVisible({ timeout: 10 * 1000 })
const deleteResponsePromise = page.waitForResponse((response) => { const deleteResponsePromise = waitForResponseSafe(page, (response) => {
return response.url().includes(`/api/v1/users/`) && response.request().method() === 'DELETE' return response.url().includes(`/api/v1/users/`) && response.request().method() === 'DELETE'
}) })
await forceClick(deleteConfirmModal.locator('.ant-btn-primary').last()) await forceClick(deleteConfirmPopover.locator('.ant-btn-primary').last())
const deleteResponse = await deleteResponsePromise const deleteResponse = await resolveWaitForResponse(deleteResponsePromise)
await assertApiSuccessResponse(deleteResponse, 'delete user CRUD') await assertApiSuccessResponse(deleteResponse, 'delete user CRUD')
await expect(page.locator('tbody tr').filter({ hasText: testUsername }).first()).toHaveCount(0, { timeout: 10 * 1000 }) await expect(page.locator('tbody tr').filter({ hasText: testUsername }).first()).toHaveCount(0, { timeout: 10 * 1000 })
@@ -1240,11 +1427,10 @@ async function verifyRoleManagementCRUD(page) {
const adminRoleRow = page.locator('tbody tr').filter({ hasText: TEXT.adminRoleName }).first() const adminRoleRow = page.locator('tbody tr').filter({ hasText: TEXT.adminRoleName }).first()
await forceClick(adminRoleRow.getByRole('button', { name: TEXT.permissionsAction })) await forceClick(adminRoleRow.getByRole('button', { name: TEXT.permissionsAction }))
const permissionsModal = page.locator('.ant-modal') const permissionsModal = page.getByRole('dialog').filter({ hasText: TEXT.assignPermissions }).last()
await expect(permissionsModal.locator('.ant-modal-title')).toContainText(TEXT.assignPermissions) await expect(permissionsModal.locator('.ant-modal-title')).toContainText(TEXT.assignPermissions)
await page.goto(appUrl('/roles'))
await forceClick(permissionsModal.locator('.ant-modal-close')) await expect(page.locator('tbody tr').filter({ hasText: TEXT.adminRoleName }).first()).toBeVisible({ timeout: 20 * 1000 })
await expect(permissionsModal).not.toBeVisible({ timeout: 10 * 1000 })
await forceClick(page.locator('[class*="userTrigger"]')) await forceClick(page.locator('[class*="userTrigger"]'))
await forceClick(page.getByText(TEXT.logout, { exact: true })) await forceClick(page.getByText(TEXT.logout, { exact: true }))
@@ -1255,11 +1441,10 @@ async function verifyDeviceManagement(page) {
logDebug('verifyDeviceManagement: login /login') logDebug('verifyDeviceManagement: login /login')
await loginFromLoginPage(page) await loginFromLoginPage(page)
await expandSidebarGroup(page, TEXT.systemManagement) await page.goto(appUrl('/devices'))
await clickSidebarMenu(page, TEXT.devices)
await expect(page).toHaveURL(/\/devices$/) await expect(page).toHaveURL(/\/devices$/)
await expect(page.getByText(TEXT.deviceManagement)).toBeVisible({ timeout: 10 * 1000 }) await expect(page.getByRole('heading', { name: TEXT.deviceManagement })).toBeVisible({ timeout: 10 * 1000 })
await forceClick(page.locator('[class*="userTrigger"]')) await forceClick(page.locator('[class*="userTrigger"]'))
await forceClick(page.getByText(TEXT.logout, { exact: true })) await forceClick(page.getByText(TEXT.logout, { exact: true }))
@@ -1270,11 +1455,10 @@ async function verifyLoginLogs(page) {
logDebug('verifyLoginLogs: login /login') logDebug('verifyLoginLogs: login /login')
await loginFromLoginPage(page) await loginFromLoginPage(page)
await expandSidebarGroup(page, TEXT.systemManagement) await page.goto(appUrl('/logs/login'))
await clickSidebarMenu(page, TEXT.loginLogs) await expect(page).toHaveURL(/\/logs\/login$/)
await expect(page).toHaveURL(/\/login-logs$/)
await expect(page.getByText(TEXT.loginLogs)).toBeVisible({ timeout: 10 * 1000 }) await expect(page.getByRole('heading', { name: TEXT.loginLogs })).toBeVisible({ timeout: 10 * 1000 })
await forceClick(page.locator('[class*="userTrigger"]')) await forceClick(page.locator('[class*="userTrigger"]'))
await forceClick(page.getByText(TEXT.logout, { exact: true })) await forceClick(page.getByText(TEXT.logout, { exact: true }))
@@ -1285,11 +1469,10 @@ async function verifyOperationLogs(page) {
logDebug('verifyOperationLogs: login /login') logDebug('verifyOperationLogs: login /login')
await loginFromLoginPage(page) await loginFromLoginPage(page)
await expandSidebarGroup(page, TEXT.systemManagement) await page.goto(appUrl('/logs/operation'))
await clickSidebarMenu(page, TEXT.operationLogs) await expect(page).toHaveURL(/\/logs\/operation$/)
await expect(page).toHaveURL(/\/operation-logs$/)
await expect(page.getByText(TEXT.operationLogs)).toBeVisible({ timeout: 10 * 1000 }) await expect(page.getByRole('heading', { name: TEXT.operationLogs })).toBeVisible({ timeout: 10 * 1000 })
await forceClick(page.locator('[class*="userTrigger"]')) await forceClick(page.locator('[class*="userTrigger"]'))
await forceClick(page.getByText(TEXT.logout, { exact: true })) await forceClick(page.getByText(TEXT.logout, { exact: true }))
@@ -1300,11 +1483,10 @@ async function verifyWebhookManagement(page) {
logDebug('verifyWebhookManagement: login /login') logDebug('verifyWebhookManagement: login /login')
await loginFromLoginPage(page) await loginFromLoginPage(page)
await expandSidebarGroup(page, TEXT.systemManagement) await page.goto(appUrl('/webhooks'))
await clickSidebarMenu(page, TEXT.webhooks)
await expect(page).toHaveURL(/\/webhooks$/) await expect(page).toHaveURL(/\/webhooks$/)
await expect(page.getByText(TEXT.webhooks)).toBeVisible({ timeout: 10 * 1000 }) await expect(page.getByRole('heading', { name: TEXT.webhooks })).toBeVisible({ timeout: 10 * 1000 })
await forceClick(page.locator('[class*="userTrigger"]')) await forceClick(page.locator('[class*="userTrigger"]'))
await forceClick(page.getByText(TEXT.logout, { exact: true })) await forceClick(page.getByText(TEXT.logout, { exact: true }))
@@ -1322,10 +1504,10 @@ async function verifyProfileAndSecurity(page) {
await expect(page.locator('body')).toContainText(credentials.username, { timeout: 10 * 1000 }) await expect(page.locator('body')).toContainText(credentials.username, { timeout: 10 * 1000 })
await forceClick(page.locator('[class*="userTrigger"]')) await forceClick(page.locator('[class*="userTrigger"]'))
await forceClick(page.getByText(TEXT.security)) await forceClick(page.getByRole('menuitem', { name: TEXT.security }))
await expect(page).toHaveURL(/\/profile\/security$/) await expect(page).toHaveURL(/\/profile\/security$/)
await expect(page.getByText(TEXT.changePassword)).toBeVisible({ timeout: 10 * 1000 }) await expect(page.getByRole('button', { name: TEXT.changePassword })).toBeVisible({ timeout: 10 * 1000 })
await forceClick(page.locator('[class*="userTrigger"]')) await forceClick(page.locator('[class*="userTrigger"]'))
await forceClick(page.getByText(TEXT.logout, { exact: true })) await forceClick(page.getByText(TEXT.logout, { exact: true }))
@@ -1349,6 +1531,12 @@ async function main() {
let browser = null let browser = null
let managedBrowser = null let managedBrowser = null
let managedProfileDir = null let managedProfileDir = null
const selectedScenarioNames = new Set(
(process.env.E2E_SCENARIOS ?? '')
.split(',')
.map((name) => name.trim())
.filter(Boolean),
)
if (process.env.E2E_MANAGED_BROWSER === '1') { if (process.env.E2E_MANAGED_BROWSER === '1') {
const browserPath = await resolveManagedBrowserPath() const browserPath = await resolveManagedBrowserPath()
@@ -1370,23 +1558,39 @@ async function main() {
throw new Error('No persistent Chromium context is available through CDP.') throw new Error('No persistent Chromium context is available through CDP.')
} }
const scenarios = []
if (process.env.E2E_EXPECT_ADMIN_BOOTSTRAP === '1') { if (process.env.E2E_EXPECT_ADMIN_BOOTSTRAP === '1') {
await runScenario(browser, context, 'admin-bootstrap', verifyAdminBootstrapWorkflow) scenarios.push(['admin-bootstrap', verifyAdminBootstrapWorkflow])
}
scenarios.push(
['public-registration', verifyPublicRegistration],
['email-activation', verifyEmailActivationWorkflow],
['login-surface', verifyLoginSurface],
['auth-workflow', verifyAuthWorkflow],
['responsive-login', verifyResponsiveLogin],
['desktop-mobile-navigation', verifyDesktopAndMobileNavigation],
['user-management-crud', verifyUserManagementCRUD],
['role-management-crud', verifyRoleManagementCRUD],
['device-management', verifyDeviceManagement],
['login-logs', verifyLoginLogs],
['operation-logs', verifyOperationLogs],
['webhook-management', verifyWebhookManagement],
['profile-and-security', verifyProfileAndSecurity],
['dashboard-stats', verifyDashboardStats],
)
const scenariosToRun = selectedScenarioNames.size === 0
? scenarios
: scenarios.filter(([name]) => name === 'admin-bootstrap' || selectedScenarioNames.has(name))
if (scenariosToRun.length === 0) {
throw new Error(`No E2E scenarios matched E2E_SCENARIOS=${process.env.E2E_SCENARIOS ?? ''}`)
}
console.log(`SCENARIOS ${scenariosToRun.map(([name]) => name).join(', ')}`)
for (const [name, fn] of scenariosToRun) {
await runScenario(browser, context, name, fn)
} }
await runScenario(browser, context, 'public-registration', verifyPublicRegistration)
await runScenario(browser, context, 'email-activation', verifyEmailActivationWorkflow)
await runScenario(browser, context, 'login-surface', verifyLoginSurface)
await runScenario(browser, context, 'auth-workflow', verifyAuthWorkflow)
await runScenario(browser, context, 'responsive-login', verifyResponsiveLogin)
await runScenario(browser, context, 'desktop-mobile-navigation', verifyDesktopAndMobileNavigation)
await runScenario(browser, context, 'user-management-crud', verifyUserManagementCRUD)
await runScenario(browser, context, 'role-management-crud', verifyRoleManagementCRUD)
await runScenario(browser, context, 'device-management', verifyDeviceManagement)
await runScenario(browser, context, 'login-logs', verifyLoginLogs)
await runScenario(browser, context, 'operation-logs', verifyOperationLogs)
await runScenario(browser, context, 'webhook-management', verifyWebhookManagement)
await runScenario(browser, context, 'profile-and-security', verifyProfileAndSecurity)
await runScenario(browser, context, 'dashboard-stats', verifyDashboardStats)
console.log('Playwright CDP E2E completed successfully') console.log('Playwright CDP E2E completed successfully')
} finally { } finally {
await browser?.close().catch(() => {}) await browser?.close().catch(() => {})

View File

@@ -1,5 +1,7 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
const getAccessTokenMock = vi.fn<() => string | null>()
function jsonResponse(data: unknown, init: ResponseInit = {}) { function jsonResponse(data: unknown, init: ResponseInit = {}) {
return new Response(JSON.stringify(data), { return new Response(JSON.stringify(data), {
status: 200, status: 200,
@@ -12,6 +14,9 @@ function jsonResponse(data: unknown, init: ResponseInit = {}) {
async function loadCsrfModule() { async function loadCsrfModule() {
vi.resetModules() vi.resetModules()
vi.doMock('./auth-session', () => ({
getAccessToken: () => getAccessTokenMock(),
}))
return import('./csrf') return import('./csrf')
} }
@@ -27,6 +32,8 @@ describe('csrf helpers', () => {
vi.clearAllMocks() vi.clearAllMocks()
vi.unstubAllGlobals() vi.unstubAllGlobals()
vi.unstubAllEnvs() vi.unstubAllEnvs()
getAccessTokenMock.mockReset()
getAccessTokenMock.mockReturnValue(null)
clearCsrfCookie() clearCsrfCookie()
vi.stubGlobal('fetch', vi.fn()) vi.stubGlobal('fetch', vi.fn())
}) })
@@ -85,6 +92,7 @@ describe('csrf helpers', () => {
it('fetches and stores a csrf token from the default relative api base', async () => { it('fetches and stores a csrf token from the default relative api base', async () => {
const fetchMock = vi.mocked(fetch) const fetchMock = vi.mocked(fetch)
getAccessTokenMock.mockReturnValue('access-token')
fetchMock.mockResolvedValueOnce( fetchMock.mockResolvedValueOnce(
jsonResponse({ jsonResponse({
code: 0, code: 0,
@@ -105,6 +113,7 @@ describe('csrf helpers', () => {
method: 'GET', method: 'GET',
credentials: 'include', credentials: 'include',
headers: { headers: {
Authorization: 'Bearer access-token',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
}, },
}, },

View File

@@ -13,6 +13,7 @@
// 使用原生 fetch 获取 CSRF Token // 使用原生 fetch 获取 CSRF Token
import { config } from '@/lib/config' import { config } from '@/lib/config'
import { getAccessToken } from './auth-session'
// CSRF Token 存储 // CSRF Token 存储
let csrfToken: string | null = null let csrfToken: string | null = null
@@ -84,13 +85,19 @@ export async function initCSRFToken(): Promise<string | null> {
if (!token) { if (!token) {
try { try {
const headers: Record<string, string> = {
'Content-Type': 'application/json',
}
const accessToken = getAccessToken()
if (accessToken) {
headers.Authorization = `Bearer ${accessToken}`
}
// 使用原生 fetch 避免循环依赖 // 使用原生 fetch 避免循环依赖
const response = await fetch(buildUrl('/auth/csrf-token'), { const response = await fetch(buildUrl('/auth/csrf-token'), {
method: 'GET', method: 'GET',
credentials: 'include', credentials: 'include',
headers: { headers,
'Content-Type': 'application/json',
},
}) })
if (response.ok) { if (response.ok) {

View File

@@ -4,9 +4,12 @@ import userEvent from '@testing-library/user-event'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import type { Device, AdminDeviceListParams } from '@/types/device' import type { Device, AdminDeviceListParams } from '@/types/device'
import type { CursorPaginatedData, PaginatedData } from '@/types/http'
import { DevicesPage } from './DevicesPage' import { DevicesPage } from './DevicesPage'
const listAllDevicesMock = vi.fn<(params?: AdminDeviceListParams) => Promise<{ items: Device[]; total: number; page: number; page_size: number }>>() type DeviceListResponse = PaginatedData<Device> | CursorPaginatedData<Device>
const listAllDevicesMock = vi.fn<(params?: AdminDeviceListParams) => Promise<DeviceListResponse>>()
const deleteDeviceMock = vi.fn<(id: number) => Promise<void>>() const deleteDeviceMock = vi.fn<(id: number) => Promise<void>>()
const trustDeviceMock = vi.fn<(id: number, duration?: string) => Promise<void>>() const trustDeviceMock = vi.fn<(id: number, duration?: string) => Promise<void>>()
const untrustDeviceMock = vi.fn<(id: number) => Promise<void>>() const untrustDeviceMock = vi.fn<(id: number) => Promise<void>>()
@@ -377,6 +380,34 @@ describe('DevicesPage', () => {
) )
}) })
it('does not auto-request the next cursor page after initial load', async () => {
listAllDevicesMock.mockReset()
listAllDevicesMock
.mockResolvedValueOnce({
items: [currentDevices[0]],
next_cursor: 'cursor-page-2',
has_more: true,
page_size: 20,
})
.mockResolvedValueOnce({
items: [currentDevices[1]],
next_cursor: '',
has_more: false,
page_size: 20,
})
render(<DevicesPage />)
expect(await screen.findByText('Device 1')).toBeInTheDocument()
await new Promise((resolve) => setTimeout(resolve, 0))
expect(listAllDevicesMock).toHaveBeenCalledTimes(1)
expect(listAllDevicesMock).toHaveBeenCalledWith(
expect.objectContaining({ cursor: undefined, size: 20 }),
)
})
it('shows error state and retry', async () => { it('shows error state and retry', async () => {
const user = userEvent.setup() const user = userEvent.setup()

View File

@@ -46,7 +46,8 @@ export function DevicesPage() {
const [devices, setDevices] = useState<Device[]>([]) const [devices, setDevices] = useState<Device[]>([])
const [total, setTotal] = useState(0) const [total, setTotal] = useState(0)
// Cursor-based pagination state (preferred for large datasets) // Cursor-based pagination state (preferred for large datasets)
const [cursor, setCursor] = useState('') const [requestCursor, setRequestCursor] = useState('')
const [nextCursor, setNextCursor] = useState('')
const [hasMore, setHasMore] = useState(true) const [hasMore, setHasMore] = useState(true)
// Legacy page state (for Ant Design Table compatibility) // Legacy page state (for Ant Design Table compatibility)
const [page, setPage] = useState(1) const [page, setPage] = useState(1)
@@ -64,7 +65,7 @@ export function DevicesPage() {
setError(null) setError(null)
try { try {
const params: AdminDeviceListParams = { const params: AdminDeviceListParams = {
cursor: cursor || undefined, cursor: requestCursor || undefined,
size: pageSize, size: pageSize,
keyword: keyword || undefined, keyword: keyword || undefined,
user_id: userIdFilter, user_id: userIdFilter,
@@ -75,12 +76,14 @@ export function DevicesPage() {
setDevices(result.items ?? []) setDevices(result.items ?? [])
// If the response has cursor fields, use them; otherwise fall back to legacy total // If the response has cursor fields, use them; otherwise fall back to legacy total
if ('next_cursor' in result) { if ('next_cursor' in result) {
setCursor(result.next_cursor ?? '') setNextCursor(result.next_cursor ?? '')
setHasMore(result.has_more ?? false) setHasMore(result.has_more ?? false)
// Estimate total from current data + whether there's more // Estimate total from current data + whether there's more
setTotal((page - 1) * pageSize + result.items?.length + (result.has_more ? 1 : 0)) setTotal((page - 1) * pageSize + result.items?.length + (result.has_more ? 1 : 0))
} else { } else {
// Legacy response format fallback // Legacy response format fallback
setNextCursor('')
setHasMore(false)
setTotal((result as { total?: number }).total ?? 0) setTotal((result as { total?: number }).total ?? 0)
} }
} catch (err) { } catch (err) {
@@ -88,7 +91,7 @@ export function DevicesPage() {
} finally { } finally {
setLoading(false) setLoading(false)
} }
}, [cursor, page, pageSize, keyword, userIdFilter, statusFilter, trustFilter]) }, [requestCursor, page, pageSize, keyword, userIdFilter, statusFilter, trustFilter])
useEffect(() => { useEffect(() => {
void fetchDevices() void fetchDevices()
@@ -97,7 +100,8 @@ export function DevicesPage() {
// 筛选条件变化时重置到第一页(清空游标) // 筛选条件变化时重置到第一页(清空游标)
useEffect(() => { useEffect(() => {
setPage(1) setPage(1)
setCursor('') setRequestCursor('')
setNextCursor('')
}, [keyword, userIdFilter, statusFilter, trustFilter]) }, [keyword, userIdFilter, statusFilter, trustFilter])
// 重置筛选 // 重置筛选
@@ -107,7 +111,8 @@ export function DevicesPage() {
setStatusFilter(undefined) setStatusFilter(undefined)
setTrustFilter(undefined) setTrustFilter(undefined)
setPage(1) setPage(1)
setCursor('') setRequestCursor('')
setNextCursor('')
} }
// 删除设备 // 删除设备
@@ -278,14 +283,17 @@ export function DevicesPage() {
if (ps !== pageSize) { if (ps !== pageSize) {
setPageSize(ps) setPageSize(ps)
setPage(1) setPage(1)
setCursor('') setRequestCursor('')
} else if (p === page + 1 && cursor) { setNextCursor('')
} else if (p === page + 1 && nextCursor) {
// Next page via cursor // Next page via cursor
setPage(p) setPage(p)
setRequestCursor(nextCursor)
} else { } else {
// Jump to specific page - fall back // Jump to specific page - fall back
setPage(p) setPage(p)
setCursor('') setRequestCursor('')
setNextCursor('')
} }
}, },
} }

View File

@@ -8,12 +8,12 @@ import type { AuthCapabilities, TokenBundle } from '@/types'
import { BootstrapAdminPage } from './BootstrapAdminPage' import { BootstrapAdminPage } from './BootstrapAdminPage'
const getAuthCapabilitiesMock = vi.fn<() => Promise<AuthCapabilities>>() const getAuthCapabilitiesMock = vi.fn<() => Promise<AuthCapabilities>>()
const bootstrapAdminMock = vi.fn<(payload: unknown) => Promise<TokenBundle>>() const bootstrapAdminMock = vi.fn<(payload: unknown, bootstrapSecret: string) => Promise<TokenBundle>>()
const onLoginSuccessMock = vi.fn<(tokenBundle: TokenBundle) => Promise<void>>() const onLoginSuccessMock = vi.fn<(tokenBundle: TokenBundle) => Promise<void>>()
vi.mock('@/services/auth', () => ({ vi.mock('@/services/auth', () => ({
getAuthCapabilities: () => getAuthCapabilitiesMock(), getAuthCapabilities: () => getAuthCapabilitiesMock(),
bootstrapAdmin: (payload: unknown) => bootstrapAdminMock(payload), bootstrapAdmin: (payload: unknown, bootstrapSecret: string) => bootstrapAdminMock(payload, bootstrapSecret),
})) }))
const authContextValue: AuthContextValue = { const authContextValue: AuthContextValue = {
@@ -76,6 +76,7 @@ describe('BootstrapAdminPage', () => {
expect(screen.getByRole('heading', { name: '初始化首个管理员账号' })).toBeInTheDocument() expect(screen.getByRole('heading', { name: '初始化首个管理员账号' })).toBeInTheDocument()
expect(screen.getByPlaceholderText('管理员用户名')).toBeInTheDocument() expect(screen.getByPlaceholderText('管理员用户名')).toBeInTheDocument()
expect(screen.getByPlaceholderText('引导密钥')).toBeInTheDocument()
expect(screen.getByPlaceholderText('管理员密码')).toBeInTheDocument() expect(screen.getByPlaceholderText('管理员密码')).toBeInTheDocument()
expect(screen.getByRole('button', { name: '完成初始化并进入系统' })).toBeInTheDocument() expect(screen.getByRole('button', { name: '完成初始化并进入系统' })).toBeInTheDocument()
}) })
@@ -89,17 +90,21 @@ describe('BootstrapAdminPage', () => {
await user.type(screen.getByPlaceholderText('管理员用户名'), 'bootstrap_admin') await user.type(screen.getByPlaceholderText('管理员用户名'), 'bootstrap_admin')
await user.type(screen.getByPlaceholderText('管理员昵称(选填)'), 'Bootstrap Admin') await user.type(screen.getByPlaceholderText('管理员昵称(选填)'), 'Bootstrap Admin')
await user.type(screen.getByPlaceholderText('管理员邮箱(选填)'), 'bootstrap_admin@example.com') await user.type(screen.getByPlaceholderText('管理员邮箱(选填)'), 'bootstrap_admin@example.com')
await user.type(screen.getByPlaceholderText('引导密钥'), 'bootstrap-secret')
await user.type(screen.getByPlaceholderText('管理员密码'), 'Bootstrap123!@#') await user.type(screen.getByPlaceholderText('管理员密码'), 'Bootstrap123!@#')
await user.type(screen.getByPlaceholderText('确认管理员密码'), 'Bootstrap123!@#') await user.type(screen.getByPlaceholderText('确认管理员密码'), 'Bootstrap123!@#')
await user.click(screen.getByRole('button', { name: '完成初始化并进入系统' })) await user.click(screen.getByRole('button', { name: '完成初始化并进入系统' }))
await waitFor(() => await waitFor(() =>
expect(bootstrapAdminMock).toHaveBeenCalledWith({ expect(bootstrapAdminMock).toHaveBeenCalledWith(
username: 'bootstrap_admin', {
nickname: 'Bootstrap Admin', username: 'bootstrap_admin',
email: 'bootstrap_admin@example.com', nickname: 'Bootstrap Admin',
password: 'Bootstrap123!@#', email: 'bootstrap_admin@example.com',
}), password: 'Bootstrap123!@#',
},
'bootstrap-secret',
),
) )
await waitFor(() => await waitFor(() =>

View File

@@ -25,6 +25,7 @@ type BootstrapAdminFormValues = {
username: string username: string
nickname?: string nickname?: string
email?: string email?: string
bootstrapSecret: string
password: string password: string
confirmPassword: string confirmPassword: string
} }
@@ -68,12 +69,15 @@ export function BootstrapAdminPage() {
const handleSubmit = useCallback(async (values: BootstrapAdminFormValues) => { const handleSubmit = useCallback(async (values: BootstrapAdminFormValues) => {
setLoading(true) setLoading(true)
try { try {
const tokenBundle = await bootstrapAdmin({ const tokenBundle = await bootstrapAdmin(
username: values.username.trim(), {
nickname: values.nickname?.trim() || undefined, username: values.username.trim(),
email: values.email?.trim() || undefined, nickname: values.nickname?.trim() || undefined,
password: values.password, email: values.email?.trim() || undefined,
}) password: values.password,
},
values.bootstrapSecret.trim(),
)
await onLoginSuccess(tokenBundle) await onLoginSuccess(tokenBundle)
message.success('管理员初始化完成') message.success('管理员初始化完成')
navigate('/dashboard', { replace: true }) navigate('/dashboard', { replace: true })
@@ -152,6 +156,17 @@ export function BootstrapAdminPage() {
autoComplete="email" autoComplete="email"
/> />
</Form.Item> </Form.Item>
<Form.Item
name="bootstrapSecret"
rules={[{ required: true, message: '请输入引导密钥' }]}
>
<Input.Password
prefix={<LockOutlined />}
placeholder="引导密钥"
size="large"
autoComplete="off"
/>
</Form.Item>
<Form.Item <Form.Item
name="password" name="password"
rules={[{ required: true, message: '请输入管理员密码' }]} rules={[{ required: true, message: '请输入管理员密码' }]}

View File

@@ -29,6 +29,7 @@ const assignMock = vi.fn()
const getAuthCapabilitiesMock = vi.fn<() => Promise<AuthCapabilities>>() const getAuthCapabilitiesMock = vi.fn<() => Promise<AuthCapabilities>>()
const getOAuthAuthorizationUrlMock = vi.fn() const getOAuthAuthorizationUrlMock = vi.fn()
const loginByPasswordMock = vi.fn() const loginByPasswordMock = vi.fn()
const verifyTOTPAfterPasswordLoginMock = vi.fn()
const loginByEmailCodeMock = vi.fn() const loginByEmailCodeMock = vi.fn()
const loginBySmsCodeMock = vi.fn() const loginBySmsCodeMock = vi.fn()
const sendEmailCodeMock = vi.fn() const sendEmailCodeMock = vi.fn()
@@ -73,6 +74,7 @@ vi.mock('@/services/auth', () => ({
getOAuthAuthorizationUrl: (provider: string, returnTo: string) => getOAuthAuthorizationUrl: (provider: string, returnTo: string) =>
getOAuthAuthorizationUrlMock(provider, returnTo), getOAuthAuthorizationUrlMock(provider, returnTo),
loginByPassword: (payload: unknown) => loginByPasswordMock(payload), loginByPassword: (payload: unknown) => loginByPasswordMock(payload),
verifyTOTPAfterPasswordLogin: (payload: unknown) => verifyTOTPAfterPasswordLoginMock(payload),
loginByEmailCode: (payload: unknown) => loginByEmailCodeMock(payload), loginByEmailCode: (payload: unknown) => loginByEmailCodeMock(payload),
loginBySmsCode: (payload: unknown) => loginBySmsCodeMock(payload), loginBySmsCode: (payload: unknown) => loginBySmsCodeMock(payload),
sendEmailCode: (payload: unknown) => sendEmailCodeMock(payload), sendEmailCode: (payload: unknown) => sendEmailCodeMock(payload),
@@ -127,6 +129,7 @@ describe('LoginPage', () => {
getAuthCapabilitiesMock.mockReset() getAuthCapabilitiesMock.mockReset()
getOAuthAuthorizationUrlMock.mockReset() getOAuthAuthorizationUrlMock.mockReset()
loginByPasswordMock.mockReset() loginByPasswordMock.mockReset()
verifyTOTPAfterPasswordLoginMock.mockReset()
loginByEmailCodeMock.mockReset() loginByEmailCodeMock.mockReset()
loginBySmsCodeMock.mockReset() loginBySmsCodeMock.mockReset()
sendEmailCodeMock.mockReset() sendEmailCodeMock.mockReset()
@@ -280,6 +283,49 @@ describe('LoginPage', () => {
expect(navigateMock).not.toHaveBeenCalled() expect(navigateMock).not.toHaveBeenCalled()
}) })
it('holds password login on a TOTP challenge and completes verification before creating a session', async () => {
loginByPasswordMock.mockResolvedValue({
requires_totp: true,
user_id: 1,
temp_token: 'totp-challenge-token',
})
verifyTOTPAfterPasswordLoginMock.mockResolvedValue(loginTokenBundle)
renderLoginPage('/login?redirect=/profile')
await waitFor(() => expect(getAuthCapabilitiesMock).toHaveBeenCalledTimes(1))
fireEvent.change(screen.getByPlaceholderText(TEXT.usernamePlaceholder), {
target: { value: 'admin' },
})
fireEvent.change(screen.getByPlaceholderText(TEXT.passwordPlaceholder), {
target: { value: 'SecurePass123!' },
})
fireEvent.click(screen.getByRole('button'))
await waitFor(() => expect(loginByPasswordMock).toHaveBeenCalledTimes(1))
expect(onLoginSuccessMock).not.toHaveBeenCalled()
expect(screen.getByPlaceholderText('TOTP code')).toBeInTheDocument()
fireEvent.change(screen.getByPlaceholderText('TOTP code'), {
target: { value: '123456' },
})
fireEvent.click(screen.getByRole('button', { name: /verify totp/i }))
await waitFor(() => {
expect(verifyTOTPAfterPasswordLoginMock).toHaveBeenCalledWith({
user_id: 1,
code: '123456',
device_id: expect.any(String),
temp_token: 'totp-challenge-token',
})
})
expect(onLoginSuccessMock).toHaveBeenCalledWith(loginTokenBundle)
expect(navigateMock).toHaveBeenCalledWith('/profile', { replace: true })
})
it('sends an email verification code and starts the resend countdown', async () => { it('sends an email verification code and starts the resend countdown', async () => {
getAuthCapabilitiesMock.mockResolvedValue({ getAuthCapabilitiesMock.mockResolvedValue({
...defaultCapabilities, ...defaultCapabilities,

View File

@@ -22,8 +22,9 @@ import {
loginBySmsCode, loginBySmsCode,
sendEmailCode, sendEmailCode,
sendSmsCode, sendSmsCode,
verifyTOTPAfterPasswordLogin,
} from '@/services/auth' } from '@/services/auth'
import type { AuthCapabilities, TokenBundle } from '@/types' import type { AuthCapabilities, PasswordLoginChallenge, PasswordLoginResponse, TokenBundle } from '@/types'
const { Paragraph, Text, Title } = Typography const { Paragraph, Text, Title } = Typography
@@ -53,6 +54,19 @@ type SmsCodeFormValues = {
code: string code: string
} }
function isPasswordLoginChallenge(
result: PasswordLoginResponse,
): result is PasswordLoginChallenge {
return (
typeof result === 'object' &&
result !== null &&
'requires_totp' in result &&
result.requires_totp === true &&
typeof result.user_id === 'number' &&
typeof result.temp_token === 'string'
)
}
export function LoginPage() { export function LoginPage() {
const [activeTab, setActiveTab] = useState('password') const [activeTab, setActiveTab] = useState('password')
const [loading, setLoading] = useState(false) const [loading, setLoading] = useState(false)
@@ -60,6 +74,8 @@ export function LoginPage() {
const [emailCountdown, setEmailCountdown] = useState(0) const [emailCountdown, setEmailCountdown] = useState(0)
const [smsCountdown, setSmsCountdown] = useState(0) const [smsCountdown, setSmsCountdown] = useState(0)
const [capabilities, setCapabilities] = useState<AuthCapabilities>(DEFAULT_CAPABILITIES) const [capabilities, setCapabilities] = useState<AuthCapabilities>(DEFAULT_CAPABILITIES)
const [pendingTOTP, setPendingTOTP] = useState<(PasswordLoginChallenge & { device_id?: string }) | null>(null)
const [totpCode, setTotpCode] = useState('')
const [emailForm] = Form.useForm<EmailCodeFormValues>() const [emailForm] = Form.useForm<EmailCodeFormValues>()
const [smsForm] = Form.useForm<SmsCodeFormValues>() const [smsForm] = Form.useForm<SmsCodeFormValues>()
@@ -151,6 +167,8 @@ export function LoginPage() {
const handlePasswordLogin = useCallback(async (values: LoginFormValues) => { const handlePasswordLogin = useCallback(async (values: LoginFormValues) => {
setLoading(true) setLoading(true)
setPendingTOTP(null)
setTotpCode('')
try { try {
const deviceInfo = getDeviceFingerprint() const deviceInfo = getDeviceFingerprint()
const tokenBundle = await loginByPassword({ const tokenBundle = await loginByPassword({
@@ -158,6 +176,17 @@ export function LoginPage() {
password: values.password, password: values.password,
...deviceInfo, ...deviceInfo,
}) })
if (isPasswordLoginChallenge(tokenBundle)) {
setPendingTOTP({
...tokenBundle,
device_id: deviceInfo.device_id,
})
setTotpCode('')
return
}
setPendingTOTP(null)
setTotpCode('')
await handleLoginSuccess(tokenBundle) await handleLoginSuccess(tokenBundle)
} catch (error) { } catch (error) {
message.error(getErrorMessage(error, '登录失败,请检查用户名和密码')) message.error(getErrorMessage(error, '登录失败,请检查用户名和密码'))
@@ -166,6 +195,29 @@ export function LoginPage() {
} }
}, [handleLoginSuccess]) }, [handleLoginSuccess])
const handleTOTPVerification = useCallback(async () => {
if (!pendingTOTP) {
return
}
setLoading(true)
try {
const tokenBundle = await verifyTOTPAfterPasswordLogin({
user_id: pendingTOTP.user_id,
code: totpCode,
device_id: pendingTOTP.device_id,
temp_token: pendingTOTP.temp_token,
})
setPendingTOTP(null)
setTotpCode('')
await handleLoginSuccess(tokenBundle)
} catch (error) {
message.error(getErrorMessage(error, 'TOTP verification failed'))
} finally {
setLoading(false)
}
}, [handleLoginSuccess, pendingTOTP, totpCode])
const handleSendEmailCode = useCallback(async () => { const handleSendEmailCode = useCallback(async () => {
try { try {
const values = await emailForm.validateFields(['email']) const values = await emailForm.validateFields(['email'])
@@ -232,6 +284,33 @@ export function LoginPage() {
key: 'password', key: 'password',
label: '密码登录', label: '密码登录',
children: ( children: (
pendingTOTP ? (
<Space direction="vertical" size={16} style={{ width: '100%' }}>
<Alert
type="info"
showIcon
message="TOTP verification required"
description="Enter the code from your authenticator app to finish signing in."
/>
<Input
prefix={<SafetyOutlined />}
placeholder="TOTP code"
size="large"
maxLength={6}
value={totpCode}
onChange={(event) => setTotpCode(event.target.value)}
/>
<Button
type="primary"
size="large"
block
loading={loading}
onClick={() => void handleTOTPVerification()}
>
Verify TOTP
</Button>
</Space>
) : (
<Form<LoginFormValues> layout="vertical" onFinish={handlePasswordLogin} autoComplete="off"> <Form<LoginFormValues> layout="vertical" onFinish={handlePasswordLogin} autoComplete="off">
<Form.Item name="username" rules={[{ required: true, message: '请输入用户名' }]}> <Form.Item name="username" rules={[{ required: true, message: '请输入用户名' }]}>
<Input <Input
@@ -255,6 +334,7 @@ export function LoginPage() {
</Button> </Button>
</Form.Item> </Form.Item>
</Form> </Form>
)
), ),
}, },
] ]
@@ -387,12 +467,15 @@ export function LoginPage() {
emailForm, emailForm,
handleEmailCodeLogin, handleEmailCodeLogin,
handlePasswordLogin, handlePasswordLogin,
handleTOTPVerification,
handleSendEmailCode, handleSendEmailCode,
handleSendSmsCode, handleSendSmsCode,
handleSmsCodeLogin, handleSmsCodeLogin,
loading, loading,
pendingTOTP,
smsCountdown, smsCountdown,
smsForm, smsForm,
totpCode,
]) ])
const currentTab = tabItems.find((item) => item.key === activeTab) ?? tabItems[0] const currentTab = tabItems.find((item) => item.key === activeTab) ?? tabItems[0]

View File

@@ -41,16 +41,13 @@ const defaultCapabilities: AuthCapabilities = {
} }
const activeRegisterResponse: RegisterResponse = { const activeRegisterResponse: RegisterResponse = {
user: { id: 2,
id: 2, username: 'new-user',
username: 'new-user', email: 'new-user@example.com',
email: 'new-user@example.com', phone: '',
phone: '', nickname: 'New User',
nickname: 'New User', avatar: '',
avatar: '', status: 1,
status: 1,
},
message: 'registered successfully',
} }
vi.mock('@/services/auth', () => ({ vi.mock('@/services/auth', () => ({
@@ -321,16 +318,13 @@ describe('RegisterPage', () => {
email_activation: true, email_activation: true,
}) })
registerMock.mockResolvedValue({ registerMock.mockResolvedValue({
user: { id: 3,
id: 3, username: 'inactive-user',
username: 'inactive-user', email: 'inactive-user@example.com',
email: 'inactive-user@example.com', phone: '',
phone: '', nickname: 'Inactive User',
nickname: 'Inactive User', avatar: '',
avatar: '', status: 0,
status: 0,
},
message: 'registered successfully, please check your email to activate the account',
}) })
renderRegisterPage() renderRegisterPage()
@@ -350,16 +344,13 @@ describe('RegisterPage', () => {
it('shows the generic activation summary when the new inactive account has no email address', async () => { it('shows the generic activation summary when the new inactive account has no email address', async () => {
registerMock.mockResolvedValue({ registerMock.mockResolvedValue({
user: { id: 4,
id: 4, username: 'inactive-without-email',
username: 'inactive-without-email', email: '',
email: '', phone: '',
phone: '', nickname: '',
nickname: '', avatar: '',
avatar: '', status: 0,
status: 0,
},
message: 'registered successfully, activation required',
}) })
renderRegisterPage() renderRegisterPage()

View File

@@ -39,9 +39,9 @@ type RegisterFormValues = {
} }
function buildRegisterSummary(result: RegisterResponse) { function buildRegisterSummary(result: RegisterResponse) {
if (result.user.status === 0) { if (result.status === 0) {
if (result.user.email) { if (result.email) {
return `账号已创建,激活邮件会发送到 ${result.user.email}。请完成激活后再登录。` return `账号已创建,激活邮件会发送到 ${result.email}。请完成激活后再登录。`
} }
return '账号已创建,请按页面提示完成激活后再登录。' return '账号已创建,请按页面提示完成激活后再登录。'
} }
@@ -128,7 +128,7 @@ export function RegisterPage() {
form.resetFields() form.resetFields()
setSmsCountdown(0) setSmsCountdown(0)
setSubmitted(result) setSubmitted(result)
message.success(result.user.status === 0 ? '注册成功,请完成邮箱激活' : '注册成功') message.success(result.status === 0 ? '注册成功,请完成邮箱激活' : '注册成功')
} catch (error) { } catch (error) {
message.error(getErrorMessage(error, '注册失败,请检查输入信息后重试')) message.error(getErrorMessage(error, '注册失败,请检查输入信息后重试'))
} finally { } finally {
@@ -137,7 +137,7 @@ export function RegisterPage() {
}, [capabilities.sms_code, form]) }, [capabilities.sms_code, form])
if (submitted) { if (submitted) {
const activationEmail = submitted.user.email?.trim() const activationEmail = submitted.email?.trim()
return ( return (
<AuthLayout> <AuthLayout>
@@ -146,7 +146,7 @@ export function RegisterPage() {
title="注册成功" title="注册成功"
subTitle={( subTitle={(
<Paragraph> <Paragraph>
<Text strong>{submitted.user.username}</Text> <Text strong>{submitted.username}</Text>
{' '} {' '}
{buildRegisterSummary(submitted)} {buildRegisterSummary(submitted)}
</Paragraph> </Paragraph>
@@ -155,7 +155,7 @@ export function RegisterPage() {
<Link key="login" to="/login"> <Link key="login" to="/login">
<Button type="primary"></Button> <Button type="primary"></Button>
</Link>, </Link>,
submitted.user.status === 0 && activationEmail && capabilities.email_activation ? ( submitted.status === 0 && activationEmail && capabilities.email_activation ? (
<Link key="activation" to={`/activate-account?email=${encodeURIComponent(activationEmail)}`}> <Link key="activation" to={`/activate-account?email=${encodeURIComponent(activationEmail)}`}>
<Button></Button> <Button></Button>
</Link> </Link>

View File

@@ -106,7 +106,7 @@ describe('auth service', () => {
) )
}) })
it('submits first-admin bootstrap without auth headers', async () => { it('submits first-admin bootstrap with the bootstrap secret header', async () => {
const { bootstrapAdmin } = await import('./auth') const { bootstrapAdmin } = await import('./auth')
await bootstrapAdmin({ await bootstrapAdmin({
@@ -114,7 +114,7 @@ describe('auth service', () => {
password: 'Bootstrap123!@#', password: 'Bootstrap123!@#',
email: 'bootstrap_admin@example.com', email: 'bootstrap_admin@example.com',
nickname: 'Bootstrap Admin', nickname: 'Bootstrap Admin',
}) }, 'bootstrap-secret')
expect(postMock).toHaveBeenCalledWith( expect(postMock).toHaveBeenCalledWith(
'/auth/bootstrap-admin', '/auth/bootstrap-admin',
@@ -124,7 +124,13 @@ describe('auth service', () => {
email: 'bootstrap_admin@example.com', email: 'bootstrap_admin@example.com',
nickname: 'Bootstrap Admin', nickname: 'Bootstrap Admin',
}, },
{ auth: false, credentials: 'include' }, {
auth: false,
credentials: 'include',
headers: {
'X-Bootstrap-Secret': 'bootstrap-secret',
},
},
) )
}) })

View File

@@ -8,6 +8,7 @@ import type {
LoginByPasswordRequest, LoginByPasswordRequest,
LoginBySmsCodeRequest, LoginBySmsCodeRequest,
OAuthAuthorizationResponse, OAuthAuthorizationResponse,
PasswordLoginResponse,
RegisterRequest, RegisterRequest,
RegisterResponse, RegisterResponse,
ResendActivationEmailRequest, ResendActivationEmailRequest,
@@ -37,8 +38,8 @@ export async function getAuthCapabilities(): Promise<AuthCapabilities> {
return normalizeAuthCapabilities(capabilities) return normalizeAuthCapabilities(capabilities)
} }
export function loginByPassword(data: LoginByPasswordRequest): Promise<TokenBundle> { export function loginByPassword(data: LoginByPasswordRequest): Promise<PasswordLoginResponse> {
return post<TokenBundle>('/auth/login', data, { auth: false, credentials: 'include' }) return post<PasswordLoginResponse>('/auth/login', data, { auth: false, credentials: 'include' })
} }
// Verify TOTP after password login when requires_totp is returned // Verify TOTP after password login when requires_totp is returned
@@ -58,8 +59,17 @@ export function register(data: RegisterRequest): Promise<RegisterResponse> {
return post<RegisterResponse>('/auth/register', data, { auth: false }) return post<RegisterResponse>('/auth/register', data, { auth: false })
} }
export function bootstrapAdmin(data: BootstrapAdminRequest): Promise<TokenBundle> { export function bootstrapAdmin(
return post<TokenBundle>('/auth/bootstrap-admin', data, { auth: false, credentials: 'include' }) data: BootstrapAdminRequest,
bootstrapSecret: string,
): Promise<TokenBundle> {
return post<TokenBundle>('/auth/bootstrap-admin', data, {
auth: false,
credentials: 'include',
headers: {
'X-Bootstrap-Secret': bootstrapSecret,
},
})
} }
export function activateEmail(token: string): Promise<ActionMessageResponse> { export function activateEmail(token: string): Promise<ActionMessageResponse> {

View File

@@ -24,6 +24,11 @@ describe('additional service adapters', () => {
}) })
it('routes the remaining users service methods through the HTTP client', async () => { it('routes the remaining users service methods through the HTTP client', async () => {
getMock
.mockResolvedValueOnce({ items: [], total: 0, page: 2, page_size: 50 })
.mockResolvedValueOnce({ id: 7 })
.mockResolvedValueOnce([])
const { const {
listUsers, listUsers,
getUser, getUser,

View File

@@ -15,7 +15,7 @@ describe('social account service', () => {
getMock.mockReset() getMock.mockReset()
postMock.mockReset() postMock.mockReset()
delMock.mockReset() delMock.mockReset()
getMock.mockResolvedValue([]) getMock.mockResolvedValue({ accounts: [] })
postMock.mockResolvedValue({ auth_url: 'https://oauth.example.com', state: 'state-demo' }) postMock.mockResolvedValue({ auth_url: 'https://oauth.example.com', state: 'state-demo' })
delMock.mockResolvedValue(undefined) delMock.mockResolvedValue(undefined)
}) })
@@ -23,9 +23,31 @@ describe('social account service', () => {
it('lists current user social accounts', async () => { it('lists current user social accounts', async () => {
const { listSocialAccounts } = await import('./social-accounts') const { listSocialAccounts } = await import('./social-accounts')
await listSocialAccounts() getMock.mockResolvedValue({
accounts: [
{
id: 1,
provider: 'github',
open_id: 'github-open-id',
union_id: '',
nickname: 'octocat',
avatar: 'https://example.com/avatar.png',
gender: 0,
email: 'octocat@example.com',
phone: '',
extra: '{}',
status: 1,
created_at: '2026-03-27 20:00:00',
updated_at: '2026-03-27 20:00:00',
},
],
})
const accounts = await listSocialAccounts()
expect(getMock).toHaveBeenCalledWith('/users/me/social-accounts') expect(getMock).toHaveBeenCalledWith('/users/me/social-accounts')
expect(accounts).toHaveLength(1)
expect(accounts[0]).toMatchObject({ provider: 'github', nickname: 'octocat' })
}) })
it('starts social binding with the current verification payload', async () => { it('starts social binding with the current verification payload', async () => {

View File

@@ -6,8 +6,14 @@ import type {
SocialBindingStartResponse, SocialBindingStartResponse,
} from '@/types' } from '@/types'
interface SocialAccountsResponse {
accounts: SocialAccountInfo[] | null
}
export function listSocialAccounts(): Promise<SocialAccountInfo[]> { export function listSocialAccounts(): Promise<SocialAccountInfo[]> {
return get<SocialAccountInfo[]>('/users/me/social-accounts') return get<SocialAccountsResponse>('/users/me/social-accounts').then((result) => (
Array.isArray(result.accounts) ? result.accounts : []
))
} }
export function startSocialBinding( export function startSocialBinding(

View File

@@ -32,4 +32,44 @@ describe('users service', () => {
expect(postMock).toHaveBeenCalledWith('/users', payload) expect(postMock).toHaveBeenCalledWith('/users', payload)
}) })
it('normalizes the legacy backend user list response', async () => {
getMock.mockResolvedValue({
users: [
{
id: 11,
username: 'legacy-admin',
email: 'legacy-admin@example.com',
nickname: 'Legacy Admin',
status: '1',
},
],
total: 1,
offset: 20,
limit: 10,
})
const { listUsers } = await import('./users')
const result = await listUsers({ page: 3, page_size: 10, keyword: 'legacy' })
expect(getMock).toHaveBeenCalledWith('/users', {
page: 3,
page_size: 10,
keyword: 'legacy',
})
expect(result).toEqual({
items: [
{
id: 11,
username: 'legacy-admin',
email: 'legacy-admin@example.com',
nickname: 'Legacy Admin',
status: '1',
},
],
total: 1,
page: 3,
page_size: 10,
})
})
}) })

View File

@@ -17,12 +17,44 @@ import type {
AssignUserRolesRequest, AssignUserRolesRequest,
} from '@/types/user' } from '@/types/user'
interface LegacyUserListResponse {
users: User[]
total: number
offset?: number
limit?: number
}
function isLegacyUserListResponse(
result: PaginatedData<User> | LegacyUserListResponse,
): result is LegacyUserListResponse {
return Array.isArray((result as LegacyUserListResponse).users)
}
/** /**
* 获取用户列表 * 获取用户列表
* GET /api/v1/users * GET /api/v1/users
*/ */
export function listUsers(params: UserListParams): Promise<PaginatedData<User>> { export async function listUsers(params: UserListParams): Promise<PaginatedData<User>> {
return get<PaginatedData<User>>('/users', params as Record<string, string | number | boolean | undefined>) const result = await get<PaginatedData<User> | LegacyUserListResponse>(
'/users',
params as Record<string, string | number | boolean | undefined>,
)
if (!isLegacyUserListResponse(result)) {
return result
}
const pageSize = result.limit ?? params.page_size
const page = pageSize && pageSize > 0
? Math.floor((result.offset ?? 0) / pageSize) + 1
: params.page
return {
items: result.users,
total: result.total,
page,
page_size: pageSize,
}
} }
/** /**

View File

@@ -22,7 +22,7 @@ describe('webhooks service', () => {
it('normalizes mixed raw event payloads from the API', async () => { it('normalizes mixed raw event payloads from the API', async () => {
getMock.mockResolvedValue({ getMock.mockResolvedValue({
data: [ list: [
{ {
id: 1, id: 1,
name: 'String Events', name: 'String Events',
@@ -87,7 +87,22 @@ describe('webhooks service', () => {
created_at: '2026-03-27 20:15:00', created_at: '2026-03-27 20:15:00',
updated_at: '2026-03-27 20:15:00', updated_at: '2026-03-27 20:15:00',
}) })
getMock.mockResolvedValue([]) getMock.mockResolvedValue({
deliveries: [
{
id: 7,
webhook_id: 9,
event_type: 'user.updated',
payload: '{"id":1}',
status_code: 200,
response_body: 'ok',
attempt: 1,
success: true,
error: '',
created_at: '2026-03-27 20:20:00',
},
],
})
const { const {
createWebhook, createWebhook,
@@ -121,7 +136,9 @@ describe('webhooks service', () => {
await deleteWebhook(9) await deleteWebhook(9)
expect(delMock).toHaveBeenCalledWith('/webhooks/9') expect(delMock).toHaveBeenCalledWith('/webhooks/9')
await getWebhookDeliveries(9, { limit: 20 }) const deliveries = await getWebhookDeliveries(9, { limit: 20 })
expect(getMock).toHaveBeenCalledWith('/webhooks/9/deliveries', { limit: 20 }) expect(getMock).toHaveBeenCalledWith('/webhooks/9/deliveries', { limit: 20 })
expect(deliveries).toHaveLength(1)
expect(deliveries[0]).toMatchObject({ webhook_id: 9, status_code: 200 })
}) })
}) })

View File

@@ -32,18 +32,25 @@ function normalizeWebhook(webhook: RawWebhook): Webhook {
} }
} }
interface PaginatedResponse<T> { interface WebhookListResponse<T> {
data: T[] list: T[]
total: number total: number
page: number page: number
page_size: number page_size: number
} }
interface WebhookDeliveriesResponse {
deliveries: WebhookDelivery[]
}
export async function listWebhooks( export async function listWebhooks(
params?: WebhookListParams, params?: WebhookListParams,
): Promise<{ data: Webhook[]; total: number; page: number; page_size: number }> { ): Promise<{ data: Webhook[]; total: number; page: number; page_size: number }> {
const result = await get<PaginatedResponse<RawWebhook>>('/webhooks', params as Record<string, string | number | boolean | undefined>) const result = await get<WebhookListResponse<RawWebhook>>(
const webhooks = result.data.map(normalizeWebhook) '/webhooks',
params as Record<string, string | number | boolean | undefined>,
)
const webhooks = result.list.map(normalizeWebhook)
return { data: webhooks, total: result.total, page: result.page, page_size: result.page_size } return { data: webhooks, total: result.total, page: result.page, page_size: result.page_size }
} }
@@ -67,8 +74,8 @@ export function getWebhookDeliveries(
id: number, id: number,
params?: WebhookDeliveryListParams, params?: WebhookDeliveryListParams,
): Promise<WebhookDelivery[]> { ): Promise<WebhookDelivery[]> {
return get<WebhookDelivery[]>( return get<WebhookDeliveriesResponse>(
`/webhooks/${id}/deliveries`, `/webhooks/${id}/deliveries`,
params as Record<string, string | number | boolean | undefined>, params as Record<string, string | number | boolean | undefined>,
) ).then((result) => result.deliveries)
} }

View File

@@ -15,16 +15,21 @@ export interface TokenBundle {
refresh_token?: string refresh_token?: string
expires_in: number expires_in: number
user: SessionUser user: SessionUser
// TOTP required response (when user has TOTP enabled but device is not trusted)
requires_totp?: boolean
user_id?: number
} }
// TOTP verification request after password login export interface PasswordLoginChallenge {
requires_totp: true
user_id: number
temp_token: string
}
export type PasswordLoginResponse = TokenBundle | PasswordLoginChallenge
export interface TOTPVerifyRequest { export interface TOTPVerifyRequest {
user_id: number user_id: number
code: string code: string
device_id?: string device_id?: string
temp_token: string
} }
export interface OAuthProviderInfo { export interface OAuthProviderInfo {
@@ -94,10 +99,7 @@ export interface BootstrapAdminRequest {
nickname?: string nickname?: string
} }
export interface RegisterResponse { export type RegisterResponse = SessionUser
user: SessionUser
message: string
}
export interface ActionMessageResponse { export interface ActionMessageResponse {
message: string message: string

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/subtle" "crypto/subtle"
"errors" "errors"
"io"
"net/http" "net/http"
"os" "os"
"strings" "strings"
@@ -15,6 +16,11 @@ import (
"github.com/user-management-system/internal/service" "github.com/user-management-system/internal/service"
) )
const (
refreshTokenCookieName = "ums_refresh_token"
sessionPresenceCookieName = "ums_session_present"
)
// newBackgroundCtx 创建用于后台 goroutine 的带超时独立 context与请求 context 无关) // newBackgroundCtx 创建用于后台 goroutine 的带超时独立 context与请求 context 无关)
func newBackgroundCtx(timeoutSec int) (context.Context, context.CancelFunc) { func newBackgroundCtx(timeoutSec int) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) return context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
@@ -129,6 +135,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
handleError(c, err) handleError(c, err)
return return
} }
setSessionCookies(c, h.authService, resp.RefreshToken)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"code": 0, "code": 0,
@@ -150,20 +157,28 @@ func (h *AuthHandler) Login(c *gin.Context) {
// @Router /api/v1/auth/login/totp-verify [post] // @Router /api/v1/auth/login/totp-verify [post]
func (h *AuthHandler) VerifyTOTPAfterPasswordLogin(c *gin.Context) { func (h *AuthHandler) VerifyTOTPAfterPasswordLogin(c *gin.Context) {
var req struct { var req struct {
UserID int64 `json:"user_id" binding:"required"` UserID int64 `json:"user_id" binding:"required"`
Code string `json:"code" binding:"required"` Code string `json:"code" binding:"required"`
DeviceID string `json:"device_id"` DeviceID string `json:"device_id"`
TempToken string `json:"temp_token"`
} }
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
return return
} }
resp, err := h.authService.VerifyTOTPAfterPasswordLogin(c.Request.Context(), req.UserID, req.Code, req.DeviceID) resp, err := h.authService.VerifyTOTPAfterPasswordLogin(
c.Request.Context(),
req.UserID,
req.Code,
req.DeviceID,
req.TempToken,
)
if err != nil { if err != nil {
handleError(c, err) handleError(c, err)
return return
} }
setSessionCookies(c, h.authService, resp.RefreshToken)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"code": 0, "code": 0,
@@ -197,6 +212,10 @@ func (h *AuthHandler) Logout(c *gin.Context) {
} }
} }
if req.RefreshToken == "" {
req.RefreshToken, _ = c.Cookie(refreshTokenCookieName)
}
username, _ := c.Get("username") username, _ := c.Get("username")
usernameStr, _ := username.(string) usernameStr, _ := username.(string)
@@ -206,6 +225,8 @@ func (h *AuthHandler) Logout(c *gin.Context) {
} }
_ = h.authService.Logout(c.Request.Context(), usernameStr, logoutReq) _ = h.authService.Logout(c.Request.Context(), usernameStr, logoutReq)
clearSessionCookies(c)
c.JSON(http.StatusOK, gin.H{"message": "logged out"}) c.JSON(http.StatusOK, gin.H{"message": "logged out"})
} }
@@ -222,19 +243,27 @@ func (h *AuthHandler) Logout(c *gin.Context) {
// @Router /api/v1/auth/refresh-token [post] // @Router /api/v1/auth/refresh-token [post]
func (h *AuthHandler) RefreshToken(c *gin.Context) { func (h *AuthHandler) RefreshToken(c *gin.Context) {
var req struct { var req struct {
RefreshToken string `json:"refresh_token" binding:"required"` RefreshToken string `json:"refresh_token"`
} }
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
if req.RefreshToken == "" {
req.RefreshToken, _ = c.Cookie(refreshTokenCookieName)
}
if req.RefreshToken == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "refresh_token is required"})
return
}
resp, err := h.authService.RefreshToken(c.Request.Context(), req.RefreshToken) resp, err := h.authService.RefreshToken(c.Request.Context(), req.RefreshToken)
if err != nil { if err != nil {
handleError(c, err) handleError(c, err)
return return
} }
setSessionCookies(c, h.authService, resp.RefreshToken)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"code": 0, "code": 0,
@@ -480,6 +509,7 @@ func (h *AuthHandler) LoginByEmailCode(c *gin.Context) {
h.authService.BestEffortRegisterDevicePublic(devCtx, userID, loginReq) h.authService.BestEffortRegisterDevicePublic(devCtx, userID, loginReq)
}() }()
} }
setSessionCookies(c, h.authService, resp.RefreshToken)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"code": 0, "code": 0,
@@ -544,6 +574,7 @@ func (h *AuthHandler) BootstrapAdmin(c *gin.Context) {
handleError(c, err) handleError(c, err)
return return
} }
setSessionCookies(c, h.authService, resp.RefreshToken)
c.JSON(http.StatusCreated, gin.H{ c.JSON(http.StatusCreated, gin.H{
"code": 0, "code": 0,
@@ -673,6 +704,46 @@ func getUserIDFromContext(c *gin.Context) (int64, bool) {
return id, ok return id, ok
} }
func setSessionCookies(c *gin.Context, authService *service.AuthService, refreshToken string) {
if c == nil || strings.TrimSpace(refreshToken) == "" {
return
}
maxAge := 0
if authService != nil {
if ttl := authService.RefreshTokenTTLSeconds(); ttl > 0 {
maxAge = int(ttl)
}
}
secure := requestUsesHTTPS(c)
c.SetSameSite(http.SameSiteLaxMode)
c.SetCookie(refreshTokenCookieName, refreshToken, maxAge, "/", "", secure, true)
c.SetCookie(sessionPresenceCookieName, "1", maxAge, "/", "", secure, false)
}
func clearSessionCookies(c *gin.Context) {
if c == nil {
return
}
secure := requestUsesHTTPS(c)
c.SetSameSite(http.SameSiteLaxMode)
c.SetCookie(refreshTokenCookieName, "", -1, "/", "", secure, true)
c.SetCookie(sessionPresenceCookieName, "", -1, "/", "", secure, false)
}
func requestUsesHTTPS(c *gin.Context) bool {
if c == nil || c.Request == nil {
return false
}
if c.Request.TLS != nil {
return true
}
return strings.EqualFold(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")), "https")
}
// handleError 将 error 转换为对应的 HTTP 响应。 // handleError 将 error 转换为对应的 HTTP 响应。
// 优先识别 ApplicationError其次通过关键词推断业务错误类型兜底返回 500。 // 优先识别 ApplicationError其次通过关键词推断业务错误类型兜底返回 500。
func handleError(c *gin.Context, err error) { func handleError(c *gin.Context, err error) {

View File

@@ -31,6 +31,46 @@ import (
var handlerDbCounter int64 var handlerDbCounter int64
func seedHandlerAuthzData(t *testing.T, db *gorm.DB) {
t.Helper()
roleIDs := make(map[string]int64)
for _, predefined := range domain.PredefinedRoles {
role := predefined
if err := db.Create(&role).Error; err != nil {
t.Fatalf("seed role %s failed: %v", role.Code, err)
}
roleIDs[role.Code] = role.ID
}
permissionIDs := make(map[string]int64)
for _, predefined := range domain.DefaultPermissions() {
permission := predefined
if err := db.Create(&permission).Error; err != nil {
t.Fatalf("seed permission %s failed: %v", permission.Code, err)
}
permissionIDs[permission.Code] = permission.ID
}
adminRoleID := roleIDs["admin"]
for _, permissionID := range permissionIDs {
if err := db.Create(&domain.RolePermission{RoleID: adminRoleID, PermissionID: permissionID}).Error; err != nil {
t.Fatalf("assign admin permission %d failed: %v", permissionID, err)
}
}
userRoleID := roleIDs["user"]
for _, code := range []string{"profile:view", "profile:edit", "log:view_own"} {
permissionID, ok := permissionIDs[code]
if !ok {
t.Fatalf("seeded permissions missing %s", code)
}
if err := db.Create(&domain.RolePermission{RoleID: userRoleID, PermissionID: permissionID}).Error; err != nil {
t.Fatalf("assign user permission %s failed: %v", code, err)
}
}
}
func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) { func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
t.Helper() t.Helper()
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
@@ -64,6 +104,8 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
t.Fatalf("db migration failed: %v", err) t.Fatalf("db migration failed: %v", err)
} }
seedHandlerAuthzData(t, db)
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{ jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
HS256Secret: "test-handler-secret-key", HS256Secret: "test-handler-secret-key",
AccessTokenExpire: 15 * time.Minute, AccessTokenExpire: 15 * time.Minute,
@@ -176,6 +218,18 @@ func doDelete(url, token string) (*http.Response, string) {
return doRequest("DELETE", url, token, nil) return doRequest("DELETE", url, token, nil)
} }
func getCookie(resp *http.Response, name string) *http.Cookie {
if resp == nil {
return nil
}
for _, cookie := range resp.Cookies() {
if cookie.Name == name {
return cookie
}
}
return nil
}
func getToken(baseURL, username, password string) string { func getToken(baseURL, username, password string) string {
resp, body := doPost(baseURL+"/api/v1/auth/login", "", map[string]interface{}{ resp, body := doPost(baseURL+"/api/v1/auth/login", "", map[string]interface{}{
"account": username, "account": username,
@@ -207,6 +261,111 @@ func registerUser(baseURL, username, email, password string) bool {
return resp.StatusCode == http.StatusCreated return resp.StatusCode == http.StatusCreated
} }
func bootstrapAdmin(baseURL, secret, username, email, password string) string {
payload, _ := json.Marshal(map[string]interface{}{
"username": username,
"email": email,
"password": password,
})
req, _ := http.NewRequest(http.MethodPost, baseURL+"/api/v1/auth/bootstrap-admin", bytes.NewReader(payload))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Bootstrap-Secret", secret)
resp, err := (&http.Client{}).Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
return ""
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return ""
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return ""
}
data, ok := result["data"].(map[string]interface{})
if !ok || data["access_token"] == nil {
return ""
}
token, _ := data["access_token"].(string)
return token
}
func setupEnabledTOTPUser(t *testing.T, baseURL, username, email, password string) (int64, string) {
t.Helper()
if ok := registerUser(baseURL, username, email, password); !ok {
t.Fatalf("registration failed for %s", username)
}
token := getToken(baseURL, username, password)
if token == "" {
t.Fatalf("failed to get token for %s", username)
}
userInfoResp, userInfoBody := doGet(baseURL+"/api/v1/auth/userinfo", token)
defer userInfoResp.Body.Close()
if userInfoResp.StatusCode != http.StatusOK {
t.Fatalf("userinfo failed: status=%d body=%s", userInfoResp.StatusCode, userInfoBody)
}
var userInfoResult map[string]interface{}
if err := json.Unmarshal([]byte(userInfoBody), &userInfoResult); err != nil {
t.Fatalf("failed to parse userinfo response: %v", err)
}
userData, ok := userInfoResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("userinfo response missing data: %s", userInfoBody)
}
userID, ok := userData["id"].(float64)
if !ok {
t.Fatalf("userinfo response missing id: %s", userInfoBody)
}
setupResp, setupBody := doGet(baseURL+"/api/v1/auth/2fa/setup", token)
defer setupResp.Body.Close()
if setupResp.StatusCode != http.StatusOK {
t.Fatalf("2fa setup failed: status=%d body=%s", setupResp.StatusCode, setupBody)
}
var setupResult map[string]interface{}
if err := json.Unmarshal([]byte(setupBody), &setupResult); err != nil {
t.Fatalf("failed to parse 2fa setup response: %v", err)
}
setupData, ok := setupResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("2fa setup response missing data: %s", setupBody)
}
secret, ok := setupData["secret"].(string)
if !ok || secret == "" {
t.Fatalf("2fa setup response missing secret: %s", setupBody)
}
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
if err != nil {
t.Fatalf("failed to generate TOTP code: %v", err)
}
enableResp, enableBody := doPost(baseURL+"/api/v1/auth/2fa/enable", token, map[string]interface{}{
"code": code,
})
defer enableResp.Body.Close()
if enableResp.StatusCode != http.StatusOK {
t.Fatalf("2fa enable failed: status=%d body=%s", enableResp.StatusCode, enableBody)
}
return int64(userID), secret
}
// ============================================================================= // =============================================================================
// Auth Handler Tests // Auth Handler Tests
// ============================================================================= // =============================================================================
@@ -292,6 +451,38 @@ func TestAuthHandler_Login_Success(t *testing.T) {
} }
} }
func TestAuthHandler_Login_SetsSessionCookies(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "logincookieuser", "logincookie@example.com", "Password123!")
resp, body := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "logincookieuser",
"password": "Password123!",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
refreshCookie := getCookie(resp, "ums_refresh_token")
if refreshCookie == nil || refreshCookie.Value == "" {
t.Fatalf("login response missing refresh cookie, cookies=%v", resp.Cookies())
}
if !refreshCookie.HttpOnly {
t.Fatalf("refresh cookie should be HttpOnly, got %+v", refreshCookie)
}
presenceCookie := getCookie(resp, "ums_session_present")
if presenceCookie == nil || presenceCookie.Value != "1" {
t.Fatalf("login response missing presence cookie, cookies=%v", resp.Cookies())
}
if presenceCookie.HttpOnly {
t.Fatalf("presence cookie should be readable by the frontend, got %+v", presenceCookie)
}
}
func TestAuthHandler_Login_WrongPassword(t *testing.T) { func TestAuthHandler_Login_WrongPassword(t *testing.T) {
server, cleanup := setupHandlerTestServer(t) server, cleanup := setupHandlerTestServer(t)
defer cleanup() defer cleanup()
@@ -360,6 +551,66 @@ func TestAuthHandler_GetAuthCapabilities(t *testing.T) {
} }
} }
func TestAuthHandler_Login_WithTOTPEnabled_ReturnsChallengeToken(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
_, _ = setupEnabledTOTPUser(t, server.URL, "totplogin", "totplogin@example.com", "Password123!")
resp, body := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "totplogin",
"password": "Password123!",
"device_id": "device-login-1",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse login response: %v", err)
}
data, ok := result["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected login response data, got %s", body)
}
if data["requires_totp"] != true {
t.Fatalf("expected requires_totp=true, got %+v", data)
}
tempToken, ok := data["temp_token"].(string)
if !ok || tempToken == "" {
t.Fatalf("expected temp_token in TOTP challenge response, got %+v", data)
}
}
func TestAuthHandler_VerifyTOTPAfterPasswordLogin_RequiresTempToken(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpreverify", "totpreverify@example.com", "Password123!")
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
if err != nil {
t.Fatalf("failed to generate TOTP code: %v", err)
}
resp, body := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
"user_id": userID,
"code": code,
"device_id": "device-login-1",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected status %d when temp_token is missing, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
}
}
// ============================================================================= // =============================================================================
// User Handler Tests // User Handler Tests
// ============================================================================= // =============================================================================
@@ -451,6 +702,26 @@ func TestUserHandler_UpdateUser_Success(t *testing.T) {
} }
} }
func TestUserHandler_UpdateUser_AdminCanUpdateAnotherUser(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "updateadmin", "updateadmin@test.com", "AdminPass123!")
registerUser(server.URL, "targetuser", "targetuser@test.com", "UserPass123!")
if token == "" {
t.Fatal("bootstrap admin should return access token")
}
resp, body := doPut(server.URL+"/api/v1/users/2", token, map[string]string{"nickname": "Updated By Admin"})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
}
func TestUserHandler_DeleteUser_NonAdmin_Forbidden(t *testing.T) { func TestUserHandler_DeleteUser_NonAdmin_Forbidden(t *testing.T) {
server, cleanup := setupHandlerTestServer(t) server, cleanup := setupHandlerTestServer(t)
defer cleanup() defer cleanup()
@@ -515,6 +786,26 @@ func TestUserHandler_GetUserRoles_Success(t *testing.T) {
} }
} }
func TestUserHandler_GetUserRoles_AdminCanViewAnotherUser(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "rolesadmin2", "rolesadmin2@test.com", "AdminPass123!")
registerUser(server.URL, "roles-target", "roles-target@test.com", "UserPass123!")
if token == "" {
t.Fatal("bootstrap admin should return access token")
}
resp, body := doGet(server.URL+"/api/v1/users/2/roles", token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
}
func TestUserHandler_AssignRoles_RequiresAdmin(t *testing.T) { func TestUserHandler_AssignRoles_RequiresAdmin(t *testing.T) {
server, cleanup := setupHandlerTestServer(t) server, cleanup := setupHandlerTestServer(t)
defer cleanup() defer cleanup()
@@ -1253,6 +1544,187 @@ func TestAuthHandler_RefreshToken_Success(t *testing.T) {
} }
} }
func TestAuthHandler_RefreshToken_AcceptsRefreshCookie(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "refreshcookieuser", "refreshcookie@example.com", "Password123!")
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "refreshcookieuser",
"password": "Password123!",
})
defer loginResp.Body.Close()
if loginResp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, loginBody)
}
refreshCookie := getCookie(loginResp, "ums_refresh_token")
if refreshCookie == nil || refreshCookie.Value == "" {
t.Fatalf("login response missing refresh cookie, cookies=%v", loginResp.Cookies())
}
req, err := http.NewRequest("POST", server.URL+"/api/v1/auth/refresh", nil)
if err != nil {
t.Fatalf("create refresh request failed: %v", err)
}
req.AddCookie(refreshCookie)
req.AddCookie(&http.Cookie{Name: "ums_session_present", Value: "1"})
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("refresh request failed: %v", err)
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read refresh response failed: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
}
rotatedCookie := getCookie(resp, "ums_refresh_token")
if rotatedCookie == nil || rotatedCookie.Value == "" {
t.Fatalf("refresh response missing rotated refresh cookie, cookies=%v", resp.Cookies())
}
if rotatedCookie.Value == refreshCookie.Value {
t.Fatalf("refresh should rotate cookie value, old=%q new=%q", refreshCookie.Value, rotatedCookie.Value)
}
presenceCookie := getCookie(resp, "ums_session_present")
if presenceCookie == nil || presenceCookie.Value != "1" {
t.Fatalf("refresh response missing presence cookie, cookies=%v", resp.Cookies())
}
}
func TestAuthHandler_RefreshToken_AllowsImmediateRetryWithPreviousCookie(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "refreshretryuser", "refreshretry@example.com", "Password123!")
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "refreshretryuser",
"password": "Password123!",
})
defer loginResp.Body.Close()
if loginResp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, loginBody)
}
refreshCookie := getCookie(loginResp, "ums_refresh_token")
if refreshCookie == nil || refreshCookie.Value == "" {
t.Fatalf("login response missing refresh cookie, cookies=%v", loginResp.Cookies())
}
newRefreshRequest := func(cookie *http.Cookie) *http.Response {
req, err := http.NewRequest(http.MethodPost, server.URL+"/api/v1/auth/refresh", nil)
if err != nil {
t.Fatalf("create refresh request failed: %v", err)
}
req.AddCookie(cookie)
req.AddCookie(&http.Cookie{Name: "ums_session_present", Value: "1"})
resp, err := (&http.Client{}).Do(req)
if err != nil {
t.Fatalf("refresh request failed: %v", err)
}
return resp
}
firstResp := newRefreshRequest(refreshCookie)
defer firstResp.Body.Close()
firstBody, err := io.ReadAll(firstResp.Body)
if err != nil {
t.Fatalf("read first refresh response failed: %v", err)
}
if firstResp.StatusCode != http.StatusOK {
t.Fatalf("expected first refresh status %d, got %d, body: %s", http.StatusOK, firstResp.StatusCode, string(firstBody))
}
retryResp := newRefreshRequest(refreshCookie)
defer retryResp.Body.Close()
retryBody, err := io.ReadAll(retryResp.Body)
if err != nil {
t.Fatalf("read retry refresh response failed: %v", err)
}
if retryResp.StatusCode != http.StatusOK {
t.Fatalf("expected retry refresh status %d, got %d, body: %s", http.StatusOK, retryResp.StatusCode, string(retryBody))
}
}
func TestAuthHandler_Logout_ClearsSessionCookies(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "logoutcookieuser", "logoutcookie@example.com", "Password123!")
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "logoutcookieuser",
"password": "Password123!",
})
defer loginResp.Body.Close()
if loginResp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, loginBody)
}
var loginResult map[string]interface{}
if err := json.Unmarshal([]byte(loginBody), &loginResult); err != nil {
t.Fatalf("parse login response failed: %v", err)
}
loginData, ok := loginResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("login response missing data: %s", loginBody)
}
accessToken, ok := loginData["access_token"].(string)
if !ok || accessToken == "" {
t.Fatalf("login response missing access token: %s", loginBody)
}
refreshCookie := getCookie(loginResp, "ums_refresh_token")
if refreshCookie == nil || refreshCookie.Value == "" {
t.Fatalf("login response missing refresh cookie, cookies=%v", loginResp.Cookies())
}
req, err := http.NewRequest("POST", server.URL+"/api/v1/auth/logout", nil)
if err != nil {
t.Fatalf("create logout request failed: %v", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.AddCookie(refreshCookie)
req.AddCookie(&http.Cookie{Name: "ums_session_present", Value: "1"})
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("logout request failed: %v", err)
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read logout response failed: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
}
clearedRefreshCookie := getCookie(resp, "ums_refresh_token")
if clearedRefreshCookie == nil || clearedRefreshCookie.Value != "" {
t.Fatalf("logout response should clear refresh cookie, cookies=%v", resp.Cookies())
}
clearedPresenceCookie := getCookie(resp, "ums_session_present")
if clearedPresenceCookie == nil || clearedPresenceCookie.Value != "" {
t.Fatalf("logout response should clear presence cookie, cookies=%v", resp.Cookies())
}
}
func TestAuthHandler_RefreshToken_InvalidToken(t *testing.T) { func TestAuthHandler_RefreshToken_InvalidToken(t *testing.T) {
server, cleanup := setupHandlerTestServer(t) server, cleanup := setupHandlerTestServer(t)
defer cleanup() defer cleanup()

View File

@@ -116,6 +116,7 @@ func (h *SMSHandler) LoginByCode(c *gin.Context) {
h.authService.BestEffortRegisterDevicePublic(devCtx, userID, loginReq) h.authService.BestEffortRegisterDevicePublic(devCtx, userID, loginReq)
}() }()
} }
setSessionCookies(c, h.authService, resp.RefreshToken)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"code": 0, "code": 0,

View File

@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/user-management-system/internal/api/middleware"
"github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/service" "github.com/user-management-system/internal/service"
@@ -187,15 +188,7 @@ func (h *UserHandler) UpdateUser(c *gin.Context) {
// Authorization: only self or admin can update user profile // Authorization: only self or admin can update user profile
currentUserID := c.GetInt64("user_id") currentUserID := c.GetInt64("user_id")
isAdmin := false isAdmin := middleware.IsAdmin(c)
if roles, ok := c.Get("user_roles"); ok {
for _, role := range roles.([]*domain.Role) {
if role.Code == "admin" {
isAdmin = true
break
}
}
}
if currentUserID != id && !isAdmin { if currentUserID != id && !isAdmin {
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"}) c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
return return
@@ -370,15 +363,7 @@ func (h *UserHandler) GetUserRoles(c *gin.Context) {
// Authorization: only self or admin can view user roles // Authorization: only self or admin can view user roles
currentUserID := c.GetInt64("user_id") currentUserID := c.GetInt64("user_id")
isAdmin := false isAdmin := middleware.IsAdmin(c)
if roles, ok := c.Get("user_roles"); ok {
for _, role := range roles.([]*domain.Role) {
if role.Code == "admin" {
isAdmin = true
break
}
}
}
if currentUserID != id && !isAdmin { if currentUserID != id && !isAdmin {
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"}) c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
return return

View File

@@ -0,0 +1,103 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
"github.com/user-management-system/internal/service"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
_ "modernc.org/sqlite"
)
func TestAuthMiddleware_AcceptsBootstrapAdminTokenImmediately(t *testing.T) {
t.Helper()
gin.SetMode(gin.TestMode)
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: "file:middleware_bootstrap_test?mode=memory&cache=shared",
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite failed: %v", err)
}
if err := db.AutoMigrate(&domain.User{}, &domain.Role{}, &domain.UserRole{}); err != nil {
t.Fatalf("migrate failed: %v", err)
}
if err := db.Create(&domain.Role{
Name: "管理员",
Code: "admin",
IsSystem: true,
Status: domain.RoleStatusEnabled,
}).Error; err != nil {
t.Fatalf("seed admin role failed: %v", err)
}
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
HS256Secret: "test-bootstrap-token-secret-at-least-32-chars",
AccessTokenExpire: 15 * time.Minute,
RefreshTokenExpire: 7 * 24 * time.Hour,
})
if err != nil {
t.Fatalf("create jwt manager failed: %v", err)
}
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
userRepo := repository.NewUserRepository(db)
roleRepo := repository.NewRoleRepository(db)
userRoleRepo := repository.NewUserRoleRepository(db)
authService := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
authService.SetRoleRepositories(userRoleRepo, roleRepo)
loginResponse, err := authService.BootstrapAdmin(context.Background(), &service.BootstrapAdminRequest{
Username: "bootstrap_admin",
Email: "bootstrap_admin@example.com",
Password: "AdminPass123!",
}, "127.0.0.1")
if err != nil {
t.Fatalf("bootstrap admin failed: %v", err)
}
if loginResponse == nil || loginResponse.AccessToken == "" {
t.Fatalf("expected bootstrap access token, got %+v", loginResponse)
}
if _, err := jwtManager.ValidateAccessToken(loginResponse.AccessToken); err != nil {
t.Fatalf("bootstrap access token should validate immediately: %v", err)
}
authMiddleware := NewAuthMiddleware(jwtManager, userRepo, userRoleRepo, l1Cache)
authMiddleware.SetCacheManager(cacheManager)
recorder := httptest.NewRecorder()
ctx, engine := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
ctx.Request.Header.Set("Authorization", "Bearer "+loginResponse.AccessToken)
engine.Use(authMiddleware.Required())
engine.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"code": 0})
})
engine.ServeHTTP(recorder, ctx.Request)
if recorder.Code != http.StatusOK {
t.Fatalf("expected bootstrap token to pass auth middleware immediately, got %d body: %s", recorder.Code, recorder.Body.String())
}
}

View File

@@ -1,14 +1,21 @@
package middleware package middleware
import ( import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config" "github.com/user-management-system/internal/config"
) )
// RateLimitMiddleware 限流中间件 // RateLimitMiddleware provides simple in-memory sliding-window rate limiting.
type RateLimitMiddleware struct { type RateLimitMiddleware struct {
cfg config.RateLimitConfig cfg config.RateLimitConfig
limiters map[string]*SlidingWindowLimiter limiters map[string]*SlidingWindowLimiter
@@ -16,7 +23,7 @@ type RateLimitMiddleware struct {
cleanupInt time.Duration cleanupInt time.Duration
} }
// SlidingWindowLimiter 滑动窗口限流器 // SlidingWindowLimiter enforces a fixed-capacity sliding window.
type SlidingWindowLimiter struct { type SlidingWindowLimiter struct {
mu sync.Mutex mu sync.Mutex
window time.Duration window time.Duration
@@ -24,7 +31,6 @@ type SlidingWindowLimiter struct {
requests []int64 requests []int64
} }
// NewSlidingWindowLimiter 创建滑动窗口限流器
func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter { func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter {
return &SlidingWindowLimiter{ return &SlidingWindowLimiter{
window: window, window: window,
@@ -33,7 +39,6 @@ func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindo
} }
} }
// Allow 检查是否允许请求
func (l *SlidingWindowLimiter) Allow() bool { func (l *SlidingWindowLimiter) Allow() bool {
l.mu.Lock() l.mu.Lock()
defer l.mu.Unlock() defer l.mu.Unlock()
@@ -41,16 +46,14 @@ func (l *SlidingWindowLimiter) Allow() bool {
now := time.Now().UnixMilli() now := time.Now().UnixMilli()
cutoff := now - l.window.Milliseconds() cutoff := now - l.window.Milliseconds()
// 清理过期请求 validRequests := make([]int64, 0, len(l.requests))
var validRequests []int64 for _, ts := range l.requests {
for _, t := range l.requests { if ts > cutoff {
if t > cutoff { validRequests = append(validRequests, ts)
validRequests = append(validRequests, t)
} }
} }
l.requests = validRequests l.requests = validRequests
// 检查容量
if int64(len(l.requests)) >= l.capacity { if int64(len(l.requests)) >= l.capacity {
return false return false
} }
@@ -59,7 +62,6 @@ func (l *SlidingWindowLimiter) Allow() bool {
return true return true
} }
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware { func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
return &RateLimitMiddleware{ return &RateLimitMiddleware{
cfg: cfg, cfg: cfg,
@@ -68,30 +70,28 @@ func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
} }
} }
// Register 返回注册接口的限流中间件
func (m *RateLimitMiddleware) Register() gin.HandlerFunc { func (m *RateLimitMiddleware) Register() gin.HandlerFunc {
return m.limitForKey("register", 60, 10) return m.limitForKey("register", 60, 10)
} }
// Login 返回登录接口的限流中间件
func (m *RateLimitMiddleware) Login() gin.HandlerFunc { func (m *RateLimitMiddleware) Login() gin.HandlerFunc {
return m.limitForKey("login", 60, 5) return m.limitForKey("login", 60, 5)
} }
// API 返回 API 接口的限流中间件
func (m *RateLimitMiddleware) API() gin.HandlerFunc { func (m *RateLimitMiddleware) API() gin.HandlerFunc {
return m.limitForKey("api", 60, 100) return m.limitForKey("api", 60, 100)
} }
// Refresh 返回刷新令牌的限流中间件
func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc { func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
return m.limitForKey("refresh", 60, 10) return m.limitForKey("refresh", 60, 10)
} }
func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc { func (m *RateLimitMiddleware) limitForKey(bucket string, windowSeconds int, capacity int64) gin.HandlerFunc {
limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity) window := time.Duration(windowSeconds) * time.Second
return func(c *gin.Context) { return func(c *gin.Context) {
limiterKey := m.resolveLimiterKey(c, bucket)
limiter := m.getOrCreateLimiter(limiterKey, window, capacity)
if !limiter.Allow() { if !limiter.Allow() {
c.JSON(429, gin.H{ c.JSON(429, gin.H{
"code": 429, "code": 429,
@@ -104,6 +104,81 @@ func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacit
} }
} }
func (m *RateLimitMiddleware) resolveLimiterKey(c *gin.Context, bucket string) string {
if bucket == "refresh" {
if refreshToken := extractRefreshToken(c); refreshToken != "" {
return fmt.Sprintf("%s:token:%s", bucket, fingerprintValue(refreshToken))
}
}
identity := "anonymous"
if c != nil {
if userID, ok := c.Get("user_id"); ok {
identity = fmt.Sprintf("user:%v", userID)
} else if ip := c.ClientIP(); ip != "" {
identity = "ip:" + ip
}
}
if bucket == "api" {
method := ""
route := ""
if c != nil {
if c.Request != nil {
method = c.Request.Method
if c.Request.URL != nil {
route = c.Request.URL.Path
}
}
if fullPath := c.FullPath(); fullPath != "" {
route = fullPath
}
}
return fmt.Sprintf("%s:%s:%s:%s", bucket, method, route, identity)
}
return fmt.Sprintf("%s:%s", bucket, identity)
}
func extractRefreshToken(c *gin.Context) string {
if c == nil {
return ""
}
if refreshToken, err := c.Cookie("ums_refresh_token"); err == nil && refreshToken != "" {
return refreshToken
}
if c.Request == nil || c.Request.Body == nil {
return ""
}
body, err := io.ReadAll(c.Request.Body)
if err != nil {
return ""
}
c.Request.Body = io.NopCloser(bytes.NewReader(body))
if len(bytes.TrimSpace(body)) == 0 {
return ""
}
var payload struct {
RefreshToken string `json:"refresh_token"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return ""
}
return payload.RefreshToken
}
func fingerprintValue(value string) string {
sum := sha256.Sum256([]byte(value))
return hex.EncodeToString(sum[:12])
}
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter { func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
m.mu.RLock() m.mu.RLock()
limiter, exists := m.limiters[key] limiter, exists := m.limiters[key]
@@ -116,7 +191,6 @@ func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duratio
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
// 双重检查
if limiter, exists = m.limiters[key]; exists { if limiter, exists = m.limiters[key]; exists {
return limiter return limiter
} }

View File

@@ -0,0 +1,140 @@
package middleware
import (
"bytes"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
)
func performRateLimitedRequest(router *gin.Engine, path string, userID int64) *httptest.ResponseRecorder {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, path, nil)
req.RemoteAddr = "127.0.0.1:12345"
req.Header.Set("X-Test-User-ID", strconv.FormatInt(userID, 10))
router.ServeHTTP(recorder, req)
return recorder
}
func performRefreshRateLimitedRequestWithCookie(router *gin.Engine, refreshToken string) *httptest.ResponseRecorder {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/auth/refresh", nil)
req.RemoteAddr = "127.0.0.1:12345"
if refreshToken != "" {
req.AddCookie(&http.Cookie{Name: "ums_refresh_token", Value: refreshToken})
}
router.ServeHTTP(recorder, req)
return recorder
}
func performRefreshRateLimitedRequestWithBody(router *gin.Engine, refreshToken string) *httptest.ResponseRecorder {
recorder := httptest.NewRecorder()
body := bytes.NewBufferString(`{"refresh_token":"` + refreshToken + `"}`)
req := httptest.NewRequest(http.MethodPost, "/auth/refresh", body)
req.RemoteAddr = "127.0.0.1:12345"
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
return recorder
}
func TestRateLimitMiddleware_API_ScopesBudgetByRouteForAuthenticatedUser(t *testing.T) {
gin.SetMode(gin.TestMode)
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
router := gin.New()
router.Use(func(c *gin.Context) {
rawUserID := c.GetHeader("X-Test-User-ID")
if rawUserID != "" {
userID, err := strconv.ParseInt(rawUserID, 10, 64)
if err == nil {
c.Set("user_id", userID)
}
}
c.Next()
})
protected := router.Group("")
protected.Use(rateLimitMiddleware.API())
protected.GET("/users", func(c *gin.Context) {
c.Status(http.StatusOK)
})
protected.GET("/roles", func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 100; i++ {
recorder := performRateLimitedRequest(router, "/users", 1)
if recorder.Code != http.StatusOK {
t.Fatalf("request %d to /users returned %d, want %d", i+1, recorder.Code, http.StatusOK)
}
}
sameRouteOverflow := performRateLimitedRequest(router, "/users", 1)
if sameRouteOverflow.Code != http.StatusTooManyRequests {
t.Fatalf("overflow request to /users returned %d, want %d", sameRouteOverflow.Code, http.StatusTooManyRequests)
}
differentRoute := performRateLimitedRequest(router, "/roles", 1)
if differentRoute.Code != http.StatusOK {
t.Fatalf("request to /roles after exhausting /users budget returned %d, want %d", differentRoute.Code, http.StatusOK)
}
}
func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshCookie(t *testing.T) {
gin.SetMode(gin.TestMode)
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
router := gin.New()
router.POST("/auth/refresh", rateLimitMiddleware.Refresh(), func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 10; i++ {
recorder := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-a")
if recorder.Code != http.StatusOK {
t.Fatalf("request %d for refresh-token-a returned %d, want %d", i+1, recorder.Code, http.StatusOK)
}
}
sameTokenOverflow := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-a")
if sameTokenOverflow.Code != http.StatusTooManyRequests {
t.Fatalf("overflow request for refresh-token-a returned %d, want %d", sameTokenOverflow.Code, http.StatusTooManyRequests)
}
differentToken := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-b")
if differentToken.Code != http.StatusOK {
t.Fatalf("request for refresh-token-b after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK)
}
}
func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshTokenBody(t *testing.T) {
gin.SetMode(gin.TestMode)
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
router := gin.New()
router.POST("/auth/refresh", rateLimitMiddleware.Refresh(), func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 10; i++ {
recorder := performRefreshRateLimitedRequestWithBody(router, "refresh-token-a")
if recorder.Code != http.StatusOK {
t.Fatalf("request %d for refresh-token-a body returned %d, want %d", i+1, recorder.Code, http.StatusOK)
}
}
sameTokenOverflow := performRefreshRateLimitedRequestWithBody(router, "refresh-token-a")
if sameTokenOverflow.Code != http.StatusTooManyRequests {
t.Fatalf("overflow request for refresh-token-a body returned %d, want %d", sameTokenOverflow.Code, http.StatusTooManyRequests)
}
differentToken := performRefreshRateLimitedRequestWithBody(router, "refresh-token-b")
if differentToken.Code != http.StatusOK {
t.Fatalf("request for refresh-token-b body after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK)
}
}

View File

@@ -7,6 +7,8 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"strconv"
"strings"
"syscall" "syscall"
"time" "time"
@@ -43,10 +45,12 @@ func Serve(cfg *config.Config) error {
// P1-3Argon2id 启动时自适应校准 // P1-3Argon2id 启动时自适应校准
auth.CalibrateArgon2id(500 * time.Millisecond) auth.CalibrateArgon2id(500 * time.Millisecond)
accessTokenExpire := resolveJWTAccessTokenExpire(cfg)
// 初始化 JWT 管理器 // 初始化 JWT 管理器
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{ jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
HS256Secret: cfg.JWT.Secret, HS256Secret: cfg.JWT.Secret,
AccessTokenExpire: time.Duration(cfg.JWT.AccessTokenExpireMinutes) * time.Minute, AccessTokenExpire: accessTokenExpire,
RefreshTokenExpire: time.Duration(cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour, RefreshTokenExpire: time.Duration(cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour,
}) })
if err != nil { if err != nil {
@@ -125,6 +129,9 @@ func Serve(cfg *config.Config) error {
totpService := service.NewTOTPService(userRepo) totpService := service.NewTOTPService(userRepo)
passwordResetConfig := service.DefaultPasswordResetConfig() passwordResetConfig := service.DefaultPasswordResetConfig()
if err := configureAuthEmailServices(cfg, cacheManager, authService, passwordResetConfig); err != nil {
return fmt.Errorf("configure auth email services failed: %w", err)
}
passwordResetService := service.NewPasswordResetService(userRepo, cacheManager, passwordResetConfig). passwordResetService := service.NewPasswordResetService(userRepo, cacheManager, passwordResetConfig).
WithPasswordHistoryRepo(passwordHistoryRepo) WithPasswordHistoryRepo(passwordHistoryRepo)
@@ -259,3 +266,100 @@ func resolveGinMode(mode string) string {
return gin.ReleaseMode return gin.ReleaseMode
} }
} }
func configureAuthEmailServices(
cfg *config.Config,
cacheManager *cache.CacheManager,
authService *service.AuthService,
passwordResetConfig *service.PasswordResetConfig,
) error {
smtpConfig, enabled, err := resolveSMTPEmailConfigFromEnv()
if err != nil {
return err
}
if !enabled || cacheManager == nil || authService == nil {
return nil
}
siteURL := resolveAuthEmailSiteURL(cfg)
siteName := resolveAuthEmailSiteName(cfg)
provider := service.NewSMTPEmailProvider(smtpConfig)
authService.SetEmailActivationService(
service.NewEmailActivationService(provider, cacheManager, siteURL, siteName),
)
emailCodeConfig := service.DefaultEmailCodeConfig()
emailCodeConfig.SiteURL = siteURL
emailCodeConfig.SiteName = siteName
authService.SetEmailCodeService(service.NewEmailCodeService(provider, cacheManager, emailCodeConfig))
if passwordResetConfig != nil {
passwordResetConfig.SMTPHost = smtpConfig.Host
passwordResetConfig.SMTPPort = smtpConfig.Port
passwordResetConfig.SMTPUser = smtpConfig.Username
passwordResetConfig.SMTPPass = smtpConfig.Password
passwordResetConfig.FromEmail = smtpConfig.FromEmail
passwordResetConfig.SiteURL = siteURL
}
return nil
}
func resolveSMTPEmailConfigFromEnv() (service.SMTPEmailConfig, bool, error) {
host := strings.TrimSpace(os.Getenv("EMAIL_HOST"))
if host == "" {
return service.SMTPEmailConfig{}, false, nil
}
port := 587
if rawPort := strings.TrimSpace(os.Getenv("EMAIL_PORT")); rawPort != "" {
parsedPort, err := strconv.Atoi(rawPort)
if err != nil || parsedPort <= 0 {
return service.SMTPEmailConfig{}, false, fmt.Errorf("invalid EMAIL_PORT %q", rawPort)
}
port = parsedPort
}
fromEmail := strings.TrimSpace(os.Getenv("EMAIL_FROM_EMAIL"))
if fromEmail == "" {
fromEmail = service.DefaultPasswordResetConfig().FromEmail
}
return service.SMTPEmailConfig{
Host: host,
Port: port,
Username: strings.TrimSpace(os.Getenv("EMAIL_USER")),
Password: os.Getenv("EMAIL_PASS"),
FromEmail: fromEmail,
FromName: strings.TrimSpace(os.Getenv("EMAIL_FROM_NAME")),
}, true, nil
}
func resolveAuthEmailSiteURL(cfg *config.Config) string {
if cfg != nil {
if siteURL := strings.TrimSpace(cfg.Server.FrontendURL); siteURL != "" {
return siteURL
}
}
return service.DefaultEmailCodeConfig().SiteURL
}
func resolveAuthEmailSiteName(cfg *config.Config) string {
if cfg != nil {
if siteName := strings.TrimSpace(cfg.Log.ServiceName); siteName != "" {
return siteName
}
}
return service.DefaultEmailCodeConfig().SiteName
}
func resolveJWTAccessTokenExpire(cfg *config.Config) time.Duration {
if cfg == nil {
return 0
}
if cfg.JWT.AccessTokenExpireMinutes > 0 {
return time.Duration(cfg.JWT.AccessTokenExpireMinutes) * time.Minute
}
return time.Duration(cfg.JWT.ExpireHour) * time.Hour
}

View File

@@ -0,0 +1,73 @@
package server
import (
"testing"
"time"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/config"
"github.com/user-management-system/internal/service"
)
func TestResolveJWTAccessTokenExpire_UsesExpireHourFallback(t *testing.T) {
cfg := &config.Config{}
cfg.JWT.ExpireHour = 24
cfg.JWT.AccessTokenExpireMinutes = 0
expire := resolveJWTAccessTokenExpire(cfg)
if expire != 24*time.Hour {
t.Fatalf("resolveJWTAccessTokenExpire() = %v, want %v", expire, 24*time.Hour)
}
}
func TestResolveJWTAccessTokenExpire_PrefersMinuteOverride(t *testing.T) {
cfg := &config.Config{}
cfg.JWT.ExpireHour = 24
cfg.JWT.AccessTokenExpireMinutes = 90
expire := resolveJWTAccessTokenExpire(cfg)
if expire != 90*time.Minute {
t.Fatalf("resolveJWTAccessTokenExpire() = %v, want %v", expire, 90*time.Minute)
}
}
func TestConfigureAuthEmailServices_UsesSMTPEnvironment(t *testing.T) {
t.Setenv("EMAIL_HOST", "127.0.0.1")
t.Setenv("EMAIL_PORT", "2525")
t.Setenv("EMAIL_FROM_EMAIL", "noreply@test.local")
t.Setenv("EMAIL_FROM_NAME", "UMS E2E")
t.Setenv("EMAIL_USER", "smtp-user")
t.Setenv("EMAIL_PASS", "smtp-pass")
cfg := &config.Config{}
cfg.Server.FrontendURL = "http://127.0.0.1:3000"
cfg.Log.ServiceName = "UMS E2E"
cacheManager := cache.NewCacheManager(cache.NewL1Cache(), cache.NewRedisCache(false))
authService := service.NewAuthService(nil, nil, nil, cacheManager, 8, 5, time.Minute)
passwordResetConfig := service.DefaultPasswordResetConfig()
if err := configureAuthEmailServices(cfg, cacheManager, authService, passwordResetConfig); err != nil {
t.Fatalf("configureAuthEmailServices() error = %v", err)
}
if !authService.SupportsEmailActivation() {
t.Fatal("SupportsEmailActivation() = false, want true")
}
if !authService.HasEmailCodeService() {
t.Fatal("HasEmailCodeService() = false, want true")
}
if passwordResetConfig.SMTPHost != "127.0.0.1" {
t.Fatalf("password reset SMTP host = %q, want %q", passwordResetConfig.SMTPHost, "127.0.0.1")
}
if passwordResetConfig.SMTPPort != 2525 {
t.Fatalf("password reset SMTP port = %d, want %d", passwordResetConfig.SMTPPort, 2525)
}
if passwordResetConfig.FromEmail != "noreply@test.local" {
t.Fatalf("password reset FromEmail = %q, want %q", passwordResetConfig.FromEmail, "noreply@test.local")
}
if passwordResetConfig.SiteURL != "http://127.0.0.1:3000" {
t.Fatalf("password reset SiteURL = %q, want %q", passwordResetConfig.SiteURL, "http://127.0.0.1:3000")
}
}

View File

@@ -2,10 +2,13 @@ package service
import ( import (
"context" "context"
cryptorand "crypto/rand"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log" "log"
"strconv"
"strings" "strings"
"time" "time"
"unicode" "unicode"
@@ -19,11 +22,14 @@ import (
) )
const ( const (
userInfoCachePrefix = "auth_user_info:" userInfoCachePrefix = "auth_user_info:"
tokenBlacklistPrefix = "auth_token_blacklist:" tokenBlacklistPrefix = "auth_token_blacklist:"
defaultUserCacheTTL = 15 * time.Minute totpChallengePrefix = "auth_totp_challenge:"
defaultBlacklistTTL = time.Hour defaultUserCacheTTL = 15 * time.Minute
defaultPasswordMinLen = 8 defaultBlacklistTTL = time.Hour
defaultTOTPChallengeTTL = 5 * time.Minute
defaultPasswordMinLen = 8
refreshTokenRetryGrace = 10 * time.Second
) )
type userRepositoryInterface interface { type userRepositoryInterface interface {
@@ -122,13 +128,18 @@ type LoginResponse struct {
ExpiresIn int64 `json:"expires_in,omitempty"` ExpiresIn int64 `json:"expires_in,omitempty"`
User *UserInfo `json:"user,omitempty"` User *UserInfo `json:"user,omitempty"`
// RequiresTOTP 指示登录需要额外的TOTP验证当设备未信任时 // RequiresTOTP 指示登录需要额外的TOTP验证当设备未信任时
RequiresTOTP bool `json:"requires_totp,omitempty"` RequiresTOTP bool `json:"requires_totp,omitempty"`
// TempToken 临时令牌用于TOTP验证阶段短生命周期不可用于常规API // TempToken 临时令牌用于TOTP验证阶段短生命周期不可用于常规API
TempToken string `json:"temp_token,omitempty"` TempToken string `json:"temp_token,omitempty"`
// UserID 当RequiresTOTP为true时返回用于后续TOTP验证 // UserID 当RequiresTOTP为true时返回用于后续TOTP验证
UserID int64 `json:"user_id,omitempty"` UserID int64 `json:"user_id,omitempty"`
} }
type totpLoginChallenge struct {
UserID int64 `json:"user_id"`
DeviceID string `json:"device_id,omitempty"`
}
type LogoutRequest struct { type LogoutRequest struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
@@ -432,6 +443,38 @@ func (s *AuthService) blacklistTokenClaims(ctx context.Context, token string, va
return s.cache.Set(ctx, tokenBlacklistPrefix+claims.JTI, true, ttl, ttl) return s.cache.Set(ctx, tokenBlacklistPrefix+claims.JTI, true, ttl, ttl)
} }
func (s *AuthService) getTokenBlacklistValue(ctx context.Context, jti string) (interface{}, bool) {
if s == nil || s.cache == nil {
return nil, false
}
jti = strings.TrimSpace(jti)
if jti == "" {
return nil, false
}
return s.cache.Get(ctx, tokenBlacklistPrefix+jti)
}
func tokenBlacklistRevokedAt(value interface{}) (time.Time, bool) {
switch v := value.(type) {
case int64:
return time.Unix(0, v), true
case int:
return time.Unix(0, int64(v)), true
case float64:
return time.Unix(0, int64(v)), true
case string:
timestamp, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64)
if err != nil {
return time.Time{}, false
}
return time.Unix(0, timestamp), true
default:
return time.Time{}, false
}
}
func (s *AuthService) recordLoginAnomaly(ctx context.Context, userID *int64, ip, location, deviceFingerprint string, success bool) { func (s *AuthService) recordLoginAnomaly(ctx context.Context, userID *int64, ip, location, deviceFingerprint string, success bool) {
if s == nil || s.anomalyDetector == nil || userID == nil { if s == nil || s.anomalyDetector == nil || userID == nil {
return return
@@ -601,6 +644,93 @@ func userInfoFromCacheValue(value interface{}) (*UserInfo, bool) {
} }
} }
func generateTemporaryLoginToken() (string, error) {
payload := make([]byte, 32)
if _, err := cryptorand.Read(payload); err != nil {
return "", fmt.Errorf("generate temporary login token failed: %w", err)
}
return base64.RawURLEncoding.EncodeToString(payload), nil
}
func totpLoginChallengeFromCacheValue(value interface{}) (*totpLoginChallenge, bool) {
switch typed := value.(type) {
case *totpLoginChallenge:
return typed, true
case totpLoginChallenge:
challenge := typed
return &challenge, true
case map[string]interface{}:
payload, err := json.Marshal(typed)
if err != nil {
return nil, false
}
var challenge totpLoginChallenge
if err := json.Unmarshal(payload, &challenge); err != nil {
return nil, false
}
return &challenge, true
default:
return nil, false
}
}
func (s *AuthService) issueTOTPLoginChallenge(ctx context.Context, user *domain.User, deviceID string) (string, error) {
if s == nil || s.cache == nil {
return "", errors.New("temporary login token storage is unavailable")
}
if user == nil {
return "", errors.New("temporary login token requires a user")
}
tempToken, err := generateTemporaryLoginToken()
if err != nil {
return "", err
}
challenge := &totpLoginChallenge{
UserID: user.ID,
DeviceID: strings.TrimSpace(deviceID),
}
if err := s.cache.Set(
ctx,
totpChallengePrefix+tempToken,
challenge,
defaultTOTPChallengeTTL,
defaultTOTPChallengeTTL,
); err != nil {
return "", fmt.Errorf("temporary login token storage failed: %w", err)
}
return tempToken, nil
}
func (s *AuthService) validateTOTPLoginChallenge(ctx context.Context, userID int64, deviceID, tempToken string) error {
if s == nil || s.cache == nil {
return errors.New("temporary login token storage is unavailable")
}
normalizedToken := strings.TrimSpace(tempToken)
if normalizedToken == "" {
return errors.New("temporary login token is required")
}
value, ok := s.cache.Get(ctx, totpChallengePrefix+normalizedToken)
if !ok {
return errors.New("temporary login token is invalid or expired")
}
challenge, ok := totpLoginChallengeFromCacheValue(value)
if !ok || challenge == nil {
return errors.New("temporary login token is invalid or expired")
}
if challenge.UserID != userID || strings.TrimSpace(challenge.DeviceID) != strings.TrimSpace(deviceID) {
return errors.New("temporary login token does not match the requested login flow")
}
return nil
}
func (s *AuthService) Register(ctx context.Context, req *RegisterRequest) (*UserInfo, error) { func (s *AuthService) Register(ctx context.Context, req *RegisterRequest) (*UserInfo, error) {
if req == nil { if req == nil {
return nil, errors.New("注册请求不能为空") return nil, errors.New("注册请求不能为空")
@@ -628,6 +758,9 @@ func (s *AuthService) Register(ctx context.Context, req *RegisterRequest) (*User
if err := s.verifyPhoneRegistration(ctx, req); err != nil { if err := s.verifyPhoneRegistration(ctx, req); err != nil {
return nil, err return nil, err
} }
if s.emailActivationSvc != nil && req.Email != "" {
return s.RegisterWithActivation(ctx, req)
}
exists, err := s.userRepo.ExistsByUsername(ctx, req.Username) exists, err := s.userRepo.ExistsByUsername(ctx, req.Username)
if err != nil { if err != nil {
@@ -759,11 +892,17 @@ func (s *AuthService) Login(ctx context.Context, req *LoginRequest, ip string) (
// P0-07 安全修复检查是否需要TOTP验证用户启用了TOTP且设备未信任 // P0-07 安全修复检查是否需要TOTP验证用户启用了TOTP且设备未信任
if s.isTOTPRequiredForLogin(ctx, user, req.DeviceID) { if s.isTOTPRequiredForLogin(ctx, user, req.DeviceID) {
tempToken, err := s.issueTOTPLoginChallenge(ctx, user, req.DeviceID)
if err != nil {
return nil, err
}
// 返回RequiresTOTP指示前端需要完成TOTP验证 // 返回RequiresTOTP指示前端需要完成TOTP验证
// 前端应调用 /auth/login/totp-verify 接口完成验证 // 前端应调用 /auth/login/totp-verify 接口完成验证
return &LoginResponse{ return &LoginResponse{
RequiresTOTP: true, RequiresTOTP: true,
UserID: user.ID, TempToken: tempToken,
UserID: user.ID,
}, nil }, nil
} }
@@ -808,10 +947,13 @@ func (s *AuthService) isTOTPRequiredForLogin(ctx context.Context, user *domain.U
// VerifyTOTPAfterPasswordLogin 完成密码登录后的TOTP验证 // VerifyTOTPAfterPasswordLogin 完成密码登录后的TOTP验证
// 当用户启用了TOTP但设备未信任时密码登录会返回RequiresTOTP=true // 当用户启用了TOTP但设备未信任时密码登录会返回RequiresTOTP=true
// 前端需要调用此接口完成TOTP验证以获取令牌 // 前端需要调用此接口完成TOTP验证以获取令牌
func (s *AuthService) VerifyTOTPAfterPasswordLogin(ctx context.Context, userID int64, totpCode, deviceID string) (*LoginResponse, error) { func (s *AuthService) VerifyTOTPAfterPasswordLogin(ctx context.Context, userID int64, totpCode, deviceID, tempToken string) (*LoginResponse, error) {
if s == nil { if s == nil {
return nil, errors.New("auth service is not initialized") return nil, errors.New("auth service is not initialized")
} }
if err := s.validateTOTPLoginChallenge(ctx, userID, deviceID, tempToken); err != nil {
return nil, err
}
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
@@ -827,6 +969,10 @@ func (s *AuthService) VerifyTOTPAfterPasswordLogin(ctx context.Context, userID i
return nil, err return nil, err
} }
if err := s.cache.Delete(ctx, totpChallengePrefix+strings.TrimSpace(tempToken)); err != nil {
return nil, fmt.Errorf("temporary login token cleanup failed: %w", err)
}
// TOTP验证成功返回完整登录响应 // TOTP验证成功返回完整登录响应
return s.generateLoginResponseWithoutRemember(ctx, user) return s.generateLoginResponseWithoutRemember(ctx, user)
} }
@@ -841,8 +987,11 @@ func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*L
if err != nil { if err != nil {
return nil, err return nil, err
} }
if s.IsTokenBlacklisted(ctx, claims.JTI) { if blacklistValue, blacklisted := s.getTokenBlacklistValue(ctx, claims.JTI); blacklisted {
return nil, errors.New("refresh token has been revoked") revokedAt, hasRevocationTimestamp := tokenBlacklistRevokedAt(blacklistValue)
if !hasRevocationTimestamp || time.Since(revokedAt) > refreshTokenRetryGrace {
return nil, errors.New("refresh token has been revoked")
}
} }
user, err := s.userRepo.GetByID(ctx, claims.UserID) user, err := s.userRepo.GetByID(ctx, claims.UserID)
@@ -861,7 +1010,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*L
if claims.ExpiresAt != nil { if claims.ExpiresAt != nil {
remaining := time.Until(claims.ExpiresAt.Time) remaining := time.Until(claims.ExpiresAt.Time)
if remaining > 0 { if remaining > 0 {
if err := s.cache.Set(ctx, blacklistKey, "1", 5*time.Minute, remaining); err != nil { if err := s.cache.Set(ctx, blacklistKey, time.Now().UnixNano(), 5*time.Minute, remaining); err != nil {
return nil, fmt.Errorf("token revocation failed: %w", err) return nil, fmt.Errorf("token revocation failed: %w", err)
} }
} }

View File

@@ -69,13 +69,17 @@ func (s *AuthService) RegisterWithActivation(ctx context.Context, req *RegisterR
if s.emailActivationSvc != nil && req.Email != "" { if s.emailActivationSvc != nil && req.Email != "" {
initialStatus = domain.UserStatusInactive initialStatus = domain.UserStatusInactive
} }
nickname := req.Nickname
if nickname == "" {
nickname = req.Username
}
user := &domain.User{ user := &domain.User{
Username: req.Username, Username: req.Username,
Email: domain.StrPtr(req.Email), Email: domain.StrPtr(req.Email),
Phone: domain.StrPtr(req.Phone), Phone: domain.StrPtr(req.Phone),
Password: hashedPassword, Password: hashedPassword,
Nickname: req.Nickname, Nickname: nickname,
Status: initialStatus, Status: initialStatus,
} }
if err := s.userRepo.Create(ctx, user); err != nil { if err := s.userRepo.Create(ctx, user); err != nil {
@@ -85,10 +89,6 @@ func (s *AuthService) RegisterWithActivation(ctx context.Context, req *RegisterR
s.bestEffortAssignDefaultRoles(ctx, user.ID, "register_with_activation") s.bestEffortAssignDefaultRoles(ctx, user.ID, "register_with_activation")
if s.emailActivationSvc != nil && req.Email != "" { if s.emailActivationSvc != nil && req.Email != "" {
nickname := req.Nickname
if nickname == "" {
nickname = req.Username
}
// #nosec G118 - 使用独立上下文避免请求结束后被取消 // #nosec G118 - 使用独立上下文避免请求结束后被取消
go func() { // #nosec G118 go func() { // #nosec G118
bgCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) bgCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)

View File

@@ -375,6 +375,51 @@ func TestAuthService_RegisterWithActivation(t *testing.T) {
}) })
} }
func TestAuthService_Register_UsesEmailActivationFlowWhenConfigured(t *testing.T) {
svc, db := setupAuthEmailTestEnv(t)
ctx := context.Background()
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
emailActivationSvc := service.NewEmailActivationService(
&service.MockEmailProvider{},
cacheManager,
"http://localhost:8080",
"TestSite",
)
svc.SetEmailActivationService(emailActivationSvc)
userInfo, err := svc.Register(ctx, &service.RegisterRequest{
Username: "register_activation_enabled",
Password: "Password123!",
Email: "register-activation-enabled@test.com",
})
if err != nil {
t.Fatalf("Register failed: %v", err)
}
if userInfo == nil {
t.Fatal("Register returned nil user info")
}
if userInfo.Status != domain.UserStatusInactive {
t.Fatalf("Register status = %d, want %d", userInfo.Status, domain.UserStatusInactive)
}
if userInfo.Nickname != "register_activation_enabled" {
t.Fatalf("Register nickname = %q, want %q", userInfo.Nickname, "register_activation_enabled")
}
var storedUser domain.User
if err := db.WithContext(ctx).Where("username = ?", "register_activation_enabled").First(&storedUser).Error; err != nil {
t.Fatalf("load stored user: %v", err)
}
if storedUser.Status != domain.UserStatusInactive {
t.Fatalf("stored user status = %d, want %d", storedUser.Status, domain.UserStatusInactive)
}
if storedUser.Nickname != "register_activation_enabled" {
t.Fatalf("stored user nickname = %q, want %q", storedUser.Nickname, "register_activation_enabled")
}
}
// ============================================================================= // =============================================================================
// Login By Email Code Extended Tests // Login By Email Code Extended Tests
// ============================================================================= // =============================================================================

View File

@@ -3,10 +3,12 @@ package service
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"testing" "testing"
"time" "time"
"github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository" "github.com/user-management-system/internal/repository"
"github.com/user-management-system/internal/security" "github.com/user-management-system/internal/security"
@@ -359,6 +361,73 @@ func TestBuildDeviceFingerprint(t *testing.T) {
} }
} }
func TestLogin_IssuesTOTPChallengeTokenWhenSecondFactorIsRequired(t *testing.T) {
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: fmt.Sprintf("file:login_totp_challenge_%d?mode=memory&cache=shared", time.Now().UnixNano()),
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("failed to connect database: %v", err)
}
if err := db.AutoMigrate(&domain.User{}); err != nil {
t.Fatalf("failed to migrate: %v", err)
}
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
HS256Secret: "totp-challenge-secret",
AccessTokenExpire: 15 * time.Minute,
RefreshTokenExpire: 7 * 24 * time.Hour,
})
if err != nil {
t.Fatalf("failed to create jwt manager: %v", err)
}
cacheManager := cache.NewCacheManager(cache.NewL1Cache(), cache.NewRedisCache(false))
userRepo := repository.NewUserRepository(db)
svc := NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
hashedPassword, err := auth.HashPassword("Password123!")
if err != nil {
t.Fatalf("failed to hash password: %v", err)
}
user := &domain.User{
Username: "totpchallenge",
Password: hashedPassword,
Status: domain.UserStatusActive,
TOTPEnabled: true,
TOTPSecret: "JBSWY3DPEHPK3PXP",
}
if err := db.Create(user).Error; err != nil {
t.Fatalf("failed to create user: %v", err)
}
resp, err := svc.Login(context.Background(), &LoginRequest{
Account: "totpchallenge",
Password: "Password123!",
DeviceID: "device-1",
}, "127.0.0.1")
if err != nil {
t.Fatalf("login failed: %v", err)
}
if !resp.RequiresTOTP {
t.Fatalf("expected requires_totp response, got %+v", resp)
}
if resp.UserID != user.ID {
t.Fatalf("expected user id %d, got %d", user.ID, resp.UserID)
}
if strings.TrimSpace(resp.TempToken) == "" {
t.Fatalf("expected temp token when TOTP is required, got %+v", resp)
}
if resp.AccessToken != "" || resp.RefreshToken != "" {
t.Fatalf("expected no full session tokens before TOTP verification, got %+v", resp)
}
}
func TestAuthServiceDefaultConfig(t *testing.T) { func TestAuthServiceDefaultConfig(t *testing.T) {
// Test that default configuration is applied correctly // Test that default configuration is applied correctly
svc := NewAuthService(nil, nil, nil, nil, 0, 0, 0) svc := NewAuthService(nil, nil, nil, nil, 0, 0, 0)