package service import ( "context" "encoding/json" "errors" "fmt" "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/domain" ) // atomicTOTPRecoveryCodeConsumer 原子性恢复码消费接口 type atomicTOTPRecoveryCodeConsumer interface { ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) } // atomicTOTPVerifier 原子性 TOTP/恢复码验证接口(不消费恢复码) type atomicTOTPVerifier interface { VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) } // TOTPService manages 2FA setup, enable/disable, and verification. type TOTPService struct { userRepo userRepositoryInterface totpManager *auth.TOTPManager } func NewTOTPService(userRepo userRepositoryInterface) *TOTPService { return &TOTPService{ userRepo: userRepo, totpManager: auth.NewTOTPManager(), } } type SetupTOTPResponse struct { Secret string `json:"secret"` QRCodeBase64 string `json:"qr_code_base64"` RecoveryCodes []string `json:"recovery_codes"` } func (s *TOTPService) SetupTOTP(ctx context.Context, userID int64) (*SetupTOTPResponse, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return nil, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728") } if user.TOTPEnabled { return nil, errors.New("2FA \u5df2\u7ecf\u542f\u7528\uff0c\u5982\u9700\u91cd\u7f6e\u8bf7\u5148\u7981\u7528") } setup, err := s.totpManager.GenerateSecret(user.Username) if err != nil { return nil, fmt.Errorf("\u751f\u6210 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err) } // Persist the generated secret and recovery codes before activation. user.TOTPSecret = setup.Secret // Hash recovery codes before storing (SEC-03 fix) hashedCodes := make([]string, len(setup.RecoveryCodes)) for i, code := range setup.RecoveryCodes { hashedCode, err := auth.HashRecoveryCode(code) if err != nil { return nil, fmt.Errorf("生成恢复码摘要失败: %w", err) } hashedCodes[i] = hashedCode } codesJSON, err := json.Marshal(hashedCodes) if err != nil { return nil, fmt.Errorf("序列化恢复码失败: %w", err) } user.TOTPRecoveryCodes = string(codesJSON) if err := s.userRepo.UpdateTOTP(ctx, user); err != nil { return nil, fmt.Errorf("\u4fdd\u5b58 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err) } return &SetupTOTPResponse{ Secret: setup.Secret, QRCodeBase64: setup.QRCodeBase64, RecoveryCodes: setup.RecoveryCodes, }, nil } func (s *TOTPService) EnableTOTP(ctx context.Context, userID int64, code string) error { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728") } if user.TOTPSecret == "" { return errors.New("\u8bf7\u5148\u521d\u59cb\u5316 2FA\uff0c\u83b7\u53d6\u4e8c\u7ef4\u7801\u540e\u518d\u6fc0\u6d3b") } if user.TOTPEnabled { return errors.New("2FA \u5df2\u542f\u7528") } if !s.totpManager.ValidateCode(user.TOTPSecret, code) { return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f") } user.TOTPEnabled = true return s.userRepo.UpdateTOTP(ctx, user) } func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string) error { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return fmt.Errorf("用户不存在") } if !user.TOTPEnabled { return errors.New("2FA 未启用") } // 尝试原子性验证(如果 repo 支持) if verifier, ok := s.userRepo.(atomicTOTPVerifier); ok { valid, err := verifier.VerifyTOTPOrRecoveryCode(ctx, userID, code) if err != nil { return fmt.Errorf("验证失败: %w", err) } if !valid { return errors.New("验证码或恢复码错误") } // 验证通过,继续禁用 } else { // 降级到非原子性验证(兼容性模式) valid := s.totpManager.ValidateCode(user.TOTPSecret, code) if !valid { var hashedCodes []string if user.TOTPRecoveryCodes != "" { if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil { return fmt.Errorf("解析恢复码失败: %w", err) } } _, matched := auth.VerifyRecoveryCode(code, hashedCodes) if !matched { return errors.New("验证码或恢复码错误") } } } user.TOTPEnabled = false user.TOTPSecret = "" user.TOTPRecoveryCodes = "" return s.userRepo.UpdateTOTP(ctx, user) } func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string) error { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return fmt.Errorf("用户不存在") } if !user.TOTPEnabled { return nil } if s.totpManager.ValidateCode(user.TOTPSecret, code) { return nil } // 尝试原子性消费恢复码(如果 repo 支持) if consumer, ok := s.userRepo.(atomicTOTPRecoveryCodeConsumer); ok { _, consumed, err := consumer.ConsumeTOTPRecoveryCode(ctx, userID, code) if err != nil { return fmt.Errorf("消费恢复码失败: %w", err) } if consumed { return nil } // 恢复码不匹配,继续返回通用错误 return errors.New("验证码错误或已过期") } // 降级到非原子性恢复码消费(兼容性模式) var storedCodes []string if user.TOTPRecoveryCodes != "" { if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes); err != nil { return fmt.Errorf("解析恢复码失败: %w", err) } } idx, matched := auth.VerifyRecoveryCode(code, storedCodes) if !matched { return errors.New("验证码错误或已过期") } storedCodes = append(storedCodes[:idx], storedCodes[idx+1:]...) codesJSON, err := json.Marshal(storedCodes) if err != nil { return fmt.Errorf("序列化恢复码失败: %w", err) } user.TOTPRecoveryCodes = string(codesJSON) if err := s.userRepo.UpdateTOTP(ctx, user); err != nil { return fmt.Errorf("更新恢复码失败: %w", err) } return nil } func (s *TOTPService) GetTOTPStatus(ctx context.Context, userID int64) (bool, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return false, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728") } return user.TOTPEnabled, nil }