feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers

This commit is contained in:
2026-04-02 11:19:50 +08:00
parent e59a77bc49
commit dcc1f186f8
298 changed files with 62603 additions and 0 deletions

View File

@@ -0,0 +1,232 @@
package logredact
import (
"encoding/json"
"regexp"
"sort"
"strings"
"sync"
)
// maxRedactDepth 限制递归深度以防止栈溢出
const maxRedactDepth = 32
var defaultSensitiveKeys = map[string]struct{}{
"authorization_code": {},
"code": {},
"code_verifier": {},
"access_token": {},
"refresh_token": {},
"id_token": {},
"client_secret": {},
"password": {},
}
var defaultSensitiveKeyList = []string{
"authorization_code",
"code",
"code_verifier",
"access_token",
"refresh_token",
"id_token",
"client_secret",
"password",
}
type textRedactPatterns struct {
reJSONLike *regexp.Regexp
reQueryLike *regexp.Regexp
rePlain *regexp.Regexp
}
var (
reGOCSPX = regexp.MustCompile(`GOCSPX-[0-9A-Za-z_-]{24,}`)
reAIza = regexp.MustCompile(`AIza[0-9A-Za-z_-]{35}`)
defaultTextRedactPatterns = compileTextRedactPatterns(nil)
extraTextPatternCache sync.Map // map[string]*textRedactPatterns
)
func RedactMap(input map[string]any, extraKeys ...string) map[string]any {
if input == nil {
return map[string]any{}
}
keys := buildKeySet(extraKeys)
redacted, ok := redactValueWithDepth(input, keys, 0).(map[string]any)
if !ok {
return map[string]any{}
}
return redacted
}
func RedactJSON(raw []byte, extraKeys ...string) string {
if len(raw) == 0 {
return ""
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return "<non-json payload redacted>"
}
keys := buildKeySet(extraKeys)
redacted := redactValueWithDepth(value, keys, 0)
encoded, err := json.Marshal(redacted)
if err != nil {
return "<redacted>"
}
return string(encoded)
}
// RedactText 对非结构化文本做轻量脱敏。
//
// 规则:
// - 如果文本本身是 JSON则按 RedactJSON 处理。
// - 否则尝试对常见 key=value / key:"value" 片段做脱敏。
//
// 注意:该函数用于日志/错误信息兜底,不保证覆盖所有格式。
func RedactText(input string, extraKeys ...string) string {
input = strings.TrimSpace(input)
if input == "" {
return ""
}
raw := []byte(input)
if json.Valid(raw) {
return RedactJSON(raw, extraKeys...)
}
patterns := getTextRedactPatterns(extraKeys)
out := input
out = reGOCSPX.ReplaceAllString(out, "GOCSPX-***")
out = reAIza.ReplaceAllString(out, "AIza***")
out = patterns.reJSONLike.ReplaceAllString(out, `$1***$3`)
out = patterns.reQueryLike.ReplaceAllString(out, `$1=***`)
out = patterns.rePlain.ReplaceAllString(out, `$1$2***`)
return out
}
func compileTextRedactPatterns(extraKeys []string) *textRedactPatterns {
keyAlt := buildKeyAlternation(extraKeys)
return &textRedactPatterns{
// JSON-like: "access_token":"..."
reJSONLike: regexp.MustCompile(`(?i)("(?:` + keyAlt + `)"\s*:\s*")([^"]*)(")`),
// Query-like: access_token=...
reQueryLike: regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))=([^&\s]+)`),
// Plain: access_token: ... / access_token = ...
rePlain: regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))\b(\s*[:=]\s*)([^,\s]+)`),
}
}
func getTextRedactPatterns(extraKeys []string) *textRedactPatterns {
normalizedExtraKeys := normalizeAndSortExtraKeys(extraKeys)
if len(normalizedExtraKeys) == 0 {
return defaultTextRedactPatterns
}
cacheKey := strings.Join(normalizedExtraKeys, ",")
if cached, ok := extraTextPatternCache.Load(cacheKey); ok {
if patterns, ok := cached.(*textRedactPatterns); ok {
return patterns
}
}
compiled := compileTextRedactPatterns(normalizedExtraKeys)
actual, _ := extraTextPatternCache.LoadOrStore(cacheKey, compiled)
if patterns, ok := actual.(*textRedactPatterns); ok {
return patterns
}
return compiled
}
func normalizeAndSortExtraKeys(extraKeys []string) []string {
if len(extraKeys) == 0 {
return nil
}
seen := make(map[string]struct{}, len(extraKeys))
keys := make([]string, 0, len(extraKeys))
for _, key := range extraKeys {
normalized := normalizeKey(key)
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
keys = append(keys, normalized)
}
sort.Strings(keys)
return keys
}
func buildKeyAlternation(extraKeys []string) string {
seen := make(map[string]struct{}, len(defaultSensitiveKeyList)+len(extraKeys))
keys := make([]string, 0, len(defaultSensitiveKeyList)+len(extraKeys))
for _, k := range defaultSensitiveKeyList {
seen[k] = struct{}{}
keys = append(keys, regexp.QuoteMeta(k))
}
for _, k := range extraKeys {
n := normalizeKey(k)
if n == "" {
continue
}
if _, ok := seen[n]; ok {
continue
}
seen[n] = struct{}{}
keys = append(keys, regexp.QuoteMeta(n))
}
return strings.Join(keys, "|")
}
func buildKeySet(extraKeys []string) map[string]struct{} {
keys := make(map[string]struct{}, len(defaultSensitiveKeys)+len(extraKeys))
for k := range defaultSensitiveKeys {
keys[k] = struct{}{}
}
for _, key := range extraKeys {
normalized := normalizeKey(key)
if normalized == "" {
continue
}
keys[normalized] = struct{}{}
}
return keys
}
func redactValueWithDepth(value any, keys map[string]struct{}, depth int) any {
if depth > maxRedactDepth {
return "<depth limit exceeded>"
}
switch v := value.(type) {
case map[string]any:
out := make(map[string]any, len(v))
for k, val := range v {
if isSensitiveKey(k, keys) {
out[k] = "***"
continue
}
out[k] = redactValueWithDepth(val, keys, depth+1)
}
return out
case []any:
out := make([]any, len(v))
for i, item := range v {
out[i] = redactValueWithDepth(item, keys, depth+1)
}
return out
default:
return value
}
}
func isSensitiveKey(key string, keys map[string]struct{}) bool {
_, ok := keys[normalizeKey(key)]
return ok
}
func normalizeKey(key string) string {
return strings.ToLower(strings.TrimSpace(key))
}

