feat(risk-control): add content moderation audit
This commit is contained in:
234
backend/internal/handler/admin/content_moderation_handler.go
Normal file
234
backend/internal/handler/admin/content_moderation_handler.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ContentModerationHandler struct {
|
||||
service *service.ContentModerationService
|
||||
}
|
||||
|
||||
func NewContentModerationHandler(svc *service.ContentModerationService) *ContentModerationHandler {
|
||||
return &ContentModerationHandler{service: svc}
|
||||
}
|
||||
|
||||
type contentModerationConfigRequest struct {
|
||||
Enabled *bool `json:"enabled"`
|
||||
Mode *string `json:"mode"`
|
||||
BaseURL *string `json:"base_url"`
|
||||
Model *string `json:"model"`
|
||||
APIKey *string `json:"api_key"`
|
||||
APIKeys *[]string `json:"api_keys"`
|
||||
ClearAPIKey bool `json:"clear_api_key"`
|
||||
TimeoutMS *int `json:"timeout_ms"`
|
||||
SampleRate *int `json:"sample_rate"`
|
||||
AllGroups *bool `json:"all_groups"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
RecordNonHits *bool `json:"record_non_hits"`
|
||||
WorkerCount *int `json:"worker_count"`
|
||||
QueueSize *int `json:"queue_size"`
|
||||
BlockStatus *int `json:"block_status"`
|
||||
BlockMessage *string `json:"block_message"`
|
||||
EmailOnHit *bool `json:"email_on_hit"`
|
||||
AutoBanEnabled *bool `json:"auto_ban_enabled"`
|
||||
BanThreshold *int `json:"ban_threshold"`
|
||||
ViolationWindowHours *int `json:"violation_window_hours"`
|
||||
RetryCount *int `json:"retry_count"`
|
||||
HitRetentionDays *int `json:"hit_retention_days"`
|
||||
NonHitRetentionDays *int `json:"non_hit_retention_days"`
|
||||
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
|
||||
}
|
||||
|
||||
type contentModerationAPIKeyTestRequest struct {
|
||||
APIKeys []string `json:"api_keys"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Model string `json:"model"`
|
||||
TimeoutMS int `json:"timeout_ms"`
|
||||
Prompt string `json:"prompt"`
|
||||
Images []string `json:"images"`
|
||||
}
|
||||
|
||||
type contentModerationHashRequest struct {
|
||||
InputHash string `json:"input_hash"`
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) GetConfig(c *gin.Context) {
|
||||
cfg, err := h.service.GetConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) {
|
||||
var req contentModerationConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.service.UpdateConfig(c.Request.Context(), service.UpdateContentModerationConfigInput{
|
||||
Enabled: req.Enabled,
|
||||
Mode: req.Mode,
|
||||
BaseURL: req.BaseURL,
|
||||
Model: req.Model,
|
||||
APIKey: req.APIKey,
|
||||
APIKeys: req.APIKeys,
|
||||
ClearAPIKey: req.ClearAPIKey,
|
||||
TimeoutMS: req.TimeoutMS,
|
||||
SampleRate: req.SampleRate,
|
||||
AllGroups: req.AllGroups,
|
||||
GroupIDs: req.GroupIDs,
|
||||
RecordNonHits: req.RecordNonHits,
|
||||
WorkerCount: req.WorkerCount,
|
||||
QueueSize: req.QueueSize,
|
||||
BlockStatus: req.BlockStatus,
|
||||
BlockMessage: req.BlockMessage,
|
||||
EmailOnHit: req.EmailOnHit,
|
||||
AutoBanEnabled: req.AutoBanEnabled,
|
||||
BanThreshold: req.BanThreshold,
|
||||
ViolationWindowHours: req.ViolationWindowHours,
|
||||
RetryCount: req.RetryCount,
|
||||
HitRetentionDays: req.HitRetentionDays,
|
||||
NonHitRetentionDays: req.NonHitRetentionDays,
|
||||
PreHashCheckEnabled: req.PreHashCheckEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) TestAPIKeys(c *gin.Context) {
|
||||
var req contentModerationAPIKeyTestRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
result, err := h.service.TestAPIKeys(c.Request.Context(), service.TestContentModerationAPIKeysInput{
|
||||
APIKeys: req.APIKeys,
|
||||
BaseURL: req.BaseURL,
|
||||
Model: req.Model,
|
||||
TimeoutMS: req.TimeoutMS,
|
||||
Prompt: req.Prompt,
|
||||
Images: req.Images,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) GetStatus(c *gin.Context) {
|
||||
status, err := h.service.GetStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, status)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) ListLogs(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
filter := service.ContentModerationLogFilter{
|
||||
Pagination: pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
SortOrder: pagination.SortOrderDesc,
|
||||
},
|
||||
Result: c.Query("result"),
|
||||
Endpoint: c.Query("endpoint"),
|
||||
Search: c.Query("search"),
|
||||
}
|
||||
if raw := strings.TrimSpace(c.Query("group_id")); raw != "" {
|
||||
groupID, err := strconv.ParseInt(raw, 10, 64)
|
||||
if err != nil || groupID <= 0 {
|
||||
response.BadRequest(c, "Invalid group_id")
|
||||
return
|
||||
}
|
||||
filter.GroupID = &groupID
|
||||
}
|
||||
if raw := strings.TrimSpace(c.Query("from")); raw != "" {
|
||||
t, _, err := parseContentModerationDate(raw)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid from")
|
||||
return
|
||||
}
|
||||
filter.From = &t
|
||||
}
|
||||
if raw := strings.TrimSpace(c.Query("to")); raw != "" {
|
||||
t, dateOnly, err := parseContentModerationDate(raw)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid to")
|
||||
return
|
||||
}
|
||||
if dateOnly {
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
}
|
||||
filter.To = &t
|
||||
}
|
||||
items, pageResult, err := h.service.ListLogs(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, items, pageResult.Total, pageResult.Page, pageResult.PageSize)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) UnbanUser(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(strings.TrimSpace(c.Param("user_id")), 10, 64)
|
||||
if err != nil || userID <= 0 {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
result, err := h.service.UnbanUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) DeleteFlaggedHash(c *gin.Context) {
|
||||
var req contentModerationHashRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
result, err := h.service.DeleteFlaggedInputHash(c.Request.Context(), req.InputHash)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *ContentModerationHandler) ClearFlaggedHashes(c *gin.Context) {
|
||||
result, err := h.service.ClearFlaggedInputHashes(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func parseContentModerationDate(raw string) (time.Time, bool, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return time.Time{}, false, nil
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||
return t, false, nil
|
||||
}
|
||||
t, err := time.Parse("2006-01-02", raw)
|
||||
return t, err == nil, err
|
||||
}
|
||||
@@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
RiskControlEnabled: settings.RiskControlEnabled,
|
||||
AffiliateRebateRate: settings.AffiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
|
||||
@@ -497,6 +498,9 @@ type UpdateSettingsRequest struct {
|
||||
// Affiliate (邀请返利) feature switch
|
||||
AffiliateEnabled *bool `json:"affiliate_enabled"`
|
||||
|
||||
// 风控中心功能开关
|
||||
RiskControlEnabled *bool `json:"risk_control_enabled"`
|
||||
|
||||
// OpenAI fast/flex policy (optional, only updated when provided)
|
||||
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||
}
|
||||
@@ -1365,6 +1369,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
return previousSettings.AffiliateEnabled
|
||||
}(),
|
||||
RiskControlEnabled: func() bool {
|
||||
if req.RiskControlEnabled != nil {
|
||||
return *req.RiskControlEnabled
|
||||
}
|
||||
return previousSettings.RiskControlEnabled
|
||||
}(),
|
||||
}
|
||||
|
||||
authSourceDefaults := &service.AuthSourceDefaultSettings{
|
||||
@@ -1616,6 +1626,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
|
||||
|
||||
AffiliateEnabled: updatedSettings.AffiliateEnabled,
|
||||
|
||||
RiskControlEnabled: updatedSettings.RiskControlEnabled,
|
||||
}
|
||||
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
|
||||
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
|
||||
@@ -2004,6 +2016,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AffiliateEnabled != after.AffiliateEnabled {
|
||||
changed = append(changed, "affiliate_enabled")
|
||||
}
|
||||
if before.RiskControlEnabled != after.RiskControlEnabled {
|
||||
changed = append(changed, "risk_control_enabled")
|
||||
}
|
||||
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
|
||||
return changed
|
||||
}
|
||||
|
||||
130
backend/internal/handler/content_moderation_helper.go
Normal file
130
backend/internal/handler/content_moderation_helper.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (h *GatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
|
||||
if h == nil || h.contentModerationService == nil {
|
||||
return nil
|
||||
}
|
||||
return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body)
|
||||
}
|
||||
|
||||
func contentModerationStatus(decision *service.ContentModerationDecision) int {
|
||||
if decision == nil || decision.StatusCode < 400 || decision.StatusCode > 599 {
|
||||
return http.StatusForbidden
|
||||
}
|
||||
return decision.StatusCode
|
||||
}
|
||||
|
||||
func contentModerationErrorCode(decision *service.ContentModerationDecision) string {
|
||||
return "content_policy_violation"
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
|
||||
if h == nil || h.contentModerationService == nil {
|
||||
return nil
|
||||
}
|
||||
return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body)
|
||||
}
|
||||
|
||||
func runContentModeration(c *gin.Context, reqLog *zap.Logger, svc *service.ContentModerationService, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
|
||||
if svc == nil || c == nil || c.Request == nil {
|
||||
return nil
|
||||
}
|
||||
input := buildContentModerationInput(c, apiKey, subject, protocol, model, body)
|
||||
if reqLog != nil {
|
||||
reqLog.Info("content_moderation.gateway_check_start",
|
||||
zap.String("request_id", input.RequestID),
|
||||
zap.Int64("user_id", input.UserID),
|
||||
zap.Int64("api_key_id", input.APIKeyID),
|
||||
zap.String("api_key_name", input.APIKeyName),
|
||||
zap.Int64p("group_id", input.GroupID),
|
||||
zap.String("group_name", input.GroupName),
|
||||
zap.String("endpoint", input.Endpoint),
|
||||
zap.String("provider", input.Provider),
|
||||
zap.String("protocol", input.Protocol),
|
||||
zap.String("model", input.Model),
|
||||
zap.Int("body_bytes", len(body)),
|
||||
)
|
||||
}
|
||||
decision, err := svc.Check(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
if reqLog != nil {
|
||||
reqLog.Warn("content_moderation.check_failed", zap.Error(err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if reqLog != nil && decision != nil {
|
||||
reqLog.Info("content_moderation.gateway_check_done",
|
||||
zap.String("request_id", input.RequestID),
|
||||
zap.Bool("allowed", decision.Allowed),
|
||||
zap.Bool("blocked", decision.Blocked),
|
||||
zap.Bool("flagged", decision.Flagged),
|
||||
zap.String("action", decision.Action),
|
||||
zap.Int("status_code", decision.StatusCode),
|
||||
zap.String("highest_category", decision.HighestCategory),
|
||||
zap.Float64("highest_score", decision.HighestScore),
|
||||
)
|
||||
}
|
||||
return decision
|
||||
}
|
||||
|
||||
func buildContentModerationInput(c *gin.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) service.ContentModerationCheckInput {
|
||||
input := service.ContentModerationCheckInput{
|
||||
RequestID: contentModerationRequestID(c.Request.Context()),
|
||||
UserID: subject.UserID,
|
||||
Endpoint: GetInboundEndpoint(c),
|
||||
Provider: contentModerationProvider(apiKey),
|
||||
Model: strings.TrimSpace(model),
|
||||
Protocol: protocol,
|
||||
Body: body,
|
||||
}
|
||||
if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
||||
input.Provider = strings.TrimSpace(forcedPlatform)
|
||||
}
|
||||
if apiKey != nil {
|
||||
input.APIKeyID = apiKey.ID
|
||||
input.APIKeyName = apiKey.Name
|
||||
if apiKey.User != nil {
|
||||
input.UserEmail = apiKey.User.Email
|
||||
}
|
||||
if apiKey.GroupID != nil {
|
||||
groupID := *apiKey.GroupID
|
||||
input.GroupID = &groupID
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
input.GroupName = apiKey.Group.Name
|
||||
}
|
||||
}
|
||||
if input.Endpoint == "" && c.Request != nil && c.Request.URL != nil {
|
||||
input.Endpoint = c.Request.URL.Path
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
func contentModerationProvider(apiKey *service.APIKey) string {
|
||||
if apiKey == nil || apiKey.Group == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(apiKey.Group.Platform)
|
||||
}
|
||||
|
||||
func contentModerationRequestID(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if requestID, ok := ctx.Value(ctxkey.RequestID).(string); ok {
|
||||
return strings.TrimSpace(requestID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -197,6 +197,9 @@ type SystemSettings struct {
|
||||
// Available Channels feature switch (user-facing aggregate view)
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
|
||||
// 风控中心功能开关
|
||||
RiskControlEnabled bool `json:"risk_control_enabled"`
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
|
||||
@@ -256,6 +259,8 @@ type PublicSettings struct {
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
|
||||
RiskControlEnabled bool `json:"risk_control_enabled"`
|
||||
}
|
||||
|
||||
// OverloadCooldownSettings 529过载冷却配置 DTO
|
||||
|
||||
@@ -45,6 +45,7 @@ type GatewayHandler struct {
|
||||
apiKeyService *service.APIKeyService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
contentModerationService *service.ContentModerationService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
userMsgQueueHelper *UserMsgQueueHelper
|
||||
maxAccountSwitches int
|
||||
@@ -65,6 +66,7 @@ func NewGatewayHandler(
|
||||
apiKeyService *service.APIKeyService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
contentModerationService *service.ContentModerationService,
|
||||
userMsgQueueService *service.UserMessageQueueService,
|
||||
cfg *config.Config,
|
||||
settingService *service.SettingService,
|
||||
@@ -98,6 +100,7 @@ func NewGatewayHandler(
|
||||
apiKeyService: apiKeyService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
contentModerationService: contentModerationService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||
userMsgQueueHelper: umqHelper,
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
@@ -189,6 +192,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
|
||||
@@ -91,6 +91,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.chatCompletionsErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// Error passthrough binding
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
|
||||
@@ -96,6 +96,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.responsesErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// Error passthrough binding
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
|
||||
@@ -185,6 +185,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
setOpsRequestContext(c, modelName, stream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, authSubject, service.ContentModerationProtocolGemini, modelName, body); decision != nil && decision.Blocked {
|
||||
googleError(c, contentModerationStatus(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
|
||||
reqModel := modelName // 保存映射前的原始模型名
|
||||
|
||||
@@ -33,6 +33,7 @@ type AdminHandlers struct {
|
||||
Channel *admin.ChannelHandler
|
||||
ChannelMonitor *admin.ChannelMonitorHandler
|
||||
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
|
||||
ContentModeration *admin.ContentModerationHandler
|
||||
Payment *admin.PaymentHandler
|
||||
Affiliate *admin.AffiliateHandler
|
||||
}
|
||||
|
||||
@@ -81,6 +81,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
|
||||
@@ -27,15 +27,16 @@ import (
|
||||
|
||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||
type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
apiKeyService *service.APIKeyService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
imageLimiter *imageConcurrencyLimiter
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
apiKeyService *service.APIKeyService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
contentModerationService *service.ContentModerationService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
imageLimiter *imageConcurrencyLimiter
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
|
||||
@@ -53,6 +54,7 @@ func NewOpenAIGatewayHandler(
|
||||
apiKeyService *service.APIKeyService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
contentModerationService *service.ContentModerationService,
|
||||
cfg *config.Config,
|
||||
) *OpenAIGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
@@ -64,15 +66,16 @@ func NewOpenAIGatewayHandler(
|
||||
}
|
||||
}
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
apiKeyService: apiKeyService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
apiKeyService: apiKeyService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
contentModerationService: contentModerationService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
imageLimiter: &imageConcurrencyLimiter{},
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,6 +192,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body)
|
||||
if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
||||
@@ -599,6 +607,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
|
||||
h.anthropicErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析渠道级模型映射
|
||||
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
@@ -1153,6 +1166,12 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, firstMessage); decision != nil && decision.Blocked {
|
||||
writeContentModerationWSError(ctx, wsConn, decision)
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, decision.Message)
|
||||
return
|
||||
}
|
||||
|
||||
if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
@@ -1268,6 +1287,26 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
InitialRequestModel: reqModel,
|
||||
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
|
||||
}
|
||||
model := strings.TrimSpace(originalModel)
|
||||
if model == "" {
|
||||
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
}
|
||||
if model == "" {
|
||||
model = reqModel
|
||||
}
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
|
||||
writeContentModerationWSError(ctx, wsConn, decision)
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
@@ -1712,6 +1751,34 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s
|
||||
_ = conn.CloseNow()
|
||||
}
|
||||
|
||||
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
|
||||
if conn == nil || decision == nil {
|
||||
return
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
message := strings.TrimSpace(decision.Message)
|
||||
if message == "" {
|
||||
message = "content moderation blocked this request"
|
||||
}
|
||||
payload, err := json.Marshal(gin.H{
|
||||
"event_id": "evt_content_moderation_blocked",
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"code": contentModerationErrorCode(decision),
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
payload = []byte(`{"event_id":"evt_content_moderation_blocked","type":"error","error":{"type":"invalid_request_error","code":"content_policy_violation","message":"content moderation blocked this request"}}`)
|
||||
}
|
||||
writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
_ = conn.Write(writeCtx, coderws.MessageText, payload)
|
||||
}
|
||||
|
||||
func summarizeWSCloseErrorForLog(err error) (string, string) {
|
||||
if err == nil {
|
||||
return "-", "-"
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
coderws "github.com/coder/websocket"
|
||||
@@ -646,6 +647,180 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||
}
|
||||
|
||||
type contentModerationHandlerSettingRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
|
||||
if value, ok := r.values[key]; ok {
|
||||
return &service.Setting{Key: key, Value: value}, nil
|
||||
}
|
||||
return nil, service.ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
|
||||
if value, ok := r.values[key]; ok {
|
||||
return value, nil
|
||||
}
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) Set(ctx context.Context, key, value string) error {
|
||||
if r.values == nil {
|
||||
r.values = map[string]string{}
|
||||
}
|
||||
r.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
out := map[string]string{}
|
||||
for _, key := range keys {
|
||||
if value, ok := r.values[key]; ok {
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
if r.values == nil {
|
||||
r.values = map[string]string{}
|
||||
}
|
||||
for key, value := range settings {
|
||||
r.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
out := make(map[string]string, len(r.values))
|
||||
for key, value := range r.values {
|
||||
out[key] = value
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) Delete(ctx context.Context, key string) error {
|
||||
delete(r.values, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
type contentModerationHandlerTestRepo struct {
|
||||
logs []service.ContentModerationLog
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
|
||||
if log != nil {
|
||||
r.logs = append(r.logs, *log)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) {
|
||||
return &service.ContentModerationCleanupResult{}, nil
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
moderationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/v1/moderations", r.URL.Path)
|
||||
_, _ = w.Write([]byte(`{"results":[{"category_scores":{"sexual":0.9}}]}`))
|
||||
}))
|
||||
defer moderationServer.Close()
|
||||
|
||||
cfg := &service.ContentModerationConfig{
|
||||
Enabled: true,
|
||||
Mode: service.ContentModerationModePreBlock,
|
||||
BaseURL: moderationServer.URL,
|
||||
Model: "omni-moderation-latest",
|
||||
APIKeys: []string{"sk-test"},
|
||||
SampleRate: 100,
|
||||
AllGroups: true,
|
||||
BlockMessage: "内容审计测试阻断",
|
||||
}
|
||||
rawCfg, err := json.Marshal(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo := &contentModerationHandlerTestRepo{}
|
||||
settingRepo := &contentModerationHandlerSettingRepo{values: map[string]string{
|
||||
service.SettingKeyRiskControlEnabled: "true",
|
||||
service.SettingKeyContentModerationConfig: string(rawCfg),
|
||||
}}
|
||||
moderationSvc := service.NewContentModerationService(
|
||||
settingRepo,
|
||||
repo,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
decision, err := moderationSvc.Check(context.Background(), service.ContentModerationCheckInput{
|
||||
UserID: 1,
|
||||
Endpoint: "/v1/responses",
|
||||
Provider: "openai",
|
||||
Model: "gpt-5.5",
|
||||
Protocol: service.ContentModerationProtocolOpenAIResponses,
|
||||
Body: []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, decision.Blocked)
|
||||
repo.logs = nil
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
contentModerationService: moderationSvc,
|
||||
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(&concurrencyCacheMock{}), SSEPingFormatNone, time.Second),
|
||||
}
|
||||
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.5",
|
||||
"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]
|
||||
}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, payload, readErr := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
if readErr == nil {
|
||||
require.Contains(t, string(payload), "content_policy_violation")
|
||||
require.Contains(t, string(payload), "内容审计测试阻断")
|
||||
} else {
|
||||
var closeErr coderws.CloseError
|
||||
require.ErrorAs(t, readErr, &closeErr)
|
||||
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
|
||||
require.Contains(t, closeErr.Reason, "内容审计测试阻断")
|
||||
}
|
||||
require.Len(t, repo.logs, 1)
|
||||
require.True(t, repo.logs[0].Flagged)
|
||||
require.Equal(t, service.ContentModerationActionBlock, repo.logs[0].Action)
|
||||
require.Equal(t, "bad prompt", repo.logs[0].InputExcerpt)
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
|
||||
|
||||
@@ -85,6 +85,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
|
||||
return
|
||||
}
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIImages, parsed.Model, parsed.ModerationBody()); decision != nil && decision.Blocked {
|
||||
h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message)
|
||||
return
|
||||
}
|
||||
imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted)
|
||||
if !acquired {
|
||||
return
|
||||
|
||||
@@ -77,5 +77,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
|
||||
AffiliateEnabled: settings.AffiliateEnabled,
|
||||
|
||||
RiskControlEnabled: settings.RiskControlEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ func ProvideAdminHandlers(
|
||||
channelHandler *admin.ChannelHandler,
|
||||
channelMonitorHandler *admin.ChannelMonitorHandler,
|
||||
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
|
||||
contentModerationHandler *admin.ContentModerationHandler,
|
||||
paymentHandler *admin.PaymentHandler,
|
||||
affiliateHandler *admin.AffiliateHandler,
|
||||
) *AdminHandlers {
|
||||
@@ -67,6 +68,7 @@ func ProvideAdminHandlers(
|
||||
Channel: channelHandler,
|
||||
ChannelMonitor: channelMonitorHandler,
|
||||
ChannelMonitorTemplate: channelMonitorTemplateHandler,
|
||||
ContentModeration: contentModerationHandler,
|
||||
Payment: paymentHandler,
|
||||
Affiliate: affiliateHandler,
|
||||
}
|
||||
@@ -170,6 +172,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewChannelHandler,
|
||||
admin.NewChannelMonitorHandler,
|
||||
admin.NewChannelMonitorRequestTemplateHandler,
|
||||
admin.NewContentModerationHandler,
|
||||
admin.NewPaymentHandler,
|
||||
admin.NewAffiliateHandler,
|
||||
|
||||
|
||||
Reference in New Issue
Block a user