feat(risk-control): add content moderation audit

This commit is contained in:
shaw
2026-05-07 09:01:48 +08:00
parent a1106e8167
commit fff4a300c6
54 changed files with 6840 additions and 34 deletions

View File

@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
coderws "github.com/coder/websocket"
@@ -646,6 +647,180 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
}
type contentModerationHandlerSettingRepo struct {
values map[string]string
}
func (r *contentModerationHandlerSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
if value, ok := r.values[key]; ok {
return &service.Setting{Key: key, Value: value}, nil
}
return nil, service.ErrSettingNotFound
}
func (r *contentModerationHandlerSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
if value, ok := r.values[key]; ok {
return value, nil
}
return "", service.ErrSettingNotFound
}
func (r *contentModerationHandlerSettingRepo) Set(ctx context.Context, key, value string) error {
if r.values == nil {
r.values = map[string]string{}
}
r.values[key] = value
return nil
}
func (r *contentModerationHandlerSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := map[string]string{}
for _, key := range keys {
if value, ok := r.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (r *contentModerationHandlerSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
if r.values == nil {
r.values = map[string]string{}
}
for key, value := range settings {
r.values[key] = value
}
return nil
}
func (r *contentModerationHandlerSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(r.values))
for key, value := range r.values {
out[key] = value
}
return out, nil
}
func (r *contentModerationHandlerSettingRepo) Delete(ctx context.Context, key string) error {
delete(r.values, key)
return nil
}
type contentModerationHandlerTestRepo struct {
logs []service.ContentModerationLog
}
func (r *contentModerationHandlerTestRepo) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
if log != nil {
r.logs = append(r.logs, *log)
}
return nil
}
func (r *contentModerationHandlerTestRepo) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *contentModerationHandlerTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
return 0, nil
}
func (r *contentModerationHandlerTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) {
return &service.ContentModerationCleanupResult{}, nil
}
func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T) {
gin.SetMode(gin.TestMode)
moderationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/v1/moderations", r.URL.Path)
_, _ = w.Write([]byte(`{"results":[{"category_scores":{"sexual":0.9}}]}`))
}))
defer moderationServer.Close()
cfg := &service.ContentModerationConfig{
Enabled: true,
Mode: service.ContentModerationModePreBlock,
BaseURL: moderationServer.URL,
Model: "omni-moderation-latest",
APIKeys: []string{"sk-test"},
SampleRate: 100,
AllGroups: true,
BlockMessage: "内容审计测试阻断",
}
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationHandlerTestRepo{}
settingRepo := &contentModerationHandlerSettingRepo{values: map[string]string{
service.SettingKeyRiskControlEnabled: "true",
service.SettingKeyContentModerationConfig: string(rawCfg),
}}
moderationSvc := service.NewContentModerationService(
settingRepo,
repo,
nil,
nil,
nil,
nil,
nil,
)
decision, err := moderationSvc.Check(context.Background(), service.ContentModerationCheckInput{
UserID: 1,
Endpoint: "/v1/responses",
Provider: "openai",
Model: "gpt-5.5",
Protocol: service.ContentModerationProtocolOpenAIResponses,
Body: []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]}`),
})
require.NoError(t, err)
require.True(t, decision.Blocked)
repo.logs = nil
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
apiKeyService: &service.APIKeyService{},
contentModerationService: moderationSvc,
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(&concurrencyCacheMock{}), SSEPingFormatNone, time.Second),
}
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{
"type":"response.create",
"model":"gpt-5.5",
"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]
}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, payload, readErr := clientConn.Read(readCtx)
cancelRead()
if readErr == nil {
require.Contains(t, string(payload), "content_policy_violation")
require.Contains(t, string(payload), "内容审计测试阻断")
} else {
var closeErr coderws.CloseError
require.ErrorAs(t, readErr, &closeErr)
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
require.Contains(t, closeErr.Reason, "内容审计测试阻断")
}
require.Len(t, repo.logs, 1)
require.True(t, repo.logs[0].Flagged)
require.Equal(t, service.ContentModerationActionBlock, repo.logs[0].Action)
require.Equal(t, "bad prompt", repo.logs[0].InputExcerpt)
}
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,