fix: 生产安全修复 + Go SDK + CAS SSO框架
安全修复: - CRITICAL: SSO重定向URL注入漏洞 - 修复redirect_uri白名单验证 - HIGH: SSO ClientSecret未验证 - 使用crypto/subtle.ConstantTimeCompare验证 - HIGH: 邮件验证码熵值过低(3字节) - 提升到6字节(48位熵) - HIGH: 短信验证码熵值过低(4字节) - 提升到6字节 - HIGH: Goroutine使用已取消上下文 - auth_email.go使用独立context+超时 - HIGH: SQL LIKE查询注入风险 - permission/role仓库使用escapeLikePattern 新功能: - Go SDK: sdk/go/user-management/ 完整SDK实现 - CAS SSO框架: internal/auth/cas.go CAS协议支持 其他: - L1Cache实例问题修复 - AuthMiddleware共享l1Cache - 设备指纹XSS防护 - 内存存储替代localStorage - 响应格式协议中间件 - 导出无界查询修复
This commit is contained in:
221
internal/auth/cas.go
Normal file
221
internal/auth/cas.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CASProvider CAS (Central Authentication Service) 提供者
|
||||
// CAS 是一种单点登录协议,用户只需登录一次即可访问多个应用
|
||||
type CASProvider struct {
|
||||
serverURL string
|
||||
serviceURL string
|
||||
}
|
||||
|
||||
// CASServiceTicket CAS 服务票据
|
||||
type CASServiceTicket struct {
|
||||
Ticket string
|
||||
Service string
|
||||
UserID int64
|
||||
Username string
|
||||
IssuedAt time.Time
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
// NewCASProvider 创建 CAS 提供者
|
||||
func NewCASProvider(serverURL, serviceURL string) *CASProvider {
|
||||
return &CASProvider{
|
||||
serverURL: strings.TrimSuffix(serverURL, "/"),
|
||||
serviceURL: serviceURL,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildLoginURL 构建 CAS 登录 URL
|
||||
// 用于重定向用户到 CAS 登录页面
|
||||
func (p *CASProvider) BuildLoginURL(renew, gateway bool) string {
|
||||
params := url.Values{}
|
||||
params.Set("service", p.serviceURL)
|
||||
if renew {
|
||||
params.Set("renew", "true")
|
||||
}
|
||||
if gateway {
|
||||
params.Set("gateway", "true")
|
||||
}
|
||||
return fmt.Sprintf("%s/login?%s", p.serverURL, params.Encode())
|
||||
}
|
||||
|
||||
// BuildLogoutURL 构建 CAS 登出 URL
|
||||
func (p *CASProvider) BuildLogoutURL(url string) string {
|
||||
if url != "" {
|
||||
return fmt.Sprintf("%s/logout?service=%s", p.serverURL, url)
|
||||
}
|
||||
return fmt.Sprintf("%s/logout", p.serverURL)
|
||||
}
|
||||
|
||||
// CASValidationResponse CAS 票据验证响应
|
||||
type CASValidationResponse struct {
|
||||
Success bool
|
||||
UserID int64
|
||||
Username string
|
||||
ErrorCode string
|
||||
ErrorMsg string
|
||||
}
|
||||
|
||||
// ValidateTicket 验证 CAS 票据
|
||||
// 向 CAS 服务器发送 ticket 验证请求
|
||||
func (p *CASProvider) ValidateTicket(ctx context.Context, ticket string) (*CASValidationResponse, error) {
|
||||
if ticket == "" {
|
||||
return &CASValidationResponse{
|
||||
Success: false,
|
||||
ErrorCode: "INVALID_REQUEST",
|
||||
ErrorMsg: "ticket is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("service", p.serviceURL)
|
||||
params.Set("ticket", ticket)
|
||||
|
||||
validateURL := fmt.Sprintf("%s/p3/serviceValidate?%s", p.serverURL, params.Encode())
|
||||
|
||||
resp, err := fetchCASResponse(ctx, validateURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("CAS validation request failed: %w", err)
|
||||
}
|
||||
|
||||
return p.parseServiceValidateResponse(resp)
|
||||
}
|
||||
|
||||
// parseServiceValidateResponse 解析 CAS serviceValidate 响应
|
||||
// CAS 1.0 和 CAS 2.0 使用不同的响应格式
|
||||
func (p *CASProvider) parseServiceValidateResponse(xml string) (*CASValidationResponse, error) {
|
||||
resp := &CASValidationResponse{Success: false}
|
||||
|
||||
// 检查是否包含 authenticationSuccess 元素
|
||||
if strings.Contains(xml, "<authenticationSuccess>") {
|
||||
resp.Success = true
|
||||
|
||||
// 解析用户名
|
||||
if start := strings.Index(xml, "<user>"); start != -1 {
|
||||
end := strings.Index(xml[start:], "</user>")
|
||||
if end != -1 {
|
||||
resp.Username = xml[start+6 : start+end]
|
||||
}
|
||||
}
|
||||
|
||||
// 解析用户 ID (CAS 2.0)
|
||||
if start := strings.Index(xml, "<userId>"); start != -1 {
|
||||
end := strings.Index(xml[start:], "</userId>")
|
||||
if end != -1 {
|
||||
userIDStr := xml[start+8 : start+end]
|
||||
var userID int64
|
||||
fmt.Sscanf(userIDStr, "%d", &userID)
|
||||
resp.UserID = userID
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(xml, "<authenticationFailure>") {
|
||||
resp.Success = false
|
||||
|
||||
// 解析错误码
|
||||
if start := strings.Index(xml, "code=\""); start != -1 {
|
||||
start += 6
|
||||
end := strings.Index(xml[start:], "\"")
|
||||
if end != -1 {
|
||||
resp.ErrorCode = xml[start : start+end]
|
||||
}
|
||||
}
|
||||
|
||||
// 解析错误消息
|
||||
if start := strings.Index(xml, "<![CDATA["); start != -1 {
|
||||
end := strings.Index(xml[start:], "]]>")
|
||||
if end != -1 {
|
||||
resp.ErrorMsg = xml[start+9 : start+end]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// GenerateProxyTicket 生成代理票据 (CAS 2.0)
|
||||
// 用于服务代理用户访问其他服务
|
||||
func (p *CASProvider) GenerateProxyTicket(ctx context.Context, proxyGrantingTicket, targetService string) (string, error) {
|
||||
params := url.Values{}
|
||||
params.Set("targetService", targetService)
|
||||
|
||||
proxyURL := fmt.Sprintf("%s/p3/proxy?%s&pgt=%s",
|
||||
p.serverURL, params.Encode(), proxyGrantingTicket)
|
||||
|
||||
resp, err := fetchCASResponse(ctx, proxyURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 解析代理票据
|
||||
if start := strings.Index(resp, "<proxyTicket>"); start != -1 {
|
||||
end := strings.Index(resp[start:], "</proxyTicket>")
|
||||
if end != -1 {
|
||||
return resp[start+12 : start+end], nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("failed to parse proxy ticket from response")
|
||||
}
|
||||
|
||||
// fetchCASResponse 从 CAS 服务器获取响应
|
||||
func fetchCASResponse(ctx context.Context, url string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Accept", "application/xml")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
// GenerateCASServiceTicket 生成 CAS 服务票据 (供 CAS 服务器使用)
|
||||
// 这个方法供实际的 CAS 服务器实现调用
|
||||
func GenerateCASServiceTicket(service string, userID int64, username string) (*CASServiceTicket, error) {
|
||||
ticketBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(ticketBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ticket: %w", err)
|
||||
}
|
||||
|
||||
return &CASServiceTicket{
|
||||
Ticket: "ST-" + base64.URLEncoding.EncodeToString(ticketBytes)[:32],
|
||||
Service: service,
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
IssuedAt: time.Now(),
|
||||
Expiry: time.Now().Add(5 * time.Minute),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IsExpired 检查票据是否过期
|
||||
func (t *CASServiceTicket) IsExpired() bool {
|
||||
return time.Now().After(t.Expiry)
|
||||
}
|
||||
|
||||
// GetDuration 返回票据有效时长
|
||||
func (t *CASServiceTicket) GetDuration() time.Duration {
|
||||
return t.Expiry.Sub(t.IssuedAt)
|
||||
}
|
||||
Reference in New Issue
Block a user