Files
tokens-reef/backend/internal/setup/setup_test.go
User db307b0d0f fix(security): add SQL injection defense for CREATE DATABASE
Add quoteIdentifier() function to safely quote PostgreSQL identifiers
following PostgreSQL's quoting rules (wrap in double quotes, escape
internal quotes by doubling).

This provides defense-in-depth for the CREATE DATABASE statement,
complementing the existing validateDBName() input validation.

Changes:
- Add quoteIdentifier() function with proper escaping
- Use quoted identifier in CREATE DATABASE statement
- Add comprehensive unit tests for quoteIdentifier()
2026-04-16 20:28:36 +08:00

142 lines
3.2 KiB
Go

package setup
import (
"os"
"strings"
"testing"
)
func TestDecideAdminBootstrap(t *testing.T) {
t.Parallel()
tests := []struct {
name string
totalUsers int64
adminUsers int64
should bool
reason string
}{
{
name: "empty database should create admin",
totalUsers: 0,
adminUsers: 0,
should: true,
reason: adminBootstrapReasonEmptyDatabase,
},
{
name: "admin exists should skip",
totalUsers: 10,
adminUsers: 1,
should: false,
reason: adminBootstrapReasonAdminExists,
},
{
name: "users exist without admin should skip",
totalUsers: 5,
adminUsers: 0,
should: false,
reason: adminBootstrapReasonUsersExistWithoutAdmin,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := decideAdminBootstrap(tc.totalUsers, tc.adminUsers)
if got.shouldCreate != tc.should {
t.Fatalf("shouldCreate=%v, want %v", got.shouldCreate, tc.should)
}
if got.reason != tc.reason {
t.Fatalf("reason=%q, want %q", got.reason, tc.reason)
}
})
}
}
func TestSetupDefaultAdminConcurrency(t *testing.T) {
t.Run("simple mode admin uses higher concurrency", func(t *testing.T) {
t.Setenv("RUN_MODE", "simple")
if got := setupDefaultAdminConcurrency(); got != simpleModeAdminConcurrency {
t.Fatalf("setupDefaultAdminConcurrency()=%d, want %d", got, simpleModeAdminConcurrency)
}
})
t.Run("standard mode keeps existing default", func(t *testing.T) {
t.Setenv("RUN_MODE", "standard")
if got := setupDefaultAdminConcurrency(); got != defaultUserConcurrency {
t.Fatalf("setupDefaultAdminConcurrency()=%d, want %d", got, defaultUserConcurrency)
}
})
}
func TestWriteConfigFileKeepsDefaultUserConcurrency(t *testing.T) {
t.Setenv("RUN_MODE", "simple")
t.Setenv("DATA_DIR", t.TempDir())
if err := writeConfigFile(&SetupConfig{}); err != nil {
t.Fatalf("writeConfigFile() error = %v", err)
}
data, err := os.ReadFile(GetConfigFilePath())
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
if !strings.Contains(string(data), "user_concurrency: 5") {
t.Fatalf("config missing default user concurrency, got:\n%s", string(data))
}
}
func TestQuoteIdentifier(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expected string
}{
{
name: "simple name",
input: "mydb",
expected: `"mydb"`,
},
{
name: "name with underscore",
input: "my_db_123",
expected: `"my_db_123"`,
},
{
name: "name with double quote (injection attempt)",
input: `my"; DROP TABLE users; --`,
expected: `"my""; DROP TABLE users; --"`,
},
{
name: "name with multiple double quotes",
input: `my"db"test`,
expected: `"my""db""test"`,
},
{
name: "empty name",
input: "",
expected: `""`,
},
{
name: "name starting with number",
input: "123db",
expected: `"123db"`,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := quoteIdentifier(tc.input)
if got != tc.expected {
t.Fatalf("quoteIdentifier(%q) = %q, want %q", tc.input, got, tc.expected)
}
})
}
}