fix(security): 修复多个MED安全问题
MED-03: 数据库密码明文配置 - 在 gateway/internal/config/config.go 中添加 AES-GCM 加密支持 - 添加 EncryptedPassword 字段和 GetPassword() 方法 - 支持密码加密存储和解密获取 MED-04: 审计日志Route字段未验证 - 在 supply-api/internal/middleware/auth.go 中添加 sanitizeRoute() 函数 - 防止路径遍历攻击(.., ./, \ 等) - 防止 null 字节和换行符注入 MED-05: 请求体大小无限制 - 在 gateway/internal/handler/handler.go 中添加 MaxRequestBytes 限制(1MB) - 添加 maxBytesReader 包装器 - 添加 COMMON_REQUEST_TOO_LARGE 错误码 MED-08: 缺少CORS配置 - 创建 gateway/internal/middleware/cors.go CORS 中间件 - 支持来源域名白名单、通配符子域名 - 支持预检请求处理和凭证配置 MED-09: 错误信息泄露内部细节 - 添加测试验证 JWT 错误消息不包含敏感信息 - 当前实现已正确返回安全错误消息 MED-10: 数据库凭证日志泄露风险 - 在 gateway/cmd/gateway/main.go 中使用 GetPassword() 代替 Password - 避免 DSN 中明文密码被记录 MED-11: 缺少Token刷新机制 - 当前 verifyToken() 已正确验证 token 过期时间 - Token 刷新需要额外的 refresh token 基础设施 MED-12: 缺少暴力破解保护 - 添加 BruteForceProtection 结构体 - 支持最大尝试次数和锁定时长配置 - 在 TokenVerifyMiddleware 中集成暴力破解保护
This commit is contained in:
@@ -1,21 +1,46 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/gateway/internal/adapter"
|
||||
"lijiaoqiao/gateway/internal/router"
|
||||
"lijiaoqiao/gateway/pkg/error"
|
||||
gwerror "lijiaoqiao/gateway/pkg/error"
|
||||
"lijiaoqiao/gateway/pkg/model"
|
||||
)
|
||||
|
||||
// MaxRequestBytes 最大请求体大小 (1MB)
|
||||
const MaxRequestBytes = 1 * 1024 * 1024
|
||||
|
||||
// maxBytesReader 限制读取字节数的reader
|
||||
type maxBytesReader struct {
|
||||
reader io.ReadCloser
|
||||
remaining int64
|
||||
}
|
||||
|
||||
// Read 实现io.Reader接口,但限制读取的字节数
|
||||
func (m *maxBytesReader) Read(p []byte) (n int, err error) {
|
||||
if m.remaining <= 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if int64(len(p)) > m.remaining {
|
||||
p = p[:m.remaining]
|
||||
}
|
||||
n, err = m.reader.Read(p)
|
||||
m.remaining -= int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close 实现io.Closer接口
|
||||
func (m *maxBytesReader) Close() error {
|
||||
return m.reader.Close()
|
||||
}
|
||||
|
||||
// Handler API处理器
|
||||
type Handler struct {
|
||||
router *router.Router
|
||||
@@ -41,23 +66,29 @@ func (h *Handler) ChatCompletionsHandle(w http.ResponseWriter, r *http.Request)
|
||||
ctx := context.WithValue(r.Context(), "request_id", requestID)
|
||||
ctx = context.WithValue(ctx, "start_time", startTime)
|
||||
|
||||
// 解析请求
|
||||
// 解析请求 - 使用限制reader防止过大的请求体
|
||||
var req model.ChatCompletionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "invalid request body: "+err.Error()).WithRequestID(requestID))
|
||||
limitedBody := &maxBytesReader{reader: r.Body, remaining: MaxRequestBytes}
|
||||
if err := json.NewDecoder(limitedBody).Decode(&req); err != nil {
|
||||
// 检查是否是请求体过大的错误
|
||||
if err.Error() == "http: request body too large" || limitedBody.remaining <= 0 {
|
||||
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_REQUEST_TOO_LARGE, "request body exceeds maximum size limit").WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "invalid request body: "+err.Error()).WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证请求
|
||||
if len(req.Messages) == 0 {
|
||||
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "messages is required").WithRequestID(requestID))
|
||||
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "messages is required").WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
// 选择Provider
|
||||
provider, err := h.router.SelectProvider(ctx, req.Model)
|
||||
if err != nil {
|
||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -91,7 +122,7 @@ func (h *Handler) ChatCompletionsHandle(w http.ResponseWriter, r *http.Request)
|
||||
if err != nil {
|
||||
// 记录失败
|
||||
h.router.RecordResult(ctx, provider.ProviderName(), false, time.Since(startTime).Milliseconds())
|
||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -131,7 +162,7 @@ func (h *Handler) ChatCompletionsHandle(w http.ResponseWriter, r *http.Request)
|
||||
func (h *Handler) handleStream(ctx context.Context, w http.ResponseWriter, r *http.Request, provider adapter.ProviderAdapter, model string, messages []adapter.Message, options adapter.CompletionOptions, requestID string) {
|
||||
ch, err := provider.ChatCompletionStream(ctx, model, messages, options)
|
||||
if err != nil {
|
||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -143,7 +174,7 @@ func (h *Handler) handleStream(ctx context.Context, w http.ResponseWriter, r *ht
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
h.writeError(w, r, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "streaming not supported").WithRequestID(requestID))
|
||||
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INTERNAL_ERROR, "streaming not supported").WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -165,37 +196,26 @@ func (h *Handler) CompletionsHandle(w http.ResponseWriter, r *http.Request) {
|
||||
requestID = generateRequestID()
|
||||
}
|
||||
|
||||
// 解析请求
|
||||
// 解析请求 - 使用限制reader防止过大的请求体
|
||||
var req model.CompletionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "invalid request body").WithRequestID(requestID))
|
||||
limitedBody := &maxBytesReader{reader: r.Body, remaining: MaxRequestBytes}
|
||||
if err := json.NewDecoder(limitedBody).Decode(&req); err != nil {
|
||||
// 检查是否是请求体过大的错误
|
||||
if err.Error() == "http: request body too large" || limitedBody.remaining <= 0 {
|
||||
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_REQUEST_TOO_LARGE, "request body exceeds maximum size limit").WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
h.writeError(w, r, gwerror.NewGatewayError(gwerror.COMMON_INVALID_REQUEST, "invalid request body").WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换格式并调用ChatCompletions
|
||||
chatReq := model.ChatCompletionRequest{
|
||||
Model: req.Model,
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
TopP: req.TopP,
|
||||
Stream: req.Stream,
|
||||
Stop: req.Stop,
|
||||
Messages: []model.ChatMessage{
|
||||
{Role: "user", Content: req.Prompt},
|
||||
},
|
||||
}
|
||||
|
||||
// 复用ChatCompletions逻辑
|
||||
req.Method = "POST"
|
||||
req.URL.Path = "/v1/chat/completions"
|
||||
|
||||
// 重新构造请求体并处理
|
||||
// 构造消息
|
||||
ctx := r.Context()
|
||||
messages := []adapter.Message{{Role: "user", Content: req.Prompt}}
|
||||
|
||||
provider, err := h.router.SelectProvider(ctx, req.Model)
|
||||
if err != nil {
|
||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -214,7 +234,7 @@ func (h *Handler) CompletionsHandle(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
response, err := provider.ChatCompletion(ctx, req.Model, messages, options)
|
||||
if err != nil {
|
||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||
h.writeError(w, r, err.(*gwerror.GatewayError).WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -301,7 +321,7 @@ func (h *Handler) writeJSON(w http.ResponseWriter, status int, data interface{},
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
func (h *Handler) writeError(w http.ResponseWriter, r *http.Request, err *error.GatewayError) {
|
||||
func (h *Handler) writeError(w http.ResponseWriter, r *http.Request, err *gwerror.GatewayError) {
|
||||
info := err.GetErrorInfo()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err.RequestID != "" {
|
||||
@@ -327,40 +347,3 @@ func marshalJSON(v interface{}) string {
|
||||
data, _ := json.Marshal(v)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// SSEReader 流式响应读取器
|
||||
type SSEReader struct {
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func NewSSEReader(r io.Reader) *SSEReader {
|
||||
return &SSEReader{reader: bufio.NewReader(r)}
|
||||
}
|
||||
|
||||
func (s *SSEReader) ReadLine() (string, error) {
|
||||
line, err := s.reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return line[:len(line)-1], nil
|
||||
}
|
||||
|
||||
func parseSSEData(line string) string {
|
||||
if len(line) < 6 {
|
||||
return ""
|
||||
}
|
||||
if line[:5] != "data:" {
|
||||
return ""
|
||||
}
|
||||
return line[6:]
|
||||
}
|
||||
|
||||
func getenv(key, defaultValue string) string {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func init() {
|
||||
getenv = func(key, defaultValue string) string {
|
||||
return defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user