fix: tighten password and surface persistence errors

This commit is contained in:
Your Name
2026-05-28 20:38:34 +08:00
parent caad1aba0c
commit 9cc5892565
7 changed files with 228 additions and 11 deletions

View File

@@ -306,6 +306,10 @@ func validatePasswordStrength(password string, minLength int, strict bool) error
return nil
}
if info.Length <= minLength && info.Score < 3 {
return errors.New("密码强度不足,短密码需至少包含三种字符类型")
}
if info.Score < 2 {
return errors.New("密码强度不足")
}

View File

@@ -0,0 +1,12 @@
package service
import "testing"
func TestValidatePasswordStrengthBoundaryRules(t *testing.T) {
t.Run("accepts boundary password with three character classes", func(t *testing.T) {
err := validatePasswordStrength("Abcd1234", defaultPasswordMinLen, false)
if err != nil {
t.Fatalf("expected 8-char password with three classes to pass: %v", err)
}
})
}

View File

@@ -138,8 +138,15 @@ func TestValidatePasswordStrength(t *testing.T) {
wantErr: true,
},
{
name: "valid_weak_password_non_strict",
password: "Abcd1234",
name: "boundary_password_requires_three_character_classes",
password: "abcd1234",
minLength: 8,
strict: false,
wantErr: true,
},
{
name: "longer_password_allows_two_character_classes",
password: "abcdefgh1234",
minLength: 8,
strict: false,
wantErr: false,

View File

@@ -47,9 +47,16 @@ func (s *TOTPService) SetupTOTP(ctx context.Context, userID int64) (*SetupTOTPRe
// Hash recovery codes before storing (SEC-03 fix)
hashedCodes := make([]string, len(setup.RecoveryCodes))
for i, code := range setup.RecoveryCodes {
hashedCodes[i], _ = auth.HashRecoveryCode(code)
hashedCode, err := auth.HashRecoveryCode(code)
if err != nil {
return nil, fmt.Errorf("生成恢复码摘要失败: %w", err)
}
hashedCodes[i] = hashedCode
}
codesJSON, err := json.Marshal(hashedCodes)
if err != nil {
return nil, fmt.Errorf("序列化恢复码失败: %w", err)
}
codesJSON, _ := json.Marshal(hashedCodes)
user.TOTPRecoveryCodes = string(codesJSON)
if err := s.userRepo.UpdateTOTP(ctx, user); err != nil {
@@ -96,11 +103,13 @@ func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string
if !valid {
var hashedCodes []string
if user.TOTPRecoveryCodes != "" {
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes)
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
return fmt.Errorf("解析恢复码失败: %w", err)
}
}
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
return errors.New("\u9a8c\u8bc1\u7801\u6216\u6062\u590d\u7801\u9519\u8bef")
return errors.New("验证码或恢复码错误")
}
}
@@ -125,17 +134,24 @@ func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string)
var storedCodes []string
if user.TOTPRecoveryCodes != "" {
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes)
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes); err != nil {
return fmt.Errorf("解析恢复码失败: %w", err)
}
}
idx, matched := auth.ValidateRecoveryCode(code, storedCodes)
if !matched {
return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f")
return errors.New("验证码错误或已过期")
}
storedCodes = append(storedCodes[:idx], storedCodes[idx+1:]...)
codesJSON, _ := json.Marshal(storedCodes)
codesJSON, err := json.Marshal(storedCodes)
if err != nil {
return fmt.Errorf("序列化恢复码失败: %w", err)
}
user.TOTPRecoveryCodes = string(codesJSON)
_ = s.userRepo.UpdateTOTP(ctx, user)
if err := s.userRepo.UpdateTOTP(ctx, user); err != nil {
return fmt.Errorf("更新恢复码失败: %w", err)
}
return nil
}

View File

@@ -0,0 +1,112 @@
package service
import (
"context"
"errors"
"strings"
"testing"
"github.com/user-management-system/internal/domain"
)
type totpTestRepo struct {
user *domain.User
getErr error
updateTOTPErr error
}
func (r *totpTestRepo) Create(ctx context.Context, user *domain.User) error { return nil }
func (r *totpTestRepo) Update(ctx context.Context, user *domain.User) error { return nil }
func (r *totpTestRepo) UpdateTOTP(ctx context.Context, user *domain.User) error {
if r.updateTOTPErr != nil {
return r.updateTOTPErr
}
copyUser := *user
r.user = &copyUser
return nil
}
func (r *totpTestRepo) Delete(ctx context.Context, id int64) error { return nil }
func (r *totpTestRepo) GetByID(ctx context.Context, id int64) (*domain.User, error) {
if r.getErr != nil {
return nil, r.getErr
}
if r.user == nil || r.user.ID != id {
return nil, errors.New("not found")
}
copyUser := *r.user
return &copyUser, nil
}
func (r *totpTestRepo) GetByUsername(ctx context.Context, username string) (*domain.User, error) {
return nil, errors.New("not implemented")
}
func (r *totpTestRepo) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
return nil, errors.New("not implemented")
}
func (r *totpTestRepo) GetByPhone(ctx context.Context, phone string) (*domain.User, error) {
return nil, errors.New("not implemented")
}
func (r *totpTestRepo) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) {
return nil, 0, errors.New("not implemented")
}
func (r *totpTestRepo) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) {
return nil, 0, errors.New("not implemented")
}
func (r *totpTestRepo) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error {
return errors.New("not implemented")
}
func (r *totpTestRepo) UpdateLastLogin(ctx context.Context, id int64, ip string) error {
return errors.New("not implemented")
}
func (r *totpTestRepo) ExistsByUsername(ctx context.Context, username string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *totpTestRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *totpTestRepo) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *totpTestRepo) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
return nil, 0, errors.New("not implemented")
}
func TestTOTPService_ReturnsDecodeErrorForCorruptedRecoveryCodes(t *testing.T) {
repo := &totpTestRepo{user: &domain.User{
ID: 42,
Username: "totp-user",
TOTPEnabled: true,
TOTPSecret: "invalid-secret",
TOTPRecoveryCodes: "not-json",
}}
svc := NewTOTPService(repo)
err := svc.VerifyTOTP(context.Background(), 42, "recovery-code")
if err == nil {
t.Fatal("expected corrupted recovery-code payload to fail")
}
if !strings.Contains(err.Error(), "解析恢复码失败") {
t.Fatalf("expected decode error, got: %v", err)
}
}
func TestTOTPService_ReturnsUpdateErrorAfterRecoveryCodeConsumption(t *testing.T) {
repo := &totpTestRepo{
user: &domain.User{
ID: 7,
Username: "totp-user",
TOTPEnabled: true,
TOTPSecret: "invalid-secret",
TOTPRecoveryCodes: `["RECOVERY-1"]`,
},
updateTOTPErr: errors.New("write failed"),
}
svc := NewTOTPService(repo)
err := svc.VerifyTOTP(context.Background(), 7, "RECOVERY-1")
if err == nil {
t.Fatal("expected update failure to be returned")
}
if !strings.Contains(err.Error(), "更新恢复码失败") {
t.Fatalf("expected update error, got: %v", err)
}
}