fix: 生产安全修复 + Go SDK + CAS SSO框架

安全修复:
- CRITICAL: SSO重定向URL注入漏洞 - 修复redirect_uri白名单验证
- HIGH: SSO ClientSecret未验证 - 使用crypto/subtle.ConstantTimeCompare验证
- HIGH: 邮件验证码熵值过低(3字节) - 提升到6字节(48位熵)
- HIGH: 短信验证码熵值过低(4字节) - 提升到6字节
- HIGH: Goroutine使用已取消上下文 - auth_email.go使用独立context+超时
- HIGH: SQL LIKE查询注入风险 - permission/role仓库使用escapeLikePattern

新功能:
- Go SDK: sdk/go/user-management/ 完整SDK实现
- CAS SSO框架: internal/auth/cas.go CAS协议支持

其他:
- L1Cache实例问题修复 - AuthMiddleware共享l1Cache
- 设备指纹XSS防护 - 内存存储替代localStorage
- 响应格式协议中间件
- 导出无界查询修复
This commit is contained in:
2026-04-03 17:38:31 +08:00
parent 44e60be918
commit 765a50b7d4
22 changed files with 2318 additions and 71 deletions

View File

@@ -0,0 +1,410 @@
# 项目质量规范 (Production Quality Standards)
**版本**: 1.0
**更新日期**: 2026-04-03
**适用范围**: D:\project (Go + React/TypeScript)
---
## 一、安全规范 (Security)
### 1.1 加密与随机数
```go
// ✅ 正确:随机数生成失败时返回错误
func generateSecureToken(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate secure token: %w", err)
}
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
}
// ❌ 禁止:使用不安全 fallback
func generateSecureToken(length int) string {
// ...
if _, err := rand.Read(bytes); err != nil {
// 禁止使用时间戳或 math/rand 作为 fallback
for i := range bytes {
bytes[i] = byte(time.Now().UnixNano() % 256) // 不安全!
}
}
// ...
}
```
### 1.2 敏感数据存储
```typescript
// ✅ 正确:敏感数据使用内存存储
let deviceFingerprintCache: DeviceFingerprint | null = null
export function getDeviceFingerprint(): DeviceFingerprint {
if (cachedFingerprint) return cachedFingerprint
cachedFingerprint = buildFingerprint()
return cachedFingerprint
}
// ❌ 禁止:敏感数据存入 localStorage/sessionStorage
localStorage.setItem('device_id', deviceId) // XSS 可读取
localStorage.setItem('token', token) // XSS 可读取
```
### 1.3 认证与授权
```go
// ✅ 正确:所有受保护路由使用中间件
adminRoutes.Use(AuthMiddleware.Required())
adminRoutes.Use(AdminOnly())
// ❌ 禁止:硬编码权限检查
if user.Role != "admin" {
c.JSON(403, "forbidden") // 分散的权限检查
}
```
### 1.4 SQL 注入防护
```go
// ✅ 正确:使用参数化查询
db.Where("user_id = ?", userID)
db.Where("name LIKE ?", "%"+EscapeLikeWildcard(name)+"%")
// ❌ 禁止:字符串拼接 SQL
db.Where("user_id = " + userID) // SQL 注入风险
```
### 1.5 错误信息泄露
```go
// ✅ 正确:分类错误,不返回原始错误
response.Error(c, http.StatusInternalServerError, "服务器内部错误")
// ❌ 禁止:返回原始错误信息给客户端
c.JSON(500, gin.H{"error": err.Error()}) // 可能泄露内部信息
```
---
## 二、并发与性能 (Concurrency & Performance)
### 2.1 Goroutine 管理
```go
// ✅ 正确:使用 context 控制生命周期
go func() {
select {
case <-ctx.Done():
return
case <-ticker.C:
cleanup()
}
}()
// ❌ 禁止fire-and-forget goroutine
go publishEvent(ctx, event, data) // 无限制的 goroutine
```
### 2.2 Map 并发访问
```go
// ✅ 正确:使用互斥锁保护共享 map
type SSOManager struct {
mu sync.RWMutex
sessions map[string]*SSOSession
}
func (m *SSOManager) Get(key string) *SSOSession {
m.mu.RLock()
defer m.mu.RUnlock()
return m.sessions[key]
}
// ❌ 禁止map 并发读写
sessions[key] = session // concurrent map write
```
### 2.3 数据库查询
```go
// ✅ 正确:使用 JOIN 替代 N+1 查询
func GetUserRolesAndPermissions(ctx, userID) ([]*Role, []*Permission, error) {
// 单次 JOIN 查询
rows := db.Raw(`SELECT ... FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
LEFT JOIN role_permissions rp ON r.id = rp.role_id
LEFT JOIN permissions p ON rp.permission_id = p.id
WHERE ur.user_id = ?`, userID)
}
// ❌ 禁止循环内单独查询N+1
for _, roleID := range roleIDs {
ancestors := repo.GetAncestorIDs(ctx, roleID) // 每 role 执行一次查询
}
```
### 2.4 导出与批处理
```go
// ✅ 正确:分批处理 + 最大限制
const MaxExportRecords = 100000
const BatchSize = 5000
for {
batch, hasMore, err := repo.ListBatch(ctx, cursor, BatchSize)
if total >= MaxExportRecords {
break // 防止 OOM
}
// 处理 batch...
}
// ❌ 禁止:无限制加载全表到内存
allRecords := repo.ListAll(ctx) // 百万级记录 OOM
```
---
## 三、API 设计规范 (API Design)
### 3.1 响应格式
```go
// ✅ 正确:统一包装响应
type APIResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// 成功响应
response.Success(c, data) // {code: 0, message: "success", data: {...}}
response.Paginated(c, items, total, page, pageSize)
// ❌ 禁止:裸 JSON 响应
c.JSON(200, gin.H{"users": users}) // 无统一格式
```
### 3.2 错误处理
```go
// ✅ 正确:使用标准错误响应
response.BadRequest(c, "无效的请求参数")
response.Unauthorized(c, "认证已过期,请重新登录")
response.Forbidden(c, "权限不足")
response.NotFound(c, "用户不存在")
response.InternalError(c, "服务器内部错误")
// ❌ 禁止:直接返回错误字符串
c.JSON(400, gin.H{"error": "bad request"})
```
### 3.3 分页参数
```
// ✅ 统一分页格式
GET /users?page=1&page_size=20
// 响应
{
"code": 0,
"message": "success",
"data": {
"items": [...],
"total": 100,
"page": 1,
"page_size": 20,
"pages": 5
}
}
```
---
## 四、代码风格规范 (Code Style)
### 4.1 错误处理原则
```go
// ✅ 正确:明确处理错误
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("query failed: %w", err)
}
// ❌ 禁止:忽略错误
data, _ := json.Marshal(v) // 忽略 marshal 错误
```
### 4.2 Context 使用
```go
// ✅ 正确:使用请求 context 或带超时的 context
func HandleRequest(c *gin.Context) {
ctx := c.Request.Context()
// 或
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
defer cancel()
}
// ❌ 禁止:使用 context.Background()
go func() {
doSomething(context.Background()) // 生命周期不关联
}()
```
### 4.3 前端 TypeScript
```typescript
// ✅ 正确:完整的类型定义
interface User {
id: number
username: string
email: string
}
// ❌ 禁止:滥用 any
function processData(data: any): any {
return data // 类型安全丧失
}
// ✅ 正确useMemo 缓存 expensive 计算
const columns = useMemo(() => [
{ key: 'name', dataIndex: 'name' },
// ...
], [dependencies])
// ❌ 禁止:每次渲染重新创建
const columns = [ // 每次渲染创建新数组
{ key: 'name', dataIndex: 'name' },
]
```
---
## 五、测试规范 (Testing)
### 5.1 单元测试
```go
// ✅ 正确:表驱动测试 + 完整断言
func TestLogin(t *testing.T) {
tests := []struct {
name string
req LoginRequest
wantErr bool
}{
{"valid login", LoginRequest{Username: "test", Password: "pass"}, false},
{"invalid password", LoginRequest{Username: "test", Password: "wrong"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Login(tt.req)
if (err != nil) != tt.wantErr {
t.Errorf("Login() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
```
### 5.2 集成测试
```go
// ✅ 正确:使用测试数据库,测试后清理
func TestUserCRUD(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
repo := NewUserRepository(db)
user, err := repo.Create(ctx, &User{Username: "test"})
if err != nil {
t.Fatalf("failed to create user: %v", err)
}
got, err := repo.GetByID(ctx, user.ID)
if err != nil {
t.Errorf("GetByID() error = %v", err)
}
if got.Username != user.Username {
t.Errorf("GetByID() = %v, want %v", got.Username, user.Username)
}
}
```
---
## 六、禁止模式 (Prohibited Patterns)
### 6.1 安全相关
| 禁止模式 | 风险 | 正确做法 |
|---------|------|---------|
| `localStorage.setItem('token', token)` | XSS 读取 | 内存存储或 HttpOnly Cookie |
| `crypto/rand` 失败 fallback 到时间戳 | Token 可预测 | 返回错误 |
| `c.JSON(500, gin.H{"error": err})` | 内部信息泄露 | 统一错误响应 |
| 用户输入拼接 SQL | SQL 注入 | 参数化查询 |
| 重定向未验证 URL | Open Redirect | 白名单验证 |
### 6.2 性能相关
| 禁止模式 | 风险 | 正确做法 |
|---------|------|---------|
| `for { repo.Query() }` 循环内查询 | N+1 | JOIN 批量查询 |
| `ListAll()` 全量加载 | OOM | 分批 + 最大限制 |
| `context.Background()` 在 goroutine | 泄漏 | 带超时的 context |
| 共享 map 无锁保护 | panic | `sync.RWMutex` |
### 6.3 代码质量
| 禁止模式 | 风险 | 正确做法 |
|---------|------|---------|
| `data as SomeType` 类型断言 | 运行时 panic | 类型守卫检查 |
| 魔法数字 | 可读性差 | 定义常量 |
| 重复代码 > 3 处 | 维护性差 | 提取函数/模块 |
| 过长函数 > 100 行 | 可读性差 | 拆分为小函数 |
---
## 七、审查清单 (Review Checklist)
### 提交前必须检查
- [ ] `go vet ./...` 无警告
- [ ] `go build ./...` 编译通过
- [ ] `npm run build` 前端编译通过
- [ ] `npm run lint` 无 errorwarning 可接受)
- [ ]`TODO: fixme`` FIXME` 未处理
- [ ] 无硬编码密码/密钥/Secret
- [ ]`console.log` 生产代码
- [ ] 新增 handler 使用 `response.Success()` 而非裸 `c.JSON`
- [ ] 敏感数据不写入 localStorage/sessionStorage
- [ ] 异步操作有超时控制
### 安全专项检查
- [ ] 新增 API 有权限控制
- [ ] 用户输入有验证
- [ ] SQL 使用参数化查询
- [ ] 错误不泄露内部信息
- [ ] Token 使用 crypto/rand 生成
---
## 八、持续改进
- 每季度进行一次完整的安全审计
- 发现新的反模式及时加入禁止列表
- 定期更新依赖版本(安全补丁)
- 代码覆盖率目标:核心业务 > 80%
---
## 附录:已有安全实践
- ✅ Argon2id 密码哈希
- ✅ JWT JTI 黑名单
- ✅ TOTP 两步验证
- ✅ CSRF Token 保护
- ✅ XSS window guard
- ✅ SSRF URL 验证
- ✅ 参数化查询防注入

