feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers

This commit is contained in:
2026-04-02 11:19:50 +08:00
parent e59a77bc49
commit dcc1f186f8
298 changed files with 62603 additions and 0 deletions

1450
internal/service/auth.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,116 @@
package service
import (
"context"
"errors"
"strings"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
)
var ErrAdminBootstrapUnavailable = errors.New("管理员初始化入口已关闭")
type BootstrapAdminRequest struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
Email string `json:"email"`
Nickname string `json:"nickname"`
}
func (s *AuthService) BootstrapAdmin(ctx context.Context, req *BootstrapAdminRequest, ip string) (*LoginResponse, error) {
if req == nil {
return nil, errors.New("管理员初始化请求不能为空")
}
if s == nil || s.userRepo == nil || s.userRoleRepo == nil || s.roleRepo == nil || s.jwtManager == nil {
return nil, errors.New("管理员初始化能力未正确配置")
}
if !s.IsAdminBootstrapRequired(ctx) {
return nil, ErrAdminBootstrapUnavailable
}
username := strings.TrimSpace(req.Username)
email := strings.TrimSpace(req.Email)
nickname := strings.TrimSpace(req.Nickname)
if username == "" {
return nil, errors.New("用户名不能为空")
}
if strings.TrimSpace(req.Password) == "" {
return nil, errors.New("密码不能为空")
}
if err := s.validatePassword(req.Password); err != nil {
return nil, err
}
exists, err := s.userRepo.ExistsByUsername(ctx, username)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("用户名已存在")
}
if email != "" {
exists, err = s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("邮箱已存在")
}
}
adminRole, err := s.roleRepo.GetByCode(ctx, adminRoleCode)
if err != nil {
return nil, err
}
if adminRole == nil || adminRole.Status != domain.RoleStatusEnabled {
return nil, errors.New("管理员角色不可用")
}
passwordHash, err := auth.HashPassword(req.Password)
if err != nil {
return nil, err
}
if nickname == "" {
nickname = username
}
user := &domain.User{
Username: username,
Email: domain.StrPtr(email),
Password: passwordHash,
Nickname: nickname,
Status: domain.UserStatusActive,
}
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
if err := s.userRoleRepo.BatchCreate(ctx, []*domain.UserRole{
{UserID: user.ID, RoleID: adminRole.ID},
}); err != nil {
_ = s.userRepo.Delete(ctx, user.ID)
return nil, err
}
s.bestEffortUpdateLastLogin(ctx, user.ID, ip, "admin_bootstrap")
s.cacheUserInfo(ctx, user)
s.writeLoginLog(ctx, &user.ID, domain.LoginTypePassword, ip, true, "")
s.publishEvent(ctx, domain.EventUserRegistered, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"role": adminRoleCode,
"source": "admin_bootstrap",
})
s.publishEvent(ctx, domain.EventUserLogin, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"ip": ip,
"method": "admin_bootstrap",
})
return s.generateLoginResponseWithoutRemember(ctx, user)
}

View File

@@ -0,0 +1,99 @@
package service
import (
"context"
"errors"
"log"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
"gorm.io/gorm"
)
const adminRoleCode = "admin"
type AuthCapabilities struct {
Password bool `json:"password"`
EmailActivation bool `json:"email_activation"`
EmailCode bool `json:"email_code"`
SMSCode bool `json:"sms_code"`
PasswordReset bool `json:"password_reset"`
AdminBootstrapRequired bool `json:"admin_bootstrap_required"`
OAuthProviders []auth.OAuthProviderInfo `json:"oauth_providers"`
}
func (s *AuthService) SupportsEmailActivation() bool {
return s != nil && s.emailActivationSvc != nil
}
func (s *AuthService) SupportsEmailCodeLogin() bool {
return s != nil && s.emailCodeSvc != nil
}
func (s *AuthService) SupportsSMSCodeLogin() bool {
return s != nil && s.smsCodeSvc != nil
}
func (s *AuthService) GetAuthCapabilities(ctx context.Context) AuthCapabilities {
if ctx == nil {
ctx = context.Background()
}
return AuthCapabilities{
Password: true,
EmailActivation: s.SupportsEmailActivation(),
EmailCode: s.SupportsEmailCodeLogin(),
SMSCode: s.SupportsSMSCodeLogin(),
AdminBootstrapRequired: s.IsAdminBootstrapRequired(ctx),
OAuthProviders: s.GetEnabledOAuthProviders(),
}
}
func (s *AuthService) IsAdminBootstrapRequired(ctx context.Context) bool {
if s == nil || s.userRepo == nil || s.roleRepo == nil || s.userRoleRepo == nil {
return false
}
if ctx == nil {
ctx = context.Background()
}
adminRole, err := s.roleRepo.GetByCode(ctx, adminRoleCode)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return true
}
log.Printf("auth: resolve auth capabilities failed while loading admin role: %v", err)
return false
}
userIDs, err := s.userRoleRepo.GetUserIDByRoleID(ctx, adminRole.ID)
if err != nil {
log.Printf("auth: resolve auth capabilities failed while loading admin users: role_id=%d err=%v", adminRole.ID, err)
return false
}
if len(userIDs) == 0 {
return true
}
hadUnexpectedLookupError := false
for _, userID := range userIDs {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if isUserNotFoundError(err) {
continue
}
hadUnexpectedLookupError = true
log.Printf("auth: resolve auth capabilities failed while loading admin user: user_id=%d err=%v", userID, err)
continue
}
if user != nil && user.Status == domain.UserStatusActive {
return false
}
}
if hadUnexpectedLookupError {
return false
}
return true
}

View File

@@ -0,0 +1,299 @@
package service
import (
"context"
"errors"
"strings"
"github.com/user-management-system/internal/domain"
)
func (s *AuthService) SendEmailBindCode(ctx context.Context, userID int64, email string) error {
if s == nil || s.userRepo == nil || s.emailCodeSvc == nil {
return errors.New("email binding is not configured")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
if err := s.ensureUserActive(user); err != nil {
return err
}
normalizedEmail := strings.TrimSpace(email)
if normalizedEmail == "" {
return errors.New("email is required")
}
if strings.EqualFold(strings.TrimSpace(domain.DerefStr(user.Email)), normalizedEmail) {
return errors.New("email is already bound to the current account")
}
exists, err := s.userRepo.ExistsByEmail(ctx, normalizedEmail)
if err != nil {
return err
}
if exists {
return errors.New("email already in use")
}
return s.emailCodeSvc.SendEmailCode(ctx, normalizedEmail, "bind")
}
func (s *AuthService) BindEmail(
ctx context.Context,
userID int64,
email string,
code string,
currentPassword string,
totpCode string,
) error {
if s == nil || s.userRepo == nil || s.emailCodeSvc == nil {
return errors.New("email binding is not configured")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
if err := s.ensureUserActive(user); err != nil {
return err
}
normalizedEmail := strings.TrimSpace(email)
if normalizedEmail == "" {
return errors.New("email is required")
}
if strings.EqualFold(strings.TrimSpace(domain.DerefStr(user.Email)), normalizedEmail) {
return errors.New("email is already bound to the current account")
}
exists, err := s.userRepo.ExistsByEmail(ctx, normalizedEmail)
if err != nil {
return err
}
if exists {
return errors.New("email already in use")
}
if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil {
return err
}
if err := s.emailCodeSvc.VerifyEmailCode(ctx, normalizedEmail, "bind", strings.TrimSpace(code)); err != nil {
return err
}
user.Email = domain.StrPtr(normalizedEmail)
if err := s.userRepo.Update(ctx, user); err != nil {
return err
}
s.cacheUserInfo(ctx, user)
s.publishEvent(ctx, domain.EventUserUpdated, map[string]interface{}{
"user_id": user.ID,
"email": normalizedEmail,
"action": "bind_email",
})
return nil
}
func (s *AuthService) UnbindEmail(ctx context.Context, userID int64, currentPassword, totpCode string) error {
if s == nil || s.userRepo == nil {
return errors.New("email binding is not configured")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
if err := s.ensureUserActive(user); err != nil {
return err
}
if strings.TrimSpace(domain.DerefStr(user.Email)) == "" {
return errors.New("email is not bound")
}
if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil {
return err
}
accounts, err := s.GetSocialAccounts(ctx, userID)
if err != nil {
return err
}
if s.availableLoginMethodCountAfterContactRemoval(user, accounts, true, false) == 0 {
return errors.New("at least one login method must remain after unbinding")
}
user.Email = nil
if err := s.userRepo.Update(ctx, user); err != nil {
return err
}
s.cacheUserInfo(ctx, user)
s.publishEvent(ctx, domain.EventUserUpdated, map[string]interface{}{
"user_id": user.ID,
"action": "unbind_email",
})
return nil
}
func (s *AuthService) SendPhoneBindCode(ctx context.Context, userID int64, phone string) (*SendCodeResponse, error) {
if s == nil || s.userRepo == nil || s.smsCodeSvc == nil {
return nil, errors.New("phone binding is not configured")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
if err := s.ensureUserActive(user); err != nil {
return nil, err
}
normalizedPhone := strings.TrimSpace(phone)
if normalizedPhone == "" {
return nil, errors.New("phone is required")
}
if strings.TrimSpace(domain.DerefStr(user.Phone)) == normalizedPhone {
return nil, errors.New("phone is already bound to the current account")
}
exists, err := s.userRepo.ExistsByPhone(ctx, normalizedPhone)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("phone already in use")
}
return s.smsCodeSvc.SendCode(ctx, &SendCodeRequest{
Phone: normalizedPhone,
Purpose: "bind",
})
}
func (s *AuthService) BindPhone(
ctx context.Context,
userID int64,
phone string,
code string,
currentPassword string,
totpCode string,
) error {
if s == nil || s.userRepo == nil || s.smsCodeSvc == nil {
return errors.New("phone binding is not configured")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
if err := s.ensureUserActive(user); err != nil {
return err
}
normalizedPhone := strings.TrimSpace(phone)
if normalizedPhone == "" {
return errors.New("phone is required")
}
if strings.TrimSpace(domain.DerefStr(user.Phone)) == normalizedPhone {
return errors.New("phone is already bound to the current account")
}
exists, err := s.userRepo.ExistsByPhone(ctx, normalizedPhone)
if err != nil {
return err
}
if exists {
return errors.New("phone already in use")
}
if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil {
return err
}
if err := s.smsCodeSvc.VerifyCode(ctx, normalizedPhone, "bind", strings.TrimSpace(code)); err != nil {
return err
}
user.Phone = domain.StrPtr(normalizedPhone)
if err := s.userRepo.Update(ctx, user); err != nil {
return err
}
s.cacheUserInfo(ctx, user)
s.publishEvent(ctx, domain.EventUserUpdated, map[string]interface{}{
"user_id": user.ID,
"phone": normalizedPhone,
"action": "bind_phone",
})
return nil
}
func (s *AuthService) UnbindPhone(ctx context.Context, userID int64, currentPassword, totpCode string) error {
if s == nil || s.userRepo == nil {
return errors.New("phone binding is not configured")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
if err := s.ensureUserActive(user); err != nil {
return err
}
if strings.TrimSpace(domain.DerefStr(user.Phone)) == "" {
return errors.New("phone is not bound")
}
if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil {
return err
}
accounts, err := s.GetSocialAccounts(ctx, userID)
if err != nil {
return err
}
if s.availableLoginMethodCountAfterContactRemoval(user, accounts, false, true) == 0 {
return errors.New("at least one login method must remain after unbinding")
}
user.Phone = nil
if err := s.userRepo.Update(ctx, user); err != nil {
return err
}
s.cacheUserInfo(ctx, user)
s.publishEvent(ctx, domain.EventUserUpdated, map[string]interface{}{
"user_id": user.ID,
"action": "unbind_phone",
})
return nil
}
func (s *AuthService) availableLoginMethodCountAfterContactRemoval(
user *domain.User,
accounts []*domain.SocialAccount,
removeEmail bool,
removePhone bool,
) int {
if user == nil {
return 0
}
count := 0
if strings.TrimSpace(user.Password) != "" {
count++
}
if !removeEmail && s.emailCodeSvc != nil && strings.TrimSpace(domain.DerefStr(user.Email)) != "" {
count++
}
if !removePhone && s.smsCodeSvc != nil && strings.TrimSpace(domain.DerefStr(user.Phone)) != "" {
count++
}
for _, account := range accounts {
if account == nil || account.Status != domain.SocialAccountStatusActive {
continue
}
count++
}
return count
}

View File

@@ -0,0 +1,201 @@
package service
import (
"context"
"errors"
"fmt"
"log"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
)
func (s *AuthService) SetEmailActivationService(svc *EmailActivationService) {
s.emailActivationSvc = svc
}
func (s *AuthService) SetEmailCodeService(svc *EmailCodeService) {
s.emailCodeSvc = svc
}
func (s *AuthService) RegisterWithActivation(ctx context.Context, req *RegisterRequest) (*UserInfo, error) {
if err := s.validatePassword(req.Password); err != nil {
return nil, err
}
if err := s.verifyPhoneRegistration(ctx, req); err != nil {
return nil, err
}
exists, err := s.userRepo.ExistsByUsername(ctx, req.Username)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("username already exists")
}
if req.Email != "" {
exists, err = s.userRepo.ExistsByEmail(ctx, req.Email)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("email already exists")
}
}
if req.Phone != "" {
exists, err = s.userRepo.ExistsByPhone(ctx, req.Phone)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("phone already exists")
}
}
hashedPassword, err := auth.HashPassword(req.Password)
if err != nil {
return nil, err
}
initialStatus := domain.UserStatusActive
if s.emailActivationSvc != nil && req.Email != "" {
initialStatus = domain.UserStatusInactive
}
user := &domain.User{
Username: req.Username,
Email: domain.StrPtr(req.Email),
Phone: domain.StrPtr(req.Phone),
Password: hashedPassword,
Nickname: req.Nickname,
Status: initialStatus,
}
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
s.bestEffortAssignDefaultRoles(ctx, user.ID, "register_with_activation")
if s.emailActivationSvc != nil && req.Email != "" {
nickname := req.Nickname
if nickname == "" {
nickname = req.Username
}
go func() {
if err := s.emailActivationSvc.SendActivationEmail(ctx, user.ID, req.Email, nickname); err != nil {
log.Printf("auth: send activation email failed, user_id=%d email=%s err=%v", user.ID, req.Email, err)
}
}()
}
userInfo := s.buildUserInfo(user)
s.publishEvent(ctx, domain.EventUserRegistered, userInfo)
return userInfo, nil
}
func (s *AuthService) ActivateEmail(ctx context.Context, token string) error {
if s.emailActivationSvc == nil {
return errors.New("email activation service is not configured")
}
userID, err := s.emailActivationSvc.ValidateActivationToken(ctx, token)
if err != nil {
return err
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("user not found: %w", err)
}
if user.Status == domain.UserStatusActive {
return errors.New("account already activated")
}
if user.Status != domain.UserStatusInactive {
return errors.New("account status does not allow activation")
}
return s.userRepo.UpdateStatus(ctx, userID, domain.UserStatusActive)
}
func (s *AuthService) ResendActivationEmail(ctx context.Context, email string) error {
if s.emailActivationSvc == nil {
return errors.New("email activation service is not configured")
}
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if isUserNotFoundError(err) {
return nil
}
return err
}
if user.Status == domain.UserStatusActive {
return nil
}
if user.Status != domain.UserStatusInactive {
return errors.New("account status does not allow activation")
}
nickname := user.Nickname
if nickname == "" {
nickname = user.Username
}
return s.emailActivationSvc.SendActivationEmail(ctx, user.ID, email, nickname)
}
func (s *AuthService) SendEmailLoginCode(ctx context.Context, email string) error {
if s.emailCodeSvc == nil {
return errors.New("email code service is not configured")
}
_, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if isUserNotFoundError(err) {
return nil
}
return err
}
return s.emailCodeSvc.SendEmailCode(ctx, email, "login")
}
func (s *AuthService) LoginByEmailCode(ctx context.Context, email, code, ip string) (*LoginResponse, error) {
if s.emailCodeSvc == nil {
return nil, errors.New("email code login is disabled")
}
if err := s.emailCodeSvc.VerifyEmailCode(ctx, email, "login", code); err != nil {
s.writeLoginLog(ctx, nil, domain.LoginTypeEmailCode, ip, false, err.Error())
return nil, err
}
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if isUserNotFoundError(err) {
s.writeLoginLog(ctx, nil, domain.LoginTypeEmailCode, ip, false, "email not registered")
return nil, errors.New("email not registered")
}
s.writeLoginLog(ctx, nil, domain.LoginTypeEmailCode, ip, false, err.Error())
return nil, err
}
if err := s.ensureUserActive(user); err != nil {
s.writeLoginLog(ctx, &user.ID, domain.LoginTypeEmailCode, ip, false, err.Error())
s.recordLoginAnomaly(ctx, &user.ID, ip, "", "", false)
return nil, err
}
s.bestEffortUpdateLastLogin(ctx, user.ID, ip, "email_code")
s.writeLoginLog(ctx, &user.ID, domain.LoginTypeEmailCode, ip, true, "")
s.recordLoginAnomaly(ctx, &user.ID, ip, "", "", true)
s.publishEvent(ctx, domain.EventUserLogin, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"ip": ip,
"method": "email_code",
})
return s.generateLoginResponseWithoutRemember(ctx, user)
}

View File

@@ -0,0 +1,369 @@
package service
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
"gorm.io/gorm"
)
type oauthRegistrar interface {
RegisterProvider(provider auth.OAuthProvider, config *auth.OAuthConfig)
}
func (s *AuthService) RegisterOAuthProvider(provider auth.OAuthProvider, cfg *auth.OAuthConfig) {
if cfg == nil {
return
}
if registrar, ok := s.oauthManager.(oauthRegistrar); ok {
registrar.RegisterProvider(provider, cfg)
}
}
func (s *AuthService) findUserForLogin(ctx context.Context, account string) (*domain.User, error) {
user, err := s.userRepo.GetByUsername(ctx, account)
if err == nil {
return user, nil
}
if !isUserNotFoundError(err) {
return nil, fmt.Errorf("lookup user by username failed: %w", err)
}
user, err = s.userRepo.GetByEmail(ctx, account)
if err == nil {
return user, nil
}
if !isUserNotFoundError(err) {
return nil, fmt.Errorf("lookup user by email failed: %w", err)
}
user, err = s.userRepo.GetByPhone(ctx, account)
if err != nil && !isUserNotFoundError(err) {
return nil, fmt.Errorf("lookup user by phone failed: %w", err)
}
return user, err
}
func isUserNotFoundError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return true
}
lowerErr := strings.ToLower(strings.TrimSpace(err.Error()))
return strings.Contains(lowerErr, "record not found") ||
strings.Contains(lowerErr, "user not found") ||
strings.Contains(err.Error(), "用户不存在") ||
strings.Contains(lowerErr, "not found")
}
func (s *AuthService) bestEffortAssignDefaultRoles(ctx context.Context, userID int64, source string) {
if s == nil || s.userRoleRepo == nil || s.roleRepo == nil {
return
}
defaultRoles, err := s.roleRepo.GetDefaultRoles(ctx)
if err != nil {
log.Printf("auth: load default roles failed, source=%s user_id=%d err=%v", source, userID, err)
return
}
if len(defaultRoles) == 0 {
return
}
userRoles := make([]*domain.UserRole, 0, len(defaultRoles))
for _, role := range defaultRoles {
userRoles = append(userRoles, &domain.UserRole{
UserID: userID,
RoleID: role.ID,
})
}
if err := s.userRoleRepo.BatchCreate(ctx, userRoles); err != nil {
log.Printf("auth: assign default roles failed, source=%s user_id=%d role_count=%d err=%v", source, userID, len(userRoles), err)
}
}
func (s *AuthService) bestEffortUpdateLastLogin(ctx context.Context, userID int64, ip, source string) {
if s == nil || s.userRepo == nil {
return
}
if err := s.userRepo.UpdateLastLogin(ctx, userID, ip); err != nil {
log.Printf("auth: update last login failed, source=%s user_id=%d ip=%s err=%v", source, userID, ip, err)
}
}
func loginAttemptKey(account string, user *domain.User) string {
if user != nil {
return fmt.Sprintf("login_attempt:user:%d", user.ID)
}
return "login_attempt:account:" + strings.ToLower(strings.TrimSpace(account))
}
func attemptCount(value interface{}) int {
if count, ok := intValue(value); ok {
return count
}
return 0
}
func intValue(value interface{}) (int, bool) {
switch v := value.(type) {
case int:
return v, true
case int64:
return int(v), true
case float64:
return int(v), true
case json.Number:
n, err := v.Int64()
if err != nil {
return 0, false
}
return int(n), true
default:
return 0, false
}
}
func int64Value(value interface{}) (int64, bool) {
switch v := value.(type) {
case int64:
return v, true
case int:
return int64(v), true
case float64:
return int64(v), true
case json.Number:
n, err := v.Int64()
if err != nil {
return 0, false
}
return n, true
default:
return 0, false
}
}
func (s *AuthService) verifyPhoneRegistration(ctx context.Context, req *RegisterRequest) error {
if req == nil || req.Phone == "" {
return nil
}
if s.smsCodeSvc == nil {
return errors.New("手机注册未启用")
}
if req.PhoneCode == "" {
return errors.New("手机验证码不能为空")
}
return s.smsCodeSvc.VerifyCode(ctx, req.Phone, "register", req.PhoneCode)
}
const (
oauthStateCachePrefix = "oauth_state:"
oauthHandoffCachePrefix = "oauth_handoff:"
oauthStateTTL = 10 * time.Minute
oauthHandoffTTL = time.Minute
)
type OAuthStatePurpose string
const (
OAuthStatePurposeLogin OAuthStatePurpose = "login"
OAuthStatePurposeBind OAuthStatePurpose = "bind"
)
type OAuthStatePayload struct {
Purpose OAuthStatePurpose `json:"purpose"`
ReturnTo string `json:"return_to"`
UserID int64 `json:"user_id,omitempty"`
}
func generateOAuthEphemeralCode() (string, error) {
buffer := make([]byte, 32)
if _, err := rand.Read(buffer); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buffer), nil
}
func (s *AuthService) CreateOAuthState(ctx context.Context, returnTo string) (string, error) {
return s.createOAuthStatePayload(ctx, &OAuthStatePayload{
Purpose: OAuthStatePurposeLogin,
ReturnTo: strings.TrimSpace(returnTo),
})
}
func (s *AuthService) CreateOAuthBindState(ctx context.Context, userID int64, returnTo string) (string, error) {
if userID <= 0 {
return "", errors.New("oauth binding user is required")
}
return s.createOAuthStatePayload(ctx, &OAuthStatePayload{
Purpose: OAuthStatePurposeBind,
ReturnTo: strings.TrimSpace(returnTo),
UserID: userID,
})
}
func (s *AuthService) createOAuthStatePayload(ctx context.Context, payload *OAuthStatePayload) (string, error) {
if s == nil || s.cache == nil {
return "", errors.New("oauth state storage unavailable")
}
if payload == nil {
return "", errors.New("oauth state payload is required")
}
if payload.Purpose == "" {
payload.Purpose = OAuthStatePurposeLogin
}
state, err := generateOAuthEphemeralCode()
if err != nil {
return "", err
}
if err := s.cache.Set(ctx, oauthStateCachePrefix+state, payload, oauthStateTTL, oauthStateTTL); err != nil {
return "", err
}
return state, nil
}
func (s *AuthService) ConsumeOAuthState(ctx context.Context, state string) (string, error) {
payload, err := s.ConsumeOAuthStatePayload(ctx, state)
if err != nil {
return "", err
}
if payload == nil {
return "", nil
}
return strings.TrimSpace(payload.ReturnTo), nil
}
func (s *AuthService) ConsumeOAuthStatePayload(ctx context.Context, state string) (*OAuthStatePayload, error) {
if s == nil || s.cache == nil {
return nil, errors.New("oauth state storage unavailable")
}
cacheKey := oauthStateCachePrefix + strings.TrimSpace(state)
value, ok := s.cache.Get(ctx, cacheKey)
if !ok {
return nil, errors.New("OAuth state validation failed")
}
_ = s.cache.Delete(ctx, cacheKey)
switch typed := value.(type) {
case *OAuthStatePayload:
payload := *typed
if payload.Purpose == "" {
payload.Purpose = OAuthStatePurposeLogin
}
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
return &payload, nil
case OAuthStatePayload:
payload := typed
if payload.Purpose == "" {
payload.Purpose = OAuthStatePurposeLogin
}
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
return &payload, nil
case string:
return &OAuthStatePayload{
Purpose: OAuthStatePurposeLogin,
ReturnTo: strings.TrimSpace(typed),
}, nil
case nil:
return &OAuthStatePayload{Purpose: OAuthStatePurposeLogin}, nil
case map[string]interface{}:
payloadBytes, err := json.Marshal(typed)
if err != nil {
return nil, err
}
var payload OAuthStatePayload
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
return nil, err
}
if payload.Purpose == "" {
payload.Purpose = OAuthStatePurposeLogin
}
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
return &payload, nil
default:
return &OAuthStatePayload{
Purpose: OAuthStatePurposeLogin,
ReturnTo: strings.TrimSpace(fmt.Sprint(typed)),
}, nil
}
}
func (s *AuthService) CreateOAuthHandoff(ctx context.Context, loginResp *LoginResponse) (string, error) {
if s == nil || s.cache == nil {
return "", errors.New("oauth handoff storage unavailable")
}
if loginResp == nil {
return "", errors.New("oauth handoff payload is required")
}
code, err := generateOAuthEphemeralCode()
if err != nil {
return "", err
}
if err := s.cache.Set(ctx, oauthHandoffCachePrefix+code, loginResp, oauthHandoffTTL, oauthHandoffTTL); err != nil {
return "", err
}
return code, nil
}
func (s *AuthService) ConsumeOAuthHandoff(ctx context.Context, code string) (*LoginResponse, error) {
if s == nil || s.cache == nil {
return nil, errors.New("oauth handoff storage unavailable")
}
cacheKey := oauthHandoffCachePrefix + strings.TrimSpace(code)
value, ok := s.cache.Get(ctx, cacheKey)
if !ok {
return nil, errors.New("OAuth handoff code is invalid or expired")
}
_ = s.cache.Delete(ctx, cacheKey)
switch typed := value.(type) {
case *LoginResponse:
return typed, nil
case LoginResponse:
resp := typed
return &resp, nil
case map[string]interface{}:
payload, err := json.Marshal(typed)
if err != nil {
return nil, err
}
var resp LoginResponse
if err := json.Unmarshal(payload, &resp); err != nil {
return nil, err
}
return &resp, nil
default:
payload, err := json.Marshal(typed)
if err != nil {
return nil, err
}
var resp LoginResponse
if err := json.Unmarshal(payload, &resp); err != nil {
return nil, err
}
return &resp, nil
}
}

