- Add atomicTOTPVerifier interface for atomic TOTP/recovery code verification - Implement VerifyTOTPOrRecoveryCode in UserRepository with transaction - Update DisableTOTP to prefer atomic verification path - Add unit tests for atomic verification success/failure paths - Maintain backward compatibility with non-atomic fallback Refs: TOTP verification atomicity completion
203 lines
6.0 KiB
Go
203 lines
6.0 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/user-management-system/internal/auth"
|
|
"github.com/user-management-system/internal/domain"
|
|
)
|
|
|
|
// atomicTOTPRecoveryCodeConsumer 原子性恢复码消费接口
|
|
type atomicTOTPRecoveryCodeConsumer interface {
|
|
ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error)
|
|
}
|
|
|
|
// atomicTOTPVerifier 原子性 TOTP/恢复码验证接口(不消费恢复码)
|
|
type atomicTOTPVerifier interface {
|
|
VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error)
|
|
}
|
|
|
|
// TOTPService manages 2FA setup, enable/disable, and verification.
|
|
type TOTPService struct {
|
|
userRepo userRepositoryInterface
|
|
totpManager *auth.TOTPManager
|
|
}
|
|
|
|
func NewTOTPService(userRepo userRepositoryInterface) *TOTPService {
|
|
return &TOTPService{
|
|
userRepo: userRepo,
|
|
totpManager: auth.NewTOTPManager(),
|
|
}
|
|
}
|
|
|
|
type SetupTOTPResponse struct {
|
|
Secret string `json:"secret"`
|
|
QRCodeBase64 string `json:"qr_code_base64"`
|
|
RecoveryCodes []string `json:"recovery_codes"`
|
|
}
|
|
|
|
func (s *TOTPService) SetupTOTP(ctx context.Context, userID int64) (*SetupTOTPResponse, error) {
|
|
user, err := s.userRepo.GetByID(ctx, userID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
|
}
|
|
if user.TOTPEnabled {
|
|
return nil, errors.New("2FA \u5df2\u7ecf\u542f\u7528\uff0c\u5982\u9700\u91cd\u7f6e\u8bf7\u5148\u7981\u7528")
|
|
}
|
|
|
|
setup, err := s.totpManager.GenerateSecret(user.Username)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("\u751f\u6210 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err)
|
|
}
|
|
|
|
// Persist the generated secret and recovery codes before activation.
|
|
user.TOTPSecret = setup.Secret
|
|
// Hash recovery codes before storing (SEC-03 fix)
|
|
hashedCodes := make([]string, len(setup.RecoveryCodes))
|
|
for i, code := range setup.RecoveryCodes {
|
|
hashedCode, err := auth.HashRecoveryCode(code)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("生成恢复码摘要失败: %w", err)
|
|
}
|
|
hashedCodes[i] = hashedCode
|
|
}
|
|
codesJSON, err := json.Marshal(hashedCodes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("序列化恢复码失败: %w", err)
|
|
}
|
|
user.TOTPRecoveryCodes = string(codesJSON)
|
|
|
|
if err := s.userRepo.UpdateTOTP(ctx, user); err != nil {
|
|
return nil, fmt.Errorf("\u4fdd\u5b58 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err)
|
|
}
|
|
|
|
return &SetupTOTPResponse{
|
|
Secret: setup.Secret,
|
|
QRCodeBase64: setup.QRCodeBase64,
|
|
RecoveryCodes: setup.RecoveryCodes,
|
|
}, nil
|
|
}
|
|
|
|
func (s *TOTPService) EnableTOTP(ctx context.Context, userID int64, code string) error {
|
|
user, err := s.userRepo.GetByID(ctx, userID)
|
|
if err != nil {
|
|
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
|
}
|
|
if user.TOTPSecret == "" {
|
|
return errors.New("\u8bf7\u5148\u521d\u59cb\u5316 2FA\uff0c\u83b7\u53d6\u4e8c\u7ef4\u7801\u540e\u518d\u6fc0\u6d3b")
|
|
}
|
|
if user.TOTPEnabled {
|
|
return errors.New("2FA \u5df2\u542f\u7528")
|
|
}
|
|
|
|
if !s.totpManager.ValidateCode(user.TOTPSecret, code) {
|
|
return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f")
|
|
}
|
|
|
|
user.TOTPEnabled = true
|
|
return s.userRepo.UpdateTOTP(ctx, user)
|
|
}
|
|
|
|
func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string) error {
|
|
user, err := s.userRepo.GetByID(ctx, userID)
|
|
if err != nil {
|
|
return fmt.Errorf("用户不存在")
|
|
}
|
|
if !user.TOTPEnabled {
|
|
return errors.New("2FA 未启用")
|
|
}
|
|
|
|
// 尝试原子性验证(如果 repo 支持)
|
|
if verifier, ok := s.userRepo.(atomicTOTPVerifier); ok {
|
|
valid, err := verifier.VerifyTOTPOrRecoveryCode(ctx, userID, code)
|
|
if err != nil {
|
|
return fmt.Errorf("验证失败: %w", err)
|
|
}
|
|
if !valid {
|
|
return errors.New("验证码或恢复码错误")
|
|
}
|
|
// 验证通过,继续禁用
|
|
} else {
|
|
// 降级到非原子性验证(兼容性模式)
|
|
valid := s.totpManager.ValidateCode(user.TOTPSecret, code)
|
|
if !valid {
|
|
var hashedCodes []string
|
|
if user.TOTPRecoveryCodes != "" {
|
|
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
|
return fmt.Errorf("解析恢复码失败: %w", err)
|
|
}
|
|
}
|
|
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
|
if !matched {
|
|
return errors.New("验证码或恢复码错误")
|
|
}
|
|
}
|
|
}
|
|
|
|
user.TOTPEnabled = false
|
|
user.TOTPSecret = ""
|
|
user.TOTPRecoveryCodes = ""
|
|
return s.userRepo.UpdateTOTP(ctx, user)
|
|
}
|
|
|
|
func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string) error {
|
|
user, err := s.userRepo.GetByID(ctx, userID)
|
|
if err != nil {
|
|
return fmt.Errorf("用户不存在")
|
|
}
|
|
if !user.TOTPEnabled {
|
|
return nil
|
|
}
|
|
|
|
if s.totpManager.ValidateCode(user.TOTPSecret, code) {
|
|
return nil
|
|
}
|
|
|
|
// 尝试原子性消费恢复码(如果 repo 支持)
|
|
if consumer, ok := s.userRepo.(atomicTOTPRecoveryCodeConsumer); ok {
|
|
_, consumed, err := consumer.ConsumeTOTPRecoveryCode(ctx, userID, code)
|
|
if err != nil {
|
|
return fmt.Errorf("消费恢复码失败: %w", err)
|
|
}
|
|
if consumed {
|
|
return nil
|
|
}
|
|
// 恢复码不匹配,继续返回通用错误
|
|
return errors.New("验证码错误或已过期")
|
|
}
|
|
|
|
// 降级到非原子性恢复码消费(兼容性模式)
|
|
var storedCodes []string
|
|
if user.TOTPRecoveryCodes != "" {
|
|
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes); err != nil {
|
|
return fmt.Errorf("解析恢复码失败: %w", err)
|
|
}
|
|
}
|
|
idx, matched := auth.VerifyRecoveryCode(code, storedCodes)
|
|
if !matched {
|
|
return errors.New("验证码错误或已过期")
|
|
}
|
|
|
|
storedCodes = append(storedCodes[:idx], storedCodes[idx+1:]...)
|
|
codesJSON, err := json.Marshal(storedCodes)
|
|
if err != nil {
|
|
return fmt.Errorf("序列化恢复码失败: %w", err)
|
|
}
|
|
user.TOTPRecoveryCodes = string(codesJSON)
|
|
if err := s.userRepo.UpdateTOTP(ctx, user); err != nil {
|
|
return fmt.Errorf("更新恢复码失败: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *TOTPService) GetTOTPStatus(ctx context.Context, userID int64) (bool, error) {
|
|
user, err := s.userRepo.GetByID(ctx, userID)
|
|
if err != nil {
|
|
return false, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
|
}
|
|
return user.TOTPEnabled, nil
|
|
}
|