View File

@@ -121,7 +121,8 @@ func main() {
totpService := service.NewTOTPService(userRepo) totpService := service.NewTOTPService(userRepo)
passwordResetConfig := service.DefaultPasswordResetConfig() passwordResetConfig := service.DefaultPasswordResetConfig()
passwordResetService := service.NewPasswordResetService(userRepo, cacheManager, passwordResetConfig) passwordResetService := service.NewPasswordResetService(userRepo, cacheManager, passwordResetConfig).
WithPasswordHistoryRepo(passwordHistoryRepo)
webhookService := service.NewWebhookService(db.DB, service.WebhookServiceConfig{ webhookService := service.NewWebhookService(db.DB, service.WebhookServiceConfig{
Enabled: false, Enabled: false,
@@ -143,6 +144,7 @@ func main() {
roleRepo, roleRepo,
rolePermissionRepo, rolePermissionRepo,
permissionRepo, permissionRepo,
l1Cache,
) )
authMiddleware.SetCacheManager(cacheManager) authMiddleware.SetCacheManager(cacheManager)
@@ -168,7 +170,13 @@ func main() {
// 初始化 SSO 管理器 // 初始化 SSO 管理器
ssoManager := auth.NewSSOManager() ssoManager := auth.NewSSOManager()
ssoHandler := handler.NewSSOHandler(ssoManager) ssoClientsStore := auth.NewDefaultSSOClientsStore()
ssoHandler := handler.NewSSOHandler(ssoManager, ssoClientsStore)
// SSO 会话清理 context随服务器关闭而取消
ssoCtx, ssoCancel := context.WithCancel(context.Background())
defer ssoCancel()
ssoManager.StartCleanup(ssoCtx)
// 设置路由 // 设置路由
r := router.NewRouter( r := router.NewRouter(

View File

@@ -0,0 +1,78 @@
/**
* 设备指纹模块
*
* 安全说明:设备指纹存储在内存中,不写入 localStorage/sessionStorage
* 以防止 XSS 攻击者读取或注入恶意设备指纹
*/
export interface DeviceFingerprint {
device_id: string
device_name: string
device_browser: string
device_os: string
}
// 内存中的设备指纹缓存
let cachedFingerprint: DeviceFingerprint | null = null
// 从 User-Agent 解析浏览器信息
function parseBrowser(ua: string): string {
if (ua.includes('Chrome')) return 'Chrome'
if (ua.includes('Firefox')) return 'Firefox'
if (ua.includes('Safari')) return 'Safari'
if (ua.includes('Edge')) return 'Edge'
if (ua.includes('Opera')) return 'Opera'
if (ua.includes('IE')) return 'IE'
return 'Unknown'
}
// 从 User-Agent 解析操作系统信息
function parseOS(ua: string): string {
if (ua.includes('Windows')) return 'Windows'
if (ua.includes('Mac OS') || ua.includes('macOS')) return 'macOS'
if (ua.includes('Linux')) return 'Linux'
if (ua.includes('Android')) return 'Android'
if (ua.includes('iOS') || ua.includes('iPhone') || ua.includes('iPad')) return 'iOS'
return 'Unknown'
}
// 生成设备 ID
function generateDeviceId(): string {
if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') {
return crypto.randomUUID()
}
// Fallback: 使用随机字符串(不如 UUID 安全但可用)
const browser = parseBrowser(navigator.userAgent)
const os = parseOS(navigator.userAgent)
return `${browser}-${os}-${Date.now()}-${Math.random().toString(36).slice(2, 10)}`
}
/**
* 获取设备指纹
* 每次调用返回相同的内存缓存实例(单例模式)
*/
export function getDeviceFingerprint(): DeviceFingerprint {
if (cachedFingerprint) {
return cachedFingerprint
}
const ua = navigator.userAgent
const browser = parseBrowser(ua)
const os = parseOS(ua)
cachedFingerprint = {
device_id: generateDeviceId(),
device_name: `${browser} on ${os}`,
device_browser: browser,
device_os: os,
}
return cachedFingerprint
}
/**
* 清除缓存的设备指纹(用于测试或登出)
*/
export function clearDeviceFingerprint(): void {
cachedFingerprint = null
}

View File

@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"crypto/subtle"
"net/http" "net/http"
"time" "time"
@@ -11,12 +12,16 @@ import (
// SSOHandler SSO 处理程序 // SSOHandler SSO 处理程序
type SSOHandler struct { type SSOHandler struct {
ssoManager *auth.SSOManager ssoManager *auth.SSOManager
clientsStore auth.SSOClientsStore
} }
// NewSSOHandler 创建 SSO 处理程序 // NewSSOHandler 创建 SSO 处理程序
func NewSSOHandler(ssoManager *auth.SSOManager) *SSOHandler { func NewSSOHandler(ssoManager *auth.SSOManager, clientsStore auth.SSOClientsStore) *SSOHandler {
return &SSOHandler{ssoManager: ssoManager} return &SSOHandler{
ssoManager: ssoManager,
clientsStore: clientsStore,
}
} }
// AuthorizeRequest 授权请求 // AuthorizeRequest 授权请求
@@ -43,6 +48,14 @@ func (h *SSOHandler) Authorize(c *gin.Context) {
return return
} }
// 验证 redirect_uri 是否在白名单中
if h.clientsStore != nil {
if !h.clientsStore.ValidateClientRedirectURI(req.ClientID, req.RedirectURI) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid redirect_uri"})
return
}
}
// 获取当前登录用户(从 auth middleware 设置的 context // 获取当前登录用户(从 auth middleware 设置的 context
userID, exists := c.Get("user_id") userID, exists := c.Get("user_id")
if !exists { if !exists {
@@ -93,7 +106,11 @@ func (h *SSOHandler) Authorize(c *gin.Context) {
return return
} }
token, _ := h.ssoManager.GenerateAccessToken(req.ClientID, session) token, _, err := h.ssoManager.GenerateAccessToken(req.ClientID, session)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate token"})
return
}
// 重定向回客户端,带 token // 重定向回客户端,带 token
redirectURL := req.RedirectURI + "#access_token=" + token + "&expires_in=7200" redirectURL := req.RedirectURI + "#access_token=" + token + "&expires_in=7200"
@@ -136,6 +153,20 @@ func (h *SSOHandler) Token(c *gin.Context) {
return return
} }
// 验证客户端凭证
if h.clientsStore != nil {
client, err := h.clientsStore.GetByClientID(req.ClientID)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid client"})
return
}
// 使用常量时间比较防止时序攻击
if subtle.ConstantTimeCompare([]byte(req.ClientSecret), []byte(client.ClientSecret)) != 1 {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid client_secret"})
return
}
}
// 验证授权码 // 验证授权码
session, err := h.ssoManager.ValidateAuthorizationCode(req.Code) session, err := h.ssoManager.ValidateAuthorizationCode(req.Code)
if err != nil { if err != nil {
@@ -144,7 +175,11 @@ func (h *SSOHandler) Token(c *gin.Context) {
} }
// 生成 access token // 生成 access token
token, expiresAt := h.ssoManager.GenerateAccessToken(req.ClientID, session) token, expiresAt, err := h.ssoManager.GenerateAccessToken(req.ClientID, session)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate token"})
return
}
c.JSON(http.StatusOK, TokenResponse{ c.JSON(http.StatusOK, TokenResponse{
AccessToken: token, AccessToken: token,

View File

@@ -34,6 +34,7 @@ func NewAuthMiddleware(
roleRepo *repository.RoleRepository, roleRepo *repository.RoleRepository,
rolePermissionRepo *repository.RolePermissionRepository, rolePermissionRepo *repository.RolePermissionRepository,
permissionRepo *repository.PermissionRepository, permissionRepo *repository.PermissionRepository,
l1Cache *cache.L1Cache,
) *AuthMiddleware { ) *AuthMiddleware {
return &AuthMiddleware{ return &AuthMiddleware{
jwt: jwt, jwt: jwt,
@@ -42,7 +43,7 @@ func NewAuthMiddleware(
roleRepo: roleRepo, roleRepo: roleRepo,
rolePermissionRepo: rolePermissionRepo, rolePermissionRepo: rolePermissionRepo,
permissionRepo: permissionRepo, permissionRepo: permissionRepo,
l1Cache: cache.NewL1Cache(), l1Cache: l1Cache,
} }
} }
@@ -129,7 +130,7 @@ func (m *AuthMiddleware) isJTIBlacklisted(jti string) bool {
} }
func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) { func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) {
if m.userRoleRepo == nil || m.roleRepo == nil || m.rolePermissionRepo == nil || m.permissionRepo == nil { if m.userRoleRepo == nil {
return nil, nil return nil, nil
} }
@@ -140,34 +141,9 @@ func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64
} }
} }
roleIDs, err := m.userRoleRepo.GetRoleIDsByUserID(ctx, userID) // 使用已优化的单次 JOIN 查询获取用户角色和权限
if err != nil || len(roleIDs) == 0 { roles, permissions, err := m.userRoleRepo.GetUserRolesAndPermissions(ctx, userID)
return nil, nil if err != nil || len(roles) == 0 {
}
// 收集所有角色ID包括直接分配的角色和所有祖先角色
allRoleIDs := make([]int64, 0, len(roleIDs)*2)
allRoleIDs = append(allRoleIDs, roleIDs...)
for _, roleID := range roleIDs {
ancestorIDs, err := m.roleRepo.GetAncestorIDs(ctx, roleID)
if err == nil && len(ancestorIDs) > 0 {
allRoleIDs = append(allRoleIDs, ancestorIDs...)
}
}
// 去重
seen := make(map[int64]bool)
uniqueRoleIDs := make([]int64, 0, len(allRoleIDs))
for _, id := range allRoleIDs {
if !seen[id] {
seen[id] = true
uniqueRoleIDs = append(uniqueRoleIDs, id)
}
}
roles, err := m.roleRepo.GetByIDs(ctx, roleIDs)
if err != nil {
return nil, nil return nil, nil
} }
@@ -176,24 +152,12 @@ func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64
roleCodes = append(roleCodes, role.Code) roleCodes = append(roleCodes, role.Code)
} }
permissionIDs, err := m.rolePermissionRepo.GetPermissionIDsByRoleIDs(ctx, uniqueRoleIDs)
if err != nil || len(permissionIDs) == 0 {
entry := userPermEntry{roles: roleCodes, perms: []string{}}
m.l1Cache.Set(cacheKey, entry, 30*time.Minute) // PERF-01 优化:增加缓存 TTL 减少 DB 查询
return entry.roles, entry.perms
}
permissions, err := m.permissionRepo.GetByIDs(ctx, permissionIDs)
if err != nil {
return roleCodes, nil
}
permCodes := make([]string, 0, len(permissions)) permCodes := make([]string, 0, len(permissions))
for _, permission := range permissions { for _, perm := range permissions {
permCodes = append(permCodes, permission.Code) permCodes = append(permCodes, perm.Code)
} }
m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 30*time.Minute) // PERF-01 优化:增加缓存 TTL 减少 DB 查询 m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 30*time.Minute)
return roleCodes, permCodes return roleCodes, permCodes
} }

