430 lines
13 KiB
Go
430 lines
13 KiB
Go
//go:build llm_script && !scripts_pkg
|
|
|
|
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
const defaultManualSubscriptionSeedPath = "seeds/subscription_plan_manual_seed.json"
|
|
|
|
type manualSubscriptionImportConfig struct {
|
|
SeedPath string
|
|
DryRun bool
|
|
}
|
|
|
|
type manualSubscriptionSeedEnvelope struct {
|
|
CheckedAt string `json:"checkedAt"`
|
|
Items []manualSubscriptionSeedItem `json:"items"`
|
|
}
|
|
|
|
type manualSubscriptionSeedItem struct {
|
|
ProviderName string `json:"providerName"`
|
|
ProviderNameCn string `json:"providerNameCn"`
|
|
ProviderCountry string `json:"providerCountry"`
|
|
ProviderWebsite string `json:"providerWebsite"`
|
|
OperatorName string `json:"operatorName"`
|
|
OperatorNameCn string `json:"operatorNameCn"`
|
|
OperatorCountry string `json:"operatorCountry"`
|
|
OperatorWebsite string `json:"operatorWebsite"`
|
|
OperatorType string `json:"operatorType"`
|
|
PlanFamily string `json:"planFamily"`
|
|
PlanCode string `json:"planCode"`
|
|
PlanName string `json:"planName"`
|
|
Tier string `json:"tier"`
|
|
BillingCycle string `json:"billingCycle"`
|
|
Currency string `json:"currency"`
|
|
ListPrice float64 `json:"listPrice"`
|
|
PriceUnit string `json:"priceUnit"`
|
|
QuotaValue int64 `json:"quotaValue"`
|
|
QuotaUnit string `json:"quotaUnit"`
|
|
ContextWindow int `json:"contextWindow"`
|
|
PlanScope string `json:"planScope"`
|
|
ModelScope []string `json:"modelScope"`
|
|
SourceURL string `json:"sourceURL"`
|
|
PublishedAt string `json:"publishedAt"`
|
|
EffectiveDate string `json:"effectiveDate"`
|
|
Notes string `json:"notes"`
|
|
}
|
|
|
|
type manualSubscriptionRow struct {
|
|
ProviderName string
|
|
ProviderNameCn string
|
|
ProviderCountry string
|
|
ProviderWebsite string
|
|
OperatorName string
|
|
OperatorNameCn string
|
|
OperatorCountry string
|
|
OperatorWebsite string
|
|
OperatorType string
|
|
PlanFamily string
|
|
PlanCode string
|
|
PlanName string
|
|
Tier string
|
|
BillingCycle string
|
|
Currency string
|
|
ListPrice float64
|
|
PriceUnit string
|
|
QuotaValue int64
|
|
QuotaUnit string
|
|
ContextWindow int
|
|
PlanScope string
|
|
ModelScope string
|
|
SourceURL string
|
|
PublishedAt string
|
|
EffectiveDate string
|
|
Notes string
|
|
}
|
|
|
|
func main() {
|
|
loadManualSubscriptionEnv()
|
|
|
|
var seedPath string
|
|
var dryRun bool
|
|
|
|
flag.StringVar(&seedPath, "seed", defaultManualSubscriptionSeedPath, "手工订阅套餐 seed JSON 路径")
|
|
flag.BoolVar(&dryRun, "dry-run", false, "仅校验并打印摘要,不写入数据库")
|
|
flag.Parse()
|
|
|
|
cfg := manualSubscriptionImportConfig{
|
|
SeedPath: seedPath,
|
|
DryRun: dryRun,
|
|
}
|
|
|
|
var db *sql.DB
|
|
var err error
|
|
if !cfg.DryRun {
|
|
dsn := os.Getenv("DATABASE_URL")
|
|
if dsn == "" {
|
|
dsn = "postgres://long@/llm_intelligence?host=/var/run/postgresql"
|
|
}
|
|
db, err = sql.Open("postgres", dsn)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "open db: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
defer db.Close()
|
|
}
|
|
|
|
if err := runManualSubscriptionImport(cfg, db, os.Stdout); err != nil {
|
|
fmt.Fprintf(os.Stderr, "import_manual_subscription_seed: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func loadManualSubscriptionEnv() {
|
|
for _, path := range []string{".env.local", ".env"} {
|
|
loadManualSubscriptionEnvFile(path)
|
|
}
|
|
}
|
|
|
|
func loadManualSubscriptionEnvFile(path string) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
for _, line := range strings.Split(string(data), "\n") {
|
|
line = strings.TrimSpace(line)
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
key, value, ok := strings.Cut(line, "=")
|
|
if !ok {
|
|
continue
|
|
}
|
|
key = strings.TrimSpace(key)
|
|
value = strings.Trim(strings.TrimSpace(value), `"'`)
|
|
if key == "" {
|
|
continue
|
|
}
|
|
if _, exists := os.LookupEnv(key); exists {
|
|
continue
|
|
}
|
|
_ = os.Setenv(key, value)
|
|
}
|
|
}
|
|
|
|
func runManualSubscriptionImport(cfg manualSubscriptionImportConfig, db *sql.DB, out io.Writer) error {
|
|
envelope, err := loadManualSubscriptionSeed(cfg.SeedPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rows, err := buildManualSubscriptionRows(envelope)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(rows) == 0 {
|
|
return fmt.Errorf("seed is empty")
|
|
}
|
|
|
|
if cfg.DryRun {
|
|
_, err = fmt.Fprintf(
|
|
out,
|
|
"source=manual-subscription-seed checked_at=%s rows=%d operators=%s families=%s dry_run=true\n",
|
|
envelope.CheckedAt,
|
|
len(rows),
|
|
summarizeManualCount(rows, func(row manualSubscriptionRow) string { return row.OperatorName }),
|
|
summarizeManualCount(rows, func(row manualSubscriptionRow) string { return row.PlanFamily }),
|
|
)
|
|
return err
|
|
}
|
|
if db == nil {
|
|
return fmt.Errorf("db is required when dry-run=false")
|
|
}
|
|
|
|
if err := upsertManualSubscriptionRows(db, rows); err != nil {
|
|
return err
|
|
}
|
|
|
|
var tableRows int
|
|
if err := db.QueryRow(`SELECT COUNT(*) FROM subscription_plan`).Scan(&tableRows); err != nil {
|
|
return fmt.Errorf("count subscription_plan: %w", err)
|
|
}
|
|
|
|
_, err = fmt.Fprintf(
|
|
out,
|
|
"source=manual-subscription-seed checked_at=%s rows=%d table_rows=%d operators=%s families=%s dry_run=false\n",
|
|
envelope.CheckedAt,
|
|
len(rows),
|
|
tableRows,
|
|
summarizeManualCount(rows, func(row manualSubscriptionRow) string { return row.OperatorName }),
|
|
summarizeManualCount(rows, func(row manualSubscriptionRow) string { return row.PlanFamily }),
|
|
)
|
|
return err
|
|
}
|
|
|
|
func loadManualSubscriptionSeed(path string) (manualSubscriptionSeedEnvelope, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return manualSubscriptionSeedEnvelope{}, fmt.Errorf("read seed %s: %w", path, err)
|
|
}
|
|
|
|
var envelope manualSubscriptionSeedEnvelope
|
|
if err := json.Unmarshal(data, &envelope); err != nil {
|
|
return manualSubscriptionSeedEnvelope{}, fmt.Errorf("unmarshal seed %s: %w", path, err)
|
|
}
|
|
return envelope, nil
|
|
}
|
|
|
|
func buildManualSubscriptionRows(envelope manualSubscriptionSeedEnvelope) ([]manualSubscriptionRow, error) {
|
|
if _, err := time.Parse(time.RFC3339, envelope.CheckedAt); err != nil {
|
|
return nil, fmt.Errorf("parse checkedAt: %w", err)
|
|
}
|
|
|
|
validPlanFamilies := map[string]bool{
|
|
"token_plan": true,
|
|
"coding_plan": true,
|
|
"package_plan": true,
|
|
}
|
|
|
|
rows := make([]manualSubscriptionRow, 0, len(envelope.Items))
|
|
seenCodes := make(map[string]struct{}, len(envelope.Items))
|
|
for _, item := range envelope.Items {
|
|
if strings.TrimSpace(item.PlanCode) == "" {
|
|
return nil, fmt.Errorf("planCode is required")
|
|
}
|
|
if _, exists := seenCodes[item.PlanCode]; exists {
|
|
return nil, fmt.Errorf("duplicate planCode %q", item.PlanCode)
|
|
}
|
|
seenCodes[item.PlanCode] = struct{}{}
|
|
if !validPlanFamilies[item.PlanFamily] {
|
|
return nil, fmt.Errorf("invalid planFamily %q for %s", item.PlanFamily, item.PlanCode)
|
|
}
|
|
if strings.TrimSpace(item.ProviderName) == "" || strings.TrimSpace(item.OperatorName) == "" {
|
|
return nil, fmt.Errorf("provider/operator is required for %s", item.PlanCode)
|
|
}
|
|
if strings.TrimSpace(item.SourceURL) == "" {
|
|
return nil, fmt.Errorf("sourceURL is required for %s", item.PlanCode)
|
|
}
|
|
|
|
modelScope, _ := json.Marshal(item.ModelScope)
|
|
rows = append(rows, manualSubscriptionRow{
|
|
ProviderName: item.ProviderName,
|
|
ProviderNameCn: item.ProviderNameCn,
|
|
ProviderCountry: defaultManualIfEmpty(item.ProviderCountry, "unknown"),
|
|
ProviderWebsite: item.ProviderWebsite,
|
|
OperatorName: item.OperatorName,
|
|
OperatorNameCn: item.OperatorNameCn,
|
|
OperatorCountry: defaultManualIfEmpty(item.OperatorCountry, "unknown"),
|
|
OperatorWebsite: item.OperatorWebsite,
|
|
OperatorType: defaultManualIfEmpty(item.OperatorType, "official"),
|
|
PlanFamily: item.PlanFamily,
|
|
PlanCode: item.PlanCode,
|
|
PlanName: item.PlanName,
|
|
Tier: item.Tier,
|
|
BillingCycle: defaultManualIfEmpty(item.BillingCycle, "monthly"),
|
|
Currency: defaultManualIfEmpty(item.Currency, "CNY"),
|
|
ListPrice: item.ListPrice,
|
|
PriceUnit: item.PriceUnit,
|
|
QuotaValue: item.QuotaValue,
|
|
QuotaUnit: item.QuotaUnit,
|
|
ContextWindow: item.ContextWindow,
|
|
PlanScope: item.PlanScope,
|
|
ModelScope: string(modelScope),
|
|
SourceURL: item.SourceURL,
|
|
PublishedAt: item.PublishedAt,
|
|
EffectiveDate: item.EffectiveDate,
|
|
Notes: item.Notes,
|
|
})
|
|
}
|
|
return rows, nil
|
|
}
|
|
|
|
func upsertManualSubscriptionRows(db *sql.DB, rows []manualSubscriptionRow) error {
|
|
for _, row := range rows {
|
|
providerID, err := ensureManualSubscriptionProvider(db, row)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
operatorID, err := ensureManualSubscriptionOperator(db, row)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
publishedAt, err := time.Parse("2006-01-02 15:04:05", row.PublishedAt)
|
|
if err != nil {
|
|
return fmt.Errorf("parse publishedAt for %s: %w", row.PlanCode, err)
|
|
}
|
|
effectiveDate, err := time.Parse("2006-01-02", row.EffectiveDate)
|
|
if err != nil {
|
|
return fmt.Errorf("parse effectiveDate for %s: %w", row.PlanCode, err)
|
|
}
|
|
|
|
_, err = db.Exec(
|
|
`INSERT INTO subscription_plan (
|
|
provider_id, operator_id, plan_family, plan_code, plan_name, tier,
|
|
billing_cycle, currency, list_price, price_unit, quota_value, quota_unit,
|
|
context_window, plan_scope, model_scope, source_url, published_at, effective_date, notes
|
|
) VALUES (
|
|
$1, $2, $3, $4, $5, $6,
|
|
$7, $8, $9, $10, $11, $12,
|
|
$13, $14, $15, $16, $17, $18, $19
|
|
)
|
|
ON CONFLICT (provider_id, plan_code, effective_date)
|
|
DO UPDATE SET
|
|
operator_id = EXCLUDED.operator_id,
|
|
plan_family = EXCLUDED.plan_family,
|
|
plan_name = EXCLUDED.plan_name,
|
|
tier = EXCLUDED.tier,
|
|
billing_cycle = EXCLUDED.billing_cycle,
|
|
currency = EXCLUDED.currency,
|
|
list_price = EXCLUDED.list_price,
|
|
price_unit = EXCLUDED.price_unit,
|
|
quota_value = EXCLUDED.quota_value,
|
|
quota_unit = EXCLUDED.quota_unit,
|
|
context_window = EXCLUDED.context_window,
|
|
plan_scope = EXCLUDED.plan_scope,
|
|
model_scope = EXCLUDED.model_scope,
|
|
source_url = EXCLUDED.source_url,
|
|
published_at = EXCLUDED.published_at,
|
|
notes = EXCLUDED.notes,
|
|
updated_at = CURRENT_TIMESTAMP`,
|
|
providerID, operatorID, row.PlanFamily, row.PlanCode, row.PlanName, row.Tier,
|
|
row.BillingCycle, row.Currency, row.ListPrice, row.PriceUnit, manualNullInt64(row.QuotaValue), manualNullIfEmpty(row.QuotaUnit),
|
|
manualNullInt(row.ContextWindow), manualNullIfEmpty(row.PlanScope), row.ModelScope, row.SourceURL, publishedAt, effectiveDate, manualNullIfEmpty(row.Notes),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("upsert subscription_plan %s: %w", row.PlanCode, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ensureManualSubscriptionProvider(db *sql.DB, row manualSubscriptionRow) (int64, error) {
|
|
var providerID int64
|
|
err := db.QueryRow(`SELECT id FROM model_provider WHERE name = $1`, row.ProviderName).Scan(&providerID)
|
|
if err == nil {
|
|
return providerID, nil
|
|
}
|
|
if err != sql.ErrNoRows {
|
|
return 0, err
|
|
}
|
|
|
|
err = db.QueryRow(
|
|
`INSERT INTO model_provider (name, name_cn, country, website, status)
|
|
VALUES ($1, $2, $3, $4, 'active')
|
|
RETURNING id`,
|
|
row.ProviderName, manualNullIfEmpty(row.ProviderNameCn), row.ProviderCountry, manualNullIfEmpty(row.ProviderWebsite),
|
|
).Scan(&providerID)
|
|
return providerID, err
|
|
}
|
|
|
|
func ensureManualSubscriptionOperator(db *sql.DB, row manualSubscriptionRow) (int64, error) {
|
|
var operatorID int64
|
|
err := db.QueryRow(`SELECT id FROM operator WHERE name = $1`, row.OperatorName).Scan(&operatorID)
|
|
if err == nil {
|
|
return operatorID, nil
|
|
}
|
|
if err != sql.ErrNoRows {
|
|
return 0, err
|
|
}
|
|
|
|
err = db.QueryRow(
|
|
`INSERT INTO operator (name, name_cn, country, website, description, status, type)
|
|
VALUES ($1, $2, $3, $4, $5, 'active', $6)
|
|
RETURNING id`,
|
|
row.OperatorName, manualNullIfEmpty(row.OperatorNameCn), row.OperatorCountry, manualNullIfEmpty(row.OperatorWebsite),
|
|
fmt.Sprintf("%s manual subscription seed", row.OperatorName), row.OperatorType,
|
|
).Scan(&operatorID)
|
|
return operatorID, err
|
|
}
|
|
|
|
func summarizeManualCount(rows []manualSubscriptionRow, getter func(manualSubscriptionRow) string) string {
|
|
counts := make(map[string]int)
|
|
keys := make([]string, 0)
|
|
for _, row := range rows {
|
|
key := getter(row)
|
|
if _, exists := counts[key]; !exists {
|
|
keys = append(keys, key)
|
|
}
|
|
counts[key]++
|
|
}
|
|
sort.Strings(keys)
|
|
|
|
parts := make([]string, 0, len(keys))
|
|
for _, key := range keys {
|
|
parts = append(parts, fmt.Sprintf("%s:%d", key, counts[key]))
|
|
}
|
|
return strings.Join(parts, ",")
|
|
}
|
|
|
|
func defaultManualIfEmpty(value string, fallback string) string {
|
|
if strings.TrimSpace(value) == "" {
|
|
return fallback
|
|
}
|
|
return value
|
|
}
|
|
|
|
func manualNullIfEmpty(value string) any {
|
|
if strings.TrimSpace(value) == "" {
|
|
return nil
|
|
}
|
|
return value
|
|
}
|
|
|
|
func manualNullInt(value int) any {
|
|
if value == 0 {
|
|
return nil
|
|
}
|
|
return value
|
|
}
|
|
|
|
func manualNullInt64(value int64) any {
|
|
if value == 0 {
|
|
return nil
|
|
}
|
|
return value
|
|
}
|