Files
user-system/internal/auth/oauth_utils.go

197 lines
4.5 KiB
Go
Raw Normal View History

package auth
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/oauth2"
)
// StateStore OAuth状态存储
type StateStore struct {
states map[string]time.Time
mu sync.RWMutex
}
var stateStore = &StateStore{
states: make(map[string]time.Time),
}
// GenerateState 生成OAuth状态参数
func GenerateState() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate state failed: %w", err)
}
state := base64.URLEncoding.EncodeToString(b)
// 存储状态10分钟过期
stateStore.mu.Lock()
stateStore.states[state] = time.Now().Add(10 * time.Minute)
stateStore.mu.Unlock()
return state, nil
}
// ValidateState 验证OAuth状态参数
func ValidateState(state string) bool {
stateStore.mu.Lock()
defer stateStore.mu.Unlock()
expireTime, ok := stateStore.states[state]
if !ok {
return false
}
// 检查是否过期
if time.Now().After(expireTime) {
delete(stateStore.states, state)
return false
}
// 使用后删除
delete(stateStore.states, state)
return true
}
// CleanupStates 清理过期的状态
func CleanupStates() {
stateStore.mu.Lock()
defer stateStore.mu.Unlock()
now := time.Now()
for state, expireTime := range stateStore.states {
if now.After(expireTime) {
delete(stateStore.states, state)
}
}
}
// HTTPClient OAuth HTTP客户端
var HTTPClient = &http.Client{
Timeout: 30 * time.Second,
}
// Get 发送GET请求
func Get(url string) (*http.Response, error) {
return HTTPClient.Get(url)
}
// PostForm 发送POST表单请求
func PostForm(url string, data url.Values) (*http.Response, error) {
return HTTPClient.PostForm(url, data)
}
// GetJSON 发送GET请求并解析JSON响应
func GetJSON(url string, result interface{}) error {
resp, err := Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("HTTP request failed with status %d", resp.StatusCode)
}
return json.NewDecoder(resp.Body).Decode(result)
}
// PostFormJSON 发送POST表单请求并解析JSON响应
func PostFormJSON(url string, data url.Values, result interface{}) error {
resp, err := PostForm(url, data)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("HTTP request failed with status %d", resp.StatusCode)
}
return json.NewDecoder(resp.Body).Decode(result)
}
// BuildAuthURL 构建标准OAuth授权URL
func BuildAuthURL(baseURL, clientID, redirectURI, scope, state string) string {
u, _ := url.Parse(baseURL)
q := u.Query()
q.Set("client_id", clientID)
q.Set("redirect_uri", redirectURI)
q.Set("scope", scope)
q.Set("state", state)
q.Set("response_type", "code")
u.RawQuery = q.Encode()
return u.String()
}
// ParseAccessTokenResponse 解析访问令牌响应
func ParseAccessTokenResponse(resp []byte) (*OAuthToken, error) {
var result struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
if err := json.Unmarshal(resp, &result); err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresIn: result.ExpiresIn,
TokenType: result.TokenType,
}, nil
}
// ParseQueryAccessToken 解析查询字符串形式的访问令牌用于某些返回text/plain的API
func ParseQueryAccessToken(body string) (accessToken string, err error) {
values, err := url.ParseQuery(body)
if err != nil {
return "", err
}
return values.Get("access_token"), nil
}
// ParseJSONPResponse 解析JSONP响应用于QQ等平台
func ParseJSONPResponse(jsonp string) (map[string]interface{}, error) {
// 移除callback包装
start := strings.Index(jsonp, "(")
end := strings.LastIndex(jsonp, ")")
if start == -1 || end == -1 {
return nil, fmt.Errorf("invalid JSONP format")
}
jsonStr := jsonp[start+1 : end]
var result map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
return nil, err
}
return result, nil
}
// ToOAuth2Config 转换为oauth2.Config
func ToOAuth2Config(config *OAuthConfig) *oauth2.Config {
return &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURI,
Scopes: strings.Split(config.Scope, ","),
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
},
}
}