Files
user-system/internal/auth/oauth.go

507 lines
14 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package auth
import (
"context"
"errors"
"fmt"
"net/url"
"github.com/user-management-system/internal/auth/providers"
)
// OAuthProvider OAuth提供商类型
type OAuthProvider string
const (
OAuthProviderWeChat OAuthProvider = "wechat"
OAuthProviderQQ OAuthProvider = "qq"
OAuthProviderWeibo OAuthProvider = "weibo"
OAuthProviderGoogle OAuthProvider = "google"
OAuthProviderFacebook OAuthProvider = "facebook"
OAuthProviderTwitter OAuthProvider = "twitter"
OAuthProviderGitHub OAuthProvider = "github"
OAuthProviderAlipay OAuthProvider = "alipay"
OAuthProviderDouyin OAuthProvider = "douyin"
)
// OAuthUser OAuth用户信息
type OAuthUser struct {
Provider OAuthProvider `json:"provider"`
OpenID string `json:"open_id"`
UnionID string `json:"union_id,omitempty"`
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
Gender string `json:"gender,omitempty"`
Email string `json:"email,omitempty"`
Phone string `json:"phone,omitempty"`
Extra map[string]interface{} `json:"extra,omitempty"`
}
// OAuthToken OAuth令牌
type OAuthToken struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
OpenID string `json:"open_id,omitempty"` // 微信等需要 openid
}
// OAuthConfig OAuth配置
type OAuthConfig struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
RedirectURI string `json:"redirect_uri"`
Scope string `json:"scope"`
AuthURL string `json:"auth_url"`
TokenURL string `json:"token_url"`
UserInfoURL string `json:"user_info_url"`
}
// OAuthManager OAuth管理器接口
type OAuthManager interface {
// GetAuthURL 获取授权URL
GetAuthURL(provider OAuthProvider, state string) (string, error)
// ExchangeCode 换取访问令牌
ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error)
// GetUserInfo 获取用户信息
GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error)
// ValidateToken 验证令牌
ValidateToken(token string) (bool, error)
// GetConfig 获取OAuth配置
GetConfig(provider OAuthProvider) (*OAuthConfig, bool)
// GetEnabledProviders 获取已启用的OAuth提供商
GetEnabledProviders() []OAuthProviderInfo
}
// OAuthProviderInfo OAuth提供商信息
type OAuthProviderInfo struct {
Provider OAuthProvider `json:"provider"`
Enabled bool `json:"enabled"`
Name string `json:"name"`
}
// providerEntry 内部 provider 条目
type providerEntry struct {
config *OAuthConfig
google *providers.GoogleProvider
wechat *providers.WeChatProvider
wechatRedir string
qq *providers.QQProvider
github *providers.GitHubProvider
alipay *providers.AlipayProvider
douyin *providers.DouyinProvider
}
// DefaultOAuthManager 默认OAuth管理器集成真实 provider HTTP 调用)
type DefaultOAuthManager struct {
entries map[OAuthProvider]*providerEntry
}
// NewOAuthManager 创建OAuth管理器
func NewOAuthManager() *DefaultOAuthManager {
return &DefaultOAuthManager{
entries: make(map[OAuthProvider]*providerEntry),
}
}
// RegisterProvider 注册OAuth提供商保留旧接口仅存储配置
func (m *DefaultOAuthManager) RegisterProvider(provider OAuthProvider, config *OAuthConfig) {
entry := &providerEntry{config: config}
switch provider {
case OAuthProviderGoogle:
entry.google = providers.NewGoogleProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
case OAuthProviderWeChat:
entry.wechat = providers.NewWeChatProvider(config.ClientID, config.ClientSecret, "web")
entry.wechatRedir = config.RedirectURI
case OAuthProviderQQ:
entry.qq = providers.NewQQProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
case OAuthProviderGitHub:
entry.github = providers.NewGitHubProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
case OAuthProviderAlipay:
// 支付宝使用 ClientID 存储 AppIDClientSecret 存储 RSA 私钥
entry.alipay = providers.NewAlipayProvider(config.ClientID, config.ClientSecret, config.RedirectURI, false)
case OAuthProviderDouyin:
entry.douyin = providers.NewDouyinProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
}
m.entries[provider] = entry
}
// GetConfig 获取OAuth配置
func (m *DefaultOAuthManager) GetConfig(provider OAuthProvider) (*OAuthConfig, bool) {
entry, ok := m.entries[provider]
if !ok {
return nil, false
}
return entry.config, true
}
// GetAuthURL 获取授权URL使用真实 provider 实现)
func (m *DefaultOAuthManager) GetAuthURL(provider OAuthProvider, state string) (string, error) {
entry, ok := m.entries[provider]
if !ok {
return "", ErrOAuthProviderNotSupported
}
switch provider {
case OAuthProviderGoogle:
if entry.google != nil {
resp, err := entry.google.GetAuthURL(state)
if err != nil {
return "", err
}
return resp.URL, nil
}
case OAuthProviderWeChat:
if entry.wechat != nil {
resp, err := entry.wechat.GetAuthURL(entry.wechatRedir, state)
if err != nil {
return "", err
}
return resp.URL, nil
}
case OAuthProviderQQ:
if entry.qq != nil {
resp, err := entry.qq.GetAuthURL(state)
if err != nil {
return "", err
}
return resp.URL, nil
}
case OAuthProviderGitHub:
if entry.github != nil {
return entry.github.GetAuthURL(state)
}
case OAuthProviderAlipay:
if entry.alipay != nil {
return entry.alipay.GetAuthURL(state)
}
case OAuthProviderDouyin:
if entry.douyin != nil {
return entry.douyin.GetAuthURL(state)
}
}
// 通用 fallback按标准 OAuth2 拼接 URL对 QQ/微博/Twitter/Facebook
config := entry.config
if config == nil {
return "", ErrOAuthProviderNotSupported
}
return fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=%s&state=%s",
config.AuthURL,
url.QueryEscape(config.ClientID),
url.QueryEscape(config.RedirectURI),
url.QueryEscape(config.Scope),
url.QueryEscape(state),
), nil
}
// ExchangeCode 换取访问令牌(使用真实 provider 实现)
func (m *DefaultOAuthManager) ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error) {
entry, ok := m.entries[provider]
if !ok {
return nil, ErrOAuthProviderNotSupported
}
ctx := context.Background()
switch provider {
case OAuthProviderGoogle:
if entry.google != nil {
resp, err := entry.google.ExchangeCode(ctx, code)
if err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresIn: int64(resp.ExpiresIn),
TokenType: resp.TokenType,
}, nil
}
case OAuthProviderWeChat:
if entry.wechat != nil {
resp, err := entry.wechat.ExchangeCode(ctx, code)
if err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresIn: int64(resp.ExpiresIn),
TokenType: "Bearer",
OpenID: resp.OpenID,
}, nil
}
case OAuthProviderQQ:
if entry.qq != nil {
resp, err := entry.qq.ExchangeCode(ctx, code)
if err != nil {
return nil, err
}
openIDResp, err := entry.qq.GetOpenID(ctx, resp.AccessToken)
if err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresIn: int64(resp.ExpiresIn),
TokenType: "Bearer",
OpenID: openIDResp.OpenID,
}, nil
}
case OAuthProviderGitHub:
if entry.github != nil {
resp, err := entry.github.ExchangeCode(ctx, code)
if err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: resp.AccessToken,
TokenType: resp.TokenType,
}, nil
}
case OAuthProviderAlipay:
if entry.alipay != nil {
resp, err := entry.alipay.ExchangeCode(ctx, code)
if err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: resp.AccessToken,
RefreshToken: resp.RefreshToken,
ExpiresIn: int64(resp.ExpiresIn),
TokenType: "Bearer",
OpenID: resp.UserID,
}, nil
}
case OAuthProviderDouyin:
if entry.douyin != nil {
resp, err := entry.douyin.ExchangeCode(ctx, code)
if err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: resp.Data.AccessToken,
RefreshToken: resp.Data.RefreshToken,
ExpiresIn: int64(resp.Data.ExpiresIn),
TokenType: "Bearer",
OpenID: resp.Data.OpenID,
}, nil
}
}
return nil, fmt.Errorf("provider %s: real HTTP exchange not implemented yet", provider)
}
// GetUserInfo 获取用户信息(使用真实 provider 实现)
func (m *DefaultOAuthManager) GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error) {
entry, ok := m.entries[provider]
if !ok {
return nil, ErrOAuthProviderNotSupported
}
ctx := context.Background()
switch provider {
case OAuthProviderGoogle:
if entry.google != nil {
info, err := entry.google.GetUserInfo(ctx, token.AccessToken)
if err != nil {
return nil, err
}
return &OAuthUser{
Provider: provider,
OpenID: info.ID,
Nickname: info.Name,
Avatar: info.Picture,
Email: info.Email,
}, nil
}
case OAuthProviderWeChat:
if entry.wechat != nil {
openID := token.OpenID
info, err := entry.wechat.GetUserInfo(ctx, token.AccessToken, openID)
if err != nil {
return nil, err
}
gender := ""
switch info.Sex {
case 1:
gender = "male"
case 2:
gender = "female"
}
return &OAuthUser{
Provider: provider,
OpenID: info.OpenID,
UnionID: info.UnionID,
Nickname: info.Nickname,
Avatar: info.HeadImgURL,
Gender: gender,
}, nil
}
case OAuthProviderQQ:
if entry.qq != nil {
info, err := entry.qq.GetUserInfo(ctx, token.AccessToken, token.OpenID)
if err != nil {
return nil, err
}
avatar := info.FigureURL2
if avatar == "" {
avatar = info.FigureURL1
}
if avatar == "" {
avatar = info.FigureURL
}
return &OAuthUser{
Provider: provider,
OpenID: token.OpenID,
Nickname: info.Nickname,
Avatar: avatar,
Gender: info.Gender,
Extra: map[string]interface{}{
"province": info.Province,
"city": info.City,
"year": info.Year,
},
}, nil
}
case OAuthProviderGitHub:
if entry.github != nil {
info, err := entry.github.GetUserInfo(ctx, token.AccessToken)
if err != nil {
return nil, err
}
nickname := info.Name
if nickname == "" {
nickname = info.Login
}
return &OAuthUser{
Provider: provider,
OpenID: fmt.Sprintf("%d", info.ID),
Nickname: nickname,
Email: info.Email,
}, nil
}
case OAuthProviderAlipay:
if entry.alipay != nil {
info, err := entry.alipay.GetUserInfo(ctx, token.AccessToken)
if err != nil {
return nil, err
}
return &OAuthUser{
Provider: provider,
OpenID: info.UserID,
Nickname: info.Nickname,
Avatar: info.Avatar,
}, nil
}
case OAuthProviderDouyin:
if entry.douyin != nil {
info, err := entry.douyin.GetUserInfo(ctx, token.AccessToken, token.OpenID)
if err != nil {
return nil, err
}
gender := ""
switch info.Data.Gender {
case 1:
gender = "male"
case 2:
gender = "female"
}
return &OAuthUser{
Provider: provider,
OpenID: info.Data.OpenID,
UnionID: info.Data.UnionID,
Nickname: info.Data.Nickname,
Avatar: info.Data.Avatar,
Gender: gender,
}, nil
}
}
return nil, fmt.Errorf("provider %s: real HTTP user info not implemented yet", provider)
}
// ValidateToken 验证令牌
// 注意:由于 ValidateToken 不持有 provider 上下文,无法进行真正的 token 验证
// 对于需要验证 token 的场景,应使用 GetUserInfo 通过 provider 的 userinfo 端点验证
// 如果没有可用的 provider返回错误
func (m *DefaultOAuthManager) ValidateToken(token string) (bool, error) {
if len(token) == 0 {
return false, nil
}
// 由于缺乏 provider 上下文,无法进行有意义的验证
// 遍历所有已启用的 provider尝试通过 GetUserInfo 验证
// 如果没有任何 provider 可用,返回错误而不是默认通过
providers := m.GetEnabledProviders()
if len(providers) == 0 {
return false, errors.New("no OAuth providers configured")
}
// 尝试任一 provider 的 userinfo 端点验证
tokenObj := &OAuthToken{AccessToken: token}
for _, p := range providers {
if _, err := m.GetUserInfo(p.Provider, tokenObj); err == nil {
return true, nil
}
}
return false, nil
}
// ValidateTokenWithProvider 通过指定 provider 验证令牌
func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider, token string) (bool, error) {
if token == "" {
return false, nil
}
cfg, ok := m.GetConfig(provider)
if !ok || cfg.ClientID == "" {
return false, fmt.Errorf("provider %s not configured", provider)
}
// 通过 provider 的 userinfo 端点验证 token
tokenObj := &OAuthToken{AccessToken: token}
_, err := m.GetUserInfo(provider, tokenObj)
if err != nil {
return false, err
}
return true, nil
}
// GetEnabledProviders 获取已启用的OAuth提供商
func (m *DefaultOAuthManager) GetEnabledProviders() []OAuthProviderInfo {
providerNames := map[OAuthProvider]string{
OAuthProviderGoogle: "Google",
OAuthProviderWeChat: "微信",
OAuthProviderQQ: "QQ",
OAuthProviderWeibo: "微博",
OAuthProviderFacebook: "Facebook",
OAuthProviderTwitter: "Twitter",
OAuthProviderGitHub: "GitHub",
OAuthProviderAlipay: "支付宝",
OAuthProviderDouyin: "抖音",
}
var result []OAuthProviderInfo
for provider, entry := range m.entries {
name := providerNames[provider]
if name == "" {
name = string(provider)
}
result = append(result, OAuthProviderInfo{
Provider: provider,
Enabled: entry.config != nil,
Name: name,
})
}
return result
}