View File

@@ -0,0 +1,84 @@
package logredact
import (
"strings"
"testing"
)
func TestRedactText_JSONLike(t *testing.T) {
in := `{"access_token":"ya29.a0AfH6SMDUMMY","refresh_token":"1//0gDUMMY","other":"ok"}`
out := RedactText(in)
if out == in {
t.Fatalf("expected redaction, got unchanged")
}
if want := `"access_token":"***"`; !strings.Contains(out, want) {
t.Fatalf("expected %q in %q", want, out)
}
if want := `"refresh_token":"***"`; !strings.Contains(out, want) {
t.Fatalf("expected %q in %q", want, out)
}
}
func TestRedactText_QueryLike(t *testing.T) {
in := "access_token=ya29.a0AfH6SMDUMMY refresh_token=1//0gDUMMY"
out := RedactText(in)
if strings.Contains(out, "ya29") || strings.Contains(out, "1//0") {
t.Fatalf("expected tokens redacted, got %q", out)
}
}
func TestRedactText_GOCSPX(t *testing.T) {
in := "client_secret=GOCSPX-your-client-secret"
out := RedactText(in)
if strings.Contains(out, "your-client-secret") {
t.Fatalf("expected secret redacted, got %q", out)
}
if !strings.Contains(out, "client_secret=***") {
t.Fatalf("expected key redacted, got %q", out)
}
}
func TestRedactText_ExtraKeyCacheUsesNormalizedSortedKey(t *testing.T) {
clearExtraTextPatternCache()
out1 := RedactText("custom_secret=abc", "Custom_Secret", " custom_secret ")
out2 := RedactText("custom_secret=xyz", "custom_secret")
if !strings.Contains(out1, "custom_secret=***") {
t.Fatalf("expected custom key redacted in first call, got %q", out1)
}
if !strings.Contains(out2, "custom_secret=***") {
t.Fatalf("expected custom key redacted in second call, got %q", out2)
}
if got := countExtraTextPatternCacheEntries(); got != 1 {
t.Fatalf("expected 1 cached pattern set, got %d", got)
}
}
func TestRedactText_DefaultPathDoesNotUseExtraCache(t *testing.T) {
clearExtraTextPatternCache()
out := RedactText("access_token=abc")
if !strings.Contains(out, "access_token=***") {
t.Fatalf("expected default key redacted, got %q", out)
}
if got := countExtraTextPatternCacheEntries(); got != 0 {
t.Fatalf("expected extra cache to remain empty, got %d", got)
}
}
func clearExtraTextPatternCache() {
extraTextPatternCache.Range(func(key, value any) bool {
extraTextPatternCache.Delete(key)
return true
})
}
func countExtraTextPatternCacheEntries() int {
count := 0
extraTextPatternCache.Range(func(key, value any) bool {
count++
return true
})
return count
}