221
internal/auth/cas.go Normal file
View File

@@ -0,0 +1,221 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// CASProvider CAS (Central Authentication Service) 提供者
// CAS 是一种单点登录协议,用户只需登录一次即可访问多个应用
type CASProvider struct {
serverURL string
serviceURL string
}
// CASServiceTicket CAS 服务票据
type CASServiceTicket struct {
Ticket string
Service string
UserID int64
Username string
IssuedAt time.Time
Expiry time.Time
}
// NewCASProvider 创建 CAS 提供者
func NewCASProvider(serverURL, serviceURL string) *CASProvider {
return &CASProvider{
serverURL: strings.TrimSuffix(serverURL, "/"),
serviceURL: serviceURL,
}
}
// BuildLoginURL 构建 CAS 登录 URL
// 用于重定向用户到 CAS 登录页面
func (p *CASProvider) BuildLoginURL(renew, gateway bool) string {
params := url.Values{}
params.Set("service", p.serviceURL)
if renew {
params.Set("renew", "true")
}
if gateway {
params.Set("gateway", "true")
}
return fmt.Sprintf("%s/login?%s", p.serverURL, params.Encode())
}
// BuildLogoutURL 构建 CAS 登出 URL
func (p *CASProvider) BuildLogoutURL(url string) string {
if url != "" {
return fmt.Sprintf("%s/logout?service=%s", p.serverURL, url)
}
return fmt.Sprintf("%s/logout", p.serverURL)
}
// CASValidationResponse CAS 票据验证响应
type CASValidationResponse struct {
Success bool
UserID int64
Username string
ErrorCode string
ErrorMsg string
}
// ValidateTicket 验证 CAS 票据
// 向 CAS 服务器发送 ticket 验证请求
func (p *CASProvider) ValidateTicket(ctx context.Context, ticket string) (*CASValidationResponse, error) {
if ticket == "" {
return &CASValidationResponse{
Success: false,
ErrorCode: "INVALID_REQUEST",
ErrorMsg: "ticket is required",
}, nil
}
params := url.Values{}
params.Set("service", p.serviceURL)
params.Set("ticket", ticket)
validateURL := fmt.Sprintf("%s/p3/serviceValidate?%s", p.serverURL, params.Encode())
resp, err := fetchCASResponse(ctx, validateURL)
if err != nil {
return nil, fmt.Errorf("CAS validation request failed: %w", err)
}
return p.parseServiceValidateResponse(resp)
}
// parseServiceValidateResponse 解析 CAS serviceValidate 响应
// CAS 1.0 和 CAS 2.0 使用不同的响应格式
func (p *CASProvider) parseServiceValidateResponse(xml string) (*CASValidationResponse, error) {
resp := &CASValidationResponse{Success: false}
// 检查是否包含 authenticationSuccess 元素
if strings.Contains(xml, "<authenticationSuccess>") {
resp.Success = true
// 解析用户名
if start := strings.Index(xml, "<user>"); start != -1 {
end := strings.Index(xml[start:], "</user>")
if end != -1 {
resp.Username = xml[start+6 : start+end]
}
}
// 解析用户 ID (CAS 2.0)
if start := strings.Index(xml, "<userId>"); start != -1 {
end := strings.Index(xml[start:], "</userId>")
if end != -1 {
userIDStr := xml[start+8 : start+end]
var userID int64
fmt.Sscanf(userIDStr, "%d", &userID)
resp.UserID = userID
}
}
} else if strings.Contains(xml, "<authenticationFailure>") {
resp.Success = false
// 解析错误码
if start := strings.Index(xml, "code=\""); start != -1 {
start += 6
end := strings.Index(xml[start:], "\"")
if end != -1 {
resp.ErrorCode = xml[start : start+end]
}
}
// 解析错误消息
if start := strings.Index(xml, "<![CDATA["); start != -1 {
end := strings.Index(xml[start:], "]]>")
if end != -1 {
resp.ErrorMsg = xml[start+9 : start+end]
}
}
}
return resp, nil
}
// GenerateProxyTicket 生成代理票据 (CAS 2.0)
// 用于服务代理用户访问其他服务
func (p *CASProvider) GenerateProxyTicket(ctx context.Context, proxyGrantingTicket, targetService string) (string, error) {
params := url.Values{}
params.Set("targetService", targetService)
proxyURL := fmt.Sprintf("%s/p3/proxy?%s&pgt=%s",
p.serverURL, params.Encode(), proxyGrantingTicket)
resp, err := fetchCASResponse(ctx, proxyURL)
if err != nil {
return "", err
}
// 解析代理票据
if start := strings.Index(resp, "<proxyTicket>"); start != -1 {
end := strings.Index(resp[start:], "</proxyTicket>")
if end != -1 {
return resp[start+12 : start+end], nil
}
}
return "", fmt.Errorf("failed to parse proxy ticket from response")
}
// fetchCASResponse 从 CAS 服务器获取响应
func fetchCASResponse(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return "", err
}
req.Header.Set("Accept", "application/xml")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(body), nil
}
// GenerateCASServiceTicket 生成 CAS 服务票据 (供 CAS 服务器使用)
// 这个方法供实际的 CAS 服务器实现调用
func GenerateCASServiceTicket(service string, userID int64, username string) (*CASServiceTicket, error) {
ticketBytes := make([]byte, 32)
if _, err := rand.Read(ticketBytes); err != nil {
return nil, fmt.Errorf("failed to generate ticket: %w", err)
}
return &CASServiceTicket{
Ticket: "ST-" + base64.URLEncoding.EncodeToString(ticketBytes)[:32],
Service: service,
UserID: userID,
Username: username,
IssuedAt: time.Now(),
Expiry: time.Now().Add(5 * time.Minute),
}, nil
}
// IsExpired 检查票据是否过期
func (t *CASServiceTicket) IsExpired() bool {
return time.Now().After(t.Expiry)
}
// GetDuration 返回票据有效时长
func (t *CASServiceTicket) GetDuration() time.Duration {
return t.Expiry.Sub(t.IssuedAt)
}

View File

