Merge pull request #2030 from KnowSky404/feature/account-bulk-edit-scope-and-compact
feat: support filtered account bulk edit and align compact OpenAI bulk fields
This commit is contained in:
@@ -134,19 +134,29 @@ type UpdateAccountRequest struct {
|
||||
|
||||
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
|
||||
type BulkUpdateAccountsRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
|
||||
Name string `json:"name"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
Filters *BulkUpdateAccountFilters `json:"filters"`
|
||||
Name string `json:"name"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
LoadFactor *int `json:"load_factor"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
type BulkUpdateAccountFilters struct {
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Group string `json:"group"`
|
||||
Search string `json:"search"`
|
||||
PrivacyMode string `json:"privacy_mode"`
|
||||
}
|
||||
|
||||
// CheckMixedChannelRequest represents check mixed channel risk request
|
||||
@@ -1369,6 +1379,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 && req.Filters == nil {
|
||||
response.BadRequest(c, "account_ids or filters is required")
|
||||
return
|
||||
}
|
||||
// base_rpm 输入校验:负值归零,超过 10000 截断
|
||||
sanitizeExtraBaseRPM(req.Extra)
|
||||
|
||||
@@ -1394,6 +1408,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
|
||||
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
|
||||
AccountIDs: req.AccountIDs,
|
||||
Filters: toServiceBulkUpdateAccountFilters(req.Filters),
|
||||
Name: req.Name,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
@@ -1429,6 +1444,20 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func toServiceBulkUpdateAccountFilters(filters *BulkUpdateAccountFilters) *service.BulkUpdateAccountFilters {
|
||||
if filters == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.BulkUpdateAccountFilters{
|
||||
Platform: filters.Platform,
|
||||
Type: filters.Type,
|
||||
Status: filters.Status,
|
||||
Group: filters.Group,
|
||||
Search: filters.Search,
|
||||
PrivacyMode: filters.PrivacyMode,
|
||||
}
|
||||
}
|
||||
|
||||
// ========== OAuth Handlers ==========
|
||||
|
||||
// GenerateAuthURLRequest represents the request for generating auth URL
|
||||
|
||||
@@ -196,3 +196,29 @@ func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
|
||||
require.Equal(t, float64(2), data["success"])
|
||||
require.Equal(t, float64(0), data["failed"])
|
||||
}
|
||||
|
||||
func TestBulkUpdateAcceptsFilterTargetRequest(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"filters": map[string]any{
|
||||
"platform": "openai",
|
||||
"type": "oauth",
|
||||
"status": "active",
|
||||
"group": "12",
|
||||
"privacy_mode": "blocked",
|
||||
"search": "bulk-target",
|
||||
},
|
||||
"schedulable": true,
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, float64(0), resp["code"])
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -291,6 +292,7 @@ type UpdateAccountInput struct {
|
||||
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
||||
type BulkUpdateAccountsInput struct {
|
||||
AccountIDs []int64
|
||||
Filters *BulkUpdateAccountFilters
|
||||
Name string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
@@ -307,6 +309,15 @@ type BulkUpdateAccountsInput struct {
|
||||
SkipMixedChannelCheck bool
|
||||
}
|
||||
|
||||
type BulkUpdateAccountFilters struct {
|
||||
Platform string
|
||||
Type string
|
||||
Status string
|
||||
Group string
|
||||
Search string
|
||||
PrivacyMode string
|
||||
}
|
||||
|
||||
// BulkUpdateAccountResult captures the result for a single account update.
|
||||
type BulkUpdateAccountResult struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
@@ -2286,6 +2297,14 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
// BulkUpdateAccounts updates multiple accounts in one request.
|
||||
// It merges credentials/extra keys instead of overwriting the whole object.
|
||||
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
||||
if len(input.AccountIDs) == 0 && input.Filters != nil {
|
||||
accountIDs, err := s.resolveBulkUpdateTargetIDs(ctx, input.Filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
input.AccountIDs = accountIDs
|
||||
}
|
||||
|
||||
result := &BulkUpdateAccountsResult{
|
||||
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||
@@ -2401,6 +2420,55 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) resolveBulkUpdateTargetIDs(ctx context.Context, filters *BulkUpdateAccountFilters) ([]int64, error) {
|
||||
if filters == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
groupID := int64(0)
|
||||
switch strings.TrimSpace(filters.Group) {
|
||||
case "":
|
||||
case "ungrouped":
|
||||
groupID = AccountListGroupUngrouped
|
||||
default:
|
||||
parsedGroupID, err := strconv.ParseInt(strings.TrimSpace(filters.Group), 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid group filter: %w", err)
|
||||
}
|
||||
groupID = parsedGroupID
|
||||
}
|
||||
|
||||
const pageSize = 500
|
||||
page := 1
|
||||
accountIDs := make([]int64, 0, pageSize)
|
||||
|
||||
for {
|
||||
accounts, total, err := s.ListAccounts(
|
||||
ctx,
|
||||
page,
|
||||
pageSize,
|
||||
filters.Platform,
|
||||
filters.Type,
|
||||
filters.Status,
|
||||
filters.Search,
|
||||
groupID,
|
||||
filters.PrivacyMode,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, account := range accounts {
|
||||
accountIDs = append(accountIDs, account.ID)
|
||||
}
|
||||
if int64(len(accountIDs)) >= total || len(accounts) == 0 {
|
||||
return accountIDs, nil
|
||||
}
|
||||
page++
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
||||
if err := s.accountRepo.Delete(ctx, id); err != nil {
|
||||
return err
|
||||
|
||||
@@ -5,8 +5,10 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -25,6 +27,19 @@ type accountRepoStubForBulkUpdate struct {
|
||||
getByIDCalled []int64
|
||||
listByGroupData map[int64][]Account
|
||||
listByGroupErr map[int64]error
|
||||
listData []Account
|
||||
listResult *pagination.PaginationResult
|
||||
listErr error
|
||||
listCalled bool
|
||||
lastListParams pagination.PaginationParams
|
||||
lastListFilters struct {
|
||||
platform string
|
||||
accountType string
|
||||
status string
|
||||
search string
|
||||
groupID int64
|
||||
privacyMode string
|
||||
}
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||
@@ -73,6 +88,24 @@ func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID in
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
|
||||
s.listCalled = true
|
||||
s.lastListParams = params
|
||||
s.lastListFilters.platform = platform
|
||||
s.lastListFilters.accountType = accountType
|
||||
s.lastListFilters.status = status
|
||||
s.lastListFilters.search = search
|
||||
s.lastListFilters.groupID = groupID
|
||||
s.lastListFilters.privacyMode = privacyMode
|
||||
if s.listErr != nil {
|
||||
return nil, nil, s.listErr
|
||||
}
|
||||
if s.listResult != nil {
|
||||
return s.listData, s.listResult, nil
|
||||
}
|
||||
return s.listData, &pagination.PaginationResult{Total: int64(len(s.listData))}, nil
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
@@ -170,3 +203,46 @@ func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingCon
|
||||
// No BindGroups should have been called since the check runs before any write.
|
||||
require.Empty(t, repo.bindGroupsCalls)
|
||||
}
|
||||
|
||||
func TestAdminServiceBulkUpdateAccounts_ResolvesIDsFromFilters(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{
|
||||
listData: []Account{
|
||||
{ID: 7},
|
||||
{ID: 11},
|
||||
},
|
||||
listResult: &pagination.PaginationResult{Total: 2},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
schedulable := true
|
||||
input := &BulkUpdateAccountsInput{
|
||||
Schedulable: &schedulable,
|
||||
}
|
||||
|
||||
filtersField := reflect.ValueOf(input).Elem().FieldByName("Filters")
|
||||
require.True(t, filtersField.IsValid(), "BulkUpdateAccountsInput should expose Filters for filter-target bulk update")
|
||||
require.Equal(t, reflect.Ptr, filtersField.Kind(), "BulkUpdateAccountsInput.Filters should be a pointer field")
|
||||
|
||||
filtersValue := reflect.New(filtersField.Type().Elem())
|
||||
filtersValue.Elem().FieldByName("Platform").SetString(PlatformOpenAI)
|
||||
filtersValue.Elem().FieldByName("Type").SetString(AccountTypeOAuth)
|
||||
filtersValue.Elem().FieldByName("Status").SetString(StatusActive)
|
||||
filtersValue.Elem().FieldByName("Group").SetString("12")
|
||||
filtersValue.Elem().FieldByName("PrivacyMode").SetString(PrivacyModeCFBlocked)
|
||||
filtersValue.Elem().FieldByName("Search").SetString("bulk-target")
|
||||
filtersField.Set(filtersValue)
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.True(t, repo.listCalled, "expected filter-target bulk update to resolve matching IDs via account list filters")
|
||||
require.Equal(t, PlatformOpenAI, repo.lastListFilters.platform)
|
||||
require.Equal(t, AccountTypeOAuth, repo.lastListFilters.accountType)
|
||||
require.Equal(t, StatusActive, repo.lastListFilters.status)
|
||||
require.Equal(t, "bulk-target", repo.lastListFilters.search)
|
||||
require.Equal(t, int64(12), repo.lastListFilters.groupID)
|
||||
require.Equal(t, PrivacyModeCFBlocked, repo.lastListFilters.privacyMode)
|
||||
require.Equal(t, []int64{7, 11}, repo.bulkUpdateIDs)
|
||||
require.Equal(t, 2, result.Success)
|
||||
require.Equal(t, 0, result.Failed)
|
||||
require.Equal(t, []int64{7, 11}, result.SuccessIDs)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user