View File

@@ -0,0 +1,117 @@
package responseheaders
import (
"net/http"
"strings"
"github.com/user-management-system/internal/config"
)
// defaultAllowed 定义允许透传的响应头白名单
// 注意:以下头部由 Go HTTP 包自动处理,不应手动设置:
// - content-length: 由 ResponseWriter 根据实际写入数据自动设置
// - transfer-encoding: 由 HTTP 库根据需要自动添加/移除
// - connection: 由 HTTP 库管理连接复用
var defaultAllowed = map[string]struct{}{
"content-type": {},
"content-encoding": {},
"content-language": {},
"cache-control": {},
"etag": {},
"last-modified": {},
"expires": {},
"vary": {},
"date": {},
"x-request-id": {},
"x-ratelimit-limit-requests": {},
"x-ratelimit-limit-tokens": {},
"x-ratelimit-remaining-requests": {},
"x-ratelimit-remaining-tokens": {},
"x-ratelimit-reset-requests": {},
"x-ratelimit-reset-tokens": {},
"retry-after": {},
"location": {},
"www-authenticate": {},
}
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
var hopByHopHeaders = map[string]struct{}{
"content-length": {},
"transfer-encoding": {},
"connection": {},
}
type CompiledHeaderFilter struct {
allowed map[string]struct{}
forceRemove map[string]struct{}
}
var defaultCompiledHeaderFilter = CompileHeaderFilter(config.ResponseHeaderConfig{})
func CompileHeaderFilter(cfg config.ResponseHeaderConfig) *CompiledHeaderFilter {
allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed))
for key := range defaultAllowed {
allowed[key] = struct{}{}
}
// 关闭时只使用默认白名单additional/force_remove 不生效
if cfg.Enabled {
for _, key := range cfg.AdditionalAllowed {
normalized := strings.ToLower(strings.TrimSpace(key))
if normalized == "" {
continue
}
allowed[normalized] = struct{}{}
}
}
forceRemove := map[string]struct{}{}
if cfg.Enabled {
forceRemove = make(map[string]struct{}, len(cfg.ForceRemove))
for _, key := range cfg.ForceRemove {
normalized := strings.ToLower(strings.TrimSpace(key))
if normalized == "" {
continue
}
forceRemove[normalized] = struct{}{}
}
}
return &CompiledHeaderFilter{
allowed: allowed,
forceRemove: forceRemove,
}
}
func FilterHeaders(src http.Header, filter *CompiledHeaderFilter) http.Header {
if filter == nil {
filter = defaultCompiledHeaderFilter
}
filtered := make(http.Header, len(src))
for key, values := range src {
lower := strings.ToLower(key)
if _, blocked := filter.forceRemove[lower]; blocked {
continue
}
if _, ok := filter.allowed[lower]; !ok {
continue
}
// 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理
if _, isHopByHop := hopByHopHeaders[lower]; isHopByHop {
continue
}
for _, value := range values {
filtered.Add(key, value)
}
}
return filtered
}
func WriteFilteredHeaders(dst http.Header, src http.Header, filter *CompiledHeaderFilter) {
filtered := FilterHeaders(src, filter)
for key, values := range filtered {
for _, value := range values {
dst.Add(key, value)
}
}
}

View File

