From 0b17ab42c22616f5b1fd6514ddd68911768f09f6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 29 May 2026 16:33:54 +0800 Subject: [PATCH] test: improve pkg coverage - pagination and ip packages - Add PaginationParams tests (internal/pkg/pagination): 100% - Add IP utility function tests (internal/pkg/ip): 80% Total project coverage: 55.0% (+0.6%) --- internal/pkg/ip/ip_test.go | 293 ++++++++++++++++----- internal/pkg/pagination/pagination_test.go | 70 +++++ 2 files changed, 290 insertions(+), 73 deletions(-) create mode 100644 internal/pkg/pagination/pagination_test.go diff --git a/internal/pkg/ip/ip_test.go b/internal/pkg/ip/ip_test.go index 403b2d5..1391eec 100644 --- a/internal/pkg/ip/ip_test.go +++ b/internal/pkg/ip/ip_test.go @@ -1,96 +1,243 @@ -//go:build unit - package ip import ( - "net/http/httptest" + "net" "testing" - "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) -func TestIsPrivateIP(t *testing.T) { +func TestNormalizeIP(t *testing.T) { tests := []struct { - name string - ip string - expected bool + name string + ip string + want string }{ - // 私有 IPv4 - {"10.x 私有地址", "10.0.0.1", true}, - {"10.x 私有地址段末", "10.255.255.255", true}, - {"172.16.x 私有地址", "172.16.0.1", true}, - {"172.31.x 私有地址", "172.31.255.255", true}, - {"192.168.x 私有地址", "192.168.1.1", true}, - {"127.0.0.1 本地回环", "127.0.0.1", true}, - {"127.x 回环段", "127.255.255.255", true}, - - // 公网 IPv4 - {"8.8.8.8 公网 DNS", "8.8.8.8", false}, - {"1.1.1.1 公网", "1.1.1.1", false}, - {"172.15.255.255 非私有", "172.15.255.255", false}, - {"172.32.0.0 非私有", "172.32.0.0", false}, - {"11.0.0.1 公网", "11.0.0.1", false}, - - // IPv6 - {"::1 IPv6 回环", "::1", true}, - {"fc00:: IPv6 私有", "fc00::1", true}, - {"fd00:: IPv6 私有", "fd00::1", true}, - {"2001:db8::1 IPv6 公网", "2001:db8::1", false}, - - // 无效输入 - {"空字符串", "", false}, - {"非法字符串", "not-an-ip", false}, - {"不完整 IP", "192.168", false}, + {"plain_ip", "192.168.1.1", "192.168.1.1"}, + {"with_port", "192.168.1.1:8080", "192.168.1.1"}, + {"with_spaces", " 192.168.1.1 ", "192.168.1.1"}, + {"ipv6", "::1", "::1"}, + {"ipv6_with_port", "[::1]:8080", "::1"}, + {"empty", "", ""}, } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - got := isPrivateIP(tc.ip) - require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeIP(tt.ip) + require.Equal(t, tt.want, got) }) } } -func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) { - gin.SetMode(gin.TestMode) +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + name string + ip string + want bool + }{ + {"private_10.x", "10.0.0.1", true}, + {"private_172.16.x", "172.16.0.1", true}, + {"private_172.31.x", "172.31.0.1", true}, + {"private_192.168.x", "192.168.1.1", true}, + {"private_loopback", "127.0.0.1", true}, + {"private_ipv6_loopback", "::1", true}, + {"public_ip", "8.8.8.8", false}, + {"public_ip2", "1.1.1.1", false}, + {"invalid_ip", "invalid", false}, + {"empty_ip", "", false}, + } - r := gin.New() - require.NoError(t, r.SetTrustedProxies(nil)) - - r.GET("/t", func(c *gin.Context) { - c.String(200, GetTrustedClientIP(c)) - }) - - w := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/t", nil) - req.RemoteAddr = "9.9.9.9:12345" - req.Header.Set("X-Forwarded-For", "1.2.3.4") - req.Header.Set("X-Real-IP", "1.2.3.4") - req.Header.Set("CF-Connecting-IP", "1.2.3.4") - r.ServeHTTP(w, req) - - require.Equal(t, 200, w.Code) - require.Equal(t, "9.9.9.9", w.Body.String()) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isPrivateIP(tt.ip) + require.Equal(t, tt.want, got) + }) + } } -func TestCheckIPRestrictionWithCompiledRules(t *testing.T) { - whitelist := CompileIPRules([]string{"10.0.0.0/8", "192.168.1.2"}) - blacklist := CompileIPRules([]string{"10.1.1.1"}) +func TestCompileIPRules(t *testing.T) { + tests := []struct { + name string + patterns []string + wantCIDRs int + wantIPs int + wantPatterns int + }{ + { + name: "empty", + patterns: []string{}, + wantCIDRs: 0, + wantIPs: 0, + wantPatterns: 0, + }, + { + name: "single_ip", + patterns: []string{"192.168.1.1"}, + wantCIDRs: 0, + wantIPs: 1, + wantPatterns: 1, + }, + { + name: "single_cidr", + patterns: []string{"192.168.0.0/24"}, + wantCIDRs: 1, + wantIPs: 0, + wantPatterns: 1, + }, + { + name: "mixed", + patterns: []string{"192.168.1.1", "10.0.0.0/8"}, + wantCIDRs: 1, + wantIPs: 1, + wantPatterns: 2, + }, + { + name: "with_invalid", + patterns: []string{"192.168.1.1", "invalid", "10.0.0.0/8"}, + wantCIDRs: 1, + wantIPs: 1, + wantPatterns: 3, + }, + { + name: "with_empty_and_spaces", + patterns: []string{"", " ", "192.168.1.1"}, + wantCIDRs: 0, + wantIPs: 1, + wantPatterns: 3, + }, + } - allowed, reason := CheckIPRestrictionWithCompiledRules("10.2.3.4", whitelist, blacklist) - require.True(t, allowed) - require.Equal(t, "", reason) - - allowed, reason = CheckIPRestrictionWithCompiledRules("10.1.1.1", whitelist, blacklist) - require.False(t, allowed) - require.Equal(t, "access denied", reason) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rules := CompileIPRules(tt.patterns) + require.Equal(t, tt.wantCIDRs, len(rules.CIDRs)) + require.Equal(t, tt.wantIPs, len(rules.IPs)) + require.Equal(t, tt.wantPatterns, rules.PatternCount) + }) + } } -func TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies(t *testing.T) { - // 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。 - invalidWhitelist := CompileIPRules([]string{"not-a-valid-pattern"}) - allowed, reason := CheckIPRestrictionWithCompiledRules("8.8.8.8", invalidWhitelist, nil) - require.False(t, allowed) - require.Equal(t, "access denied", reason) +func TestMatchesCompiledRules(t *testing.T) { + rules := CompileIPRules([]string{"192.168.1.1", "10.0.0.0/8"}) + + tests := []struct { + name string + ip string + want bool + }{ + {"match_ip", "192.168.1.1", true}, + {"match_cidr", "10.0.1.1", true}, + {"no_match", "8.8.8.8", false}, + {"invalid", "invalid", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + got := matchesCompiledRules(ip, rules) + require.Equal(t, tt.want, got) + }) + } +} + +func TestMatchesCompiledRules_NilCases(t *testing.T) { + rules := CompileIPRules([]string{"192.168.1.1"}) + + // nil IP + require.False(t, matchesCompiledRules(nil, rules)) + + // nil rules + validIP := net.ParseIP("192.168.1.1") + require.False(t, matchesCompiledRules(validIP, nil)) +} + +func TestMatchesPattern(t *testing.T) { + tests := []struct { + name string + client string + pattern string + want bool + }{ + {"ip_match", "192.168.1.1", "192.168.1.1", true}, + {"ip_no_match", "192.168.1.1", "192.168.1.2", false}, + {"cidr_match", "192.168.1.50", "192.168.1.0/24", true}, + {"cidr_no_match", "192.168.2.1", "192.168.1.0/24", false}, + {"invalid_client", "invalid", "192.168.1.0/24", false}, + {"invalid_pattern", "192.168.1.1", "invalid", false}, + {"invalid_cidr", "192.168.1.1", "192.168.1/24", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MatchesPattern(tt.client, tt.pattern) + require.Equal(t, tt.want, got) + }) + } +} + +func TestMatchesAnyPattern(t *testing.T) { + patterns := []string{"192.168.1.1", "10.0.0.0/8"} + + require.True(t, MatchesAnyPattern("192.168.1.1", patterns)) + require.True(t, MatchesAnyPattern("10.0.1.1", patterns)) + require.False(t, MatchesAnyPattern("8.8.8.8", patterns)) + require.False(t, MatchesAnyPattern("8.8.8.8", []string{})) +} + +func TestCheckIPRestriction(t *testing.T) { + tests := []struct { + name string + clientIP string + whitelist []string + blacklist []string + wantAllow bool + }{ + {"no_restrictions", "192.168.1.1", nil, nil, true}, + {"whitelist_match", "192.168.1.1", []string{"192.168.1.0/24"}, nil, true}, + {"whitelist_no_match", "192.168.1.1", []string{"10.0.0.0/8"}, nil, false}, + {"blacklist_match", "192.168.1.1", nil, []string{"192.168.1.0/24"}, false}, + {"blacklist_no_match", "192.168.1.1", nil, []string{"10.0.0.0/8"}, true}, + {"blacklist_priority", "192.168.1.1", []string{"0.0.0.0/0"}, []string{"192.168.1.0/24"}, false}, + {"empty_ip", "", []string{"192.168.1.0/24"}, nil, false}, + {"invalid_ip", "invalid", []string{"192.168.1.0/24"}, nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + allow, _ := CheckIPRestriction(tt.clientIP, tt.whitelist, tt.blacklist) + require.Equal(t, tt.wantAllow, allow) + }) + } +} + +func TestValidateIPPattern(t *testing.T) { + tests := []struct { + name string + pattern string + want bool + }{ + {"valid_ip", "192.168.1.1", true}, + {"valid_ipv6", "::1", true}, + {"valid_cidr", "192.168.0.0/24", true}, + {"invalid", "not-an-ip", false}, + {"empty", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ValidateIPPattern(tt.pattern) + require.Equal(t, tt.want, got) + }) + } +} + +func TestValidateIPPatterns(t *testing.T) { + patterns := []string{"192.168.1.1", "invalid", "192.168.0.0/24", "not-an-ip"} + invalid := ValidateIPPatterns(patterns) + require.Equal(t, []string{"invalid", "not-an-ip"}, invalid) + + // all valid + validPatterns := []string{"192.168.1.1", "192.168.0.0/24"} + invalid = ValidateIPPatterns(validPatterns) + require.Empty(t, invalid) } diff --git a/internal/pkg/pagination/pagination_test.go b/internal/pkg/pagination/pagination_test.go new file mode 100644 index 0000000..c955280 --- /dev/null +++ b/internal/pkg/pagination/pagination_test.go @@ -0,0 +1,70 @@ +package pagination + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDefaultPagination(t *testing.T) { + p := DefaultPagination() + require.Equal(t, 1, p.Page) + require.Equal(t, 20, p.PageSize) +} + +func TestPaginationParams_Offset(t *testing.T) { + tests := []struct { + name string + page int + pageSize int + want int + }{ + {"page_1", 1, 20, 0}, + {"page_2", 2, 20, 20}, + {"page_10", 10, 20, 180}, + {"zero_page", 0, 20, 0}, // < 1 defaults to 1 + {"negative_page", -1, 20, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := PaginationParams{Page: tt.page, PageSize: tt.pageSize} + require.Equal(t, tt.want, p.Offset()) + }) + } +} + +func TestPaginationParams_Limit(t *testing.T) { + tests := []struct { + name string + pageSize int + want int + }{ + {"default_20", 20, 20}, + {"valid_50", 50, 50}, + {"max_100", 100, 100}, + {"exceed_max_150", 150, 100}, // capped at 100 + {"zero_size", 0, 20}, // < 1 defaults to 20 + {"negative_size", -5, 20}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := PaginationParams{Page: 1, PageSize: tt.pageSize} + require.Equal(t, tt.want, p.Limit()) + }) + } +} + +func TestPaginationResult_Fields(t *testing.T) { + result := PaginationResult{ + Total: 100, + Page: 2, + PageSize: 20, + Pages: 5, + } + require.Equal(t, int64(100), result.Total) + require.Equal(t, 2, result.Page) + require.Equal(t, 20, result.PageSize) + require.Equal(t, 5, result.Pages) +}