Files
user-system/internal/auth/oauth_utils_test.go
long-agent 582ad7a069 test: add comprehensive test coverage and improve code quality
- Add new test files for auth, service, and handler modules
- Improve test organization and coverage
- Refactor code for better maintainability
- Add captcha, settings, stats, and theme handler tests
- Add auth module tests (CAS, OAuth, password, SSO, state)
- Add service layer tests for auth, export, permissions, roles
- All Go tests pass (exit code 0)
- All frontend tests pass (325 tests in 59 files)
2026-04-17 20:43:50 +08:00

406 lines
11 KiB
Go

package auth
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
)
func TestGenerateState(t *testing.T) {
state, err := GenerateState()
if err != nil {
t.Fatalf("GenerateState() error = %v", err)
}
if state == "" {
t.Error("GenerateState() returned empty state")
}
// State should be base64 encoded, so no special chars that would break URLs
if strings.ContainsAny(state, "+/") {
t.Error("GenerateState() should use URL-safe base64 encoding")
}
}
func TestValidateState(t *testing.T) {
// Test valid state
state, err := GenerateState()
if err != nil {
t.Fatalf("GenerateState() error = %v", err)
}
if !ValidateState(state) {
t.Error("ValidateState() returned false for valid state")
}
// State should be consumed (one-time use)
if ValidateState(state) {
t.Error("ValidateState() should return false for consumed state")
}
// Test invalid state
if ValidateState("invalid-state") {
t.Error("ValidateState() returned true for invalid state")
}
}
func TestValidateState_Expired(t *testing.T) {
// Create a state and manually expire it
state, err := GenerateState()
if err != nil {
t.Fatalf("GenerateState() error = %v", err)
}
// Manually set expired time
stateStore.mu.Lock()
stateStore.states[state] = time.Now().Add(-1 * time.Hour)
stateStore.mu.Unlock()
if ValidateState(state) {
t.Error("ValidateState() should return false for expired state")
}
}
func TestCleanupStates(t *testing.T) {
// Clear existing states
stateStore.mu.Lock()
stateStore.states = make(map[string]time.Time)
stateStore.mu.Unlock()
// Add some states
state1, _ := GenerateState()
state2, _ := GenerateState()
// Manually expire one
stateStore.mu.Lock()
stateStore.states["expired-state"] = time.Now().Add(-1 * time.Hour)
stateStore.mu.Unlock()
// Cleanup
CleanupStates()
stateStore.mu.RLock()
defer stateStore.mu.RUnlock()
// Expired state should be removed
if _, ok := stateStore.states["expired-state"]; ok {
t.Error("CleanupStates() did not remove expired state")
}
// Valid states should remain
if _, ok := stateStore.states[state1]; !ok {
t.Error("CleanupStates() removed valid state1")
}
if _, ok := stateStore.states[state2]; !ok {
t.Error("CleanupStates() removed valid state2")
}
}
func TestGet(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Errorf("Expected GET request, got %s", r.Method)
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error = %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Get() status = %d, want %d", resp.StatusCode, http.StatusOK)
}
}
func TestPostForm(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
data := url.Values{}
data.Set("key", "value")
resp, err := PostForm(server.URL, data)
if err != nil {
t.Fatalf("PostForm() error = %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("PostForm() status = %d, want %d", resp.StatusCode, http.StatusOK)
}
}
func TestGetJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"message": "hello"})
}))
defer server.Close()
var result struct {
Message string `json:"message"`
}
err := GetJSON(server.URL, &result)
if err != nil {
t.Fatalf("GetJSON() error = %v", err)
}
if result.Message != "hello" {
t.Errorf("GetJSON() result.Message = %s, want hello", result.Message)
}
}
func TestGetJSON_NonOKStatus(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
var result struct{}
err := GetJSON(server.URL, &result)
if err == nil {
t.Error("GetJSON() should return error for non-OK status")
}
}
func TestPostFormJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"token": "abc123"})
}))
defer server.Close()
data := url.Values{}
data.Set("grant_type", "authorization_code")
var result struct {
Token string `json:"token"`
}
err := PostFormJSON(server.URL, data, &result)
if err != nil {
t.Fatalf("PostFormJSON() error = %v", err)
}
if result.Token != "abc123" {
t.Errorf("PostFormJSON() result.Token = %s, want abc123", result.Token)
}
}
func TestPostFormJSON_NonOKStatus(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer server.Close()
var result struct{}
err := PostFormJSON(server.URL, url.Values{}, &result)
if err == nil {
t.Error("PostFormJSON() should return error for non-OK status")
}
}
func TestBuildAuthURL(t *testing.T) {
baseURL := "https://example.com/oauth/authorize"
clientID := "test-client-id"
redirectURI := "https://myapp.com/callback"
scope := "openid email"
state := "random-state"
result := BuildAuthURL(baseURL, clientID, redirectURI, scope, state)
u, err := url.Parse(result)
if err != nil {
t.Fatalf("BuildAuthURL() produced invalid URL: %v", err)
}
if u.Scheme != "https" {
t.Errorf("BuildAuthURL() scheme = %s, want https", u.Scheme)
}
if u.Host != "example.com" {
t.Errorf("BuildAuthURL() host = %s, want example.com", u.Host)
}
q := u.Query()
if q.Get("client_id") != clientID {
t.Errorf("BuildAuthURL() client_id = %s, want %s", q.Get("client_id"), clientID)
}
if q.Get("redirect_uri") != redirectURI {
t.Errorf("BuildAuthURL() redirect_uri = %s, want %s", q.Get("redirect_uri"), redirectURI)
}
if q.Get("scope") != scope {
t.Errorf("BuildAuthURL() scope = %s, want %s", q.Get("scope"), scope)
}
if q.Get("state") != state {
t.Errorf("BuildAuthURL() state = %s, want %s", q.Get("state"), state)
}
if q.Get("response_type") != "code" {
t.Errorf("BuildAuthURL() response_type = %s, want code", q.Get("response_type"))
}
}
func TestParseAccessTokenResponse(t *testing.T) {
jsonData := `{
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
"token_type": "Bearer"
}`
token, err := ParseAccessTokenResponse([]byte(jsonData))
if err != nil {
t.Fatalf("ParseAccessTokenResponse() error = %v", err)
}
if token.AccessToken != "test-access-token" {
t.Errorf("AccessToken = %s, want test-access-token", token.AccessToken)
}
if token.RefreshToken != "test-refresh-token" {
t.Errorf("RefreshToken = %s, want test-refresh-token", token.RefreshToken)
}
if token.ExpiresIn != 3600 {
t.Errorf("ExpiresIn = %d, want 3600", token.ExpiresIn)
}
if token.TokenType != "Bearer" {
t.Errorf("TokenType = %s, want Bearer", token.TokenType)
}
}
func TestParseAccessTokenResponse_InvalidJSON(t *testing.T) {
_, err := ParseAccessTokenResponse([]byte("invalid json"))
if err == nil {
t.Error("ParseAccessTokenResponse() should return error for invalid JSON")
}
}
func TestParseQueryAccessToken(t *testing.T) {
body := "access_token=abc123&token_type=Bearer&expires_in=3600"
token, err := ParseQueryAccessToken(body)
if err != nil {
t.Fatalf("ParseQueryAccessToken() error = %v", err)
}
if token != "abc123" {
t.Errorf("ParseQueryAccessToken() = %s, want abc123", token)
}
}
func TestParseQueryAccessToken_NoToken(t *testing.T) {
body := "token_type=Bearer&expires_in=3600"
token, err := ParseQueryAccessToken(body)
if err != nil {
t.Fatalf("ParseQueryAccessToken() error = %v", err)
}
if token != "" {
t.Errorf("ParseQueryAccessToken() = %s, want empty", token)
}
}
func TestParseQueryAccessToken_InvalidQuery(t *testing.T) {
_, err := ParseQueryAccessToken("invalid%zz")
if err == nil {
t.Error("ParseQueryAccessToken() should return error for invalid query string")
}
}
func TestParseJSONPResponse(t *testing.T) {
jsonp := `callback({"access_token":"abc123","expires_in":7200})`
result, err := ParseJSONPResponse(jsonp)
if err != nil {
t.Fatalf("ParseJSONPResponse() error = %v", err)
}
if result["access_token"] != "abc123" {
t.Errorf("ParseJSONPResponse() access_token = %v, want abc123", result["access_token"])
}
if result["expires_in"].(float64) != 7200 {
t.Errorf("ParseJSONPResponse() expires_in = %v, want 7200", result["expires_in"])
}
}
func TestParseJSONPResponse_InvalidFormat(t *testing.T) {
tests := []struct {
name string
jsonp string
}{
{"no parentheses", "invalid"},
{"no opening", "invalid)"},
{"no closing", "invalid("},
{"invalid JSON", "callback(invalid json)"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ParseJSONPResponse(tt.jsonp)
if err == nil {
t.Errorf("ParseJSONPResponse() should return error for %s", tt.name)
}
})
}
}
func TestToOAuth2Config(t *testing.T) {
config := &OAuthConfig{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
RedirectURI: "https://myapp.com/callback",
Scope: "openid,email,profile",
AuthURL: "https://example.com/oauth/authorize",
TokenURL: "https://example.com/oauth/token",
}
oauth2Config := ToOAuth2Config(config)
if oauth2Config.ClientID != config.ClientID {
t.Errorf("ClientID = %s, want %s", oauth2Config.ClientID, config.ClientID)
}
if oauth2Config.ClientSecret != config.ClientSecret {
t.Errorf("ClientSecret = %s, want %s", oauth2Config.ClientSecret, config.ClientSecret)
}
if oauth2Config.RedirectURL != config.RedirectURI {
t.Errorf("RedirectURL = %s, want %s", oauth2Config.RedirectURL, config.RedirectURI)
}
if len(oauth2Config.Scopes) != 3 {
t.Errorf("Scopes length = %d, want 3", len(oauth2Config.Scopes))
}
if oauth2Config.Endpoint.AuthURL != config.AuthURL {
t.Errorf("AuthURL = %s, want %s", oauth2Config.Endpoint.AuthURL, config.AuthURL)
}
if oauth2Config.Endpoint.TokenURL != config.TokenURL {
t.Errorf("TokenURL = %s, want %s", oauth2Config.Endpoint.TokenURL, config.TokenURL)
}
}
func TestGetJSON_ConnectionError(t *testing.T) {
var result struct{}
err := GetJSON("http://127.0.0.1:1", &result) // Invalid port
if err == nil {
t.Error("GetJSON() should return error for connection failure")
}
}
func TestPostFormJSON_ConnectionError(t *testing.T) {
var result struct{}
err := PostFormJSON("http://127.0.0.1:1", url.Values{}, &result) // Invalid port
if err == nil {
t.Error("PostFormJSON() should return error for connection failure")
}
}