Files
lijiaoqiao/supply-api/internal/middleware/tracing.go

188 lines
5.0 KiB
Go
Raw Normal View History

package middleware
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net/http"
)
// ==================== P1-006 分布式追踪集成 ====================
// W3C Trace Context 标准实现
// 参考: https://www.w3.org/TR/trace-context/
// TraceContext Trace上下文
type TraceContext struct {
TraceID string // 追踪ID (32字符十六进制)
SpanID string // Span ID (16字符十六进制)
TraceFlags string // 追踪标志 (01 = sampled)
}
// W3C Trace Context Header格式
// traceparent: 00-{trace-id}-{span-id}-{trace-flags}
// 例如: 00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01
const (
// TraceContextVersion 追踪上下文版本
TraceContextVersion = "00"
// TraceFlagSampled 采样标志
TraceFlagSampled = "01"
// TraceFlagNotSampled 未采样标志
TraceFlagNotSampled = "00"
)
// TraceContextKey Trace上下文在context中的key
type traceContextKey struct{}
// WithTraceContext 在context中设置追踪上下文
func WithTraceContext(ctx context.Context, tc *TraceContext) context.Context {
return context.WithValue(ctx, traceContextKey{}, tc)
}
// GetTraceContext 从context获取追踪上下文
func GetTraceContext(ctx context.Context) (*TraceContext, bool) {
if tc, ok := ctx.Value(traceContextKey{}).(*TraceContext); ok {
return tc, true
}
return nil, false
}
// ParseTraceParent 解析traceparent header
func ParseTraceParent(traceParent string) (*TraceContext, error) {
if traceParent == "" {
return nil, fmt.Errorf("traceparent header is empty")
}
// 格式: 00-{trace-id}-{span-id}-{trace-flags}
// 长度检查
if len(traceParent) < 55 { // 00- + 32 + - + 16 + - + 02
return nil, fmt.Errorf("invalid traceparent format")
}
// 检查版本
version := traceParent[0:2]
if version != TraceContextVersion {
return nil, fmt.Errorf("unsupported trace context version: %s", version)
}
// 提取各部分
// 00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01
// 0123456789012345678901234567890123456789012345678901234
// 0 1 2 3 4 5
traceID := traceParent[3:35]
spanID := traceParent[36:52]
traceFlags := traceParent[53:55]
// 验证trace-id长度 (必须是32字符)
if len(traceID) != 32 {
return nil, fmt.Errorf("invalid trace-id length: %d", len(traceID))
}
// 验证span-id长度 (必须是16字符)
if len(spanID) != 16 {
return nil, fmt.Errorf("invalid span-id length: %d", len(spanID))
}
// 验证trace-flags
if traceFlags != TraceFlagSampled && traceFlags != TraceFlagNotSampled {
return nil, fmt.Errorf("invalid trace-flags: %s", traceFlags)
}
return &TraceContext{
TraceID: traceID,
SpanID: spanID,
TraceFlags: traceFlags,
}, nil
}
// FormatTraceParent 格式化traceparent header
func (tc *TraceContext) FormatTraceParent() string {
return fmt.Sprintf("%s-%s-%s-%s", TraceContextVersion, tc.TraceID, tc.SpanID, tc.TraceFlags)
}
// GenerateTraceID 生成新的TraceID
func GenerateTraceID() string {
// 简化实现使用随机16字节 = 32字符十六进制
return generateRandomHex(32)
}
// GenerateSpanID 生成新的SpanID
func GenerateSpanID() string {
// 简化实现使用随机8字节 = 16字符十六进制
return generateRandomHex(16)
}
// NewTraceContext 创建新的Trace上下文
func NewTraceContext() *TraceContext {
return &TraceContext{
TraceID: GenerateTraceID(),
SpanID: GenerateSpanID(),
TraceFlags: TraceFlagSampled,
}
}
// NewChildSpanContext 创建子Span上下文
func (tc *TraceContext) NewChildSpanContext() *TraceContext {
return &TraceContext{
TraceID: tc.TraceID,
SpanID: GenerateSpanID(),
TraceFlags: tc.TraceFlags,
}
}
// IsSampled 是否采样
func (tc *TraceContext) IsSampled() bool {
return tc.TraceFlags == TraceFlagSampled
}
// TraceIDAndSpanID 生成用于日志的格式
func (tc *TraceContext) LogFields() map[string]string {
return map[string]string{
"trace_id": tc.TraceID,
"span_id": tc.SpanID,
}
}
// generateRandomHex 生成密码学安全的随机十六进制字符串
func generateRandomHex(length int) string {
// length/2 因为hex编码后长度翻倍
bytes := make([]byte, (length+1)/2)
if _, err := rand.Read(bytes); err != nil {
// 不应该发生,但如果发生使用确定性降级
for i := range bytes {
bytes[i] = byte(i * 7 % 256)
}
}
return hex.EncodeToString(bytes)[:length]
}
// TracingMiddleware HTTP追踪中间件
// P1-006修复解析traceparent header并注入到context
func TracingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
traceParent := r.Header.Get("traceparent")
var tc *TraceContext
if traceParent != "" {
// 解析传入的traceparent
parsed, err := ParseTraceParent(traceParent)
if err == nil {
tc = parsed
}
}
if tc == nil {
// 如果没有有效的traceparent生成新的
tc = NewTraceContext()
}
// 将trace context注入到request context
ctx := WithTraceContext(r.Context(), tc)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}