152 lines
3.6 KiB
Go
152 lines
3.6 KiB
Go
package api
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strconv"
|
|
"testing"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/supabase/auth/internal/conf"
|
|
)
|
|
|
|
func TestIsValidCodeChallenge(t *testing.T) {
|
|
cases := []struct {
|
|
challenge string
|
|
isValid bool
|
|
expectedError error
|
|
}{
|
|
{
|
|
challenge: "invalid",
|
|
isValid: false,
|
|
expectedError: badRequestError(ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength),
|
|
},
|
|
{
|
|
challenge: "codechallengecontainsinvalidcharacterslike@$^&*",
|
|
isValid: false,
|
|
expectedError: badRequestError(ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes"),
|
|
},
|
|
{
|
|
challenge: "validchallengevalidchallengevalidchallengevalidchallenge",
|
|
isValid: true,
|
|
expectedError: nil,
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
valid, err := isValidCodeChallenge(c.challenge)
|
|
require.Equal(t, c.isValid, valid)
|
|
require.Equal(t, c.expectedError, err)
|
|
}
|
|
}
|
|
|
|
func TestIsValidPKCEParams(t *testing.T) {
|
|
cases := []struct {
|
|
challengeMethod string
|
|
challenge string
|
|
expected error
|
|
}{
|
|
{
|
|
challengeMethod: "",
|
|
challenge: "",
|
|
expected: nil,
|
|
},
|
|
{
|
|
challengeMethod: "test",
|
|
challenge: "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttest",
|
|
expected: nil,
|
|
},
|
|
{
|
|
challengeMethod: "test",
|
|
challenge: "",
|
|
expected: badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage),
|
|
},
|
|
{
|
|
challengeMethod: "",
|
|
challenge: "test",
|
|
expected: badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage),
|
|
},
|
|
}
|
|
|
|
for i, c := range cases {
|
|
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
|
err := validatePKCEParams(c.challengeMethod, c.challenge)
|
|
require.Equal(t, c.expected, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRequestAud(ts *testing.T) {
|
|
mockAPI := API{
|
|
config: &conf.GlobalConfiguration{
|
|
JWT: conf.JWTConfiguration{
|
|
Aud: "authenticated",
|
|
Secret: "test-secret",
|
|
},
|
|
},
|
|
}
|
|
|
|
cases := []struct {
|
|
desc string
|
|
headers map[string]string
|
|
payload map[string]interface{}
|
|
expectedAud string
|
|
}{
|
|
{
|
|
desc: "Valid audience slice",
|
|
headers: map[string]string{
|
|
audHeaderName: "my_custom_aud",
|
|
},
|
|
payload: map[string]interface{}{
|
|
"aud": "authenticated",
|
|
},
|
|
expectedAud: "my_custom_aud",
|
|
},
|
|
{
|
|
desc: "Valid custom audience",
|
|
payload: map[string]interface{}{
|
|
"aud": "my_custom_aud",
|
|
},
|
|
expectedAud: "my_custom_aud",
|
|
},
|
|
{
|
|
desc: "Invalid audience",
|
|
payload: map[string]interface{}{
|
|
"aud": "",
|
|
},
|
|
expectedAud: mockAPI.config.JWT.Aud,
|
|
},
|
|
{
|
|
desc: "Missing audience",
|
|
payload: map[string]interface{}{
|
|
"sub": "d6044b6e-b0ec-4efe-a055-0d2d6ff1dbd8",
|
|
},
|
|
expectedAud: mockAPI.config.JWT.Aud,
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
ts.Run(c.desc, func(t *testing.T) {
|
|
claims := jwt.MapClaims(c.payload)
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
signed, err := token.SignedString([]byte(mockAPI.config.JWT.Secret))
|
|
require.NoError(t, err)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer: %s", signed))
|
|
for k, v := range c.headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
|
|
// set the token in the request context for requestAud
|
|
ctx, err := mockAPI.parseJWTClaims(signed, req)
|
|
require.NoError(t, err)
|
|
aud := mockAPI.requestAud(ctx, req)
|
|
require.Equal(t, c.expectedAud, aud)
|
|
})
|
|
}
|
|
|
|
}
|