357 lines
10 KiB
Go
357 lines
10 KiB
Go
package models
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gobuffalo/pop/v6"
|
|
"github.com/gofrs/uuid"
|
|
"github.com/pkg/errors"
|
|
"github.com/supabase/auth/internal/storage"
|
|
)
|
|
|
|
type AuthenticatorAssuranceLevel int
|
|
|
|
const (
|
|
AAL1 AuthenticatorAssuranceLevel = iota
|
|
AAL2
|
|
AAL3
|
|
)
|
|
|
|
func (aal AuthenticatorAssuranceLevel) String() string {
|
|
switch aal {
|
|
case AAL1:
|
|
return "aal1"
|
|
case AAL2:
|
|
return "aal2"
|
|
case AAL3:
|
|
return "aal3"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
// AMREntry represents a method that a user has logged in together with the corresponding time
|
|
type AMREntry struct {
|
|
Method string `json:"method"`
|
|
Timestamp int64 `json:"timestamp"`
|
|
Provider string `json:"provider,omitempty"`
|
|
}
|
|
|
|
type sortAMREntries struct {
|
|
Array []AMREntry
|
|
}
|
|
|
|
func (s sortAMREntries) Len() int {
|
|
return len(s.Array)
|
|
}
|
|
|
|
func (s sortAMREntries) Less(i, j int) bool {
|
|
return s.Array[i].Timestamp < s.Array[j].Timestamp
|
|
}
|
|
|
|
func (s sortAMREntries) Swap(i, j int) {
|
|
s.Array[j], s.Array[i] = s.Array[i], s.Array[j]
|
|
}
|
|
|
|
type Session struct {
|
|
ID uuid.UUID `json:"-" db:"id"`
|
|
UserID uuid.UUID `json:"user_id" db:"user_id"`
|
|
|
|
// NotAfter is overriden by timeboxed sessions.
|
|
NotAfter *time.Time `json:"not_after,omitempty" db:"not_after"`
|
|
|
|
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
|
FactorID *uuid.UUID `json:"factor_id" db:"factor_id"`
|
|
AMRClaims []AMRClaim `json:"amr,omitempty" has_many:"amr_claims"`
|
|
AAL *string `json:"aal" db:"aal"`
|
|
|
|
RefreshedAt *time.Time `json:"refreshed_at,omitempty" db:"refreshed_at"`
|
|
UserAgent *string `json:"user_agent,omitempty" db:"user_agent"`
|
|
IP *string `json:"ip,omitempty" db:"ip"`
|
|
|
|
Tag *string `json:"tag" db:"tag"`
|
|
}
|
|
|
|
func (Session) TableName() string {
|
|
tableName := "sessions"
|
|
return tableName
|
|
}
|
|
|
|
func (s *Session) LastRefreshedAt(refreshTokenTime *time.Time) time.Time {
|
|
refreshedAt := s.RefreshedAt
|
|
|
|
if refreshedAt == nil || refreshedAt.IsZero() {
|
|
if refreshTokenTime != nil {
|
|
rtt := *refreshTokenTime
|
|
|
|
if rtt.IsZero() {
|
|
return s.CreatedAt
|
|
} else if rtt.After(s.CreatedAt) {
|
|
return rtt
|
|
}
|
|
}
|
|
|
|
return s.CreatedAt
|
|
}
|
|
|
|
return *refreshedAt
|
|
}
|
|
|
|
func (s *Session) UpdateOnlyRefreshInfo(tx *storage.Connection) error {
|
|
// TODO(kangmingtay): The underlying database type uses timestamp without timezone,
|
|
// so we need to convert the value to UTC before updating it.
|
|
// In the future, we should add a migration to update the type to contain the timezone.
|
|
*s.RefreshedAt = s.RefreshedAt.UTC()
|
|
return tx.UpdateOnly(s, "refreshed_at", "user_agent", "ip")
|
|
}
|
|
|
|
type SessionValidityReason = int
|
|
|
|
const (
|
|
SessionValid SessionValidityReason = iota
|
|
SessionPastNotAfter = iota
|
|
SessionPastTimebox = iota
|
|
SessionTimedOut = iota
|
|
)
|
|
|
|
func (s *Session) CheckValidity(now time.Time, refreshTokenTime *time.Time, timebox, inactivityTimeout *time.Duration) SessionValidityReason {
|
|
if s.NotAfter != nil && now.After(*s.NotAfter) {
|
|
return SessionPastNotAfter
|
|
}
|
|
|
|
if timebox != nil && *timebox != 0 && now.After(s.CreatedAt.Add(*timebox)) {
|
|
return SessionPastTimebox
|
|
}
|
|
|
|
if inactivityTimeout != nil && *inactivityTimeout != 0 && now.After(s.LastRefreshedAt(refreshTokenTime).Add(*inactivityTimeout)) {
|
|
return SessionTimedOut
|
|
}
|
|
|
|
return SessionValid
|
|
}
|
|
|
|
func (s *Session) DetermineTag(tags []string) string {
|
|
if len(tags) == 0 {
|
|
return ""
|
|
}
|
|
|
|
if s.Tag == nil {
|
|
return tags[0]
|
|
}
|
|
|
|
tag := *s.Tag
|
|
if tag == "" {
|
|
return tags[0]
|
|
}
|
|
|
|
for _, t := range tags {
|
|
if t == tag {
|
|
return tag
|
|
}
|
|
}
|
|
|
|
return tags[0]
|
|
}
|
|
|
|
func NewSession(userID uuid.UUID, factorID *uuid.UUID) (*Session, error) {
|
|
id := uuid.Must(uuid.NewV4())
|
|
|
|
defaultAAL := AAL1.String()
|
|
|
|
session := &Session{
|
|
ID: id,
|
|
AAL: &defaultAAL,
|
|
UserID: userID,
|
|
FactorID: factorID,
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
// FindSessionByID looks up a Session by the provided id. If forUpdate is set
|
|
// to true, then the SELECT statement used by the query has the form SELECT ...
|
|
// FOR UPDATE SKIP LOCKED. This means that a FOR UPDATE lock will only be
|
|
// acquired if there's no other lock. In case there is a lock, a
|
|
// IsNotFound(err) error will be retured.
|
|
func FindSessionByID(tx *storage.Connection, id uuid.UUID, forUpdate bool) (*Session, error) {
|
|
session := &Session{}
|
|
|
|
if forUpdate {
|
|
// pop does not provide us with a way to execute FOR UPDATE
|
|
// queries which lock the rows affected by the query from
|
|
// being accessed by any other transaction that also uses FOR
|
|
// UPDATE
|
|
if err := tx.RawQuery(fmt.Sprintf("SELECT * FROM %q WHERE id = ? LIMIT 1 FOR UPDATE SKIP LOCKED;", session.TableName()), id).First(session); err != nil {
|
|
if errors.Cause(err) == sql.ErrNoRows {
|
|
return nil, SessionNotFoundError{}
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// once the rows are locked (if forUpdate was true), we can query again using pop
|
|
if err := tx.Eager().Q().Where("id = ?", id).First(session); err != nil {
|
|
if errors.Cause(err) == sql.ErrNoRows {
|
|
return nil, SessionNotFoundError{}
|
|
}
|
|
return nil, errors.Wrap(err, "error finding session")
|
|
}
|
|
return session, nil
|
|
}
|
|
|
|
func FindSessionByUserID(tx *storage.Connection, userId uuid.UUID) (*Session, error) {
|
|
session := &Session{}
|
|
if err := tx.Eager().Q().Where("user_id = ?", userId).Order("created_at asc").First(session); err != nil {
|
|
if errors.Cause(err) == sql.ErrNoRows {
|
|
return nil, SessionNotFoundError{}
|
|
}
|
|
return nil, errors.Wrap(err, "error finding session")
|
|
}
|
|
return session, nil
|
|
}
|
|
|
|
func FindSessionsByFactorID(tx *storage.Connection, factorID uuid.UUID) ([]*Session, error) {
|
|
sessions := []*Session{}
|
|
if err := tx.Q().Where("factor_id = ?", factorID).All(&sessions); err != nil {
|
|
return nil, err
|
|
}
|
|
return sessions, nil
|
|
}
|
|
|
|
// FindAllSessionsForUser finds all of the sessions for a user. If forUpdate is
|
|
// set, it will first lock on the user row which can be used to prevent issues
|
|
// with concurrency. If the lock is acquired, it will return a
|
|
// UserNotFoundError and the operation should be retried. If there are no
|
|
// sessions for the user, a nil result is returned without an error.
|
|
func FindAllSessionsForUser(tx *storage.Connection, userId uuid.UUID, forUpdate bool) ([]*Session, error) {
|
|
if forUpdate {
|
|
user := &User{}
|
|
if err := tx.RawQuery(fmt.Sprintf("SELECT id FROM %q WHERE id = ? LIMIT 1 FOR UPDATE SKIP LOCKED;", user.TableName()), userId).First(user); err != nil {
|
|
if errors.Cause(err) == sql.ErrNoRows {
|
|
return nil, UserNotFoundError{}
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
var sessions []*Session
|
|
if err := tx.Where("user_id = ?", userId).All(&sessions); err != nil {
|
|
if errors.Cause(err) == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return sessions, nil
|
|
}
|
|
|
|
func updateFactorAssociatedSessions(tx *storage.Connection, userID, factorID uuid.UUID, aal string) error {
|
|
return tx.RawQuery("UPDATE "+(&pop.Model{Value: Session{}}).TableName()+" set aal = ?, factor_id = ? WHERE user_id = ? AND factor_id = ?", aal, nil, userID, factorID).Exec()
|
|
}
|
|
|
|
func InvalidateSessionsWithAALLessThan(tx *storage.Connection, userID uuid.UUID, level string) error {
|
|
return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE user_id = ? AND aal < ?", userID, level).Exec()
|
|
}
|
|
|
|
// Logout deletes all sessions for a user.
|
|
func Logout(tx *storage.Connection, userId uuid.UUID) error {
|
|
return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE user_id = ?", userId).Exec()
|
|
}
|
|
|
|
// LogoutSession deletes the current session for a user
|
|
func LogoutSession(tx *storage.Connection, sessionId uuid.UUID) error {
|
|
return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE id = ?", sessionId).Exec()
|
|
}
|
|
|
|
// LogoutAllExceptMe deletes all sessions for a user except the current one
|
|
func LogoutAllExceptMe(tx *storage.Connection, sessionId uuid.UUID, userID uuid.UUID) error {
|
|
return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE id != ? AND user_id = ?", sessionId, userID).Exec()
|
|
}
|
|
|
|
func (s *Session) UpdateAALAndAssociatedFactor(tx *storage.Connection, aal AuthenticatorAssuranceLevel, factorID *uuid.UUID) error {
|
|
s.FactorID = factorID
|
|
aalAsString := aal.String()
|
|
s.AAL = &aalAsString
|
|
return tx.UpdateOnly(s, "aal", "factor_id")
|
|
}
|
|
|
|
func (s *Session) CalculateAALAndAMR(user *User) (aal AuthenticatorAssuranceLevel, amr []AMREntry, err error) {
|
|
amr, aal = []AMREntry{}, AAL1
|
|
for _, claim := range s.AMRClaims {
|
|
if claim.IsAAL2Claim() {
|
|
aal = AAL2
|
|
}
|
|
amr = append(amr, AMREntry{Method: claim.GetAuthenticationMethod(), Timestamp: claim.UpdatedAt.Unix()})
|
|
}
|
|
|
|
// makes sure that the AMR claims are always ordered most-recent first
|
|
|
|
// sort in ascending order
|
|
sort.Sort(sortAMREntries{
|
|
Array: amr,
|
|
})
|
|
|
|
// now reverse for descending order
|
|
_ = sort.Reverse(sortAMREntries{
|
|
Array: amr,
|
|
})
|
|
|
|
lastIndex := len(amr) - 1
|
|
|
|
if lastIndex > -1 && amr[lastIndex].Method == SSOSAML.String() {
|
|
// initial AMR claim is from sso/saml, we need to add information
|
|
// about the provider that was used for the authentication
|
|
identities := user.Identities
|
|
|
|
if len(identities) == 1 {
|
|
identity := identities[0]
|
|
|
|
if identity.IsForSSOProvider() {
|
|
amr[lastIndex].Provider = strings.TrimPrefix(identity.Provider, "sso:")
|
|
}
|
|
}
|
|
|
|
// otherwise we can't identify that this user account has only
|
|
// one SSO identity, so we are not encoding the provider at
|
|
// this time
|
|
}
|
|
|
|
return aal, amr, nil
|
|
}
|
|
|
|
func (s *Session) GetAAL() string {
|
|
if s.AAL == nil {
|
|
return ""
|
|
}
|
|
return *(s.AAL)
|
|
}
|
|
|
|
func (s *Session) IsAAL2() bool {
|
|
return s.GetAAL() == AAL2.String()
|
|
}
|
|
|
|
// FindCurrentlyActiveRefreshToken returns the currently active refresh
|
|
// token in the session. This is the last created (ordered by the serial
|
|
// primary key) non-revoked refresh token for the session.
|
|
func (s *Session) FindCurrentlyActiveRefreshToken(tx *storage.Connection) (*RefreshToken, error) {
|
|
var activeRefreshToken RefreshToken
|
|
|
|
if err := tx.Q().Where("session_id = ? and revoked is false", s.ID).Order("id desc").First(&activeRefreshToken); err != nil {
|
|
if errors.Cause(err) == sql.ErrNoRows || errors.Is(err, sql.ErrNoRows) {
|
|
return nil, RefreshTokenNotFoundError{}
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return &activeRefreshToken, nil
|
|
}
|