Files
user-system/internal/service/totp.go
Your Name 363c77d020 feat: atomic TOTP verification for DisableTOTP
- Add atomicTOTPVerifier interface for atomic TOTP/recovery code verification
- Implement VerifyTOTPOrRecoveryCode in UserRepository with transaction
- Update DisableTOTP to prefer atomic verification path
- Add unit tests for atomic verification success/failure paths
- Maintain backward compatibility with non-atomic fallback

Refs: TOTP verification atomicity completion
2026-05-29 12:47:05 +08:00

203 lines
6.0 KiB
Go

package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
)
// atomicTOTPRecoveryCodeConsumer 原子性恢复码消费接口
type atomicTOTPRecoveryCodeConsumer interface {
ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error)
}
// atomicTOTPVerifier 原子性 TOTP/恢复码验证接口(不消费恢复码)
type atomicTOTPVerifier interface {
VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error)
}
// TOTPService manages 2FA setup, enable/disable, and verification.
type TOTPService struct {
userRepo userRepositoryInterface
totpManager *auth.TOTPManager
}
func NewTOTPService(userRepo userRepositoryInterface) *TOTPService {
return &TOTPService{
userRepo: userRepo,
totpManager: auth.NewTOTPManager(),
}
}
type SetupTOTPResponse struct {
Secret string `json:"secret"`
QRCodeBase64 string `json:"qr_code_base64"`
RecoveryCodes []string `json:"recovery_codes"`
}
func (s *TOTPService) SetupTOTP(ctx context.Context, userID int64) (*SetupTOTPResponse, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
}
if user.TOTPEnabled {
return nil, errors.New("2FA \u5df2\u7ecf\u542f\u7528\uff0c\u5982\u9700\u91cd\u7f6e\u8bf7\u5148\u7981\u7528")
}
setup, err := s.totpManager.GenerateSecret(user.Username)
if err != nil {
return nil, fmt.Errorf("\u751f\u6210 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err)
}
// Persist the generated secret and recovery codes before activation.
user.TOTPSecret = setup.Secret
// Hash recovery codes before storing (SEC-03 fix)
hashedCodes := make([]string, len(setup.RecoveryCodes))
for i, code := range setup.RecoveryCodes {
hashedCode, err := auth.HashRecoveryCode(code)
if err != nil {
return nil, fmt.Errorf("生成恢复码摘要失败: %w", err)
}
hashedCodes[i] = hashedCode
}
codesJSON, err := json.Marshal(hashedCodes)
if err != nil {
return nil, fmt.Errorf("序列化恢复码失败: %w", err)
}
user.TOTPRecoveryCodes = string(codesJSON)
if err := s.userRepo.UpdateTOTP(ctx, user); err != nil {
return nil, fmt.Errorf("\u4fdd\u5b58 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err)
}
return &SetupTOTPResponse{
Secret: setup.Secret,
QRCodeBase64: setup.QRCodeBase64,
RecoveryCodes: setup.RecoveryCodes,
}, nil
}
func (s *TOTPService) EnableTOTP(ctx context.Context, userID int64, code string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
}
if user.TOTPSecret == "" {
return errors.New("\u8bf7\u5148\u521d\u59cb\u5316 2FA\uff0c\u83b7\u53d6\u4e8c\u7ef4\u7801\u540e\u518d\u6fc0\u6d3b")
}
if user.TOTPEnabled {
return errors.New("2FA \u5df2\u542f\u7528")
}
if !s.totpManager.ValidateCode(user.TOTPSecret, code) {
return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f")
}
user.TOTPEnabled = true
return s.userRepo.UpdateTOTP(ctx, user)
}
func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("用户不存在")
}
if !user.TOTPEnabled {
return errors.New("2FA 未启用")
}
// 尝试原子性验证(如果 repo 支持)
if verifier, ok := s.userRepo.(atomicTOTPVerifier); ok {
valid, err := verifier.VerifyTOTPOrRecoveryCode(ctx, userID, code)
if err != nil {
return fmt.Errorf("验证失败: %w", err)
}
if !valid {
return errors.New("验证码或恢复码错误")
}
// 验证通过,继续禁用
} else {
// 降级到非原子性验证(兼容性模式)
valid := s.totpManager.ValidateCode(user.TOTPSecret, code)
if !valid {
var hashedCodes []string
if user.TOTPRecoveryCodes != "" {
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
return fmt.Errorf("解析恢复码失败: %w", err)
}
}
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
return errors.New("验证码或恢复码错误")
}
}
}
user.TOTPEnabled = false
user.TOTPSecret = ""
user.TOTPRecoveryCodes = ""
return s.userRepo.UpdateTOTP(ctx, user)
}
func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("用户不存在")
}
if !user.TOTPEnabled {
return nil
}
if s.totpManager.ValidateCode(user.TOTPSecret, code) {
return nil
}
// 尝试原子性消费恢复码(如果 repo 支持)
if consumer, ok := s.userRepo.(atomicTOTPRecoveryCodeConsumer); ok {
_, consumed, err := consumer.ConsumeTOTPRecoveryCode(ctx, userID, code)
if err != nil {
return fmt.Errorf("消费恢复码失败: %w", err)
}
if consumed {
return nil
}
// 恢复码不匹配,继续返回通用错误
return errors.New("验证码错误或已过期")
}
// 降级到非原子性恢复码消费(兼容性模式)
var storedCodes []string
if user.TOTPRecoveryCodes != "" {
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes); err != nil {
return fmt.Errorf("解析恢复码失败: %w", err)
}
}
idx, matched := auth.VerifyRecoveryCode(code, storedCodes)
if !matched {
return errors.New("验证码错误或已过期")
}
storedCodes = append(storedCodes[:idx], storedCodes[idx+1:]...)
codesJSON, err := json.Marshal(storedCodes)
if err != nil {
return fmt.Errorf("序列化恢复码失败: %w", err)
}
user.TOTPRecoveryCodes = string(codesJSON)
if err := s.userRepo.UpdateTOTP(ctx, user); err != nil {
return fmt.Errorf("更新恢复码失败: %w", err)
}
return nil
}
func (s *TOTPService) GetTOTPStatus(ctx context.Context, userID int64) (bool, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return false, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
}
return user.TOTPEnabled, nil
}