@@ -0,0 +1,67 @@
package responseheaders
import (
"net/http"
"testing"
"github.com/user-management-system/internal/config"
)
func TestFilterHeadersDisabledUsesDefaultAllowlist(t *testing.T) {
src := http.Header{}
src.Add("Content-Type", "application/json")
src.Add("X-Request-Id", "req-123")
src.Add("X-Test", "ok")
src.Add("Connection", "keep-alive")
src.Add("Content-Length", "123")
cfg := config.ResponseHeaderConfig{
Enabled: false,
ForceRemove: []string{"x-request-id"},
}
filtered := FilterHeaders(src, CompileHeaderFilter(cfg))
if filtered.Get("Content-Type") != "application/json" {
t.Fatalf("expected Content-Type passthrough, got %q", filtered.Get("Content-Type"))
}
if filtered.Get("X-Request-Id") != "req-123" {
t.Fatalf("expected X-Request-Id allowed, got %q", filtered.Get("X-Request-Id"))
}
if filtered.Get("X-Test") != "" {
t.Fatalf("expected X-Test removed, got %q", filtered.Get("X-Test"))
}
if filtered.Get("Connection") != "" {
t.Fatalf("expected Connection to be removed, got %q", filtered.Get("Connection"))
}
if filtered.Get("Content-Length") != "" {
t.Fatalf("expected Content-Length to be removed, got %q", filtered.Get("Content-Length"))
}
}
func TestFilterHeadersEnabledUsesAllowlist(t *testing.T) {
src := http.Header{}
src.Add("Content-Type", "application/json")
src.Add("X-Extra", "ok")
src.Add("X-Remove", "nope")
src.Add("X-Blocked", "nope")
cfg := config.ResponseHeaderConfig{
Enabled: true,
AdditionalAllowed: []string{"x-extra"},
ForceRemove: []string{"x-remove"},
}
filtered := FilterHeaders(src, CompileHeaderFilter(cfg))
if filtered.Get("Content-Type") != "application/json" {
t.Fatalf("expected Content-Type allowed, got %q", filtered.Get("Content-Type"))
}
if filtered.Get("X-Extra") != "ok" {
t.Fatalf("expected X-Extra allowed, got %q", filtered.Get("X-Extra"))
}
if filtered.Get("X-Remove") != "" {
t.Fatalf("expected X-Remove removed, got %q", filtered.Get("X-Remove"))
}
if filtered.Get("X-Blocked") != "" {
t.Fatalf("expected X-Blocked removed, got %q", filtered.Get("X-Blocked"))
}
}

View File

@@ -0,0 +1,170 @@
package soraerror
import (
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
)
var (
cfRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
cRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
htmlChallenge = []string{
"window._cf_chl_opt",
"just a moment",
"enable javascript and cookies to continue",
"__cf_chl_",
"challenge-platform",
}
)
// IsCloudflareChallengeResponse reports whether the upstream response matches Cloudflare challenge behavior.
func IsCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
return false
}
if headers != nil && strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
return true
}
preview := strings.ToLower(TruncateBody(body, 4096))
for _, marker := range htmlChallenge {
if strings.Contains(preview, marker) {
return true
}
}
contentType := ""
if headers != nil {
contentType = strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
}
if strings.Contains(contentType, "text/html") &&
(strings.Contains(preview, "<html") || strings.Contains(preview, "<!doctype html")) &&
(strings.Contains(preview, "cloudflare") || strings.Contains(preview, "challenge")) {
return true
}
return false
}
// ExtractCloudflareRayID extracts cf-ray from headers or response body.
func ExtractCloudflareRayID(headers http.Header, body []byte) string {
if headers != nil {
rayID := strings.TrimSpace(headers.Get("cf-ray"))
if rayID != "" {
return rayID
}
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
if rayID != "" {
return rayID
}
}
preview := TruncateBody(body, 8192)
if matches := cfRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
if matches := cRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
return ""
}
// FormatCloudflareChallengeMessage appends cf-ray info when available.
func FormatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
rayID := ExtractCloudflareRayID(headers, body)
if rayID == "" {
return base
}
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
}
// ExtractUpstreamErrorCodeAndMessage extracts structured error code/message from common JSON layouts.
func ExtractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return "", ""
}
if !json.Valid([]byte(trimmed)) {
return "", truncateMessage(trimmed, 256)
}
var payload map[string]any
if err := json.Unmarshal([]byte(trimmed), &payload); err != nil {
return "", truncateMessage(trimmed, 256)
}
code := firstNonEmpty(
extractNestedString(payload, "error", "code"),
extractRootString(payload, "code"),
)
message := firstNonEmpty(
extractNestedString(payload, "error", "message"),
extractRootString(payload, "message"),
extractNestedString(payload, "error", "detail"),
extractRootString(payload, "detail"),
)
return strings.TrimSpace(code), truncateMessage(strings.TrimSpace(message), 512)
}
// TruncateBody truncates body text for logging/inspection.
func TruncateBody(body []byte, max int) string {
if max <= 0 {
max = 512
}
raw := strings.TrimSpace(string(body))
if len(raw) <= max {
return raw
}
return raw[:max] + "...(truncated)"
}
func truncateMessage(s string, max int) string {
if max <= 0 {
return ""
}
if len(s) <= max {
return s
}
return s[:max] + "...(truncated)"
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
if strings.TrimSpace(v) != "" {
return v
}
}
return ""
}
func extractRootString(m map[string]any, key string) string {
if m == nil {
return ""
}
v, ok := m[key]
if !ok {
return ""
}
s, _ := v.(string)
return s
}
func extractNestedString(m map[string]any, parent, key string) string {
if m == nil {
return ""
}
node, ok := m[parent]
if !ok {
return ""
}
child, ok := node.(map[string]any)
if !ok {
return ""
}
s, _ := child[key].(string)
return s
}