343
internal/service/captcha.go Normal file
View File

@@ -0,0 +1,343 @@
package service
import (
"bytes"
"context"
crand "crypto/rand"
"encoding/hex"
"errors"
"fmt"
"image"
"image/color"
"image/draw"
"image/png"
"math/big"
"math/rand"
"strings"
"time"
"github.com/user-management-system/internal/cache"
)
const (
captchaWidth = 120
captchaHeight = 40
captchaLength = 4 // 验证码位数
captchaTTL = 5 * time.Minute
)
// captchaChars 验证码字符集(去掉容易混淆的字符 0/O/1/I/l
const captchaChars = "23456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjkmnpqrstuvwxyz"
// CaptchaService 图形验证码服务
type CaptchaService struct {
cache *cache.CacheManager
}
// NewCaptchaService 创建验证码服务
func NewCaptchaService(cache *cache.CacheManager) *CaptchaService {
return &CaptchaService{cache: cache}
}
// CaptchaResult 验证码生成结果
type CaptchaResult struct {
CaptchaID string // 验证码IDUUID
ImageData []byte // PNG图片字节
}
// Generate 生成图形验证码
func (s *CaptchaService) Generate(ctx context.Context) (*CaptchaResult, error) {
// 生成随机验证码文字
text, err := s.randomText(captchaLength)
if err != nil {
return nil, fmt.Errorf("生成验证码文本失败: %w", err)
}
// 生成验证码ID
captchaID, err := s.generateID()
if err != nil {
return nil, fmt.Errorf("生成验证码ID失败: %w", err)
}
// 生成图片
imgData, err := s.renderImage(text)
if err != nil {
return nil, fmt.Errorf("生成验证码图片失败: %w", err)
}
// 存入缓存(不区分大小写,存小写)
cacheKey := "captcha:" + captchaID
s.cache.Set(ctx, cacheKey, strings.ToLower(text), captchaTTL, captchaTTL)
return &CaptchaResult{
CaptchaID: captchaID,
ImageData: imgData,
}, nil
}
// Verify 验证验证码(验证后立即删除,防止重放)
func (s *CaptchaService) Verify(ctx context.Context, captchaID, answer string) bool {
if captchaID == "" || answer == "" {
return false
}
cacheKey := "captcha:" + captchaID
val, ok := s.cache.Get(ctx, cacheKey)
if !ok {
return false
}
// 删除验证码(一次性使用)
s.cache.Delete(ctx, cacheKey)
expected, ok := val.(string)
if !ok {
return false
}
return strings.ToLower(answer) == expected
}
// VerifyWithoutDelete 验证验证码但不删除(用于测试)
func (s *CaptchaService) VerifyWithoutDelete(ctx context.Context, captchaID, answer string) bool {
if captchaID == "" || answer == "" {
return false
}
cacheKey := "captcha:" + captchaID
val, ok := s.cache.Get(ctx, cacheKey)
if !ok {
return false
}
expected, ok := val.(string)
if !ok {
return false
}
return strings.ToLower(answer) == expected
}
// ValidateCaptcha 验证验证码(对外暴露,验证后删除)
func (s *CaptchaService) ValidateCaptcha(ctx context.Context, captchaID, answer string) error {
if captchaID == "" {
return errors.New("验证码ID不能为空")
}
if answer == "" {
return errors.New("验证码答案不能为空")
}
if !s.Verify(ctx, captchaID, answer) {
return errors.New("验证码错误或已过期")
}
return nil
}
// randomText 生成随机验证码文字
func (s *CaptchaService) randomText(length int) (string, error) {
chars := []byte(captchaChars)
result := make([]byte, length)
for i := range result {
n, err := crand.Int(crand.Reader, big.NewInt(int64(len(chars))))
if err != nil {
return "", err
}
result[i] = chars[n.Int64()]
}
return string(result), nil
}
// generateID 生成验证码IDcrypto/rand 保证全局唯一,无碰撞)
func (s *CaptchaService) generateID() (string, error) {
b := make([]byte, 16)
if _, err := crand.Read(b); err != nil {
return "", err
}
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), hex.EncodeToString(b)), nil
}
// renderImage 将文字渲染为PNG验证码图片纯Go实现无外部字体依赖
func (s *CaptchaService) renderImage(text string) ([]byte, error) {
// 创建 RGBA 图像
img := image.NewRGBA(image.Rect(0, 0, captchaWidth, captchaHeight))
// 随机背景色(浅色)
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
bgColor := color.RGBA{
R: uint8(220 + rng.Intn(35)),
G: uint8(220 + rng.Intn(35)),
B: uint8(220 + rng.Intn(35)),
A: 255,
}
draw.Draw(img, img.Bounds(), &image.Uniform{bgColor}, image.Point{}, draw.Src)
// 绘制干扰线
for i := 0; i < 5; i++ {
lineColor := color.RGBA{
R: uint8(rng.Intn(200)),
G: uint8(rng.Intn(200)),
B: uint8(rng.Intn(200)),
A: 255,
}
x1 := rng.Intn(captchaWidth)
y1 := rng.Intn(captchaHeight)
x2 := rng.Intn(captchaWidth)
y2 := rng.Intn(captchaHeight)
drawLine(img, x1, y1, x2, y2, lineColor)
}
// 绘制文字(使用像素字体)
for i, ch := range text {
charColor := color.RGBA{
R: uint8(rng.Intn(150)),
G: uint8(rng.Intn(150)),
B: uint8(rng.Intn(150)),
A: 255,
}
x := 10 + i*25 + rng.Intn(5)
y := 8 + rng.Intn(12)
drawChar(img, x, y, byte(ch), charColor)
}
// 绘制干扰点
for i := 0; i < 80; i++ {
dotColor := color.RGBA{
R: uint8(rng.Intn(255)),
G: uint8(rng.Intn(255)),
B: uint8(rng.Intn(255)),
A: uint8(100 + rng.Intn(100)),
}
img.Set(rng.Intn(captchaWidth), rng.Intn(captchaHeight), dotColor)
}
// 编码为 PNG
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// drawLine 画直线Bresenham算法
func drawLine(img *image.RGBA, x1, y1, x2, y2 int, c color.RGBA) {
dx := abs(x2 - x1)
dy := abs(y2 - y1)
sx, sy := 1, 1
if x1 > x2 {
sx = -1
}
if y1 > y2 {
sy = -1
}
err := dx - dy
for {
img.Set(x1, y1, c)
if x1 == x2 && y1 == y2 {
break
}
e2 := 2 * err
if e2 > -dy {
err -= dy
x1 += sx
}
if e2 < dx {
err += dx
y1 += sy
}
}
}
func abs(x int) int {
if x < 0 {
return -x
}
return x
}
// pixelFont 5x7 像素字体位图ASCII 32-127
// 每个字符用5个uint8表示5列每个uint8的低7位是每行是否亮起
var pixelFont = map[byte][5]uint8{
'0': {0x3E, 0x51, 0x49, 0x45, 0x3E},
'1': {0x00, 0x42, 0x7F, 0x40, 0x00},
'2': {0x42, 0x61, 0x51, 0x49, 0x46},
'3': {0x21, 0x41, 0x45, 0x4B, 0x31},
'4': {0x18, 0x14, 0x12, 0x7F, 0x10},
'5': {0x27, 0x45, 0x45, 0x45, 0x39},
'6': {0x3C, 0x4A, 0x49, 0x49, 0x30},
'7': {0x01, 0x71, 0x09, 0x05, 0x03},
'8': {0x36, 0x49, 0x49, 0x49, 0x36},
'9': {0x06, 0x49, 0x49, 0x29, 0x1E},
'A': {0x7E, 0x11, 0x11, 0x11, 0x7E},
'B': {0x7F, 0x49, 0x49, 0x49, 0x36},
'C': {0x3E, 0x41, 0x41, 0x41, 0x22},
'D': {0x7F, 0x41, 0x41, 0x22, 0x1C},
'E': {0x7F, 0x49, 0x49, 0x49, 0x41},
'F': {0x7F, 0x09, 0x09, 0x09, 0x01},
'G': {0x3E, 0x41, 0x49, 0x49, 0x7A},
'H': {0x7F, 0x08, 0x08, 0x08, 0x7F},
'J': {0x20, 0x40, 0x41, 0x3F, 0x01},
'K': {0x7F, 0x08, 0x14, 0x22, 0x41},
'L': {0x7F, 0x40, 0x40, 0x40, 0x40},
'M': {0x7F, 0x02, 0x0C, 0x02, 0x7F},
'N': {0x7F, 0x04, 0x08, 0x10, 0x7F},
'P': {0x7F, 0x09, 0x09, 0x09, 0x06},
'Q': {0x3E, 0x41, 0x51, 0x21, 0x5E},
'R': {0x7F, 0x09, 0x19, 0x29, 0x46},
'S': {0x46, 0x49, 0x49, 0x49, 0x31},
'T': {0x01, 0x01, 0x7F, 0x01, 0x01},
'U': {0x3F, 0x40, 0x40, 0x40, 0x3F},
'V': {0x1F, 0x20, 0x40, 0x20, 0x1F},
'W': {0x3F, 0x40, 0x38, 0x40, 0x3F},
'X': {0x63, 0x14, 0x08, 0x14, 0x63},
'Y': {0x07, 0x08, 0x70, 0x08, 0x07},
'Z': {0x61, 0x51, 0x49, 0x45, 0x43},
'a': {0x20, 0x54, 0x54, 0x54, 0x78},
'b': {0x7F, 0x48, 0x44, 0x44, 0x38},
'c': {0x38, 0x44, 0x44, 0x44, 0x20},
'd': {0x38, 0x44, 0x44, 0x48, 0x7F},
'e': {0x38, 0x54, 0x54, 0x54, 0x18},
'f': {0x08, 0x7E, 0x09, 0x01, 0x02},
'g': {0x0C, 0x52, 0x52, 0x52, 0x3E},
'h': {0x7F, 0x08, 0x04, 0x04, 0x78},
'j': {0x20, 0x40, 0x44, 0x3D, 0x00},
'k': {0x7F, 0x10, 0x28, 0x44, 0x00},
'm': {0x7C, 0x04, 0x18, 0x04, 0x78},
'n': {0x7C, 0x08, 0x04, 0x04, 0x78},
'p': {0x7C, 0x14, 0x14, 0x14, 0x08},
'q': {0x08, 0x14, 0x14, 0x18, 0x7C},
'r': {0x7C, 0x08, 0x04, 0x04, 0x08},
's': {0x48, 0x54, 0x54, 0x54, 0x20},
't': {0x04, 0x3F, 0x44, 0x40, 0x20},
'u': {0x3C, 0x40, 0x40, 0x20, 0x7C},
'v': {0x1C, 0x20, 0x40, 0x20, 0x1C},
'w': {0x3C, 0x40, 0x30, 0x40, 0x3C},
'x': {0x44, 0x28, 0x10, 0x28, 0x44},
'y': {0x0C, 0x50, 0x50, 0x50, 0x3C},
'z': {0x44, 0x64, 0x54, 0x4C, 0x44},
}
// drawChar 在图像上绘制单个字符
func drawChar(img *image.RGBA, x, y int, ch byte, c color.RGBA) {
glyph, ok := pixelFont[ch]
if !ok {
// 未知字符画个方块
for dy := 0; dy < 7; dy++ {
for dx := 0; dx < 5; dx++ {
img.Set(x+dx*2, y+dy*2, c)
}
}
return
}
for col, colData := range glyph {
for row := 0; row < 7; row++ {
if colData&(1<<uint(row)) != 0 {
// 放大2倍绘制
img.Set(x+col*2, y+row*2, c)
img.Set(x+col*2+1, y+row*2, c)
img.Set(x+col*2, y+row*2+1, c)
img.Set(x+col*2+1, y+row*2+1, c)
}
}
}
}

View File

@@ -0,0 +1,41 @@
package service
import "errors"
var (
ErrRateLimitExceeded = errors.New("rate limit exceeded")
ErrValidationFailed = errors.New("validation failed")
)
type classifiedError struct {
message string
cause error
}
func (e *classifiedError) Error() string {
if e.message != "" {
return e.message
}
if e.cause != nil {
return e.cause.Error()
}
return ""
}
func (e *classifiedError) Unwrap() error {
return e.cause
}
func newRateLimitError(message string) error {
return &classifiedError{
message: message,
cause: ErrRateLimitExceeded,
}
}
func newValidationError(message string) error {
return &classifiedError{
message: message,
cause: ErrValidationFailed,
}
}

View File

@@ -0,0 +1,319 @@
package service
import (
"context"
"errors"
"fmt"
"strconv"
"time"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
// CustomFieldService 自定义字段服务
type CustomFieldService struct {
fieldRepo *repository.CustomFieldRepository
valueRepo *repository.UserCustomFieldValueRepository
}
// NewCustomFieldService 创建自定义字段服务
func NewCustomFieldService(
fieldRepo *repository.CustomFieldRepository,
valueRepo *repository.UserCustomFieldValueRepository,
) *CustomFieldService {
return &CustomFieldService{
fieldRepo: fieldRepo,
valueRepo: valueRepo,
}
}
// CreateFieldRequest 创建字段请求
type CreateFieldRequest struct {
Name string `json:"name" binding:"required"`
FieldKey string `json:"field_key" binding:"required"`
Type int `json:"type" binding:"required"`
Required bool `json:"required"`
Default string `json:"default"`
MinLen int `json:"min_len"`
MaxLen int `json:"max_len"`
MinVal float64 `json:"min_val"`
MaxVal float64 `json:"max_val"`
Options string `json:"options"`
Sort int `json:"sort"`
}
// UpdateFieldRequest 更新字段请求
type UpdateFieldRequest struct {
Name string `json:"name"`
Type int `json:"type"`
Required *bool `json:"required"`
Default string `json:"default"`
MinLen int `json:"min_len"`
MaxLen int `json:"max_len"`
MinVal float64 `json:"min_val"`
MaxVal float64 `json:"max_val"`
Options string `json:"options"`
Sort int `json:"sort"`
Status *int `json:"status"`
}
// CreateField 创建自定义字段
func (s *CustomFieldService) CreateField(ctx context.Context, req *CreateFieldRequest) (*domain.CustomField, error) {
// 检查field_key是否已存在
existing, err := s.fieldRepo.GetByFieldKey(ctx, req.FieldKey)
if err == nil && existing != nil {
return nil, errors.New("字段标识符已存在")
}
field := &domain.CustomField{
Name: req.Name,
FieldKey: req.FieldKey,
Type: domain.CustomFieldType(req.Type),
Required: req.Required,
DefaultVal: req.Default,
MinLen: req.MinLen,
MaxLen: req.MaxLen,
MinVal: req.MinVal,
MaxVal: req.MaxVal,
Options: req.Options,
Sort: req.Sort,
Status: 1,
}
if err := s.fieldRepo.Create(ctx, field); err != nil {
return nil, err
}
return field, nil
}
// UpdateField 更新自定义字段
func (s *CustomFieldService) UpdateField(ctx context.Context, id int64, req *UpdateFieldRequest) (*domain.CustomField, error) {
field, err := s.fieldRepo.GetByID(ctx, id)
if err != nil {
return nil, errors.New("字段不存在")
}
if req.Name != "" {
field.Name = req.Name
}
if req.Type > 0 {
field.Type = domain.CustomFieldType(req.Type)
}
if req.Required != nil {
field.Required = *req.Required
}
if req.Default != "" {
field.DefaultVal = req.Default
}
if req.MinLen > 0 {
field.MinLen = req.MinLen
}
if req.MaxLen > 0 {
field.MaxLen = req.MaxLen
}
if req.MinVal > 0 {
field.MinVal = req.MinVal
}
if req.MaxVal > 0 {
field.MaxVal = req.MaxVal
}
if req.Options != "" {
field.Options = req.Options
}
if req.Sort > 0 {
field.Sort = req.Sort
}
if req.Status != nil {
field.Status = *req.Status
}
if err := s.fieldRepo.Update(ctx, field); err != nil {
return nil, err
}
return field, nil
}
// DeleteField 删除自定义字段
func (s *CustomFieldService) DeleteField(ctx context.Context, id int64) error {
field, err := s.fieldRepo.GetByID(ctx, id)
if err != nil {
return errors.New("字段不存在")
}
// 删除字段定义
if err := s.fieldRepo.Delete(ctx, id); err != nil {
return err
}
// 清理用户的该字段值(可选,取决于业务需求)
_ = field
return nil
}
// GetField 获取自定义字段
func (s *CustomFieldService) GetField(ctx context.Context, id int64) (*domain.CustomField, error) {
return s.fieldRepo.GetByID(ctx, id)
}
// ListFields 获取所有启用的自定义字段
func (s *CustomFieldService) ListFields(ctx context.Context) ([]*domain.CustomField, error) {
return s.fieldRepo.List(ctx)
}
// ListAllFields 获取所有自定义字段
func (s *CustomFieldService) ListAllFields(ctx context.Context) ([]*domain.CustomField, error) {
return s.fieldRepo.ListAll(ctx)
}
// SetUserFieldValue 设置用户的自定义字段值
func (s *CustomFieldService) SetUserFieldValue(ctx context.Context, userID int64, fieldKey string, value string) error {
// 获取字段定义
field, err := s.fieldRepo.GetByFieldKey(ctx, fieldKey)
if err != nil {
return errors.New("字段不存在")
}
// 验证值
if err := s.validateFieldValue(field, value); err != nil {
return err
}
return s.valueRepo.Set(ctx, userID, field.ID, fieldKey, value)
}
// BatchSetUserFieldValues 批量设置用户的自定义字段值
func (s *CustomFieldService) BatchSetUserFieldValues(ctx context.Context, userID int64, values map[string]string) error {
// 获取所有启用的字段定义
fields, err := s.fieldRepo.List(ctx)
if err != nil {
return err
}
fieldMap := make(map[string]*domain.CustomField)
for _, f := range fields {
fieldMap[f.FieldKey] = f
}
// 验证每个值
for fieldKey, value := range values {
field, ok := fieldMap[fieldKey]
if !ok {
return fmt.Errorf("字段不存在: %s", fieldKey)
}
if err := s.validateFieldValue(field, value); err != nil {
return err
}
}
// 批量设置值
return s.valueRepo.BatchSet(ctx, userID, values)
}
// GetUserFieldValues 获取用户的所有自定义字段值
func (s *CustomFieldService) GetUserFieldValues(ctx context.Context, userID int64) ([]*domain.CustomFieldValueResponse, error) {
// 获取所有启用的字段定义
fields, err := s.fieldRepo.List(ctx)
if err != nil {
return nil, err
}
// 获取用户的字段值
values, err := s.valueRepo.GetByUserID(ctx, userID)
if err != nil {
return nil, err
}
// 构建字段值映射
valueMap := make(map[int64]*domain.UserCustomFieldValue)
for _, v := range values {
valueMap[v.FieldID] = v
}
// 构建响应
fieldMap := make(map[string]*domain.CustomField)
for _, f := range fields {
fieldMap[f.FieldKey] = f
}
var result []*domain.CustomFieldValueResponse
for _, field := range fields {
resp := &domain.CustomFieldValueResponse{
FieldKey: field.FieldKey,
}
if val, ok := valueMap[field.ID]; ok {
resp.Value = val.GetValueAsInterface(field)
} else if field.DefaultVal != "" {
resp.Value = field.DefaultVal
} else {
resp.Value = nil
}
result = append(result, resp)
}
return result, nil
}
// DeleteUserFieldValue 删除用户的自定义字段值
func (s *CustomFieldService) DeleteUserFieldValue(ctx context.Context, userID int64, fieldKey string) error {
field, err := s.fieldRepo.GetByFieldKey(ctx, fieldKey)
if err != nil {
return errors.New("字段不存在")
}
return s.valueRepo.Delete(ctx, userID, field.ID)
}
// validateFieldValue 验证字段值
func (s *CustomFieldService) validateFieldValue(field *domain.CustomField, value string) error {
// 检查必填
if field.Required && value == "" {
return errors.New("字段值不能为空")
}
// 如果值为空且有默认值,跳过验证
if value == "" && field.DefaultVal != "" {
return nil
}
switch field.Type {
case domain.CustomFieldTypeString:
// 字符串长度验证
if field.MinLen > 0 && len(value) < field.MinLen {
return fmt.Errorf("值长度不能小于%d", field.MinLen)
}
if field.MaxLen > 0 && len(value) > field.MaxLen {
return fmt.Errorf("值长度不能大于%d", field.MaxLen)
}
case domain.CustomFieldTypeNumber:
// 数字验证
numVal, err := strconv.ParseFloat(value, 64)
if err != nil {
return errors.New("值必须是数字")
}
if field.MinVal > 0 && numVal < field.MinVal {
return fmt.Errorf("值不能小于%.2f", field.MinVal)
}
if field.MaxVal > 0 && numVal > field.MaxVal {
return fmt.Errorf("值不能大于%.2f", field.MaxVal)
}
case domain.CustomFieldTypeBoolean:
// 布尔验证
if value != "true" && value != "false" && value != "1" && value != "0" {
return errors.New("值必须是布尔值(true/false/1/0)")
}
case domain.CustomFieldTypeDate:
// 日期验证
_, err := time.Parse("2006-01-02", value)
if err != nil {
return errors.New("值必须是有效的日期格式(YYYY-MM-DD)")
}
}
return nil
}

276
internal/service/device.go Normal file
View File

@@ -0,0 +1,276 @@
package service
import (
"context"
"errors"
"time"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
// DeviceService 设备服务
type DeviceService struct {
deviceRepo *repository.DeviceRepository
userRepo *repository.UserRepository
}
// NewDeviceService 创建设备服务
func NewDeviceService(
deviceRepo *repository.DeviceRepository,
userRepo *repository.UserRepository,
) *DeviceService {
return &DeviceService{
deviceRepo: deviceRepo,
userRepo: userRepo,
}
}
// CreateDeviceRequest 创建设备请求
type CreateDeviceRequest struct {
DeviceID string `json:"device_id" binding:"required"`
DeviceName string `json:"device_name"`
DeviceType int `json:"device_type"`
DeviceOS string `json:"device_os"`
DeviceBrowser string `json:"device_browser"`
IP string `json:"ip"`
Location string `json:"location"`
}
// UpdateDeviceRequest 更新设备请求
type UpdateDeviceRequest struct {
DeviceName string `json:"device_name"`
DeviceType int `json:"device_type"`
DeviceOS string `json:"device_os"`
DeviceBrowser string `json:"device_browser"`
IP string `json:"ip"`
Location string `json:"location"`
Status int `json:"status"`
}
// CreateDevice 创建设备
func (s *DeviceService) CreateDevice(ctx context.Context, userID int64, req *CreateDeviceRequest) (*domain.Device, error) {
// 检查用户是否存在
_, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, errors.New("用户不存在")
}
// 检查设备是否已存在
exists, err := s.deviceRepo.Exists(ctx, userID, req.DeviceID)
if err != nil {
return nil, err
}
if exists {
// 设备已存在,更新最后活跃时间
device, err := s.deviceRepo.GetByDeviceID(ctx, userID, req.DeviceID)
if err != nil {
return nil, err
}
device.LastActiveTime = time.Now()
return device, s.deviceRepo.Update(ctx, device)
}
// 创建设备
device := &domain.Device{
UserID: userID,
DeviceID: req.DeviceID,
DeviceName: req.DeviceName,
DeviceType: domain.DeviceType(req.DeviceType),
DeviceOS: req.DeviceOS,
DeviceBrowser: req.DeviceBrowser,
IP: req.IP,
Location: req.Location,
Status: domain.DeviceStatusActive,
}
if err := s.deviceRepo.Create(ctx, device); err != nil {
return nil, err
}
return device, nil
}
// UpdateDevice 更新设备
func (s *DeviceService) UpdateDevice(ctx context.Context, deviceID int64, req *UpdateDeviceRequest) (*domain.Device, error) {
device, err := s.deviceRepo.GetByID(ctx, deviceID)
if err != nil {
return nil, errors.New("设备不存在")
}
// 更新字段
if req.DeviceName != "" {
device.DeviceName = req.DeviceName
}
if req.DeviceType >= 0 {
device.DeviceType = domain.DeviceType(req.DeviceType)
}
if req.DeviceOS != "" {
device.DeviceOS = req.DeviceOS
}
if req.DeviceBrowser != "" {
device.DeviceBrowser = req.DeviceBrowser
}
if req.IP != "" {
device.IP = req.IP
}
if req.Location != "" {
device.Location = req.Location
}
if req.Status >= 0 {
device.Status = domain.DeviceStatus(req.Status)
}
if err := s.deviceRepo.Update(ctx, device); err != nil {
return nil, err
}
return device, nil
}
// DeleteDevice 删除设备
func (s *DeviceService) DeleteDevice(ctx context.Context, deviceID int64) error {
return s.deviceRepo.Delete(ctx, deviceID)
}
// GetDevice 获取设备信息
func (s *DeviceService) GetDevice(ctx context.Context, deviceID int64) (*domain.Device, error) {
return s.deviceRepo.GetByID(ctx, deviceID)
}
// GetUserDevices 获取用户设备列表
func (s *DeviceService) GetUserDevices(ctx context.Context, userID int64, page, pageSize int) ([]*domain.Device, int64, error) {
offset := (page - 1) * pageSize
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
return s.deviceRepo.ListByUserID(ctx, userID, offset, pageSize)
}
// UpdateDeviceStatus 更新设备状态
func (s *DeviceService) UpdateDeviceStatus(ctx context.Context, deviceID int64, status domain.DeviceStatus) error {
return s.deviceRepo.UpdateStatus(ctx, deviceID, status)
}
// UpdateLastActiveTime 更新最后活跃时间
func (s *DeviceService) UpdateLastActiveTime(ctx context.Context, deviceID int64) error {
return s.deviceRepo.UpdateLastActiveTime(ctx, deviceID)
}
// GetActiveDevices 获取活跃设备
func (s *DeviceService) GetActiveDevices(ctx context.Context, page, pageSize int) ([]*domain.Device, int64, error) {
offset := (page - 1) * pageSize
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
return s.deviceRepo.ListByStatus(ctx, domain.DeviceStatusActive, offset, pageSize)
}
// TrustDevice 设置设备为信任状态
func (s *DeviceService) TrustDevice(ctx context.Context, deviceID int64, trustDuration time.Duration) error {
device, err := s.deviceRepo.GetByID(ctx, deviceID)
if err != nil {
return errors.New("设备不存在")
}
var trustExpiresAt *time.Time
if trustDuration > 0 {
expiresAt := time.Now().Add(trustDuration)
trustExpiresAt = &expiresAt
}
return s.deviceRepo.TrustDevice(ctx, device.ID, trustExpiresAt)
}
// TrustDeviceByDeviceID 根据设备标识字符串设置设备为信任状态
func (s *DeviceService) TrustDeviceByDeviceID(ctx context.Context, userID int64, deviceID string, trustDuration time.Duration) error {
device, err := s.deviceRepo.GetByDeviceID(ctx, userID, deviceID)
if err != nil {
return errors.New("设备不存在")
}
var trustExpiresAt *time.Time
if trustDuration > 0 {
expiresAt := time.Now().Add(trustDuration)
trustExpiresAt = &expiresAt
}
return s.deviceRepo.TrustDevice(ctx, device.ID, trustExpiresAt)
}
// UntrustDevice 取消设备信任状态
func (s *DeviceService) UntrustDevice(ctx context.Context, deviceID int64) error {
device, err := s.deviceRepo.GetByID(ctx, deviceID)
if err != nil {
return errors.New("设备不存在")
}
return s.deviceRepo.UntrustDevice(ctx, device.ID)
}
// LogoutAllOtherDevices 登出所有其他设备
func (s *DeviceService) LogoutAllOtherDevices(ctx context.Context, userID int64, currentDeviceID int64) error {
return s.deviceRepo.DeleteAllByUserIDExcept(ctx, userID, currentDeviceID)
}
// GetTrustedDevices 获取用户的信任设备列表
func (s *DeviceService) GetTrustedDevices(ctx context.Context, userID int64) ([]*domain.Device, error) {
return s.deviceRepo.GetTrustedDevices(ctx, userID)
}
// GetAllDevicesRequest 获取所有设备请求参数
type GetAllDevicesRequest struct {
Page int
PageSize int
UserID int64 `form:"user_id"`
Status int `form:"status"`
IsTrusted *bool `form:"is_trusted"`
Keyword string `form:"keyword"`
}
// GetAllDevices 获取所有设备(管理员用)
func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesRequest) ([]*domain.Device, int64, error) {
if req.Page <= 0 {
req.Page = 1
}
if req.PageSize <= 0 {
req.PageSize = 20
}
if req.PageSize > 100 {
req.PageSize = 100
}
offset := (req.Page - 1) * req.PageSize
params := &repository.ListDevicesParams{
UserID: req.UserID,
Keyword: req.Keyword,
Offset: offset,
Limit: req.PageSize,
}
// 处理状态筛选
if req.Status >= 0 {
params.Status = domain.DeviceStatus(req.Status)
}
// 处理信任状态筛选
if req.IsTrusted != nil {
params.IsTrusted = req.IsTrusted
}
return s.deviceRepo.ListAll(ctx, params)
}
// GetDeviceByDeviceID 根据设备标识获取设备(用于设备信任检查)
func (s *DeviceService) GetDeviceByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
return s.deviceRepo.GetByDeviceID(ctx, userID, deviceID)
}

