diff --git a/src/test/java/com/mosquito/project/web/ApiKeyAuthInterceptorTest.java b/src/test/java/com/mosquito/project/web/ApiKeyAuthInterceptorTest.java new file mode 100644 index 0000000..684a927 --- /dev/null +++ b/src/test/java/com/mosquito/project/web/ApiKeyAuthInterceptorTest.java @@ -0,0 +1,232 @@ +package com.mosquito.project.web; + +import com.mosquito.project.persistence.entity.ApiKeyEntity; +import com.mosquito.project.persistence.repository.ApiKeyRepository; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.time.OffsetDateTime; +import java.util.Base64; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +@DisplayName("ApiKeyAuthInterceptor 测试") +class ApiKeyAuthInterceptorTest { + + @Mock + private ApiKeyRepository apiKeyRepository; + + @Mock + private HttpServletRequest request; + + @Mock + private HttpServletResponse response; + + @Mock + private Object handler; + + private ApiKeyAuthInterceptor interceptor; + + @BeforeEach + void setUp() { + interceptor = new ApiKeyAuthInterceptor(apiKeyRepository); + } + + @Test + @DisplayName("应该拒绝null API Key") + void shouldRejectNullApiKey() { + // Given + when(request.getHeader("X-API-Key")).thenReturn(null); + + // When + boolean result = interceptor.preHandle(request, response, handler); + + // Then + assertThat(result).isFalse(); + verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED); + verify(apiKeyRepository, never()).findByKeyPrefix(anyString()); + } + + @Test + @DisplayName("应该拒绝空白API Key") + void shouldRejectBlankApiKey() { + // Given + when(request.getHeader("X-API-Key")).thenReturn(" "); + + // When + boolean result = interceptor.preHandle(request, response, handler); + + // Then + assertThat(result).isFalse(); + verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + @DisplayName("应该拒绝不存在的API Key前缀") + void shouldRejectNonExistentKeyPrefix() { + // Given + when(request.getHeader("X-API-Key")).thenReturn("test-api-key-12345"); + when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.empty()); + + // When + boolean result = interceptor.preHandle(request, response, handler); + + // Then + assertThat(result).isFalse(); + verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + @DisplayName("应该拒绝已吊销的API Key") + void shouldRejectRevokedApiKey() { + // Given + String apiKey = "test-api-key-12345"; + ApiKeyEntity entity = new ApiKeyEntity(); + entity.setRevokedAt(OffsetDateTime.now()); + + when(request.getHeader("X-API-Key")).thenReturn(apiKey); + when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.of(entity)); + + // When + boolean result = interceptor.preHandle(request, response, handler); + + // Then + assertThat(result).isFalse(); + verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + @DisplayName("应该拒绝哈希不匹配的API Key") + void shouldRejectMismatchedApiKeyHash() throws Exception { + // Given + String apiKey = "test-api-key-12345"; + byte[] salt = new byte[16]; + String saltBase64 = Base64.getEncoder().encodeToString(salt); + + ApiKeyEntity entity = new ApiKeyEntity(); + entity.setSalt(saltBase64); + entity.setKeyHash("wrong-hash"); + entity.setRevokedAt(null); + + when(request.getHeader("X-API-Key")).thenReturn(apiKey); + when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.of(entity)); + + // When + boolean result = interceptor.preHandle(request, response, handler); + + // Then + assertThat(result).isFalse(); + verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + @DisplayName("应该接受有效的API Key") + void shouldAcceptValidApiKey() throws Exception { + // Given + String apiKey = "test-api-key-12345"; + byte[] salt = new byte[16]; + String saltBase64 = Base64.getEncoder().encodeToString(salt); + + // 计算正确的哈希 + javax.crypto.SecretKeyFactory skf = javax.crypto.SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256"); + javax.crypto.spec.PBEKeySpec spec = new javax.crypto.spec.PBEKeySpec(apiKey.toCharArray(), salt, 185000, 256); + byte[] derived = skf.generateSecret(spec).getEncoded(); + String correctHash = Base64.getEncoder().encodeToString(derived); + + ApiKeyEntity entity = new ApiKeyEntity(); + entity.setSalt(saltBase64); + entity.setKeyHash(correctHash); + entity.setRevokedAt(null); + + when(request.getHeader("X-API-Key")).thenReturn(apiKey); + when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.of(entity)); + + // When + boolean result = interceptor.preHandle(request, response, handler); + + // Then + assertThat(result).isTrue(); + verify(response, never()).setStatus(anyInt()); + verify(request).setAttribute(eq("apiKeyPrefix"), anyString()); + } + + @Test + @DisplayName("应该处理短API Key") + void shouldHandleShortApiKey() { + // Given + String shortKey = "short"; + when(request.getHeader("X-API-Key")).thenReturn(shortKey); + when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.empty()); + + // When + boolean result = interceptor.preHandle(request, response, handler); + + // Then + assertThat(result).isFalse(); + verify(apiKeyRepository).findByKeyPrefix("short"); + } + + @Test + @DisplayName("应该处理加密异常") + void shouldHandleCryptoException() { + // Given + String apiKey = "test-api-key-12345"; + + ApiKeyEntity entity = new ApiKeyEntity(); + entity.setSalt("invalid-base64!!!"); // 无效的Base64会导致异常 + entity.setKeyHash("some-hash"); + entity.setRevokedAt(null); + + when(request.getHeader("X-API-Key")).thenReturn(apiKey); + when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.of(entity)); + + // When + boolean result = interceptor.preHandle(request, response, handler); + + // Then + assertThat(result).isFalse(); + verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + @DisplayName("应该正确提取API Key前缀") + void shouldExtractCorrectPrefix() { + // Given + String apiKey = "abcdefghijklmnopqrstuvwxyz"; + when(request.getHeader("X-API-Key")).thenReturn(apiKey); + when(apiKeyRepository.findByKeyPrefix("abcdefghijkl")).thenReturn(Optional.empty()); + + // When + interceptor.preHandle(request, response, handler); + + // Then + verify(apiKeyRepository).findByKeyPrefix("abcdefghijkl"); + } + + @Test + @DisplayName("应该处理带空格的API Key") + void shouldHandleApiKeyWithSpaces() { + // Given + String apiKey = " test-key-123 "; + when(request.getHeader("X-API-Key")).thenReturn(apiKey); + when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.empty()); + + // When + boolean result = interceptor.preHandle(request, response, handler); + + // Then + assertThat(result).isFalse(); + // 前缀应该被trim + verify(apiKeyRepository).findByKeyPrefix(contains("test-key")); + } +} diff --git a/src/test/java/com/mosquito/project/web/ApiResponseWrapperInterceptorTest.java b/src/test/java/com/mosquito/project/web/ApiResponseWrapperInterceptorTest.java new file mode 100644 index 0000000..c584d72 --- /dev/null +++ b/src/test/java/com/mosquito/project/web/ApiResponseWrapperInterceptorTest.java @@ -0,0 +1,203 @@ +package com.mosquito.project.web; + +import com.mosquito.project.config.ApiVersion; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.web.servlet.ModelAndView; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +@DisplayName("ApiResponseWrapperInterceptor 测试") +class ApiResponseWrapperInterceptorTest { + + @Mock + private HttpServletRequest request; + + @Mock + private HttpServletResponse response; + + @Mock + private Object handler; + + @Mock + private ModelAndView modelAndView; + + private ApiResponseWrapperInterceptor interceptor; + + @BeforeEach + void setUp() { + interceptor = new ApiResponseWrapperInterceptor(); + } + + @Test + @DisplayName("preHandle应该设置startTime属性并返回true") + void shouldSetStartTimeAndReturnTrue_whenPreHandle() { + // When + boolean result = interceptor.preHandle(request, response, handler); + + // Then + assertThat(result).isTrue(); + verify(request).setAttribute(eq("startTime"), anyLong()); + } + + @Test + @DisplayName("postHandle应该为成功响应设置API版本头") + void shouldSetApiVersionHeader_whenResponseIsSuccessful() { + // Given + when(response.getStatus()).thenReturn(200); + when(request.getHeader(ApiVersion.HEADER_NAME)).thenReturn("v1"); + + // When + interceptor.postHandle(request, response, handler, modelAndView); + + // Then + verify(response).setHeader(ApiVersion.HEADER_NAME, "v1"); + } + + @Test + @DisplayName("postHandle应该使用默认版本当请求头为null") + void shouldUseDefaultVersion_whenRequestHeaderIsNull() { + // Given + when(response.getStatus()).thenReturn(200); + when(request.getHeader(ApiVersion.HEADER_NAME)).thenReturn(null); + + // When + interceptor.postHandle(request, response, handler, modelAndView); + + // Then + verify(response).setHeader(ApiVersion.HEADER_NAME, ApiVersion.DEFAULT_VERSION); + } + + @Test + @DisplayName("postHandle应该使用默认版本当请求头为空白") + void shouldUseDefaultVersion_whenRequestHeaderIsBlank() { + // Given + when(response.getStatus()).thenReturn(200); + when(request.getHeader(ApiVersion.HEADER_NAME)).thenReturn(" "); + + // When + interceptor.postHandle(request, response, handler, modelAndView); + + // Then + verify(response).setHeader(ApiVersion.HEADER_NAME, ApiVersion.DEFAULT_VERSION); + } + + @Test + @DisplayName("postHandle不应该为错误响应设置版本头") + void shouldNotSetVersionHeader_whenResponseIsError() { + // Given + when(response.getStatus()).thenReturn(400); + + // When + interceptor.postHandle(request, response, handler, modelAndView); + + // Then + verify(response, never()).setHeader(anyString(), anyString()); + } + + @Test + @DisplayName("postHandle不应该为服务器错误设置版本头") + void shouldNotSetVersionHeader_whenResponseIsServerError() { + // Given + when(response.getStatus()).thenReturn(500); + + // When + interceptor.postHandle(request, response, handler, modelAndView); + + // Then + verify(response, never()).setHeader(anyString(), anyString()); + } + + @Test + @DisplayName("afterCompletion应该记录API请求日志") + void shouldLogApiRequest_whenAfterCompletion() { + // Given + when(request.getAttribute("startTime")).thenReturn(System.currentTimeMillis() - 100); + when(request.getRequestURI()).thenReturn("/api/v1/activities"); + when(request.getMethod()).thenReturn("GET"); + + // When + interceptor.afterCompletion(request, response, handler, null); + + // Then - 验证没有抛出异常 + verify(request).getAttribute("startTime"); + verify(request, atLeastOnce()).getRequestURI(); + } + + @Test + @DisplayName("afterCompletion应该处理非API请求") + void shouldHandleNonApiRequest_whenAfterCompletion() { + // Given + when(request.getAttribute("startTime")).thenReturn(System.currentTimeMillis()); + when(request.getRequestURI()).thenReturn("/health"); + + // When + interceptor.afterCompletion(request, response, handler, null); + + // Then - 验证没有抛出异常 + verify(request).getAttribute("startTime"); + verify(request).getRequestURI(); + } + + @Test + @DisplayName("afterCompletion应该处理异常情况") + void shouldHandleException_whenAfterCompletion() { + // Given + when(request.getAttribute("startTime")).thenReturn(System.currentTimeMillis()); + when(request.getRequestURI()).thenReturn("/api/v1/test"); + when(request.getMethod()).thenReturn("POST"); + Exception exception = new RuntimeException("Test exception"); + + // When + interceptor.afterCompletion(request, response, handler, exception); + + // Then - 验证没有抛出异常 + verify(request).getAttribute("startTime"); + } + + @Test + @DisplayName("postHandle应该处理2xx范围内的所有成功状态码") + void shouldHandleAllSuccessStatusCodes_whenPostHandle() { + // Test 200 OK + when(response.getStatus()).thenReturn(200); + when(request.getHeader(ApiVersion.HEADER_NAME)).thenReturn("v1"); + interceptor.postHandle(request, response, handler, modelAndView); + verify(response).setHeader(ApiVersion.HEADER_NAME, "v1"); + + // Test 201 Created + reset(response, request); + when(response.getStatus()).thenReturn(201); + when(request.getHeader(ApiVersion.HEADER_NAME)).thenReturn("v1"); + interceptor.postHandle(request, response, handler, modelAndView); + verify(response).setHeader(ApiVersion.HEADER_NAME, "v1"); + + // Test 204 No Content + reset(response, request); + when(response.getStatus()).thenReturn(204); + when(request.getHeader(ApiVersion.HEADER_NAME)).thenReturn("v1"); + interceptor.postHandle(request, response, handler, modelAndView); + verify(response).setHeader(ApiVersion.HEADER_NAME, "v1"); + } + + @Test + @DisplayName("postHandle应该拒绝3xx重定向状态码") + void shouldNotSetHeaderForRedirectStatus_whenPostHandle() { + // Given + when(response.getStatus()).thenReturn(302); + + // When + interceptor.postHandle(request, response, handler, modelAndView); + + // Then + verify(response, never()).setHeader(anyString(), anyString()); + } +} diff --git a/src/test/java/com/mosquito/project/web/UrlValidatorTest.java b/src/test/java/com/mosquito/project/web/UrlValidatorTest.java new file mode 100644 index 0000000..a1af015 --- /dev/null +++ b/src/test/java/com/mosquito/project/web/UrlValidatorTest.java @@ -0,0 +1,164 @@ +package com.mosquito.project.web; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.assertj.core.api.Assertions.assertThat; + +@DisplayName("UrlValidator 测试") +class UrlValidatorTest { + + private UrlValidator urlValidator; + + @BeforeEach + void setUp() { + urlValidator = new UrlValidator(); + } + + @Test + @DisplayName("应该拒绝null URL") + void shouldRejectNullUrl() { + assertThat(urlValidator.isAllowedUrl(null)).isFalse(); + } + + @Test + @DisplayName("应该拒绝空白URL") + void shouldRejectBlankUrl() { + assertThat(urlValidator.isAllowedUrl("")).isFalse(); + assertThat(urlValidator.isAllowedUrl(" ")).isFalse(); + } + + @Test + @DisplayName("应该拒绝相对URL") + void shouldRejectRelativeUrl() { + assertThat(urlValidator.isAllowedUrl("/path/to/resource")).isFalse(); + assertThat(urlValidator.isAllowedUrl("path/to/resource")).isFalse(); + } + + @ParameterizedTest + @ValueSource(strings = { + "ftp://example.com", + "file:///etc/passwd", + "javascript:alert(1)", + "data:text/html," + }) + @DisplayName("应该拒绝不允许的协议") + void shouldRejectDisallowedSchemes(String url) { + assertThat(urlValidator.isAllowedUrl(url)).isFalse(); + } + + @ParameterizedTest + @ValueSource(strings = { + "http://localhost", + "http://127.0.0.1", + "http://::1", + "http://0.0.0.0", + "https://localhost:8080" + }) + @DisplayName("应该拒绝localhost地址") + void shouldRejectLocalhostAddresses(String url) { + assertThat(urlValidator.isAllowedUrl(url)).isFalse(); + } + + @ParameterizedTest + @ValueSource(strings = { + "http://10.0.0.1", + "http://10.255.255.255", + "http://172.16.0.1", + "http://172.31.255.255", + "http://192.168.0.1", + "http://192.168.255.255" + }) + @DisplayName("应该拒绝私有IP地址") + void shouldRejectPrivateIpAddresses(String url) { + assertThat(urlValidator.isAllowedUrl(url)).isFalse(); + } + + @Test + @DisplayName("应该接受有效的公网URL - google.com") + void shouldAcceptValidPublicUrls() { + // 使用真实存在的公网域名进行测试 + assertThat(urlValidator.isAllowedUrl("https://www.google.com")).isTrue(); + assertThat(urlValidator.isAllowedUrl("https://github.com")).isTrue(); + assertThat(urlValidator.isAllowedUrl("http://www.baidu.com")).isTrue(); + } + + @Test + @DisplayName("应该拒绝无效的URL语法") + void shouldRejectInvalidUrlSyntax() { + assertThat(urlValidator.isAllowedUrl("not a url")).isFalse(); + assertThat(urlValidator.isAllowedUrl("http://")).isFalse(); + assertThat(urlValidator.isAllowedUrl("://example.com")).isFalse(); + } + + @Test + @DisplayName("sanitizeUrl应该返回有效URL的字符串形式") + void shouldSanitizeValidUrl() { + String url = "https://example.com/path"; + String sanitized = urlValidator.sanitizeUrl(url); + assertThat(sanitized).isNotNull(); + assertThat(sanitized).contains("example.com"); + } + + @Test + @DisplayName("sanitizeUrl应该对无效URL返回null") + void shouldReturnNullForInvalidUrl() { + assertThat(urlValidator.sanitizeUrl(null)).isNull(); + assertThat(urlValidator.sanitizeUrl("")).isNull(); + assertThat(urlValidator.sanitizeUrl("http://localhost")).isNull(); + assertThat(urlValidator.sanitizeUrl("not a url")).isNull(); + } + + @Test + @DisplayName("应该处理URL中的大小写") + void shouldHandleUrlCaseInsensitivity() { + assertThat(urlValidator.isAllowedUrl("HTTP://EXAMPLE.COM")).isTrue(); + assertThat(urlValidator.isAllowedUrl("HTTPS://EXAMPLE.COM")).isTrue(); + } + + @Test + @DisplayName("应该拒绝空主机名") + void shouldRejectEmptyHost() { + assertThat(urlValidator.isAllowedUrl("http://")).isFalse(); + } + + @ParameterizedTest + @ValueSource(strings = { + "http://169.254.0.1", // Link-local + "http://224.0.0.1" // Multicast + }) + @DisplayName("应该拒绝特殊用途的IP地址") + void shouldRejectSpecialPurposeIpAddresses(String url) { + assertThat(urlValidator.isAllowedUrl(url)).isFalse(); + } + + @Test + @DisplayName("应该处理带端口的URL") + void shouldHandleUrlsWithPorts() { + assertThat(urlValidator.isAllowedUrl("https://example.com:443")).isTrue(); + assertThat(urlValidator.isAllowedUrl("http://example.com:80")).isTrue(); + assertThat(urlValidator.isAllowedUrl("https://example.com:8443")).isTrue(); + } + + @Test + @DisplayName("应该处理带查询参数的URL") + void shouldHandleUrlsWithQueryParameters() { + assertThat(urlValidator.isAllowedUrl("https://example.com/path?key=value&foo=bar")).isTrue(); + } + + @Test + @DisplayName("应该处理带片段的URL") + void shouldHandleUrlsWithFragments() { + assertThat(urlValidator.isAllowedUrl("https://example.com/path#section")).isTrue(); + } + + @Test + @DisplayName("应该拒绝IPv6 loopback地址") + void shouldRejectIpv6LoopbackAddress() { + assertThat(urlValidator.isAllowedUrl("http://[::1]")).isFalse(); + assertThat(urlValidator.isAllowedUrl("http://[0:0:0:0:0:0:0:1]")).isFalse(); + } +}