View File

@@ -0,0 +1,47 @@
package soraerror
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestIsCloudflareChallengeResponse(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-mitigated", "challenge")
require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`)))
require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>`)))
require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title>`)))
}
func TestExtractCloudflareRayID(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil))
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body))
}
func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) {
code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`))
require.Equal(t, "cf_shield_429", code)
require.Equal(t, "rate limited", msg)
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`))
require.Equal(t, "unsupported_country_code", code)
require.Equal(t, "not available", msg)
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`))
require.Equal(t, "", code)
require.Equal(t, "plain text", msg)
}
func TestFormatCloudflareChallengeMessage(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
msg := FormatCloudflareChallengeMessage("blocked", headers, nil)
require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg)
}

View File

@@ -0,0 +1,175 @@
package urlvalidator
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"
)
type ValidationOptions struct {
AllowedHosts []string
RequireAllowlist bool
AllowPrivate bool
}
// ValidateHTTPURL validates an outbound HTTP/HTTPS URL.
//
// It provides a single validation entry point that supports:
// - scheme 校验https 或可选允许 http
// - 可选 allowlist支持 *.example.com 通配)
// - allow_private_hosts 策略(阻断 localhost/私网字面量 IP
//
// 注意DNS Rebinding 防护(解析后 IP 校验)应在实际发起请求时执行,避免 TOCTOU。
func ValidateHTTPURL(raw string, allowInsecureHTTP bool, opts ValidationOptions) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("url is required")
}
parsed, err := url.Parse(trimmed)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid url: %s", trimmed)
}
scheme := strings.ToLower(parsed.Scheme)
if scheme != "https" && (!allowInsecureHTTP || scheme != "http") {
return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
return "", errors.New("invalid host")
}
if !opts.AllowPrivate && isBlockedHost(host) {
return "", fmt.Errorf("host is not allowed: %s", host)
}
if port := parsed.Port(); port != "" {
num, err := strconv.Atoi(port)
if err != nil || num <= 0 || num > 65535 {
return "", fmt.Errorf("invalid port: %s", port)
}
}
allowlist := normalizeAllowlist(opts.AllowedHosts)
if opts.RequireAllowlist && len(allowlist) == 0 {
return "", errors.New("allowlist is not configured")
}
if len(allowlist) > 0 && !isAllowedHost(host, allowlist) {
return "", fmt.Errorf("host is not allowed: %s", host)
}
parsed.Path = strings.TrimRight(parsed.Path, "/")
parsed.RawPath = ""
return strings.TrimRight(parsed.String(), "/"), nil
}
func ValidateURLFormat(raw string, allowInsecureHTTP bool) (string, error) {
// 最小格式校验:仅保证 URL 可解析且 scheme 合规,不做白名单/私网/SSRF 校验
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("url is required")
}
parsed, err := url.Parse(trimmed)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid url: %s", trimmed)
}
scheme := strings.ToLower(parsed.Scheme)
if scheme != "https" && (!allowInsecureHTTP || scheme != "http") {
return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
}
host := strings.TrimSpace(parsed.Hostname())
if host == "" {
return "", errors.New("invalid host")
}
if port := parsed.Port(); port != "" {
num, err := strconv.Atoi(port)
if err != nil || num <= 0 || num > 65535 {
return "", fmt.Errorf("invalid port: %s", port)
}
}
return strings.TrimRight(trimmed, "/"), nil
}
func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) {
return ValidateHTTPURL(raw, false, opts)
}
// ValidateResolvedIP 验证 DNS 解析后的 IP 地址是否安全
// 用于防止 DNS Rebinding 攻击:在实际 HTTP 请求时调用此函数验证解析后的 IP
func ValidateResolvedIP(host string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
if err != nil {
return fmt.Errorf("dns resolution failed: %w", err)
}
for _, ip := range ips {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
return fmt.Errorf("resolved ip %s is not allowed", ip.String())
}
}
return nil
}
func normalizeAllowlist(values []string) []string {
if len(values) == 0 {
return nil
}
normalized := make([]string, 0, len(values))
for _, v := range values {
entry := strings.ToLower(strings.TrimSpace(v))
if entry == "" {
continue
}
if host, _, err := net.SplitHostPort(entry); err == nil {
entry = host
}
normalized = append(normalized, entry)
}
return normalized
}
func isAllowedHost(host string, allowlist []string) bool {
for _, entry := range allowlist {
if entry == "" {
continue
}
if strings.HasPrefix(entry, "*.") {
suffix := strings.TrimPrefix(entry, "*.")
if host == suffix || strings.HasSuffix(host, "."+suffix) {
return true
}
continue
}
if host == entry {
return true
}
}
return false
}
func isBlockedHost(host string) bool {
if host == "localhost" || strings.HasSuffix(host, ".localhost") {
return true
}
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
return true
}
}
return false
}