308
internal/service/email.go Normal file
View File

@@ -0,0 +1,308 @@
package service
import (
"context"
cryptorand "crypto/rand"
"encoding/hex"
"fmt"
"log"
"net/url"
"net/smtp"
"strings"
"time"
)
type EmailProvider interface {
SendMail(ctx context.Context, to, subject, htmlBody string) error
}
type SMTPEmailConfig struct {
Host string
Port int
Username string
Password string
FromEmail string
FromName string
TLS bool
}
type SMTPEmailProvider struct {
cfg SMTPEmailConfig
}
func NewSMTPEmailProvider(cfg SMTPEmailConfig) EmailProvider {
return &SMTPEmailProvider{cfg: cfg}
}
func (p *SMTPEmailProvider) SendMail(ctx context.Context, to, subject, htmlBody string) error {
_ = ctx
var authInfo smtp.Auth
if p.cfg.Username != "" || p.cfg.Password != "" {
authInfo = smtp.PlainAuth("", p.cfg.Username, p.cfg.Password, p.cfg.Host)
}
from := p.cfg.FromEmail
if p.cfg.FromName != "" {
from = fmt.Sprintf("%s <%s>", p.cfg.FromName, p.cfg.FromEmail)
}
headers := []string{
fmt.Sprintf("From: %s", from),
fmt.Sprintf("To: %s", to),
fmt.Sprintf("Subject: %s", subject),
"MIME-Version: 1.0",
"Content-Type: text/html; charset=UTF-8",
"",
}
message := strings.Join(headers, "\r\n") + htmlBody
addr := fmt.Sprintf("%s:%d", p.cfg.Host, p.cfg.Port)
return smtp.SendMail(addr, authInfo, p.cfg.FromEmail, []string{to}, []byte(message))
}
type MockEmailProvider struct{}
func (m *MockEmailProvider) SendMail(ctx context.Context, to, subject, htmlBody string) error {
_ = ctx
log.Printf("[email-mock] to=%s subject=%s body_bytes=%d", to, subject, len(htmlBody))
return nil
}
type EmailCodeConfig struct {
CodeTTL time.Duration
ResendCooldown time.Duration
MaxDailyLimit int
SiteURL string
SiteName string
}
func DefaultEmailCodeConfig() EmailCodeConfig {
return EmailCodeConfig{
CodeTTL: 5 * time.Minute,
ResendCooldown: time.Minute,
MaxDailyLimit: 10,
SiteURL: "http://localhost:8080",
SiteName: "User Management System",
}
}
type EmailCodeService struct {
provider EmailProvider
cache cacheInterface
cfg EmailCodeConfig
}
func NewEmailCodeService(provider EmailProvider, cache cacheInterface, cfg EmailCodeConfig) *EmailCodeService {
if cfg.CodeTTL <= 0 {
cfg.CodeTTL = 5 * time.Minute
}
if cfg.ResendCooldown <= 0 {
cfg.ResendCooldown = time.Minute
}
if cfg.MaxDailyLimit <= 0 {
cfg.MaxDailyLimit = 10
}
return &EmailCodeService{
provider: provider,
cache: cache,
cfg: cfg,
}
}
func (s *EmailCodeService) SendEmailCode(ctx context.Context, email, purpose string) error {
cooldownKey := fmt.Sprintf("email_cooldown:%s:%s", purpose, email)
if _, ok := s.cache.Get(ctx, cooldownKey); ok {
return newRateLimitError(fmt.Sprintf("\u64cd\u4f5c\u8fc7\u4e8e\u9891\u7e41\uff0c\u8bf7 %d \u79d2\u540e\u518d\u8bd5", int(s.cfg.ResendCooldown.Seconds())))
}
dailyKey := fmt.Sprintf("email_daily:%s:%s", email, time.Now().Format("2006-01-02"))
var dailyCount int
if value, ok := s.cache.Get(ctx, dailyKey); ok {
if count, ok := intValue(value); ok {
dailyCount = count
}
}
if dailyCount >= s.cfg.MaxDailyLimit {
return newRateLimitError("\u4eca\u65e5\u53d1\u9001\u6b21\u6570\u5df2\u8fbe\u4e0a\u9650\uff0c\u8bf7\u660e\u5929\u518d\u8bd5")
}
code, err := generateEmailCode()
if err != nil {
return err
}
codeKey := fmt.Sprintf("email_code:%s:%s", purpose, email)
if err := s.cache.Set(ctx, codeKey, code, s.cfg.CodeTTL, s.cfg.CodeTTL); err != nil {
return fmt.Errorf("store email code failed: %w", err)
}
if err := s.cache.Set(ctx, cooldownKey, true, s.cfg.ResendCooldown, s.cfg.ResendCooldown); err != nil {
_ = s.cache.Delete(ctx, codeKey)
return fmt.Errorf("store email cooldown failed: %w", err)
}
if err := s.cache.Set(ctx, dailyKey, dailyCount+1, 24*time.Hour, 24*time.Hour); err != nil {
_ = s.cache.Delete(ctx, codeKey)
_ = s.cache.Delete(ctx, cooldownKey)
return fmt.Errorf("store email daily counter failed: %w", err)
}
subject, body := buildEmailCodeContent(purpose, code, s.cfg.SiteName, s.cfg.CodeTTL)
if err := s.provider.SendMail(ctx, email, subject, body); err != nil {
_ = s.cache.Delete(ctx, codeKey)
_ = s.cache.Delete(ctx, cooldownKey)
return fmt.Errorf("email delivery failed: %w", err)
}
return nil
}
func (s *EmailCodeService) VerifyEmailCode(ctx context.Context, email, purpose, code string) error {
if strings.TrimSpace(code) == "" {
return fmt.Errorf("verification code is required")
}
codeKey := fmt.Sprintf("email_code:%s:%s", purpose, email)
value, ok := s.cache.Get(ctx, codeKey)
if !ok {
return fmt.Errorf("verification code expired or missing")
}
storedCode, ok := value.(string)
if !ok || storedCode != code {
return fmt.Errorf("verification code is invalid")
}
if err := s.cache.Delete(ctx, codeKey); err != nil {
return fmt.Errorf("consume email code failed: %w", err)
}
return nil
}
type EmailActivationService struct {
provider EmailProvider
cache cacheInterface
tokenTTL time.Duration
siteURL string
siteName string
}
func NewEmailActivationService(provider EmailProvider, cache cacheInterface, siteURL, siteName string) *EmailActivationService {
return &EmailActivationService{
provider: provider,
cache: cache,
tokenTTL: 24 * time.Hour,
siteURL: siteURL,
siteName: siteName,
}
}
func (s *EmailActivationService) SendActivationEmail(ctx context.Context, userID int64, email, username string) error {
tokenBytes := make([]byte, 32)
if _, err := cryptorand.Read(tokenBytes); err != nil {
return fmt.Errorf("generate activation token failed: %w", err)
}
token := hex.EncodeToString(tokenBytes)
cacheKey := fmt.Sprintf("email_activation:%s", token)
if err := s.cache.Set(ctx, cacheKey, userID, s.tokenTTL, s.tokenTTL); err != nil {
return fmt.Errorf("store activation token failed: %w", err)
}
activationURL := buildFrontendActivationURL(s.siteURL, token)
subject := fmt.Sprintf("[%s] Activate Your Account", s.siteName)
body := buildActivationEmailBody(username, activationURL, s.siteName, s.tokenTTL)
return s.provider.SendMail(ctx, email, subject, body)
}
func buildFrontendActivationURL(siteURL, token string) string {
base := strings.TrimRight(strings.TrimSpace(siteURL), "/")
if base == "" {
base = DefaultEmailCodeConfig().SiteURL
}
return fmt.Sprintf("%s/activate-account?token=%s", base, url.QueryEscape(token))
}
func (s *EmailActivationService) ValidateActivationToken(ctx context.Context, token string) (int64, error) {
token = strings.TrimSpace(token)
if token == "" {
return 0, fmt.Errorf("activation token is required")
}
cacheKey := fmt.Sprintf("email_activation:%s", token)
value, ok := s.cache.Get(ctx, cacheKey)
if !ok {
return 0, fmt.Errorf("activation token expired or missing")
}
userID, ok := int64Value(value)
if !ok {
return 0, fmt.Errorf("activation token payload is invalid")
}
if err := s.cache.Delete(ctx, cacheKey); err != nil {
return 0, fmt.Errorf("consume activation token failed: %w", err)
}
return userID, nil
}
func buildEmailCodeContent(purpose, code, siteName string, ttl time.Duration) (subject, body string) {
purposeText := map[string]string{
"login": "login verification",
"register": "registration verification",
"reset": "password reset",
"bind": "binding verification",
}
label := purposeText[purpose]
if label == "" {
label = "identity verification"
}
subject = fmt.Sprintf("[%s] Your %s code: %s", siteName, label, code)
body = fmt.Sprintf(`<!DOCTYPE html>
<html>
<body style="font-family:Arial,sans-serif;max-width:600px;margin:0 auto;padding:20px;">
<h2 style="color:#333;">%s</h2>
<p>Your %s code is:</p>
<div style="background:#f5f5f5;padding:20px;text-align:center;margin:20px 0;border-radius:8px;">
<span style="font-size:36px;font-weight:bold;color:#2563eb;letter-spacing:8px;">%s</span>
</div>
<p>This code expires in <strong>%d minutes</strong>.</p>
<p style="color:#999;font-size:12px;">If you did not request this code, you can ignore this email.</p>
</body>
</html>`, siteName, label, code, int(ttl.Minutes()))
return subject, body
}
func buildActivationEmailBody(username, activationURL, siteName string, ttl time.Duration) string {
return fmt.Sprintf(`<!DOCTYPE html>
<html>
<body style="font-family:Arial,sans-serif;max-width:600px;margin:0 auto;padding:20px;">
<h2 style="color:#333;">Welcome to %s</h2>
<p>Hello <strong>%s</strong>,</p>
<p>Please click the button below to activate your account.</p>
<div style="text-align:center;margin:30px 0;">
<a href="%s"
style="background:#2563eb;color:#fff;padding:14px 32px;text-decoration:none;border-radius:8px;font-size:16px;font-weight:bold;">
Activate Account
</a>
</div>
<p>If the button does not work, copy this link into your browser:</p>
<p style="word-break:break-all;color:#2563eb;">%s</p>
<p>This link expires in <strong>%d hours</strong>.</p>
</body>
</html>`, siteName, username, activationURL, activationURL, int(ttl.Hours()))
}
func generateEmailCode() (string, error) {
buffer := make([]byte, 3)
if _, err := cryptorand.Read(buffer); err != nil {
return "", fmt.Errorf("generate email code failed: %w", err)
}
value := int(buffer[0])<<16 | int(buffer[1])<<8 | int(buffer[2])
value = value % 1000000
if value < 100000 {
value += 100000
}
return fmt.Sprintf("%06d", value), nil
}

534
internal/service/export.go Normal file
View File

