feat(risk-control): add content moderation audit
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
coderws "github.com/coder/websocket"
|
||||
@@ -646,6 +647,180 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||
}
|
||||
|
||||
type contentModerationHandlerSettingRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
|
||||
if value, ok := r.values[key]; ok {
|
||||
return &service.Setting{Key: key, Value: value}, nil
|
||||
}
|
||||
return nil, service.ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
|
||||
if value, ok := r.values[key]; ok {
|
||||
return value, nil
|
||||
}
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) Set(ctx context.Context, key, value string) error {
|
||||
if r.values == nil {
|
||||
r.values = map[string]string{}
|
||||
}
|
||||
r.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
out := map[string]string{}
|
||||
for _, key := range keys {
|
||||
if value, ok := r.values[key]; ok {
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
if r.values == nil {
|
||||
r.values = map[string]string{}
|
||||
}
|
||||
for key, value := range settings {
|
||||
r.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
out := make(map[string]string, len(r.values))
|
||||
for key, value := range r.values {
|
||||
out[key] = value
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerSettingRepo) Delete(ctx context.Context, key string) error {
|
||||
delete(r.values, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
type contentModerationHandlerTestRepo struct {
|
||||
logs []service.ContentModerationLog
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
|
||||
if log != nil {
|
||||
r.logs = append(r.logs, *log)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) {
|
||||
return &service.ContentModerationCleanupResult{}, nil
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
moderationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/v1/moderations", r.URL.Path)
|
||||
_, _ = w.Write([]byte(`{"results":[{"category_scores":{"sexual":0.9}}]}`))
|
||||
}))
|
||||
defer moderationServer.Close()
|
||||
|
||||
cfg := &service.ContentModerationConfig{
|
||||
Enabled: true,
|
||||
Mode: service.ContentModerationModePreBlock,
|
||||
BaseURL: moderationServer.URL,
|
||||
Model: "omni-moderation-latest",
|
||||
APIKeys: []string{"sk-test"},
|
||||
SampleRate: 100,
|
||||
AllGroups: true,
|
||||
BlockMessage: "内容审计测试阻断",
|
||||
}
|
||||
rawCfg, err := json.Marshal(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo := &contentModerationHandlerTestRepo{}
|
||||
settingRepo := &contentModerationHandlerSettingRepo{values: map[string]string{
|
||||
service.SettingKeyRiskControlEnabled: "true",
|
||||
service.SettingKeyContentModerationConfig: string(rawCfg),
|
||||
}}
|
||||
moderationSvc := service.NewContentModerationService(
|
||||
settingRepo,
|
||||
repo,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
decision, err := moderationSvc.Check(context.Background(), service.ContentModerationCheckInput{
|
||||
UserID: 1,
|
||||
Endpoint: "/v1/responses",
|
||||
Provider: "openai",
|
||||
Model: "gpt-5.5",
|
||||
Protocol: service.ContentModerationProtocolOpenAIResponses,
|
||||
Body: []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, decision.Blocked)
|
||||
repo.logs = nil
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
contentModerationService: moderationSvc,
|
||||
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(&concurrencyCacheMock{}), SSEPingFormatNone, time.Second),
|
||||
}
|
||||
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.5",
|
||||
"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]
|
||||
}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, payload, readErr := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
if readErr == nil {
|
||||
require.Contains(t, string(payload), "content_policy_violation")
|
||||
require.Contains(t, string(payload), "内容审计测试阻断")
|
||||
} else {
|
||||
var closeErr coderws.CloseError
|
||||
require.ErrorAs(t, readErr, &closeErr)
|
||||
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
|
||||
require.Contains(t, closeErr.Reason, "内容审计测试阻断")
|
||||
}
|
||||
require.Len(t, repo.logs, 1)
|
||||
require.True(t, repo.logs[0].Flagged)
|
||||
require.Equal(t, service.ContentModerationActionBlock, repo.logs[0].Action)
|
||||
require.Equal(t, "bad prompt", repo.logs[0].InputExcerpt)
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
|
||||
|
||||
Reference in New Issue
Block a user