Files

1451 lines
39 KiB
Go
Raw Permalink Normal View History

package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"strings"
"time"
"unicode"
"unicode/utf8"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
"github.com/user-management-system/internal/security"
)
const (
userInfoCachePrefix = "auth_user_info:"
tokenBlacklistPrefix = "auth_token_blacklist:"
defaultUserCacheTTL = 15 * time.Minute
defaultBlacklistTTL = time.Hour
defaultPasswordMinLen = 8
)
type userRepositoryInterface interface {
Create(ctx context.Context, user *domain.User) error
Update(ctx context.Context, user *domain.User) error
UpdateTOTP(ctx context.Context, user *domain.User) error
Delete(ctx context.Context, id int64) error
GetByID(ctx context.Context, id int64) (*domain.User, error)
GetByUsername(ctx context.Context, username string) (*domain.User, error)
GetByEmail(ctx context.Context, email string) (*domain.User, error)
GetByPhone(ctx context.Context, phone string) (*domain.User, error)
List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error)
ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error)
UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error
UpdateLastLogin(ctx context.Context, id int64, ip string) error
ExistsByUsername(ctx context.Context, username string) (bool, error)
ExistsByEmail(ctx context.Context, email string) (bool, error)
ExistsByPhone(ctx context.Context, phone string) (bool, error)
Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error)
}
type userRoleRepositoryInterface interface {
BatchCreate(ctx context.Context, userRoles []*domain.UserRole) error
GetUserIDByRoleID(ctx context.Context, roleID int64) ([]int64, error)
}
type roleRepositoryInterface interface {
GetDefaultRoles(ctx context.Context) ([]*domain.Role, error)
GetByCode(ctx context.Context, code string) (*domain.Role, error)
}
type loginLogRepositoryInterface interface {
Create(ctx context.Context, loginRecord *domain.LoginLog) error
}
type anomalyRecorder interface {
RecordLogin(ctx context.Context, userID int64, ip, location, deviceFingerprint string, success bool) []security.AnomalyEvent
}
type PasswordStrengthInfo struct {
Score int `json:"score"`
Length int `json:"length"`
HasUpper bool `json:"has_upper"`
HasLower bool `json:"has_lower"`
HasDigit bool `json:"has_digit"`
HasSpecial bool `json:"has_special"`
}
type RegisterRequest struct {
Username string `json:"username" binding:"required"`
Email string `json:"email"`
Phone string `json:"phone"`
PhoneCode string `json:"phone_code"`
Password string `json:"password" binding:"required"`
Nickname string `json:"nickname"`
}
type LoginRequest struct {
Account string `json:"account"`
Username string `json:"username"`
Email string `json:"email"`
Phone string `json:"phone"`
Password string `json:"password"`
Remember bool `json:"remember"` // 记住登录
DeviceID string `json:"device_id,omitempty"` // 设备唯一标识
DeviceName string `json:"device_name,omitempty"` // 设备名称
DeviceBrowser string `json:"device_browser,omitempty"` // 浏览器
DeviceOS string `json:"device_os,omitempty"` // 操作系统
}
func (r *LoginRequest) GetAccount() string {
if r == nil {
return ""
}
for _, candidate := range []string{r.Account, r.Username, r.Email, r.Phone} {
if trimmed := strings.TrimSpace(candidate); trimmed != "" {
return trimmed
}
}
return ""
}
type UserInfo struct {
ID int64 `json:"id"`
Username string `json:"username"`
Email string `json:"email,omitempty"`
Phone string `json:"phone,omitempty"`
Nickname string `json:"nickname,omitempty"`
Avatar string `json:"avatar,omitempty"`
Status domain.UserStatus `json:"status"`
}
type LoginResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
User *UserInfo `json:"user"`
}
type LogoutRequest struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
type AuthService struct {
userRepo userRepositoryInterface
socialRepo repository.SocialAccountRepository
jwtManager *auth.JWT
cache *cache.CacheManager
passwordMinLength int
maxLoginAttempts int
loginLockDuration time.Duration
userRoleRepo userRoleRepositoryInterface
roleRepo roleRepositoryInterface
loginLogRepo loginLogRepositoryInterface
webhookSvc *WebhookService
passwordPolicy security.PasswordPolicy
passwordPolicySet bool
anomalyDetector anomalyRecorder
smsCodeSvc *SMSCodeService
emailActivationSvc *EmailActivationService
emailCodeSvc *EmailCodeService
oauthManager auth.OAuthManager
deviceService *DeviceService
}
func NewAuthService(
userRepo userRepositoryInterface,
socialRepo repository.SocialAccountRepository,
jwtManager *auth.JWT,
cacheManager *cache.CacheManager,
passwordMinLength int,
maxLoginAttempts int,
loginLockDuration time.Duration,
) *AuthService {
if passwordMinLength <= 0 {
passwordMinLength = defaultPasswordMinLen
}
if maxLoginAttempts <= 0 {
maxLoginAttempts = 5
}
if loginLockDuration <= 0 {
loginLockDuration = 15 * time.Minute
}
return &AuthService{
userRepo: userRepo,
socialRepo: socialRepo,
jwtManager: jwtManager,
cache: cacheManager,
passwordMinLength: passwordMinLength,
maxLoginAttempts: maxLoginAttempts,
loginLockDuration: loginLockDuration,
oauthManager: auth.NewOAuthManager(),
}
}
func (s *AuthService) SetWebhookService(webhookSvc *WebhookService) {
s.webhookSvc = webhookSvc
}
func (s *AuthService) SetRoleRepositories(userRoleRepo userRoleRepositoryInterface, roleRepo roleRepositoryInterface) {
s.userRoleRepo = userRoleRepo
s.roleRepo = roleRepo
}
func (s *AuthService) SetLoginLogRepository(loginLogRepo loginLogRepositoryInterface) {
s.loginLogRepo = loginLogRepo
}
func (s *AuthService) SetPasswordPolicy(policy security.PasswordPolicy) {
s.passwordPolicy = policy.Normalize()
s.passwordPolicySet = true
}
func (s *AuthService) SetAnomalyDetector(detector anomalyRecorder) {
s.anomalyDetector = detector
}
func (s *AuthService) SetDeviceService(svc *DeviceService) {
s.deviceService = svc
}
func (s *AuthService) SetSMSCodeService(svc *SMSCodeService) {
s.smsCodeSvc = svc
}
func sanitizeUsername(value string) string {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return "user"
}
var builder strings.Builder
lastUnderscore := false
for _, r := range trimmed {
switch {
case unicode.IsLetter(r) || unicode.IsDigit(r):
builder.WriteRune(unicode.ToLower(r))
lastUnderscore = false
case r == '.' || r == '-' || r == '_':
builder.WriteRune(r)
lastUnderscore = false
case unicode.IsSpace(r):
if !lastUnderscore && builder.Len() > 0 {
builder.WriteByte('_')
lastUnderscore = true
}
}
}
result := strings.Trim(builder.String(), "._-")
if result == "" {
return "user"
}
runes := []rune(result)
if len(runes) > 50 {
result = string(runes[:50])
}
return result
}
func (s *AuthService) generateUniqueUsername(ctx context.Context, base string) (string, error) {
username := sanitizeUsername(base)
if s == nil || s.userRepo == nil {
return username, nil
}
exists, err := s.userRepo.ExistsByUsername(ctx, username)
if err != nil {
return "", err
}
if !exists {
return username, nil
}
baseRunes := []rune(username)
if len(baseRunes) > 40 {
username = string(baseRunes[:40])
}
for i := 1; i <= 1000; i++ {
candidate := fmt.Sprintf("%s_%d", username, i)
exists, err = s.userRepo.ExistsByUsername(ctx, candidate)
if err != nil {
return "", err
}
if !exists {
return candidate, nil
}
}
return "", errors.New("unable to generate unique username")
}
func validatePasswordStrength(password string, minLength int, strict bool) error {
if minLength <= 0 {
minLength = defaultPasswordMinLen
}
info := GetPasswordStrength(password)
if info.Length < minLength {
return fmt.Errorf("密码长度不能少于%d位", minLength)
}
if strict {
if !info.HasUpper || !info.HasLower || !info.HasDigit {
return errors.New("密码必须包含大小写字母和数字")
}
return nil
}
if info.Score < 2 {
return errors.New("密码强度不足")
}
return nil
}
func GetPasswordStrength(password string) PasswordStrengthInfo {
info := PasswordStrengthInfo{
Length: utf8.RuneCountInString(password),
}
for _, r := range password {
switch {
case unicode.IsUpper(r):
info.HasUpper = true
case unicode.IsLower(r):
info.HasLower = true
case unicode.IsDigit(r):
info.HasDigit = true
case unicode.IsPunct(r) || unicode.IsSymbol(r):
info.HasSpecial = true
}
}
if info.HasUpper {
info.Score++
}
if info.HasLower {
info.Score++
}
if info.HasDigit {
info.Score++
}
if info.HasSpecial {
info.Score++
}
return info
}
func (s *AuthService) validatePassword(password string) error {
if s != nil && s.passwordPolicySet {
return s.passwordPolicy.Validate(password)
}
minLength := defaultPasswordMinLen
if s != nil && s.passwordMinLength > 0 {
minLength = s.passwordMinLength
}
return validatePasswordStrength(password, minLength, false)
}
func (s *AuthService) accessTokenTTLSeconds() int64 {
if s == nil || s.jwtManager == nil {
return 0
}
return int64(s.jwtManager.GetAccessTokenExpire().Seconds())
}
func (s *AuthService) RefreshTokenTTLSeconds() int64 {
if s == nil || s.jwtManager == nil {
return 0
}
return int64(s.jwtManager.GetRefreshTokenExpire().Seconds())
}
func (s *AuthService) buildUserInfo(user *domain.User) *UserInfo {
if user == nil {
return nil
}
return &UserInfo{
ID: user.ID,
Username: user.Username,
Email: domain.DerefStr(user.Email),
Phone: domain.DerefStr(user.Phone),
Nickname: user.Nickname,
Avatar: user.Avatar,
Status: user.Status,
}
}
func (s *AuthService) ensureUserActive(user *domain.User) error {
if user == nil {
return errors.New("用户不存在")
}
switch user.Status {
case domain.UserStatusActive:
return nil
case domain.UserStatusInactive:
return errors.New("账号未激活")
case domain.UserStatusLocked:
return errors.New("账号已锁定")
case domain.UserStatusDisabled:
return errors.New("账号已禁用")
default:
return errors.New("账号状态异常")
}
}
func (s *AuthService) blacklistTokenClaims(ctx context.Context, token string, validate func(string) (*auth.Claims, error)) error {
if s == nil || s.cache == nil {
return nil
}
token = strings.TrimSpace(token)
if token == "" || validate == nil {
return nil
}
claims, err := validate(token)
if err != nil || claims == nil || strings.TrimSpace(claims.JTI) == "" {
return nil
}
ttl := defaultBlacklistTTL
if claims.ExpiresAt != nil {
if until := time.Until(claims.ExpiresAt.Time); until > 0 {
ttl = until
}
}
return s.cache.Set(ctx, tokenBlacklistPrefix+claims.JTI, true, ttl, ttl)
}
func (s *AuthService) recordLoginAnomaly(ctx context.Context, userID *int64, ip, location, deviceFingerprint string, success bool) {
if s == nil || s.anomalyDetector == nil || userID == nil {
return
}
events := s.anomalyDetector.RecordLogin(ctx, *userID, ip, location, deviceFingerprint, success)
if len(events) == 0 {
return
}
s.publishEvent(ctx, domain.EventAnomalyDetected, map[string]interface{}{
"user_id": *userID,
"ip": ip,
"location": location,
"device": deviceFingerprint,
"events": events,
"success": success,
})
}
func (s *AuthService) publishEvent(ctx context.Context, eventType domain.WebhookEventType, data interface{}) {
if s == nil || s.webhookSvc == nil {
return
}
go s.webhookSvc.Publish(ctx, eventType, data)
}
func (s *AuthService) writeLoginLog(
ctx context.Context,
userID *int64,
loginType domain.LoginType,
ip string,
success bool,
failReason string,
) {
if s == nil || s.loginLogRepo == nil {
return
}
status := 0
if success {
status = 1
}
loginRecord := &domain.LoginLog{
UserID: userID,
LoginType: int(loginType),
IP: ip,
Status: status,
FailReason: failReason,
}
go func() {
if err := s.loginLogRepo.Create(context.Background(), loginRecord); err != nil {
log.Printf("auth: write login log failed, user_id=%v login_type=%d err=%v", userID, loginType, err)
}
}()
}
func (s *AuthService) incrementFailAttempts(ctx context.Context, key string) int {
if s == nil || s.cache == nil || key == "" {
return 0
}
current := 0
if value, ok := s.cache.Get(ctx, key); ok {
current = attemptCount(value)
}
current++
if err := s.cache.Set(ctx, key, current, s.loginLockDuration, s.loginLockDuration); err != nil {
log.Printf("auth: store login attempts failed, key=%s err=%v", key, err)
}
return current
}
func isValidPhoneSimple(phone string) bool {
return isValidPhone(phone)
}
// buildDeviceFingerprint 构建设备指纹字符串
func buildDeviceFingerprint(req *LoginRequest) string {
if req == nil {
return ""
}
var parts []string
if req.DeviceID != "" {
parts = append(parts, req.DeviceID)
}
if req.DeviceName != "" {
parts = append(parts, req.DeviceName)
}
if req.DeviceBrowser != "" {
parts = append(parts, req.DeviceBrowser)
}
if req.DeviceOS != "" {
parts = append(parts, req.DeviceOS)
}
result := strings.Join(parts, "|")
if result == "" {
return ""
}
return result
}
// bestEffortRegisterDevice 尝试自动注册/更新设备记录
func (s *AuthService) bestEffortRegisterDevice(ctx context.Context, userID int64, req *LoginRequest) {
if s == nil || s.deviceService == nil || req == nil || req.DeviceID == "" {
return
}
createReq := &CreateDeviceRequest{
DeviceID: req.DeviceID,
DeviceName: req.DeviceName,
DeviceBrowser: req.DeviceBrowser,
DeviceOS: req.DeviceOS,
}
_, _ = s.deviceService.CreateDevice(ctx, userID, createReq)
}
func (s *AuthService) cacheUserInfo(ctx context.Context, user *domain.User) {
if s == nil || s.cache == nil || user == nil {
return
}
info := s.buildUserInfo(user)
if info == nil {
return
}
_ = s.cache.Set(ctx, userInfoCachePrefix+fmt.Sprintf("%d", user.ID), info, defaultUserCacheTTL, defaultUserCacheTTL)
}
func userInfoFromCacheValue(value interface{}) (*UserInfo, bool) {
switch typed := value.(type) {
case *UserInfo:
return typed, true
case UserInfo:
userInfo := typed
return &userInfo, true
case map[string]interface{}:
payload, err := json.Marshal(typed)
if err != nil {
return nil, false
}
var userInfo UserInfo
if err := json.Unmarshal(payload, &userInfo); err != nil {
return nil, false
}
return &userInfo, true
default:
return nil, false
}
}
func (s *AuthService) Register(ctx context.Context, req *RegisterRequest) (*UserInfo, error) {
if req == nil {
return nil, errors.New("注册请求不能为空")
}
if s == nil || s.userRepo == nil {
return nil, errors.New("user repository is not configured")
}
req.Username = strings.TrimSpace(req.Username)
req.Email = strings.TrimSpace(req.Email)
req.Phone = strings.TrimSpace(req.Phone)
if req.Username == "" {
return nil, errors.New("用户名不能为空")
}
if req.Password == "" {
return nil, errors.New("密码不能为空")
}
if req.Phone != "" && !isValidPhoneSimple(req.Phone) {
return nil, errors.New("手机号格式不正确")
}
if err := s.validatePassword(req.Password); err != nil {
return nil, err
}
if err := s.verifyPhoneRegistration(ctx, req); err != nil {
return nil, err
}
exists, err := s.userRepo.ExistsByUsername(ctx, req.Username)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("用户名已存在")
}
if req.Email != "" {
exists, err = s.userRepo.ExistsByEmail(ctx, req.Email)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("邮箱已存在")
}
}
if req.Phone != "" {
exists, err = s.userRepo.ExistsByPhone(ctx, req.Phone)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("手机号已存在")
}
}
hashedPassword, err := auth.HashPassword(req.Password)
if err != nil {
return nil, err
}
nickname := strings.TrimSpace(req.Nickname)
if nickname == "" {
nickname = req.Username
}
user := &domain.User{
Username: req.Username,
Email: domain.StrPtr(req.Email),
Phone: domain.StrPtr(req.Phone),
Password: hashedPassword,
Nickname: nickname,
Status: domain.UserStatusActive,
}
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
s.bestEffortAssignDefaultRoles(ctx, user.ID, "register")
s.cacheUserInfo(ctx, user)
userInfo := s.buildUserInfo(user)
s.publishEvent(ctx, domain.EventUserRegistered, userInfo)
return userInfo, nil
}
func (s *AuthService) Login(ctx context.Context, req *LoginRequest, ip string) (*LoginResponse, error) {
if req == nil {
return nil, errors.New("登录请求不能为空")
}
if s == nil || s.userRepo == nil || s.jwtManager == nil {
return nil, errors.New("auth service is not fully configured")
}
account := req.GetAccount()
if account == "" {
return nil, errors.New("账号不能为空")
}
if strings.TrimSpace(req.Password) == "" {
return nil, errors.New("密码不能为空")
}
// 构建设备指纹
deviceFingerprint := buildDeviceFingerprint(req)
user, err := s.findUserForLogin(ctx, account)
if err != nil && !isUserNotFoundError(err) {
s.writeLoginLog(ctx, nil, domain.LoginTypePassword, ip, false, err.Error())
return nil, err
}
attemptKey := loginAttemptKey(account, user)
if s.cache != nil {
if value, ok := s.cache.Get(ctx, attemptKey); ok && attemptCount(value) >= s.maxLoginAttempts {
lockErr := errors.New("账号已锁定,请稍后再试")
s.writeLoginLog(ctx, nil, domain.LoginTypePassword, ip, false, lockErr.Error())
return nil, lockErr
}
}
if user == nil {
s.incrementFailAttempts(ctx, attemptKey)
s.writeLoginLog(ctx, nil, domain.LoginTypePassword, ip, false, "用户不存在")
return nil, errors.New("账号或密码错误")
}
if err := s.ensureUserActive(user); err != nil {
s.writeLoginLog(ctx, &user.ID, domain.LoginTypePassword, ip, false, err.Error())
s.recordLoginAnomaly(ctx, &user.ID, ip, "", deviceFingerprint, false)
return nil, err
}
if !auth.VerifyPassword(user.Password, req.Password) {
failCount := s.incrementFailAttempts(ctx, attemptKey)
failErr := errors.New("账号或密码错误")
if failCount >= s.maxLoginAttempts {
s.publishEvent(ctx, domain.EventUserLocked, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"ip": ip,
})
}
s.writeLoginLog(ctx, &user.ID, domain.LoginTypePassword, ip, false, failErr.Error())
s.recordLoginAnomaly(ctx, &user.ID, ip, "", deviceFingerprint, false)
s.publishEvent(ctx, domain.EventLoginFailed, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"ip": ip,
})
return nil, failErr
}
if s.cache != nil {
_ = s.cache.Delete(ctx, attemptKey)
}
s.bestEffortUpdateLastLogin(ctx, user.ID, ip, "password")
s.cacheUserInfo(ctx, user)
s.writeLoginLog(ctx, &user.ID, domain.LoginTypePassword, ip, true, "")
s.recordLoginAnomaly(ctx, &user.ID, ip, "", deviceFingerprint, true)
s.bestEffortRegisterDevice(ctx, user.ID, req)
s.publishEvent(ctx, domain.EventUserLogin, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"ip": ip,
"method": "password",
})
return s.generateLoginResponse(ctx, user, req.Remember)
}
func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
if s == nil || s.jwtManager == nil || s.userRepo == nil {
return nil, errors.New("auth service is not fully configured")
}
claims, err := s.jwtManager.ValidateRefreshToken(strings.TrimSpace(refreshToken))
if err != nil {
return nil, err
}
if s.IsTokenBlacklisted(ctx, claims.JTI) {
return nil, errors.New("refresh token has been revoked")
}
user, err := s.userRepo.GetByID(ctx, claims.UserID)
if err != nil {
return nil, err
}
if err := s.ensureUserActive(user); err != nil {
return nil, err
}
return s.generateLoginResponse(ctx, user, claims.Remember)
}
func (s *AuthService) GetUserInfo(ctx context.Context, userID int64) (*UserInfo, error) {
if s == nil || s.userRepo == nil {
return nil, errors.New("user repository is not configured")
}
if s.cache != nil {
cacheKey := userInfoCachePrefix + fmt.Sprintf("%d", userID)
if value, ok := s.cache.Get(ctx, cacheKey); ok {
if info, ok := userInfoFromCacheValue(value); ok {
return info, nil
}
}
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
s.cacheUserInfo(ctx, user)
return s.buildUserInfo(user), nil
}
func (s *AuthService) Logout(ctx context.Context, username string, req *LogoutRequest) error {
if s == nil {
return nil
}
if req == nil {
return nil
}
_ = s.blacklistTokenClaims(ctx, req.AccessToken, func(token string) (*auth.Claims, error) {
if s.jwtManager == nil {
return nil, nil
}
return s.jwtManager.ValidateAccessToken(token)
})
_ = s.blacklistTokenClaims(ctx, req.RefreshToken, func(token string) (*auth.Claims, error) {
if s.jwtManager == nil {
return nil, nil
}
return s.jwtManager.ValidateRefreshToken(token)
})
if strings.TrimSpace(username) != "" {
s.publishEvent(ctx, domain.EventUserLogout, map[string]interface{}{
"username": strings.TrimSpace(username),
})
}
return nil
}
func (s *AuthService) IsTokenBlacklisted(ctx context.Context, jti string) bool {
if s == nil || s.cache == nil {
return false
}
jti = strings.TrimSpace(jti)
if jti == "" {
return false
}
_, ok := s.cache.Get(ctx, tokenBlacklistPrefix+jti)
return ok
}
func (s *AuthService) OAuthLogin(ctx context.Context, provider, state string) (string, error) {
if s == nil || s.oauthManager == nil {
return "", errors.New("oauth manager is not configured")
}
return s.oauthManager.GetAuthURL(auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider))), state)
}
func (s *AuthService) OAuthCallback(ctx context.Context, provider, code string) (*LoginResponse, error) {
if s == nil || s.oauthManager == nil || s.socialRepo == nil || s.userRepo == nil {
return nil, errors.New("oauth login is not fully configured")
}
oauthProvider := auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider)))
token, err := s.oauthManager.ExchangeCode(oauthProvider, strings.TrimSpace(code))
if err != nil {
return nil, err
}
oauthUser, err := s.oauthManager.GetUserInfo(oauthProvider, token)
if err != nil {
return nil, err
}
if oauthUser == nil {
return nil, errors.New("oauth user info is empty")
}
socialAccount, err := s.socialRepo.GetByProviderAndOpenID(ctx, string(oauthProvider), oauthUser.OpenID)
if err != nil {
return nil, err
}
var user *domain.User
if socialAccount != nil {
user, err = s.userRepo.GetByID(ctx, socialAccount.UserID)
if err != nil {
return nil, err
}
socialAccount.UnionID = oauthUser.UnionID
socialAccount.Nickname = oauthUser.Nickname
socialAccount.Avatar = oauthUser.Avatar
socialAccount.Gender = oauthUser.Gender
socialAccount.Email = oauthUser.Email
socialAccount.Phone = oauthUser.Phone
socialAccount.Status = domain.SocialAccountStatusActive
if oauthUser.Extra != nil {
socialAccount.Extra = oauthUser.Extra
}
if err := s.socialRepo.Update(ctx, socialAccount); err != nil {
log.Printf("auth: update social account failed, provider=%s open_id=%s err=%v", oauthProvider, oauthUser.OpenID, err)
}
} else {
if strings.TrimSpace(oauthUser.Email) != "" {
user, err = s.userRepo.GetByEmail(ctx, strings.TrimSpace(oauthUser.Email))
if err != nil {
if !isUserNotFoundError(err) {
return nil, err
}
user = nil
}
}
if user == nil {
baseUsername := oauthUser.Nickname
if baseUsername == "" && oauthUser.Email != "" {
baseUsername = strings.Split(strings.TrimSpace(oauthUser.Email), "@")[0]
}
if baseUsername == "" {
baseUsername = string(oauthProvider) + "_" + oauthUser.OpenID
}
username, err := s.generateUniqueUsername(ctx, baseUsername)
if err != nil {
return nil, err
}
user = &domain.User{
Username: username,
Email: domain.StrPtr(strings.TrimSpace(oauthUser.Email)),
Phone: domain.StrPtr(strings.TrimSpace(oauthUser.Phone)),
Nickname: strings.TrimSpace(oauthUser.Nickname),
Avatar: strings.TrimSpace(oauthUser.Avatar),
Status: domain.UserStatusActive,
}
if user.Nickname == "" {
user.Nickname = user.Username
}
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
s.bestEffortAssignDefaultRoles(ctx, user.ID, "oauth")
s.publishEvent(ctx, domain.EventUserRegistered, s.buildUserInfo(user))
}
socialAccount = &domain.SocialAccount{
UserID: user.ID,
Provider: string(oauthProvider),
OpenID: oauthUser.OpenID,
UnionID: oauthUser.UnionID,
Nickname: oauthUser.Nickname,
Avatar: oauthUser.Avatar,
Gender: oauthUser.Gender,
Email: oauthUser.Email,
Phone: oauthUser.Phone,
Status: domain.SocialAccountStatusActive,
}
if oauthUser.Extra != nil {
socialAccount.Extra = oauthUser.Extra
}
if err := s.socialRepo.Create(ctx, socialAccount); err != nil {
return nil, err
}
}
if err := s.ensureUserActive(user); err != nil {
return nil, err
}
s.bestEffortUpdateLastLogin(ctx, user.ID, "", "oauth")
s.cacheUserInfo(ctx, user)
s.writeLoginLog(ctx, &user.ID, domain.LoginTypeOAuth, "", true, "")
s.recordLoginAnomaly(ctx, &user.ID, "", "", "", true)
s.publishEvent(ctx, domain.EventUserLogin, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"method": "oauth",
"provider": string(oauthProvider),
})
return s.generateLoginResponseWithoutRemember(ctx, user)
}
func (s *AuthService) StartSocialAccountBinding(
ctx context.Context,
userID int64,
provider string,
returnTo string,
currentPassword string,
totpCode string,
) (string, string, error) {
if s == nil || s.oauthManager == nil || s.socialRepo == nil || s.userRepo == nil {
return "", "", errors.New("social account binding is not fully configured")
}
normalizedProvider := strings.ToLower(strings.TrimSpace(provider))
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return "", "", err
}
if err := s.ensureUserActive(user); err != nil {
return "", "", err
}
if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil {
return "", "", err
}
accounts, err := s.GetSocialAccounts(ctx, userID)
if err != nil {
return "", "", err
}
if existing := findSocialAccountByProvider(accounts, normalizedProvider); existing != nil {
return "", "", auth.ErrOAuthAlreadyBound
}
state, err := s.CreateOAuthBindState(ctx, userID, returnTo)
if err != nil {
return "", "", err
}
authURL, err := s.OAuthLogin(ctx, normalizedProvider, state)
if err != nil {
return "", "", err
}
return authURL, state, nil
}
func (s *AuthService) OAuthBindCallback(ctx context.Context, userID int64, provider, code string) (*domain.SocialAccountInfo, error) {
if s == nil || s.oauthManager == nil || s.socialRepo == nil || s.userRepo == nil {
return nil, errors.New("social account binding is not fully configured")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
if err := s.ensureUserActive(user); err != nil {
return nil, err
}
oauthProvider := auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider)))
token, err := s.oauthManager.ExchangeCode(oauthProvider, strings.TrimSpace(code))
if err != nil {
return nil, err
}
oauthUser, err := s.oauthManager.GetUserInfo(oauthProvider, token)
if err != nil {
return nil, err
}
if oauthUser == nil {
return nil, errors.New("oauth user info is empty")
}
account, err := s.upsertOAuthSocialAccount(ctx, userID, oauthProvider, oauthUser)
if err != nil {
return nil, err
}
return account.ToInfo(), nil
}
func (s *AuthService) upsertOAuthSocialAccount(
ctx context.Context,
userID int64,
provider auth.OAuthProvider,
oauthUser *auth.OAuthUser,
) (*domain.SocialAccount, error) {
if s == nil || s.socialRepo == nil || s.userRepo == nil {
return nil, errors.New("social account binding is not configured")
}
if oauthUser == nil {
return nil, errors.New("oauth user info is empty")
}
normalizedProvider := strings.ToLower(strings.TrimSpace(string(provider)))
accounts, err := s.GetSocialAccounts(ctx, userID)
if err != nil {
return nil, err
}
if currentProviderBinding := findSocialAccountByProvider(accounts, normalizedProvider); currentProviderBinding != nil &&
!strings.EqualFold(strings.TrimSpace(currentProviderBinding.OpenID), strings.TrimSpace(oauthUser.OpenID)) {
return nil, errors.New("provider already bound to current account")
}
existing, err := s.socialRepo.GetByProviderAndOpenID(ctx, normalizedProvider, strings.TrimSpace(oauthUser.OpenID))
if err != nil {
return nil, err
}
if existing != nil {
if existing.UserID != userID {
return nil, auth.ErrOAuthAlreadyBound
}
existing.UnionID = oauthUser.UnionID
existing.Nickname = oauthUser.Nickname
existing.Avatar = oauthUser.Avatar
existing.Gender = oauthUser.Gender
existing.Email = oauthUser.Email
existing.Phone = oauthUser.Phone
existing.Status = domain.SocialAccountStatusActive
if oauthUser.Extra != nil {
existing.Extra = oauthUser.Extra
}
if err := s.socialRepo.Update(ctx, existing); err != nil {
return nil, err
}
return existing, nil
}
account := &domain.SocialAccount{
UserID: userID,
Provider: normalizedProvider,
OpenID: strings.TrimSpace(oauthUser.OpenID),
UnionID: oauthUser.UnionID,
Nickname: oauthUser.Nickname,
Avatar: oauthUser.Avatar,
Gender: oauthUser.Gender,
Email: oauthUser.Email,
Phone: oauthUser.Phone,
Status: domain.SocialAccountStatusActive,
}
if oauthUser.Extra != nil {
account.Extra = oauthUser.Extra
}
if err := s.socialRepo.Create(ctx, account); err != nil {
return nil, err
}
return account, nil
}
func (s *AuthService) verifySensitiveAction(
ctx context.Context,
user *domain.User,
currentPassword string,
totpCode string,
) error {
if user == nil {
return errors.New("user is required")
}
password := strings.TrimSpace(currentPassword)
code := strings.TrimSpace(totpCode)
hasPassword := strings.TrimSpace(user.Password) != ""
hasTOTP := user.TOTPEnabled && strings.TrimSpace(user.TOTPSecret) != ""
// 如果用户既没有密码也没有启用TOTP禁止执行敏感操作
if !hasPassword && !hasTOTP {
return errors.New("请先设置密码或启用两步验证")
}
if password != "" {
if !hasPassword || !auth.VerifyPassword(user.Password, password) {
return errors.New("当前密码不正确")
}
return nil
}
if code != "" {
if !hasTOTP {
return errors.New("TOTP verification is not available")
}
return s.verifyTOTPCodeOrRecoveryCode(ctx, user, code)
}
return errors.New("password or TOTP verification is required")
}
func (s *AuthService) verifyTOTPCodeOrRecoveryCode(ctx context.Context, user *domain.User, code string) error {
if user == nil {
return errors.New("user is required")
}
if !user.TOTPEnabled || strings.TrimSpace(user.TOTPSecret) == "" {
return errors.New("TOTP verification is not available")
}
manager := auth.NewTOTPManager()
if manager.ValidateCode(user.TOTPSecret, code) {
return nil
}
var hashedCodes []string
if strings.TrimSpace(user.TOTPRecoveryCodes) != "" {
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes)
}
index, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
return errors.New("TOTP code or recovery code is invalid")
}
hashedCodes = append(hashedCodes[:index], hashedCodes[index+1:]...)
payload, err := json.Marshal(hashedCodes)
if err != nil {
return err
}
user.TOTPRecoveryCodes = string(payload)
return s.userRepo.UpdateTOTP(ctx, user)
}
// VerifyTOTP 验证 TOTP支持设备信任跳过
// 如果设备已信任且未过期,跳过 TOTP 验证
func (s *AuthService) VerifyTOTP(ctx context.Context, userID int64, code, deviceID string) error {
if s == nil || s.userRepo == nil {
return errors.New("auth service is not fully configured")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return errors.New("用户不存在")
}
// 检查设备信任状态
if deviceID != "" && s.deviceService != nil {
device, err := s.deviceService.GetDeviceByDeviceID(ctx, userID, deviceID)
if err == nil && device.IsTrusted {
// 检查信任是否过期
if device.TrustExpiresAt == nil || device.TrustExpiresAt.After(time.Now()) {
return nil // 设备已信任,跳过 TOTP 验证
}
}
}
// 执行 TOTP 验证
return s.verifyTOTPCodeOrRecoveryCode(ctx, user, code)
}
func findSocialAccountByProvider(accounts []*domain.SocialAccount, provider string) *domain.SocialAccount {
normalizedProvider := strings.ToLower(strings.TrimSpace(provider))
for _, account := range accounts {
if account == nil {
continue
}
if strings.EqualFold(strings.TrimSpace(account.Provider), normalizedProvider) {
return account
}
}
return nil
}
func (s *AuthService) availableLoginMethodCount(
user *domain.User,
accounts []*domain.SocialAccount,
excludeProvider string,
) int {
if user == nil {
return 0
}
count := 0
if strings.TrimSpace(user.Password) != "" {
count++
}
if s.emailCodeSvc != nil && strings.TrimSpace(domain.DerefStr(user.Email)) != "" {
count++
}
if s.smsCodeSvc != nil && strings.TrimSpace(domain.DerefStr(user.Phone)) != "" {
count++
}
normalizedExcludeProvider := strings.ToLower(strings.TrimSpace(excludeProvider))
for _, account := range accounts {
if account == nil || account.Status != domain.SocialAccountStatusActive {
continue
}
if strings.EqualFold(strings.TrimSpace(account.Provider), normalizedExcludeProvider) {
continue
}
count++
}
return count
}
func (s *AuthService) generateLoginResponse(ctx context.Context, user *domain.User, remember bool) (*LoginResponse, error) {
if s == nil || s.jwtManager == nil {
return nil, errors.New("jwt manager is not configured")
}
if user == nil {
return nil, errors.New("user is required")
}
var accessToken, refreshToken string
var err error
if remember {
accessToken, refreshToken, err = s.jwtManager.GenerateTokenPairWithRemember(user.ID, user.Username, remember)
} else {
accessToken, refreshToken, err = s.jwtManager.GenerateTokenPair(user.ID, user.Username)
}
if err != nil {
return nil, err
}
s.cacheUserInfo(ctx, user)
return &LoginResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: s.accessTokenTTLSeconds(),
User: s.buildUserInfo(user),
}, nil
}
// generateLoginResponseWithoutRemember 生成登录响应(不支持记住登录)
func (s *AuthService) generateLoginResponseWithoutRemember(ctx context.Context, user *domain.User) (*LoginResponse, error) {
return s.generateLoginResponse(ctx, user, false)
}
func (s *AuthService) BindSocialAccount(ctx context.Context, userID int64, provider, openID string) error {
if s == nil || s.socialRepo == nil || s.userRepo == nil {
return errors.New("social account binding is not configured")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
if err := s.ensureUserActive(user); err != nil {
return err
}
normalizedProvider := strings.ToLower(strings.TrimSpace(provider))
normalizedOpenID := strings.TrimSpace(openID)
if normalizedProvider == "" || normalizedOpenID == "" {
return errors.New("provider and open_id are required")
}
accounts, err := s.GetSocialAccounts(ctx, userID)
if err != nil {
return err
}
if existingProvider := findSocialAccountByProvider(accounts, normalizedProvider); existingProvider != nil &&
!strings.EqualFold(strings.TrimSpace(existingProvider.OpenID), normalizedOpenID) {
return errors.New("provider already bound to current account")
}
existing, err := s.socialRepo.GetByProviderAndOpenID(ctx, normalizedProvider, normalizedOpenID)
if err != nil {
return err
}
if existing != nil {
if existing.UserID == userID {
return nil
}
return auth.ErrOAuthAlreadyBound
}
return s.socialRepo.Create(ctx, &domain.SocialAccount{
UserID: userID,
Provider: normalizedProvider,
OpenID: normalizedOpenID,
Status: domain.SocialAccountStatusActive,
})
}
func (s *AuthService) UnbindSocialAccount(ctx context.Context, userID int64, provider, currentPassword, totpCode string) error {
if s == nil || s.socialRepo == nil || s.userRepo == nil {
return errors.New("social account binding is not configured")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
if err := s.ensureUserActive(user); err != nil {
return err
}
accounts, err := s.GetSocialAccounts(ctx, userID)
if err != nil {
return err
}
normalizedProvider := strings.ToLower(strings.TrimSpace(provider))
if findSocialAccountByProvider(accounts, normalizedProvider) == nil {
return auth.ErrOAuthNotFound
}
if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil {
return err
}
if s.availableLoginMethodCount(user, accounts, normalizedProvider) == 0 {
return errors.New("at least one login method must remain after unbinding")
}
return s.socialRepo.DeleteByProviderAndUserID(ctx, normalizedProvider, userID)
}
func (s *AuthService) GetSocialAccounts(ctx context.Context, userID int64) ([]*domain.SocialAccount, error) {
if s == nil || s.socialRepo == nil {
return []*domain.SocialAccount{}, nil
}
accounts, err := s.socialRepo.GetByUserID(ctx, userID)
if err != nil {
return nil, err
}
if accounts == nil {
return []*domain.SocialAccount{}, nil
}
return accounts, nil
}
func (s *AuthService) GetEnabledOAuthProviders() []auth.OAuthProviderInfo {
if s == nil || s.oauthManager == nil {
return []auth.OAuthProviderInfo{}
}
providers := s.oauthManager.GetEnabledProviders()
if providers == nil {
return []auth.OAuthProviderInfo{}
}
return providers
}
func (s *AuthService) LoginByCode(ctx context.Context, phone, code, ip string) (*LoginResponse, error) {
if s == nil || s.smsCodeSvc == nil || s.userRepo == nil {
return nil, errors.New("sms code login is not configured")
}
phone = strings.TrimSpace(phone)
if phone == "" {
return nil, errors.New("手机号不能为空")
}
if err := s.smsCodeSvc.VerifyCode(ctx, phone, "login", strings.TrimSpace(code)); err != nil {
s.writeLoginLog(ctx, nil, domain.LoginTypeSMSCode, ip, false, err.Error())
return nil, err
}
user, err := s.userRepo.GetByPhone(ctx, phone)
if err != nil {
if isUserNotFoundError(err) {
s.writeLoginLog(ctx, nil, domain.LoginTypeSMSCode, ip, false, "手机号未注册")
return nil, errors.New("手机号未注册")
}
s.writeLoginLog(ctx, nil, domain.LoginTypeSMSCode, ip, false, err.Error())
return nil, err
}
if err := s.ensureUserActive(user); err != nil {
s.writeLoginLog(ctx, &user.ID, domain.LoginTypeSMSCode, ip, false, err.Error())
s.recordLoginAnomaly(ctx, &user.ID, ip, "", "", false)
return nil, err
}
s.bestEffortUpdateLastLogin(ctx, user.ID, ip, "sms_code")
s.cacheUserInfo(ctx, user)
s.writeLoginLog(ctx, &user.ID, domain.LoginTypeSMSCode, ip, true, "")
s.recordLoginAnomaly(ctx, &user.ID, ip, "", "", true)
s.publishEvent(ctx, domain.EventUserLogin, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"ip": ip,
"method": "sms_code",
})
return s.generateLoginResponseWithoutRemember(ctx, user)
}