@@ -0,0 +1,534 @@
package service
import (
"bytes"
"context"
"encoding/csv"
"fmt"
"strings"
"time"
"github.com/xuri/excelize/v2"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
const (
ExportFormatCSV = "csv"
ExportFormatXLSX = "xlsx"
)
// ExportUsersRequest defines the supported export filters and output options.
type ExportUsersRequest struct {
Format string
Fields []string
Keyword string
Status *int
}
type exportColumn struct {
Key string
Header string
Value func(*domain.User) string
}
var defaultExportColumns = []exportColumn{
{Key: "id", Header: "ID", Value: func(u *domain.User) string { return fmt.Sprintf("%d", u.ID) }},
{Key: "username", Header: "用户名", Value: func(u *domain.User) string { return u.Username }},
{Key: "email", Header: "邮箱", Value: func(u *domain.User) string { return domain.DerefStr(u.Email) }},
{Key: "phone", Header: "手机号", Value: func(u *domain.User) string { return domain.DerefStr(u.Phone) }},
{Key: "nickname", Header: "昵称", Value: func(u *domain.User) string { return u.Nickname }},
{Key: "avatar", Header: "头像", Value: func(u *domain.User) string { return u.Avatar }},
{Key: "gender", Header: "性别", Value: func(u *domain.User) string { return genderLabel(u.Gender) }},
{Key: "status", Header: "状态", Value: func(u *domain.User) string { return userStatusLabel(u.Status) }},
{Key: "region", Header: "地区", Value: func(u *domain.User) string { return u.Region }},
{Key: "bio", Header: "个人简介", Value: func(u *domain.User) string { return u.Bio }},
{Key: "totp_enabled", Header: "TOTP已启用", Value: func(u *domain.User) string { return boolLabel(u.TOTPEnabled) }},
{Key: "last_login_time", Header: "最后登录时间", Value: func(u *domain.User) string { return timeLabel(u.LastLoginTime) }},
{Key: "last_login_ip", Header: "最后登录IP", Value: func(u *domain.User) string { return u.LastLoginIP }},
{Key: "created_at", Header: "注册时间", Value: func(u *domain.User) string { return u.CreatedAt.Format("2006-01-02 15:04:05") }},
}
// ExportService 用户数据导入导出服务
type ExportService struct {
userRepo *repository.UserRepository
roleRepo *repository.RoleRepository
}
// NewExportService 创建导入导出服务
func NewExportService(
userRepo *repository.UserRepository,
roleRepo *repository.RoleRepository,
) *ExportService {
return &ExportService{
userRepo: userRepo,
roleRepo: roleRepo,
}
}
// ExportUsers exports users as CSV or XLSX.
func (s *ExportService) ExportUsers(ctx context.Context, req *ExportUsersRequest) ([]byte, string, string, error) {
if req == nil {
req = &ExportUsersRequest{}
}
format, err := normalizeExportFormat(req.Format)
if err != nil {
return nil, "", "", err
}
columns, err := resolveExportColumns(req.Fields)
if err != nil {
return nil, "", "", err
}
users, err := s.listUsersForExport(ctx, req)
if err != nil {
return nil, "", "", err
}
filename := fmt.Sprintf("users_%s.%s", time.Now().Format("20060102_150405"), format)
switch format {
case ExportFormatCSV:
data, err := buildCSVExport(columns, users)
if err != nil {
return nil, "", "", err
}
return data, filename, "text/csv; charset=utf-8", nil
case ExportFormatXLSX:
data, err := buildXLSXExport(columns, users)
if err != nil {
return nil, "", "", err
}
return data, filename, "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil
default:
return nil, "", "", fmt.Errorf("不支持的导出格式: %s", req.Format)
}
}
// ExportUsersCSV keeps backward compatibility for callers that still expect CSV-only export.
func (s *ExportService) ExportUsersCSV(ctx context.Context) ([]byte, string, error) {
data, filename, _, err := s.ExportUsers(ctx, &ExportUsersRequest{Format: ExportFormatCSV})
return data, filename, err
}
// ExportUsersXLSX exports users as Excel.
func (s *ExportService) ExportUsersXLSX(ctx context.Context) ([]byte, string, error) {
data, filename, _, err := s.ExportUsers(ctx, &ExportUsersRequest{Format: ExportFormatXLSX})
return data, filename, err
}
func (s *ExportService) listUsersForExport(ctx context.Context, req *ExportUsersRequest) ([]*domain.User, error) {
var allUsers []*domain.User
offset := 0
batchSize := 500
for {
var (
users []*domain.User
total int64
err error
)
if req.Keyword != "" || req.Status != nil {
filter := &repository.AdvancedFilter{
Keyword: req.Keyword,
Status: -1,
SortBy: "created_at",
SortOrder: "desc",
Offset: offset,
Limit: batchSize,
}
if req.Status != nil {
filter.Status = *req.Status
}
users, total, err = s.userRepo.AdvancedSearch(ctx, filter)
if err != nil {
return nil, fmt.Errorf("查询用户失败: %w", err)
}
allUsers = append(allUsers, users...)
offset += len(users)
if offset >= int(total) || len(users) == 0 {
break
}
continue
}
users, _, err = s.userRepo.List(ctx, offset, batchSize)
if err != nil {
return nil, fmt.Errorf("查询用户失败: %w", err)
}
allUsers = append(allUsers, users...)
if len(users) < batchSize {
break
}
offset += batchSize
}
return allUsers, nil
}
// ImportUsers imports users from CSV or XLSX.
func (s *ExportService) ImportUsers(ctx context.Context, data []byte, format string) (successCount, failCount int, errs []string) {
normalized, err := normalizeExportFormat(format)
if err != nil {
return 0, 0, []string{err.Error()}
}
var records [][]string
switch normalized {
case ExportFormatCSV:
records, err = parseCSVRecords(data)
case ExportFormatXLSX:
records, err = parseXLSXRecords(data)
default:
err = fmt.Errorf("不支持的导入格式: %s", format)
}
if err != nil {
return 0, 0, []string{err.Error()}
}
return s.importUsersRecords(ctx, records)
}
// ImportUsersCSV keeps backward compatibility for callers that still upload CSV.
func (s *ExportService) ImportUsersCSV(ctx context.Context, data []byte) (successCount, failCount int, errs []string) {
return s.ImportUsers(ctx, data, ExportFormatCSV)
}
// ImportUsersXLSX imports users from Excel.
func (s *ExportService) ImportUsersXLSX(ctx context.Context, data []byte) (successCount, failCount int, errs []string) {
return s.ImportUsers(ctx, data, ExportFormatXLSX)
}
func (s *ExportService) importUsersRecords(ctx context.Context, records [][]string) (successCount, failCount int, errs []string) {
if len(records) < 2 {
return 0, 0, []string{"导入文件为空或没有数据行"}
}
headers := records[0]
colIdx := buildColIndex(headers)
getCol := func(row []string, name string) string {
idx, ok := colIdx[name]
if !ok || idx >= len(row) {
return ""
}
return strings.TrimSpace(row[idx])
}
for i, row := range records[1:] {
lineNum := i + 2
username := getCol(row, "用户名")
password := getCol(row, "密码")
if username == "" || password == "" {
failCount++
errs = append(errs, fmt.Sprintf("第%d行用户名和密码不能为空", lineNum))
continue
}
exists, err := s.userRepo.ExistsByUsername(ctx, username)
if err != nil {
failCount++
errs = append(errs, fmt.Sprintf("第%d行检查用户名失败: %v", lineNum, err))
continue
}
if exists {
failCount++
errs = append(errs, fmt.Sprintf("第%d行用户名 '%s' 已存在", lineNum, username))
continue
}
hashedPwd, err := hashPassword(password)
if err != nil {
failCount++
errs = append(errs, fmt.Sprintf("第%d行密码加密失败: %v", lineNum, err))
continue
}
user := &domain.User{
Username: username,
Email: domain.StrPtr(getCol(row, "邮箱")),
Phone: domain.StrPtr(getCol(row, "手机号")),
Nickname: getCol(row, "昵称"),
Password: hashedPwd,
Region: getCol(row, "地区"),
Bio: getCol(row, "个人简介"),
Status: domain.UserStatusActive,
}
if err := s.userRepo.Create(ctx, user); err != nil {
failCount++
errs = append(errs, fmt.Sprintf("第%d行创建用户失败: %v", lineNum, err))
continue
}
successCount++
}
return successCount, failCount, errs
}
// GetImportTemplate keeps backward compatibility for callers that still expect CSV templates.
func (s *ExportService) GetImportTemplate() ([]byte, string) {
data, filename, _, _ := s.GetImportTemplateByFormat(ExportFormatCSV)
return data, filename
}
// GetImportTemplateByFormat returns a CSV or XLSX template for imports.
func (s *ExportService) GetImportTemplateByFormat(format string) ([]byte, string, string, error) {
normalized, err := normalizeExportFormat(format)
if err != nil {
return nil, "", "", err
}
headers := []string{"用户名", "密码", "邮箱", "手机号", "昵称", "性别", "地区", "个人简介"}
rows := [][]string{{
"john_doe", "Password123!", "john@example.com", "13800138000",
"约翰", "男", "北京", "这是个人简介",
}}
switch normalized {
case ExportFormatCSV:
data, err := buildCSVRecords(headers, rows)
if err != nil {
return nil, "", "", err
}
return data, "user_import_template.csv", "text/csv; charset=utf-8", nil
case ExportFormatXLSX:
data, err := buildXLSXRecords(headers, rows)
if err != nil {
return nil, "", "", err
}
return data, "user_import_template.xlsx", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil
default:
return nil, "", "", fmt.Errorf("不支持的模板格式: %s", format)
}
}
func normalizeExportFormat(format string) (string, error) {
normalized := strings.ToLower(strings.TrimSpace(format))
if normalized == "" {
normalized = ExportFormatCSV
}
switch normalized {
case ExportFormatCSV, ExportFormatXLSX:
return normalized, nil
default:
return "", fmt.Errorf("不支持的格式: %s", format)
}
}
func resolveExportColumns(fields []string) ([]exportColumn, error) {
if len(fields) == 0 {
return defaultExportColumns, nil
}
columnMap := make(map[string]exportColumn, len(defaultExportColumns))
for _, col := range defaultExportColumns {
columnMap[col.Key] = col
}
selected := make([]exportColumn, 0, len(fields))
seen := make(map[string]struct{}, len(fields))
for _, field := range fields {
key := strings.ToLower(strings.TrimSpace(field))
if key == "" {
continue
}
if _, ok := seen[key]; ok {
continue
}
col, ok := columnMap[key]
if !ok {
return nil, fmt.Errorf("不支持的导出字段: %s", field)
}
selected = append(selected, col)
seen[key] = struct{}{}
}
if len(selected) == 0 {
return defaultExportColumns, nil
}
return selected, nil
}
func buildCSVExport(columns []exportColumn, users []*domain.User) ([]byte, error) {
headers := make([]string, 0, len(columns))
rows := make([][]string, 0, len(users))
for _, col := range columns {
headers = append(headers, col.Header)
}
for _, u := range users {
row := make([]string, 0, len(columns))
for _, col := range columns {
row = append(row, col.Value(u))
}
rows = append(rows, row)
}
return buildCSVRecords(headers, rows)
}
func buildCSVRecords(headers []string, rows [][]string) ([]byte, error) {
var buf bytes.Buffer
buf.Write([]byte{0xEF, 0xBB, 0xBF})
writer := csv.NewWriter(&buf)
if err := writer.Write(headers); err != nil {
return nil, fmt.Errorf("写CSV表头失败: %w", err)
}
for _, row := range rows {
if err := writer.Write(row); err != nil {
return nil, fmt.Errorf("写CSV行失败: %w", err)
}
}
writer.Flush()
if err := writer.Error(); err != nil {
return nil, fmt.Errorf("CSV Flush 失败: %w", err)
}
return buf.Bytes(), nil
}
func buildXLSXExport(columns []exportColumn, users []*domain.User) ([]byte, error) {
headers := make([]string, 0, len(columns))
rows := make([][]string, 0, len(users))
for _, col := range columns {
headers = append(headers, col.Header)
}
for _, u := range users {
row := make([]string, 0, len(columns))
for _, col := range columns {
row = append(row, col.Value(u))
}
rows = append(rows, row)
}
return buildXLSXRecords(headers, rows)
}
func buildXLSXRecords(headers []string, rows [][]string) ([]byte, error) {
file := excelize.NewFile()
defer file.Close()
sheet := file.GetSheetName(file.GetActiveSheetIndex())
if sheet == "" {
sheet = "Sheet1"
}
for idx, header := range headers {
cell, err := excelize.CoordinatesToCellName(idx+1, 1)
if err != nil {
return nil, fmt.Errorf("生成表头单元格失败: %w", err)
}
if err := file.SetCellValue(sheet, cell, header); err != nil {
return nil, fmt.Errorf("写入表头失败: %w", err)
}
}
for rowIdx, row := range rows {
for colIdx, value := range row {
cell, err := excelize.CoordinatesToCellName(colIdx+1, rowIdx+2)
if err != nil {
return nil, fmt.Errorf("生成数据单元格失败: %w", err)
}
if err := file.SetCellValue(sheet, cell, value); err != nil {
return nil, fmt.Errorf("写入单元格失败: %w", err)
}
}
}
var buf bytes.Buffer
if _, err := file.WriteTo(&buf); err != nil {
return nil, fmt.Errorf("生成Excel失败: %w", err)
}
return buf.Bytes(), nil
}
func parseCSVRecords(data []byte) ([][]string, error) {
if len(data) >= 3 && data[0] == 0xEF && data[1] == 0xBB && data[2] == 0xBF {
data = data[3:]
}
reader := csv.NewReader(bytes.NewReader(data))
records, err := reader.ReadAll()
if err != nil {
return nil, fmt.Errorf("CSV 解析失败: %w", err)
}
return records, nil
}
func parseXLSXRecords(data []byte) ([][]string, error) {
file, err := excelize.OpenReader(bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("Excel 解析失败: %w", err)
}
defer file.Close()
sheets := file.GetSheetList()
if len(sheets) == 0 {
return nil, fmt.Errorf("Excel 文件没有可用工作表")
}
rows, err := file.GetRows(sheets[0])
if err != nil {
return nil, fmt.Errorf("读取Excel行失败: %w", err)
}
return rows, nil
}
// ---- 辅助函数 ----
func genderLabel(g domain.Gender) string {
switch g {
case domain.GenderMale:
return "男"
case domain.GenderFemale:
return "女"
default:
return "未知"
}
}
func userStatusLabel(s domain.UserStatus) string {
switch s {
case domain.UserStatusActive:
return "已激活"
case domain.UserStatusInactive:
return "未激活"
case domain.UserStatusLocked:
return "已锁定"
case domain.UserStatusDisabled:
return "已禁用"
default:
return "未知"
}
}
func boolLabel(b bool) string {
if b {
return "是"
}
return "否"
}
func timeLabel(t *time.Time) string {
if t == nil {
return ""
}
return t.Format("2006-01-02 15:04:05")
}
// buildColIndex 将表头列名映射到列索引
func buildColIndex(headers []string) map[string]int {
idx := make(map[string]int, len(headers))
for i, h := range headers {
idx[h] = i
}
return idx
}
// hashPassword hashes imported passwords with the primary runtime algorithm.
func hashPassword(password string) (string, error) {
return auth.HashPassword(password)
}

View File

@@ -0,0 +1,157 @@
package service
import (
"net/http"
"strings"
)
// headerWireCasing 定义每个白名单 header 在真实 Claude CLI 抓包中的准确大小写。
// Go 的 HTTP server 解析请求时会将所有 header key 转为 Canonical 形式(如 x-app → X-App
// 此 map 用于在转发时恢复到真实的 wire format。
//
// 来源:对真实 Claude CLI (claude-cli/2.1.81) 到 api.anthropic.com 的 HTTPS 流量抓包。
var headerWireCasing = map[string]string{
// Title case
"accept": "Accept",
"user-agent": "User-Agent",
// X-Stainless-* 保持 SDK 原始大小写
"x-stainless-retry-count": "X-Stainless-Retry-Count",
"x-stainless-timeout": "X-Stainless-Timeout",
"x-stainless-lang": "X-Stainless-Lang",
"x-stainless-package-version": "X-Stainless-Package-Version",
"x-stainless-os": "X-Stainless-OS",
"x-stainless-arch": "X-Stainless-Arch",
"x-stainless-runtime": "X-Stainless-Runtime",
"x-stainless-runtime-version": "X-Stainless-Runtime-Version",
"x-stainless-helper-method": "x-stainless-helper-method",
// Anthropic SDK 自身设置的 header全小写
"anthropic-dangerous-direct-browser-access": "anthropic-dangerous-direct-browser-access",
"anthropic-version": "anthropic-version",
"anthropic-beta": "anthropic-beta",
"x-app": "x-app",
"content-type": "content-type",
"accept-language": "accept-language",
"sec-fetch-mode": "sec-fetch-mode",
"accept-encoding": "accept-encoding",
"authorization": "authorization",
}
// headerWireOrder 定义真实 Claude CLI 发送 header 的顺序(基于抓包)。
// 用于 debug log 按此顺序输出,便于与抓包结果直接对比。
var headerWireOrder = []string{
"Accept",
"X-Stainless-Retry-Count",
"X-Stainless-Timeout",
"X-Stainless-Lang",
"X-Stainless-Package-Version",
"X-Stainless-OS",
"X-Stainless-Arch",
"X-Stainless-Runtime",
"X-Stainless-Runtime-Version",
"anthropic-dangerous-direct-browser-access",
"anthropic-version",
"authorization",
"x-app",
"User-Agent",
"content-type",
"anthropic-beta",
"accept-language",
"sec-fetch-mode",
"accept-encoding",
"x-stainless-helper-method",
}
// headerWireOrderSet 用于快速判断某个 key 是否在 headerWireOrder 中(按 lowercase 匹配)。
var headerWireOrderSet map[string]struct{}
func init() {
headerWireOrderSet = make(map[string]struct{}, len(headerWireOrder))
for _, k := range headerWireOrder {
headerWireOrderSet[strings.ToLower(k)] = struct{}{}
}
}
// resolveWireCasing 将 Go canonical key如 X-Stainless-Os映射为真实 wire casing如 X-Stainless-OS
// 如果 map 中没有对应条目,返回原始 key 不变。
func resolveWireCasing(key string) string {
if wk, ok := headerWireCasing[strings.ToLower(key)]; ok {
return wk
}
return key
}
// setHeaderRaw sets a header bypassing Go's canonical-case normalization.
// The key is stored exactly as provided, preserving original casing.
//
// It first removes any existing value under the canonical key, the wire casing key,
// and the exact raw key, preventing duplicates from any source.
func setHeaderRaw(h http.Header, key, value string) {
h.Del(key) // remove canonical form (e.g. "Anthropic-Beta")
if wk := resolveWireCasing(key); wk != key {
delete(h, wk) // remove wire casing form if different
}
delete(h, key) // remove exact raw key if it differs from canonical
h[key] = []string{value}
}
// addHeaderRaw appends a header value bypassing Go's canonical-case normalization.
func addHeaderRaw(h http.Header, key, value string) {
h[key] = append(h[key], value)
}
// getHeaderRaw reads a header value, trying multiple key forms to handle the mismatch
// between Go canonical keys, wire casing keys, and raw keys:
// 1. exact key as provided
// 2. wire casing form (from headerWireCasing)
// 3. Go canonical form (via http.Header.Get)
func getHeaderRaw(h http.Header, key string) string {
// 1. exact key
if vals := h[key]; len(vals) > 0 {
return vals[0]
}
// 2. wire casing (e.g. looking up "Anthropic-Dangerous-Direct-Browser-Access" finds "anthropic-dangerous-direct-browser-access")
if wk := resolveWireCasing(key); wk != key {
if vals := h[wk]; len(vals) > 0 {
return vals[0]
}
}
// 3. canonical fallback
return h.Get(key)
}
// sortHeadersByWireOrder 按照真实 Claude CLI 的 header 顺序返回排序后的 key 列表。
// 在 headerWireOrder 中定义的 key 按其顺序排列,未定义的 key 追加到末尾。
func sortHeadersByWireOrder(h http.Header) []string {
// 构建 lowercase -> actual map key 的映射
present := make(map[string]string, len(h))
for k := range h {
present[strings.ToLower(k)] = k
}
result := make([]string, 0, len(h))
seen := make(map[string]struct{}, len(h))
// 先按 wire order 输出
for _, wk := range headerWireOrder {
lk := strings.ToLower(wk)
if actual, ok := present[lk]; ok {
if _, dup := seen[lk]; !dup {
result = append(result, actual)
seen[lk] = struct{}{}
}
}
}
// 再追加不在 wire order 中的 header
for k := range h {
lk := strings.ToLower(k)
if _, ok := seen[lk]; !ok {
result = append(result, k)
seen[lk] = struct{}{}
}
}
return result
}

View File

@@ -0,0 +1,257 @@
package service
import (
"bytes"
"context"
"encoding/csv"
"fmt"
"time"
"github.com/xuri/excelize/v2"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
// LoginLogService 登录日志服务
type LoginLogService struct {
loginLogRepo *repository.LoginLogRepository
}
// NewLoginLogService 创建登录日志服务
func NewLoginLogService(loginLogRepo *repository.LoginLogRepository) *LoginLogService {
return &LoginLogService{loginLogRepo: loginLogRepo}
}
// RecordLogin 记录登录日志
func (s *LoginLogService) RecordLogin(ctx context.Context, req *RecordLoginRequest) error {
log := &domain.LoginLog{
LoginType: req.LoginType,
DeviceID: req.DeviceID,
IP: req.IP,
Location: req.Location,
Status: req.Status,
FailReason: req.FailReason,
}
if req.UserID != 0 {
log.UserID = &req.UserID
}
return s.loginLogRepo.Create(ctx, log)
}
// RecordLoginRequest 记录登录请求
type RecordLoginRequest struct {
UserID int64 `json:"user_id"`
LoginType int `json:"login_type"` // 1-用户名, 2-邮箱, 3-手机
DeviceID string `json:"device_id"`
IP string `json:"ip"`
Location string `json:"location"`
Status int `json:"status"` // 0-失败, 1-成功
FailReason string `json:"fail_reason"`
}
// ListLoginLogRequest 登录日志列表请求
type ListLoginLogRequest struct {
UserID int64 `json:"user_id"`
Status int `json:"status"`
Page int `json:"page"`
PageSize int `json:"page_size"`
StartAt string `json:"start_at"`
EndAt string `json:"end_at"`
}
// GetLoginLogs 获取登录日志列表
func (s *LoginLogService) GetLoginLogs(ctx context.Context, req *ListLoginLogRequest) ([]*domain.LoginLog, int64, error) {
if req.Page <= 0 {
req.Page = 1
}
if req.PageSize <= 0 {
req.PageSize = 20
}
offset := (req.Page - 1) * req.PageSize
// 按用户 ID 查询
if req.UserID > 0 {
return s.loginLogRepo.ListByUserID(ctx, req.UserID, offset, req.PageSize)
}
// 按时间范围查询
if req.StartAt != "" && req.EndAt != "" {
start, err1 := time.Parse(time.RFC3339, req.StartAt)
end, err2 := time.Parse(time.RFC3339, req.EndAt)
if err1 == nil && err2 == nil {
return s.loginLogRepo.ListByTimeRange(ctx, start, end, offset, req.PageSize)
}
}
// 按状态查询
if req.Status == 0 || req.Status == 1 {
return s.loginLogRepo.ListByStatus(ctx, req.Status, offset, req.PageSize)
}
return s.loginLogRepo.List(ctx, offset, req.PageSize)
}
// GetMyLoginLogs 获取当前用户的登录日志
func (s *LoginLogService) GetMyLoginLogs(ctx context.Context, userID int64, page, pageSize int) ([]*domain.LoginLog, int64, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
offset := (page - 1) * pageSize
return s.loginLogRepo.ListByUserID(ctx, userID, offset, pageSize)
}
// CleanupOldLogs 清理旧日志(保留最近 N 天)
func (s *LoginLogService) CleanupOldLogs(ctx context.Context, retentionDays int) error {
return s.loginLogRepo.DeleteOlderThan(ctx, retentionDays)
}
// ExportLoginLogRequest 导出登录日志请求
type ExportLoginLogRequest struct {
UserID int64 `form:"user_id"`
Status int `form:"status"`
Format string `form:"format"`
StartAt string `form:"start_at"`
EndAt string `form:"end_at"`
}
// ExportLoginLogs 导出登录日志
func (s *LoginLogService) ExportLoginLogs(ctx context.Context, req *ExportLoginLogRequest) ([]byte, string, string, error) {
format := "csv"
if req.Format == "xlsx" {
format = "xlsx"
}
var startAt, endAt *time.Time
if req.StartAt != "" {
if t, err := time.Parse(time.RFC3339, req.StartAt); err == nil {
startAt = &t
}
}
if req.EndAt != "" {
if t, err := time.Parse(time.RFC3339, req.EndAt); err == nil {
endAt = &t
}
}
logs, err := s.loginLogRepo.ListAllForExport(ctx, req.UserID, req.Status, startAt, endAt)
if err != nil {
return nil, "", "", fmt.Errorf("查询登录日志失败: %w", err)
}
filename := fmt.Sprintf("login_logs_%s.%s", time.Now().Format("20060102_150405"), format)
if format == "xlsx" {
data, err := buildLoginLogXLSXExport(logs)
if err != nil {
return nil, "", "", err
}
return data, filename, "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil
}
data, err := buildLoginLogCSVExport(logs)
if err != nil {
return nil, "", "", err
}
return data, filename, "text/csv; charset=utf-8", nil
}
func buildLoginLogCSVExport(logs []*domain.LoginLog) ([]byte, error) {
headers := []string{"ID", "用户ID", "登录方式", "设备ID", "IP地址", "位置", "状态", "失败原因", "时间"}
rows := make([][]string, 0, len(logs)+1)
rows = append(rows, headers)
for _, log := range logs {
rows = append(rows, []string{
fmt.Sprintf("%d", log.ID),
fmt.Sprintf("%d", derefInt64(log.UserID)),
loginTypeLabel(log.LoginType),
log.DeviceID,
log.IP,
log.Location,
loginStatusLabel(log.Status),
log.FailReason,
log.CreatedAt.Format("2006-01-02 15:04:05"),
})
}
var buf bytes.Buffer
buf.Write([]byte{0xEF, 0xBB, 0xBF})
writer := csv.NewWriter(&buf)
if err := writer.WriteAll(rows); err != nil {
return nil, fmt.Errorf("写CSV失败: %w", err)
}
return buf.Bytes(), nil
}
func buildLoginLogXLSXExport(logs []*domain.LoginLog) ([]byte, error) {
file := excelize.NewFile()
defer file.Close()
sheet := file.GetSheetName(file.GetActiveSheetIndex())
if sheet == "" {
sheet = "Sheet1"
}
headers := []string{"ID", "用户ID", "登录方式", "设备ID", "IP地址", "位置", "状态", "失败原因", "时间"}
for idx, header := range headers {
cell, _ := excelize.CoordinatesToCellName(idx+1, 1)
_ = file.SetCellValue(sheet, cell, header)
}
for rowIdx, log := range logs {
row := []string{
fmt.Sprintf("%d", log.ID),
fmt.Sprintf("%d", derefInt64(log.UserID)),
loginTypeLabel(log.LoginType),
log.DeviceID,
log.IP,
log.Location,
loginStatusLabel(log.Status),
log.FailReason,
log.CreatedAt.Format("2006-01-02 15:04:05"),
}
for colIdx, value := range row {
cell, _ := excelize.CoordinatesToCellName(colIdx+1, rowIdx+2)
_ = file.SetCellValue(sheet, cell, value)
}
}
var buf bytes.Buffer
if _, err := file.WriteTo(&buf); err != nil {
return nil, fmt.Errorf("生成Excel失败: %w", err)
}
return buf.Bytes(), nil
}
func loginTypeLabel(t int) string {
switch t {
case 1:
return "密码登录"
case 2:
return "邮箱验证码"
case 3:
return "手机验证码"
case 4:
return "OAuth"
default:
return "未知"
}
}
func loginStatusLabel(s int) string {
if s == 1 {
return "成功"
}
return "失败"
}
func derefInt64(v *int64) int64 {
if v == nil {
return 0
}
return *v
}

View File

@@ -0,0 +1,115 @@
package service
import (
"context"
"time"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
// OperationLogService 操作日志服务
type OperationLogService struct {
operationLogRepo *repository.OperationLogRepository
}
// NewOperationLogService 创建操作日志服务
func NewOperationLogService(operationLogRepo *repository.OperationLogRepository) *OperationLogService {
return &OperationLogService{operationLogRepo: operationLogRepo}
}
// RecordOperation 记录操作日志
func (s *OperationLogService) RecordOperation(ctx context.Context, req *RecordOperationRequest) error {
log := &domain.OperationLog{
OperationType: req.OperationType,
OperationName: req.OperationName,
RequestMethod: req.RequestMethod,
RequestPath: req.RequestPath,
RequestParams: req.RequestParams,
ResponseStatus: req.ResponseStatus,
IP: req.IP,
UserAgent: req.UserAgent,
}
if req.UserID != 0 {
log.UserID = &req.UserID
}
return s.operationLogRepo.Create(ctx, log)
}
// RecordOperationRequest 记录操作请求
type RecordOperationRequest struct {
UserID int64 `json:"user_id"`
OperationType string `json:"operation_type"`
OperationName string `json:"operation_name"`
RequestMethod string `json:"request_method"`
RequestPath string `json:"request_path"`
RequestParams string `json:"request_params"`
ResponseStatus int `json:"response_status"`
IP string `json:"ip"`
UserAgent string `json:"user_agent"`
}
// ListOperationLogRequest 操作日志列表请求
type ListOperationLogRequest struct {
UserID int64 `json:"user_id"`
Method string `json:"method"`
Keyword string `json:"keyword"`
Page int `json:"page"`
PageSize int `json:"page_size"`
StartAt string `json:"start_at"`
EndAt string `json:"end_at"`
}
// GetOperationLogs 获取操作日志列表
func (s *OperationLogService) GetOperationLogs(ctx context.Context, req *ListOperationLogRequest) ([]*domain.OperationLog, int64, error) {
if req.Page <= 0 {
req.Page = 1
}
if req.PageSize <= 0 {
req.PageSize = 20
}
offset := (req.Page - 1) * req.PageSize
// 按关键词搜索
if req.Keyword != "" {
return s.operationLogRepo.Search(ctx, req.Keyword, offset, req.PageSize)
}
// 按用户 ID 查询
if req.UserID > 0 {
return s.operationLogRepo.ListByUserID(ctx, req.UserID, offset, req.PageSize)
}
// 按 HTTP 方法查询
if req.Method != "" {
return s.operationLogRepo.ListByMethod(ctx, req.Method, offset, req.PageSize)
}
// 按时间范围查询
if req.StartAt != "" && req.EndAt != "" {
start, err1 := time.Parse(time.RFC3339, req.StartAt)
end, err2 := time.Parse(time.RFC3339, req.EndAt)
if err1 == nil && err2 == nil {
return s.operationLogRepo.ListByTimeRange(ctx, start, end, offset, req.PageSize)
}
}
return s.operationLogRepo.List(ctx, offset, req.PageSize)
}
// GetMyOperationLogs 获取当前用户的操作日志
func (s *OperationLogService) GetMyOperationLogs(ctx context.Context, userID int64, page, pageSize int) ([]*domain.OperationLog, int64, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
offset := (page - 1) * pageSize
return s.operationLogRepo.ListByUserID(ctx, userID, offset, pageSize)
}
// CleanupOldLogs 清理旧日志(保留最近 N 天)
func (s *OperationLogService) CleanupOldLogs(ctx context.Context, retentionDays int) error {
return s.operationLogRepo.DeleteOlderThan(ctx, retentionDays)
}

View File

@@ -0,0 +1,272 @@
package service
import (
"context"
cryptorand "crypto/rand"
"encoding/hex"
"errors"
"fmt"
"log"
"net/smtp"
"time"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/security"
)
// PasswordResetConfig controls reset-token issuance and SMTP delivery.
type PasswordResetConfig struct {
TokenTTL time.Duration
SMTPHost string
SMTPPort int
SMTPUser string
SMTPPass string
FromEmail string
SiteURL string
PasswordMinLen int
PasswordRequireSpecial bool
PasswordRequireNumber bool
}
func DefaultPasswordResetConfig() *PasswordResetConfig {
return &PasswordResetConfig{
TokenTTL: 15 * time.Minute,
SMTPHost: "",
SMTPPort: 587,
SMTPUser: "",
SMTPPass: "",
FromEmail: "noreply@example.com",
SiteURL: "http://localhost:8080",
PasswordMinLen: 8,
PasswordRequireSpecial: false,
PasswordRequireNumber: false,
}
}
type PasswordResetService struct {
userRepo userRepositoryInterface
cache *cache.CacheManager
config *PasswordResetConfig
}
func NewPasswordResetService(
userRepo userRepositoryInterface,
cache *cache.CacheManager,
config *PasswordResetConfig,
) *PasswordResetService {
if config == nil {
config = DefaultPasswordResetConfig()
}
return &PasswordResetService{
userRepo: userRepo,
cache: cache,
config: config,
}
}
func (s *PasswordResetService) ForgotPassword(ctx context.Context, email string) error {
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
return nil
}
tokenBytes := make([]byte, 32)
if _, err := cryptorand.Read(tokenBytes); err != nil {
return fmt.Errorf("生成重置Token失败: %w", err)
}
resetToken := hex.EncodeToString(tokenBytes)
cacheKey := "pwd_reset:" + resetToken
ttl := s.config.TokenTTL
if err := s.cache.Set(ctx, cacheKey, user.ID, ttl, ttl); err != nil {
return fmt.Errorf("缓存重置Token失败: %w", err)
}
go s.sendResetEmail(domain.DerefStr(user.Email), user.Username, resetToken)
return nil
}
func (s *PasswordResetService) ResetPassword(ctx context.Context, token, newPassword string) error {
if token == "" || newPassword == "" {
return errors.New("参数不完整")
}
cacheKey := "pwd_reset:" + token
val, ok := s.cache.Get(ctx, cacheKey)
if !ok {
return errors.New("重置链接已失效或不存在,请重新申请")
}
userID, ok := int64Value(val)
if !ok {
return errors.New("重置Token数据异常")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return errors.New("用户不存在")
}
if err := s.doResetPassword(ctx, user, newPassword); err != nil {
return err
}
if err := s.cache.Delete(ctx, cacheKey); err != nil {
return fmt.Errorf("清理重置Token失败: %w", err)
}
return nil
}
func (s *PasswordResetService) ValidateResetToken(ctx context.Context, token string) (bool, error) {
if token == "" {
return false, errors.New("token不能为空")
}
_, ok := s.cache.Get(ctx, "pwd_reset:"+token)
return ok, nil
}
func (s *PasswordResetService) sendResetEmail(email, username, token string) {
if s.config.SMTPHost == "" {
return
}
resetURL := fmt.Sprintf("%s/reset-password?token=%s", s.config.SiteURL, token)
subject := "密码重置请求"
body := fmt.Sprintf(`您好 %s
您收到此邮件,是因为有人请求重置账户密码。
请点击以下链接重置密码(链接将在 %s 后失效):
%s
如果不是您本人操作,请忽略此邮件,您的密码不会被修改。
用户管理系统团队`, username, s.config.TokenTTL.String(), resetURL)
var authInfo smtp.Auth
if s.config.SMTPUser != "" || s.config.SMTPPass != "" {
authInfo = smtp.PlainAuth("", s.config.SMTPUser, s.config.SMTPPass, s.config.SMTPHost)
}
msg := fmt.Sprintf(
"From: %s\r\nTo: %s\r\nSubject: %s\r\nContent-Type: text/plain; charset=UTF-8\r\n\r\n%s",
s.config.FromEmail,
email,
subject,
body,
)
addr := fmt.Sprintf("%s:%d", s.config.SMTPHost, s.config.SMTPPort)
if err := smtp.SendMail(addr, authInfo, s.config.FromEmail, []string{email}, []byte(msg)); err != nil {
log.Printf("password-reset-email: send failed to=%s err=%v", email, err)
}
}
// ForgotPasswordByPhoneRequest 短信密码重置请求
type ForgotPasswordByPhoneRequest struct {
Phone string `json:"phone" binding:"required"`
}
// ForgotPasswordByPhone 通过手机验证码重置密码 - 发送验证码
func (s *PasswordResetService) ForgotPasswordByPhone(ctx context.Context, phone string) (string, error) {
user, err := s.userRepo.GetByPhone(ctx, phone)
if err != nil {
return "", nil // 用户不存在不提示,防止用户枚举
}
// 生成6位数字验证码
code, err := generateSMSCode()
if err != nil {
return "", fmt.Errorf("生成验证码失败: %w", err)
}
// 存储验证码关联用户ID
cacheKey := fmt.Sprintf("pwd_reset_sms:%s", phone)
ttl := s.config.TokenTTL
if err := s.cache.Set(ctx, cacheKey, user.ID, ttl, ttl); err != nil {
return "", fmt.Errorf("缓存验证码失败: %w", err)
}
// 存储验证码到另一个key用于后续校验
codeKey := fmt.Sprintf("pwd_reset_sms_code:%s", phone)
if err := s.cache.Set(ctx, codeKey, code, ttl, ttl); err != nil {
return "", fmt.Errorf("缓存验证码失败: %w", err)
}
return code, nil
}
// ResetPasswordByPhoneRequest 通过手机验证码重置密码请求
type ResetPasswordByPhoneRequest struct {
Phone string `json:"phone" binding:"required"`
Code string `json:"code" binding:"required"`
NewPassword string `json:"new_password" binding:"required"`
}
// ResetPasswordByPhone 通过手机验证码重置密码 - 验证并重置
func (s *PasswordResetService) ResetPasswordByPhone(ctx context.Context, req *ResetPasswordByPhoneRequest) error {
if req.Phone == "" || req.Code == "" || req.NewPassword == "" {
return errors.New("参数不完整")
}
codeKey := fmt.Sprintf("pwd_reset_sms_code:%s", req.Phone)
storedCode, ok := s.cache.Get(ctx, codeKey)
if !ok {
return errors.New("验证码已失效,请重新获取")
}
code, ok := storedCode.(string)
if !ok || code != req.Code {
return errors.New("验证码不正确")
}
// 获取用户ID
cacheKey := fmt.Sprintf("pwd_reset_sms:%s", req.Phone)
val, ok := s.cache.Get(ctx, cacheKey)
if !ok {
return errors.New("验证码已失效,请重新获取")
}
userID, ok := int64Value(val)
if !ok {
return errors.New("验证码数据异常")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return errors.New("用户不存在")
}
if err := s.doResetPassword(ctx, user, req.NewPassword); err != nil {
return err
}
// 清理验证码
s.cache.Delete(ctx, codeKey)
s.cache.Delete(ctx, cacheKey)
return nil
}
func (s *PasswordResetService) doResetPassword(ctx context.Context, user *domain.User, newPassword string) error {
policy := security.PasswordPolicy{
MinLength: s.config.PasswordMinLen,
RequireSpecial: s.config.PasswordRequireSpecial,
RequireNumber: s.config.PasswordRequireNumber,
}.Normalize()
if err := policy.Validate(newPassword); err != nil {
return err
}
hashedPassword, err := auth.HashPassword(newPassword)
if err != nil {
return fmt.Errorf("密码加密失败: %w", err)
}
user.Password = hashedPassword
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("更新密码失败: %w", err)
}
return nil
}

View File

@@ -0,0 +1,223 @@
package service
import (
"context"
"errors"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
// PermissionService 权限服务
type PermissionService struct {
permissionRepo *repository.PermissionRepository
}
// NewPermissionService 创建权限服务
func NewPermissionService(
permissionRepo *repository.PermissionRepository,
) *PermissionService {
return &PermissionService{
permissionRepo: permissionRepo,
}
}
// CreatePermissionRequest 创建权限请求
type CreatePermissionRequest struct {
Name string `json:"name" binding:"required"`
Code string `json:"code" binding:"required"`
Type int `json:"type" binding:"required"`
Description string `json:"description"`
ParentID *int64 `json:"parent_id"`
Path string `json:"path"`
Method string `json:"method"`
Sort int `json:"sort"`
Icon string `json:"icon"`
}
// UpdatePermissionRequest 更新权限请求
type UpdatePermissionRequest struct {
Name string `json:"name"`
Description string `json:"description"`
ParentID *int64 `json:"parent_id"`
Path string `json:"path"`
Method string `json:"method"`
Sort int `json:"sort"`
Icon string `json:"icon"`
}
// CreatePermission 创建权限
func (s *PermissionService) CreatePermission(ctx context.Context, req *CreatePermissionRequest) (*domain.Permission, error) {
// 检查权限代码是否已存在
exists, err := s.permissionRepo.ExistsByCode(ctx, req.Code)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("权限代码已存在")
}
// 检查父权限是否存在
if req.ParentID != nil {
_, err := s.permissionRepo.GetByID(ctx, *req.ParentID)
if err != nil {
return nil, errors.New("父权限不存在")
}
}
// 创建权限
permission := &domain.Permission{
Name: req.Name,
Code: req.Code,
Type: domain.PermissionType(req.Type),
Description: req.Description,
ParentID: req.ParentID,
Level: 1,
Path: req.Path,
Method: req.Method,
Sort: req.Sort,
Icon: req.Icon,
Status: domain.PermissionStatusEnabled,
}
if req.ParentID != nil {
permission.Level = 2
}
if err := s.permissionRepo.Create(ctx, permission); err != nil {
return nil, err
}
return permission, nil
}
// UpdatePermission 更新权限
func (s *PermissionService) UpdatePermission(ctx context.Context, permissionID int64, req *UpdatePermissionRequest) (*domain.Permission, error) {
permission, err := s.permissionRepo.GetByID(ctx, permissionID)
if err != nil {
return nil, errors.New("权限不存在")
}
// 检查父权限是否存在
if req.ParentID != nil {
if *req.ParentID == permissionID {
return nil, errors.New("不能将权限设置为自己的父权限")
}
_, err := s.permissionRepo.GetByID(ctx, *req.ParentID)
if err != nil {
return nil, errors.New("父权限不存在")
}
permission.ParentID = req.ParentID
}
// 更新字段
if req.Name != "" {
permission.Name = req.Name
}
if req.Description != "" {
permission.Description = req.Description
}
if req.Path != "" {
permission.Path = req.Path
}
if req.Method != "" {
permission.Method = req.Method
}
if req.Sort > 0 {
permission.Sort = req.Sort
}
if req.Icon != "" {
permission.Icon = req.Icon
}
if err := s.permissionRepo.Update(ctx, permission); err != nil {
return nil, err
}
return permission, nil
}
// DeletePermission 删除权限
func (s *PermissionService) DeletePermission(ctx context.Context, permissionID int64) error {
_, err := s.permissionRepo.GetByID(ctx, permissionID)
if err != nil {
return errors.New("权限不存在")
}
// 检查是否有子权限
children, err := s.permissionRepo.ListByParentID(ctx, permissionID)
if err == nil && len(children) > 0 {
return errors.New("存在子权限,无法删除")
}
return s.permissionRepo.Delete(ctx, permissionID)
}
// GetPermission 获取权限信息
func (s *PermissionService) GetPermission(ctx context.Context, permissionID int64) (*domain.Permission, error) {
return s.permissionRepo.GetByID(ctx, permissionID)
}
// ListPermissions 获取权限列表
type ListPermissionRequest struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
Type int `json:"type"`
Status int `json:"status"`
Keyword string `json:"keyword"`
}
func (s *PermissionService) ListPermissions(ctx context.Context, req *ListPermissionRequest) ([]*domain.Permission, int64, error) {
if req.Page <= 0 {
req.Page = 1
}
if req.PageSize <= 0 {
req.PageSize = 20
}
offset := (req.Page - 1) * req.PageSize
if req.Keyword != "" {
return s.permissionRepo.Search(ctx, req.Keyword, offset, req.PageSize)
}
// Type > 0 表示按类型过滤0 表示不过滤(查全部)
if req.Type > 0 {
return s.permissionRepo.ListByType(ctx, domain.PermissionType(req.Type), offset, req.PageSize)
}
// Status > 0 表示按状态过滤0 表示不过滤(查全部)
if req.Status > 0 {
return s.permissionRepo.ListByStatus(ctx, domain.PermissionStatus(req.Status), offset, req.PageSize)
}
return s.permissionRepo.List(ctx, offset, req.PageSize)
}
// UpdatePermissionStatus 更新权限状态
func (s *PermissionService) UpdatePermissionStatus(ctx context.Context, permissionID int64, status domain.PermissionStatus) error {
return s.permissionRepo.UpdateStatus(ctx, permissionID, status)
}
// GetPermissionTree 获取权限树
func (s *PermissionService) GetPermissionTree(ctx context.Context) ([]*domain.Permission, error) {
// 获取所有权限
permissions, _, err := s.permissionRepo.List(ctx, 0, 1000)
if err != nil {
return nil, err
}
// 构建树形结构
return s.buildPermissionTree(permissions, 0), nil
}
// buildPermissionTree 构建权限树
func (s *PermissionService) buildPermissionTree(permissions []*domain.Permission, parentID int64) []*domain.Permission {
var tree []*domain.Permission
for _, perm := range permissions {
if (parentID == 0 && perm.ParentID == nil) || (perm.ParentID != nil && *perm.ParentID == parentID) {
perm.Children = s.buildPermissionTree(permissions, perm.ID)
tree = append(tree, perm)
}
}
return tree
}

