package security import ( "testing" "github.com/stretchr/testify/assert" ) // TestNewValidator 测试 Validator 创建 func TestNewValidator(t *testing.T) { v := NewValidator(8, true, true) assert.NotNil(t, v) assert.Equal(t, 8, v.passwordMinLength) assert.True(t, v.passwordRequireSpecial) assert.True(t, v.passwordRequireNumber) v2 := NewValidator(6, false, false) assert.Equal(t, 6, v2.passwordMinLength) assert.False(t, v2.passwordRequireSpecial) assert.False(t, v2.passwordRequireNumber) } // TestValidator_ValidateEmail 测试邮箱验证 func TestValidator_ValidateEmail(t *testing.T) { v := NewValidator(8, true, true) tests := []struct { name string email string expected bool }{ {"empty", "", false}, {"invalid", "invalid", false}, {"no at", "test.example.com", false}, {"no domain", "test@", false}, {"no user", "@example.com", false}, {"valid simple", "test@example.com", true}, {"valid with dot", "test.user@example.com", true}, {"valid with plus", "test+tag@example.com", true}, {"valid subdomain", "test@mail.example.com", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := v.ValidateEmail(tt.email) assert.Equal(t, tt.expected, got) }) } } // TestValidator_ValidatePhone 测试手机号验证 func TestValidator_ValidatePhone(t *testing.T) { v := NewValidator(8, true, true) tests := []struct { name string phone string expected bool }{ {"empty", "", false}, {"invalid format", "12345678901", false}, {"too short", "1380013800", false}, {"too long", "138001380001", false}, {"invalid prefix 1", "12800138000", false}, {"invalid prefix 2", "10800138000", false}, {"valid 13x", "13800138000", true}, {"valid 15x", "15800138000", true}, {"valid 18x", "18800138000", true}, {"valid 19x", "19800138000", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := v.ValidatePhone(tt.phone) assert.Equal(t, tt.expected, got) }) } } // TestValidator_ValidateUsername 测试用户名验证 func TestValidator_ValidateUsername(t *testing.T) { v := NewValidator(8, true, true) tests := []struct { name string username string expected bool }{ {"empty", "", false}, {"too short", "abc", false}, {"starts with number", "1abc", false}, {"starts with underscore", "_abc", false}, {"contains special", "abc@123", false}, {"valid lowercase", "abc123", true}, {"valid uppercase", "Abc123", true}, {"valid with underscore", "abc_123", true}, {"valid max length", "abcd1234abcd1234abcd", true}, {"too long", "abcd1234abcd1234abcd1", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := v.ValidateUsername(tt.username) assert.Equal(t, tt.expected, got) }) } } // TestValidator_ValidatePassword 测试密码验证 func TestValidator_ValidatePassword(t *testing.T) { v := NewValidator(8, true, true) tests := []struct { name string password string expected bool }{ {"too short", "Abc1!", false}, {"no number", "Abcdefgh!", false}, {"no special", "Abcdefgh1", false}, {"no uppercase", "abcdefgh1!", false}, {"no lowercase", "ABCDEFGH1!", false}, {"valid complex", "Abcdef1!", true}, {"valid longer", "Abcdefgh123!", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := v.ValidatePassword(tt.password) assert.Equal(t, tt.expected, got) }) } } // TestValidator_ValidateURL 测试 URL 验证 func TestValidator_ValidateURL(t *testing.T) { v := NewValidator(8, true, true) tests := []struct { name string url string expected bool }{ {"empty", "", false}, {"no scheme", "example.com", false}, {"http", "http://example.com", true}, {"https", "https://example.com", true}, {"with path", "https://example.com/path", true}, {"with query", "https://example.com?foo=bar", true}, {"with fragment", "https://example.com#section", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := v.ValidateURL(tt.url) assert.Equal(t, tt.expected, got) }) } } // TestValidator_ValidateIP 测试 IP 验证 func TestValidator_ValidateIP(t *testing.T) { v := NewValidator(8, true, true) tests := []struct { name string ip string expected bool }{ {"empty", "", false}, {"invalid", "not-an-ip", false}, {"IPv4 valid", "192.168.1.1", true}, {"IPv4 invalid", "192.168.1.256", false}, {"IPv6 valid", "::1", true}, {"IPv6 valid full", "2001:0db8:85a3:0000:0000:8a2e:0370:7334", true}, {"IPv6 compressed", "fe80::1", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := v.ValidateIP(tt.ip) assert.Equal(t, tt.expected, got) }) } } // TestValidator_ValidateIPv4 测试 IPv4 验证 func TestValidator_ValidateIPv4(t *testing.T) { v := NewValidator(8, true, true) tests := []struct { name string ip string expected bool }{ {"empty", "", false}, {"IPv4 valid", "192.168.1.1", true}, {"IPv4 invalid", "192.168.1.256", false}, {"IPv6 localhost", "::1", false}, // IPv6 should fail IPv4 validation {"IPv6 full", "2001:0db8:85a3::8a2e:0370:7334", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := v.ValidateIPv4(tt.ip) assert.Equal(t, tt.expected, got) }) } } // TestValidator_ValidateIPv6 测试 IPv6 验证 func TestValidator_ValidateIPv6(t *testing.T) { v := NewValidator(8, true, true) tests := []struct { name string ip string expected bool }{ {"empty", "", false}, {"IPv4 valid", "192.168.1.1", false}, // IPv4 should fail IPv6 validation {"IPv6 localhost", "::1", true}, {"IPv6 full", "2001:0db8:85a3:0000:0000:8a2e:0370:7334", true}, {"IPv6 compressed", "fe80::1", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := v.ValidateIPv6(tt.ip) assert.Equal(t, tt.expected, got) }) } } // TestValidator_SanitizeSQL 测试 SQL 净化 func TestValidator_SanitizeSQL(t *testing.T) { v := NewValidator(8, true, true) tests := []struct { name string input string expected string }{ {"empty", "", ""}, {"normal text", "hello world", "hello world"}, {"quote escape", "'test'", "''test''"}, {"backslash escape", "\\test", "\\test"}, {"remove comment", "select; -- comment", "select "}, {"remove block comment", "select /* comment */ from", "select from"}, {"remove union", "select union select", "select "}, {"remove drop", "drop table users", ""}, {"remove insert", "insert into users", ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := v.SanitizeSQL(tt.input) // 检查输出不包含危险模式 assert.NotContains(t, got, "--") assert.NotContains(t, got, "/*") assert.NotContains(t, got, "*/") }) } } // TestValidator_SanitizeXSS 测试 XSS 净化 func TestValidator_SanitizeXSS(t *testing.T) { v := NewValidator(8, true, true) tests := []struct { name string input string checkNot string }{ {"empty", "", ""}, {"normal text", "hello world", ""}, {"remove script", "", "script"}, {"remove iframe", "", "iframe"}, {"remove javascript", "javascript:alert(1)", "javascript:"}, {"remove event handler", "", "onerror"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := v.SanitizeXSS(tt.input) if tt.checkNot != "" { assert.NotContains(t, got, tt.checkNot) } }) } }