fix startup bootstrap recovery and local verification

This commit is contained in:
2026-04-23 10:27:13 +08:00
parent 32b2c23a04
commit fa0aacc559
9 changed files with 211 additions and 59 deletions

View File

@@ -150,6 +150,9 @@ func runMainServer() {
log.Fatalf("Failed to initialize application: %v", err)
}
defer app.Cleanup()
if err := app.Bootstrap(); err != nil {
log.Fatalf("Failed to bootstrap application state: %v", err)
}
// 启动服务器
go func() {

View File

@@ -18,14 +18,16 @@ import (
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/setup"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
)
type Application struct {
Server *http.Server
Cleanup func()
Server *http.Server
Cleanup func()
Bootstrap func() error
}
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
@@ -53,9 +55,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
// Cleanup function provider
provideCleanup,
provideBootstrap,
// Application struct
wire.Struct(new(Application), "Server", "Cleanup"),
wire.Struct(new(Application), "Server", "Cleanup", "Bootstrap"),
)
return nil, nil
}
@@ -71,6 +74,28 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
}
}
func provideBootstrap(settingService *service.SettingService, userRepo service.UserRepository, cfg *config.Config) func() error {
return newBootstrapFunc(settingService.InitializeDefaultSettings, setup.RecoverAutoSetupAdmin, userRepo, cfg)
}
func newBootstrapFunc(
initDefaults func(context.Context) error,
recoverAdmin func(context.Context, service.UserRepository, *config.Config) error,
userRepo service.UserRepository,
cfg *config.Config,
) func() error {
return func() error {
ctx := context.Background()
if err := initDefaults(ctx); err != nil {
return err
}
if err := recoverAdmin(ctx, userRepo, cfg); err != nil {
return err
}
return nil
}
}
func provideCleanup(
entClient *ent.Client,
rdb *redis.Client,

View File

@@ -17,6 +17,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/setup"
"github.com/redis/go-redis/v9"
"log"
"net/http"
@@ -260,9 +261,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService)
bootstrap := provideBootstrap(settingService, userRepository, configConfig)
application := &Application{
Server: httpServer,
Cleanup: v,
Server: httpServer,
Cleanup: v,
Bootstrap: bootstrap,
}
return application, nil
}
@@ -270,8 +273,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
// wire.go:
type Application struct {
Server *http.Server
Cleanup func()
Server *http.Server
Cleanup func()
Bootstrap func() error
}
func providePrivacyClientFactory() service.PrivacyClientFactory {
@@ -285,6 +289,23 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
}
}
func provideBootstrap(settingService *service.SettingService, userRepo service.UserRepository, cfg *config.Config) func() error {
return newBootstrapFunc(settingService.InitializeDefaultSettings, setup.RecoverAutoSetupAdmin, userRepo, cfg)
}
func newBootstrapFunc(initDefaults func(context.Context) error, recoverAdmin func(context.Context, service.UserRepository, *config.Config) error, userRepo service.UserRepository, cfg *config.Config) func() error {
return func() error {
ctx := context.Background()
if err := initDefaults(ctx); err != nil {
return err
}
if err := recoverAdmin(ctx, userRepo, cfg); err != nil {
return err
}
return nil
}
}
func provideCleanup(
entClient *ent.Client,
rdb *redis.Client,

View File

@@ -1,6 +1,8 @@
package main
import (
"context"
"errors"
"testing"
"time"
@@ -83,3 +85,47 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
cleanup()
})
}
func TestNewBootstrapFunc_RunsDefaultsBeforeRecovery(t *testing.T) {
cfg := &config.Config{}
order := make([]string, 0, 2)
bootstrap := newBootstrapFunc(
func(context.Context) error {
order = append(order, "defaults")
return nil
},
func(_ context.Context, gotRepo service.UserRepository, got *config.Config) error {
require.Nil(t, gotRepo)
require.Same(t, cfg, got)
order = append(order, "recover")
return nil
},
nil,
cfg,
)
require.NoError(t, bootstrap())
require.Equal(t, []string{"defaults", "recover"}, order)
}
func TestNewBootstrapFunc_StopsWhenDefaultsFail(t *testing.T) {
cfg := &config.Config{}
wantErr := errors.New("defaults failed")
recoverCalled := false
bootstrap := newBootstrapFunc(
func(context.Context) error {
return wantErr
},
func(context.Context, service.UserRepository, *config.Config) error {
recoverCalled = true
return nil
},
nil,
cfg,
)
require.ErrorIs(t, bootstrap(), wantErr)
require.False(t, recoverCalled)
}