View File

@@ -0,0 +1,122 @@
# Codex Running in OpenCode
You are running Codex through OpenCode, an open-source terminal coding assistant. OpenCode provides different tools but follows Codex operating principles.
## CRITICAL: Tool Replacements
<critical_rule priority="0">
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
- NEVER use: apply_patch, applyPatch
- ALWAYS use: edit tool for ALL file modifications
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
</critical_rule>
<critical_rule priority="0">
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
- NEVER use: update_plan, updatePlan, read_plan, readPlan
- ALWAYS use: todowrite for task/plan updates, todoread to read plans
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
</critical_rule>
## Available OpenCode Tools
**File Operations:**
- `write` - Create new files
- Overwriting existing files requires a prior Read in this session; default to ASCII unless the file already uses Unicode.
- `edit` - Modify existing files (REPLACES apply_patch)
- Requires a prior Read in this session; preserve exact indentation; ensure `oldString` uniquely matches or use `replaceAll`; edit fails if ambiguous or missing.
- `read` - Read file contents
**Search/Discovery:**
- `grep` - Search file contents (tool, not bash grep); use `include` to filter patterns; set `path` only when not searching workspace root; for cross-file match counts use bash with `rg`.
- `glob` - Find files by pattern; defaults to workspace cwd unless `path` is set.
- `list` - List directories (requires absolute paths)
**Execution:**
- `bash` - Run shell commands
- No workdir parameter; do not include it in tool calls.
- Always include a short description for the command.
- Do not use cd; use absolute paths in commands.
- Quote paths containing spaces with double quotes.
- Chain multiple commands with ';' or '&&'; avoid newlines.
- Use Grep/Glob tools for searches; only use bash with `rg` when you need counts or advanced features.
- Do not use `ls`/`cat` in bash; use `list`/`read` tools instead.
- For deletions (rm), verify by listing parent dir with `list`.
**Network:**
- `webfetch` - Fetch web content
- Use fully-formed URLs (http/https; http auto-upgrades to https).
- Always set `format` to one of: text | markdown | html; prefer markdown unless otherwise required.
- Read-only; short cache window.
**Task Management:**
- `todowrite` - Manage tasks/plans (REPLACES update_plan)
- `todoread` - Read current plan
## Substitution Rules
Base instruction says: You MUST use instead:
apply_patch → edit
update_plan → todowrite
read_plan → todoread
**Path Usage:** Use per-tool conventions to avoid conflicts:
- Tool calls: `read`, `edit`, `write`, `list` require absolute paths.
- Searches: `grep`/`glob` default to the workspace cwd; prefer relative include patterns; set `path` only when a different root is needed.
- Presentation: In assistant messages, show workspace-relative paths; use absolute paths only inside tool calls.
- Tool schema overrides general path preferences—do not convert required absolute paths to relative.
## Verification Checklist
Before file/plan modifications:
1. Am I using "edit" NOT "apply_patch"?
2. Am I using "todowrite" NOT "update_plan"?
3. Is this tool in the approved list above?
4. Am I following each tool's path requirements?
If ANY answer is NO → STOP and correct before proceeding.
## OpenCode Working Style
**Communication:**
- Send brief preambles (8-12 words) before tool calls, building on prior context
- Provide progress updates during longer tasks
**Execution:**
- Keep working autonomously until query is fully resolved before yielding
- Don't return to user with partial solutions
**Code Approach:**
- New projects: Be ambitious and creative
- Existing codebases: Surgical precision - modify only what's requested unless explicitly instructed to do otherwise
**Testing:**
- If tests exist: Start specific to your changes, then broader validation
## Advanced Tools
**Task Tool (Sub-Agents):**
- Use the Task tool (functions.task) to launch sub-agents
- Check the Task tool description for current agent types and their capabilities
- Useful for complex analysis, specialized workflows, or tasks requiring isolated context
- The agent list is dynamically generated - refer to tool schema for available agents
**Parallelization:**
- When multiple independent tool calls are needed, use multi_tool_use.parallel to run them concurrently.
- Reserve sequential calls for ordered or data-dependent steps.
**MCP Tools:**
- Model Context Protocol servers provide additional capabilities
- MCP tools are prefixed: `mcp__<server-name>__<tool-name>`
- Check your available tools for MCP integrations
- Use when the tool's functionality matches your task needs
## What Remains from Codex
Sandbox policies, approval mechanisms, final answer formatting, git commit protocols, and file reference formats all follow Codex instructions. In approval policy "never", never request escalations.
## Approvals & Safety
- Assume workspace-write filesystem, network enabled, approval on-failure unless explicitly stated otherwise.
- When a command fails due to sandboxing or permissions, retry with escalated permissions if allowed by policy, including a one-line justification.
- Treat destructive commands (e.g., `rm`, `git reset --hard`) as requiring explicit user request or approval.
- When uncertain, prefer non-destructive verification first (e.g., confirm file existence with `list`, then delete with `bash`).