View File

@@ -0,0 +1,75 @@
package urlvalidator
import "testing"
func TestValidateURLFormat(t *testing.T) {
if _, err := ValidateURLFormat("", false); err == nil {
t.Fatalf("expected empty url to fail")
}
if _, err := ValidateURLFormat("://bad", false); err == nil {
t.Fatalf("expected invalid url to fail")
}
if _, err := ValidateURLFormat("http://example.com", false); err == nil {
t.Fatalf("expected http to fail when allow_insecure_http is false")
}
if _, err := ValidateURLFormat("https://example.com", false); err != nil {
t.Fatalf("expected https to pass, got %v", err)
}
if _, err := ValidateURLFormat("http://example.com", true); err != nil {
t.Fatalf("expected http to pass when allow_insecure_http is true, got %v", err)
}
if _, err := ValidateURLFormat("https://example.com:bad", true); err == nil {
t.Fatalf("expected invalid port to fail")
}
// 验证末尾斜杠被移除
normalized, err := ValidateURLFormat("https://example.com/", false)
if err != nil {
t.Fatalf("expected trailing slash url to pass, got %v", err)
}
if normalized != "https://example.com" {
t.Fatalf("expected trailing slash to be removed, got %s", normalized)
}
// 验证多个末尾斜杠被移除
normalized, err = ValidateURLFormat("https://example.com///", false)
if err != nil {
t.Fatalf("expected multiple trailing slashes to pass, got %v", err)
}
if normalized != "https://example.com" {
t.Fatalf("expected all trailing slashes to be removed, got %s", normalized)
}
// 验证带路径的 URL 末尾斜杠被移除
normalized, err = ValidateURLFormat("https://example.com/api/v1/", false)
if err != nil {
t.Fatalf("expected trailing slash url with path to pass, got %v", err)
}
if normalized != "https://example.com/api/v1" {
t.Fatalf("expected trailing slash to be removed from path, got %s", normalized)
}
}
func TestValidateHTTPURL(t *testing.T) {
if _, err := ValidateHTTPURL("http://example.com", false, ValidationOptions{}); err == nil {
t.Fatalf("expected http to fail when allow_insecure_http is false")
}
if _, err := ValidateHTTPURL("http://example.com", true, ValidationOptions{}); err != nil {
t.Fatalf("expected http to pass when allow_insecure_http is true, got %v", err)
}
if _, err := ValidateHTTPURL("https://example.com", false, ValidationOptions{RequireAllowlist: true}); err == nil {
t.Fatalf("expected require allowlist to fail when empty")
}
if _, err := ValidateHTTPURL("https://example.com", false, ValidationOptions{AllowedHosts: []string{"api.example.com"}}); err == nil {
t.Fatalf("expected host not in allowlist to fail")
}
if _, err := ValidateHTTPURL("https://api.example.com", false, ValidationOptions{AllowedHosts: []string{"api.example.com"}}); err != nil {
t.Fatalf("expected allowlisted host to pass, got %v", err)
}
if _, err := ValidateHTTPURL("https://sub.api.example.com", false, ValidationOptions{AllowedHosts: []string{"*.example.com"}}); err != nil {
t.Fatalf("expected wildcard allowlist to pass, got %v", err)
}
if _, err := ValidateHTTPURL("https://localhost", false, ValidationOptions{AllowPrivate: false}); err == nil {
t.Fatalf("expected localhost to be blocked when allow_private_hosts is false")
}
}