105 lines
2.8 KiB
Go
105 lines
2.8 KiB
Go
package models
|
|
|
|
import (
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/stretchr/testify/suite"
|
|
"github.com/supabase/auth/internal/conf"
|
|
"github.com/supabase/auth/internal/storage"
|
|
"github.com/supabase/auth/internal/storage/test"
|
|
)
|
|
|
|
type SessionsTestSuite struct {
|
|
suite.Suite
|
|
db *storage.Connection
|
|
Config *conf.GlobalConfiguration
|
|
}
|
|
|
|
func (ts *SessionsTestSuite) SetupTest() {
|
|
TruncateAll(ts.db)
|
|
email := "test@example.com"
|
|
user, err := NewUser("", email, "secret", ts.Config.JWT.Aud, nil)
|
|
require.NoError(ts.T(), err)
|
|
|
|
err = ts.db.Create(user)
|
|
require.NoError(ts.T(), err)
|
|
}
|
|
|
|
func TestSession(t *testing.T) {
|
|
globalConfig, err := conf.LoadGlobal(modelsTestConfig)
|
|
require.NoError(t, err)
|
|
conn, err := test.SetupDBConnection(globalConfig)
|
|
require.NoError(t, err)
|
|
ts := &SessionsTestSuite{
|
|
db: conn,
|
|
Config: globalConfig,
|
|
}
|
|
defer ts.db.Close()
|
|
suite.Run(t, ts)
|
|
}
|
|
|
|
func (ts *SessionsTestSuite) TestFindBySessionIDWithForUpdate() {
|
|
u, err := FindUserByEmailAndAudience(ts.db, "test@example.com", ts.Config.JWT.Aud)
|
|
require.NoError(ts.T(), err)
|
|
session, err := NewSession(u.ID, nil)
|
|
require.NoError(ts.T(), err)
|
|
require.NoError(ts.T(), ts.db.Create(session))
|
|
|
|
found, err := FindSessionByID(ts.db, session.ID, true)
|
|
require.NoError(ts.T(), err)
|
|
|
|
require.Equal(ts.T(), session.ID, found.ID)
|
|
}
|
|
|
|
func (ts *SessionsTestSuite) AddClaimAndReloadSession(session *Session, claim AuthenticationMethod) *Session {
|
|
err := AddClaimToSession(ts.db, session.ID, claim)
|
|
require.NoError(ts.T(), err)
|
|
session, err = FindSessionByID(ts.db, session.ID, false)
|
|
require.NoError(ts.T(), err)
|
|
return session
|
|
}
|
|
|
|
func (ts *SessionsTestSuite) TestCalculateAALAndAMR() {
|
|
totalDistinctClaims := 3
|
|
u, err := FindUserByEmailAndAudience(ts.db, "test@example.com", ts.Config.JWT.Aud)
|
|
require.NoError(ts.T(), err)
|
|
session, err := NewSession(u.ID, nil)
|
|
require.NoError(ts.T(), err)
|
|
require.NoError(ts.T(), ts.db.Create(session))
|
|
|
|
session = ts.AddClaimAndReloadSession(session, PasswordGrant)
|
|
|
|
firstClaimAddedTime := time.Now()
|
|
session = ts.AddClaimAndReloadSession(session, TOTPSignIn)
|
|
|
|
_, _, err = session.CalculateAALAndAMR(u)
|
|
require.NoError(ts.T(), err)
|
|
|
|
session = ts.AddClaimAndReloadSession(session, TOTPSignIn)
|
|
|
|
session = ts.AddClaimAndReloadSession(session, SSOSAML)
|
|
|
|
aal, amr, err := session.CalculateAALAndAMR(u)
|
|
require.NoError(ts.T(), err)
|
|
|
|
require.Equal(ts.T(), AAL2, aal)
|
|
require.Equal(ts.T(), totalDistinctClaims, len(amr))
|
|
|
|
found := false
|
|
for _, claim := range session.AMRClaims {
|
|
if claim.GetAuthenticationMethod() == TOTPSignIn.String() {
|
|
require.True(ts.T(), firstClaimAddedTime.Before(claim.UpdatedAt))
|
|
found = true
|
|
}
|
|
}
|
|
|
|
for _, claim := range amr {
|
|
if claim.Method == SSOSAML.String() {
|
|
require.NotNil(ts.T(), claim.Provider)
|
|
}
|
|
}
|
|
require.True(ts.T(), found)
|
|
}
|