fix confirmed deployment risks
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled
Security Scan / backend-security (push) Has been cancelled
Security Scan / frontend-security (push) Has been cancelled

This commit is contained in:
2026-04-25 09:24:17 +08:00
parent 75d03e4713
commit 649eb23091
10 changed files with 258 additions and 19 deletions

View File

@@ -1,12 +1,23 @@
package admin
import (
"sync"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
restorePasswordMaxFailures = 5
restorePasswordBlockWindow = 15 * time.Minute
)
var backupRestoreAttemptLimiter = newRestoreAttemptLimiter(restorePasswordMaxFailures, restorePasswordBlockWindow)
type BackupHandler struct {
backupService *service.BackupService
userService *service.UserService
@@ -165,6 +176,64 @@ type RestoreBackupRequest struct {
Password string `json:"password" binding:"required"`
}
type restoreAttemptState struct {
failuresUntil time.Time
failureCount int
}
type restoreAttemptLimiter struct {
mu sync.Mutex
maxFailures int
window time.Duration
states map[int64]restoreAttemptState
}
func newRestoreAttemptLimiter(maxFailures int, window time.Duration) *restoreAttemptLimiter {
return &restoreAttemptLimiter{
maxFailures: maxFailures,
window: window,
states: make(map[int64]restoreAttemptState),
}
}
func (l *restoreAttemptLimiter) allow(userID int64, now time.Time) (bool, time.Duration) {
l.mu.Lock()
defer l.mu.Unlock()
state, ok := l.states[userID]
if !ok {
return true, 0
}
if now.After(state.failuresUntil) {
delete(l.states, userID)
return true, 0
}
if state.failureCount < l.maxFailures {
return true, 0
}
return false, state.failuresUntil.Sub(now)
}
func (l *restoreAttemptLimiter) recordFailure(userID int64, now time.Time) {
l.mu.Lock()
defer l.mu.Unlock()
state := l.states[userID]
if now.After(state.failuresUntil) {
state = restoreAttemptState{}
}
state.failureCount++
state.failuresUntil = now.Add(l.window)
l.states[userID] = state
}
func (l *restoreAttemptLimiter) recordSuccess(userID int64) {
l.mu.Lock()
defer l.mu.Unlock()
delete(l.states, userID)
}
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
@@ -185,6 +254,11 @@ func (h *BackupHandler) RestoreBackup(c *gin.Context) {
return
}
if allowed, _ := backupRestoreAttemptLimiter.allow(sub.UserID, time.Now()); !allowed {
response.ErrorFrom(c, infraerrors.TooManyRequests("RESTORE_PASSWORD_RATE_LIMITED", "too many failed password attempts, please try again later"))
return
}
// 获取管理员用户并验证密码
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
if err != nil {
@@ -192,9 +266,11 @@ func (h *BackupHandler) RestoreBackup(c *gin.Context) {
return
}
if !user.CheckPassword(req.Password) {
backupRestoreAttemptLimiter.recordFailure(sub.UserID, time.Now())
response.BadRequest(c, "incorrect admin password")
return
}
backupRestoreAttemptLimiter.recordSuccess(sub.UserID)
record, err := h.backupService.StartRestore(c.Request.Context(), backupID)
if err != nil {

View File

@@ -0,0 +1,56 @@
package admin
import (
"testing"
"time"
)
func TestRestoreAttemptLimiterBlocksAfterMaxFailures(t *testing.T) {
limiter := newRestoreAttemptLimiter(2, time.Minute)
now := time.Unix(1700000000, 0)
if limited, _ := limiter.allow(1, now); !limited {
t.Fatalf("first attempt should be allowed")
}
limiter.recordFailure(1, now)
if limited, _ := limiter.allow(1, now.Add(time.Second)); !limited {
t.Fatalf("second attempt should still be allowed")
}
limiter.recordFailure(1, now.Add(2*time.Second))
allowed, retryAfter := limiter.allow(1, now.Add(3*time.Second))
if allowed {
t.Fatal("limiter should block after hitting max failures")
}
if retryAfter <= 0 {
t.Fatalf("retryAfter should be positive, got %v", retryAfter)
}
}
func TestRestoreAttemptLimiterResetsAfterSuccess(t *testing.T) {
limiter := newRestoreAttemptLimiter(2, time.Minute)
now := time.Unix(1700000000, 0)
limiter.recordFailure(1, now)
limiter.recordSuccess(1)
allowed, retryAfter := limiter.allow(1, now.Add(time.Second))
if !allowed {
t.Fatalf("limiter should reset after success, retryAfter=%v", retryAfter)
}
}
func TestRestoreAttemptLimiterExpiresBlockWindow(t *testing.T) {
limiter := newRestoreAttemptLimiter(1, time.Minute)
now := time.Unix(1700000000, 0)
limiter.recordFailure(1, now)
allowed, _ := limiter.allow(1, now.Add(61*time.Second))
if !allowed {
t.Fatal("limiter should allow attempts after block window expires")
}
}

View File

@@ -170,12 +170,11 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
if err != nil {
response.ErrorFrom(c, err)
return
}
_ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
@@ -270,7 +269,11 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
}
// Delete the login session (only after all checks pass)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
if err := h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken); err != nil {
slog.Warn("login_2fa_delete_session_failed", "user_id", session.UserID, "error", err)
response.InternalError(c, "Failed to finalize 2FA login session")
return
}
h.respondWithTokenPair(c, user)
}

