feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers
This commit is contained in:
232
internal/util/logredact/redact.go
Normal file
232
internal/util/logredact/redact.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package logredact
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// maxRedactDepth 限制递归深度以防止栈溢出
|
||||
const maxRedactDepth = 32
|
||||
|
||||
var defaultSensitiveKeys = map[string]struct{}{
|
||||
"authorization_code": {},
|
||||
"code": {},
|
||||
"code_verifier": {},
|
||||
"access_token": {},
|
||||
"refresh_token": {},
|
||||
"id_token": {},
|
||||
"client_secret": {},
|
||||
"password": {},
|
||||
}
|
||||
|
||||
var defaultSensitiveKeyList = []string{
|
||||
"authorization_code",
|
||||
"code",
|
||||
"code_verifier",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"id_token",
|
||||
"client_secret",
|
||||
"password",
|
||||
}
|
||||
|
||||
type textRedactPatterns struct {
|
||||
reJSONLike *regexp.Regexp
|
||||
reQueryLike *regexp.Regexp
|
||||
rePlain *regexp.Regexp
|
||||
}
|
||||
|
||||
var (
|
||||
reGOCSPX = regexp.MustCompile(`GOCSPX-[0-9A-Za-z_-]{24,}`)
|
||||
reAIza = regexp.MustCompile(`AIza[0-9A-Za-z_-]{35}`)
|
||||
|
||||
defaultTextRedactPatterns = compileTextRedactPatterns(nil)
|
||||
extraTextPatternCache sync.Map // map[string]*textRedactPatterns
|
||||
)
|
||||
|
||||
func RedactMap(input map[string]any, extraKeys ...string) map[string]any {
|
||||
if input == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
keys := buildKeySet(extraKeys)
|
||||
redacted, ok := redactValueWithDepth(input, keys, 0).(map[string]any)
|
||||
if !ok {
|
||||
return map[string]any{}
|
||||
}
|
||||
return redacted
|
||||
}
|
||||
|
||||
func RedactJSON(raw []byte, extraKeys ...string) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
var value any
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return "<non-json payload redacted>"
|
||||
}
|
||||
keys := buildKeySet(extraKeys)
|
||||
redacted := redactValueWithDepth(value, keys, 0)
|
||||
encoded, err := json.Marshal(redacted)
|
||||
if err != nil {
|
||||
return "<redacted>"
|
||||
}
|
||||
return string(encoded)
|
||||
}
|
||||
|
||||
// RedactText 对非结构化文本做轻量脱敏。
|
||||
//
|
||||
// 规则:
|
||||
// - 如果文本本身是 JSON,则按 RedactJSON 处理。
|
||||
// - 否则尝试对常见 key=value / key:"value" 片段做脱敏。
|
||||
//
|
||||
// 注意:该函数用于日志/错误信息兜底,不保证覆盖所有格式。
|
||||
func RedactText(input string, extraKeys ...string) string {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
raw := []byte(input)
|
||||
if json.Valid(raw) {
|
||||
return RedactJSON(raw, extraKeys...)
|
||||
}
|
||||
|
||||
patterns := getTextRedactPatterns(extraKeys)
|
||||
|
||||
out := input
|
||||
out = reGOCSPX.ReplaceAllString(out, "GOCSPX-***")
|
||||
out = reAIza.ReplaceAllString(out, "AIza***")
|
||||
out = patterns.reJSONLike.ReplaceAllString(out, `$1***$3`)
|
||||
out = patterns.reQueryLike.ReplaceAllString(out, `$1=***`)
|
||||
out = patterns.rePlain.ReplaceAllString(out, `$1$2***`)
|
||||
return out
|
||||
}
|
||||
|
||||
func compileTextRedactPatterns(extraKeys []string) *textRedactPatterns {
|
||||
keyAlt := buildKeyAlternation(extraKeys)
|
||||
return &textRedactPatterns{
|
||||
// JSON-like: "access_token":"..."
|
||||
reJSONLike: regexp.MustCompile(`(?i)("(?:` + keyAlt + `)"\s*:\s*")([^"]*)(")`),
|
||||
// Query-like: access_token=...
|
||||
reQueryLike: regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))=([^&\s]+)`),
|
||||
// Plain: access_token: ... / access_token = ...
|
||||
rePlain: regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))\b(\s*[:=]\s*)([^,\s]+)`),
|
||||
}
|
||||
}
|
||||
|
||||
func getTextRedactPatterns(extraKeys []string) *textRedactPatterns {
|
||||
normalizedExtraKeys := normalizeAndSortExtraKeys(extraKeys)
|
||||
if len(normalizedExtraKeys) == 0 {
|
||||
return defaultTextRedactPatterns
|
||||
}
|
||||
|
||||
cacheKey := strings.Join(normalizedExtraKeys, ",")
|
||||
if cached, ok := extraTextPatternCache.Load(cacheKey); ok {
|
||||
if patterns, ok := cached.(*textRedactPatterns); ok {
|
||||
return patterns
|
||||
}
|
||||
}
|
||||
|
||||
compiled := compileTextRedactPatterns(normalizedExtraKeys)
|
||||
actual, _ := extraTextPatternCache.LoadOrStore(cacheKey, compiled)
|
||||
if patterns, ok := actual.(*textRedactPatterns); ok {
|
||||
return patterns
|
||||
}
|
||||
return compiled
|
||||
}
|
||||
|
||||
func normalizeAndSortExtraKeys(extraKeys []string) []string {
|
||||
if len(extraKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(extraKeys))
|
||||
keys := make([]string, 0, len(extraKeys))
|
||||
for _, key := range extraKeys {
|
||||
normalized := normalizeKey(key)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[normalized]; ok {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
keys = append(keys, normalized)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func buildKeyAlternation(extraKeys []string) string {
|
||||
seen := make(map[string]struct{}, len(defaultSensitiveKeyList)+len(extraKeys))
|
||||
keys := make([]string, 0, len(defaultSensitiveKeyList)+len(extraKeys))
|
||||
for _, k := range defaultSensitiveKeyList {
|
||||
seen[k] = struct{}{}
|
||||
keys = append(keys, regexp.QuoteMeta(k))
|
||||
}
|
||||
for _, k := range extraKeys {
|
||||
n := normalizeKey(k)
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[n]; ok {
|
||||
continue
|
||||
}
|
||||
seen[n] = struct{}{}
|
||||
keys = append(keys, regexp.QuoteMeta(n))
|
||||
}
|
||||
return strings.Join(keys, "|")
|
||||
}
|
||||
|
||||
func buildKeySet(extraKeys []string) map[string]struct{} {
|
||||
keys := make(map[string]struct{}, len(defaultSensitiveKeys)+len(extraKeys))
|
||||
for k := range defaultSensitiveKeys {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
for _, key := range extraKeys {
|
||||
normalized := normalizeKey(key)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
keys[normalized] = struct{}{}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func redactValueWithDepth(value any, keys map[string]struct{}, depth int) any {
|
||||
if depth > maxRedactDepth {
|
||||
return "<depth limit exceeded>"
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(v))
|
||||
for k, val := range v {
|
||||
if isSensitiveKey(k, keys) {
|
||||
out[k] = "***"
|
||||
continue
|
||||
}
|
||||
out[k] = redactValueWithDepth(val, keys, depth+1)
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
out[i] = redactValueWithDepth(item, keys, depth+1)
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func isSensitiveKey(key string, keys map[string]struct{}) bool {
|
||||
_, ok := keys[normalizeKey(key)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func normalizeKey(key string) string {
|
||||
return strings.ToLower(strings.TrimSpace(key))
|
||||
}
|
||||
84
internal/util/logredact/redact_test.go
Normal file
84
internal/util/logredact/redact_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package logredact
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRedactText_JSONLike(t *testing.T) {
|
||||
in := `{"access_token":"ya29.a0AfH6SMDUMMY","refresh_token":"1//0gDUMMY","other":"ok"}`
|
||||
out := RedactText(in)
|
||||
if out == in {
|
||||
t.Fatalf("expected redaction, got unchanged")
|
||||
}
|
||||
if want := `"access_token":"***"`; !strings.Contains(out, want) {
|
||||
t.Fatalf("expected %q in %q", want, out)
|
||||
}
|
||||
if want := `"refresh_token":"***"`; !strings.Contains(out, want) {
|
||||
t.Fatalf("expected %q in %q", want, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactText_QueryLike(t *testing.T) {
|
||||
in := "access_token=ya29.a0AfH6SMDUMMY refresh_token=1//0gDUMMY"
|
||||
out := RedactText(in)
|
||||
if strings.Contains(out, "ya29") || strings.Contains(out, "1//0") {
|
||||
t.Fatalf("expected tokens redacted, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactText_GOCSPX(t *testing.T) {
|
||||
in := "client_secret=GOCSPX-your-client-secret"
|
||||
out := RedactText(in)
|
||||
if strings.Contains(out, "your-client-secret") {
|
||||
t.Fatalf("expected secret redacted, got %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "client_secret=***") {
|
||||
t.Fatalf("expected key redacted, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactText_ExtraKeyCacheUsesNormalizedSortedKey(t *testing.T) {
|
||||
clearExtraTextPatternCache()
|
||||
|
||||
out1 := RedactText("custom_secret=abc", "Custom_Secret", " custom_secret ")
|
||||
out2 := RedactText("custom_secret=xyz", "custom_secret")
|
||||
if !strings.Contains(out1, "custom_secret=***") {
|
||||
t.Fatalf("expected custom key redacted in first call, got %q", out1)
|
||||
}
|
||||
if !strings.Contains(out2, "custom_secret=***") {
|
||||
t.Fatalf("expected custom key redacted in second call, got %q", out2)
|
||||
}
|
||||
|
||||
if got := countExtraTextPatternCacheEntries(); got != 1 {
|
||||
t.Fatalf("expected 1 cached pattern set, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactText_DefaultPathDoesNotUseExtraCache(t *testing.T) {
|
||||
clearExtraTextPatternCache()
|
||||
|
||||
out := RedactText("access_token=abc")
|
||||
if !strings.Contains(out, "access_token=***") {
|
||||
t.Fatalf("expected default key redacted, got %q", out)
|
||||
}
|
||||
if got := countExtraTextPatternCacheEntries(); got != 0 {
|
||||
t.Fatalf("expected extra cache to remain empty, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func clearExtraTextPatternCache() {
|
||||
extraTextPatternCache.Range(func(key, value any) bool {
|
||||
extraTextPatternCache.Delete(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func countExtraTextPatternCacheEntries() int {
|
||||
count := 0
|
||||
extraTextPatternCache.Range(func(key, value any) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
117
internal/util/responseheaders/responseheaders.go
Normal file
117
internal/util/responseheaders/responseheaders.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package responseheaders
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/user-management-system/internal/config"
|
||||
)
|
||||
|
||||
// defaultAllowed 定义允许透传的响应头白名单
|
||||
// 注意:以下头部由 Go HTTP 包自动处理,不应手动设置:
|
||||
// - content-length: 由 ResponseWriter 根据实际写入数据自动设置
|
||||
// - transfer-encoding: 由 HTTP 库根据需要自动添加/移除
|
||||
// - connection: 由 HTTP 库管理连接复用
|
||||
var defaultAllowed = map[string]struct{}{
|
||||
"content-type": {},
|
||||
"content-encoding": {},
|
||||
"content-language": {},
|
||||
"cache-control": {},
|
||||
"etag": {},
|
||||
"last-modified": {},
|
||||
"expires": {},
|
||||
"vary": {},
|
||||
"date": {},
|
||||
"x-request-id": {},
|
||||
"x-ratelimit-limit-requests": {},
|
||||
"x-ratelimit-limit-tokens": {},
|
||||
"x-ratelimit-remaining-requests": {},
|
||||
"x-ratelimit-remaining-tokens": {},
|
||||
"x-ratelimit-reset-requests": {},
|
||||
"x-ratelimit-reset-tokens": {},
|
||||
"retry-after": {},
|
||||
"location": {},
|
||||
"www-authenticate": {},
|
||||
}
|
||||
|
||||
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
"content-length": {},
|
||||
"transfer-encoding": {},
|
||||
"connection": {},
|
||||
}
|
||||
|
||||
type CompiledHeaderFilter struct {
|
||||
allowed map[string]struct{}
|
||||
forceRemove map[string]struct{}
|
||||
}
|
||||
|
||||
var defaultCompiledHeaderFilter = CompileHeaderFilter(config.ResponseHeaderConfig{})
|
||||
|
||||
func CompileHeaderFilter(cfg config.ResponseHeaderConfig) *CompiledHeaderFilter {
|
||||
allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed))
|
||||
for key := range defaultAllowed {
|
||||
allowed[key] = struct{}{}
|
||||
}
|
||||
// 关闭时只使用默认白名单,additional/force_remove 不生效
|
||||
if cfg.Enabled {
|
||||
for _, key := range cfg.AdditionalAllowed {
|
||||
normalized := strings.ToLower(strings.TrimSpace(key))
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
allowed[normalized] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
forceRemove := map[string]struct{}{}
|
||||
if cfg.Enabled {
|
||||
forceRemove = make(map[string]struct{}, len(cfg.ForceRemove))
|
||||
for _, key := range cfg.ForceRemove {
|
||||
normalized := strings.ToLower(strings.TrimSpace(key))
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
forceRemove[normalized] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return &CompiledHeaderFilter{
|
||||
allowed: allowed,
|
||||
forceRemove: forceRemove,
|
||||
}
|
||||
}
|
||||
|
||||
func FilterHeaders(src http.Header, filter *CompiledHeaderFilter) http.Header {
|
||||
if filter == nil {
|
||||
filter = defaultCompiledHeaderFilter
|
||||
}
|
||||
|
||||
filtered := make(http.Header, len(src))
|
||||
for key, values := range src {
|
||||
lower := strings.ToLower(key)
|
||||
if _, blocked := filter.forceRemove[lower]; blocked {
|
||||
continue
|
||||
}
|
||||
if _, ok := filter.allowed[lower]; !ok {
|
||||
continue
|
||||
}
|
||||
// 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理
|
||||
if _, isHopByHop := hopByHopHeaders[lower]; isHopByHop {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
filtered.Add(key, value)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func WriteFilteredHeaders(dst http.Header, src http.Header, filter *CompiledHeaderFilter) {
|
||||
filtered := FilterHeaders(src, filter)
|
||||
for key, values := range filtered {
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
67
internal/util/responseheaders/responseheaders_test.go
Normal file
67
internal/util/responseheaders/responseheaders_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package responseheaders
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/user-management-system/internal/config"
|
||||
)
|
||||
|
||||
func TestFilterHeadersDisabledUsesDefaultAllowlist(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Content-Type", "application/json")
|
||||
src.Add("X-Request-Id", "req-123")
|
||||
src.Add("X-Test", "ok")
|
||||
src.Add("Connection", "keep-alive")
|
||||
src.Add("Content-Length", "123")
|
||||
|
||||
cfg := config.ResponseHeaderConfig{
|
||||
Enabled: false,
|
||||
ForceRemove: []string{"x-request-id"},
|
||||
}
|
||||
|
||||
filtered := FilterHeaders(src, CompileHeaderFilter(cfg))
|
||||
if filtered.Get("Content-Type") != "application/json" {
|
||||
t.Fatalf("expected Content-Type passthrough, got %q", filtered.Get("Content-Type"))
|
||||
}
|
||||
if filtered.Get("X-Request-Id") != "req-123" {
|
||||
t.Fatalf("expected X-Request-Id allowed, got %q", filtered.Get("X-Request-Id"))
|
||||
}
|
||||
if filtered.Get("X-Test") != "" {
|
||||
t.Fatalf("expected X-Test removed, got %q", filtered.Get("X-Test"))
|
||||
}
|
||||
if filtered.Get("Connection") != "" {
|
||||
t.Fatalf("expected Connection to be removed, got %q", filtered.Get("Connection"))
|
||||
}
|
||||
if filtered.Get("Content-Length") != "" {
|
||||
t.Fatalf("expected Content-Length to be removed, got %q", filtered.Get("Content-Length"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterHeadersEnabledUsesAllowlist(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Content-Type", "application/json")
|
||||
src.Add("X-Extra", "ok")
|
||||
src.Add("X-Remove", "nope")
|
||||
src.Add("X-Blocked", "nope")
|
||||
|
||||
cfg := config.ResponseHeaderConfig{
|
||||
Enabled: true,
|
||||
AdditionalAllowed: []string{"x-extra"},
|
||||
ForceRemove: []string{"x-remove"},
|
||||
}
|
||||
|
||||
filtered := FilterHeaders(src, CompileHeaderFilter(cfg))
|
||||
if filtered.Get("Content-Type") != "application/json" {
|
||||
t.Fatalf("expected Content-Type allowed, got %q", filtered.Get("Content-Type"))
|
||||
}
|
||||
if filtered.Get("X-Extra") != "ok" {
|
||||
t.Fatalf("expected X-Extra allowed, got %q", filtered.Get("X-Extra"))
|
||||
}
|
||||
if filtered.Get("X-Remove") != "" {
|
||||
t.Fatalf("expected X-Remove removed, got %q", filtered.Get("X-Remove"))
|
||||
}
|
||||
if filtered.Get("X-Blocked") != "" {
|
||||
t.Fatalf("expected X-Blocked removed, got %q", filtered.Get("X-Blocked"))
|
||||
}
|
||||
}
|
||||
170
internal/util/soraerror/soraerror.go
Normal file
170
internal/util/soraerror/soraerror.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package soraerror
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
cfRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
|
||||
cRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
|
||||
htmlChallenge = []string{
|
||||
"window._cf_chl_opt",
|
||||
"just a moment",
|
||||
"enable javascript and cookies to continue",
|
||||
"__cf_chl_",
|
||||
"challenge-platform",
|
||||
}
|
||||
)
|
||||
|
||||
// IsCloudflareChallengeResponse reports whether the upstream response matches Cloudflare challenge behavior.
|
||||
func IsCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||
if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
|
||||
return false
|
||||
}
|
||||
|
||||
if headers != nil && strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
|
||||
return true
|
||||
}
|
||||
|
||||
preview := strings.ToLower(TruncateBody(body, 4096))
|
||||
for _, marker := range htmlChallenge {
|
||||
if strings.Contains(preview, marker) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
contentType := ""
|
||||
if headers != nil {
|
||||
contentType = strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
|
||||
}
|
||||
if strings.Contains(contentType, "text/html") &&
|
||||
(strings.Contains(preview, "<html") || strings.Contains(preview, "<!doctype html")) &&
|
||||
(strings.Contains(preview, "cloudflare") || strings.Contains(preview, "challenge")) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractCloudflareRayID extracts cf-ray from headers or response body.
|
||||
func ExtractCloudflareRayID(headers http.Header, body []byte) string {
|
||||
if headers != nil {
|
||||
rayID := strings.TrimSpace(headers.Get("cf-ray"))
|
||||
if rayID != "" {
|
||||
return rayID
|
||||
}
|
||||
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
|
||||
if rayID != "" {
|
||||
return rayID
|
||||
}
|
||||
}
|
||||
|
||||
preview := TruncateBody(body, 8192)
|
||||
if matches := cfRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
if matches := cRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// FormatCloudflareChallengeMessage appends cf-ray info when available.
|
||||
func FormatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||
rayID := ExtractCloudflareRayID(headers, body)
|
||||
if rayID == "" {
|
||||
return base
|
||||
}
|
||||
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
|
||||
}
|
||||
|
||||
// ExtractUpstreamErrorCodeAndMessage extracts structured error code/message from common JSON layouts.
|
||||
func ExtractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
||||
trimmed := strings.TrimSpace(string(body))
|
||||
if trimmed == "" {
|
||||
return "", ""
|
||||
}
|
||||
if !json.Valid([]byte(trimmed)) {
|
||||
return "", truncateMessage(trimmed, 256)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(trimmed), &payload); err != nil {
|
||||
return "", truncateMessage(trimmed, 256)
|
||||
}
|
||||
|
||||
code := firstNonEmpty(
|
||||
extractNestedString(payload, "error", "code"),
|
||||
extractRootString(payload, "code"),
|
||||
)
|
||||
message := firstNonEmpty(
|
||||
extractNestedString(payload, "error", "message"),
|
||||
extractRootString(payload, "message"),
|
||||
extractNestedString(payload, "error", "detail"),
|
||||
extractRootString(payload, "detail"),
|
||||
)
|
||||
return strings.TrimSpace(code), truncateMessage(strings.TrimSpace(message), 512)
|
||||
}
|
||||
|
||||
// TruncateBody truncates body text for logging/inspection.
|
||||
func TruncateBody(body []byte, max int) string {
|
||||
if max <= 0 {
|
||||
max = 512
|
||||
}
|
||||
raw := strings.TrimSpace(string(body))
|
||||
if len(raw) <= max {
|
||||
return raw
|
||||
}
|
||||
return raw[:max] + "...(truncated)"
|
||||
}
|
||||
|
||||
func truncateMessage(s string, max int) string {
|
||||
if max <= 0 {
|
||||
return ""
|
||||
}
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "...(truncated)"
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, v := range values {
|
||||
if strings.TrimSpace(v) != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractRootString(m map[string]any, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func extractNestedString(m map[string]any, parent, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
node, ok := m[parent]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
child, ok := node.(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
s, _ := child[key].(string)
|
||||
return s
|
||||
}
|
||||
47
internal/util/soraerror/soraerror_test.go
Normal file
47
internal/util/soraerror/soraerror_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package soraerror
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsCloudflareChallengeResponse(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set("cf-mitigated", "challenge")
|
||||
require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`)))
|
||||
|
||||
require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>`)))
|
||||
require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title>`)))
|
||||
}
|
||||
|
||||
func TestExtractCloudflareRayID(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||
require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil))
|
||||
|
||||
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
|
||||
require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body))
|
||||
}
|
||||
|
||||
func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) {
|
||||
code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`))
|
||||
require.Equal(t, "cf_shield_429", code)
|
||||
require.Equal(t, "rate limited", msg)
|
||||
|
||||
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`))
|
||||
require.Equal(t, "unsupported_country_code", code)
|
||||
require.Equal(t, "not available", msg)
|
||||
|
||||
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`))
|
||||
require.Equal(t, "", code)
|
||||
require.Equal(t, "plain text", msg)
|
||||
}
|
||||
|
||||
func TestFormatCloudflareChallengeMessage(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
|
||||
msg := FormatCloudflareChallengeMessage("blocked", headers, nil)
|
||||
require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg)
|
||||
}
|
||||
175
internal/util/urlvalidator/validator.go
Normal file
175
internal/util/urlvalidator/validator.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package urlvalidator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ValidationOptions struct {
|
||||
AllowedHosts []string
|
||||
RequireAllowlist bool
|
||||
AllowPrivate bool
|
||||
}
|
||||
|
||||
// ValidateHTTPURL validates an outbound HTTP/HTTPS URL.
|
||||
//
|
||||
// It provides a single validation entry point that supports:
|
||||
// - scheme 校验(https 或可选允许 http)
|
||||
// - 可选 allowlist(支持 *.example.com 通配)
|
||||
// - allow_private_hosts 策略(阻断 localhost/私网字面量 IP)
|
||||
//
|
||||
// 注意:DNS Rebinding 防护(解析后 IP 校验)应在实际发起请求时执行,避免 TOCTOU。
|
||||
func ValidateHTTPURL(raw string, allowInsecureHTTP bool, opts ValidationOptions) (string, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", errors.New("url is required")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
return "", fmt.Errorf("invalid url: %s", trimmed)
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(parsed.Scheme)
|
||||
if scheme != "https" && (!allowInsecureHTTP || scheme != "http") {
|
||||
return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
|
||||
}
|
||||
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
if host == "" {
|
||||
return "", errors.New("invalid host")
|
||||
}
|
||||
if !opts.AllowPrivate && isBlockedHost(host) {
|
||||
return "", fmt.Errorf("host is not allowed: %s", host)
|
||||
}
|
||||
|
||||
if port := parsed.Port(); port != "" {
|
||||
num, err := strconv.Atoi(port)
|
||||
if err != nil || num <= 0 || num > 65535 {
|
||||
return "", fmt.Errorf("invalid port: %s", port)
|
||||
}
|
||||
}
|
||||
|
||||
allowlist := normalizeAllowlist(opts.AllowedHosts)
|
||||
if opts.RequireAllowlist && len(allowlist) == 0 {
|
||||
return "", errors.New("allowlist is not configured")
|
||||
}
|
||||
if len(allowlist) > 0 && !isAllowedHost(host, allowlist) {
|
||||
return "", fmt.Errorf("host is not allowed: %s", host)
|
||||
}
|
||||
|
||||
parsed.Path = strings.TrimRight(parsed.Path, "/")
|
||||
parsed.RawPath = ""
|
||||
return strings.TrimRight(parsed.String(), "/"), nil
|
||||
}
|
||||
|
||||
func ValidateURLFormat(raw string, allowInsecureHTTP bool) (string, error) {
|
||||
// 最小格式校验:仅保证 URL 可解析且 scheme 合规,不做白名单/私网/SSRF 校验
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", errors.New("url is required")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
return "", fmt.Errorf("invalid url: %s", trimmed)
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(parsed.Scheme)
|
||||
if scheme != "https" && (!allowInsecureHTTP || scheme != "http") {
|
||||
return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
|
||||
}
|
||||
|
||||
host := strings.TrimSpace(parsed.Hostname())
|
||||
if host == "" {
|
||||
return "", errors.New("invalid host")
|
||||
}
|
||||
|
||||
if port := parsed.Port(); port != "" {
|
||||
num, err := strconv.Atoi(port)
|
||||
if err != nil || num <= 0 || num > 65535 {
|
||||
return "", fmt.Errorf("invalid port: %s", port)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimRight(trimmed, "/"), nil
|
||||
}
|
||||
|
||||
func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) {
|
||||
return ValidateHTTPURL(raw, false, opts)
|
||||
}
|
||||
|
||||
// ValidateResolvedIP 验证 DNS 解析后的 IP 地址是否安全
|
||||
// 用于防止 DNS Rebinding 攻击:在实际 HTTP 请求时调用此函数验证解析后的 IP
|
||||
func ValidateResolvedIP(host string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dns resolution failed: %w", err)
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
|
||||
return fmt.Errorf("resolved ip %s is not allowed", ip.String())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeAllowlist(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
entry := strings.ToLower(strings.TrimSpace(v))
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
if host, _, err := net.SplitHostPort(entry); err == nil {
|
||||
entry = host
|
||||
}
|
||||
normalized = append(normalized, entry)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func isAllowedHost(host string, allowlist []string) bool {
|
||||
for _, entry := range allowlist {
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(entry, "*.") {
|
||||
suffix := strings.TrimPrefix(entry, "*.")
|
||||
if host == suffix || strings.HasSuffix(host, "."+suffix) {
|
||||
return true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if host == entry {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isBlockedHost(host string) bool {
|
||||
if host == "localhost" || strings.HasSuffix(host, ".localhost") {
|
||||
return true
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
75
internal/util/urlvalidator/validator_test.go
Normal file
75
internal/util/urlvalidator/validator_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package urlvalidator
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestValidateURLFormat(t *testing.T) {
|
||||
if _, err := ValidateURLFormat("", false); err == nil {
|
||||
t.Fatalf("expected empty url to fail")
|
||||
}
|
||||
if _, err := ValidateURLFormat("://bad", false); err == nil {
|
||||
t.Fatalf("expected invalid url to fail")
|
||||
}
|
||||
if _, err := ValidateURLFormat("http://example.com", false); err == nil {
|
||||
t.Fatalf("expected http to fail when allow_insecure_http is false")
|
||||
}
|
||||
if _, err := ValidateURLFormat("https://example.com", false); err != nil {
|
||||
t.Fatalf("expected https to pass, got %v", err)
|
||||
}
|
||||
if _, err := ValidateURLFormat("http://example.com", true); err != nil {
|
||||
t.Fatalf("expected http to pass when allow_insecure_http is true, got %v", err)
|
||||
}
|
||||
if _, err := ValidateURLFormat("https://example.com:bad", true); err == nil {
|
||||
t.Fatalf("expected invalid port to fail")
|
||||
}
|
||||
|
||||
// 验证末尾斜杠被移除
|
||||
normalized, err := ValidateURLFormat("https://example.com/", false)
|
||||
if err != nil {
|
||||
t.Fatalf("expected trailing slash url to pass, got %v", err)
|
||||
}
|
||||
if normalized != "https://example.com" {
|
||||
t.Fatalf("expected trailing slash to be removed, got %s", normalized)
|
||||
}
|
||||
|
||||
// 验证多个末尾斜杠被移除
|
||||
normalized, err = ValidateURLFormat("https://example.com///", false)
|
||||
if err != nil {
|
||||
t.Fatalf("expected multiple trailing slashes to pass, got %v", err)
|
||||
}
|
||||
if normalized != "https://example.com" {
|
||||
t.Fatalf("expected all trailing slashes to be removed, got %s", normalized)
|
||||
}
|
||||
|
||||
// 验证带路径的 URL 末尾斜杠被移除
|
||||
normalized, err = ValidateURLFormat("https://example.com/api/v1/", false)
|
||||
if err != nil {
|
||||
t.Fatalf("expected trailing slash url with path to pass, got %v", err)
|
||||
}
|
||||
if normalized != "https://example.com/api/v1" {
|
||||
t.Fatalf("expected trailing slash to be removed from path, got %s", normalized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateHTTPURL(t *testing.T) {
|
||||
if _, err := ValidateHTTPURL("http://example.com", false, ValidationOptions{}); err == nil {
|
||||
t.Fatalf("expected http to fail when allow_insecure_http is false")
|
||||
}
|
||||
if _, err := ValidateHTTPURL("http://example.com", true, ValidationOptions{}); err != nil {
|
||||
t.Fatalf("expected http to pass when allow_insecure_http is true, got %v", err)
|
||||
}
|
||||
if _, err := ValidateHTTPURL("https://example.com", false, ValidationOptions{RequireAllowlist: true}); err == nil {
|
||||
t.Fatalf("expected require allowlist to fail when empty")
|
||||
}
|
||||
if _, err := ValidateHTTPURL("https://example.com", false, ValidationOptions{AllowedHosts: []string{"api.example.com"}}); err == nil {
|
||||
t.Fatalf("expected host not in allowlist to fail")
|
||||
}
|
||||
if _, err := ValidateHTTPURL("https://api.example.com", false, ValidationOptions{AllowedHosts: []string{"api.example.com"}}); err != nil {
|
||||
t.Fatalf("expected allowlisted host to pass, got %v", err)
|
||||
}
|
||||
if _, err := ValidateHTTPURL("https://sub.api.example.com", false, ValidationOptions{AllowedHosts: []string{"*.example.com"}}); err != nil {
|
||||
t.Fatalf("expected wildcard allowlist to pass, got %v", err)
|
||||
}
|
||||
if _, err := ValidateHTTPURL("https://localhost", false, ValidationOptions{AllowPrivate: false}); err == nil {
|
||||
t.Fatalf("expected localhost to be blocked when allow_private_hosts is false")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user