package models import ( "database/sql" "fmt" "sort" "strings" "time" "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/pkg/errors" "github.com/supabase/auth/internal/storage" ) type AuthenticatorAssuranceLevel int const ( AAL1 AuthenticatorAssuranceLevel = iota AAL2 AAL3 ) func (aal AuthenticatorAssuranceLevel) String() string { switch aal { case AAL1: return "aal1" case AAL2: return "aal2" case AAL3: return "aal3" default: return "" } } // AMREntry represents a method that a user has logged in together with the corresponding time type AMREntry struct { Method string `json:"method"` Timestamp int64 `json:"timestamp"` Provider string `json:"provider,omitempty"` } type sortAMREntries struct { Array []AMREntry } func (s sortAMREntries) Len() int { return len(s.Array) } func (s sortAMREntries) Less(i, j int) bool { return s.Array[i].Timestamp < s.Array[j].Timestamp } func (s sortAMREntries) Swap(i, j int) { s.Array[j], s.Array[i] = s.Array[i], s.Array[j] } type Session struct { ID uuid.UUID `json:"-" db:"id"` UserID uuid.UUID `json:"user_id" db:"user_id"` // NotAfter is overriden by timeboxed sessions. NotAfter *time.Time `json:"not_after,omitempty" db:"not_after"` CreatedAt time.Time `json:"created_at" db:"created_at"` UpdatedAt time.Time `json:"updated_at" db:"updated_at"` FactorID *uuid.UUID `json:"factor_id" db:"factor_id"` AMRClaims []AMRClaim `json:"amr,omitempty" has_many:"amr_claims"` AAL *string `json:"aal" db:"aal"` RefreshedAt *time.Time `json:"refreshed_at,omitempty" db:"refreshed_at"` UserAgent *string `json:"user_agent,omitempty" db:"user_agent"` IP *string `json:"ip,omitempty" db:"ip"` Tag *string `json:"tag" db:"tag"` } func (Session) TableName() string { tableName := "sessions" return tableName } func (s *Session) LastRefreshedAt(refreshTokenTime *time.Time) time.Time { refreshedAt := s.RefreshedAt if refreshedAt == nil || refreshedAt.IsZero() { if refreshTokenTime != nil { rtt := *refreshTokenTime if rtt.IsZero() { return s.CreatedAt } else if rtt.After(s.CreatedAt) { return rtt } } return s.CreatedAt } return *refreshedAt } func (s *Session) UpdateOnlyRefreshInfo(tx *storage.Connection) error { // TODO(kangmingtay): The underlying database type uses timestamp without timezone, // so we need to convert the value to UTC before updating it. // In the future, we should add a migration to update the type to contain the timezone. *s.RefreshedAt = s.RefreshedAt.UTC() return tx.UpdateOnly(s, "refreshed_at", "user_agent", "ip") } type SessionValidityReason = int const ( SessionValid SessionValidityReason = iota SessionPastNotAfter = iota SessionPastTimebox = iota SessionTimedOut = iota ) func (s *Session) CheckValidity(now time.Time, refreshTokenTime *time.Time, timebox, inactivityTimeout *time.Duration) SessionValidityReason { if s.NotAfter != nil && now.After(*s.NotAfter) { return SessionPastNotAfter } if timebox != nil && *timebox != 0 && now.After(s.CreatedAt.Add(*timebox)) { return SessionPastTimebox } if inactivityTimeout != nil && *inactivityTimeout != 0 && now.After(s.LastRefreshedAt(refreshTokenTime).Add(*inactivityTimeout)) { return SessionTimedOut } return SessionValid } func (s *Session) DetermineTag(tags []string) string { if len(tags) == 0 { return "" } if s.Tag == nil { return tags[0] } tag := *s.Tag if tag == "" { return tags[0] } for _, t := range tags { if t == tag { return tag } } return tags[0] } func NewSession(userID uuid.UUID, factorID *uuid.UUID) (*Session, error) { id := uuid.Must(uuid.NewV4()) defaultAAL := AAL1.String() session := &Session{ ID: id, AAL: &defaultAAL, UserID: userID, FactorID: factorID, } return session, nil } // FindSessionByID looks up a Session by the provided id. If forUpdate is set // to true, then the SELECT statement used by the query has the form SELECT ... // FOR UPDATE SKIP LOCKED. This means that a FOR UPDATE lock will only be // acquired if there's no other lock. In case there is a lock, a // IsNotFound(err) error will be retured. func FindSessionByID(tx *storage.Connection, id uuid.UUID, forUpdate bool) (*Session, error) { session := &Session{} if forUpdate { // pop does not provide us with a way to execute FOR UPDATE // queries which lock the rows affected by the query from // being accessed by any other transaction that also uses FOR // UPDATE if err := tx.RawQuery(fmt.Sprintf("SELECT * FROM %q WHERE id = ? LIMIT 1 FOR UPDATE SKIP LOCKED;", session.TableName()), id).First(session); err != nil { if errors.Cause(err) == sql.ErrNoRows { return nil, SessionNotFoundError{} } return nil, err } } // once the rows are locked (if forUpdate was true), we can query again using pop if err := tx.Eager().Q().Where("id = ?", id).First(session); err != nil { if errors.Cause(err) == sql.ErrNoRows { return nil, SessionNotFoundError{} } return nil, errors.Wrap(err, "error finding session") } return session, nil } func FindSessionByUserID(tx *storage.Connection, userId uuid.UUID) (*Session, error) { session := &Session{} if err := tx.Eager().Q().Where("user_id = ?", userId).Order("created_at asc").First(session); err != nil { if errors.Cause(err) == sql.ErrNoRows { return nil, SessionNotFoundError{} } return nil, errors.Wrap(err, "error finding session") } return session, nil } func FindSessionsByFactorID(tx *storage.Connection, factorID uuid.UUID) ([]*Session, error) { sessions := []*Session{} if err := tx.Q().Where("factor_id = ?", factorID).All(&sessions); err != nil { return nil, err } return sessions, nil } // FindAllSessionsForUser finds all of the sessions for a user. If forUpdate is // set, it will first lock on the user row which can be used to prevent issues // with concurrency. If the lock is acquired, it will return a // UserNotFoundError and the operation should be retried. If there are no // sessions for the user, a nil result is returned without an error. func FindAllSessionsForUser(tx *storage.Connection, userId uuid.UUID, forUpdate bool) ([]*Session, error) { if forUpdate { user := &User{} if err := tx.RawQuery(fmt.Sprintf("SELECT id FROM %q WHERE id = ? LIMIT 1 FOR UPDATE SKIP LOCKED;", user.TableName()), userId).First(user); err != nil { if errors.Cause(err) == sql.ErrNoRows { return nil, UserNotFoundError{} } return nil, err } } var sessions []*Session if err := tx.Where("user_id = ?", userId).All(&sessions); err != nil { if errors.Cause(err) == sql.ErrNoRows { return nil, nil } return nil, err } return sessions, nil } func updateFactorAssociatedSessions(tx *storage.Connection, userID, factorID uuid.UUID, aal string) error { return tx.RawQuery("UPDATE "+(&pop.Model{Value: Session{}}).TableName()+" set aal = ?, factor_id = ? WHERE user_id = ? AND factor_id = ?", aal, nil, userID, factorID).Exec() } func InvalidateSessionsWithAALLessThan(tx *storage.Connection, userID uuid.UUID, level string) error { return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE user_id = ? AND aal < ?", userID, level).Exec() } // Logout deletes all sessions for a user. func Logout(tx *storage.Connection, userId uuid.UUID) error { return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE user_id = ?", userId).Exec() } // LogoutSession deletes the current session for a user func LogoutSession(tx *storage.Connection, sessionId uuid.UUID) error { return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE id = ?", sessionId).Exec() } // LogoutAllExceptMe deletes all sessions for a user except the current one func LogoutAllExceptMe(tx *storage.Connection, sessionId uuid.UUID, userID uuid.UUID) error { return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE id != ? AND user_id = ?", sessionId, userID).Exec() } func (s *Session) UpdateAALAndAssociatedFactor(tx *storage.Connection, aal AuthenticatorAssuranceLevel, factorID *uuid.UUID) error { s.FactorID = factorID aalAsString := aal.String() s.AAL = &aalAsString return tx.UpdateOnly(s, "aal", "factor_id") } func (s *Session) CalculateAALAndAMR(user *User) (aal AuthenticatorAssuranceLevel, amr []AMREntry, err error) { amr, aal = []AMREntry{}, AAL1 for _, claim := range s.AMRClaims { if claim.IsAAL2Claim() { aal = AAL2 } amr = append(amr, AMREntry{Method: claim.GetAuthenticationMethod(), Timestamp: claim.UpdatedAt.Unix()}) } // makes sure that the AMR claims are always ordered most-recent first // sort in ascending order sort.Sort(sortAMREntries{ Array: amr, }) // now reverse for descending order _ = sort.Reverse(sortAMREntries{ Array: amr, }) lastIndex := len(amr) - 1 if lastIndex > -1 && amr[lastIndex].Method == SSOSAML.String() { // initial AMR claim is from sso/saml, we need to add information // about the provider that was used for the authentication identities := user.Identities if len(identities) == 1 { identity := identities[0] if identity.IsForSSOProvider() { amr[lastIndex].Provider = strings.TrimPrefix(identity.Provider, "sso:") } } // otherwise we can't identify that this user account has only // one SSO identity, so we are not encoding the provider at // this time } return aal, amr, nil } func (s *Session) GetAAL() string { if s.AAL == nil { return "" } return *(s.AAL) } func (s *Session) IsAAL2() bool { return s.GetAAL() == AAL2.String() } // FindCurrentlyActiveRefreshToken returns the currently active refresh // token in the session. This is the last created (ordered by the serial // primary key) non-revoked refresh token for the session. func (s *Session) FindCurrentlyActiveRefreshToken(tx *storage.Connection) (*RefreshToken, error) { var activeRefreshToken RefreshToken if err := tx.Q().Where("session_id = ? and revoked is false", s.ID).Order("id desc").First(&activeRefreshToken); err != nil { if errors.Cause(err) == sql.ErrNoRows || errors.Is(err, sql.ErrNoRows) { return nil, RefreshTokenNotFoundError{} } return nil, err } return &activeRefreshToken, nil }