Files
user-system/internal/auth/sso.go

234 lines
5.4 KiB
Go
Raw Normal View History

package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"time"
)
// SSOOAuth2Config SSO OAuth2 配置
type SSOOAuth2Config struct {
ClientID string
ClientSecret string
RedirectURI string
Scope string
}
// SSOProvider SSO 提供者接口
type SSOProvider interface {
// Authorize 处理授权请求
Authorize(ctx context.Context, req *SSOAuthorizeRequest) (*SSOAuthorizeResponse, error)
// Introspect 验证 access token
Introspect(ctx context.Context, token string) (*SSOTokenInfo, error)
// Revoke 撤销 token
Revoke(ctx context.Context, token string) error
}
// SSOAuthorizeRequest 授权请求
type SSOAuthorizeRequest struct {
ClientID string
RedirectURI string
ResponseType string // "code" 或 "token"
Scope string
State string
UserID int64
}
// SSOAuthorizeResponse 授权响应
type SSOAuthorizeResponse struct {
Code string // 授权码authorization_code 模式)
State string
}
// SSOTokenInfo Token 信息
type SSOTokenInfo struct {
Active bool
UserID int64
Username string
ExpiresAt time.Time
Scope string
ClientID string
}
// SSOSession SSO Session
type SSOSession struct {
SessionID string
UserID int64
Username string
ClientID string
CreatedAt time.Time
ExpiresAt time.Time
Scope string
}
// SSOManager SSO 管理器
type SSOManager struct {
sessions map[string]*SSOSession
}
// NewSSOManager 创建 SSO 管理器
func NewSSOManager() *SSOManager {
return &SSOManager{
sessions: make(map[string]*SSOSession),
}
}
// GenerateAuthorizationCode 生成授权码
func (m *SSOManager) GenerateAuthorizationCode(clientID, redirectURI, scope string, userID int64, username string) (string, error) {
code := generateSecureToken(32)
session := &SSOSession{
SessionID: generateSecureToken(16),
UserID: userID,
Username: username,
ClientID: clientID,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(10 * time.Minute), // 授权码 10 分钟有效期
Scope: scope,
}
m.sessions[code] = session
return code, nil
}
// ValidateAuthorizationCode 验证授权码
func (m *SSOManager) ValidateAuthorizationCode(code string) (*SSOSession, error) {
session, ok := m.sessions[code]
if !ok {
return nil, errors.New("invalid authorization code")
}
if time.Now().After(session.ExpiresAt) {
delete(m.sessions, code)
return nil, errors.New("authorization code expired")
}
// 使用后删除
delete(m.sessions, code)
return session, nil
}
// GenerateAccessToken 生成访问令牌
func (m *SSOManager) GenerateAccessToken(clientID string, session *SSOSession) (string, time.Time) {
token := generateSecureToken(32)
expiresAt := time.Now().Add(2 * time.Hour) // Access token 2 小时有效期
accessSession := &SSOSession{
SessionID: token,
UserID: session.UserID,
Username: session.Username,
ClientID: clientID,
CreatedAt: time.Now(),
ExpiresAt: expiresAt,
Scope: session.Scope,
}
m.sessions[token] = accessSession
return token, expiresAt
}
// IntrospectToken 验证 token
func (m *SSOManager) IntrospectToken(token string) (*SSOTokenInfo, error) {
session, ok := m.sessions[token]
if !ok {
return &SSOTokenInfo{Active: false}, nil
}
if time.Now().After(session.ExpiresAt) {
delete(m.sessions, token)
return &SSOTokenInfo{Active: false}, nil
}
return &SSOTokenInfo{
Active: true,
UserID: session.UserID,
Username: session.Username,
ExpiresAt: session.ExpiresAt,
Scope: session.Scope,
ClientID: session.ClientID,
}, nil
}
// RevokeToken 撤销 token
func (m *SSOManager) RevokeToken(token string) error {
delete(m.sessions, token)
return nil
}
// CleanupExpired 清理过期的 session可由后台 goroutine 定期调用)
func (m *SSOManager) CleanupExpired() {
now := time.Now()
for key, session := range m.sessions {
if now.After(session.ExpiresAt) {
delete(m.sessions, key)
}
}
}
// generateSecureToken 生成安全随机 token
func generateSecureToken(length int) string {
bytes := make([]byte, length)
rand.Read(bytes)
return base64.URLEncoding.EncodeToString(bytes)[:length]
}
// SSOClient SSO 客户端配置存储
type SSOClient struct {
ClientID string
ClientSecret string
Name string
RedirectURIs []string
}
// SSOClientsStore SSO 客户端存储接口
type SSOClientsStore interface {
GetByClientID(clientID string) (*SSOClient, error)
}
// DefaultSSOClientsStore 默认内存存储
type DefaultSSOClientsStore struct {
clients map[string]*SSOClient
}
// NewDefaultSSOClientsStore 创建默认客户端存储
func NewDefaultSSOClientsStore() *DefaultSSOClientsStore {
return &DefaultSSOClientsStore{
clients: make(map[string]*SSOClient),
}
}
// RegisterClient 注册客户端
func (s *DefaultSSOClientsStore) RegisterClient(client *SSOClient) {
s.clients[client.ClientID] = client
}
// GetByClientID 根据 ClientID 获取客户端
func (s *DefaultSSOClientsStore) GetByClientID(clientID string) (*SSOClient, error) {
client, ok := s.clients[clientID]
if !ok {
return nil, fmt.Errorf("client not found: %s", clientID)
}
return client, nil
}
// ValidateClientRedirectURI 验证客户端的 RedirectURI
func (s *DefaultSSOClientsStore) ValidateClientRedirectURI(clientID, redirectURI string) bool {
client, err := s.GetByClientID(clientID)
if err != nil {
return false
}
for _, uri := range client.RedirectURIs {
if uri == redirectURI {
return true
}
}
return false
}