Files
user-system/internal/auth/totp_test.go

207 lines
5.4 KiB
Go
Raw Normal View History

package auth
import (
"strings"
"testing"
)
func TestTOTPManager_GenerateAndValidate(t *testing.T) {
m := NewTOTPManager()
// 生成密钥
setup, err := m.GenerateSecret("testuser@example.com")
if err != nil {
t.Fatalf("GenerateSecret 失败: %v", err)
}
if setup.Secret == "" {
t.Fatal("生成的 Secret 不应为空")
}
if setup.QRCodeBase64 == "" {
t.Fatal("QRCode Base64 不应为空")
}
if len(setup.RecoveryCodes) != RecoveryCodeCount {
t.Fatalf("恢复码数量期望 %d实际 %d", RecoveryCodeCount, len(setup.RecoveryCodes))
}
t.Logf("生成 Secret: %s", setup.Secret)
t.Logf("恢复码示例: %s", setup.RecoveryCodes[0])
// 用生成的密钥生成当前 TOTP 码,再验证
code, err := m.GenerateCurrentCode(setup.Secret)
if err != nil {
t.Fatalf("GenerateCurrentCode 失败: %v", err)
}
if !m.ValidateCode(setup.Secret, code) {
t.Fatalf("有效 TOTP 码应该通过验证code=%s", code)
}
t.Logf("TOTP 验证通过code=%s", code)
}
func TestTOTPManager_InvalidCode(t *testing.T) {
m := NewTOTPManager()
setup, err := m.GenerateSecret("user")
if err != nil {
t.Fatalf("GenerateSecret 失败: %v", err)
}
// 错误的验证码
if m.ValidateCode(setup.Secret, "000000") {
// 偶尔可能恰好正确,跳过而不是 fatal
t.Skip("000000 碰巧是有效码,跳过测试")
}
t.Log("无效验证码正确拒绝")
}
func TestTOTPManager_RecoveryCodeFormat(t *testing.T) {
m := NewTOTPManager()
setup, err := m.GenerateSecret("user2")
if err != nil {
t.Fatalf("GenerateSecret 失败: %v", err)
}
for i, code := range setup.RecoveryCodes {
parts := strings.Split(code, "-")
if len(parts) != 2 {
t.Errorf("恢复码 [%d] 格式错误(期望 XXXXX-XXXXX: %s", i, code)
}
if len(parts[0]) != 5 || len(parts[1]) != 5 {
t.Errorf("恢复码 [%d] 各部分长度应为 5: %s", i, code)
}
}
}
func TestValidateRecoveryCode(t *testing.T) {
codes := []string{"ABCDE-FGHIJ", "KLMNO-PQRST", "UVWXY-ZABCD"}
// 正确匹配
idx, ok := ValidateRecoveryCode("ABCDE-FGHIJ", codes)
if !ok || idx != 0 {
t.Fatalf("有效恢复码应该匹配idx=%d ok=%v", idx, ok)
}
// 大小写不敏感
idx2, ok2 := ValidateRecoveryCode("klmno-pqrst", codes)
if !ok2 || idx2 != 1 {
t.Fatalf("大小写不敏感匹配失败idx=%d ok=%v", idx2, ok2)
}
// 去除空格
idx3, ok3 := ValidateRecoveryCode(" UVWXY-ZABCD ", codes)
if !ok3 || idx3 != 2 {
t.Fatalf("去除空格匹配失败idx=%d ok=%v", idx3, ok3)
}
// 不匹配
_, ok4 := ValidateRecoveryCode("XXXXX-YYYYY", codes)
if ok4 {
t.Fatal("无效恢复码不应该匹配")
}
t.Log("恢复码验证全部通过")
}
func TestHashRecoveryCode(t *testing.T) {
code := "ABCDE-FGHIJ"
hashed, err := HashRecoveryCode(code)
if err != nil {
t.Fatalf("HashRecoveryCode failed: %v", err)
}
if hashed == "" {
t.Fatal("HashRecoveryCode should return non-empty hash")
}
// Same code should produce same hash
hashed2, err := HashRecoveryCode(code)
if err != nil {
t.Fatalf("HashRecoveryCode second call failed: %v", err)
}
if hashed != hashed2 {
t.Error("Same code should produce same hash")
}
// Different codes should produce different hashes
hashed3, err := HashRecoveryCode("DIFFERENT-CODE")
if err != nil {
t.Fatalf("HashRecoveryCode for different code failed: %v", err)
}
if hashed == hashed3 {
t.Error("Different codes should produce different hashes")
}
t.Logf("Hashed code: %s", hashed)
}
func TestVerifyRecoveryCode(t *testing.T) {
// Generate hashed codes
codes := []string{"ABCDE-FGHIJ", "KLMNO-PQRST", "UVWXY-ZABCD"}
hashedCodes := make([]string, len(codes))
for i, code := range codes {
hashed, err := HashRecoveryCode(code)
if err != nil {
t.Fatalf("HashRecoveryCode failed: %v", err)
}
hashedCodes[i] = hashed
}
// Test valid code (exact match)
idx, ok := VerifyRecoveryCode("ABCDE-FGHIJ", hashedCodes)
if !ok || idx != 0 {
t.Fatalf("Valid recovery code should match, idx=%d ok=%v", idx, ok)
}
// Test second code
idx2, ok2 := VerifyRecoveryCode("KLMNO-PQRST", hashedCodes)
if !ok2 || idx2 != 1 {
t.Fatalf("Second code match failed, idx=%d ok=%v", idx2, ok2)
}
// Test third code
idx3, ok3 := VerifyRecoveryCode("UVWXY-ZABCD", hashedCodes)
if !ok3 || idx3 != 2 {
t.Fatalf("Third code match failed, idx=%d ok=%v", idx3, ok3)
}
// Test invalid code
_, ok4 := VerifyRecoveryCode("XXXXX-YYYYY", hashedCodes)
if ok4 {
t.Fatal("Invalid recovery code should not match")
}
// Test empty hashed codes list
_, ok5 := VerifyRecoveryCode("ABCDE-FGHIJ", []string{})
if ok5 {
t.Fatal("Should not match against empty list")
}
t.Log("VerifyRecoveryCode tests passed")
}
func TestVerifyRecoveryCode_TimingSafety(t *testing.T) {
// Test that the function always iterates through all codes
// regardless of where the match is found (timing attack prevention)
codes := []string{"CODE1-AAAAA", "CODE2-BBBBB", "CODE3-CCCCC"}
hashedCodes := make([]string, len(codes))
for i, code := range codes {
hashed, _ := HashRecoveryCode(code)
hashedCodes[i] = hashed
}
// Test matching first code
idx1, ok1 := VerifyRecoveryCode("CODE1-AAAAA", hashedCodes)
if !ok1 || idx1 != 0 {
t.Errorf("First code match failed, idx=%d ok=%v", idx1, ok1)
}
// Test matching last code
idx3, ok3 := VerifyRecoveryCode("CODE3-CCCCC", hashedCodes)
if !ok3 || idx3 != 2 {
t.Errorf("Last code match failed, idx=%d ok=%v", idx3, ok3)
}
t.Log("Timing safety test passed")
}