chatdesk-ui/auth_v2.169.0/internal/api/token_test.go

858 lines
28 KiB
Go

package api
import (
"bytes"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
)
type TokenTestSuite struct {
suite.Suite
API *API
Config *conf.GlobalConfiguration
RefreshToken *models.RefreshToken
User *models.User
}
func TestToken(t *testing.T) {
os.Setenv("GOTRUE_RATE_LIMIT_HEADER", "My-Custom-Header")
api, config, err := setupAPIForTest()
require.NoError(t, err)
ts := &TokenTestSuite{
API: api,
Config: config,
}
defer api.db.Close()
suite.Run(t, ts)
}
func (ts *TokenTestSuite) SetupTest() {
ts.RefreshToken = nil
models.TruncateAll(ts.API.db)
// Create user & refresh token
u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
t := time.Now()
u.EmailConfirmedAt = &t
u.BannedUntil = nil
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user")
ts.User = u
ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{})
require.NoError(ts.T(), err, "Error creating refresh token")
ts.Config.Hook.CustomAccessToken.Enabled = false
}
func (ts *TokenTestSuite) TestSessionTimebox() {
timebox := 10 * time.Second
ts.API.config.Sessions.Timebox = &timebox
ts.API.overrideTime = func() time.Time {
return time.Now().Add(timebox).Add(time.Second)
}
defer func() {
ts.API.overrideTime = nil
ts.API.config.Sessions.Timebox = nil
}()
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": ts.RefreshToken.Token,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
var firstResult struct {
ErrorCode string `json:"error_code"`
Message string `json:"msg"`
}
assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult))
assert.Equal(ts.T(), ErrorCodeSessionExpired, firstResult.ErrorCode)
assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired", firstResult.Message)
}
func (ts *TokenTestSuite) TestSessionInactivityTimeout() {
inactivityTimeout := 10 * time.Second
ts.API.config.Sessions.InactivityTimeout = &inactivityTimeout
ts.API.overrideTime = func() time.Time {
return time.Now().Add(inactivityTimeout).Add(time.Second)
}
defer func() {
ts.API.config.Sessions.InactivityTimeout = nil
ts.API.overrideTime = nil
}()
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": ts.RefreshToken.Token,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
var firstResult struct {
ErrorCode string `json:"error_code"`
Message string `json:"msg"`
}
assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult))
assert.Equal(ts.T(), ErrorCodeSessionExpired, firstResult.ErrorCode)
assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired (Inactivity)", firstResult.Message)
}
func (ts *TokenTestSuite) TestFailedToSaveRefreshTokenResultCase() {
var buffer bytes.Buffer
// first refresh
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": ts.RefreshToken.Token,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusOK, w.Code)
var firstResult struct {
RefreshToken string `json:"refresh_token"`
}
assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult))
assert.NotEmpty(ts.T(), firstResult.RefreshToken)
// pretend that the browser wasn't able to save the firstResult,
// run again with the first refresh token
buffer = bytes.Buffer{}
// second refresh with the reused refresh token
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": ts.RefreshToken.Token,
}))
w = httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusOK, w.Code)
var secondResult struct {
RefreshToken string `json:"refresh_token"`
}
assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&secondResult))
assert.NotEmpty(ts.T(), secondResult.RefreshToken)
// new refresh token is not being issued but the active one from
// the first refresh that failed to save is stored
assert.Equal(ts.T(), firstResult.RefreshToken, secondResult.RefreshToken)
}
func (ts *TokenTestSuite) TestSingleSessionPerUserNoTags() {
ts.API.config.Sessions.SinglePerUser = true
defer func() {
ts.API.config.Sessions.SinglePerUser = false
}()
firstRefreshToken := ts.RefreshToken
// just in case to give some delay between first and second session creation
time.Sleep(10 * time.Millisecond)
secondRefreshToken, err := models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{})
require.NoError(ts.T(), err)
require.NotEqual(ts.T(), *firstRefreshToken.SessionId, *secondRefreshToken.SessionId)
require.Equal(ts.T(), firstRefreshToken.UserID, secondRefreshToken.UserID)
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": firstRefreshToken.Token,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
assert.True(ts.T(), ts.API.config.Sessions.SinglePerUser)
var firstResult struct {
ErrorCode string `json:"error_code"`
Message string `json:"msg"`
}
assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult))
assert.Equal(ts.T(), ErrorCodeSessionExpired, firstResult.ErrorCode)
assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired (Revoked by Newer Login)", firstResult.Message)
}
func (ts *TokenTestSuite) TestRateLimitTokenRefresh() {
var buffer bytes.Buffer
req := httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("My-Custom-Header", "1.2.3.4")
// It rate limits after 30 requests
for i := 0; i < 30; i++ {
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
}
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code)
// It ignores X-Forwarded-For by default
req.Header.Set("X-Forwarded-For", "1.1.1.1")
w = httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code)
// It doesn't rate limit a new value for the limited header
req = httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("My-Custom-Header", "5.6.7.8")
w = httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
}
func (ts *TokenTestSuite) TestTokenPasswordGrantSuccess() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": "test@example.com",
"password": "password",
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusOK, w.Code)
}
func (ts *TokenTestSuite) TestTokenRefreshTokenGrantSuccess() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": ts.RefreshToken.Token,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusOK, w.Code)
}
func (ts *TokenTestSuite) TestTokenPasswordGrantFailure() {
u := ts.createBannedUser()
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": u.GetEmail(),
"password": "password",
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
}
func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() {
authCode := "1234563"
codeVerifier := "4a9505b9-0857-42bb-ab3c-098b4d28ddc2"
invalidAuthCode := authCode + "123"
invalidVerifier := codeVerifier + "123"
codeChallenge := sha256.Sum256([]byte(codeVerifier))
challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:])
flowState := models.NewFlowState("github", challenge, models.SHA256, models.OAuth, nil)
flowState.AuthCode = authCode
require.NoError(ts.T(), ts.API.db.Create(flowState))
cases := []struct {
desc string
authCode string
codeVerifier string
grantType string
expectedHTTPCode int
}{
{
desc: "Invalid Authcode",
authCode: invalidAuthCode,
codeVerifier: codeVerifier,
},
{
desc: "Invalid code verifier",
authCode: authCode,
codeVerifier: invalidVerifier,
},
{
desc: "Invalid auth code and verifier",
authCode: invalidAuthCode,
codeVerifier: invalidVerifier,
},
}
for _, v := range cases {
ts.Run(v.desc, func() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"code_verifier": v.codeVerifier,
"auth_code": v.authCode,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusNotFound, w.Code)
})
}
}
func (ts *TokenTestSuite) TestTokenRefreshTokenGrantFailure() {
_ = ts.createBannedUser()
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": ts.RefreshToken.Token,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
}
func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() {
originalSecurity := ts.API.config.Security
ts.API.config.Security.RefreshTokenRotationEnabled = true
ts.API.config.Security.RefreshTokenReuseInterval = 0
defer func() {
ts.API.config.Security = originalSecurity
}()
refreshTokens := []string{
ts.RefreshToken.Token,
}
for i := 0; i < 3; i += 1 {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": refreshTokens[len(refreshTokens)-1],
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusOK, w.Code)
var response struct {
RefreshToken string `json:"refresh_token"`
}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response))
refreshTokens = append(refreshTokens, response.RefreshToken)
}
// ensure that the 4 refresh tokens are setup correctly
for i, refreshToken := range refreshTokens {
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
require.NoError(ts.T(), err)
if i == len(refreshTokens)-1 {
require.False(ts.T(), token.Revoked)
} else {
require.True(ts.T(), token.Revoked)
}
}
// try to reuse the first (earliest) refresh token which should trigger the family revocation logic
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": refreshTokens[0],
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
var response struct {
ErrorCode string `json:"error_code"`
Message string `json:"msg"`
}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response))
require.Equal(ts.T(), ErrorCodeRefreshTokenAlreadyUsed, response.ErrorCode)
require.Equal(ts.T(), "Invalid Refresh Token: Already Used", response.Message)
// ensure that the refresh tokens are marked as revoked in the database
for _, refreshToken := range refreshTokens {
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
require.NoError(ts.T(), err)
require.True(ts.T(), token.Revoked)
}
// finally ensure that none of the refresh tokens can be reused any
// more, starting with the previously valid one
for i := len(refreshTokens) - 1; i >= 0; i -= 1 {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": refreshTokens[i],
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code, "For refresh token %d", i)
var response struct {
ErrorCode string `json:"error_code"`
Message string `json:"msg"`
}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response))
require.Equal(ts.T(), ErrorCodeRefreshTokenAlreadyUsed, response.ErrorCode, "For refresh token %d", i)
require.Equal(ts.T(), "Invalid Refresh Token: Already Used", response.Message, "For refresh token %d", i)
}
}
func (ts *TokenTestSuite) createBannedUser() *models.User {
u, err := models.NewUser("", "banned@example.com", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
t := time.Now()
u.EmailConfirmedAt = &t
t = t.Add(24 * time.Hour)
u.BannedUntil = &t
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test banned user")
ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{})
require.NoError(ts.T(), err, "Error creating refresh token")
return u
}
func (ts *TokenTestSuite) TestTokenRefreshWithExpiredSession() {
var err error
now := time.Now().UTC().Add(-1 * time.Second)
ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{
SessionNotAfter: &now,
})
require.NoError(ts.T(), err, "Error creating refresh token")
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": ts.RefreshToken.Token,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
}
func (ts *TokenTestSuite) TestTokenRefreshWithUnexpiredSession() {
var err error
now := time.Now().UTC().Add(1 * time.Second)
ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{
SessionNotAfter: &now,
})
require.NoError(ts.T(), err, "Error creating refresh token")
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": ts.RefreshToken.Token,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusOK, w.Code)
}
func (ts *TokenTestSuite) TestMagicLinkPKCESignIn() {
var buffer bytes.Buffer
// Send OTP
codeVerifier := "4a9505b9-0857-42bb-ab3c-098b4d28ddc2"
codeChallenge := sha256.Sum256([]byte(codeVerifier))
challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:])
req := httptest.NewRequest(http.MethodPost, "/otp", &buffer)
req.Header.Set("Content-Type", "application/json")
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(OtpParams{
Email: "test@example.com",
CreateUser: true,
CodeChallengeMethod: "s256",
CodeChallenge: challenge,
}))
req = httptest.NewRequest(http.MethodPost, "/otp", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
// Verify OTP
requestUrl := fmt.Sprintf("http://localhost/verify?type=%v&token=%v", "magiclink", u.RecoveryToken)
req = httptest.NewRequest(http.MethodGet, requestUrl, &buffer)
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusSeeOther, w.Code)
rURL, _ := w.Result().Location()
u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
assert.True(ts.T(), u.IsConfirmed())
f, err := url.ParseQuery(rURL.RawQuery)
require.NoError(ts.T(), err)
authCode := f.Get("code")
assert.NotEmpty(ts.T(), authCode)
// Extract token and sign in
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"code_verifier": codeVerifier,
"auth_code": authCode,
}))
req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer)
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
verifyResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&verifyResp))
require.NotEmpty(ts.T(), verifyResp.Token)
}
func (ts *TokenTestSuite) TestPasswordVerificationHook() {
type verificationHookTestcase struct {
desc string
uri string
hookFunctionSQL string
expectedCode int
}
cases := []verificationHookTestcase{
{
desc: "Default success",
uri: "pg-functions://postgres/auth/password_verification_hook",
hookFunctionSQL: `
create or replace function password_verification_hook(input jsonb)
returns jsonb as $$
begin
return jsonb_build_object('decision', 'continue');
end; $$ language plpgsql;`,
expectedCode: http.StatusOK,
}, {
desc: "Reject- Enabled",
uri: "pg-functions://postgres/auth/password_verification_hook_reject",
hookFunctionSQL: `
create or replace function password_verification_hook_reject(input jsonb)
returns jsonb as $$
begin
return jsonb_build_object('decision', 'reject', 'message', 'You shall not pass!');
end; $$ language plpgsql;`,
expectedCode: http.StatusBadRequest,
},
}
for _, c := range cases {
ts.T().Run(c.desc, func(t *testing.T) {
ts.Config.Hook.PasswordVerificationAttempt.Enabled = true
ts.Config.Hook.PasswordVerificationAttempt.URI = c.uri
require.NoError(ts.T(), ts.Config.Hook.PasswordVerificationAttempt.PopulateExtensibilityPoint())
err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec()
require.NoError(t, err)
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": "test@example.com",
"password": "password",
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), c.expectedCode, w.Code)
cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.PasswordVerificationAttempt.HookName)
require.NoError(ts.T(), ts.API.db.RawQuery(cleanupHookSQL).Exec())
// Reset so it doesn't affect other tests
ts.Config.Hook.PasswordVerificationAttempt.Enabled = false
})
}
}
func (ts *TokenTestSuite) TestCustomAccessToken() {
type customAccessTokenTestcase struct {
desc string
uri string
hookFunctionSQL string
expectedClaims map[string]interface{}
shouldError bool
}
cases := []customAccessTokenTestcase{
{
desc: "Add a new claim",
uri: "pg-functions://postgres/auth/custom_access_token_add_claim",
hookFunctionSQL: ` create or replace function custom_access_token_add_claim(input jsonb) returns jsonb as $$ declare result jsonb; begin if jsonb_typeof(jsonb_object_field(input, 'claims')) is null then result := jsonb_build_object('error', jsonb_build_object('http_code', 400, 'message', 'Input does not contain claims field')); return result; end if;
input := jsonb_set(input, '{claims,newclaim}', '"newvalue"', true);
result := jsonb_build_object('claims', input->'claims');
return result;
end; $$ language plpgsql;`,
expectedClaims: map[string]interface{}{
"newclaim": "newvalue",
},
}, {
desc: "Delete the Role claim",
uri: "pg-functions://postgres/auth/custom_access_token_delete_claim",
hookFunctionSQL: `
create or replace function custom_access_token_delete_claim(input jsonb)
returns jsonb as $$
declare
result jsonb;
begin
input := jsonb_set(input, '{claims}', (input->'claims') - 'role');
result := jsonb_build_object('claims', input->'claims');
return result;
end; $$ language plpgsql;`,
expectedClaims: map[string]interface{}{},
shouldError: true,
}, {
desc: "Delete a non-required claim (UserMetadata)",
uri: "pg-functions://postgres/auth/custom_access_token_delete_usermetadata",
hookFunctionSQL: `
create or replace function custom_access_token_delete_usermetadata(input jsonb)
returns jsonb as $$
declare
result jsonb;
begin
input := jsonb_set(input, '{claims}', (input->'claims') - 'user_metadata');
result := jsonb_build_object('claims', input->'claims');
return result;
end; $$ language plpgsql;`,
// Not used
expectedClaims: map[string]interface{}{
"user_metadata": nil,
},
shouldError: false,
},
}
for _, c := range cases {
ts.T().Run(c.desc, func(t *testing.T) {
ts.Config.Hook.CustomAccessToken.Enabled = true
ts.Config.Hook.CustomAccessToken.URI = c.uri
require.NoError(t, ts.Config.Hook.CustomAccessToken.PopulateExtensibilityPoint())
err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec()
require.NoError(t, err)
var buffer bytes.Buffer
require.NoError(t, json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": ts.RefreshToken.Token,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
var tokenResponse struct {
AccessToken string `json:"access_token"`
}
require.NoError(t, json.NewDecoder(w.Result().Body).Decode(&tokenResponse))
if c.shouldError {
require.Equal(t, http.StatusInternalServerError, w.Code)
} else {
parts := strings.Split(tokenResponse.AccessToken, ".")
require.Equal(t, 3, len(parts), "Token should have 3 parts")
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
require.NoError(t, err)
var responseClaims map[string]interface{}
require.NoError(t, json.Unmarshal(payload, &responseClaims))
for key, expectedValue := range c.expectedClaims {
if expectedValue == nil {
// Since c.shouldError is false here, we only need to check if the claim should be removed
_, exists := responseClaims[key]
assert.False(t, exists, "Claim should be removed")
} else {
assert.Equal(t, expectedValue, responseClaims[key])
}
}
}
cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.CustomAccessToken.HookName)
require.NoError(t, ts.API.db.RawQuery(cleanupHookSQL).Exec())
ts.Config.Hook.CustomAccessToken.Enabled = false
})
}
}
func (ts *TokenTestSuite) TestAllowSelectAuthenticationMethods() {
companyUser, err := models.NewUser("12345678", "test@company.com", "password", ts.Config.JWT.Aud, nil)
t := time.Now()
companyUser.EmailConfirmedAt = &t
require.NoError(ts.T(), err, "Error creating test user model")
require.NoError(ts.T(), ts.API.db.Create(companyUser), "Error saving new test user")
type allowSelectAuthMethodsTestcase struct {
desc string
uri string
email string
expectedError string
expectedStatus int
}
// Common hook function SQL definition
hookFunctionSQL := `
create or replace function auth.custom_access_token(event jsonb) returns jsonb language plpgsql as $$
declare
email_claim text;
authentication_method text;
begin
email_claim := event->'claims'->>'email';
authentication_method := event->>'authentication_method';
if authentication_method = 'password' and email_claim not like '%@company.com' then
return jsonb_build_object(
'error', jsonb_build_object(
'http_code', 403,
'message', 'only members on company.com can access with password authentication'
)
);
end if;
return event;
end;
$$;`
cases := []allowSelectAuthMethodsTestcase{
{
desc: "Error for non-protected domain with password authentication",
uri: "pg-functions://postgres/auth/custom_access_token",
email: "test@example.com",
expectedError: "only members on company.com can access with password authentication",
expectedStatus: http.StatusForbidden,
},
{
desc: "Allow access for protected domain with password authentication",
uri: "pg-functions://postgres/auth/custom_access_token",
email: companyUser.Email.String(),
expectedError: "",
expectedStatus: http.StatusOK,
},
}
for _, c := range cases {
ts.T().Run(c.desc, func(t *testing.T) {
// Enable and set up the custom access token hook
ts.Config.Hook.CustomAccessToken.Enabled = true
ts.Config.Hook.CustomAccessToken.URI = c.uri
require.NoError(t, ts.Config.Hook.CustomAccessToken.PopulateExtensibilityPoint())
// Execute the common hook function SQL
err := ts.API.db.RawQuery(hookFunctionSQL).Exec()
require.NoError(t, err)
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": c.email,
"password": "password",
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
require.Equal(t, c.expectedStatus, w.Code, "Unexpected HTTP status code")
if c.expectedError != "" {
require.Contains(t, w.Body.String(), c.expectedError, "Expected error message not found")
} else {
require.NotContains(t, w.Body.String(), "error", "Unexpected error occurred")
}
// Delete the function and cleanup
cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.CustomAccessToken.HookName)
require.NoError(t, ts.API.db.RawQuery(cleanupHookSQL).Exec())
ts.Config.Hook.CustomAccessToken.Enabled = false
})
}
}