858 lines
28 KiB
Go
858 lines
28 KiB
Go
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/stretchr/testify/suite"
|
|
"github.com/supabase/auth/internal/conf"
|
|
"github.com/supabase/auth/internal/models"
|
|
)
|
|
|
|
type TokenTestSuite struct {
|
|
suite.Suite
|
|
API *API
|
|
Config *conf.GlobalConfiguration
|
|
|
|
RefreshToken *models.RefreshToken
|
|
User *models.User
|
|
}
|
|
|
|
func TestToken(t *testing.T) {
|
|
os.Setenv("GOTRUE_RATE_LIMIT_HEADER", "My-Custom-Header")
|
|
api, config, err := setupAPIForTest()
|
|
require.NoError(t, err)
|
|
|
|
ts := &TokenTestSuite{
|
|
API: api,
|
|
Config: config,
|
|
}
|
|
defer api.db.Close()
|
|
|
|
suite.Run(t, ts)
|
|
}
|
|
|
|
func (ts *TokenTestSuite) SetupTest() {
|
|
ts.RefreshToken = nil
|
|
models.TruncateAll(ts.API.db)
|
|
|
|
// Create user & refresh token
|
|
u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil)
|
|
require.NoError(ts.T(), err, "Error creating test user model")
|
|
t := time.Now()
|
|
u.EmailConfirmedAt = &t
|
|
u.BannedUntil = nil
|
|
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user")
|
|
|
|
ts.User = u
|
|
ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{})
|
|
require.NoError(ts.T(), err, "Error creating refresh token")
|
|
ts.Config.Hook.CustomAccessToken.Enabled = false
|
|
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestSessionTimebox() {
|
|
timebox := 10 * time.Second
|
|
|
|
ts.API.config.Sessions.Timebox = &timebox
|
|
ts.API.overrideTime = func() time.Time {
|
|
return time.Now().Add(timebox).Add(time.Second)
|
|
}
|
|
|
|
defer func() {
|
|
ts.API.overrideTime = nil
|
|
ts.API.config.Sessions.Timebox = nil
|
|
}()
|
|
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"refresh_token": ts.RefreshToken.Token,
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
|
|
|
|
var firstResult struct {
|
|
ErrorCode string `json:"error_code"`
|
|
Message string `json:"msg"`
|
|
}
|
|
|
|
assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult))
|
|
assert.Equal(ts.T(), ErrorCodeSessionExpired, firstResult.ErrorCode)
|
|
assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired", firstResult.Message)
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestSessionInactivityTimeout() {
|
|
inactivityTimeout := 10 * time.Second
|
|
|
|
ts.API.config.Sessions.InactivityTimeout = &inactivityTimeout
|
|
ts.API.overrideTime = func() time.Time {
|
|
return time.Now().Add(inactivityTimeout).Add(time.Second)
|
|
}
|
|
|
|
defer func() {
|
|
ts.API.config.Sessions.InactivityTimeout = nil
|
|
ts.API.overrideTime = nil
|
|
}()
|
|
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"refresh_token": ts.RefreshToken.Token,
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
|
|
|
|
var firstResult struct {
|
|
ErrorCode string `json:"error_code"`
|
|
Message string `json:"msg"`
|
|
}
|
|
|
|
assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult))
|
|
assert.Equal(ts.T(), ErrorCodeSessionExpired, firstResult.ErrorCode)
|
|
assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired (Inactivity)", firstResult.Message)
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestFailedToSaveRefreshTokenResultCase() {
|
|
var buffer bytes.Buffer
|
|
|
|
// first refresh
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"refresh_token": ts.RefreshToken.Token,
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusOK, w.Code)
|
|
|
|
var firstResult struct {
|
|
RefreshToken string `json:"refresh_token"`
|
|
}
|
|
|
|
assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult))
|
|
assert.NotEmpty(ts.T(), firstResult.RefreshToken)
|
|
|
|
// pretend that the browser wasn't able to save the firstResult,
|
|
// run again with the first refresh token
|
|
buffer = bytes.Buffer{}
|
|
|
|
// second refresh with the reused refresh token
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"refresh_token": ts.RefreshToken.Token,
|
|
}))
|
|
|
|
w = httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusOK, w.Code)
|
|
|
|
var secondResult struct {
|
|
RefreshToken string `json:"refresh_token"`
|
|
}
|
|
|
|
assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&secondResult))
|
|
assert.NotEmpty(ts.T(), secondResult.RefreshToken)
|
|
|
|
// new refresh token is not being issued but the active one from
|
|
// the first refresh that failed to save is stored
|
|
assert.Equal(ts.T(), firstResult.RefreshToken, secondResult.RefreshToken)
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestSingleSessionPerUserNoTags() {
|
|
ts.API.config.Sessions.SinglePerUser = true
|
|
defer func() {
|
|
ts.API.config.Sessions.SinglePerUser = false
|
|
}()
|
|
|
|
firstRefreshToken := ts.RefreshToken
|
|
|
|
// just in case to give some delay between first and second session creation
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
secondRefreshToken, err := models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{})
|
|
|
|
require.NoError(ts.T(), err)
|
|
|
|
require.NotEqual(ts.T(), *firstRefreshToken.SessionId, *secondRefreshToken.SessionId)
|
|
require.Equal(ts.T(), firstRefreshToken.UserID, secondRefreshToken.UserID)
|
|
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"refresh_token": firstRefreshToken.Token,
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
|
|
assert.True(ts.T(), ts.API.config.Sessions.SinglePerUser)
|
|
|
|
var firstResult struct {
|
|
ErrorCode string `json:"error_code"`
|
|
Message string `json:"msg"`
|
|
}
|
|
|
|
assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult))
|
|
assert.Equal(ts.T(), ErrorCodeSessionExpired, firstResult.ErrorCode)
|
|
assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired (Revoked by Newer Login)", firstResult.Message)
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestRateLimitTokenRefresh() {
|
|
var buffer bytes.Buffer
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("My-Custom-Header", "1.2.3.4")
|
|
|
|
// It rate limits after 30 requests
|
|
for i := 0; i < 30; i++ {
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
|
|
}
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code)
|
|
|
|
// It ignores X-Forwarded-For by default
|
|
req.Header.Set("X-Forwarded-For", "1.1.1.1")
|
|
w = httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code)
|
|
|
|
// It doesn't rate limit a new value for the limited header
|
|
req = httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("My-Custom-Header", "5.6.7.8")
|
|
w = httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestTokenPasswordGrantSuccess() {
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"email": "test@example.com",
|
|
"password": "password",
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusOK, w.Code)
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestTokenRefreshTokenGrantSuccess() {
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"refresh_token": ts.RefreshToken.Token,
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusOK, w.Code)
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestTokenPasswordGrantFailure() {
|
|
u := ts.createBannedUser()
|
|
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"email": u.GetEmail(),
|
|
"password": "password",
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() {
|
|
authCode := "1234563"
|
|
codeVerifier := "4a9505b9-0857-42bb-ab3c-098b4d28ddc2"
|
|
invalidAuthCode := authCode + "123"
|
|
invalidVerifier := codeVerifier + "123"
|
|
codeChallenge := sha256.Sum256([]byte(codeVerifier))
|
|
challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:])
|
|
flowState := models.NewFlowState("github", challenge, models.SHA256, models.OAuth, nil)
|
|
flowState.AuthCode = authCode
|
|
require.NoError(ts.T(), ts.API.db.Create(flowState))
|
|
cases := []struct {
|
|
desc string
|
|
authCode string
|
|
codeVerifier string
|
|
grantType string
|
|
expectedHTTPCode int
|
|
}{
|
|
{
|
|
desc: "Invalid Authcode",
|
|
authCode: invalidAuthCode,
|
|
codeVerifier: codeVerifier,
|
|
},
|
|
{
|
|
desc: "Invalid code verifier",
|
|
authCode: authCode,
|
|
codeVerifier: invalidVerifier,
|
|
},
|
|
{
|
|
desc: "Invalid auth code and verifier",
|
|
authCode: invalidAuthCode,
|
|
codeVerifier: invalidVerifier,
|
|
},
|
|
}
|
|
for _, v := range cases {
|
|
ts.Run(v.desc, func() {
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"code_verifier": v.codeVerifier,
|
|
"auth_code": v.authCode,
|
|
}))
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusNotFound, w.Code)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestTokenRefreshTokenGrantFailure() {
|
|
_ = ts.createBannedUser()
|
|
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"refresh_token": ts.RefreshToken.Token,
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() {
|
|
originalSecurity := ts.API.config.Security
|
|
|
|
ts.API.config.Security.RefreshTokenRotationEnabled = true
|
|
ts.API.config.Security.RefreshTokenReuseInterval = 0
|
|
|
|
defer func() {
|
|
ts.API.config.Security = originalSecurity
|
|
}()
|
|
|
|
refreshTokens := []string{
|
|
ts.RefreshToken.Token,
|
|
}
|
|
|
|
for i := 0; i < 3; i += 1 {
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"refresh_token": refreshTokens[len(refreshTokens)-1],
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(ts.T(), http.StatusOK, w.Code)
|
|
|
|
var response struct {
|
|
RefreshToken string `json:"refresh_token"`
|
|
}
|
|
|
|
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response))
|
|
|
|
refreshTokens = append(refreshTokens, response.RefreshToken)
|
|
}
|
|
|
|
// ensure that the 4 refresh tokens are setup correctly
|
|
for i, refreshToken := range refreshTokens {
|
|
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
|
|
require.NoError(ts.T(), err)
|
|
|
|
if i == len(refreshTokens)-1 {
|
|
require.False(ts.T(), token.Revoked)
|
|
} else {
|
|
require.True(ts.T(), token.Revoked)
|
|
}
|
|
}
|
|
|
|
// try to reuse the first (earliest) refresh token which should trigger the family revocation logic
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"refresh_token": refreshTokens[0],
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
|
|
|
|
var response struct {
|
|
ErrorCode string `json:"error_code"`
|
|
Message string `json:"msg"`
|
|
}
|
|
|
|
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response))
|
|
require.Equal(ts.T(), ErrorCodeRefreshTokenAlreadyUsed, response.ErrorCode)
|
|
require.Equal(ts.T(), "Invalid Refresh Token: Already Used", response.Message)
|
|
|
|
// ensure that the refresh tokens are marked as revoked in the database
|
|
for _, refreshToken := range refreshTokens {
|
|
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
|
|
require.NoError(ts.T(), err)
|
|
|
|
require.True(ts.T(), token.Revoked)
|
|
}
|
|
|
|
// finally ensure that none of the refresh tokens can be reused any
|
|
// more, starting with the previously valid one
|
|
for i := len(refreshTokens) - 1; i >= 0; i -= 1 {
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"refresh_token": refreshTokens[i],
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(ts.T(), http.StatusBadRequest, w.Code, "For refresh token %d", i)
|
|
|
|
var response struct {
|
|
ErrorCode string `json:"error_code"`
|
|
Message string `json:"msg"`
|
|
}
|
|
|
|
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response))
|
|
require.Equal(ts.T(), ErrorCodeRefreshTokenAlreadyUsed, response.ErrorCode, "For refresh token %d", i)
|
|
require.Equal(ts.T(), "Invalid Refresh Token: Already Used", response.Message, "For refresh token %d", i)
|
|
}
|
|
}
|
|
|
|
func (ts *TokenTestSuite) createBannedUser() *models.User {
|
|
u, err := models.NewUser("", "banned@example.com", "password", ts.Config.JWT.Aud, nil)
|
|
require.NoError(ts.T(), err, "Error creating test user model")
|
|
t := time.Now()
|
|
u.EmailConfirmedAt = &t
|
|
t = t.Add(24 * time.Hour)
|
|
u.BannedUntil = &t
|
|
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test banned user")
|
|
|
|
ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{})
|
|
require.NoError(ts.T(), err, "Error creating refresh token")
|
|
|
|
return u
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestTokenRefreshWithExpiredSession() {
|
|
var err error
|
|
|
|
now := time.Now().UTC().Add(-1 * time.Second)
|
|
|
|
ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{
|
|
SessionNotAfter: &now,
|
|
})
|
|
require.NoError(ts.T(), err, "Error creating refresh token")
|
|
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"refresh_token": ts.RefreshToken.Token,
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestTokenRefreshWithUnexpiredSession() {
|
|
var err error
|
|
|
|
now := time.Now().UTC().Add(1 * time.Second)
|
|
|
|
ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{
|
|
SessionNotAfter: &now,
|
|
})
|
|
require.NoError(ts.T(), err, "Error creating refresh token")
|
|
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"refresh_token": ts.RefreshToken.Token,
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusOK, w.Code)
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestMagicLinkPKCESignIn() {
|
|
var buffer bytes.Buffer
|
|
// Send OTP
|
|
codeVerifier := "4a9505b9-0857-42bb-ab3c-098b4d28ddc2"
|
|
codeChallenge := sha256.Sum256([]byte(codeVerifier))
|
|
challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:])
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/otp", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(OtpParams{
|
|
Email: "test@example.com",
|
|
CreateUser: true,
|
|
CodeChallengeMethod: "s256",
|
|
CodeChallenge: challenge,
|
|
}))
|
|
req = httptest.NewRequest(http.MethodPost, "/otp", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
require.Equal(ts.T(), http.StatusOK, w.Code)
|
|
|
|
u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
|
|
require.NoError(ts.T(), err)
|
|
|
|
// Verify OTP
|
|
requestUrl := fmt.Sprintf("http://localhost/verify?type=%v&token=%v", "magiclink", u.RecoveryToken)
|
|
req = httptest.NewRequest(http.MethodGet, requestUrl, &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w = httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
assert.Equal(ts.T(), http.StatusSeeOther, w.Code)
|
|
rURL, _ := w.Result().Location()
|
|
|
|
u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
|
|
require.NoError(ts.T(), err)
|
|
assert.True(ts.T(), u.IsConfirmed())
|
|
|
|
f, err := url.ParseQuery(rURL.RawQuery)
|
|
require.NoError(ts.T(), err)
|
|
authCode := f.Get("code")
|
|
assert.NotEmpty(ts.T(), authCode)
|
|
// Extract token and sign in
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"code_verifier": codeVerifier,
|
|
"auth_code": authCode,
|
|
}))
|
|
req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
w = httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
require.Equal(ts.T(), http.StatusOK, w.Code)
|
|
verifyResp := &AccessTokenResponse{}
|
|
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&verifyResp))
|
|
require.NotEmpty(ts.T(), verifyResp.Token)
|
|
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestPasswordVerificationHook() {
|
|
type verificationHookTestcase struct {
|
|
desc string
|
|
uri string
|
|
hookFunctionSQL string
|
|
expectedCode int
|
|
}
|
|
cases := []verificationHookTestcase{
|
|
{
|
|
desc: "Default success",
|
|
uri: "pg-functions://postgres/auth/password_verification_hook",
|
|
hookFunctionSQL: `
|
|
create or replace function password_verification_hook(input jsonb)
|
|
returns jsonb as $$
|
|
begin
|
|
return jsonb_build_object('decision', 'continue');
|
|
end; $$ language plpgsql;`,
|
|
expectedCode: http.StatusOK,
|
|
}, {
|
|
desc: "Reject- Enabled",
|
|
uri: "pg-functions://postgres/auth/password_verification_hook_reject",
|
|
hookFunctionSQL: `
|
|
create or replace function password_verification_hook_reject(input jsonb)
|
|
returns jsonb as $$
|
|
begin
|
|
return jsonb_build_object('decision', 'reject', 'message', 'You shall not pass!');
|
|
end; $$ language plpgsql;`,
|
|
expectedCode: http.StatusBadRequest,
|
|
},
|
|
}
|
|
for _, c := range cases {
|
|
ts.T().Run(c.desc, func(t *testing.T) {
|
|
ts.Config.Hook.PasswordVerificationAttempt.Enabled = true
|
|
ts.Config.Hook.PasswordVerificationAttempt.URI = c.uri
|
|
require.NoError(ts.T(), ts.Config.Hook.PasswordVerificationAttempt.PopulateExtensibilityPoint())
|
|
|
|
err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec()
|
|
require.NoError(t, err)
|
|
var buffer bytes.Buffer
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"email": "test@example.com",
|
|
"password": "password",
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(ts.T(), c.expectedCode, w.Code)
|
|
cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.PasswordVerificationAttempt.HookName)
|
|
require.NoError(ts.T(), ts.API.db.RawQuery(cleanupHookSQL).Exec())
|
|
// Reset so it doesn't affect other tests
|
|
ts.Config.Hook.PasswordVerificationAttempt.Enabled = false
|
|
|
|
})
|
|
}
|
|
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestCustomAccessToken() {
|
|
type customAccessTokenTestcase struct {
|
|
desc string
|
|
uri string
|
|
hookFunctionSQL string
|
|
expectedClaims map[string]interface{}
|
|
shouldError bool
|
|
}
|
|
cases := []customAccessTokenTestcase{
|
|
{
|
|
desc: "Add a new claim",
|
|
uri: "pg-functions://postgres/auth/custom_access_token_add_claim",
|
|
hookFunctionSQL: ` create or replace function custom_access_token_add_claim(input jsonb) returns jsonb as $$ declare result jsonb; begin if jsonb_typeof(jsonb_object_field(input, 'claims')) is null then result := jsonb_build_object('error', jsonb_build_object('http_code', 400, 'message', 'Input does not contain claims field')); return result; end if;
|
|
input := jsonb_set(input, '{claims,newclaim}', '"newvalue"', true);
|
|
result := jsonb_build_object('claims', input->'claims');
|
|
return result;
|
|
end; $$ language plpgsql;`,
|
|
expectedClaims: map[string]interface{}{
|
|
"newclaim": "newvalue",
|
|
},
|
|
}, {
|
|
desc: "Delete the Role claim",
|
|
uri: "pg-functions://postgres/auth/custom_access_token_delete_claim",
|
|
hookFunctionSQL: `
|
|
create or replace function custom_access_token_delete_claim(input jsonb)
|
|
returns jsonb as $$
|
|
declare
|
|
result jsonb;
|
|
begin
|
|
input := jsonb_set(input, '{claims}', (input->'claims') - 'role');
|
|
result := jsonb_build_object('claims', input->'claims');
|
|
return result;
|
|
end; $$ language plpgsql;`,
|
|
expectedClaims: map[string]interface{}{},
|
|
shouldError: true,
|
|
}, {
|
|
desc: "Delete a non-required claim (UserMetadata)",
|
|
uri: "pg-functions://postgres/auth/custom_access_token_delete_usermetadata",
|
|
hookFunctionSQL: `
|
|
create or replace function custom_access_token_delete_usermetadata(input jsonb)
|
|
returns jsonb as $$
|
|
declare
|
|
result jsonb;
|
|
begin
|
|
input := jsonb_set(input, '{claims}', (input->'claims') - 'user_metadata');
|
|
result := jsonb_build_object('claims', input->'claims');
|
|
return result;
|
|
end; $$ language plpgsql;`,
|
|
// Not used
|
|
expectedClaims: map[string]interface{}{
|
|
"user_metadata": nil,
|
|
},
|
|
shouldError: false,
|
|
},
|
|
}
|
|
for _, c := range cases {
|
|
ts.T().Run(c.desc, func(t *testing.T) {
|
|
ts.Config.Hook.CustomAccessToken.Enabled = true
|
|
ts.Config.Hook.CustomAccessToken.URI = c.uri
|
|
require.NoError(t, ts.Config.Hook.CustomAccessToken.PopulateExtensibilityPoint())
|
|
|
|
err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec()
|
|
require.NoError(t, err)
|
|
|
|
var buffer bytes.Buffer
|
|
require.NoError(t, json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"refresh_token": ts.RefreshToken.Token,
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
|
|
var tokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
}
|
|
require.NoError(t, json.NewDecoder(w.Result().Body).Decode(&tokenResponse))
|
|
if c.shouldError {
|
|
require.Equal(t, http.StatusInternalServerError, w.Code)
|
|
} else {
|
|
parts := strings.Split(tokenResponse.AccessToken, ".")
|
|
require.Equal(t, 3, len(parts), "Token should have 3 parts")
|
|
|
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
require.NoError(t, err)
|
|
|
|
var responseClaims map[string]interface{}
|
|
require.NoError(t, json.Unmarshal(payload, &responseClaims))
|
|
|
|
for key, expectedValue := range c.expectedClaims {
|
|
if expectedValue == nil {
|
|
// Since c.shouldError is false here, we only need to check if the claim should be removed
|
|
_, exists := responseClaims[key]
|
|
assert.False(t, exists, "Claim should be removed")
|
|
} else {
|
|
assert.Equal(t, expectedValue, responseClaims[key])
|
|
}
|
|
}
|
|
}
|
|
|
|
cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.CustomAccessToken.HookName)
|
|
require.NoError(t, ts.API.db.RawQuery(cleanupHookSQL).Exec())
|
|
ts.Config.Hook.CustomAccessToken.Enabled = false
|
|
})
|
|
}
|
|
}
|
|
|
|
func (ts *TokenTestSuite) TestAllowSelectAuthenticationMethods() {
|
|
|
|
companyUser, err := models.NewUser("12345678", "test@company.com", "password", ts.Config.JWT.Aud, nil)
|
|
t := time.Now()
|
|
companyUser.EmailConfirmedAt = &t
|
|
require.NoError(ts.T(), err, "Error creating test user model")
|
|
require.NoError(ts.T(), ts.API.db.Create(companyUser), "Error saving new test user")
|
|
|
|
type allowSelectAuthMethodsTestcase struct {
|
|
desc string
|
|
uri string
|
|
email string
|
|
expectedError string
|
|
expectedStatus int
|
|
}
|
|
|
|
// Common hook function SQL definition
|
|
hookFunctionSQL := `
|
|
create or replace function auth.custom_access_token(event jsonb) returns jsonb language plpgsql as $$
|
|
declare
|
|
email_claim text;
|
|
authentication_method text;
|
|
begin
|
|
email_claim := event->'claims'->>'email';
|
|
authentication_method := event->>'authentication_method';
|
|
|
|
if authentication_method = 'password' and email_claim not like '%@company.com' then
|
|
return jsonb_build_object(
|
|
'error', jsonb_build_object(
|
|
'http_code', 403,
|
|
'message', 'only members on company.com can access with password authentication'
|
|
)
|
|
);
|
|
end if;
|
|
|
|
return event;
|
|
end;
|
|
$$;`
|
|
|
|
cases := []allowSelectAuthMethodsTestcase{
|
|
{
|
|
desc: "Error for non-protected domain with password authentication",
|
|
uri: "pg-functions://postgres/auth/custom_access_token",
|
|
email: "test@example.com",
|
|
expectedError: "only members on company.com can access with password authentication",
|
|
expectedStatus: http.StatusForbidden,
|
|
},
|
|
{
|
|
desc: "Allow access for protected domain with password authentication",
|
|
uri: "pg-functions://postgres/auth/custom_access_token",
|
|
email: companyUser.Email.String(),
|
|
expectedError: "",
|
|
expectedStatus: http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
ts.T().Run(c.desc, func(t *testing.T) {
|
|
// Enable and set up the custom access token hook
|
|
ts.Config.Hook.CustomAccessToken.Enabled = true
|
|
ts.Config.Hook.CustomAccessToken.URI = c.uri
|
|
require.NoError(t, ts.Config.Hook.CustomAccessToken.PopulateExtensibilityPoint())
|
|
|
|
// Execute the common hook function SQL
|
|
err := ts.API.db.RawQuery(hookFunctionSQL).Exec()
|
|
require.NoError(t, err)
|
|
|
|
var buffer bytes.Buffer
|
|
|
|
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
|
|
"email": c.email,
|
|
"password": "password",
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
ts.API.handler.ServeHTTP(w, req)
|
|
|
|
require.Equal(t, c.expectedStatus, w.Code, "Unexpected HTTP status code")
|
|
if c.expectedError != "" {
|
|
require.Contains(t, w.Body.String(), c.expectedError, "Expected error message not found")
|
|
} else {
|
|
require.NotContains(t, w.Body.String(), "error", "Unexpected error occurred")
|
|
}
|
|
|
|
// Delete the function and cleanup
|
|
cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.CustomAccessToken.HookName)
|
|
require.NoError(t, ts.API.db.RawQuery(cleanupHookSQL).Exec())
|
|
ts.Config.Hook.CustomAccessToken.Enabled = false
|
|
})
|
|
}
|
|
}
|