269 lines
7.5 KiB
Go
269 lines
7.5 KiB
Go
|
|
package service
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"crypto/rand"
|
|||
|
|
"encoding/hex"
|
|||
|
|
"fmt"
|
|||
|
|
"strings"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
|||
|
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
|||
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
var (
|
|||
|
|
ErrPromoCodeNotFound = infraerrors.NotFound("PROMO_CODE_NOT_FOUND", "promo code not found")
|
|||
|
|
ErrPromoCodeExpired = infraerrors.BadRequest("PROMO_CODE_EXPIRED", "promo code has expired")
|
|||
|
|
ErrPromoCodeDisabled = infraerrors.BadRequest("PROMO_CODE_DISABLED", "promo code is disabled")
|
|||
|
|
ErrPromoCodeMaxUsed = infraerrors.BadRequest("PROMO_CODE_MAX_USED", "promo code has reached maximum uses")
|
|||
|
|
ErrPromoCodeAlreadyUsed = infraerrors.Conflict("PROMO_CODE_ALREADY_USED", "you have already used this promo code")
|
|||
|
|
ErrPromoCodeInvalid = infraerrors.BadRequest("PROMO_CODE_INVALID", "invalid promo code")
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// PromoService 优惠码服务
|
|||
|
|
type PromoService struct {
|
|||
|
|
promoRepo PromoCodeRepository
|
|||
|
|
userRepo UserRepository
|
|||
|
|
billingCacheService *BillingCacheService
|
|||
|
|
entClient *dbent.Client
|
|||
|
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewPromoService 创建优惠码服务实例
|
|||
|
|
func NewPromoService(
|
|||
|
|
promoRepo PromoCodeRepository,
|
|||
|
|
userRepo UserRepository,
|
|||
|
|
billingCacheService *BillingCacheService,
|
|||
|
|
entClient *dbent.Client,
|
|||
|
|
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
|||
|
|
) *PromoService {
|
|||
|
|
return &PromoService{
|
|||
|
|
promoRepo: promoRepo,
|
|||
|
|
userRepo: userRepo,
|
|||
|
|
billingCacheService: billingCacheService,
|
|||
|
|
entClient: entClient,
|
|||
|
|
authCacheInvalidator: authCacheInvalidator,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ValidatePromoCode 验证优惠码(注册前调用)
|
|||
|
|
// 返回 nil, nil 表示空码(不报错)
|
|||
|
|
func (s *PromoService) ValidatePromoCode(ctx context.Context, code string) (*PromoCode, error) {
|
|||
|
|
code = strings.TrimSpace(code)
|
|||
|
|
if code == "" {
|
|||
|
|
return nil, nil // 空码不报错,直接返回
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
promoCode, err := s.promoRepo.GetByCode(ctx, code)
|
|||
|
|
if err != nil {
|
|||
|
|
// 保留原始错误类型,不要统一映射为 NotFound
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := s.validatePromoCodeStatus(promoCode); err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return promoCode, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// validatePromoCodeStatus 验证优惠码状态
|
|||
|
|
func (s *PromoService) validatePromoCodeStatus(promoCode *PromoCode) error {
|
|||
|
|
if !promoCode.CanUse() {
|
|||
|
|
if promoCode.IsExpired() {
|
|||
|
|
return ErrPromoCodeExpired
|
|||
|
|
}
|
|||
|
|
if promoCode.Status == PromoCodeStatusDisabled {
|
|||
|
|
return ErrPromoCodeDisabled
|
|||
|
|
}
|
|||
|
|
if promoCode.MaxUses > 0 && promoCode.UsedCount >= promoCode.MaxUses {
|
|||
|
|
return ErrPromoCodeMaxUsed
|
|||
|
|
}
|
|||
|
|
return ErrPromoCodeInvalid
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ApplyPromoCode 应用优惠码(注册成功后调用)
|
|||
|
|
// 使用事务和行锁确保并发安全
|
|||
|
|
func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code string) error {
|
|||
|
|
code = strings.TrimSpace(code)
|
|||
|
|
if code == "" {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 开启事务
|
|||
|
|
tx, err := s.entClient.Tx(ctx)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("begin transaction: %w", err)
|
|||
|
|
}
|
|||
|
|
defer func() { _ = tx.Rollback() }()
|
|||
|
|
|
|||
|
|
txCtx := dbent.NewTxContext(ctx, tx)
|
|||
|
|
|
|||
|
|
// 在事务中获取并锁定优惠码记录(FOR UPDATE)
|
|||
|
|
promoCode, err := s.promoRepo.GetByCodeForUpdate(txCtx, code)
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 在事务中验证优惠码状态
|
|||
|
|
if err := s.validatePromoCodeStatus(promoCode); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 在事务中检查用户是否已使用过此优惠码
|
|||
|
|
existing, err := s.promoRepo.GetUsageByPromoCodeAndUser(txCtx, promoCode.ID, userID)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("check existing usage: %w", err)
|
|||
|
|
}
|
|||
|
|
if existing != nil {
|
|||
|
|
return ErrPromoCodeAlreadyUsed
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 增加用户余额
|
|||
|
|
if err := s.userRepo.UpdateBalance(txCtx, userID, promoCode.BonusAmount); err != nil {
|
|||
|
|
return fmt.Errorf("update user balance: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 创建使用记录
|
|||
|
|
usage := &PromoCodeUsage{
|
|||
|
|
PromoCodeID: promoCode.ID,
|
|||
|
|
UserID: userID,
|
|||
|
|
BonusAmount: promoCode.BonusAmount,
|
|||
|
|
UsedAt: time.Now(),
|
|||
|
|
}
|
|||
|
|
if err := s.promoRepo.CreateUsage(txCtx, usage); err != nil {
|
|||
|
|
return fmt.Errorf("create usage record: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 增加使用次数
|
|||
|
|
if err := s.promoRepo.IncrementUsedCount(txCtx, promoCode.ID); err != nil {
|
|||
|
|
return fmt.Errorf("increment used count: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := tx.Commit(); err != nil {
|
|||
|
|
return fmt.Errorf("commit transaction: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount)
|
|||
|
|
|
|||
|
|
// 失效余额缓存
|
|||
|
|
if s.billingCacheService != nil {
|
|||
|
|
go func() {
|
|||
|
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|||
|
|
defer cancel()
|
|||
|
|
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
|||
|
|
}()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) {
|
|||
|
|
if bonusAmount == 0 || s.authCacheInvalidator == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GenerateRandomCode 生成随机优惠码
|
|||
|
|
func (s *PromoService) GenerateRandomCode() (string, error) {
|
|||
|
|
bytes := make([]byte, 8)
|
|||
|
|
if _, err := rand.Read(bytes); err != nil {
|
|||
|
|
return "", fmt.Errorf("generate random bytes: %w", err)
|
|||
|
|
}
|
|||
|
|
return strings.ToUpper(hex.EncodeToString(bytes)), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Create 创建优惠码
|
|||
|
|
func (s *PromoService) Create(ctx context.Context, input *CreatePromoCodeInput) (*PromoCode, error) {
|
|||
|
|
code := strings.TrimSpace(input.Code)
|
|||
|
|
if code == "" {
|
|||
|
|
// 自动生成
|
|||
|
|
var err error
|
|||
|
|
code, err = s.GenerateRandomCode()
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
promoCode := &PromoCode{
|
|||
|
|
Code: strings.ToUpper(code),
|
|||
|
|
BonusAmount: input.BonusAmount,
|
|||
|
|
MaxUses: input.MaxUses,
|
|||
|
|
UsedCount: 0,
|
|||
|
|
Status: PromoCodeStatusActive,
|
|||
|
|
ExpiresAt: input.ExpiresAt,
|
|||
|
|
Notes: input.Notes,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := s.promoRepo.Create(ctx, promoCode); err != nil {
|
|||
|
|
return nil, fmt.Errorf("create promo code: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return promoCode, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetByID 根据ID获取优惠码
|
|||
|
|
func (s *PromoService) GetByID(ctx context.Context, id int64) (*PromoCode, error) {
|
|||
|
|
code, err := s.promoRepo.GetByID(ctx, id)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
return code, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Update 更新优惠码
|
|||
|
|
func (s *PromoService) Update(ctx context.Context, id int64, input *UpdatePromoCodeInput) (*PromoCode, error) {
|
|||
|
|
promoCode, err := s.promoRepo.GetByID(ctx, id)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if input.Code != nil {
|
|||
|
|
promoCode.Code = strings.ToUpper(strings.TrimSpace(*input.Code))
|
|||
|
|
}
|
|||
|
|
if input.BonusAmount != nil {
|
|||
|
|
promoCode.BonusAmount = *input.BonusAmount
|
|||
|
|
}
|
|||
|
|
if input.MaxUses != nil {
|
|||
|
|
promoCode.MaxUses = *input.MaxUses
|
|||
|
|
}
|
|||
|
|
if input.Status != nil {
|
|||
|
|
promoCode.Status = *input.Status
|
|||
|
|
}
|
|||
|
|
if input.ExpiresAt != nil {
|
|||
|
|
promoCode.ExpiresAt = input.ExpiresAt
|
|||
|
|
}
|
|||
|
|
if input.Notes != nil {
|
|||
|
|
promoCode.Notes = *input.Notes
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := s.promoRepo.Update(ctx, promoCode); err != nil {
|
|||
|
|
return nil, fmt.Errorf("update promo code: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return promoCode, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Delete 删除优惠码
|
|||
|
|
func (s *PromoService) Delete(ctx context.Context, id int64) error {
|
|||
|
|
if err := s.promoRepo.Delete(ctx, id); err != nil {
|
|||
|
|
return fmt.Errorf("delete promo code: %w", err)
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// List 获取优惠码列表
|
|||
|
|
func (s *PromoService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error) {
|
|||
|
|
return s.promoRepo.ListWithFilters(ctx, params, status, search)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ListUsages 获取使用记录
|
|||
|
|
func (s *PromoService) ListUsages(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error) {
|
|||
|
|
return s.promoRepo.ListUsagesByPromoCode(ctx, promoCodeID, params)
|
|||
|
|
}
|