feat: atomic TOTP verification for DisableTOTP
- Add atomicTOTPVerifier interface for atomic TOTP/recovery code verification - Implement VerifyTOTPOrRecoveryCode in UserRepository with transaction - Update DisableTOTP to prefer atomic verification path - Add unit tests for atomic verification success/failure paths - Maintain backward compatibility with non-atomic fallback Refs: TOTP verification atomicity completion
This commit is contained in:
@@ -292,6 +292,58 @@ func (r *UserRepository) ConsumeTOTPRecoveryCode(ctx context.Context, userID int
|
|||||||
return &user, consumed, nil
|
return &user, consumed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// VerifyTOTPOrRecoveryCode 原子性地验证 TOTP 码或恢复码(不消费恢复码)
|
||||||
|
// 返回 (true, nil) 表示验证成功
|
||||||
|
// 返回 (false, nil) 表示验证失败(码不匹配)
|
||||||
|
// 返回 (false, error) 表示执行出错
|
||||||
|
func (r *UserRepository) VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) {
|
||||||
|
var user domain.User
|
||||||
|
|
||||||
|
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||||
|
if err := tx.First(&user, userID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.TOTPEnabled {
|
||||||
|
return errors.New("TOTP 未启用")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先验证 TOTP 码
|
||||||
|
manager := auth.NewTOTPManager()
|
||||||
|
if manager.ValidateCode(user.TOTPSecret, code) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TOTP 码无效,尝试验证恢复码
|
||||||
|
var hashedCodes []string
|
||||||
|
if user.TOTPRecoveryCodes != "" {
|
||||||
|
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
||||||
|
return fmt.Errorf("解析恢复码失败: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||||||
|
if !matched {
|
||||||
|
// 恢复码也不匹配,标记验证失败
|
||||||
|
return errVerificationFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err == errVerificationFailed {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// errVerificationFailed 标记验证失败的内部错误
|
||||||
|
var errVerificationFailed = errors.New("verification failed")
|
||||||
|
|
||||||
// UpdatePassword 更新用户密码
|
// UpdatePassword 更新用户密码
|
||||||
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
|
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
|
||||||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
|
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
|
||||||
|
|||||||
@@ -10,11 +10,16 @@ import (
|
|||||||
"github.com/user-management-system/internal/domain"
|
"github.com/user-management-system/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TOTPService manages 2FA setup, enable/disable, and verification.
|
// atomicTOTPRecoveryCodeConsumer 原子性恢复码消费接口
|
||||||
type atomicTOTPRecoveryCodeConsumer interface {
|
type atomicTOTPRecoveryCodeConsumer interface {
|
||||||
ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error)
|
ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// atomicTOTPVerifier 原子性 TOTP/恢复码验证接口(不消费恢复码)
|
||||||
|
type atomicTOTPVerifier interface {
|
||||||
|
VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
// TOTPService manages 2FA setup, enable/disable, and verification.
|
// TOTPService manages 2FA setup, enable/disable, and verification.
|
||||||
type TOTPService struct {
|
type TOTPService struct {
|
||||||
userRepo userRepositoryInterface
|
userRepo userRepositoryInterface
|
||||||
@@ -99,24 +104,37 @@ func (s *TOTPService) EnableTOTP(ctx context.Context, userID int64, code string)
|
|||||||
func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string) error {
|
func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string) error {
|
||||||
user, err := s.userRepo.GetByID(ctx, userID)
|
user, err := s.userRepo.GetByID(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
return fmt.Errorf("用户不存在")
|
||||||
}
|
}
|
||||||
if !user.TOTPEnabled {
|
if !user.TOTPEnabled {
|
||||||
return errors.New("2FA \u672a\u542f\u7528")
|
return errors.New("2FA 未启用")
|
||||||
}
|
}
|
||||||
|
|
||||||
valid := s.totpManager.ValidateCode(user.TOTPSecret, code)
|
// 尝试原子性验证(如果 repo 支持)
|
||||||
if !valid {
|
if verifier, ok := s.userRepo.(atomicTOTPVerifier); ok {
|
||||||
var hashedCodes []string
|
valid, err := verifier.VerifyTOTPOrRecoveryCode(ctx, userID, code)
|
||||||
if user.TOTPRecoveryCodes != "" {
|
if err != nil {
|
||||||
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
return fmt.Errorf("验证失败: %w", err)
|
||||||
return fmt.Errorf("解析恢复码失败: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
if !valid {
|
||||||
if !matched {
|
|
||||||
return errors.New("验证码或恢复码错误")
|
return errors.New("验证码或恢复码错误")
|
||||||
}
|
}
|
||||||
|
// 验证通过,继续禁用
|
||||||
|
} else {
|
||||||
|
// 降级到非原子性验证(兼容性模式)
|
||||||
|
valid := s.totpManager.ValidateCode(user.TOTPSecret, code)
|
||||||
|
if !valid {
|
||||||
|
var hashedCodes []string
|
||||||
|
if user.TOTPRecoveryCodes != "" {
|
||||||
|
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
||||||
|
return fmt.Errorf("解析恢复码失败: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||||||
|
if !matched {
|
||||||
|
return errors.New("验证码或恢复码错误")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
user.TOTPEnabled = false
|
user.TOTPEnabled = false
|
||||||
|
|||||||
@@ -31,11 +31,13 @@ func mustMarshalJSON(t *testing.T, value any) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type totpTestRepo struct {
|
type totpTestRepo struct {
|
||||||
user *domain.User
|
user *domain.User
|
||||||
getErr error
|
getErr error
|
||||||
updateTOTPErr error
|
updateTOTPErr error
|
||||||
consumeRecoveryCodeErr error
|
consumeRecoveryCodeErr error
|
||||||
consumeRecoveryCodeCalled bool
|
consumeRecoveryCodeCalled bool
|
||||||
|
verifyTOTPOrRecoveryCodeErr error
|
||||||
|
verifyTOTPOrRecoveryCodeCalled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *totpTestRepo) Create(ctx context.Context, user *domain.User) error { return nil }
|
func (r *totpTestRepo) Create(ctx context.Context, user *domain.User) error { return nil }
|
||||||
@@ -89,9 +91,11 @@ func (r *totpTestRepo) ExistsByEmail(ctx context.Context, email string) (bool, e
|
|||||||
func (r *totpTestRepo) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
|
func (r *totpTestRepo) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
|
||||||
return false, errors.New("not implemented")
|
return false, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *totpTestRepo) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
|
func (r *totpTestRepo) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
|
||||||
return nil, 0, errors.New("not implemented")
|
return nil, 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) {
|
func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) {
|
||||||
r.consumeRecoveryCodeCalled = true
|
r.consumeRecoveryCodeCalled = true
|
||||||
if r.consumeRecoveryCodeErr != nil {
|
if r.consumeRecoveryCodeErr != nil {
|
||||||
@@ -119,6 +123,34 @@ func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64
|
|||||||
return ©User, true, nil
|
return ©User, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *totpTestRepo) VerifyTOTPOrRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) {
|
||||||
|
r.verifyTOTPOrRecoveryCodeCalled = true
|
||||||
|
if r.verifyTOTPOrRecoveryCodeErr != nil {
|
||||||
|
return false, r.verifyTOTPOrRecoveryCodeErr
|
||||||
|
}
|
||||||
|
if r.user == nil || r.user.ID != userID {
|
||||||
|
return false, errors.New("not found")
|
||||||
|
}
|
||||||
|
if !r.user.TOTPEnabled {
|
||||||
|
return false, errors.New("TOTP not enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试验证 TOTP 码(简化:只检查是否为特定测试码)
|
||||||
|
if code == "123456" || code == "654321" {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试验证恢复码
|
||||||
|
var hashedCodes []string
|
||||||
|
if strings.TrimSpace(r.user.TOTPRecoveryCodes) != "" {
|
||||||
|
if err := json.Unmarshal([]byte(r.user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
||||||
|
return false, fmt.Errorf("解析恢复码失败: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||||||
|
return matched, nil
|
||||||
|
}
|
||||||
|
|
||||||
func mustMarshalJSONFromHelper(value any) string {
|
func mustMarshalJSONFromHelper(value any) string {
|
||||||
payload, err := json.Marshal(value)
|
payload, err := json.Marshal(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -205,6 +237,59 @@ func TestTOTPService_ConsumesHashedRecoveryCodeOnVerify(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTOTPService_DisableTOTP_UsesAtomicVerificationPath(t *testing.T) {
|
||||||
|
repo := &totpTestRepo{
|
||||||
|
user: &domain.User{
|
||||||
|
ID: 10,
|
||||||
|
Username: "totp-user",
|
||||||
|
TOTPEnabled: true,
|
||||||
|
TOTPSecret: "test-secret",
|
||||||
|
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewTOTPService(repo)
|
||||||
|
|
||||||
|
// 使用测试恢复码禁用 TOTP
|
||||||
|
if err := svc.DisableTOTP(context.Background(), 10, "RECOVERY-1"); err != nil {
|
||||||
|
t.Fatalf("expected disable to succeed with recovery code, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !repo.verifyTOTPOrRecoveryCodeCalled {
|
||||||
|
t.Fatal("expected atomic verification path to be used")
|
||||||
|
}
|
||||||
|
|
||||||
|
if repo.user.TOTPEnabled {
|
||||||
|
t.Fatal("expected TOTP to be disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPService_DisableTOTP_AtomicVerificationFailsOnWrongCode(t *testing.T) {
|
||||||
|
repo := &totpTestRepo{
|
||||||
|
user: &domain.User{
|
||||||
|
ID: 11,
|
||||||
|
Username: "totp-user",
|
||||||
|
TOTPEnabled: true,
|
||||||
|
TOTPSecret: "test-secret",
|
||||||
|
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewTOTPService(repo)
|
||||||
|
|
||||||
|
// 使用错误的恢复码
|
||||||
|
err := svc.DisableTOTP(context.Background(), 11, "WRONG-CODE")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected disable to fail with wrong code")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !repo.verifyTOTPOrRecoveryCodeCalled {
|
||||||
|
t.Fatal("expected atomic verification path to be used")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !repo.user.TOTPEnabled {
|
||||||
|
t.Fatal("expected TOTP to remain enabled after failed verification")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTOTPService_DisableAcceptsHashedRecoveryCode(t *testing.T) {
|
func TestTOTPService_DisableAcceptsHashedRecoveryCode(t *testing.T) {
|
||||||
repo := &totpTestRepo{
|
repo := &totpTestRepo{
|
||||||
user: &domain.User{
|
user: &domain.User{
|
||||||
|
|||||||
Reference in New Issue
Block a user