Files
user-system/internal/auth/providers/provider_crypto_test.go

170 lines
4.8 KiB
Go
Raw Normal View History

package providers
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"net/url"
"strings"
"testing"
)
func generateRSAKeyForTest(t *testing.T) *rsa.PrivateKey {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatalf("generate rsa key failed: %v", err)
}
return key
}
func marshalPKCS8PEMForTest(t *testing.T, key *rsa.PrivateKey) string {
t.Helper()
der, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
t.Fatalf("marshal PKCS#8 failed: %v", err)
}
return string(pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: der,
}))
}
func TestParseAlipayPrivateKeySupportsRawPKCS8AndPKCS1(t *testing.T) {
key := generateRSAKeyForTest(t)
pkcs8DER, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
t.Fatalf("marshal PKCS#8 failed: %v", err)
}
rawPKCS8 := base64.StdEncoding.EncodeToString(pkcs8DER)
parsedPKCS8, err := parseAlipayPrivateKey(rawPKCS8)
if err != nil {
t.Fatalf("parse raw PKCS#8 key failed: %v", err)
}
if parsedPKCS8.N.Cmp(key.N) != 0 || parsedPKCS8.D.Cmp(key.D) != 0 {
t.Fatal("parsed raw PKCS#8 key does not match original key")
}
pkcs1PEM := string(pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
}))
parsedPKCS1, err := parseAlipayPrivateKey(pkcs1PEM)
if err != nil {
t.Fatalf("parse PKCS#1 key failed: %v", err)
}
if parsedPKCS1.N.Cmp(key.N) != 0 || parsedPKCS1.D.Cmp(key.D) != 0 {
t.Fatal("parsed PKCS#1 key does not match original key")
}
}
func TestParseAlipayPrivateKeyRejectsInvalidPEM(t *testing.T) {
if _, err := parseAlipayPrivateKey("not-a-valid-private-key"); err == nil {
t.Fatal("expected invalid private key parsing to fail")
}
}
func TestAlipayProviderSignParamsProducesVerifiableSignature(t *testing.T) {
key := generateRSAKeyForTest(t)
provider := NewAlipayProvider(
"app-id",
marshalPKCS8PEMForTest(t, key),
"https://admin.example.com/login/oauth/callback",
false,
)
params := map[string]string{
"method": "alipay.system.oauth.token",
"app_id": "app-id",
"code": "auth-code",
"sign": "should-be-ignored",
}
signature, err := provider.signParams(params)
if err != nil {
t.Fatalf("signParams failed: %v", err)
}
if signature == "" {
t.Fatal("expected non-empty signature")
}
signatureBytes, err := base64.StdEncoding.DecodeString(signature)
if err != nil {
t.Fatalf("decode signature failed: %v", err)
}
signContent := "app_id=app-id&code=auth-code&method=alipay.system.oauth.token"
hash := sha256.Sum256([]byte(signContent))
if err := rsa.VerifyPKCS1v15(&key.PublicKey, crypto.SHA256, hash[:], signatureBytes); err != nil {
t.Fatalf("signature verification failed: %v", err)
}
}
func TestTwitterProviderPKCEHelpersAndAuthURL(t *testing.T) {
provider := NewTwitterProvider("twitter-client", "https://admin.example.com/login/oauth/callback")
verifierA, err := provider.GenerateCodeVerifier()
if err != nil {
t.Fatalf("GenerateCodeVerifier(first) failed: %v", err)
}
verifierB, err := provider.GenerateCodeVerifier()
if err != nil {
t.Fatalf("GenerateCodeVerifier(second) failed: %v", err)
}
if verifierA == "" || verifierB == "" {
t.Fatal("expected non-empty code verifiers")
}
if verifierA == verifierB {
t.Fatal("expected code verifiers to differ across calls")
}
if strings.Contains(verifierA, "=") || strings.Contains(verifierB, "=") {
t.Fatal("expected code verifiers to be base64url values without padding")
}
if provider.GenerateCodeChallenge(verifierA) != verifierA {
t.Fatal("expected current code challenge implementation to mirror the verifier")
}
authURL, err := provider.GetAuthURL()
if err != nil {
t.Fatalf("GetAuthURL failed: %v", err)
}
if authURL.CodeVerifier == "" || authURL.State == "" {
t.Fatal("expected auth url response to include verifier and state")
}
if authURL.Redirect != provider.RedirectURI {
t.Fatalf("expected redirect %q, got %q", provider.RedirectURI, authURL.Redirect)
}
parsed, err := url.Parse(authURL.URL)
if err != nil {
t.Fatalf("parse auth url failed: %v", err)
}
query := parsed.Query()
if query.Get("client_id") != "twitter-client" {
t.Fatalf("expected twitter client_id, got %q", query.Get("client_id"))
}
if query.Get("redirect_uri") != provider.RedirectURI {
t.Fatalf("expected redirect_uri %q, got %q", provider.RedirectURI, query.Get("redirect_uri"))
}
if query.Get("code_challenge") != authURL.CodeVerifier {
t.Fatalf("expected code challenge to equal verifier, got %q", query.Get("code_challenge"))
}
if query.Get("code_challenge_method") != "plain" {
t.Fatalf("expected code_challenge_method plain, got %q", query.Get("code_challenge_method"))
}
if query.Get("state") != authURL.State {
t.Fatalf("expected state %q, got %q", authURL.State, query.Get("state"))
}
}