287 lines
7.0 KiB
Go
287 lines
7.0 KiB
Go
package models
|
|
|
|
import (
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofrs/uuid"
|
|
"github.com/pkg/errors"
|
|
"github.com/supabase/auth/internal/storage"
|
|
)
|
|
|
|
type OneTimeTokenType int
|
|
|
|
const (
|
|
ConfirmationToken OneTimeTokenType = iota
|
|
ReauthenticationToken
|
|
RecoveryToken
|
|
EmailChangeTokenNew
|
|
EmailChangeTokenCurrent
|
|
PhoneChangeToken
|
|
)
|
|
|
|
func (t OneTimeTokenType) String() string {
|
|
switch t {
|
|
case ConfirmationToken:
|
|
return "confirmation_token"
|
|
|
|
case ReauthenticationToken:
|
|
return "reauthentication_token"
|
|
|
|
case RecoveryToken:
|
|
return "recovery_token"
|
|
|
|
case EmailChangeTokenNew:
|
|
return "email_change_token_new"
|
|
|
|
case EmailChangeTokenCurrent:
|
|
return "email_change_token_current"
|
|
|
|
case PhoneChangeToken:
|
|
return "phone_change_token"
|
|
|
|
default:
|
|
panic("OneTimeToken: unreachable case")
|
|
}
|
|
}
|
|
|
|
func ParseOneTimeTokenType(s string) (OneTimeTokenType, error) {
|
|
switch s {
|
|
case "confirmation_token":
|
|
return ConfirmationToken, nil
|
|
|
|
case "reauthentication_token":
|
|
return ReauthenticationToken, nil
|
|
|
|
case "recovery_token":
|
|
return RecoveryToken, nil
|
|
|
|
case "email_change_token_new":
|
|
return EmailChangeTokenNew, nil
|
|
|
|
case "email_change_token_current":
|
|
return EmailChangeTokenCurrent, nil
|
|
|
|
case "phone_change_token":
|
|
return PhoneChangeToken, nil
|
|
|
|
default:
|
|
return 0, fmt.Errorf("OneTimeTokenType: unrecognized string %q", s)
|
|
}
|
|
}
|
|
|
|
func (t OneTimeTokenType) Value() (driver.Value, error) {
|
|
return t.String(), nil
|
|
}
|
|
|
|
func (t *OneTimeTokenType) Scan(src interface{}) error {
|
|
s, ok := src.(string)
|
|
if !ok {
|
|
return fmt.Errorf("OneTimeTokenType: scan type is not string but is %T", src)
|
|
}
|
|
|
|
parsed, err := ParseOneTimeTokenType(s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
*t = parsed
|
|
return nil
|
|
}
|
|
|
|
type OneTimeTokenNotFoundError struct {
|
|
}
|
|
|
|
func (e OneTimeTokenNotFoundError) Error() string {
|
|
return "One-time token not found"
|
|
}
|
|
|
|
type OneTimeToken struct {
|
|
ID uuid.UUID `json:"id" db:"id"`
|
|
|
|
UserID uuid.UUID `json:"user_id" db:"user_id"`
|
|
TokenType OneTimeTokenType `json:"token_type" db:"token_type"`
|
|
|
|
TokenHash string `json:"token_hash" db:"token_hash"`
|
|
RelatesTo string `json:"relates_to" db:"relates_to"`
|
|
|
|
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
|
}
|
|
|
|
func (OneTimeToken) TableName() string {
|
|
return "one_time_tokens"
|
|
}
|
|
|
|
func ClearAllOneTimeTokensForUser(tx *storage.Connection, userID uuid.UUID) error {
|
|
return tx.Q().Where("user_id = ?", userID).Delete(OneTimeToken{})
|
|
}
|
|
|
|
func ClearOneTimeTokenForUser(tx *storage.Connection, userID uuid.UUID, tokenType OneTimeTokenType) error {
|
|
if err := tx.Q().Where("token_type = ? and user_id = ?", tokenType, userID).Delete(OneTimeToken{}); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func CreateOneTimeToken(tx *storage.Connection, userID uuid.UUID, relatesTo, tokenHash string, tokenType OneTimeTokenType) error {
|
|
if err := ClearOneTimeTokenForUser(tx, userID, tokenType); err != nil {
|
|
return err
|
|
}
|
|
|
|
oneTimeToken := &OneTimeToken{
|
|
ID: uuid.Must(uuid.NewV4()),
|
|
UserID: userID,
|
|
TokenType: tokenType,
|
|
TokenHash: tokenHash,
|
|
RelatesTo: strings.ToLower(relatesTo),
|
|
}
|
|
|
|
if err := tx.Eager().Create(oneTimeToken); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func FindOneTimeToken(tx *storage.Connection, tokenHash string, tokenTypes ...OneTimeTokenType) (*OneTimeToken, error) {
|
|
oneTimeToken := &OneTimeToken{}
|
|
|
|
query := tx.Eager().Q()
|
|
|
|
switch len(tokenTypes) {
|
|
case 2:
|
|
query = query.Where("(token_type = ? or token_type = ?) and token_hash = ?", tokenTypes[0], tokenTypes[1], tokenHash)
|
|
|
|
case 1:
|
|
query = query.Where("token_type = ? and token_hash = ?", tokenTypes[0], tokenHash)
|
|
|
|
default:
|
|
panic("at most 2 token types are accepted")
|
|
}
|
|
|
|
if err := query.First(oneTimeToken); err != nil {
|
|
if errors.Cause(err) == sql.ErrNoRows {
|
|
return nil, OneTimeTokenNotFoundError{}
|
|
}
|
|
|
|
return nil, errors.Wrap(err, "error finding one time token")
|
|
}
|
|
|
|
return oneTimeToken, nil
|
|
}
|
|
|
|
// FindUserByConfirmationToken finds users with the matching confirmation token.
|
|
func FindUserByConfirmationOrRecoveryToken(tx *storage.Connection, token string) (*User, error) {
|
|
ott, err := FindOneTimeToken(tx, token, ConfirmationToken, RecoveryToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return FindUserByID(tx, ott.UserID)
|
|
}
|
|
|
|
// FindUserByConfirmationToken finds users with the matching confirmation token.
|
|
func FindUserByConfirmationToken(tx *storage.Connection, token string) (*User, error) {
|
|
ott, err := FindOneTimeToken(tx, token, ConfirmationToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return FindUserByID(tx, ott.UserID)
|
|
}
|
|
|
|
// FindUserByRecoveryToken finds a user with the matching recovery token.
|
|
func FindUserByRecoveryToken(tx *storage.Connection, token string) (*User, error) {
|
|
ott, err := FindOneTimeToken(tx, token, RecoveryToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return FindUserByID(tx, ott.UserID)
|
|
}
|
|
|
|
// FindUserByEmailChangeToken finds a user with the matching email change token.
|
|
func FindUserByEmailChangeToken(tx *storage.Connection, token string) (*User, error) {
|
|
ott, err := FindOneTimeToken(tx, token, EmailChangeTokenCurrent, EmailChangeTokenNew)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return FindUserByID(tx, ott.UserID)
|
|
}
|
|
|
|
// FindUserByEmailChangeCurrentAndAudience finds a user with the matching email change and audience.
|
|
func FindUserByEmailChangeCurrentAndAudience(tx *storage.Connection, email, token, aud string) (*User, error) {
|
|
ott, err := FindOneTimeToken(tx, token, EmailChangeTokenCurrent)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if ott == nil {
|
|
ott, err = FindOneTimeToken(tx, "pkce_"+token, EmailChangeTokenCurrent)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if ott == nil {
|
|
return nil, err
|
|
}
|
|
|
|
user, err := FindUserByID(tx, ott.UserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if user.Aud != aud && strings.EqualFold(user.GetEmail(), email) {
|
|
return nil, UserNotFoundError{}
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// FindUserByEmailChangeNewAndAudience finds a user with the matching email change and audience.
|
|
func FindUserByEmailChangeNewAndAudience(tx *storage.Connection, email, token, aud string) (*User, error) {
|
|
ott, err := FindOneTimeToken(tx, token, EmailChangeTokenNew)
|
|
if err != nil && !IsNotFoundError(err) {
|
|
return nil, err
|
|
}
|
|
|
|
if ott == nil {
|
|
ott, err = FindOneTimeToken(tx, "pkce_"+token, EmailChangeTokenNew)
|
|
if err != nil && !IsNotFoundError(err) {
|
|
return nil, err
|
|
}
|
|
}
|
|
if ott == nil {
|
|
return nil, err
|
|
}
|
|
|
|
user, err := FindUserByID(tx, ott.UserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if user.Aud != aud && strings.EqualFold(user.EmailChange, email) {
|
|
return nil, UserNotFoundError{}
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// FindUserForEmailChange finds a user requesting for an email change
|
|
func FindUserForEmailChange(tx *storage.Connection, email, token, aud string, secureEmailChangeEnabled bool) (*User, error) {
|
|
if secureEmailChangeEnabled {
|
|
if user, err := FindUserByEmailChangeCurrentAndAudience(tx, email, token, aud); err == nil {
|
|
return user, err
|
|
} else if !IsNotFoundError(err) {
|
|
return nil, err
|
|
}
|
|
}
|
|
return FindUserByEmailChangeNewAndAudience(tx, email, token, aud)
|
|
}
|