View File

@@ -0,0 +1,63 @@
<user_instructions priority="0">
<environment_override priority="0">
YOU ARE IN A DIFFERENT ENVIRONMENT. These instructions override ALL previous tool references.
</environment_override>
<tool_replacements priority="0">
<critical_rule priority="0">
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
- NEVER use: apply_patch, applyPatch
- ALWAYS use: edit tool for ALL file modifications
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
</critical_rule>
<critical_rule priority="0">
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
- NEVER use: update_plan, updatePlan
- ALWAYS use: todowrite for ALL task/plan operations
- Use todoread to read current plan
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
</critical_rule>
</tool_replacements>
<available_tools priority="0">
File Operations:
• write - Create new files
• edit - Modify existing files (REPLACES apply_patch)
• patch - Apply diff patches
• read - Read file contents
Search/Discovery:
• grep - Search file contents
• glob - Find files by pattern
• list - List directories (use relative paths)
Execution:
• bash - Run shell commands
Network:
• webfetch - Fetch web content
Task Management:
• todowrite - Manage tasks/plans (REPLACES update_plan)
• todoread - Read current plan
</available_tools>
<substitution_rules priority="0">
Base instruction says: You MUST use instead:
apply_patch → edit
update_plan → todowrite
read_plan → todoread
absolute paths → relative paths
</substitution_rules>
<verification_checklist priority="0">
Before file/plan modifications:
1. Am I using "edit" NOT "apply_patch"?
2. Am I using "todowrite" NOT "update_plan"?
3. Is this tool in the approved list above?
4. Am I using relative paths?
If ANY answer is NO → STOP and correct before proceeding.
</verification_checklist>
</user_instructions>

View File

@@ -0,0 +1,216 @@
package service
import (
"context"
"sync/atomic"
"github.com/user-management-system/internal/pkg/ctxkey"
)
type requestMetadataContextKey struct{}
var requestMetadataKey = requestMetadataContextKey{}
type RequestMetadata struct {
IsMaxTokensOneHaikuRequest *bool
ThinkingEnabled *bool
PrefetchedStickyAccountID *int64
PrefetchedStickyGroupID *int64
SingleAccountRetry *bool
AccountSwitchCount *int
}
var (
requestMetadataFallbackIsMaxTokensOneHaikuTotal atomic.Int64
requestMetadataFallbackThinkingEnabledTotal atomic.Int64
requestMetadataFallbackPrefetchedStickyAccount atomic.Int64
requestMetadataFallbackPrefetchedStickyGroup atomic.Int64
requestMetadataFallbackSingleAccountRetryTotal atomic.Int64
requestMetadataFallbackAccountSwitchCountTotal atomic.Int64
)
func RequestMetadataFallbackStats() (isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount int64) {
return requestMetadataFallbackIsMaxTokensOneHaikuTotal.Load(),
requestMetadataFallbackThinkingEnabledTotal.Load(),
requestMetadataFallbackPrefetchedStickyAccount.Load(),
requestMetadataFallbackPrefetchedStickyGroup.Load(),
requestMetadataFallbackSingleAccountRetryTotal.Load(),
requestMetadataFallbackAccountSwitchCountTotal.Load()
}
func metadataFromContext(ctx context.Context) *RequestMetadata {
if ctx == nil {
return nil
}
md, _ := ctx.Value(requestMetadataKey).(*RequestMetadata)
return md
}
func updateRequestMetadata(
ctx context.Context,
bridgeOldKeys bool,
update func(md *RequestMetadata),
legacyBridge func(ctx context.Context) context.Context,
) context.Context {
if ctx == nil {
return nil
}
current := metadataFromContext(ctx)
next := &RequestMetadata{}
if current != nil {
*next = *current
}
update(next)
ctx = context.WithValue(ctx, requestMetadataKey, next)
if bridgeOldKeys && legacyBridge != nil {
ctx = legacyBridge(ctx)
}
return ctx
}
func WithIsMaxTokensOneHaikuRequest(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
v := value
md.IsMaxTokensOneHaikuRequest = &v
}, func(base context.Context) context.Context {
return context.WithValue(base, ctxkey.IsMaxTokensOneHaikuRequest, value)
})
}
func WithThinkingEnabled(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
v := value
md.ThinkingEnabled = &v
}, func(base context.Context) context.Context {
return context.WithValue(base, ctxkey.ThinkingEnabled, value)
})
}
func WithPrefetchedStickySession(ctx context.Context, accountID, groupID int64, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
account := accountID
group := groupID
md.PrefetchedStickyAccountID = &account
md.PrefetchedStickyGroupID = &group
}, func(base context.Context) context.Context {
bridged := context.WithValue(base, ctxkey.PrefetchedStickyAccountID, accountID)
return context.WithValue(bridged, ctxkey.PrefetchedStickyGroupID, groupID)
})
}
func WithSingleAccountRetry(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
v := value
md.SingleAccountRetry = &v
}, func(base context.Context) context.Context {
return context.WithValue(base, ctxkey.SingleAccountRetry, value)
})
}
func WithAccountSwitchCount(ctx context.Context, value int, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
v := value
md.AccountSwitchCount = &v
}, func(base context.Context) context.Context {
return context.WithValue(base, ctxkey.AccountSwitchCount, value)
})
}
func IsMaxTokensOneHaikuRequestFromContext(ctx context.Context) (bool, bool) {
if md := metadataFromContext(ctx); md != nil && md.IsMaxTokensOneHaikuRequest != nil {
return *md.IsMaxTokensOneHaikuRequest, true
}
if ctx == nil {
return false, false
}
if value, ok := ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok {
requestMetadataFallbackIsMaxTokensOneHaikuTotal.Add(1)
return value, true
}
return false, false
}
func ThinkingEnabledFromContext(ctx context.Context) (bool, bool) {
if md := metadataFromContext(ctx); md != nil && md.ThinkingEnabled != nil {
return *md.ThinkingEnabled, true
}
if ctx == nil {
return false, false
}
if value, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
requestMetadataFallbackThinkingEnabledTotal.Add(1)
return value, true
}
return false, false
}
func PrefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) {
if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyGroupID != nil {
return *md.PrefetchedStickyGroupID, true
}
if ctx == nil {
return 0, false
}
v := ctx.Value(ctxkey.PrefetchedStickyGroupID)
switch t := v.(type) {
case int64:
requestMetadataFallbackPrefetchedStickyGroup.Add(1)
return t, true
case int:
requestMetadataFallbackPrefetchedStickyGroup.Add(1)
return int64(t), true
}
return 0, false
}
func PrefetchedStickyAccountIDFromContext(ctx context.Context) (int64, bool) {
if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyAccountID != nil {
return *md.PrefetchedStickyAccountID, true
}
if ctx == nil {
return 0, false
}
v := ctx.Value(ctxkey.PrefetchedStickyAccountID)
switch t := v.(type) {
case int64:
requestMetadataFallbackPrefetchedStickyAccount.Add(1)
return t, true
case int:
requestMetadataFallbackPrefetchedStickyAccount.Add(1)
return int64(t), true
}
return 0, false
}
func SingleAccountRetryFromContext(ctx context.Context) (bool, bool) {
if md := metadataFromContext(ctx); md != nil && md.SingleAccountRetry != nil {
return *md.SingleAccountRetry, true
}
if ctx == nil {
return false, false
}
if value, ok := ctx.Value(ctxkey.SingleAccountRetry).(bool); ok {
requestMetadataFallbackSingleAccountRetryTotal.Add(1)
return value, true
}
return false, false
}
func AccountSwitchCountFromContext(ctx context.Context) (int, bool) {
if md := metadataFromContext(ctx); md != nil && md.AccountSwitchCount != nil {
return *md.AccountSwitchCount, true
}
if ctx == nil {
return 0, false
}
v := ctx.Value(ctxkey.AccountSwitchCount)
switch t := v.(type) {
case int:
requestMetadataFallbackAccountSwitchCountTotal.Add(1)
return t, true
case int64:
requestMetadataFallbackAccountSwitchCountTotal.Add(1)
return int(t), true
}
return 0, false
}

284
internal/service/role.go Normal file
View File

@@ -0,0 +1,284 @@
package service
import (
"context"
"errors"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
// RoleService 角色服务
type RoleService struct {
roleRepo *repository.RoleRepository
rolePermissionRepo *repository.RolePermissionRepository
}
// NewRoleService 创建角色服务
func NewRoleService(
roleRepo *repository.RoleRepository,
rolePermissionRepo *repository.RolePermissionRepository,
) *RoleService {
return &RoleService{
roleRepo: roleRepo,
rolePermissionRepo: rolePermissionRepo,
}
}
// CreateRoleRequest 创建角色请求
type CreateRoleRequest struct {
Name string `json:"name" binding:"required"`
Code string `json:"code" binding:"required"`
Description string `json:"description"`
ParentID *int64 `json:"parent_id"`
}
// UpdateRoleRequest 更新角色请求
type UpdateRoleRequest struct {
Name string `json:"name"`
Description string `json:"description"`
ParentID *int64 `json:"parent_id"`
}
// CreateRole 创建角色
func (s *RoleService) CreateRole(ctx context.Context, req *CreateRoleRequest) (*domain.Role, error) {
// 检查角色代码是否已存在
exists, err := s.roleRepo.ExistsByCode(ctx, req.Code)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("角色代码已存在")
}
// 设置角色层级
level := 1
if req.ParentID != nil {
parentRole, err := s.roleRepo.GetByID(ctx, *req.ParentID)
if err != nil {
return nil, errors.New("父角色不存在")
}
level = parentRole.Level + 1
}
// 创建角色
role := &domain.Role{
Name: req.Name,
Code: req.Code,
Description: req.Description,
ParentID: req.ParentID,
Level: level,
Status: domain.RoleStatusEnabled,
}
if err := s.roleRepo.Create(ctx, role); err != nil {
return nil, err
}
return role, nil
}
const maxRoleDepth = 5 // 角色继承深度上限,可配置
// UpdateRole 更新角色
func (s *RoleService) UpdateRole(ctx context.Context, roleID int64, req *UpdateRoleRequest) (*domain.Role, error) {
role, err := s.roleRepo.GetByID(ctx, roleID)
if err != nil {
return nil, errors.New("角色不存在")
}
// 检查父角色是否存在
if req.ParentID != nil {
if *req.ParentID == roleID {
return nil, errors.New("不能将角色设置为自己的父角色")
}
// 检测循环继承:检查新父角色的祖先链是否包含当前角色
if err := s.checkCircularInheritance(ctx, roleID, *req.ParentID); err != nil {
return nil, err
}
// 检测继承深度:计算新父角色的深度 + 1
if err := s.checkInheritanceDepth(ctx, *req.ParentID, maxRoleDepth-1); err != nil {
return nil, err
}
role.ParentID = req.ParentID
}
// 更新字段
if req.Name != "" {
role.Name = req.Name
}
if req.Description != "" {
role.Description = req.Description
}
if err := s.roleRepo.Update(ctx, role); err != nil {
return nil, err
}
return role, nil
}
// checkCircularInheritance 检测循环继承
// 如果将 childID 的父角色设为 parentID检查 parentID 的祖先链是否包含 childID
func (s *RoleService) checkCircularInheritance(ctx context.Context, childID, parentID int64) error {
ancestorIDs, err := s.roleRepo.GetAncestorIDs(ctx, parentID)
if err != nil {
return err
}
for _, ancestorID := range ancestorIDs {
if ancestorID == childID {
return errors.New("检测到循环继承,操作被拒绝")
}
}
return nil
}
// checkInheritanceDepth 检测继承深度是否超限
func (s *RoleService) checkInheritanceDepth(ctx context.Context, roleID int64, maxDepth int) error {
if maxDepth <= 0 {
return errors.New("继承深度超限最大支持5层")
}
depth := 0
currentID := roleID
for {
role, err := s.roleRepo.GetByID(ctx, currentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
break
}
return err
}
if role.ParentID == nil {
break
}
depth++
if depth > maxDepth {
return errors.New("继承深度超限最大支持5层")
}
currentID = *role.ParentID
}
return nil
}
// DeleteRole 删除角色
func (s *RoleService) DeleteRole(ctx context.Context, roleID int64) error {
role, err := s.roleRepo.GetByID(ctx, roleID)
if err != nil {
return errors.New("角色不存在")
}
// 系统角色不能删除
if role.IsSystem {
return errors.New("系统角色不能删除")
}
// 检查是否有子角色
children, err := s.roleRepo.ListByParentID(ctx, roleID)
if err == nil && len(children) > 0 {
return errors.New("存在子角色,无法删除")
}
// 删除角色权限关联
if err := s.rolePermissionRepo.DeleteByRoleID(ctx, roleID); err != nil {
return err
}
// 删除角色
return s.roleRepo.Delete(ctx, roleID)
}
// GetRole 获取角色信息
func (s *RoleService) GetRole(ctx context.Context, roleID int64) (*domain.Role, error) {
return s.roleRepo.GetByID(ctx, roleID)
}
// ListRoles 获取角色列表
type ListRoleRequest struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
Status int `json:"status"`
Keyword string `json:"keyword"`
}
func (s *RoleService) ListRoles(ctx context.Context, req *ListRoleRequest) ([]*domain.Role, int64, error) {
if req.Page <= 0 {
req.Page = 1
}
if req.PageSize <= 0 {
req.PageSize = 20
}
offset := (req.Page - 1) * req.PageSize
if req.Keyword != "" {
return s.roleRepo.Search(ctx, req.Keyword, offset, req.PageSize)
}
// Status > 0 表示按状态过滤0 表示不过滤(查全部)
if req.Status > 0 {
return s.roleRepo.ListByStatus(ctx, domain.RoleStatus(req.Status), offset, req.PageSize)
}
return s.roleRepo.List(ctx, offset, req.PageSize)
}
// UpdateRoleStatus 更新角色状态
func (s *RoleService) UpdateRoleStatus(ctx context.Context, roleID int64, status domain.RoleStatus) error {
role, err := s.roleRepo.GetByID(ctx, roleID)
if err != nil {
return errors.New("角色不存在")
}
// 系统角色不能禁用
if role.IsSystem && status == domain.RoleStatusDisabled {
return errors.New("系统角色不能禁用")
}
return s.roleRepo.UpdateStatus(ctx, roleID, status)
}
// GetRolePermissions 获取角色权限(包含继承的父角色权限)
func (s *RoleService) GetRolePermissions(ctx context.Context, roleID int64) ([]*domain.Permission, error) {
// 收集所有角色ID包括当前角色和所有父角色
allRoleIDs := []int64{roleID}
ancestorIDs, err := s.roleRepo.GetAncestorIDs(ctx, roleID)
if err != nil {
return nil, err
}
allRoleIDs = append(allRoleIDs, ancestorIDs...)
// 批量获取所有角色的权限ID
permissionIDs, err := s.rolePermissionRepo.GetPermissionIDsByRoleIDs(ctx, allRoleIDs)
if err != nil {
return nil, err
}
// 批量获取权限详情
permissions, err := s.rolePermissionRepo.GetPermissionsByIDs(ctx, permissionIDs)
if err != nil {
return nil, err
}
return permissions, nil
}
// AssignPermissions 分配权限
func (s *RoleService) AssignPermissions(ctx context.Context, roleID int64, permissionIDs []int64) error {
// 删除原有权限
if err := s.rolePermissionRepo.DeleteByRoleID(ctx, roleID); err != nil {
return err
}
// 创建新权限关联
var rolePermissions []*domain.RolePermission
for _, permissionID := range permissionIDs {
rolePermissions = append(rolePermissions, &domain.RolePermission{
RoleID: roleID,
PermissionID: permissionID,
})
}
return s.rolePermissionRepo.BatchCreate(ctx, rolePermissions)
}

462
internal/service/sms.go Normal file
View File

