fix startup bootstrap recovery and local verification
This commit is contained in:
@@ -150,6 +150,9 @@ func runMainServer() {
|
||||
log.Fatalf("Failed to initialize application: %v", err)
|
||||
}
|
||||
defer app.Cleanup()
|
||||
if err := app.Bootstrap(); err != nil {
|
||||
log.Fatalf("Failed to bootstrap application state: %v", err)
|
||||
}
|
||||
|
||||
// 启动服务器
|
||||
go func() {
|
||||
|
||||
@@ -18,14 +18,16 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/setup"
|
||||
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type Application struct {
|
||||
Server *http.Server
|
||||
Cleanup func()
|
||||
Server *http.Server
|
||||
Cleanup func()
|
||||
Bootstrap func() error
|
||||
}
|
||||
|
||||
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
@@ -53,9 +55,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
|
||||
// Cleanup function provider
|
||||
provideCleanup,
|
||||
provideBootstrap,
|
||||
|
||||
// Application struct
|
||||
wire.Struct(new(Application), "Server", "Cleanup"),
|
||||
wire.Struct(new(Application), "Server", "Cleanup", "Bootstrap"),
|
||||
)
|
||||
return nil, nil
|
||||
}
|
||||
@@ -71,6 +74,28 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
}
|
||||
}
|
||||
|
||||
func provideBootstrap(settingService *service.SettingService, userRepo service.UserRepository, cfg *config.Config) func() error {
|
||||
return newBootstrapFunc(settingService.InitializeDefaultSettings, setup.RecoverAutoSetupAdmin, userRepo, cfg)
|
||||
}
|
||||
|
||||
func newBootstrapFunc(
|
||||
initDefaults func(context.Context) error,
|
||||
recoverAdmin func(context.Context, service.UserRepository, *config.Config) error,
|
||||
userRepo service.UserRepository,
|
||||
cfg *config.Config,
|
||||
) func() error {
|
||||
return func() error {
|
||||
ctx := context.Background()
|
||||
if err := initDefaults(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := recoverAdmin(ctx, userRepo, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func provideCleanup(
|
||||
entClient *ent.Client,
|
||||
rdb *redis.Client,
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/setup"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -260,9 +261,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService)
|
||||
bootstrap := provideBootstrap(settingService, userRepository, configConfig)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
Bootstrap: bootstrap,
|
||||
}
|
||||
return application, nil
|
||||
}
|
||||
@@ -270,8 +273,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
// wire.go:
|
||||
|
||||
type Application struct {
|
||||
Server *http.Server
|
||||
Cleanup func()
|
||||
Server *http.Server
|
||||
Cleanup func()
|
||||
Bootstrap func() error
|
||||
}
|
||||
|
||||
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||
@@ -285,6 +289,23 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
}
|
||||
}
|
||||
|
||||
func provideBootstrap(settingService *service.SettingService, userRepo service.UserRepository, cfg *config.Config) func() error {
|
||||
return newBootstrapFunc(settingService.InitializeDefaultSettings, setup.RecoverAutoSetupAdmin, userRepo, cfg)
|
||||
}
|
||||
|
||||
func newBootstrapFunc(initDefaults func(context.Context) error, recoverAdmin func(context.Context, service.UserRepository, *config.Config) error, userRepo service.UserRepository, cfg *config.Config) func() error {
|
||||
return func() error {
|
||||
ctx := context.Background()
|
||||
if err := initDefaults(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := recoverAdmin(ctx, userRepo, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func provideCleanup(
|
||||
entClient *ent.Client,
|
||||
rdb *redis.Client,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -83,3 +85,47 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
cleanup()
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewBootstrapFunc_RunsDefaultsBeforeRecovery(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
order := make([]string, 0, 2)
|
||||
|
||||
bootstrap := newBootstrapFunc(
|
||||
func(context.Context) error {
|
||||
order = append(order, "defaults")
|
||||
return nil
|
||||
},
|
||||
func(_ context.Context, gotRepo service.UserRepository, got *config.Config) error {
|
||||
require.Nil(t, gotRepo)
|
||||
require.Same(t, cfg, got)
|
||||
order = append(order, "recover")
|
||||
return nil
|
||||
},
|
||||
nil,
|
||||
cfg,
|
||||
)
|
||||
|
||||
require.NoError(t, bootstrap())
|
||||
require.Equal(t, []string{"defaults", "recover"}, order)
|
||||
}
|
||||
|
||||
func TestNewBootstrapFunc_StopsWhenDefaultsFail(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
wantErr := errors.New("defaults failed")
|
||||
recoverCalled := false
|
||||
|
||||
bootstrap := newBootstrapFunc(
|
||||
func(context.Context) error {
|
||||
return wantErr
|
||||
},
|
||||
func(context.Context, service.UserRepository, *config.Config) error {
|
||||
recoverCalled = true
|
||||
return nil
|
||||
},
|
||||
nil,
|
||||
cfg,
|
||||
)
|
||||
|
||||
require.ErrorIs(t, bootstrap(), wantErr)
|
||||
require.False(t, recoverCalled)
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ func TestUserRegistrationAndLogin(t *testing.T) {
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
resp, err := doRequest(t, "POST", "/api/auth/register", body, "")
|
||||
resp, err := doRequest(t, "POST", "/api/v1/auth/register", body, "")
|
||||
if err != nil {
|
||||
t.Skipf("注册接口不可用,跳过用户流程测试: %v", err)
|
||||
return
|
||||
@@ -64,7 +64,7 @@ func TestUserRegistrationAndLogin(t *testing.T) {
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
resp, err := doRequest(t, "POST", "/api/auth/login", body, "")
|
||||
resp, err := doRequest(t, "POST", "/api/v1/auth/login", body, "")
|
||||
if err != nil {
|
||||
t.Fatalf("登录请求失败: %v", err)
|
||||
}
|
||||
@@ -111,7 +111,7 @@ func TestUserRegistrationAndLogin(t *testing.T) {
|
||||
|
||||
// 步骤 3: 使用 JWT 获取当前用户信息
|
||||
t.Run("获取当前用户信息", func(t *testing.T) {
|
||||
resp, err := doRequest(t, "GET", "/api/user/me", nil, accessToken)
|
||||
resp, err := doRequest(t, "GET", "/api/v1/auth/me", nil, accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("请求失败: %v", err)
|
||||
}
|
||||
@@ -144,7 +144,7 @@ func TestAPIKeyLifecycle(t *testing.T) {
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
resp, err := doRequest(t, "POST", "/api/keys", body, accessToken)
|
||||
resp, err := doRequest(t, "POST", "/api/v1/keys", body, accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("创建 API Key 请求失败: %v", err)
|
||||
}
|
||||
@@ -215,7 +215,7 @@ func TestAPIKeyLifecycle(t *testing.T) {
|
||||
|
||||
// 步骤 3: 查询用量记录
|
||||
t.Run("查询用量记录", func(t *testing.T) {
|
||||
resp, err := doRequest(t, "GET", "/api/usage/dashboard", nil, accessToken)
|
||||
resp, err := doRequest(t, "GET", "/api/v1/usage/dashboard/stats", nil, accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("用量查询请求失败: %v", err)
|
||||
}
|
||||
@@ -279,7 +279,7 @@ func loginTestUser(t *testing.T) string {
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
resp, err := doRequest(t, "POST", "/api/auth/login", body, "")
|
||||
resp, err := doRequest(t, "POST", "/api/v1/auth/login", body, "")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1538,44 +1538,6 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
|
||||
s.Require().GreaterOrEqual(len(trend), 2)
|
||||
}
|
||||
|
||||
// --- GetAPIKeyUsageTrend ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
|
||||
s.createUsageLog(user, apiKey2, account, 50, 100, 0.5, base)
|
||||
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour))
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(48 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
|
||||
s.Require().NoError(err, "GetAPIKeyUsageTrend")
|
||||
s.Require().GreaterOrEqual(len(trend), 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
|
||||
s.createUsageLog(user, apiKey, account, 50, 100, 0.5, base.Add(1*time.Hour))
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(3 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
|
||||
s.Require().NoError(err, "GetAPIKeyUsageTrend hourly")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
|
||||
// --- ListWithFilters (additional filter tests) ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -404,11 +406,11 @@ func createAdminUser(cfg *SetupConfig) (bool, string, error) {
|
||||
defer cancel()
|
||||
|
||||
var totalUsers int64
|
||||
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users").Scan(&totalUsers); err != nil {
|
||||
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM public.users").Scan(&totalUsers); err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
var adminUsers int64
|
||||
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users WHERE role = $1", service.RoleAdmin).Scan(&adminUsers); err != nil {
|
||||
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM public.users WHERE role = $1", service.RoleAdmin).Scan(&adminUsers); err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
decision := decideAdminBootstrap(totalUsers, adminUsers)
|
||||
@@ -442,7 +444,7 @@ func createAdminUser(cfg *SetupConfig) (bool, string, error) {
|
||||
|
||||
_, err = db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO users (email, password_hash, role, balance, concurrency, status, created_at, updated_at)
|
||||
`INSERT INTO public.users (email, password_hash, role, balance, concurrency, status, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
|
||||
admin.Email,
|
||||
admin.PasswordHash,
|
||||
@@ -706,3 +708,68 @@ func AutoSetupFromEnv() error {
|
||||
logger.LegacyPrintf("setup", "%s", "Auto setup completed successfully!")
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecoverAutoSetupAdmin repairs an interrupted bootstrap by creating the admin
|
||||
// user when the initialized application state still has no users.
|
||||
func RecoverAutoSetupAdmin(ctx context.Context, userRepo service.UserRepository, cfg *config.Config) error {
|
||||
if cfg == nil || userRepo == nil {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
if _, err := userRepo.GetFirstAdmin(ctx); err == nil {
|
||||
logger.LegacyPrintf("setup", "startup admin recovery result: created=false reason=%s", adminBootstrapReasonAdminExists)
|
||||
return nil
|
||||
} else if !errors.Is(err, service.ErrUserNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
_, page, err := userRepo.List(ctx, pagination.PaginationParams{Page: 1, PageSize: 1})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if page != nil && page.Total > 0 {
|
||||
logger.LegacyPrintf("setup", "startup admin recovery result: created=false reason=%s", adminBootstrapReasonUsersExistWithoutAdmin)
|
||||
return nil
|
||||
}
|
||||
|
||||
adminEmail := strings.TrimSpace(cfg.Default.AdminEmail)
|
||||
if adminEmail == "" {
|
||||
adminEmail = "admin@sub2api.local"
|
||||
}
|
||||
|
||||
adminPassword := getEnvOrDefault("ADMIN_PASSWORD", cfg.Default.AdminPassword)
|
||||
if strings.TrimSpace(adminPassword) == "" {
|
||||
password, genErr := generateSecret(16)
|
||||
if genErr != nil {
|
||||
return fmt.Errorf("failed to generate admin password: %w", genErr)
|
||||
}
|
||||
adminPassword = password
|
||||
fmt.Printf("Generated admin password (one-time): %s\n", adminPassword)
|
||||
fmt.Println("IMPORTANT: Save this password! It will not be shown again.")
|
||||
}
|
||||
|
||||
admin := &service.User{
|
||||
Email: getEnvOrDefault("ADMIN_EMAIL", adminEmail),
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
Balance: 0,
|
||||
Concurrency: setupDefaultAdminConcurrency(),
|
||||
}
|
||||
if err := admin.SetPassword(adminPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := userRepo.Create(ctx, admin); err != nil {
|
||||
if errors.Is(err, service.ErrEmailExists) {
|
||||
logger.LegacyPrintf("setup", "startup admin recovery result: created=false reason=%s", adminBootstrapReasonAdminExists)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("setup", "startup admin recovery result: created=true reason=%s", adminBootstrapReasonEmptyDatabase)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -126,8 +126,12 @@ func TestQuoteIdentifier_SafetyInvariant(t *testing.T) {
|
||||
quoted := quoteIdentifier(attack)
|
||||
|
||||
// Invariant 1: Output always starts and ends with exactly one double quote
|
||||
if !strings.HasPrefix(quoted, `"`) { t.Errorf("must start with double quote") }
|
||||
if !strings.HasSuffix(quoted, `"`) { t.Errorf("must end with double quote") }
|
||||
if !strings.HasPrefix(quoted, `"`) {
|
||||
t.Errorf("must start with double quote")
|
||||
}
|
||||
if !strings.HasSuffix(quoted, `"`) {
|
||||
t.Errorf("must end with double quote")
|
||||
}
|
||||
|
||||
// Invariant 2: All internal double quotes are escaped (doubled)
|
||||
inner := quoted[1 : len(quoted)-1]
|
||||
@@ -139,19 +143,28 @@ func TestQuoteIdentifier_SafetyInvariant(t *testing.T) {
|
||||
|
||||
// Invariant 3: When used in SQL, the result is a single valid identifier
|
||||
sql := fmt.Sprintf("CREATE DATABASE %s", quoted)
|
||||
if !strings.Contains(sql, quoted) { t.Error("SQL must contain the exact quoted identifier") }
|
||||
if !strings.Contains(sql, quoted) {
|
||||
t.Error("SQL must contain the exact quoted identifier")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func min(a, b int) int { if a < b { return a }; return b }
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func hashString(s string) int {
|
||||
h := 0
|
||||
for _, c := range s {
|
||||
h = h*31 + int(c)
|
||||
}
|
||||
if h < 0 { h = -h }
|
||||
if h < 0 {
|
||||
h = -h
|
||||
}
|
||||
return h % 10000
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user