Files

370 lines
9.3 KiB
Go
Raw Permalink Normal View History

package service
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
"gorm.io/gorm"
)
type oauthRegistrar interface {
RegisterProvider(provider auth.OAuthProvider, config *auth.OAuthConfig)
}
func (s *AuthService) RegisterOAuthProvider(provider auth.OAuthProvider, cfg *auth.OAuthConfig) {
if cfg == nil {
return
}
if registrar, ok := s.oauthManager.(oauthRegistrar); ok {
registrar.RegisterProvider(provider, cfg)
}
}
func (s *AuthService) findUserForLogin(ctx context.Context, account string) (*domain.User, error) {
user, err := s.userRepo.GetByUsername(ctx, account)
if err == nil {
return user, nil
}
if !isUserNotFoundError(err) {
return nil, fmt.Errorf("lookup user by username failed: %w", err)
}
user, err = s.userRepo.GetByEmail(ctx, account)
if err == nil {
return user, nil
}
if !isUserNotFoundError(err) {
return nil, fmt.Errorf("lookup user by email failed: %w", err)
}
user, err = s.userRepo.GetByPhone(ctx, account)
if err != nil && !isUserNotFoundError(err) {
return nil, fmt.Errorf("lookup user by phone failed: %w", err)
}
return user, err
}
func isUserNotFoundError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return true
}
lowerErr := strings.ToLower(strings.TrimSpace(err.Error()))
return strings.Contains(lowerErr, "record not found") ||
strings.Contains(lowerErr, "user not found") ||
strings.Contains(err.Error(), "用户不存在") ||
strings.Contains(lowerErr, "not found")
}
func (s *AuthService) bestEffortAssignDefaultRoles(ctx context.Context, userID int64, source string) {
if s == nil || s.userRoleRepo == nil || s.roleRepo == nil {
return
}
defaultRoles, err := s.roleRepo.GetDefaultRoles(ctx)
if err != nil {
log.Printf("auth: load default roles failed, source=%s user_id=%d err=%v", source, userID, err)
return
}
if len(defaultRoles) == 0 {
return
}
userRoles := make([]*domain.UserRole, 0, len(defaultRoles))
for _, role := range defaultRoles {
userRoles = append(userRoles, &domain.UserRole{
UserID: userID,
RoleID: role.ID,
})
}
if err := s.userRoleRepo.BatchCreate(ctx, userRoles); err != nil {
log.Printf("auth: assign default roles failed, source=%s user_id=%d role_count=%d err=%v", source, userID, len(userRoles), err)
}
}
func (s *AuthService) bestEffortUpdateLastLogin(ctx context.Context, userID int64, ip, source string) {
if s == nil || s.userRepo == nil {
return
}
if err := s.userRepo.UpdateLastLogin(ctx, userID, ip); err != nil {
log.Printf("auth: update last login failed, source=%s user_id=%d ip=%s err=%v", source, userID, ip, err)
}
}
func loginAttemptKey(account string, user *domain.User) string {
if user != nil {
return fmt.Sprintf("login_attempt:user:%d", user.ID)
}
return "login_attempt:account:" + strings.ToLower(strings.TrimSpace(account))
}
func attemptCount(value interface{}) int {
if count, ok := intValue(value); ok {
return count
}
return 0
}
func intValue(value interface{}) (int, bool) {
switch v := value.(type) {
case int:
return v, true
case int64:
return int(v), true
case float64:
return int(v), true
case json.Number:
n, err := v.Int64()
if err != nil {
return 0, false
}
return int(n), true
default:
return 0, false
}
}
func int64Value(value interface{}) (int64, bool) {
switch v := value.(type) {
case int64:
return v, true
case int:
return int64(v), true
case float64:
return int64(v), true
case json.Number:
n, err := v.Int64()
if err != nil {
return 0, false
}
return n, true
default:
return 0, false
}
}
func (s *AuthService) verifyPhoneRegistration(ctx context.Context, req *RegisterRequest) error {
if req == nil || req.Phone == "" {
return nil
}
if s.smsCodeSvc == nil {
return errors.New("手机注册未启用")
}
if req.PhoneCode == "" {
return errors.New("手机验证码不能为空")
}
return s.smsCodeSvc.VerifyCode(ctx, req.Phone, "register", req.PhoneCode)
}
const (
oauthStateCachePrefix = "oauth_state:"
oauthHandoffCachePrefix = "oauth_handoff:"
oauthStateTTL = 10 * time.Minute
oauthHandoffTTL = time.Minute
)
type OAuthStatePurpose string
const (
OAuthStatePurposeLogin OAuthStatePurpose = "login"
OAuthStatePurposeBind OAuthStatePurpose = "bind"
)
type OAuthStatePayload struct {
Purpose OAuthStatePurpose `json:"purpose"`
ReturnTo string `json:"return_to"`
UserID int64 `json:"user_id,omitempty"`
}
func generateOAuthEphemeralCode() (string, error) {
buffer := make([]byte, 32)
if _, err := rand.Read(buffer); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buffer), nil
}
func (s *AuthService) CreateOAuthState(ctx context.Context, returnTo string) (string, error) {
return s.createOAuthStatePayload(ctx, &OAuthStatePayload{
Purpose: OAuthStatePurposeLogin,
ReturnTo: strings.TrimSpace(returnTo),
})
}
func (s *AuthService) CreateOAuthBindState(ctx context.Context, userID int64, returnTo string) (string, error) {
if userID <= 0 {
return "", errors.New("oauth binding user is required")
}
return s.createOAuthStatePayload(ctx, &OAuthStatePayload{
Purpose: OAuthStatePurposeBind,
ReturnTo: strings.TrimSpace(returnTo),
UserID: userID,
})
}
func (s *AuthService) createOAuthStatePayload(ctx context.Context, payload *OAuthStatePayload) (string, error) {
if s == nil || s.cache == nil {
return "", errors.New("oauth state storage unavailable")
}
if payload == nil {
return "", errors.New("oauth state payload is required")
}
if payload.Purpose == "" {
payload.Purpose = OAuthStatePurposeLogin
}
state, err := generateOAuthEphemeralCode()
if err != nil {
return "", err
}
if err := s.cache.Set(ctx, oauthStateCachePrefix+state, payload, oauthStateTTL, oauthStateTTL); err != nil {
return "", err
}
return state, nil
}
func (s *AuthService) ConsumeOAuthState(ctx context.Context, state string) (string, error) {
payload, err := s.ConsumeOAuthStatePayload(ctx, state)
if err != nil {
return "", err
}
if payload == nil {
return "", nil
}
return strings.TrimSpace(payload.ReturnTo), nil
}
func (s *AuthService) ConsumeOAuthStatePayload(ctx context.Context, state string) (*OAuthStatePayload, error) {
if s == nil || s.cache == nil {
return nil, errors.New("oauth state storage unavailable")
}
cacheKey := oauthStateCachePrefix + strings.TrimSpace(state)
value, ok := s.cache.Get(ctx, cacheKey)
if !ok {
return nil, errors.New("OAuth state validation failed")
}
_ = s.cache.Delete(ctx, cacheKey)
switch typed := value.(type) {
case *OAuthStatePayload:
payload := *typed
if payload.Purpose == "" {
payload.Purpose = OAuthStatePurposeLogin
}
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
return &payload, nil
case OAuthStatePayload:
payload := typed
if payload.Purpose == "" {
payload.Purpose = OAuthStatePurposeLogin
}
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
return &payload, nil
case string:
return &OAuthStatePayload{
Purpose: OAuthStatePurposeLogin,
ReturnTo: strings.TrimSpace(typed),
}, nil
case nil:
return &OAuthStatePayload{Purpose: OAuthStatePurposeLogin}, nil
case map[string]interface{}:
payloadBytes, err := json.Marshal(typed)
if err != nil {
return nil, err
}
var payload OAuthStatePayload
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
return nil, err
}
if payload.Purpose == "" {
payload.Purpose = OAuthStatePurposeLogin
}
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
return &payload, nil
default:
return &OAuthStatePayload{
Purpose: OAuthStatePurposeLogin,
ReturnTo: strings.TrimSpace(fmt.Sprint(typed)),
}, nil
}
}
func (s *AuthService) CreateOAuthHandoff(ctx context.Context, loginResp *LoginResponse) (string, error) {
if s == nil || s.cache == nil {
return "", errors.New("oauth handoff storage unavailable")
}
if loginResp == nil {
return "", errors.New("oauth handoff payload is required")
}
code, err := generateOAuthEphemeralCode()
if err != nil {
return "", err
}
if err := s.cache.Set(ctx, oauthHandoffCachePrefix+code, loginResp, oauthHandoffTTL, oauthHandoffTTL); err != nil {
return "", err
}
return code, nil
}
func (s *AuthService) ConsumeOAuthHandoff(ctx context.Context, code string) (*LoginResponse, error) {
if s == nil || s.cache == nil {
return nil, errors.New("oauth handoff storage unavailable")
}
cacheKey := oauthHandoffCachePrefix + strings.TrimSpace(code)
value, ok := s.cache.Get(ctx, cacheKey)
if !ok {
return nil, errors.New("OAuth handoff code is invalid or expired")
}
_ = s.cache.Delete(ctx, cacheKey)
switch typed := value.(type) {
case *LoginResponse:
return typed, nil
case LoginResponse:
resp := typed
return &resp, nil
case map[string]interface{}:
payload, err := json.Marshal(typed)
if err != nil {
return nil, err
}
var resp LoginResponse
if err := json.Unmarshal(payload, &resp); err != nil {
return nil, err
}
return &resp, nil
default:
payload, err := json.Marshal(typed)
if err != nil {
return nil, err
}
var resp LoginResponse
if err := json.Unmarshal(payload, &resp); err != nil {
return nil, err
}
return &resp, nil
}
}