View File

@@ -33,7 +33,7 @@ func TestUserRegistrationAndLogin(t *testing.T) {
}
body, _ := json.Marshal(payload)
resp, err := doRequest(t, "POST", "/api/auth/register", body, "")
resp, err := doRequest(t, "POST", "/api/v1/auth/register", body, "")
if err != nil {
t.Skipf("注册接口不可用,跳过用户流程测试: %v", err)
return
@@ -64,7 +64,7 @@ func TestUserRegistrationAndLogin(t *testing.T) {
}
body, _ := json.Marshal(payload)
resp, err := doRequest(t, "POST", "/api/auth/login", body, "")
resp, err := doRequest(t, "POST", "/api/v1/auth/login", body, "")
if err != nil {
t.Fatalf("登录请求失败: %v", err)
}
@@ -111,7 +111,7 @@ func TestUserRegistrationAndLogin(t *testing.T) {
// 步骤 3: 使用 JWT 获取当前用户信息
t.Run("获取当前用户信息", func(t *testing.T) {
resp, err := doRequest(t, "GET", "/api/user/me", nil, accessToken)
resp, err := doRequest(t, "GET", "/api/v1/auth/me", nil, accessToken)
if err != nil {
t.Fatalf("请求失败: %v", err)
}
@@ -144,7 +144,7 @@ func TestAPIKeyLifecycle(t *testing.T) {
}
body, _ := json.Marshal(payload)
resp, err := doRequest(t, "POST", "/api/keys", body, accessToken)
resp, err := doRequest(t, "POST", "/api/v1/keys", body, accessToken)
if err != nil {
t.Fatalf("创建 API Key 请求失败: %v", err)
}
@@ -215,7 +215,7 @@ func TestAPIKeyLifecycle(t *testing.T) {
// 步骤 3: 查询用量记录
t.Run("查询用量记录", func(t *testing.T) {
resp, err := doRequest(t, "GET", "/api/usage/dashboard", nil, accessToken)
resp, err := doRequest(t, "GET", "/api/v1/usage/dashboard/stats", nil, accessToken)
if err != nil {
t.Fatalf("用量查询请求失败: %v", err)
}
@@ -279,7 +279,7 @@ func loginTestUser(t *testing.T) string {
}
body, _ := json.Marshal(payload)
resp, err := doRequest(t, "POST", "/api/auth/login", body, "")
resp, err := doRequest(t, "POST", "/api/v1/auth/login", body, "")
if err != nil {
return ""
}

View File

@@ -1538,44 +1538,6 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
s.Require().GreaterOrEqual(len(trend), 2)
}
// --- GetAPIKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
s.createUsageLog(user, apiKey2, account, 50, 100, 0.5, base)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2)
}
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
s.createUsageLog(user, apiKey, account, 50, 100, 0.5, base.Add(1*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend hourly")
s.Require().Len(trend, 2)
}
// --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {

View File

@@ -6,6 +6,7 @@ import (
"crypto/tls"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"os"
"strconv"
@@ -14,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -404,11 +406,11 @@ func createAdminUser(cfg *SetupConfig) (bool, string, error) {
defer cancel()
var totalUsers int64
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users").Scan(&totalUsers); err != nil {
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM public.users").Scan(&totalUsers); err != nil {
return false, "", err
}
var adminUsers int64
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users WHERE role = $1", service.RoleAdmin).Scan(&adminUsers); err != nil {
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM public.users WHERE role = $1", service.RoleAdmin).Scan(&adminUsers); err != nil {
return false, "", err
}
decision := decideAdminBootstrap(totalUsers, adminUsers)
@@ -442,7 +444,7 @@ func createAdminUser(cfg *SetupConfig) (bool, string, error) {
_, err = db.ExecContext(
ctx,
`INSERT INTO users (email, password_hash, role, balance, concurrency, status, created_at, updated_at)
`INSERT INTO public.users (email, password_hash, role, balance, concurrency, status, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
admin.Email,
admin.PasswordHash,
@@ -706,3 +708,68 @@ func AutoSetupFromEnv() error {
logger.LegacyPrintf("setup", "%s", "Auto setup completed successfully!")
return nil
}
// RecoverAutoSetupAdmin repairs an interrupted bootstrap by creating the admin
// user when the initialized application state still has no users.
func RecoverAutoSetupAdmin(ctx context.Context, userRepo service.UserRepository, cfg *config.Config) error {
if cfg == nil || userRepo == nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
if _, err := userRepo.GetFirstAdmin(ctx); err == nil {
logger.LegacyPrintf("setup", "startup admin recovery result: created=false reason=%s", adminBootstrapReasonAdminExists)
return nil
} else if !errors.Is(err, service.ErrUserNotFound) {
return err
}
_, page, err := userRepo.List(ctx, pagination.PaginationParams{Page: 1, PageSize: 1})
if err != nil {
return err
}
if page != nil && page.Total > 0 {
logger.LegacyPrintf("setup", "startup admin recovery result: created=false reason=%s", adminBootstrapReasonUsersExistWithoutAdmin)
return nil
}
adminEmail := strings.TrimSpace(cfg.Default.AdminEmail)
if adminEmail == "" {
adminEmail = "admin@sub2api.local"
}
adminPassword := getEnvOrDefault("ADMIN_PASSWORD", cfg.Default.AdminPassword)
if strings.TrimSpace(adminPassword) == "" {
password, genErr := generateSecret(16)
if genErr != nil {
return fmt.Errorf("failed to generate admin password: %w", genErr)
}
adminPassword = password
fmt.Printf("Generated admin password (one-time): %s\n", adminPassword)
fmt.Println("IMPORTANT: Save this password! It will not be shown again.")
}
admin := &service.User{
Email: getEnvOrDefault("ADMIN_EMAIL", adminEmail),
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 0,
Concurrency: setupDefaultAdminConcurrency(),
}
if err := admin.SetPassword(adminPassword); err != nil {
return err
}
if err := userRepo.Create(ctx, admin); err != nil {
if errors.Is(err, service.ErrEmailExists) {
logger.LegacyPrintf("setup", "startup admin recovery result: created=false reason=%s", adminBootstrapReasonAdminExists)
return nil
}
return err
}
logger.LegacyPrintf("setup", "startup admin recovery result: created=true reason=%s", adminBootstrapReasonEmptyDatabase)
return nil
}

View File

@@ -126,8 +126,12 @@ func TestQuoteIdentifier_SafetyInvariant(t *testing.T) {
quoted := quoteIdentifier(attack)
// Invariant 1: Output always starts and ends with exactly one double quote
if !strings.HasPrefix(quoted, `"`) { t.Errorf("must start with double quote") }
if !strings.HasSuffix(quoted, `"`) { t.Errorf("must end with double quote") }
if !strings.HasPrefix(quoted, `"`) {
t.Errorf("must start with double quote")
}
if !strings.HasSuffix(quoted, `"`) {
t.Errorf("must end with double quote")
}
// Invariant 2: All internal double quotes are escaped (doubled)
inner := quoted[1 : len(quoted)-1]
@@ -139,19 +143,28 @@ func TestQuoteIdentifier_SafetyInvariant(t *testing.T) {
// Invariant 3: When used in SQL, the result is a single valid identifier
sql := fmt.Sprintf("CREATE DATABASE %s", quoted)
if !strings.Contains(sql, quoted) { t.Error("SQL must contain the exact quoted identifier") }
if !strings.Contains(sql, quoted) {
t.Error("SQL must contain the exact quoted identifier")
}
})
}
}
func min(a, b int) int { if a < b { return a }; return b }
func min(a, b int) int {
if a < b {
return a
}
return b
}
func hashString(s string) int {
h := 0
for _, c := range s {
h = h*31 + int(c)
}
if h < 0 { h = -h }
if h < 0 {
h = -h
}
return h % 10000
}