Files
lijiaoqiao/gateway/internal/app/bootstrap.go

223 lines
6.2 KiB
Go

package app
import (
"fmt"
"net/http"
"strings"
"time"
"lijiaoqiao/gateway/internal/config"
"lijiaoqiao/gateway/internal/handler"
"lijiaoqiao/gateway/internal/middleware"
"lijiaoqiao/gateway/internal/ratelimit"
"lijiaoqiao/gateway/internal/router"
)
func BuildServer(cfg *config.Config) (*http.Server, error) {
if cfg == nil {
return nil, fmt.Errorf("config is required")
}
normalized := normalizeConfig(*cfg)
if err := config.ValidateAuthConfig(normalized.Auth); err != nil {
return nil, err
}
r, err := buildRouter(&normalized)
if err != nil {
return nil, err
}
limiter := buildLimiter(normalized.RateLimit)
auditor, err := buildAuditor(normalized)
if err != nil {
return nil, err
}
tokenRuntime, err := buildTokenRuntime(normalized.Auth)
if err != nil {
return nil, err
}
authConfig := middleware.AuthMiddlewareConfig{
Verifier: tokenRuntime,
StatusResolver: tokenRuntime,
Authorizer: middleware.NewScopeRoleAuthorizer(),
Auditor: auditor,
ProtectedPrefixes: []string{
"/v1/chat/completions",
"/v1/completions",
"/api/v1/chat/completions",
"/api/v1/completions",
"/api/v1/supply",
"/api/v1/platform",
},
ExcludedPrefixes: []string{"/health", "/healthz", "/metrics", "/readyz"},
Now: time.Now,
}
h := handler.NewHandler(r)
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", normalized.Server.Host, normalized.Server.Port),
Handler: BuildMux(h, limiter, authConfig),
ReadTimeout: normalized.Server.ReadTimeout,
WriteTimeout: normalized.Server.WriteTimeout,
IdleTimeout: normalized.Server.IdleTimeout,
}
return server, nil
}
func BuildMux(h *handler.Handler, limiter *ratelimit.Middleware, authConfig middleware.AuthMiddlewareConfig) http.Handler {
mux := http.NewServeMux()
chatHandler := middleware.BuildTokenAuthChain(authConfig, http.HandlerFunc(h.ChatCompletionsHandle))
completionsHandler := middleware.BuildTokenAuthChain(authConfig, http.HandlerFunc(h.CompletionsHandle))
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
limitHandler(limiter, chatHandler).ServeHTTP(w, r)
})
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
limitHandler(limiter, completionsHandler).ServeHTTP(w, r)
})
mux.HandleFunc("/v1/models", h.ModelsHandle)
mux.HandleFunc("/api/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
limitHandler(limiter, chatHandler).ServeHTTP(w, r)
})
mux.HandleFunc("/api/v1/completions", func(w http.ResponseWriter, r *http.Request) {
limitHandler(limiter, completionsHandler).ServeHTTP(w, r)
})
mux.HandleFunc("/health", h.HealthHandle)
mux.HandleFunc("/healthz", h.HealthHandle)
mux.HandleFunc("/readyz", h.HealthHandle)
return middleware.CORSMiddleware(middleware.DefaultCORSConfig())(mux)
}
func buildRouter(cfg *config.Config) (*router.Router, error) {
if len(cfg.Providers) == 0 {
return nil, fmt.Errorf("at least one provider must be configured")
}
r := router.NewRouter(resolveStrategy(cfg.Router.Strategy))
for _, providerCfg := range cfg.Providers {
provider, err := buildProvider(providerCfg)
if err != nil {
return nil, err
}
r.RegisterProvider(providerCfg.Name, provider)
}
return r, nil
}
func buildLimiter(cfg config.RateLimitConfig) *ratelimit.Middleware {
if strings.EqualFold(cfg.Algorithm, "sliding_window") {
limiter := ratelimit.NewSlidingWindowLimiter(time.Minute, cfg.DefaultRPM)
return ratelimit.NewMiddleware(limiter)
}
limiter := ratelimit.NewTokenBucketLimiter(cfg.DefaultRPM, cfg.DefaultTPM, cfg.BurstMultiplier)
return ratelimit.NewMiddleware(limiter)
}
func buildAuditor(cfg config.Config) (middleware.AuditEmitter, error) {
if strings.TrimSpace(cfg.Database.Host) == "" {
return middleware.NewMemoryAuditEmitter(), nil
}
dsn := fmt.Sprintf(
"postgres://%s:%s@%s:%d/%s?sslmode=disable",
cfg.Database.User,
cfg.Database.GetPassword(),
cfg.Database.Host,
cfg.Database.Port,
cfg.Database.Database,
)
auditor, err := middleware.NewDatabaseAuditEmitter(dsn, time.Now)
if err != nil {
return nil, fmt.Errorf("create database audit emitter: %w", err)
}
return auditor, nil
}
func buildTokenRuntime(cfg config.AuthConfig) (interface {
middleware.TokenVerifier
middleware.TokenStatusResolver
}, error) {
switch strings.ToLower(strings.TrimSpace(cfg.TokenRuntimeMode)) {
case "", "inmemory":
return middleware.NewInMemoryTokenRuntime(time.Now), nil
case "remote_introspection":
return middleware.NewRemoteTokenRuntime(cfg.TokenRuntimeURL, http.DefaultClient, time.Now), nil
default:
return nil, fmt.Errorf("unsupported token runtime mode: %s", cfg.TokenRuntimeMode)
}
}
func resolveStrategy(strategy string) router.LoadBalancerStrategy {
switch strings.ToLower(strings.TrimSpace(strategy)) {
case string(router.StrategyRoundRobin):
return router.StrategyRoundRobin
case string(router.StrategyWeighted):
return router.StrategyWeighted
case string(router.StrategyAvailability):
return router.StrategyAvailability
default:
return router.StrategyLatency
}
}
func normalizeConfig(cfg config.Config) config.Config {
if strings.TrimSpace(cfg.Server.Host) == "" {
cfg.Server.Host = "0.0.0.0"
}
if cfg.Server.Port == 0 {
cfg.Server.Port = 8080
}
if cfg.Server.ReadTimeout == 0 {
cfg.Server.ReadTimeout = 30 * time.Second
}
if cfg.Server.WriteTimeout == 0 {
cfg.Server.WriteTimeout = 30 * time.Second
}
if cfg.Server.IdleTimeout == 0 {
cfg.Server.IdleTimeout = 120 * time.Second
}
if cfg.Router.Strategy == "" {
cfg.Router.Strategy = string(router.StrategyLatency)
}
if cfg.RateLimit.DefaultRPM == 0 {
cfg.RateLimit.DefaultRPM = 60
}
if cfg.RateLimit.DefaultTPM == 0 {
cfg.RateLimit.DefaultTPM = 60000
}
if cfg.RateLimit.BurstMultiplier == 0 {
cfg.RateLimit.BurstMultiplier = 1.5
}
if cfg.RateLimit.Algorithm == "" {
cfg.RateLimit.Algorithm = "token_bucket"
}
if cfg.Auth.TokenRuntimeMode == "" {
cfg.Auth.TokenRuntimeMode = "inmemory"
}
return cfg
}
func limitHandler(limiter *ratelimit.Middleware, next http.Handler) http.Handler {
if limiter == nil {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limiter.Limit(next.ServeHTTP)(w, r)
})
}