fix: tighten password and surface persistence errors
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -87,10 +88,16 @@ func (m *OperationLogMiddleware) Record() gin.HandlerFunc {
|
|||||||
UserAgent: c.Request.UserAgent(),
|
UserAgent: c.Request.UserAgent(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m == nil || m.repo == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
go func(entry *domain.OperationLog) {
|
go func(entry *domain.OperationLog) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
_ = m.repo.Create(ctx, entry)
|
if err := m.repo.Create(ctx, entry); err != nil {
|
||||||
|
log.Printf("[operation-log] create failed: %v", err)
|
||||||
|
}
|
||||||
}(logEntry)
|
}(logEntry)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
59
internal/api/middleware/operation_log_test.go
Normal file
59
internal/api/middleware/operation_log_test.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOperationLogRecord_AllowsNilRepository(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
router := gin.New()
|
||||||
|
router.Use((&OperationLogMiddleware{}).Record())
|
||||||
|
router.POST("/operation-log", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusCreated, gin.H{"ok": true})
|
||||||
|
})
|
||||||
|
|
||||||
|
body := bytes.NewBufferString(`{"password":"secret","token":"abc"}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/operation-log", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("unexpected status: got %d want %d", recorder.Code, http.StatusCreated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeParams_MasksSensitiveFields(t *testing.T) {
|
||||||
|
sanitized := sanitizeParams([]byte(`{"password":"secret","nested":"ok","token":"abc"}`))
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(sanitized), &payload); err != nil {
|
||||||
|
t.Fatalf("sanitized payload should remain valid json: %v", err)
|
||||||
|
}
|
||||||
|
if payload["password"] != "***" {
|
||||||
|
t.Fatalf("password should be masked, got: %#v", payload["password"])
|
||||||
|
}
|
||||||
|
if payload["token"] != "***" {
|
||||||
|
t.Fatalf("token should be masked, got: %#v", payload["token"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeParams_FallbacksForNonJSONPayload(t *testing.T) {
|
||||||
|
longText := strings.Repeat("x", 600)
|
||||||
|
sanitized := sanitizeParams([]byte(longText))
|
||||||
|
if len(sanitized) != 503 {
|
||||||
|
t.Fatalf("expected truncated fallback length 503, got %d", len(sanitized))
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(sanitized, "...") {
|
||||||
|
t.Fatalf("expected truncated fallback to end with ellipsis: %q", sanitized[len(sanitized)-3:])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -306,6 +306,10 @@ func validatePasswordStrength(password string, minLength int, strict bool) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if info.Length <= minLength && info.Score < 3 {
|
||||||
|
return errors.New("密码强度不足,短密码需至少包含三种字符类型")
|
||||||
|
}
|
||||||
|
|
||||||
if info.Score < 2 {
|
if info.Score < 2 {
|
||||||
return errors.New("密码强度不足")
|
return errors.New("密码强度不足")
|
||||||
}
|
}
|
||||||
|
|||||||
12
internal/service/auth_password_internal_test.go
Normal file
12
internal/service/auth_password_internal_test.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestValidatePasswordStrengthBoundaryRules(t *testing.T) {
|
||||||
|
t.Run("accepts boundary password with three character classes", func(t *testing.T) {
|
||||||
|
err := validatePasswordStrength("Abcd1234", defaultPasswordMinLen, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected 8-char password with three classes to pass: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -138,8 +138,15 @@ func TestValidatePasswordStrength(t *testing.T) {
|
|||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid_weak_password_non_strict",
|
name: "boundary_password_requires_three_character_classes",
|
||||||
password: "Abcd1234",
|
password: "abcd1234",
|
||||||
|
minLength: 8,
|
||||||
|
strict: false,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "longer_password_allows_two_character_classes",
|
||||||
|
password: "abcdefgh1234",
|
||||||
minLength: 8,
|
minLength: 8,
|
||||||
strict: false,
|
strict: false,
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
|||||||
@@ -47,9 +47,16 @@ func (s *TOTPService) SetupTOTP(ctx context.Context, userID int64) (*SetupTOTPRe
|
|||||||
// Hash recovery codes before storing (SEC-03 fix)
|
// Hash recovery codes before storing (SEC-03 fix)
|
||||||
hashedCodes := make([]string, len(setup.RecoveryCodes))
|
hashedCodes := make([]string, len(setup.RecoveryCodes))
|
||||||
for i, code := range setup.RecoveryCodes {
|
for i, code := range setup.RecoveryCodes {
|
||||||
hashedCodes[i], _ = auth.HashRecoveryCode(code)
|
hashedCode, err := auth.HashRecoveryCode(code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("生成恢复码摘要失败: %w", err)
|
||||||
|
}
|
||||||
|
hashedCodes[i] = hashedCode
|
||||||
|
}
|
||||||
|
codesJSON, err := json.Marshal(hashedCodes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("序列化恢复码失败: %w", err)
|
||||||
}
|
}
|
||||||
codesJSON, _ := json.Marshal(hashedCodes)
|
|
||||||
user.TOTPRecoveryCodes = string(codesJSON)
|
user.TOTPRecoveryCodes = string(codesJSON)
|
||||||
|
|
||||||
if err := s.userRepo.UpdateTOTP(ctx, user); err != nil {
|
if err := s.userRepo.UpdateTOTP(ctx, user); err != nil {
|
||||||
@@ -96,11 +103,13 @@ func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string
|
|||||||
if !valid {
|
if !valid {
|
||||||
var hashedCodes []string
|
var hashedCodes []string
|
||||||
if user.TOTPRecoveryCodes != "" {
|
if user.TOTPRecoveryCodes != "" {
|
||||||
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes)
|
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
||||||
|
return fmt.Errorf("解析恢复码失败: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||||||
if !matched {
|
if !matched {
|
||||||
return errors.New("\u9a8c\u8bc1\u7801\u6216\u6062\u590d\u7801\u9519\u8bef")
|
return errors.New("验证码或恢复码错误")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,17 +134,24 @@ func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string)
|
|||||||
|
|
||||||
var storedCodes []string
|
var storedCodes []string
|
||||||
if user.TOTPRecoveryCodes != "" {
|
if user.TOTPRecoveryCodes != "" {
|
||||||
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes)
|
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes); err != nil {
|
||||||
|
return fmt.Errorf("解析恢复码失败: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
idx, matched := auth.ValidateRecoveryCode(code, storedCodes)
|
idx, matched := auth.ValidateRecoveryCode(code, storedCodes)
|
||||||
if !matched {
|
if !matched {
|
||||||
return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f")
|
return errors.New("验证码错误或已过期")
|
||||||
}
|
}
|
||||||
|
|
||||||
storedCodes = append(storedCodes[:idx], storedCodes[idx+1:]...)
|
storedCodes = append(storedCodes[:idx], storedCodes[idx+1:]...)
|
||||||
codesJSON, _ := json.Marshal(storedCodes)
|
codesJSON, err := json.Marshal(storedCodes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("序列化恢复码失败: %w", err)
|
||||||
|
}
|
||||||
user.TOTPRecoveryCodes = string(codesJSON)
|
user.TOTPRecoveryCodes = string(codesJSON)
|
||||||
_ = s.userRepo.UpdateTOTP(ctx, user)
|
if err := s.userRepo.UpdateTOTP(ctx, user); err != nil {
|
||||||
|
return fmt.Errorf("更新恢复码失败: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
112
internal/service/totp_internal_test.go
Normal file
112
internal/service/totp_internal_test.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/user-management-system/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type totpTestRepo struct {
|
||||||
|
user *domain.User
|
||||||
|
getErr error
|
||||||
|
updateTOTPErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *totpTestRepo) Create(ctx context.Context, user *domain.User) error { return nil }
|
||||||
|
func (r *totpTestRepo) Update(ctx context.Context, user *domain.User) error { return nil }
|
||||||
|
func (r *totpTestRepo) UpdateTOTP(ctx context.Context, user *domain.User) error {
|
||||||
|
if r.updateTOTPErr != nil {
|
||||||
|
return r.updateTOTPErr
|
||||||
|
}
|
||||||
|
copyUser := *user
|
||||||
|
r.user = ©User
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *totpTestRepo) Delete(ctx context.Context, id int64) error { return nil }
|
||||||
|
func (r *totpTestRepo) GetByID(ctx context.Context, id int64) (*domain.User, error) {
|
||||||
|
if r.getErr != nil {
|
||||||
|
return nil, r.getErr
|
||||||
|
}
|
||||||
|
if r.user == nil || r.user.ID != id {
|
||||||
|
return nil, errors.New("not found")
|
||||||
|
}
|
||||||
|
copyUser := *r.user
|
||||||
|
return ©User, nil
|
||||||
|
}
|
||||||
|
func (r *totpTestRepo) GetByUsername(ctx context.Context, username string) (*domain.User, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (r *totpTestRepo) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (r *totpTestRepo) GetByPhone(ctx context.Context, phone string) (*domain.User, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (r *totpTestRepo) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) {
|
||||||
|
return nil, 0, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (r *totpTestRepo) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) {
|
||||||
|
return nil, 0, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (r *totpTestRepo) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (r *totpTestRepo) UpdateLastLogin(ctx context.Context, id int64, ip string) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (r *totpTestRepo) ExistsByUsername(ctx context.Context, username string) (bool, error) {
|
||||||
|
return false, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (r *totpTestRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||||
|
return false, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (r *totpTestRepo) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
|
||||||
|
return false, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (r *totpTestRepo) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
|
||||||
|
return nil, 0, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPService_ReturnsDecodeErrorForCorruptedRecoveryCodes(t *testing.T) {
|
||||||
|
repo := &totpTestRepo{user: &domain.User{
|
||||||
|
ID: 42,
|
||||||
|
Username: "totp-user",
|
||||||
|
TOTPEnabled: true,
|
||||||
|
TOTPSecret: "invalid-secret",
|
||||||
|
TOTPRecoveryCodes: "not-json",
|
||||||
|
}}
|
||||||
|
svc := NewTOTPService(repo)
|
||||||
|
|
||||||
|
err := svc.VerifyTOTP(context.Background(), 42, "recovery-code")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected corrupted recovery-code payload to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "解析恢复码失败") {
|
||||||
|
t.Fatalf("expected decode error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPService_ReturnsUpdateErrorAfterRecoveryCodeConsumption(t *testing.T) {
|
||||||
|
repo := &totpTestRepo{
|
||||||
|
user: &domain.User{
|
||||||
|
ID: 7,
|
||||||
|
Username: "totp-user",
|
||||||
|
TOTPEnabled: true,
|
||||||
|
TOTPSecret: "invalid-secret",
|
||||||
|
TOTPRecoveryCodes: `["RECOVERY-1"]`,
|
||||||
|
},
|
||||||
|
updateTOTPErr: errors.New("write failed"),
|
||||||
|
}
|
||||||
|
svc := NewTOTPService(repo)
|
||||||
|
|
||||||
|
err := svc.VerifyTOTP(context.Background(), 7, "RECOVERY-1")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected update failure to be returned")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "更新恢复码失败") {
|
||||||
|
t.Fatalf("expected update error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user