223 lines
6.2 KiB
Go
223 lines
6.2 KiB
Go
package app
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"lijiaoqiao/gateway/internal/config"
|
|
"lijiaoqiao/gateway/internal/handler"
|
|
"lijiaoqiao/gateway/internal/middleware"
|
|
"lijiaoqiao/gateway/internal/ratelimit"
|
|
"lijiaoqiao/gateway/internal/router"
|
|
)
|
|
|
|
func BuildServer(cfg *config.Config) (*http.Server, error) {
|
|
if cfg == nil {
|
|
return nil, fmt.Errorf("config is required")
|
|
}
|
|
|
|
normalized := normalizeConfig(*cfg)
|
|
|
|
if err := config.ValidateAuthConfig(normalized.Auth); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r, err := buildRouter(&normalized)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
limiter := buildLimiter(normalized.RateLimit)
|
|
auditor, err := buildAuditor(normalized)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tokenRuntime, err := buildTokenRuntime(normalized.Auth)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
authConfig := middleware.AuthMiddlewareConfig{
|
|
Verifier: tokenRuntime,
|
|
StatusResolver: tokenRuntime,
|
|
Authorizer: middleware.NewScopeRoleAuthorizer(),
|
|
Auditor: auditor,
|
|
ProtectedPrefixes: []string{
|
|
"/v1/chat/completions",
|
|
"/v1/completions",
|
|
"/api/v1/chat/completions",
|
|
"/api/v1/completions",
|
|
"/api/v1/supply",
|
|
"/api/v1/platform",
|
|
},
|
|
ExcludedPrefixes: []string{"/health", "/healthz", "/metrics", "/readyz"},
|
|
Now: time.Now,
|
|
}
|
|
|
|
h := handler.NewHandler(r)
|
|
server := &http.Server{
|
|
Addr: fmt.Sprintf("%s:%d", normalized.Server.Host, normalized.Server.Port),
|
|
Handler: BuildMux(h, limiter, authConfig),
|
|
ReadTimeout: normalized.Server.ReadTimeout,
|
|
WriteTimeout: normalized.Server.WriteTimeout,
|
|
IdleTimeout: normalized.Server.IdleTimeout,
|
|
}
|
|
|
|
return server, nil
|
|
}
|
|
|
|
func BuildMux(h *handler.Handler, limiter *ratelimit.Middleware, authConfig middleware.AuthMiddlewareConfig) http.Handler {
|
|
mux := http.NewServeMux()
|
|
|
|
chatHandler := middleware.BuildTokenAuthChain(authConfig, http.HandlerFunc(h.ChatCompletionsHandle))
|
|
completionsHandler := middleware.BuildTokenAuthChain(authConfig, http.HandlerFunc(h.CompletionsHandle))
|
|
|
|
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
|
limitHandler(limiter, chatHandler).ServeHTTP(w, r)
|
|
})
|
|
|
|
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
|
limitHandler(limiter, completionsHandler).ServeHTTP(w, r)
|
|
})
|
|
|
|
mux.HandleFunc("/v1/models", h.ModelsHandle)
|
|
|
|
mux.HandleFunc("/api/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
|
limitHandler(limiter, chatHandler).ServeHTTP(w, r)
|
|
})
|
|
mux.HandleFunc("/api/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
|
limitHandler(limiter, completionsHandler).ServeHTTP(w, r)
|
|
})
|
|
|
|
mux.HandleFunc("/health", h.HealthHandle)
|
|
mux.HandleFunc("/healthz", h.HealthHandle)
|
|
mux.HandleFunc("/readyz", h.HealthHandle)
|
|
|
|
return middleware.CORSMiddleware(middleware.DefaultCORSConfig())(mux)
|
|
}
|
|
|
|
func buildRouter(cfg *config.Config) (*router.Router, error) {
|
|
if len(cfg.Providers) == 0 {
|
|
return nil, fmt.Errorf("at least one provider must be configured")
|
|
}
|
|
|
|
r := router.NewRouter(resolveStrategy(cfg.Router.Strategy))
|
|
for _, providerCfg := range cfg.Providers {
|
|
provider, err := buildProvider(providerCfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
r.RegisterProvider(providerCfg.Name, provider)
|
|
}
|
|
|
|
return r, nil
|
|
}
|
|
|
|
func buildLimiter(cfg config.RateLimitConfig) *ratelimit.Middleware {
|
|
if strings.EqualFold(cfg.Algorithm, "sliding_window") {
|
|
limiter := ratelimit.NewSlidingWindowLimiter(time.Minute, cfg.DefaultRPM)
|
|
return ratelimit.NewMiddleware(limiter)
|
|
}
|
|
|
|
limiter := ratelimit.NewTokenBucketLimiter(cfg.DefaultRPM, cfg.DefaultTPM, cfg.BurstMultiplier)
|
|
return ratelimit.NewMiddleware(limiter)
|
|
}
|
|
|
|
func buildAuditor(cfg config.Config) (middleware.AuditEmitter, error) {
|
|
if strings.TrimSpace(cfg.Database.Host) == "" {
|
|
return middleware.NewMemoryAuditEmitter(), nil
|
|
}
|
|
|
|
dsn := fmt.Sprintf(
|
|
"postgres://%s:%s@%s:%d/%s?sslmode=disable",
|
|
cfg.Database.User,
|
|
cfg.Database.GetPassword(),
|
|
cfg.Database.Host,
|
|
cfg.Database.Port,
|
|
cfg.Database.Database,
|
|
)
|
|
|
|
auditor, err := middleware.NewDatabaseAuditEmitter(dsn, time.Now)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create database audit emitter: %w", err)
|
|
}
|
|
|
|
return auditor, nil
|
|
}
|
|
|
|
func buildTokenRuntime(cfg config.AuthConfig) (interface {
|
|
middleware.TokenVerifier
|
|
middleware.TokenStatusResolver
|
|
}, error) {
|
|
switch strings.ToLower(strings.TrimSpace(cfg.TokenRuntimeMode)) {
|
|
case "", "inmemory":
|
|
return middleware.NewInMemoryTokenRuntime(time.Now), nil
|
|
case "remote_introspection":
|
|
return middleware.NewRemoteTokenRuntime(cfg.TokenRuntimeURL, http.DefaultClient, time.Now), nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported token runtime mode: %s", cfg.TokenRuntimeMode)
|
|
}
|
|
}
|
|
|
|
func resolveStrategy(strategy string) router.LoadBalancerStrategy {
|
|
switch strings.ToLower(strings.TrimSpace(strategy)) {
|
|
case string(router.StrategyRoundRobin):
|
|
return router.StrategyRoundRobin
|
|
case string(router.StrategyWeighted):
|
|
return router.StrategyWeighted
|
|
case string(router.StrategyAvailability):
|
|
return router.StrategyAvailability
|
|
default:
|
|
return router.StrategyLatency
|
|
}
|
|
}
|
|
|
|
func normalizeConfig(cfg config.Config) config.Config {
|
|
if strings.TrimSpace(cfg.Server.Host) == "" {
|
|
cfg.Server.Host = "0.0.0.0"
|
|
}
|
|
if cfg.Server.Port == 0 {
|
|
cfg.Server.Port = 8080
|
|
}
|
|
if cfg.Server.ReadTimeout == 0 {
|
|
cfg.Server.ReadTimeout = 30 * time.Second
|
|
}
|
|
if cfg.Server.WriteTimeout == 0 {
|
|
cfg.Server.WriteTimeout = 30 * time.Second
|
|
}
|
|
if cfg.Server.IdleTimeout == 0 {
|
|
cfg.Server.IdleTimeout = 120 * time.Second
|
|
}
|
|
if cfg.Router.Strategy == "" {
|
|
cfg.Router.Strategy = string(router.StrategyLatency)
|
|
}
|
|
if cfg.RateLimit.DefaultRPM == 0 {
|
|
cfg.RateLimit.DefaultRPM = 60
|
|
}
|
|
if cfg.RateLimit.DefaultTPM == 0 {
|
|
cfg.RateLimit.DefaultTPM = 60000
|
|
}
|
|
if cfg.RateLimit.BurstMultiplier == 0 {
|
|
cfg.RateLimit.BurstMultiplier = 1.5
|
|
}
|
|
if cfg.RateLimit.Algorithm == "" {
|
|
cfg.RateLimit.Algorithm = "token_bucket"
|
|
}
|
|
if cfg.Auth.TokenRuntimeMode == "" {
|
|
cfg.Auth.TokenRuntimeMode = "inmemory"
|
|
}
|
|
return cfg
|
|
}
|
|
|
|
func limitHandler(limiter *ratelimit.Middleware, next http.Handler) http.Handler {
|
|
if limiter == nil {
|
|
return next
|
|
}
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
limiter.Limit(next.ServeHTTP)(w, r)
|
|
})
|
|
}
|