Files
user-system/internal/api/middleware/cors_test.go
Your Name e5da23cea2 test: add CORS middleware tests
Add tests for CORS functionality:
- validateCORSConfig (valid and invalid configs)
- SetCORSConfig (update and validation)
- resolveAllowedOrigin (exact match, wildcard, case insensitive)
- CORS middleware (allow/forbid origins, OPTIONS handling)

Coverage: middleware 36.4% → 37.4%
2026-05-29 21:06:43 +08:00

216 lines
5.3 KiB
Go

package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/user-management-system/internal/config"
)
func TestValidateCORSConfig(t *testing.T) {
tests := []struct {
name string
cfg config.CORSConfig
wantErr bool
}{
{
name: "valid config with specific origins",
cfg: config.CORSConfig{
AllowedOrigins: []string{"https://example.com"},
AllowCredentials: true,
},
wantErr: false,
},
{
name: "valid config with wildcard no credentials",
cfg: config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: false,
},
wantErr: false,
},
{
name: "invalid config with wildcard and credentials",
cfg: config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
},
wantErr: true,
},
{
name: "empty origins",
cfg: config.CORSConfig{
AllowedOrigins: []string{},
AllowCredentials: false,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateCORSConfig(tt.cfg)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestSetCORSConfig(t *testing.T) {
// Save original config
originalConfig := corsConfig
defer func() { corsConfig = originalConfig }()
t.Run("valid config", func(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://example.com"},
AllowCredentials: true,
}
err := SetCORSConfig(cfg)
assert.NoError(t, err)
assert.Equal(t, cfg, corsConfig)
})
t.Run("invalid config", func(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
}
err := SetCORSConfig(cfg)
assert.Error(t, err)
})
}
func TestResolveAllowedOrigin(t *testing.T) {
tests := []struct {
name string
origin string
allowedOrigins []string
allowCredentials bool
wantOrigin string
wantAllowed bool
}{
{
name: "exact match",
origin: "https://example.com",
allowedOrigins: []string{"https://example.com"},
allowCredentials: true,
wantOrigin: "https://example.com",
wantAllowed: true,
},
{
name: "wildcard without credentials",
origin: "https://any.com",
allowedOrigins: []string{"*"},
allowCredentials: false,
wantOrigin: "*",
wantAllowed: true,
},
{
name: "wildcard with credentials returns origin",
origin: "https://any.com",
allowedOrigins: []string{"*"},
allowCredentials: true,
wantOrigin: "https://any.com",
wantAllowed: true,
},
{
name: "no match",
origin: "https://evil.com",
allowedOrigins: []string{"https://example.com"},
allowCredentials: false,
wantOrigin: "",
wantAllowed: false,
},
{
name: "case insensitive match",
origin: "HTTPS://EXAMPLE.COM",
allowedOrigins: []string{"https://example.com"},
allowCredentials: false,
wantOrigin: "HTTPS://EXAMPLE.COM",
wantAllowed: true,
},
{
name: "empty origins list",
origin: "https://example.com",
allowedOrigins: []string{},
allowCredentials: false,
wantOrigin: "",
wantAllowed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotOrigin, gotAllowed := resolveAllowedOrigin(tt.origin, tt.allowedOrigins, tt.allowCredentials)
assert.Equal(t, tt.wantOrigin, gotOrigin)
assert.Equal(t, tt.wantAllowed, gotAllowed)
})
}
}
func TestCORS(t *testing.T) {
gin.SetMode(gin.TestMode)
// Save and restore original config
originalConfig := corsConfig
defer func() { corsConfig = originalConfig }()
// Set test config
corsConfig = config.CORSConfig{
AllowedOrigins: []string{"https://example.com"},
AllowCredentials: true,
}
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.String(200, "OK")
})
t.Run("allow valid origin", func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Origin", "https://example.com")
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "https://example.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"))
})
t.Run("forbid invalid origin", func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Origin", "https://evil.com")
router.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
})
t.Run("handle OPTIONS request", func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "/test", nil)
req.Header.Set("Origin", "https://example.com")
router.ServeHTTP(w, req)
assert.Equal(t, 204, w.Code)
assert.Equal(t, "GET, POST, PUT, DELETE, OPTIONS", w.Header().Get("Access-Control-Allow-Methods"))
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
t.Run("no origin header", func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
})
}