@@ -0,0 +1,462 @@
package service
import (
"context"
cryptorand "crypto/rand"
"encoding/json"
"fmt"
"log"
"regexp"
"strings"
"time"
aliyunopenapiutil "github.com/alibabacloud-go/darabonba-openapi/v2/utils"
aliyunsms "github.com/alibabacloud-go/dysmsapi-20170525/v5/client"
"github.com/alibabacloud-go/tea/dara"
tccommon "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
tcprofile "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
tcsms "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/sms/v20210111"
)
var (
validPhonePattern = regexp.MustCompile(`^((\+86|86)?1[3-9]\d{9}|\+[1-9]\d{6,14})$`)
mainlandPhonePattern = regexp.MustCompile(`^1[3-9]\d{9}$`)
mainlandPhone86Pattern = regexp.MustCompile(`^86(1[3-9]\d{9})$`)
mainlandPhone0086Pattern = regexp.MustCompile(`^0086(1[3-9]\d{9})$`)
verificationCodeCharset10 = 1000000
)
// SMSProvider sends one verification code to one phone number.
type SMSProvider interface {
SendVerificationCode(ctx context.Context, phone, code string) error
}
// MockSMSProvider is a test helper and is not wired into the server runtime.
type MockSMSProvider struct{}
func (m *MockSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
_ = ctx
// 安全:不在日志中记录完整验证码,仅显示部分信息用于调试
maskedCode := "****"
if len(code) >= 4 {
maskedCode = strings.Repeat("*", len(code)-4) + code[len(code)-4:]
}
log.Printf("[sms-mock] phone=%s code=%s ttl=5m", phone, maskedCode)
return nil
}
type aliyunSMSClient interface {
SendSms(request *aliyunsms.SendSmsRequest) (*aliyunsms.SendSmsResponse, error)
}
type tencentSMSClient interface {
SendSmsWithContext(ctx context.Context, request *tcsms.SendSmsRequest) (*tcsms.SendSmsResponse, error)
}
type AliyunSMSConfig struct {
AccessKeyID string
AccessKeySecret string
SignName string
TemplateCode string
Endpoint string
RegionID string
CodeParamName string
}
type AliyunSMSProvider struct {
cfg AliyunSMSConfig
client aliyunSMSClient
}
func NewAliyunSMSProvider(cfg AliyunSMSConfig) (SMSProvider, error) {
cfg = normalizeAliyunSMSConfig(cfg)
if cfg.AccessKeyID == "" || cfg.AccessKeySecret == "" || cfg.SignName == "" || cfg.TemplateCode == "" {
return nil, fmt.Errorf("aliyun SMS config is incomplete")
}
client, err := newAliyunSMSClient(cfg)
if err != nil {
return nil, fmt.Errorf("create aliyun SMS client failed: %w", err)
}
return &AliyunSMSProvider{
cfg: cfg,
client: client,
}, nil
}
func newAliyunSMSClient(cfg AliyunSMSConfig) (aliyunSMSClient, error) {
client, err := aliyunsms.NewClient(&aliyunopenapiutil.Config{
AccessKeyId: dara.String(cfg.AccessKeyID),
AccessKeySecret: dara.String(cfg.AccessKeySecret),
Endpoint: stringPointerOrNil(cfg.Endpoint),
RegionId: dara.String(cfg.RegionID),
})
if err != nil {
return nil, err
}
return client, nil
}
func (a *AliyunSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
_ = ctx
templateParam, err := json.Marshal(map[string]string{
a.cfg.CodeParamName: code,
})
if err != nil {
return fmt.Errorf("marshal aliyun SMS template param failed: %w", err)
}
resp, err := a.client.SendSms(
new(aliyunsms.SendSmsRequest).
SetPhoneNumbers(normalizePhoneForSMS(phone)).
SetSignName(a.cfg.SignName).
SetTemplateCode(a.cfg.TemplateCode).
SetTemplateParam(string(templateParam)),
)
if err != nil {
return fmt.Errorf("aliyun SMS request failed: %w", err)
}
if resp == nil || resp.Body == nil {
return fmt.Errorf("aliyun SMS returned empty response")
}
body := resp.Body
if !strings.EqualFold(dara.StringValue(body.Code), "OK") {
return fmt.Errorf(
"aliyun SMS rejected: code=%s message=%s request_id=%s",
valueOrDefault(dara.StringValue(body.Code), "unknown"),
valueOrDefault(dara.StringValue(body.Message), "unknown"),
valueOrDefault(dara.StringValue(body.RequestId), "unknown"),
)
}
return nil
}
type TencentSMSConfig struct {
SecretID string
SecretKey string
AppID string
SignName string
TemplateID string
Region string
Endpoint string
}
type TencentSMSProvider struct {
cfg TencentSMSConfig
client tencentSMSClient
}
func NewTencentSMSProvider(cfg TencentSMSConfig) (SMSProvider, error) {
cfg = normalizeTencentSMSConfig(cfg)
if cfg.SecretID == "" || cfg.SecretKey == "" || cfg.AppID == "" || cfg.SignName == "" || cfg.TemplateID == "" {
return nil, fmt.Errorf("tencent SMS config is incomplete")
}
client, err := newTencentSMSClient(cfg)
if err != nil {
return nil, fmt.Errorf("create tencent SMS client failed: %w", err)
}
return &TencentSMSProvider{
cfg: cfg,
client: client,
}, nil
}
func newTencentSMSClient(cfg TencentSMSConfig) (tencentSMSClient, error) {
clientProfile := tcprofile.NewClientProfile()
clientProfile.HttpProfile.ReqTimeout = 30
if cfg.Endpoint != "" {
clientProfile.HttpProfile.Endpoint = cfg.Endpoint
}
client, err := tcsms.NewClient(
tccommon.NewCredential(cfg.SecretID, cfg.SecretKey),
cfg.Region,
clientProfile,
)
if err != nil {
return nil, err
}
return client, nil
}
func (t *TencentSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
req := tcsms.NewSendSmsRequest()
req.PhoneNumberSet = []*string{tccommon.StringPtr(normalizePhoneForSMS(phone))}
req.SmsSdkAppId = tccommon.StringPtr(t.cfg.AppID)
req.SignName = tccommon.StringPtr(t.cfg.SignName)
req.TemplateId = tccommon.StringPtr(t.cfg.TemplateID)
req.TemplateParamSet = []*string{tccommon.StringPtr(code)}
resp, err := t.client.SendSmsWithContext(ctx, req)
if err != nil {
return fmt.Errorf("tencent SMS request failed: %w", err)
}
if resp == nil || resp.Response == nil {
return fmt.Errorf("tencent SMS returned empty response")
}
if len(resp.Response.SendStatusSet) == 0 {
return fmt.Errorf(
"tencent SMS returned empty status list: request_id=%s",
valueOrDefault(pointerString(resp.Response.RequestId), "unknown"),
)
}
status := resp.Response.SendStatusSet[0]
if !strings.EqualFold(pointerString(status.Code), "Ok") {
return fmt.Errorf(
"tencent SMS rejected: code=%s message=%s request_id=%s",
valueOrDefault(pointerString(status.Code), "unknown"),
valueOrDefault(pointerString(status.Message), "unknown"),
valueOrDefault(pointerString(resp.Response.RequestId), "unknown"),
)
}
return nil
}
type SMSCodeConfig struct {
CodeTTL time.Duration
ResendCooldown time.Duration
MaxDailyLimit int
}
func DefaultSMSCodeConfig() SMSCodeConfig {
return SMSCodeConfig{
CodeTTL: 5 * time.Minute,
ResendCooldown: time.Minute,
MaxDailyLimit: 10,
}
}
type SMSCodeService struct {
provider SMSProvider
cache cacheInterface
cfg SMSCodeConfig
}
type cacheInterface interface {
Get(ctx context.Context, key string) (interface{}, bool)
Set(ctx context.Context, key string, value interface{}, l1TTL, l2TTL time.Duration) error
Delete(ctx context.Context, key string) error
}
func NewSMSCodeService(provider SMSProvider, cacheManager cacheInterface, cfg SMSCodeConfig) *SMSCodeService {
if cfg.CodeTTL <= 0 {
cfg.CodeTTL = 5 * time.Minute
}
if cfg.ResendCooldown <= 0 {
cfg.ResendCooldown = time.Minute
}
if cfg.MaxDailyLimit <= 0 {
cfg.MaxDailyLimit = 10
}
return &SMSCodeService{
provider: provider,
cache: cacheManager,
cfg: cfg,
}
}
type SendCodeRequest struct {
Phone string `json:"phone" binding:"required"`
Purpose string `json:"purpose"`
Scene string `json:"scene"`
}
type SendCodeResponse struct {
ExpiresIn int `json:"expires_in"`
Cooldown int `json:"cooldown"`
}
func (s *SMSCodeService) SendCode(ctx context.Context, req *SendCodeRequest) (*SendCodeResponse, error) {
if s == nil || s.provider == nil || s.cache == nil {
return nil, fmt.Errorf("sms code service is not configured")
}
if req == nil {
return nil, newValidationError("\u8bf7\u6c42\u4e0d\u80fd\u4e3a\u7a7a")
}
phone := strings.TrimSpace(req.Phone)
if !isValidPhone(phone) {
return nil, newValidationError("\u624b\u673a\u53f7\u7801\u683c\u5f0f\u4e0d\u6b63\u786e")
}
purpose := strings.TrimSpace(req.Purpose)
if purpose == "" {
purpose = strings.TrimSpace(req.Scene)
}
cooldownKey := fmt.Sprintf("sms_cooldown:%s", phone)
if _, ok := s.cache.Get(ctx, cooldownKey); ok {
return nil, newRateLimitError(fmt.Sprintf("\u64cd\u4f5c\u8fc7\u4e8e\u9891\u7e41\uff0c\u8bf7 %d \u79d2\u540e\u518d\u8bd5", int(s.cfg.ResendCooldown.Seconds())))
}
dailyKey := fmt.Sprintf("sms_daily:%s:%s", phone, time.Now().Format("2006-01-02"))
var dailyCount int
if val, ok := s.cache.Get(ctx, dailyKey); ok {
if n, ok := intValue(val); ok {
dailyCount = n
}
}
if dailyCount >= s.cfg.MaxDailyLimit {
return nil, newRateLimitError(fmt.Sprintf("\u4eca\u65e5\u53d1\u9001\u6b21\u6570\u5df2\u8fbe\u4e0a\u9650\uff08%d\u6b21\uff09\uff0c\u8bf7\u660e\u65e5\u518d\u8bd5", s.cfg.MaxDailyLimit))
}
code, err := generateSMSCode()
if err != nil {
return nil, fmt.Errorf("generate sms code failed: %w", err)
}
codeKey := fmt.Sprintf("sms_code:%s:%s", purpose, phone)
if err := s.cache.Set(ctx, codeKey, code, s.cfg.CodeTTL, s.cfg.CodeTTL); err != nil {
return nil, fmt.Errorf("store sms code failed: %w", err)
}
if err := s.cache.Set(ctx, cooldownKey, true, s.cfg.ResendCooldown, s.cfg.ResendCooldown); err != nil {
_ = s.cache.Delete(ctx, codeKey)
return nil, fmt.Errorf("store sms cooldown failed: %w", err)
}
if err := s.cache.Set(ctx, dailyKey, dailyCount+1, 24*time.Hour, 24*time.Hour); err != nil {
_ = s.cache.Delete(ctx, codeKey)
_ = s.cache.Delete(ctx, cooldownKey)
return nil, fmt.Errorf("store sms daily counter failed: %w", err)
}
if err := s.provider.SendVerificationCode(ctx, phone, code); err != nil {
_ = s.cache.Delete(ctx, codeKey)
_ = s.cache.Delete(ctx, cooldownKey)
return nil, fmt.Errorf("\u77ed\u4fe1\u53d1\u9001\u5931\u8d25: %w", err)
}
return &SendCodeResponse{
ExpiresIn: int(s.cfg.CodeTTL.Seconds()),
Cooldown: int(s.cfg.ResendCooldown.Seconds()),
}, nil
}
func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, purpose, code string) error {
if s == nil || s.cache == nil {
return fmt.Errorf("sms code service is not configured")
}
if strings.TrimSpace(code) == "" {
return fmt.Errorf("\u9a8c\u8bc1\u7801\u4e0d\u80fd\u4e3a\u7a7a")
}
phone = strings.TrimSpace(phone)
purpose = strings.TrimSpace(purpose)
codeKey := fmt.Sprintf("sms_code:%s:%s", purpose, phone)
val, ok := s.cache.Get(ctx, codeKey)
if !ok {
return fmt.Errorf("\u9a8c\u8bc1\u7801\u5df2\u8fc7\u671f\u6216\u4e0d\u5b58\u5728")
}
stored, ok := val.(string)
if !ok || stored != code {
return fmt.Errorf("\u9a8c\u8bc1\u7801\u4e0d\u6b63\u786e")
}
if err := s.cache.Delete(ctx, codeKey); err != nil {
return fmt.Errorf("consume sms code failed: %w", err)
}
return nil
}
func isValidPhone(phone string) bool {
return validPhonePattern.MatchString(strings.TrimSpace(phone))
}
func generateSMSCode() (string, error) {
b := make([]byte, 4)
if _, err := cryptorand.Read(b); err != nil {
return "", err
}
n := int(b[0])<<24 | int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if n < 0 {
n = -n
}
n = n % verificationCodeCharset10
if n < 100000 {
n += 100000
}
return fmt.Sprintf("%06d", n), nil
}
func normalizeAliyunSMSConfig(cfg AliyunSMSConfig) AliyunSMSConfig {
cfg.AccessKeyID = strings.TrimSpace(cfg.AccessKeyID)
cfg.AccessKeySecret = strings.TrimSpace(cfg.AccessKeySecret)
cfg.SignName = strings.TrimSpace(cfg.SignName)
cfg.TemplateCode = strings.TrimSpace(cfg.TemplateCode)
cfg.Endpoint = strings.TrimSpace(cfg.Endpoint)
cfg.RegionID = strings.TrimSpace(cfg.RegionID)
cfg.CodeParamName = strings.TrimSpace(cfg.CodeParamName)
if cfg.RegionID == "" {
cfg.RegionID = "cn-hangzhou"
}
if cfg.CodeParamName == "" {
cfg.CodeParamName = "code"
}
return cfg
}
func normalizeTencentSMSConfig(cfg TencentSMSConfig) TencentSMSConfig {
cfg.SecretID = strings.TrimSpace(cfg.SecretID)
cfg.SecretKey = strings.TrimSpace(cfg.SecretKey)
cfg.AppID = strings.TrimSpace(cfg.AppID)
cfg.SignName = strings.TrimSpace(cfg.SignName)
cfg.TemplateID = strings.TrimSpace(cfg.TemplateID)
cfg.Region = strings.TrimSpace(cfg.Region)
cfg.Endpoint = strings.TrimSpace(cfg.Endpoint)
if cfg.Region == "" {
cfg.Region = "ap-guangzhou"
}
return cfg
}
func normalizePhoneForSMS(phone string) string {
phone = strings.TrimSpace(phone)
switch {
case mainlandPhonePattern.MatchString(phone):
return "+86" + phone
case mainlandPhone86Pattern.MatchString(phone):
return "+" + phone
case mainlandPhone0086Pattern.MatchString(phone):
return "+86" + mainlandPhone0086Pattern.ReplaceAllString(phone, "$1")
default:
return phone
}
}
func stringPointerOrNil(value string) *string {
if value == "" {
return nil
}
return dara.String(value)
}
func pointerString(value *string) string {
if value == nil {
return ""
}
return *value
}
func valueOrDefault(value, fallback string) string {
if strings.TrimSpace(value) == "" {
return fallback
}
return value
}

124
internal/service/stats.go Normal file
View File

@@ -0,0 +1,124 @@
package service
import (
"context"
"time"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
// StatsService 统计服务
type StatsService struct {
userRepo *repository.UserRepository
loginLogRepo *repository.LoginLogRepository
}
// NewStatsService 创建统计服务
func NewStatsService(
userRepo *repository.UserRepository,
loginLogRepo *repository.LoginLogRepository,
) *StatsService {
return &StatsService{
userRepo: userRepo,
loginLogRepo: loginLogRepo,
}
}
// UserStats 用户统计数据
type UserStats struct {
TotalUsers int64 `json:"total_users"`
ActiveUsers int64 `json:"active_users"`
InactiveUsers int64 `json:"inactive_users"`
LockedUsers int64 `json:"locked_users"`
DisabledUsers int64 `json:"disabled_users"`
NewUsersToday int64 `json:"new_users_today"`
NewUsersWeek int64 `json:"new_users_week"`
NewUsersMonth int64 `json:"new_users_month"`
}
// LoginStats 登录统计数据
type LoginStats struct {
LoginsTodaySuccess int64 `json:"logins_today_success"`
LoginsTodayFailed int64 `json:"logins_today_failed"`
LoginsWeek int64 `json:"logins_week"`
}
// DashboardStats 仪表盘综合统计
type DashboardStats struct {
Users UserStats `json:"users"`
Logins LoginStats `json:"logins"`
}
// GetUserStats 获取用户统计
func (s *StatsService) GetUserStats(ctx context.Context) (*UserStats, error) {
stats := &UserStats{}
// 统计总用户数
_, total, err := s.userRepo.List(ctx, 0, 1)
if err != nil {
return nil, err
}
stats.TotalUsers = total
// 按状态统计
statusCounts := map[domain.UserStatus]*int64{
domain.UserStatusActive: &stats.ActiveUsers,
domain.UserStatusInactive: &stats.InactiveUsers,
domain.UserStatusLocked: &stats.LockedUsers,
domain.UserStatusDisabled: &stats.DisabledUsers,
}
for status, countPtr := range statusCounts {
_, cnt, err := s.userRepo.ListByStatus(ctx, status, 0, 1)
if err == nil {
*countPtr = cnt
}
}
// 今日新增
stats.NewUsersToday = s.countNewUsers(ctx, daysAgo(0))
// 本周新增
stats.NewUsersWeek = s.countNewUsers(ctx, daysAgo(7))
// 本月新增
stats.NewUsersMonth = s.countNewUsers(ctx, daysAgo(30))
return stats, nil
}
// countNewUsers 统计指定时间之后的新增用户数
func (s *StatsService) countNewUsers(ctx context.Context, since time.Time) int64 {
_, count, err := s.userRepo.ListCreatedAfter(ctx, since, 0, 0)
if err != nil {
return 0
}
return count
}
// GetDashboardStats 获取仪表盘综合统计
func (s *StatsService) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
userStats, err := s.GetUserStats(ctx)
if err != nil {
return nil, err
}
loginStats := &LoginStats{}
// 今日登录成功/失败
today := daysAgo(0)
if s.loginLogRepo != nil {
loginStats.LoginsTodaySuccess = s.loginLogRepo.CountByResultSince(ctx, true, today)
loginStats.LoginsTodayFailed = s.loginLogRepo.CountByResultSince(ctx, false, today)
loginStats.LoginsWeek = s.loginLogRepo.CountByResultSince(ctx, true, daysAgo(7))
}
return &DashboardStats{
Users: *userStats,
Logins: *loginStats,
}, nil
}
// daysAgo 返回N天前的时间当天0点
func daysAgo(n int) time.Time {
now := time.Now()
start := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
return start.AddDate(0, 0, -n)
}

206
internal/service/theme.go Normal file
View File