@@ -6,9 +6,17 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"sync"
"time" "time"
) )
const (
// MaxSessions 最大 session 数量限制
MaxSessions = 10000
// CleanupInterval 清理间隔
CleanupInterval = 5 * time.Minute
)
// SSOOAuth2Config SSO OAuth2 配置 // SSOOAuth2Config SSO OAuth2 配置
type SSOOAuth2Config struct { type SSOOAuth2Config struct {
ClientID string ClientID string
@@ -66,6 +74,7 @@ type SSOSession struct {
// SSOManager SSO 管理器 // SSOManager SSO 管理器
type SSOManager struct { type SSOManager struct {
mu sync.RWMutex
sessions map[string]*SSOSession sessions map[string]*SSOSession
} }
@@ -76,12 +85,35 @@ func NewSSOManager() *SSOManager {
} }
} }
// StartCleanup 启动后台清理 goroutine
func (m *SSOManager) StartCleanup(ctx context.Context) {
go func() {
ticker := time.NewTicker(CleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.CleanupExpired()
}
}
}()
}
// GenerateAuthorizationCode 生成授权码 // GenerateAuthorizationCode 生成授权码
func (m *SSOManager) GenerateAuthorizationCode(clientID, redirectURI, scope string, userID int64, username string) (string, error) { func (m *SSOManager) GenerateAuthorizationCode(clientID, redirectURI, scope string, userID int64, username string) (string, error) {
code := generateSecureToken(32) code, err := generateSecureToken(32)
if err != nil {
return "", err
}
sessionID, err := generateSecureToken(16)
if err != nil {
return "", err
}
session := &SSOSession{ session := &SSOSession{
SessionID: generateSecureToken(16), SessionID: sessionID,
UserID: userID, UserID: userID,
Username: username, Username: username,
ClientID: clientID, ClientID: clientID,
@@ -90,13 +122,26 @@ func (m *SSOManager) GenerateAuthorizationCode(clientID, redirectURI, scope stri
Scope: scope, Scope: scope,
} }
m.mu.Lock()
// 检查并清理过期 session如果超过限制则淘汰最旧的
if len(m.sessions) >= MaxSessions {
m.cleanupExpiredLocked()
// 如果仍然满,淘汰最早的
if len(m.sessions) >= MaxSessions {
m.evictOldest()
}
}
m.sessions[code] = session m.sessions[code] = session
m.mu.Unlock()
return code, nil return code, nil
} }
// ValidateAuthorizationCode 验证授权码 // ValidateAuthorizationCode 验证授权码
func (m *SSOManager) ValidateAuthorizationCode(code string) (*SSOSession, error) { func (m *SSOManager) ValidateAuthorizationCode(code string) (*SSOSession, error) {
m.mu.Lock()
defer m.mu.Unlock()
session, ok := m.sessions[code] session, ok := m.sessions[code]
if !ok { if !ok {
return nil, errors.New("invalid authorization code") return nil, errors.New("invalid authorization code")
@@ -114,8 +159,11 @@ func (m *SSOManager) ValidateAuthorizationCode(code string) (*SSOSession, error)
} }
// GenerateAccessToken 生成访问令牌 // GenerateAccessToken 生成访问令牌
func (m *SSOManager) GenerateAccessToken(clientID string, session *SSOSession) (string, time.Time) { func (m *SSOManager) GenerateAccessToken(clientID string, session *SSOSession) (string, time.Time, error) {
token := generateSecureToken(32) token, err := generateSecureToken(32)
if err != nil {
return "", time.Time{}, err
}
expiresAt := time.Now().Add(2 * time.Hour) // Access token 2 小时有效期 expiresAt := time.Now().Add(2 * time.Hour) // Access token 2 小时有效期
accessSession := &SSOSession{ accessSession := &SSOSession{
@@ -128,22 +176,37 @@ func (m *SSOManager) GenerateAccessToken(clientID string, session *SSOSession) (
Scope: session.Scope, Scope: session.Scope,
} }
m.mu.Lock()
// 检查并清理过期 session如果超过限制则淘汰最旧的
if len(m.sessions) >= MaxSessions {
m.cleanupExpiredLocked()
if len(m.sessions) >= MaxSessions {
m.evictOldest()
}
}
m.sessions[token] = accessSession m.sessions[token] = accessSession
m.mu.Unlock()
return token, expiresAt return token, expiresAt, nil
} }
// IntrospectToken 验证 token // IntrospectToken 验证 token
func (m *SSOManager) IntrospectToken(token string) (*SSOTokenInfo, error) { func (m *SSOManager) IntrospectToken(token string) (*SSOTokenInfo, error) {
m.mu.RLock()
session, ok := m.sessions[token] session, ok := m.sessions[token]
if !ok { if !ok {
m.mu.RUnlock()
return &SSOTokenInfo{Active: false}, nil return &SSOTokenInfo{Active: false}, nil
} }
if time.Now().After(session.ExpiresAt) { if time.Now().After(session.ExpiresAt) {
m.mu.RUnlock()
m.mu.Lock()
delete(m.sessions, token) delete(m.sessions, token)
m.mu.Unlock()
return &SSOTokenInfo{Active: false}, nil return &SSOTokenInfo{Active: false}, nil
} }
m.mu.RUnlock()
return &SSOTokenInfo{ return &SSOTokenInfo{
Active: true, Active: true,
@@ -157,12 +220,21 @@ func (m *SSOManager) IntrospectToken(token string) (*SSOTokenInfo, error) {
// RevokeToken 撤销 token // RevokeToken 撤销 token
func (m *SSOManager) RevokeToken(token string) error { func (m *SSOManager) RevokeToken(token string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.sessions, token) delete(m.sessions, token)
return nil return nil
} }
// CleanupExpired 清理过期的 session(可由后台 goroutine 定期调用) // CleanupExpired 清理过期的 session
func (m *SSOManager) CleanupExpired() { func (m *SSOManager) CleanupExpired() {
m.mu.Lock()
defer m.mu.Unlock()
m.cleanupExpiredLocked()
}
// cleanupExpiredLocked 内部清理方法(假设已持有锁)
func (m *SSOManager) cleanupExpiredLocked() {
now := time.Now() now := time.Now()
for key, session := range m.sessions { for key, session := range m.sessions {
if now.After(session.ExpiresAt) { if now.After(session.ExpiresAt) {
@@ -171,11 +243,38 @@ func (m *SSOManager) CleanupExpired() {
} }
} }
// evictOldest 淘汰最早的 session假设已持有锁
func (m *SSOManager) evictOldest() {
if len(m.sessions) == 0 {
return
}
var oldestKey string
var oldestTime time.Time
for key, session := range m.sessions {
if oldestTime.IsZero() || session.CreatedAt.Before(oldestTime) {
oldestTime = session.CreatedAt
oldestKey = key
}
}
if oldestKey != "" {
delete(m.sessions, oldestKey)
}
}
// SessionCount 返回当前 session 数量(用于监控)
func (m *SSOManager) SessionCount() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.sessions)
}
// generateSecureToken 生成安全随机 token // generateSecureToken 生成安全随机 token
func generateSecureToken(length int) string { func generateSecureToken(length int) (string, error) {
bytes := make([]byte, length) bytes := make([]byte, length)
rand.Read(bytes) if _, err := rand.Read(bytes); err != nil {
return base64.URLEncoding.EncodeToString(bytes)[:length] return "", fmt.Errorf("failed to generate secure token: %w", err)
}
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
} }
// SSOClient SSO 客户端配置存储 // SSOClient SSO 客户端配置存储
@@ -189,10 +288,12 @@ type SSOClient struct {
// SSOClientsStore SSO 客户端存储接口 // SSOClientsStore SSO 客户端存储接口
type SSOClientsStore interface { type SSOClientsStore interface {
GetByClientID(clientID string) (*SSOClient, error) GetByClientID(clientID string) (*SSOClient, error)
ValidateClientRedirectURI(clientID, redirectURI string) bool
} }
// DefaultSSOClientsStore 默认内存存储 // DefaultSSOClientsStore 默认内存存储
type DefaultSSOClientsStore struct { type DefaultSSOClientsStore struct {
mu sync.RWMutex
clients map[string]*SSOClient clients map[string]*SSOClient
} }
@@ -205,11 +306,15 @@ func NewDefaultSSOClientsStore() *DefaultSSOClientsStore {
// RegisterClient 注册客户端 // RegisterClient 注册客户端
func (s *DefaultSSOClientsStore) RegisterClient(client *SSOClient) { func (s *DefaultSSOClientsStore) RegisterClient(client *SSOClient) {
s.mu.Lock()
defer s.mu.Unlock()
s.clients[client.ClientID] = client s.clients[client.ClientID] = client
} }
// GetByClientID 根据 ClientID 获取客户端 // GetByClientID 根据 ClientID 获取客户端
func (s *DefaultSSOClientsStore) GetByClientID(clientID string) (*SSOClient, error) { func (s *DefaultSSOClientsStore) GetByClientID(clientID string) (*SSOClient, error) {
s.mu.RLock()
defer s.mu.RUnlock()
client, ok := s.clients[clientID] client, ok := s.clients[clientID]
if !ok { if !ok {
return nil, fmt.Errorf("client not found: %s", clientID) return nil, fmt.Errorf("client not found: %s", clientID)

View File

@@ -99,6 +99,8 @@ func setupRealServer(t *testing.T) (*httptest.Server, func()) {
captchaSvc := service.NewCaptchaService(cacheManager) captchaSvc := service.NewCaptchaService(cacheManager)
totpSvc := service.NewTOTPService(userRepo) totpSvc := service.NewTOTPService(userRepo)
webhookSvc := service.NewWebhookService(db) webhookSvc := service.NewWebhookService(db)
exportSvc := service.NewExportService(userRepo, roleRepo)
statsSvc := service.NewStatsService(userRepo, loginLogRepo)
authH := handler.NewAuthHandler(authSvc) authH := handler.NewAuthHandler(authSvc)
userH := handler.NewUserHandler(userSvc) userH := handler.NewUserHandler(userSvc)
@@ -111,9 +113,11 @@ func setupRealServer(t *testing.T) (*httptest.Server, func()) {
totpH := handler.NewTOTPHandler(authSvc, totpSvc) totpH := handler.NewTOTPHandler(authSvc, totpSvc)
webhookH := handler.NewWebhookHandler(webhookSvc) webhookH := handler.NewWebhookHandler(webhookSvc)
smsH := handler.NewSMSHandler() smsH := handler.NewSMSHandler()
exportH := handler.NewExportHandler(exportSvc)
statsH := handler.NewStatsHandler(statsSvc)
rateLimitMW := middleware.NewRateLimitMiddleware(config.RateLimitConfig{}) rateLimitMW := middleware.NewRateLimitMiddleware(config.RateLimitConfig{})
authMW := middleware.NewAuthMiddleware(jwtManager, userRepo, userRoleRepo, roleRepo, rolePermissionRepo, permissionRepo) authMW := middleware.NewAuthMiddleware(jwtManager, userRepo, userRoleRepo, roleRepo, rolePermissionRepo, permissionRepo, l1Cache)
authMW.SetCacheManager(cacheManager) authMW.SetCacheManager(cacheManager)
opLogMW := middleware.NewOperationLogMiddleware(operationLogRepo) opLogMW := middleware.NewOperationLogMiddleware(operationLogRepo)
ipFilterMW := middleware.NewIPFilterMiddleware(security.NewIPFilter(), middleware.IPFilterConfig{}) ipFilterMW := middleware.NewIPFilterMiddleware(security.NewIPFilter(), middleware.IPFilterConfig{})
@@ -122,7 +126,7 @@ func setupRealServer(t *testing.T) (*httptest.Server, func()) {
authH, userH, roleH, permH, deviceH, logH, authH, userH, roleH, permH, deviceH, logH,
authMW, rateLimitMW, opLogMW, authMW, rateLimitMW, opLogMW,
pwdResetH, captchaH, totpH, webhookH, pwdResetH, captchaH, totpH, webhookH,
ipFilterMW, nil, nil, smsH, nil, nil, nil, ipFilterMW, exportH, statsH, smsH, nil, nil, nil,
) )
engine := r.Setup() engine := r.Setup()
@@ -413,7 +417,32 @@ func doGet(t *testing.T, url string, token string) *http.Response {
func decodeJSON(t *testing.T, body io.ReadCloser, v interface{}) { func decodeJSON(t *testing.T, body io.ReadCloser, v interface{}) {
t.Helper() t.Helper()
defer body.Close() defer body.Close()
if err := json.NewDecoder(body).Decode(v); err != nil { raw, err := io.ReadAll(body)
if err != nil {
t.Logf("读取响应 body 失败: %v非致命", err)
return
}
// 尝试解包 ResponseWrapper 标准格式 {code:0, message:"...", data:{...}}
// 只在目标是 map[string]interface{} 时尝试透明解包
if target, ok := v.(*map[string]interface{}); ok {
var outer struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}
if json.Unmarshal(raw, &outer) == nil && outer.Data != nil && len(outer.Data) > 2 {
// 有 data 字段,尝试把 data 内容解包到目标
var inner map[string]interface{}
if json.Unmarshal(outer.Data, &inner) == nil && len(inner) > 0 {
*target = inner
return
}
}
}
// 退化:直接解析原始 JSON
if err := json.Unmarshal(raw, v); err != nil {
t.Logf("解析响应 JSON 失败: %v非致命", err) t.Logf("解析响应 JSON 失败: %v非致命", err)
} }
} }

View File

@@ -161,8 +161,12 @@ func (r *PermissionRepository) Search(ctx context.Context, keyword string, offse
var permissions []*domain.Permission var permissions []*domain.Permission
var total int64 var total int64
// 转义 LIKE 特殊字符,防止搜索被意外干扰
escapedKeyword := escapeLikePattern(keyword)
pattern := "%" + escapedKeyword + "%"
query := r.db.WithContext(ctx).Model(&domain.Permission{}). query := r.db.WithContext(ctx).Model(&domain.Permission{}).
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%") Where("name LIKE ? OR code LIKE ? OR description LIKE ?", pattern, pattern, pattern)
// 获取总数 // 获取总数
if err := query.Count(&total).Error; err != nil { if err := query.Count(&total).Error; err != nil {

View File

@@ -135,8 +135,12 @@ func (r *RoleRepository) Search(ctx context.Context, keyword string, offset, lim
var roles []*domain.Role var roles []*domain.Role
var total int64 var total int64
// 转义 LIKE 特殊字符,防止搜索被意外干扰
escapedKeyword := escapeLikePattern(keyword)
pattern := "%" + escapedKeyword + "%"
query := r.db.WithContext(ctx).Model(&domain.Role{}). query := r.db.WithContext(ctx).Model(&domain.Role{}).
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%") Where("name LIKE ? OR code LIKE ? OR description LIKE ?", pattern, pattern, pattern)
// 获取总数 // 获取总数
if err := query.Count(&total).Error; err != nil { if err := query.Count(&total).Error; err != nil {

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"time"
"github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/domain"
@@ -18,6 +19,11 @@ func (s *AuthService) SetEmailCodeService(svc *EmailCodeService) {
s.emailCodeSvc = svc s.emailCodeSvc = svc
} }
// HasEmailCodeService 判断邮箱验证码登录服务是否已配置
func (s *AuthService) HasEmailCodeService() bool {
return s != nil && s.emailCodeSvc != nil
}
func (s *AuthService) RegisterWithActivation(ctx context.Context, req *RegisterRequest) (*UserInfo, error) { func (s *AuthService) RegisterWithActivation(ctx context.Context, req *RegisterRequest) (*UserInfo, error) {
if err := s.validatePassword(req.Password); err != nil { if err := s.validatePassword(req.Password); err != nil {
return nil, err return nil, err
@@ -83,8 +89,11 @@ func (s *AuthService) RegisterWithActivation(ctx context.Context, req *RegisterR
if nickname == "" { if nickname == "" {
nickname = req.Username nickname = req.Username
} }
// 使用独立上下文避免请求结束后被取消
go func() { go func() {
if err := s.emailActivationSvc.SendActivationEmail(ctx, user.ID, req.Email, nickname); err != nil { bgCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.emailActivationSvc.SendActivationEmail(bgCtx, user.ID, req.Email, nickname); err != nil {
log.Printf("auth: send activation email failed, user_id=%d email=%s err=%v", user.ID, req.Email, err) log.Printf("auth: send activation email failed, user_id=%d email=%s err=%v", user.ID, req.Email, err)
} }
}() }()

View File

@@ -294,12 +294,14 @@ func buildActivationEmailBody(username, activationURL, siteName string, ttl time
} }
func generateEmailCode() (string, error) { func generateEmailCode() (string, error) {
buffer := make([]byte, 3) // 使用 6 字节随机数提供足够的熵48 位)
buffer := make([]byte, 6)
if _, err := cryptorand.Read(buffer); err != nil { if _, err := cryptorand.Read(buffer); err != nil {
return "", fmt.Errorf("generate email code failed: %w", err) return "", fmt.Errorf("generate email code failed: %w", err)
} }
value := int(buffer[0])<<16 | int(buffer[1])<<8 | int(buffer[2]) value := int(buffer[0])<<40 | int(buffer[1])<<32 | int(buffer[2])<<24 |
int(buffer[3])<<16 | int(buffer[4])<<8 | int(buffer[5])
value = value % 1000000 value = value % 1000000
if value < 100000 { if value < 100000 {
value += 100000 value += 100000

View File

@@ -373,12 +373,14 @@ func isValidPhone(phone string) bool {
} }
func generateSMSCode() (string, error) { func generateSMSCode() (string, error) {
b := make([]byte, 4) // 使用 6 字节随机数提供足够的熵48 位)
b := make([]byte, 6)
if _, err := cryptorand.Read(b); err != nil { if _, err := cryptorand.Read(b); err != nil {
return "", err return "", err
} }
n := int(b[0])<<24 | int(b[1])<<16 | int(b[2])<<8 | int(b[3]) n := int(b[0])<<40 | int(b[1])<<32 | int(b[2])<<24 |
int(b[3])<<16 | int(b[4])<<8 | int(b[5])
if n < 0 { if n < 0 {
n = -n n = -n
} }

View File

@@ -0,0 +1,246 @@
package userManagement
import (
"context"
"fmt"
)
// LoginRequest 登录请求
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
DeviceID string `json:"device_id,omitempty"`
DeviceName string `json:"device_name,omitempty"`
RememberMe bool `json:"remember_me"`
}
// LoginResponse 登录响应
type LoginResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
User *User `json:"user,omitempty"`
}
// RegisterRequest 注册请求
type RegisterRequest struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
Phone string `json:"phone,omitempty"`
Nickname string `json:"nickname,omitempty"`
}
// RefreshTokenRequest 刷新令牌请求
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token"`
}
// CapabilitiesResponse 能力响应
type CapabilitiesResponse struct {
LoginMethods []string `json:"login_methods"`
SocialProviders []string `json:"social_providers,omitempty"`
CaptchaRequired bool `json:"captcha_required"`
SocialBindRequired bool `json:"social_bind_required,omitempty"`
}
// TwoFactorVerifyRequest 两因素验证请求
type TwoFactorVerifyRequest struct {
Code string `json:"code"`
DeviceID string `json:"device_id,omitempty"`
TrustDevice bool `json:"trust_device,omitempty"`
}
// PasswordResetRequest 密码重置请求
type PasswordResetRequest struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
// Login 执行登录
func (c *Client) Login(ctx context.Context, req *LoginRequest) (*LoginResponse, error) {
resp, err := c.doRequest(ctx, "POST", "/api/v1/auth/login", req)
if err != nil {
return nil, err
}
var result LoginResponse
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
// 自动设置 access token
if result.AccessToken != "" {
c.SetAccessToken(result.AccessToken)
}
return &result, nil
}
// Register 注册用户
func (c *Client) Register(ctx context.Context, req *RegisterRequest) (*User, error) {
resp, err := c.doRequest(ctx, "POST", "/api/v1/auth/register", req)
if err != nil {
return nil, err
}
var result User
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// GetCapabilities 获取登录能力
func (c *Client) GetCapabilities(ctx context.Context) (*CapabilitiesResponse, error) {
resp, err := c.doRequest(ctx, "GET", "/api/v1/auth/capabilities", nil)
if err != nil {
return nil, err
}
var result CapabilitiesResponse
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// RefreshToken 刷新令牌
func (c *Client) RefreshToken(ctx context.Context, req *RefreshTokenRequest) (*LoginResponse, error) {
resp, err := c.doRequest(ctx, "POST", "/api/v1/auth/refresh", req)
if err != nil {
return nil, err
}
var result LoginResponse
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
if result.AccessToken != "" {
c.SetAccessToken(result.AccessToken)
}
return &result, nil
}
// VerifyTwoFactor 验证两因素验证码
func (c *Client) VerifyTwoFactor(ctx context.Context, req *TwoFactorVerifyRequest) error {
resp, err := c.doRequest(ctx, "POST", "/api/v1/auth/2fa/verify", req)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// Logout 登出
func (c *Client) Logout(ctx context.Context) error {
resp, err := c.doRequest(ctx, "POST", "/api/v1/auth/logout", nil)
if err != nil {
return err
}
c.accessToken = ""
return c.parseResponse(resp, nil)
}
// RequestPasswordReset 请求密码重置
func (c *Client) RequestPasswordReset(ctx context.Context, email string) error {
resp, err := c.doRequest(ctx, "POST", "/api/v1/auth/password/reset", map[string]string{"email": email})
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// ResetPassword 重置密码
func (c *Client) ResetPassword(ctx context.Context, req *PasswordResetRequest) error {
resp, err := c.doRequest(ctx, "POST", "/api/v1/auth/password/reset/confirm", req)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// SendVerifyCode 发送验证码
func (c *Client) SendVerifyCode(ctx context.Context, phone string) error {
resp, err := c.doRequest(ctx, "POST", "/api/v1/auth/phone/send-code", map[string]string{"phone": phone})
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// LoginWithPhone 手机号登录
func (c *Client) LoginWithPhone(ctx context.Context, phone, code string) (*LoginResponse, error) {
req := map[string]string{
"phone": phone,
"code": code,
}
resp, err := c.doRequest(ctx, "POST", "/api/v1/auth/login/phone", req)
if err != nil {
return nil, err
}
var result LoginResponse
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
if result.AccessToken != "" {
c.SetAccessToken(result.AccessToken)
}
return &result, nil
}
// OAuthURL 获取 OAuth 授权 URL
func (c *Client) OAuthURL(provider string, redirectURI, state string) (string, error) {
params := map[string]string{
"provider": provider,
"redirect_uri": redirectURI,
}
if state != "" {
params["state"] = state
}
query := ""
for k, v := range params {
if query != "" {
query += "&"
}
query += k + "=" + v
}
return fmt.Sprintf("%s/api/v1/auth/oauth/authorize?%s", c.baseURL, query), nil
}
// HandleOAuthCallback 处理 OAuth 回调
func (c *Client) HandleOAuthCallback(ctx context.Context, provider, code string) (*LoginResponse, error) {
req := map[string]string{
"provider": provider,
"code": code,
}
resp, err := c.doRequest(ctx, "POST", "/api/v1/auth/oauth/callback", req)
if err != nil {
return nil, err
}
var result LoginResponse
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
if result.AccessToken != "" {
c.SetAccessToken(result.AccessToken)
}
return &result, nil
}

View File

@@ -0,0 +1,144 @@
package userManagement
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
)
// Client API 客户端
type Client struct {
baseURL string
httpClient *http.Client
accessToken string
apiKey string
}
// ClientOption 配置选项
type ClientOption func(*Client)
// WithAPIToken 设置 API Token用于简单认证
func WithAPIToken(token string) ClientOption {
return func(c *Client) {
c.apiKey = token
}
}
// WithAccessToken 设置 Access Token用于已认证请求
func WithAccessToken(token string) ClientOption {
return func(c *Client) {
c.accessToken = token
}
}
// WithHTTPClient 设置自定义 HTTP 客户端
func WithHTTPClient(httpClient *http.Client) ClientOption {
return func(c *Client) {
c.httpClient = httpClient
}
}
// NewClient 创建新的 API 客户端
func NewClient(baseURL string, opts ...ClientOption) *Client {
c := &Client{
baseURL: baseURL,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
for _, opt := range opts {
opt(c)
}
return c
}
// APIResponse 标准 API 响应
type APIResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data,omitempty"`
}
// ErrorResponse 错误响应
type ErrorResponse struct {
Code int `json:"code"`
Message string `json:"message"`
}
func (c *Client) doRequest(ctx context.Context, method, path string, body interface{}) (*http.Response, error) {
u, err := url.JoinPath(c.baseURL, path)
if err != nil {
return nil, fmt.Errorf("failed to join URL: %w", err)
}
var reqBody io.Reader
if body != nil {
jsonData, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
reqBody = bytes.NewReader(jsonData)
}
req, err := http.NewRequestWithContext(ctx, method, u, reqBody)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if c.accessToken != "" {
req.Header.Set("Authorization", "Bearer "+c.accessToken)
} else if c.apiKey != "" {
req.Header.Set("X-API-Key", c.apiKey)
}
return c.httpClient.Do(req)
}
func (c *Client) parseResponse(resp *http.Response, result interface{}) error {
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode >= 400 {
var errResp ErrorResponse
if err := json.Unmarshal(body, &errResp); err == nil {
return fmt.Errorf("API error %d: %s", resp.StatusCode, errResp.Message)
}
return fmt.Errorf("API error %d: %s", resp.StatusCode, string(body))
}
if result == nil {
return nil
}
var apiResp APIResponse
if err := json.Unmarshal(body, &apiResp); err != nil {
return fmt.Errorf("failed to unmarshal response: %w", err)
}
if apiResp.Data != nil {
if err := json.Unmarshal(apiResp.Data, result); err != nil {
return fmt.Errorf("failed to unmarshal data: %w", err)
}
}
return nil
}
// SetAccessToken 设置访问令牌
func (c *Client) SetAccessToken(token string) {
c.accessToken = token
}

View File

@@ -0,0 +1,138 @@
package userManagement
import (
"context"
"fmt"
)
// ListDevicesParams 设备列表查询参数
type ListDevicesParams struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
UserID int64 `json:"user_id,omitempty"`
IsActive *bool `json:"is_active,omitempty"`
IsTrusted *bool `json:"is_trusted,omitempty"`
}
// GetMyDevices 获取当前用户的设备列表
func (c *Client) GetMyDevices(ctx context.Context) ([]*Device, error) {
resp, err := c.doRequest(ctx, "GET", "/api/v1/devices/me", nil)
if err != nil {
return nil, err
}
var result []*Device
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return result, nil
}
// GetTrustedDevices 获取信任设备列表
func (c *Client) GetTrustedDevices(ctx context.Context) ([]*Device, error) {
resp, err := c.doRequest(ctx, "GET", "/api/v1/devices/me/trusted", nil)
if err != nil {
return nil, err
}
var result []*Device
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return result, nil
}
// GetDevice 获取设备详情
func (c *Client) GetDevice(ctx context.Context, id int64) (*Device, error) {
resp, err := c.doRequest(ctx, "GET", fmt.Sprintf("/api/v1/devices/%d", id), nil)
if err != nil {
return nil, err
}
var result Device
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// ListDevices 获取设备列表(管理员用)
func (c *Client) ListDevices(ctx context.Context, params *ListDevicesParams) (*PaginatedResponse, error) {
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 20
}
path := fmt.Sprintf("/api/v1/admin/devices?page=%d&page_size=%d", params.Page, params.PageSize)
if params.UserID > 0 {
path += fmt.Sprintf("&user_id=%d", params.UserID)
}
resp, err := c.doRequest(ctx, "GET", path, nil)
if err != nil {
return nil, err
}
var result PaginatedResponse
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// TrustDevice 信任设备
func (c *Client) TrustDevice(ctx context.Context, deviceID int64) error {
resp, err := c.doRequest(ctx, "POST", fmt.Sprintf("/api/v1/devices/%d/trust", deviceID), nil)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// UntrustDevice 取消设备信任
func (c *Client) UntrustDevice(ctx context.Context, deviceID int64) error {
resp, err := c.doRequest(ctx, "DELETE", fmt.Sprintf("/api/v1/devices/%d/trust", deviceID), nil)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// TrustDeviceByDeviceID 通过 device_id 信任设备
func (c *Client) TrustDeviceByDeviceID(ctx context.Context, deviceID string) error {
resp, err := c.doRequest(ctx, "POST", fmt.Sprintf("/api/v1/devices/by-device-id/%s/trust", deviceID), nil)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// RevokeDevice 撤销设备
func (c *Client) RevokeDevice(ctx context.Context, deviceID int64) error {
resp, err := c.doRequest(ctx, "DELETE", fmt.Sprintf("/api/v1/devices/%d", deviceID), nil)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// LogoutOtherDevices 登出其他设备
func (c *Client) LogoutOtherDevices(ctx context.Context, currentDeviceID string) error {
req := map[string]string{"current_device_id": currentDeviceID}
resp, err := c.doRequest(ctx, "POST", "/api/v1/devices/me/logout-others", req)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}

View File

@@ -0,0 +1,135 @@
package userManagement
import (
"context"
"fmt"
"log"
)
// Example_basic_usage 基础使用示例
func Example_basic_usage() {
// 创建客户端
client := NewClient("https://api.example.com")
// 登录
loginResp, err := client.Login(context.Background(), &LoginRequest{
Username: "admin",
Password: "password123",
DeviceName: "Go SDK Test",
})
if err != nil {
log.Fatalf("Login failed: %v", err)
}
fmt.Printf("Logged in as %s, token: %s...\n", loginResp.User.Username, loginResp.AccessToken[:20])
}
// Example_user_management 用户管理示例
func Example_user_management() {
client := NewClient("https://api.example.com", WithAPIToken("your-api-token"))
// 获取当前用户
user, err := client.GetCurrentUser(context.Background())
if err != nil {
log.Fatalf("GetCurrentUser failed: %v", err)
}
fmt.Printf("Current user: %s (%s)\n", user.Username, user.Email)
// 创建新用户
newUser, err := client.CreateUser(context.Background(), &CreateUserRequest{
Username: "newuser",
Email: "newuser@example.com",
Password: "SecurePass123!",
Status: UserStatusActive,
})
if err != nil {
log.Fatalf("CreateUser failed: %v", err)
}
fmt.Printf("Created user: %s (ID: %d)\n", newUser.Username, newUser.ID)
// 更新用户
updatedUser, err := client.UpdateUser(context.Background(), newUser.ID, &UpdateUserRequest{
Nickname: "New Nickname",
})
if err != nil {
log.Fatalf("UpdateUser failed: %v", err)
}
fmt.Printf("Updated nickname: %s\n", updatedUser.Nickname)
// 删除用户
if err := client.DeleteUser(context.Background(), newUser.ID); err != nil {
log.Fatalf("DeleteUser failed: %v", err)
}
fmt.Printf("User %d deleted\n", newUser.ID)
}
// Example_device_management 设备管理示例
func Example_device_management() {
client := NewClient("https://api.example.com", WithAccessToken("access-token"))
// 获取我的设备
devices, err := client.GetMyDevices(context.Background())
if err != nil {
log.Fatalf("GetMyDevices failed: %v", err)
}
fmt.Printf("My devices (%d):\n", len(devices))
for _, d := range devices {
trustStatus := "untrusted"
if d.IsTrusted {
trustStatus = "trusted"
}
fmt.Printf(" - %s (%s) [%s]\n", d.DeviceName, d.DeviceType, trustStatus)
}
// 获取信任设备
trusted, err := client.GetTrustedDevices(context.Background())
if err != nil {
log.Fatalf("GetTrustedDevices failed: %v", err)
}
fmt.Printf("Trusted devices: %d\n", len(trusted))
}
// Example_role_management 角色管理示例
func Example_role_management() {
client := NewClient("https://api.example.com", WithAccessToken("access-token"))
// 获取角色列表
roles, err := client.ListRoles(context.Background(), &ListRolesParams{
Page: 1,
PageSize: 20,
})
if err != nil {
log.Fatalf("ListRoles failed: %v", err)
}
fmt.Printf("Total roles: %d\n", roles.Total)
// 获取权限树
permissions, err := client.ListPermissions(context.Background())
if err != nil {
log.Fatalf("ListPermissions failed: %v", err)
}
fmt.Printf("Total permissions: %d\n", len(permissions))
}
// Example_totp TOTP 两因素认证示例
func Example_totp() {
client := NewClient("https://api.example.com", WithAccessToken("access-token"))
// 启用 TOTP
setup, err := client.EnableTOTP(context.Background())
if err != nil {
log.Fatalf("EnableTOTP failed: %v", err)
}
fmt.Printf("TOTP Secret: %s\n", setup.Secret)
fmt.Printf("QR Code URL: %s\n", setup.QRCodeURL)
fmt.Printf("Recovery Codes: %v\n", setup.RecoveryCodes)
// 用户手动验证 TOTP 后才能正式启用
// 这里用示例 code 验证
if err := client.VerifyTOTP(context.Background(), "123456"); err != nil {
fmt.Printf("TOTP verification: %v\n", err)
} else {
fmt.Println("TOTP verified successfully")
}
}

View File

@@ -0,0 +1,3 @@
module github.com/user-management-system/sdk/go
go 1.21

View File

@@ -0,0 +1,135 @@
package userManagement
import (
"context"
"fmt"
"time"
)
// ListLoginLogsParams 登录日志查询参数
type ListLoginLogsParams struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
UserID int64 `json:"user_id,omitempty"`
Status int `json:"status,omitempty"`
StartAt *time.Time `json:"start_at,omitempty"`
EndAt *time.Time `json:"end_at,omitempty"`
}
// ListOperationLogsParams 操作日志查询参数
type ListOperationLogsParams struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
UserID int64 `json:"user_id,omitempty"`
Action string `json:"action,omitempty"`
Resource string `json:"resource,omitempty"`
StartAt *time.Time `json:"start_at,omitempty"`
EndAt *time.Time `json:"end_at,omitempty"`
}
// GetLoginLogs 获取登录日志列表
func (c *Client) GetLoginLogs(ctx context.Context, params *ListLoginLogsParams) (*PaginatedResponse, error) {
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 20
}
path := fmt.Sprintf("/api/v1/logs/login?page=%d&page_size=%d", params.Page, params.PageSize)
if params.UserID > 0 {
path += fmt.Sprintf("&user_id=%d", params.UserID)
}
if params.Status > 0 {
path += fmt.Sprintf("&status=%d", params.Status)
}
resp, err := c.doRequest(ctx, "GET", path, nil)
if err != nil {
return nil, err
}
var result PaginatedResponse
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// GetOperationLogs 获取操作日志列表
func (c *Client) GetOperationLogs(ctx context.Context, params *ListOperationLogsParams) (*PaginatedResponse, error) {
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 20
}
path := fmt.Sprintf("/api/v1/logs/operation?page=%d&page_size=%d", params.Page, params.PageSize)
if params.UserID > 0 {
path += fmt.Sprintf("&user_id=%d", params.UserID)
}
if params.Action != "" {
path += "&action=" + params.Action
}
if params.Resource != "" {
path += "&resource=" + params.Resource
}
resp, err := c.doRequest(ctx, "GET", path, nil)
if err != nil {
return nil, err
}
var result PaginatedResponse
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// ExportLoginLogsRequest 导出登录日志请求
type ExportLoginLogsRequest struct {
Format string `json:"format"` // "xlsx" or "csv"
UserID int64 `json:"user_id,omitempty"`
Status int `json:"status,omitempty"`
StartAt *time.Time `json:"start_at,omitempty"`
EndAt *time.Time `json:"end_at,omitempty"`
Fields string `json:"fields,omitempty"`
}
// ExportLoginLogs 导出登录日志(返回下载 URL
func (c *Client) ExportLoginLogs(ctx context.Context, req *ExportLoginLogsRequest) (string, error) {
resp, err := c.doRequest(ctx, "GET", "/api/v1/logs/login/export", req)
if err != nil {
return "", err
}
var result map[string]string
if err := c.parseResponse(resp, &result); err != nil {
return "", err
}
if url, ok := result["download_url"]; ok {
return url, nil
}
return "", nil
}
// GetStats 获取统计信息
func (c *Client) GetStats(ctx context.Context) (*Stats, error) {
resp, err := c.doRequest(ctx, "GET", "/api/v1/stats/dashboard", nil)
if err != nil {
return nil, err
}
var result Stats
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}

View File

@@ -0,0 +1,157 @@
package userManagement
import (
"context"
"fmt"
)
// CreateRoleRequest 创建角色请求
type CreateRoleRequest struct {
Name string `json:"name"`
Code string `json:"code"`
Description string `json:"description,omitempty"`
PermissionIDs []int64 `json:"permission_ids,omitempty"`
Status RoleStatus `json:"status,omitempty"`
}
// UpdateRoleRequest 更新角色请求
type UpdateRoleRequest struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
PermissionIDs []int64 `json:"permission_ids,omitempty"`
Status RoleStatus `json:"status,omitempty"`
}
// ListRolesParams 角色列表查询参数
type ListRolesParams struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
Keyword string `json:"keyword,omitempty"`
Status string `json:"status,omitempty"`
}
// GetRole 获取角色详情
func (c *Client) GetRole(ctx context.Context, id int64) (*Role, error) {
resp, err := c.doRequest(ctx, "GET", fmt.Sprintf("/api/v1/roles/%d", id), nil)
if err != nil {
return nil, err
}
var result Role
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// ListRoles 获取角色列表
func (c *Client) ListRoles(ctx context.Context, params *ListRolesParams) (*PaginatedResponse, error) {
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 20
}
path := fmt.Sprintf("/api/v1/roles?page=%d&page_size=%d", params.Page, params.PageSize)
if params.Keyword != "" {
path += "&keyword=" + params.Keyword
}
if params.Status != "" {
path += "&status=" + params.Status
}
resp, err := c.doRequest(ctx, "GET", path, nil)
if err != nil {
return nil, err
}
var result PaginatedResponse
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// CreateRole 创建角色
func (c *Client) CreateRole(ctx context.Context, req *CreateRoleRequest) (*Role, error) {
resp, err := c.doRequest(ctx, "POST", "/api/v1/roles", req)
if err != nil {
return nil, err
}
var result Role
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// UpdateRole 更新角色
func (c *Client) UpdateRole(ctx context.Context, id int64, req *UpdateRoleRequest) (*Role, error) {
resp, err := c.doRequest(ctx, "PUT", fmt.Sprintf("/api/v1/roles/%d", id), req)
if err != nil {
return nil, err
}
var result Role
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// DeleteRole 删除角色
func (c *Client) DeleteRole(ctx context.Context, id int64) error {
resp, err := c.doRequest(ctx, "DELETE", fmt.Sprintf("/api/v1/roles/%d", id), nil)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// AssignPermissions 分配权限给角色
func (c *Client) AssignPermissions(ctx context.Context, roleID int64, permissionIDs []int64) error {
req := map[string][]int64{"permission_ids": permissionIDs}
resp, err := c.doRequest(ctx, "POST", fmt.Sprintf("/api/v1/roles/%d/permissions", roleID), req)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// GetRolePermissions 获取角色权限
func (c *Client) GetRolePermissions(ctx context.Context, roleID int64) ([]*Permission, error) {
resp, err := c.doRequest(ctx, "GET", fmt.Sprintf("/api/v1/roles/%d/permissions", roleID), nil)
if err != nil {
return nil, err
}
var result []*Permission
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return result, nil
}
// ListPermissions 获取权限列表(树形)
func (c *Client) ListPermissions(ctx context.Context) ([]*Permission, error) {
resp, err := c.doRequest(ctx, "GET", "/api/v1/permissions", nil)
if err != nil {
return nil, err
}
var result []*Permission
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return result, nil
}

View File

@@ -0,0 +1,171 @@
package userManagement
import "time"
// User 用户
type User struct {
ID int64 `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Phone string `json:"phone,omitempty"`
Nickname string `json:"nickname,omitempty"`
Avatar string `json:"avatar,omitempty"`
Status UserStatus `json:"status"`
RoleIDs []int64 `json:"role_ids,omitempty"`
Roles []*Role `json:"roles,omitempty"`
IsSuperAdmin bool `json:"is_super_admin"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
LastLoginAt *time.Time `json:"last_login_at,omitempty"`
}
// UserStatus 用户状态
type UserStatus string
const (
UserStatusActive UserStatus = "active"
UserStatusInactive UserStatus = "inactive"
UserStatusBanned UserStatus = "banned"
)
// Role 角色
type Role struct {
ID int64 `json:"id"`
Name string `json:"name"`
Code string `json:"code"`
Description string `json:"description,omitempty"`
Status RoleStatus `json:"status"`
Permissions []*Permission `json:"permissions,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// RoleStatus 角色状态
type RoleStatus string
const (
RoleStatusActive RoleStatus = "active"
RoleStatusInactive RoleStatus = "inactive"
)
// Permission 权限
type Permission struct {
ID int64 `json:"id"`
Name string `json:"name"`
Code string `json:"code"`
Description string `json:"description,omitempty"`
Type PermissionType `json:"type"`
ParentID *int64 `json:"parent_id,omitempty"`
Children []*Permission `json:"children,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// PermissionType 权限类型
type PermissionType string
const (
PermissionTypeMenu PermissionType = "menu"
PermissionTypeAction PermissionType = "action"
PermissionTypeAPI PermissionType = "api"
)
// Device 设备
type Device struct {
ID int64 `json:"id"`
DeviceID string `json:"device_id"`
DeviceName string `json:"device_name"`
DeviceType DeviceType `json:"device_type"`
DeviceOS string `json:"device_os"`
DeviceBrowser string `json:"device_browser"`
IP string `json:"ip"`
Location string `json:"location,omitempty"`
IsTrusted bool `json:"is_trusted"`
IsActive bool `json:"is_active"`
LastActiveAt time.Time `json:"last_active_at"`
CreatedAt time.Time `json:"created_at"`
UserID int64 `json:"user_id"`
}
// DeviceType 设备类型
type DeviceType string
const (
DeviceTypeDesktop DeviceType = "desktop"
DeviceTypeMobile DeviceType = "mobile"
DeviceTypeTablet DeviceType = "tablet"
)
// LoginLog 登录日志
type LoginLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Username string `json:"username"`
IP string `json:"ip"`
Location string `json:"location,omitempty"`
DeviceID string `json:"device_id"`
DeviceName string `json:"device_name"`
Status LoginStatus `json:"status"`
FailReason string `json:"fail_reason,omitempty"`
LoginMethod string `json:"login_method"`
CreatedAt time.Time `json:"created_at"`
}
// LoginStatus 登录状态
type LoginStatus int
const (
LoginStatusFailed LoginStatus = 0
LoginStatusSuccess LoginStatus = 1
)
// OperationLog 操作日志
type OperationLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Username string `json:"username"`
Action string `json:"action"`
Resource string `json:"resource"`
ResourceID *int64 `json:"resource_id,omitempty"`
Details string `json:"details,omitempty"`
IP string `json:"ip"`
Status int `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
// Webhook Webhook 配置
type Webhook struct {
ID int64 `json:"id"`
Name string `json:"name"`
URL string `json:"url"`
Events []string `json:"events"`
Secret string `json:"secret,omitempty"`
IsActive bool `json:"is_active"`
RetryCount int `json:"retry_count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// Stats 统计信息
type Stats struct {
TotalUsers int64 `json:"total_users"`
ActiveUsers int64 `json:"active_users"`
TotalDevices int64 `json:"total_devices"`
ActiveDevices int64 `json:"active_devices"`
TodayLogins int64 `json:"today_logins"`
TodayFailLogins int64 `json:"today_fail_logins"`
}
// Pagination 分页参数
type Pagination struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// PaginatedResponse 分页响应
type PaginatedResponse struct {
Items interface{} `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Pages int `json:"pages"`
}

View File

@@ -0,0 +1,247 @@
package userManagement
import (
"context"
"fmt"
)
// CreateUserRequest 创建用户请求
type CreateUserRequest struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
Phone string `json:"phone,omitempty"`
Nickname string `json:"nickname,omitempty"`
RoleIDs []int64 `json:"role_ids,omitempty"`
Status UserStatus `json:"status,omitempty"`
}
// UpdateUserRequest 更新用户请求
type UpdateUserRequest struct {
Email string `json:"email,omitempty"`
Phone string `json:"phone,omitempty"`
Nickname string `json:"nickname,omitempty"`
Avatar string `json:"avatar,omitempty"`
Status UserStatus `json:"status,omitempty"`
}
// AssignRolesRequest 分配角色请求
type AssignRolesRequest struct {
RoleIDs []int64 `json:"role_ids"`
}
// ListUsersParams 用户列表查询参数
type ListUsersParams struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
Keyword string `json:"keyword,omitempty"`
Status string `json:"status,omitempty"`
}
// GetCurrentUser 获取当前登录用户
func (c *Client) GetCurrentUser(ctx context.Context) (*User, error) {
resp, err := c.doRequest(ctx, "GET", "/api/v1/users/me", nil)
if err != nil {
return nil, err
}
var result User
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// GetUser 获取用户详情
func (c *Client) GetUser(ctx context.Context, id int64) (*User, error) {
resp, err := c.doRequest(ctx, "GET", fmt.Sprintf("/api/v1/users/%d", id), nil)
if err != nil {
return nil, err
}
var result User
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// ListUsers 获取用户列表
func (c *Client) ListUsers(ctx context.Context, params *ListUsersParams) (*PaginatedResponse, error) {
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 20
}
path := fmt.Sprintf("/api/v1/users?page=%d&page_size=%d", params.Page, params.PageSize)
if params.Keyword != "" {
path += "&keyword=" + params.Keyword
}
if params.Status != "" {
path += "&status=" + params.Status
}
resp, err := c.doRequest(ctx, "GET", path, nil)
if err != nil {
return nil, err
}
var result PaginatedResponse
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// CreateUser 创建用户
func (c *Client) CreateUser(ctx context.Context, req *CreateUserRequest) (*User, error) {
resp, err := c.doRequest(ctx, "POST", "/api/v1/users", req)
if err != nil {
return nil, err
}
var result User
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// UpdateUser 更新用户
func (c *Client) UpdateUser(ctx context.Context, id int64, req *UpdateUserRequest) (*User, error) {
resp, err := c.doRequest(ctx, "PUT", fmt.Sprintf("/api/v1/users/%d", id), req)
if err != nil {
return nil, err
}
var result User
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// DeleteUser 删除用户
func (c *Client) DeleteUser(ctx context.Context, id int64) error {
resp, err := c.doRequest(ctx, "DELETE", fmt.Sprintf("/api/v1/users/%d", id), nil)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// AssignRoles 分配角色
func (c *Client) AssignRoles(ctx context.Context, userID int64, req *AssignRolesRequest) error {
resp, err := c.doRequest(ctx, "POST", fmt.Sprintf("/api/v1/users/%d/roles", userID), req)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// GetUserRoles 获取用户角色
func (c *Client) GetUserRoles(ctx context.Context, userID int64) ([]*Role, error) {
resp, err := c.doRequest(ctx, "GET", fmt.Sprintf("/api/v1/users/%d/roles", userID), nil)
if err != nil {
return nil, err
}
var result []*Role
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return result, nil
}
// UpdatePassword 更新密码
func (c *Client) UpdatePassword(ctx context.Context, oldPassword, newPassword string) error {
req := map[string]string{
"old_password": oldPassword,
"new_password": newPassword,
}
resp, err := c.doRequest(ctx, "POST", "/api/v1/users/me/password", req)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// BindEmail 绑定邮箱
func (c *Client) BindEmail(ctx context.Context, email string) error {
req := map[string]string{"email": email}
resp, err := c.doRequest(ctx, "POST", "/api/v1/users/me/email", req)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// BindPhone 绑定手机
func (c *Client) BindPhone(ctx context.Context, phone, code string) error {
req := map[string]string{
"phone": phone,
"code": code,
}
resp, err := c.doRequest(ctx, "POST", "/api/v1/users/me/phone", req)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// EnableTOTP 启用 TOTP
func (c *Client) EnableTOTP(ctx context.Context) (*TOTPSetupResponse, error) {
resp, err := c.doRequest(ctx, "POST", "/api/v1/users/me/2fa/totp/setup", nil)
if err != nil {
return nil, err
}
var result TOTPSetupResponse
if err := c.parseResponse(resp, &result); err != nil {
return nil, err
}
return &result, nil
}
// TOTPSetupResponse TOTP 设置响应
type TOTPSetupResponse struct {
Secret string `json:"secret"`
QRCodeURL string `json:"qr_code_url"`
RecoveryCodes []string `json:"recovery_codes,omitempty"`
}
// VerifyTOTP 验证 TOTP
func (c *Client) VerifyTOTP(ctx context.Context, code string) error {
req := map[string]string{"code": code}
resp, err := c.doRequest(ctx, "POST", "/api/v1/users/me/2fa/totp/verify", req)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}
// DisableTOTP 禁用 TOTP
func (c *Client) DisableTOTP(ctx context.Context, code string) error {
req := map[string]string{"code": code}
resp, err := c.doRequest(ctx, "POST", "/api/v1/users/me/2fa/totp/disable", req)
if err != nil {
return err
}
return c.parseResponse(resp, nil)
}