View File

@@ -80,7 +80,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
provider, err := h.paymentService.GetWebhookProvider(c.Request.Context(), providerKey, outTradeNo)
if err != nil {
slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err)
writeSuccessResponse(c, providerKey)
writeUnavailableResponse(c, providerKey)
return
}
@@ -137,12 +137,22 @@ type wxpaySuccessResponse struct {
Message string `json:"message"`
}
type wxpayFailureResponse struct {
Code string `json:"code"`
Message string `json:"message"`
}
// WeChat Pay webhook success response constants.
const (
wxpaySuccessCode = "SUCCESS"
wxpaySuccessMessage = "成功"
)
const (
wxpayFailureCode = "FAIL"
unavailableMessage = "provider unavailable"
)
// writeSuccessResponse sends the provider-specific success response.
// WeChat Pay requires JSON {"code":"SUCCESS","message":"成功"};
// Stripe expects an empty 200; others accept plain text "success".
@@ -156,3 +166,14 @@ func writeSuccessResponse(c *gin.Context, providerKey string) {
c.String(http.StatusOK, "success")
}
}
func writeUnavailableResponse(c *gin.Context, providerKey string) {
switch providerKey {
case payment.TypeWxpay:
c.JSON(http.StatusServiceUnavailable, wxpayFailureResponse{Code: wxpayFailureCode, Message: unavailableMessage})
case payment.TypeStripe:
c.String(http.StatusServiceUnavailable, "")
default:
c.String(http.StatusServiceUnavailable, unavailableMessage)
}
}

View File

@@ -97,3 +97,64 @@ func TestWebhookConstants(t *testing.T) {
assert.Equal(t, 200, webhookLogTruncateLen)
})
}
func TestWriteUnavailableResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
providerKey string
wantCode int
wantContentType string
wantBody string
checkJSON bool
wantJSONCode string
wantJSONMessage string
}{
{
name: "wxpay returns JSON failure to trigger retry",
providerKey: "wxpay",
wantCode: http.StatusServiceUnavailable,
wantContentType: "application/json",
checkJSON: true,
wantJSONCode: "FAIL",
wantJSONMessage: "provider unavailable",
},
{
name: "stripe returns 503 with empty body",
providerKey: "stripe",
wantCode: http.StatusServiceUnavailable,
wantContentType: "text/plain",
wantBody: "",
},
{
name: "easypay returns 503 plain text",
providerKey: "easypay",
wantCode: http.StatusServiceUnavailable,
wantContentType: "text/plain",
wantBody: "provider unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
writeUnavailableResponse(c, tt.providerKey)
assert.Equal(t, tt.wantCode, w.Code)
assert.Contains(t, w.Header().Get("Content-Type"), tt.wantContentType)
if tt.checkJSON {
var resp map[string]string
err := json.Unmarshal(w.Body.Bytes(), &resp)
require.NoError(t, err)
assert.Equal(t, tt.wantJSONCode, resp["code"])
assert.Equal(t, tt.wantJSONMessage, resp["message"])
} else {
assert.Equal(t, tt.wantBody, w.Body.String())
}
})
}
}