@@ -0,0 +1,206 @@
package service
import (
"context"
"errors"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
// ThemeService 主题服务
type ThemeService struct {
themeRepo *repository.ThemeConfigRepository
}
// NewThemeService 创建主题服务
func NewThemeService(themeRepo *repository.ThemeConfigRepository) *ThemeService {
return &ThemeService{themeRepo: themeRepo}
}
// CreateThemeRequest 创建主题请求
type CreateThemeRequest struct {
Name string `json:"name" binding:"required"`
LogoURL string `json:"logo_url"`
FaviconURL string `json:"favicon_url"`
PrimaryColor string `json:"primary_color"`
SecondaryColor string `json:"secondary_color"`
BackgroundColor string `json:"background_color"`
TextColor string `json:"text_color"`
CustomCSS string `json:"custom_css"`
CustomJS string `json:"custom_js"`
IsDefault bool `json:"is_default"`
}
// UpdateThemeRequest 更新主题请求
type UpdateThemeRequest struct {
LogoURL string `json:"logo_url"`
FaviconURL string `json:"favicon_url"`
PrimaryColor string `json:"primary_color"`
SecondaryColor string `json:"secondary_color"`
BackgroundColor string `json:"background_color"`
TextColor string `json:"text_color"`
CustomCSS string `json:"custom_css"`
CustomJS string `json:"custom_js"`
Enabled *bool `json:"enabled"`
IsDefault *bool `json:"is_default"`
}
// CreateTheme 创建主题
func (s *ThemeService) CreateTheme(ctx context.Context, req *CreateThemeRequest) (*domain.ThemeConfig, error) {
// 检查主题名称是否已存在
existing, err := s.themeRepo.GetByName(ctx, req.Name)
if err == nil && existing != nil {
return nil, errors.New("主题名称已存在")
}
theme := &domain.ThemeConfig{
Name: req.Name,
LogoURL: req.LogoURL,
FaviconURL: req.FaviconURL,
PrimaryColor: req.PrimaryColor,
SecondaryColor: req.SecondaryColor,
BackgroundColor: req.BackgroundColor,
TextColor: req.TextColor,
CustomCSS: req.CustomCSS,
CustomJS: req.CustomJS,
IsDefault: req.IsDefault,
Enabled: true,
}
// 如果设置为默认,先清除其他默认
if req.IsDefault {
if err := s.clearDefaultThemes(ctx); err != nil {
return nil, err
}
}
if err := s.themeRepo.Create(ctx, theme); err != nil {
return nil, err
}
return theme, nil
}
// UpdateTheme 更新主题
func (s *ThemeService) UpdateTheme(ctx context.Context, id int64, req *UpdateThemeRequest) (*domain.ThemeConfig, error) {
theme, err := s.themeRepo.GetByID(ctx, id)
if err != nil {
return nil, errors.New("主题不存在")
}
if req.LogoURL != "" {
theme.LogoURL = req.LogoURL
}
if req.FaviconURL != "" {
theme.FaviconURL = req.FaviconURL
}
if req.PrimaryColor != "" {
theme.PrimaryColor = req.PrimaryColor
}
if req.SecondaryColor != "" {
theme.SecondaryColor = req.SecondaryColor
}
if req.BackgroundColor != "" {
theme.BackgroundColor = req.BackgroundColor
}
if req.TextColor != "" {
theme.TextColor = req.TextColor
}
if req.CustomCSS != "" {
theme.CustomCSS = req.CustomCSS
}
if req.CustomJS != "" {
theme.CustomJS = req.CustomJS
}
if req.Enabled != nil {
theme.Enabled = *req.Enabled
}
if req.IsDefault != nil && *req.IsDefault {
if err := s.clearDefaultThemes(ctx); err != nil {
return nil, err
}
theme.IsDefault = true
}
if err := s.themeRepo.Update(ctx, theme); err != nil {
return nil, err
}
return theme, nil
}
// DeleteTheme 删除主题
func (s *ThemeService) DeleteTheme(ctx context.Context, id int64) error {
theme, err := s.themeRepo.GetByID(ctx, id)
if err != nil {
return errors.New("主题不存在")
}
if theme.IsDefault {
return errors.New("不能删除默认主题")
}
return s.themeRepo.Delete(ctx, id)
}
// GetTheme 获取主题
func (s *ThemeService) GetTheme(ctx context.Context, id int64) (*domain.ThemeConfig, error) {
return s.themeRepo.GetByID(ctx, id)
}
// ListThemes 获取所有已启用主题
func (s *ThemeService) ListThemes(ctx context.Context) ([]*domain.ThemeConfig, error) {
return s.themeRepo.List(ctx)
}
// ListAllThemes 获取所有主题
func (s *ThemeService) ListAllThemes(ctx context.Context) ([]*domain.ThemeConfig, error) {
return s.themeRepo.ListAll(ctx)
}
// GetDefaultTheme 获取默认主题
func (s *ThemeService) GetDefaultTheme(ctx context.Context) (*domain.ThemeConfig, error) {
return s.themeRepo.GetDefault(ctx)
}
// SetDefaultTheme 设置默认主题
func (s *ThemeService) SetDefaultTheme(ctx context.Context, id int64) error {
theme, err := s.themeRepo.GetByID(ctx, id)
if err != nil {
return errors.New("主题不存在")
}
if !theme.Enabled {
return errors.New("不能将禁用的主题设为默认")
}
return s.themeRepo.SetDefault(ctx, id)
}
// GetActiveTheme 获取当前生效的主题
func (s *ThemeService) GetActiveTheme(ctx context.Context) (*domain.ThemeConfig, error) {
theme, err := s.themeRepo.GetDefault(ctx)
if err != nil {
// 返回默认配置
return domain.DefaultThemeConfig(), nil
}
return theme, nil
}
// clearDefaultThemes 清除所有默认主题标记
func (s *ThemeService) clearDefaultThemes(ctx context.Context) error {
themes, err := s.themeRepo.ListAll(ctx)
if err != nil {
return err
}
for _, t := range themes {
if t.IsDefault {
t.IsDefault = false
if err := s.themeRepo.Update(ctx, t); err != nil {
return err
}
}
}
return nil
}

148
internal/service/totp.go Normal file
View File

@@ -0,0 +1,148 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/user-management-system/internal/auth"
)
// TOTPService manages 2FA setup, enable/disable, and verification.
type TOTPService struct {
userRepo userRepositoryInterface
totpManager *auth.TOTPManager
}
func NewTOTPService(userRepo userRepositoryInterface) *TOTPService {
return &TOTPService{
userRepo: userRepo,
totpManager: auth.NewTOTPManager(),
}
}
type SetupTOTPResponse struct {
Secret string `json:"secret"`
QRCodeBase64 string `json:"qr_code_base64"`
RecoveryCodes []string `json:"recovery_codes"`
}
func (s *TOTPService) SetupTOTP(ctx context.Context, userID int64) (*SetupTOTPResponse, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
}
if user.TOTPEnabled {
return nil, errors.New("2FA \u5df2\u7ecf\u542f\u7528\uff0c\u5982\u9700\u91cd\u7f6e\u8bf7\u5148\u7981\u7528")
}
setup, err := s.totpManager.GenerateSecret(user.Username)
if err != nil {
return nil, fmt.Errorf("\u751f\u6210 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err)
}
// Persist the generated secret and recovery codes before activation.
user.TOTPSecret = setup.Secret
// Hash recovery codes before storing (SEC-03 fix)
hashedCodes := make([]string, len(setup.RecoveryCodes))
for i, code := range setup.RecoveryCodes {
hashedCodes[i], _ = auth.HashRecoveryCode(code)
}
codesJSON, _ := json.Marshal(hashedCodes)
user.TOTPRecoveryCodes = string(codesJSON)
if err := s.userRepo.UpdateTOTP(ctx, user); err != nil {
return nil, fmt.Errorf("\u4fdd\u5b58 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err)
}
return &SetupTOTPResponse{
Secret: setup.Secret,
QRCodeBase64: setup.QRCodeBase64,
RecoveryCodes: setup.RecoveryCodes,
}, nil
}
func (s *TOTPService) EnableTOTP(ctx context.Context, userID int64, code string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
}
if user.TOTPSecret == "" {
return errors.New("\u8bf7\u5148\u521d\u59cb\u5316 2FA\uff0c\u83b7\u53d6\u4e8c\u7ef4\u7801\u540e\u518d\u6fc0\u6d3b")
}
if user.TOTPEnabled {
return errors.New("2FA \u5df2\u542f\u7528")
}
if !s.totpManager.ValidateCode(user.TOTPSecret, code) {
return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f")
}
user.TOTPEnabled = true
return s.userRepo.UpdateTOTP(ctx, user)
}
func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
}
if !user.TOTPEnabled {
return errors.New("2FA \u672a\u542f\u7528")
}
valid := s.totpManager.ValidateCode(user.TOTPSecret, code)
if !valid {
var hashedCodes []string
if user.TOTPRecoveryCodes != "" {
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes)
}
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
return errors.New("\u9a8c\u8bc1\u7801\u6216\u6062\u590d\u7801\u9519\u8bef")
}
}
user.TOTPEnabled = false
user.TOTPSecret = ""
user.TOTPRecoveryCodes = ""
return s.userRepo.UpdateTOTP(ctx, user)
}
func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
}
if !user.TOTPEnabled {
return nil
}
if s.totpManager.ValidateCode(user.TOTPSecret, code) {
return nil
}
var storedCodes []string
if user.TOTPRecoveryCodes != "" {
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes)
}
idx, matched := auth.ValidateRecoveryCode(code, storedCodes)
if !matched {
return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f")
}
storedCodes = append(storedCodes[:idx], storedCodes[idx+1:]...)
codesJSON, _ := json.Marshal(storedCodes)
user.TOTPRecoveryCodes = string(codesJSON)
_ = s.userRepo.UpdateTOTP(ctx, user)
return nil
}
func (s *TOTPService) GetTOTPStatus(ctx context.Context, userID int64) (bool, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return false, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
}
return user.TOTPEnabled, nil
}

View File

@@ -0,0 +1,133 @@
package service
import (
"context"
"errors"
"strings"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
// UserService 用户服务
type UserService struct {
userRepo *repository.UserRepository
userRoleRepo *repository.UserRoleRepository
roleRepo *repository.RoleRepository
passwordHistoryRepo *repository.PasswordHistoryRepository
}
const passwordHistoryLimit = 5 // 保留最近5条密码历史
// NewUserService 创建用户服务实例
func NewUserService(
userRepo *repository.UserRepository,
userRoleRepo *repository.UserRoleRepository,
roleRepo *repository.RoleRepository,
passwordHistoryRepo *repository.PasswordHistoryRepository,
) *UserService {
return &UserService{
userRepo: userRepo,
userRoleRepo: userRoleRepo,
roleRepo: roleRepo,
passwordHistoryRepo: passwordHistoryRepo,
}
}
// ChangePassword 修改用户密码(含历史记录检查)
func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error {
if s.userRepo == nil {
return errors.New("user repository is not configured")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return errors.New("用户不存在")
}
// 验证旧密码
if strings.TrimSpace(oldPassword) == "" {
return errors.New("请输入当前密码")
}
if !auth.VerifyPassword(user.Password, oldPassword) {
return errors.New("当前密码不正确")
}
// 检查新密码强度
if strings.TrimSpace(newPassword) == "" {
return errors.New("新密码不能为空")
}
if err := validatePasswordStrength(newPassword, 8, false); err != nil {
return err
}
// 检查密码历史
if s.passwordHistoryRepo != nil {
histories, err := s.passwordHistoryRepo.GetByUserID(ctx, userID, passwordHistoryLimit)
if err == nil && len(histories) > 0 {
for _, h := range histories {
if auth.VerifyPassword(h.PasswordHash, newPassword) {
return errors.New("新密码不能与最近5次密码相同")
}
}
}
// 保存新密码到历史记录
newHashedPassword, hashErr := auth.HashPassword(newPassword)
if hashErr != nil {
return errors.New("密码哈希失败")
}
go func() {
_ = s.passwordHistoryRepo.Create(context.Background(), &domain.PasswordHistory{
UserID: userID,
PasswordHash: newHashedPassword,
})
_ = s.passwordHistoryRepo.DeleteOldRecords(context.Background(), userID, passwordHistoryLimit)
}()
}
// 更新密码
newHashedPassword, err := auth.HashPassword(newPassword)
if err != nil {
return errors.New("密码哈希失败")
}
user.Password = newHashedPassword
return s.userRepo.Update(ctx, user)
}
// GetByID 根据ID获取用户
func (s *UserService) GetByID(ctx context.Context, id int64) (*domain.User, error) {
return s.userRepo.GetByID(ctx, id)
}
// GetByEmail 根据邮箱获取用户
func (s *UserService) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
return s.userRepo.GetByEmail(ctx, email)
}
// Create 创建用户
func (s *UserService) Create(ctx context.Context, user *domain.User) error {
return s.userRepo.Create(ctx, user)
}
// Update 更新用户
func (s *UserService) Update(ctx context.Context, user *domain.User) error {
return s.userRepo.Update(ctx, user)
}
// Delete 删除用户
func (s *UserService) Delete(ctx context.Context, id int64) error {
return s.userRepo.Delete(ctx, id)
}
// List 获取用户列表
func (s *UserService) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) {
return s.userRepo.List(ctx, offset, limit)
}
// UpdateStatus 更新用户状态
func (s *UserService) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error {
return s.userRepo.UpdateStatus(ctx, id, status)
}

484
internal/service/webhook.go Normal file
View File

@@ -0,0 +1,484 @@
package service
import (
"bytes"
"context"
"crypto/hmac"
cryptorand "crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"log/slog"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
"gorm.io/gorm"
)
// WebhookService Webhook 服务
type WebhookService struct {
db *gorm.DB
repo *repository.WebhookRepository
queue chan *deliveryTask
workers int
config WebhookServiceConfig
wg sync.WaitGroup
once sync.Once
}
type WebhookServiceConfig struct {
Enabled bool
SecretHeader string
TimeoutSec int
MaxRetries int
RetryBackoff string
WorkerCount int
QueueSize int
}
// deliveryTask 投递任务
type deliveryTask struct {
webhook *domain.Webhook
eventType domain.WebhookEventType
payload []byte
attempt int
}
// WebhookEvent 发布的事件结构
type WebhookEvent struct {
EventID string `json:"event_id"`
EventType domain.WebhookEventType `json:"event_type"`
Timestamp time.Time `json:"timestamp"`
Data interface{} `json:"data"`
}
// NewWebhookService 创建 Webhook 服务
func NewWebhookService(db *gorm.DB, cfgs ...WebhookServiceConfig) *WebhookService {
cfg := defaultWebhookServiceConfig()
if len(cfgs) > 0 {
cfg = cfgs[0]
}
if cfg.WorkerCount <= 0 {
cfg.WorkerCount = defaultWebhookServiceConfig().WorkerCount
}
if cfg.QueueSize <= 0 {
cfg.QueueSize = defaultWebhookServiceConfig().QueueSize
}
if cfg.SecretHeader == "" {
cfg.SecretHeader = defaultWebhookServiceConfig().SecretHeader
}
if cfg.TimeoutSec <= 0 {
cfg.TimeoutSec = defaultWebhookServiceConfig().TimeoutSec
}
if cfg.MaxRetries <= 0 {
cfg.MaxRetries = defaultWebhookServiceConfig().MaxRetries
}
if cfg.RetryBackoff == "" {
cfg.RetryBackoff = defaultWebhookServiceConfig().RetryBackoff
}
svc := &WebhookService{
db: db,
repo: repository.NewWebhookRepository(db),
queue: make(chan *deliveryTask, cfg.QueueSize),
workers: cfg.WorkerCount,
config: cfg,
}
svc.startWorkers()
return svc
}
func defaultWebhookServiceConfig() WebhookServiceConfig {
return WebhookServiceConfig{
Enabled: true,
SecretHeader: "X-Webhook-Signature",
TimeoutSec: 10,
MaxRetries: 3,
RetryBackoff: "exponential",
WorkerCount: 4,
QueueSize: 1000,
}
}
// startWorkers 启动后台投递 worker
func (s *WebhookService) startWorkers() {
s.once.Do(func() {
for i := 0; i < s.workers; i++ {
s.wg.Add(1)
go func() {
defer s.wg.Done()
for task := range s.queue {
s.deliver(task)
}
}()
}
})
}
// Publish 发布事件:找到订阅该事件的所有 Webhook异步投递
func (s *WebhookService) Publish(ctx context.Context, eventType domain.WebhookEventType, data interface{}) {
if !s.config.Enabled {
return
}
// 查询所有活跃 Webhook
webhooks, err := s.repo.ListActive(ctx)
if err != nil {
return
}
// 构建事件载荷
eventID, err := generateEventID()
if err != nil {
slog.Error("generate event ID failed", "error", err)
return
}
event := &WebhookEvent{
EventID: eventID,
EventType: eventType,
Timestamp: time.Now().UTC(),
Data: data,
}
payloadBytes, err := json.Marshal(event)
if err != nil {
return
}
for i := range webhooks {
wh := webhooks[i]
// 检查是否订阅了该事件类型
if !webhookSubscribesTo(wh, eventType) {
continue
}
task := &deliveryTask{
webhook: wh,
eventType: eventType,
payload: payloadBytes,
attempt: 1,
}
// 非阻塞投递到队列
select {
case s.queue <- task:
default:
// 队列满时记录但不阻塞
}
}
}
// deliver 执行单次 HTTP 投递
func (s *WebhookService) deliver(task *deliveryTask) {
wh := task.webhook
// NEW-SEC-01 修复:检查 URL 安全性
if !isSafeURL(wh.URL) {
s.recordDelivery(task, 0, "", "webhook URL 不安全: 可能存在 SSRF 风险", false)
return
}
timeout := time.Duration(wh.TimeoutSec) * time.Second
if timeout <= 0 {
timeout = time.Duration(s.config.TimeoutSec) * time.Second
}
if timeout <= 0 {
timeout = 10 * time.Second
}
client := &http.Client{Timeout: timeout}
req, err := http.NewRequest("POST", wh.URL, bytes.NewReader(task.payload))
if err != nil {
s.recordDelivery(task, 0, "", err.Error(), false)
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "UserManagementSystem-Webhook/1.0")
req.Header.Set("X-Webhook-Event", string(task.eventType))
req.Header.Set("X-Webhook-Attempt", fmt.Sprintf("%d", task.attempt))
// HMAC 签名
if wh.Secret != "" {
sig := computeHMAC(task.payload, wh.Secret)
req.Header.Set(s.config.SecretHeader, "sha256="+sig)
}
// 使用带超时的 context 避免请求无限等待
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
s.handleFailure(task, 0, "", err.Error())
return
}
defer resp.Body.Close()
var respBuf bytes.Buffer
respBuf.ReadFrom(resp.Body)
success := resp.StatusCode >= 200 && resp.StatusCode < 300
if !success {
s.handleFailure(task, resp.StatusCode, respBuf.String(), "非 2xx 响应")
return
}
s.recordDelivery(task, resp.StatusCode, respBuf.String(), "", true)
}
// handleFailure 处理投递失败(重试逻辑)
func (s *WebhookService) handleFailure(task *deliveryTask, statusCode int, body, errMsg string) {
s.recordDelivery(task, statusCode, body, errMsg, false)
// 指数退避重试
if task.attempt < task.webhook.MaxRetries {
backoff := time.Second
if s.config.RetryBackoff == "fixed" {
backoff = 2 * time.Second
} else {
backoff = time.Duration(1<<uint(task.attempt)) * time.Second
}
time.AfterFunc(backoff, func() {
task.attempt++
select {
case s.queue <- task:
default:
}
})
}
}
// recordDelivery 记录投递日志
func (s *WebhookService) recordDelivery(task *deliveryTask, statusCode int, body, errMsg string, success bool) {
now := time.Now()
delivery := &domain.WebhookDelivery{
WebhookID: task.webhook.ID,
EventType: task.eventType,
Payload: string(task.payload),
StatusCode: statusCode,
ResponseBody: body,
Attempt: task.attempt,
Success: success,
Error: errMsg,
}
if success {
delivery.DeliveredAt = &now
}
_ = s.repo.CreateDelivery(context.Background(), delivery)
}
// CreateWebhook 创建 Webhook
func (s *WebhookService) CreateWebhook(ctx context.Context, req *CreateWebhookRequest, createdBy int64) (*domain.Webhook, error) {
eventsJSON, err := json.Marshal(req.Events)
if err != nil {
return nil, fmt.Errorf("序列化事件列表失败")
}
secret := req.Secret
if secret == "" {
generatedSecret, err := generateWebhookSecret()
if err != nil {
return nil, fmt.Errorf("generate webhook secret failed: %w", err)
}
secret = generatedSecret
}
wh := &domain.Webhook{
Name: req.Name,
URL: req.URL,
Secret: secret,
Events: string(eventsJSON),
Status: domain.WebhookStatusActive,
MaxRetries: s.config.MaxRetries,
TimeoutSec: s.config.TimeoutSec,
CreatedBy: createdBy,
}
if err := s.repo.Create(ctx, wh); err != nil {
return nil, err
}
return wh, nil
}
// UpdateWebhook 更新 Webhook
func (s *WebhookService) UpdateWebhook(ctx context.Context, id int64, req *UpdateWebhookRequest) error {
updates := map[string]interface{}{}
if req.Name != "" {
updates["name"] = req.Name
}
if req.URL != "" {
updates["url"] = req.URL
}
if len(req.Events) > 0 {
b, _ := json.Marshal(req.Events)
updates["events"] = string(b)
}
if req.Status != nil {
updates["status"] = *req.Status
}
return s.repo.Update(ctx, id, updates)
}
// DeleteWebhook 删除 Webhook
func (s *WebhookService) DeleteWebhook(ctx context.Context, id int64) error {
return s.repo.Delete(ctx, id)
}
func (s *WebhookService) GetWebhook(ctx context.Context, id int64) (*domain.Webhook, error) {
return s.repo.GetByID(ctx, id)
}
// ListWebhooks 获取 Webhook 列表(不分页)
func (s *WebhookService) ListWebhooks(ctx context.Context, createdBy int64) ([]*domain.Webhook, error) {
return s.repo.ListByCreator(ctx, createdBy)
}
// ListWebhooksPaginated 获取 Webhook 列表(分页)
func (s *WebhookService) ListWebhooksPaginated(ctx context.Context, createdBy int64, offset, limit int) ([]*domain.Webhook, int64, error) {
return s.repo.ListByCreatorPaginated(ctx, createdBy, offset, limit)
}
// GetWebhookDeliveries 获取投递记录
func (s *WebhookService) GetWebhookDeliveries(ctx context.Context, webhookID int64, limit int) ([]*domain.WebhookDelivery, error) {
return s.repo.ListDeliveries(ctx, webhookID, limit)
}
// ---- Request/Response 结构 ----
// CreateWebhookRequest 创建 Webhook 请求
type CreateWebhookRequest struct {
Name string `json:"name" binding:"required"`
URL string `json:"url" binding:"required,url"`
Secret string `json:"secret"`
Events []domain.WebhookEventType `json:"events" binding:"required,min=1"`
}
// UpdateWebhookRequest 更新 Webhook 请求
type UpdateWebhookRequest struct {
Name string `json:"name"`
URL string `json:"url"`
Events []domain.WebhookEventType `json:"events"`
Status *domain.WebhookStatus `json:"status"`
}
// ---- 辅助函数 ----
// webhookSubscribesTo 检查 Webhook 是否订阅了指定事件类型
func webhookSubscribesTo(w *domain.Webhook, eventType domain.WebhookEventType) bool {
var events []domain.WebhookEventType
if err := json.Unmarshal([]byte(w.Events), &events); err != nil {
return false
}
for _, e := range events {
if e == eventType || e == "*" {
return true
}
}
return false
}
// SubscribesTo 检查 Webhook 是否订阅了指定事件类型(为 domain.Webhook 添加方法,通过包装实现)
// 注意:此函数在 domain 包外部无法直接扩展,使用独立函数代替
// isSafeURL 检查 URL 是否安全(防止 SSRF 攻击)
// NEW-SEC-01 修复:添加完整的 URL 安全检查
func isSafeURL(rawURL string) bool {
u, err := url.Parse(rawURL)
if err != nil || u.Scheme == "" {
return false
}
// 只允许 http/https
if u.Scheme != "http" && u.Scheme != "https" {
return false
}
host := u.Hostname()
// 禁止 localhost
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
return false
}
// 检查内网 IP
if ip := net.ParseIP(host); ip != nil {
if isPrivateIP(ip) {
return false
}
}
// 检查内网域名
if strings.HasSuffix(host, ".internal") ||
strings.HasSuffix(host, ".local") ||
strings.HasSuffix(host, ".corp") ||
strings.HasSuffix(host, ".lan") ||
strings.HasSuffix(host, ".intranet") {
return false
}
// 检查知名内网服务地址
blockedHosts := []string{
"metadata.google.internal", // GCP 元数据服务
"169.254.169.254", // AWS/Azure/GCP 元数据服务
"metadata.azure.internal", // Azure 元数据服务
"100.100.100.200", // 阿里云元数据服务
}
for _, blocked := range blockedHosts {
if host == blocked {
return false
}
}
return true
}
// isPrivateIP 检查是否为内网 IP
func isPrivateIP(ip net.IP) bool {
privateRanges := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"::1/128",
"fc00::/7",
}
for _, cidr := range privateRanges {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
if network.Contains(ip) {
return true
}
}
return false
}
// computeHMAC 计算 HMAC-SHA256 签名
func computeHMAC(payload []byte, secret string) string {
mac := hmac.New(sha256.New, []byte(secret))
mac.Write(payload)
return hex.EncodeToString(mac.Sum(nil))
}
// generateEventID 生成随机事件 ID
func generateEventID() (string, error) {
b := make([]byte, 8)
if _, err := cryptorand.Read(b); err != nil {
return "", fmt.Errorf("generate event ID failed: %w", err)
}
return "evt_" + hex.EncodeToString(b), nil
}
// generateWebhookSecret 生成随机 Webhook 签名密钥
func generateWebhookSecret() (string, error) {
b := make([]byte, 24)
if _, err := cryptorand.Read(b); err != nil {
return "", fmt.Errorf("generate webhook secret failed: %w", err)
}
return strings.ToLower(hex.EncodeToString(b)), nil
}