Files

149 lines
4.2 KiB
Go
Raw Permalink Normal View History

package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/user-management-system/internal/auth"
)
// TOTPService manages 2FA setup, enable/disable, and verification.
type TOTPService struct {
userRepo userRepositoryInterface
totpManager *auth.TOTPManager
}
func NewTOTPService(userRepo userRepositoryInterface) *TOTPService {
return &TOTPService{
userRepo: userRepo,
totpManager: auth.NewTOTPManager(),
}
}
type SetupTOTPResponse struct {
Secret string `json:"secret"`
QRCodeBase64 string `json:"qr_code_base64"`
RecoveryCodes []string `json:"recovery_codes"`
}
func (s *TOTPService) SetupTOTP(ctx context.Context, userID int64) (*SetupTOTPResponse, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
}
if user.TOTPEnabled {
return nil, errors.New("2FA \u5df2\u7ecf\u542f\u7528\uff0c\u5982\u9700\u91cd\u7f6e\u8bf7\u5148\u7981\u7528")
}
setup, err := s.totpManager.GenerateSecret(user.Username)
if err != nil {
return nil, fmt.Errorf("\u751f\u6210 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err)
}
// Persist the generated secret and recovery codes before activation.
user.TOTPSecret = setup.Secret
// Hash recovery codes before storing (SEC-03 fix)
hashedCodes := make([]string, len(setup.RecoveryCodes))
for i, code := range setup.RecoveryCodes {
hashedCodes[i], _ = auth.HashRecoveryCode(code)
}
codesJSON, _ := json.Marshal(hashedCodes)
user.TOTPRecoveryCodes = string(codesJSON)
if err := s.userRepo.UpdateTOTP(ctx, user); err != nil {
return nil, fmt.Errorf("\u4fdd\u5b58 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err)
}
return &SetupTOTPResponse{
Secret: setup.Secret,
QRCodeBase64: setup.QRCodeBase64,
RecoveryCodes: setup.RecoveryCodes,
}, nil
}
func (s *TOTPService) EnableTOTP(ctx context.Context, userID int64, code string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
}
if user.TOTPSecret == "" {
return errors.New("\u8bf7\u5148\u521d\u59cb\u5316 2FA\uff0c\u83b7\u53d6\u4e8c\u7ef4\u7801\u540e\u518d\u6fc0\u6d3b")
}
if user.TOTPEnabled {
return errors.New("2FA \u5df2\u542f\u7528")
}
if !s.totpManager.ValidateCode(user.TOTPSecret, code) {
return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f")
}
user.TOTPEnabled = true
return s.userRepo.UpdateTOTP(ctx, user)
}
func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
}
if !user.TOTPEnabled {
return errors.New("2FA \u672a\u542f\u7528")
}
valid := s.totpManager.ValidateCode(user.TOTPSecret, code)
if !valid {
var hashedCodes []string
if user.TOTPRecoveryCodes != "" {
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes)
}
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
return errors.New("\u9a8c\u8bc1\u7801\u6216\u6062\u590d\u7801\u9519\u8bef")
}
}
user.TOTPEnabled = false
user.TOTPSecret = ""
user.TOTPRecoveryCodes = ""
return s.userRepo.UpdateTOTP(ctx, user)
}
func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
}
if !user.TOTPEnabled {
return nil
}
if s.totpManager.ValidateCode(user.TOTPSecret, code) {
return nil
}
var storedCodes []string
if user.TOTPRecoveryCodes != "" {
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes)
}
idx, matched := auth.ValidateRecoveryCode(code, storedCodes)
if !matched {
return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f")
}
storedCodes = append(storedCodes[:idx], storedCodes[idx+1:]...)
codesJSON, _ := json.Marshal(storedCodes)
user.TOTPRecoveryCodes = string(codesJSON)
_ = s.userRepo.UpdateTOTP(ctx, user)
return nil
}
func (s *TOTPService) GetTOTPStatus(ctx context.Context, userID int64) (bool, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return false, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
}
return user.TOTPEnabled, nil
}