Files
user-system/internal/auth/oauth.go

507 lines
14 KiB
Go
Raw Normal View History

package auth
import (
"context"
"errors"
"fmt"
"net/url"
"github.com/user-management-system/internal/auth/providers"
)
// OAuthProvider OAuth提供商类型
type OAuthProvider string
const (
OAuthProviderWeChat OAuthProvider = "wechat"
OAuthProviderQQ OAuthProvider = "qq"
OAuthProviderWeibo OAuthProvider = "weibo"
OAuthProviderGoogle OAuthProvider = "google"
OAuthProviderFacebook OAuthProvider = "facebook"
OAuthProviderTwitter OAuthProvider = "twitter"
OAuthProviderGitHub OAuthProvider = "github"
OAuthProviderAlipay OAuthProvider = "alipay"
OAuthProviderDouyin OAuthProvider = "douyin"
)
// OAuthUser OAuth用户信息
type OAuthUser struct {
Provider OAuthProvider `json:"provider"`
OpenID string `json:"open_id"`
UnionID string `json:"union_id,omitempty"`
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
Gender string `json:"gender,omitempty"`
Email string `json:"email,omitempty"`
Phone string `json:"phone,omitempty"`
Extra map[string]interface{} `json:"extra,omitempty"`
}
// OAuthToken OAuth令牌
type OAuthToken struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
OpenID string `json:"open_id,omitempty"` // 微信等需要 openid
}
// OAuthConfig OAuth配置
type OAuthConfig struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
RedirectURI string `json:"redirect_uri"`
Scope string `json:"scope"`
AuthURL string `json:"auth_url"`
TokenURL string `json:"token_url"`
UserInfoURL string `json:"user_info_url"`
}
// OAuthManager OAuth管理器接口
type OAuthManager interface {
// GetAuthURL 获取授权URL
GetAuthURL(provider OAuthProvider, state string) (string, error)
// ExchangeCode 换取访问令牌
ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error)
// GetUserInfo 获取用户信息
GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error)
// ValidateToken 验证令牌
ValidateToken(token string) (bool, error)
// GetConfig 获取OAuth配置
GetConfig(provider OAuthProvider) (*OAuthConfig, bool)
// GetEnabledProviders 获取已启用的OAuth提供商
GetEnabledProviders() []OAuthProviderInfo
}
// OAuthProviderInfo OAuth提供商信息
type OAuthProviderInfo struct {
Provider OAuthProvider `json:"provider"`
Enabled bool `json:"enabled"`
Name string `json:"name"`
}
// providerEntry 内部 provider 条目
type providerEntry struct {
config *OAuthConfig
google *providers.GoogleProvider
wechat *providers.WeChatProvider
wechatRedir string
qq *providers.QQProvider
github *providers.GitHubProvider
alipay *providers.AlipayProvider
douyin *providers.DouyinProvider
}
// DefaultOAuthManager 默认OAuth管理器集成真实 provider HTTP 调用)
type DefaultOAuthManager struct {
entries map[OAuthProvider]*providerEntry
}
// NewOAuthManager 创建OAuth管理器
func NewOAuthManager() *DefaultOAuthManager {
return &DefaultOAuthManager{
entries: make(map[OAuthProvider]*providerEntry),
}
}
// RegisterProvider 注册OAuth提供商保留旧接口仅存储配置
func (m *DefaultOAuthManager) RegisterProvider(provider OAuthProvider, config *OAuthConfig) {
entry := &providerEntry{config: config}
switch provider {
case OAuthProviderGoogle:
entry.google = providers.NewGoogleProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
case OAuthProviderWeChat:
entry.wechat = providers.NewWeChatProvider(config.ClientID, config.ClientSecret, "web")
entry.wechatRedir = config.RedirectURI
case OAuthProviderQQ:
entry.qq = providers.NewQQProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
case OAuthProviderGitHub:
entry.github = providers.NewGitHubProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
case OAuthProviderAlipay:
// 支付宝使用 ClientID 存储 AppIDClientSecret 存储 RSA 私钥
entry.alipay = providers.NewAlipayProvider(config.ClientID, config.ClientSecret, config.RedirectURI, false)
case OAuthProviderDouyin:
entry.douyin = providers.NewDouyinProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
}
m.entries[provider] = entry
}
// GetConfig 获取OAuth配置
func (m *DefaultOAuthManager) GetConfig(provider OAuthProvider) (*OAuthConfig, bool) {
entry, ok := m.entries[provider]
if !ok {
return nil, false
}
return entry.config, true
}
// GetAuthURL 获取授权URL使用真实 provider 实现)
func (m *DefaultOAuthManager) GetAuthURL(provider OAuthProvider, state string) (string, error) {
entry, ok := m.entries[provider]
if !ok {
return "", ErrOAuthProviderNotSupported
}
switch provider {
case OAuthProviderGoogle:
if entry.google != nil {
resp, err := entry.google.GetAuthURL(state)
if err != nil {
return "", err
}
return resp.URL, nil
}
case OAuthProviderWeChat:
if entry.wechat != nil {
resp, err := entry.wechat.GetAuthURL(entry.wechatRedir, state)
if err != nil {
return "", err
}
return resp.URL, nil
}
case OAuthProviderQQ:
if entry.qq != nil {
resp, err := entry.qq.GetAuthURL(state)
if err != nil {
return "", err
}
return resp.URL, nil
}
case OAuthProviderGitHub:
if entry.github != nil {
return entry.github.GetAuthURL(state)
}
case OAuthProviderAlipay:
if entry.alipay != nil {
return entry.alipay.GetAuthURL(state)
}
case OAuthProviderDouyin:
if entry.douyin != nil {
return entry.douyin.GetAuthURL(state)
}
}
// 通用 fallback按标准 OAuth2 拼接 URL对 QQ/微博/Twitter/Facebook
config := entry.config
if config == nil {
return "", ErrOAuthProviderNotSupported
}
return fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=%s&state=%s",
config.AuthURL,
url.QueryEscape(config.ClientID),
url.QueryEscape(config.RedirectURI),
url.QueryEscape(config.Scope),
url.QueryEscape(state),
), nil
}
// ExchangeCode 换取访问令牌(使用真实 provider 实现)
func (m *DefaultOAuthManager) ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error) {
entry, ok := m.entries[provider]
if !ok {
return nil, ErrOAuthProviderNotSupported
}
ctx := context.Background()
switch provider {
case OAuthProviderGoogle:
if entry.google != nil {
resp, err := entry.google.ExchangeCode(ctx, code)
if err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresIn: int64(resp.ExpiresIn),
TokenType: resp.TokenType,
}, nil
}
case OAuthProviderWeChat:
if entry.wechat != nil {
resp, err := entry.wechat.ExchangeCode(ctx, code)
if err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresIn: int64(resp.ExpiresIn),
TokenType: "Bearer",
OpenID: resp.OpenID,
}, nil
}
case OAuthProviderQQ:
if entry.qq != nil {
resp, err := entry.qq.ExchangeCode(ctx, code)
if err != nil {
return nil, err
}
openIDResp, err := entry.qq.GetOpenID(ctx, resp.AccessToken)
if err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresIn: int64(resp.ExpiresIn),
TokenType: "Bearer",
OpenID: openIDResp.OpenID,
}, nil
}
case OAuthProviderGitHub:
if entry.github != nil {
resp, err := entry.github.ExchangeCode(ctx, code)
if err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: resp.AccessToken,
TokenType: resp.TokenType,
}, nil
}
case OAuthProviderAlipay:
if entry.alipay != nil {
resp, err := entry.alipay.ExchangeCode(ctx, code)
if err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresIn: int64(resp.ExpiresIn),
TokenType: "Bearer",
OpenID: resp.UserID,
}, nil
}
case OAuthProviderDouyin:
if entry.douyin != nil {
resp, err := entry.douyin.ExchangeCode(ctx, code)
if err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: resp.Data.AccessToken,
RefreshToken: resp.Data.RefreshToken,
ExpiresIn: int64(resp.Data.ExpiresIn),
TokenType: "Bearer",
OpenID: resp.Data.OpenID,
}, nil
}
}
return nil, fmt.Errorf("provider %s: real HTTP exchange not implemented yet", provider)
}
// GetUserInfo 获取用户信息(使用真实 provider 实现)
func (m *DefaultOAuthManager) GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error) {
entry, ok := m.entries[provider]
if !ok {
return nil, ErrOAuthProviderNotSupported
}
ctx := context.Background()
switch provider {
case OAuthProviderGoogle:
if entry.google != nil {
info, err := entry.google.GetUserInfo(ctx, token.AccessToken)
if err != nil {
return nil, err
}
return &OAuthUser{
Provider: provider,
OpenID: info.ID,
Nickname: info.Name,
Avatar: info.Picture,
Email: info.Email,
}, nil
}
case OAuthProviderWeChat:
if entry.wechat != nil {
openID := token.OpenID
info, err := entry.wechat.GetUserInfo(ctx, token.AccessToken, openID)
if err != nil {
return nil, err
}
gender := ""
switch info.Sex {
case 1:
gender = "male"
case 2:
gender = "female"
}
return &OAuthUser{
Provider: provider,
OpenID: info.OpenID,
UnionID: info.UnionID,
Nickname: info.Nickname,
Avatar: info.HeadImgURL,
Gender: gender,
}, nil
}
case OAuthProviderQQ:
if entry.qq != nil {
info, err := entry.qq.GetUserInfo(ctx, token.AccessToken, token.OpenID)
if err != nil {
return nil, err
}
avatar := info.FigureURL2
if avatar == "" {
avatar = info.FigureURL1
}
if avatar == "" {
avatar = info.FigureURL
}
return &OAuthUser{
Provider: provider,
OpenID: token.OpenID,
Nickname: info.Nickname,
Avatar: avatar,
Gender: info.Gender,
Extra: map[string]interface{}{
"province": info.Province,
"city": info.City,
"year": info.Year,
},
}, nil
}
case OAuthProviderGitHub:
if entry.github != nil {
info, err := entry.github.GetUserInfo(ctx, token.AccessToken)
if err != nil {
return nil, err
}
nickname := info.Name
if nickname == "" {
nickname = info.Login
}
return &OAuthUser{
Provider: provider,
OpenID: fmt.Sprintf("%d", info.ID),
Nickname: nickname,
Email: info.Email,
}, nil
}
case OAuthProviderAlipay:
if entry.alipay != nil {
info, err := entry.alipay.GetUserInfo(ctx, token.AccessToken)
if err != nil {
return nil, err
}
return &OAuthUser{
Provider: provider,
OpenID: info.UserID,
Nickname: info.Nickname,
Avatar: info.Avatar,
}, nil
}
case OAuthProviderDouyin:
if entry.douyin != nil {
info, err := entry.douyin.GetUserInfo(ctx, token.AccessToken, token.OpenID)
if err != nil {
return nil, err
}
gender := ""
switch info.Data.Gender {
case 1:
gender = "male"
case 2:
gender = "female"
}
return &OAuthUser{
Provider: provider,
OpenID: info.Data.OpenID,
UnionID: info.Data.UnionID,
Nickname: info.Data.Nickname,
Avatar: info.Data.Avatar,
Gender: gender,
}, nil
}
}
return nil, fmt.Errorf("provider %s: real HTTP user info not implemented yet", provider)
}
// ValidateToken 验证令牌
// 注意:由于 ValidateToken 不持有 provider 上下文,无法进行真正的 token 验证
// 对于需要验证 token 的场景,应使用 GetUserInfo 通过 provider 的 userinfo 端点验证
// 如果没有可用的 provider返回错误
func (m *DefaultOAuthManager) ValidateToken(token string) (bool, error) {
if len(token) == 0 {
return false, nil
}
// 由于缺乏 provider 上下文,无法进行有意义的验证
// 遍历所有已启用的 provider尝试通过 GetUserInfo 验证
// 如果没有任何 provider 可用,返回错误而不是默认通过
providers := m.GetEnabledProviders()
if len(providers) == 0 {
return false, errors.New("no OAuth providers configured")
}
// 尝试任一 provider 的 userinfo 端点验证
tokenObj := &OAuthToken{AccessToken: token}
for _, p := range providers {
if _, err := m.GetUserInfo(p.Provider, tokenObj); err == nil {
return true, nil
}
}
return false, nil
}
// ValidateTokenWithProvider 通过指定 provider 验证令牌
func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider, token string) (bool, error) {
if token == "" {
return false, nil
}
cfg, ok := m.GetConfig(provider)
if !ok || cfg.ClientID == "" {
return false, fmt.Errorf("provider %s not configured", provider)
}
// 通过 provider 的 userinfo 端点验证 token
tokenObj := &OAuthToken{AccessToken: token}
_, err := m.GetUserInfo(provider, tokenObj)
if err != nil {
return false, err
}
return true, nil
}
// GetEnabledProviders 获取已启用的OAuth提供商
func (m *DefaultOAuthManager) GetEnabledProviders() []OAuthProviderInfo {
providerNames := map[OAuthProvider]string{
OAuthProviderGoogle: "Google",
OAuthProviderWeChat: "微信",
OAuthProviderQQ: "QQ",
OAuthProviderWeibo: "微博",
OAuthProviderFacebook: "Facebook",
OAuthProviderTwitter: "Twitter",
OAuthProviderGitHub: "GitHub",
OAuthProviderAlipay: "支付宝",
OAuthProviderDouyin: "抖音",
}
var result []OAuthProviderInfo
for provider, entry := range m.entries {
name := providerNames[provider]
if name == "" {
name = string(provider)
}
result = append(result, OAuthProviderInfo{
Provider: provider,
Enabled: entry.config != nil,
Name: name,
})
}
return result
}