View File

@@ -400,35 +400,29 @@ func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool {
}
// Login 用户登录返回JWT token
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *User, error) {
func (s *AuthService) Login(ctx context.Context, email, password string) (*User, error) {
// 查找用户
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return "", nil, ErrInvalidCredentials
return nil, ErrInvalidCredentials
}
// 记录数据库错误但不暴露给用户
logger.LegacyPrintf("service.auth", "[Auth] Database error during login: %v", err)
return "", nil, ErrServiceUnavailable
return nil, ErrServiceUnavailable
}
// 验证密码
if !s.CheckPassword(password, user.PasswordHash) {
return "", nil, ErrInvalidCredentials
return nil, ErrInvalidCredentials
}
// 检查用户状态
if !user.IsActive() {
return "", nil, ErrUserNotActive
return nil, ErrUserNotActive
}
// 生成JWT token
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
}
return token, user, nil
return user, nil
}
// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录:

View File

@@ -43,7 +43,7 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 500 * 1024 * 1024
defaultMaxLineSize = 1 * 1024 * 1024
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
// to match real Claude CLI traffic as closely as possible. When we need a visual
// separator between system blocks, we add "\n\n" at concatenation time.

View File

@@ -0,0 +1,22 @@
import { describe, expect, it } from 'vitest'
import { sanitizeHtml } from '../sanitize'
describe('sanitizeHtml', () => {
it('removes unsafe script tags', () => {
const raw = '<div>safe</div><script>alert(1)</script>'
const sanitized = sanitizeHtml(raw)
expect(sanitized).toContain('<div>safe</div>')
expect(sanitized).not.toContain('<script>')
expect(sanitized).not.toContain('alert(1)')
})
it('removes inline event handlers', () => {
const raw = '<img src="x" onerror="alert(1)"><p onclick="evil()">ok</p>'
const sanitized = sanitizeHtml(raw)
expect(sanitized).not.toContain('onerror')
expect(sanitized).not.toContain('onclick')
})
})

View File

@@ -4,3 +4,8 @@ export function sanitizeSvg(svg: string): string {
if (!svg) return ''
return DOMPurify.sanitize(svg, { USE_PROFILES: { svg: true, svgFilters: true } })
}
export function sanitizeHtml(html: string): string {
if (!html) return ''
return DOMPurify.sanitize(html)
}

View File

@@ -8,8 +8,7 @@
class="h-screen w-full border-0"
allowfullscreen
></iframe>
<!-- HTML mode - SECURITY: homeContent is admin-only setting, XSS risk is acceptable -->
<div v-else v-html="homeContent"></div>
<div v-else v-html="sanitizedHomeContent"></div>
</div>
<!-- Default Home Page -->
@@ -410,6 +409,7 @@ import { useI18n } from 'vue-i18n'
import { useAuthStore, useAppStore } from '@/stores'
import LocaleSwitcher from '@/components/common/LocaleSwitcher.vue'
import Icon from '@/components/icons/Icon.vue'
import { sanitizeHtml } from '@/utils/sanitize'
const { t } = useI18n()
@@ -422,6 +422,7 @@ const siteLogo = computed(() => appStore.cachedPublicSettings?.site_logo || appS
const siteSubtitle = computed(() => appStore.cachedPublicSettings?.site_subtitle || 'AI API Gateway Platform')
const docUrl = computed(() => appStore.cachedPublicSettings?.doc_url || appStore.docUrl || '')
const homeContent = computed(() => appStore.cachedPublicSettings?.home_content || '')
const sanitizedHomeContent = computed(() => sanitizeHtml(homeContent.value))
// Check if homeContent is a URL (for iframe display)
const isHomeContentUrl = computed(() => {