package models import ( "database/sql" "database/sql/driver" "encoding/json" "fmt" "strings" "time" "github.com/go-webauthn/webauthn/webauthn" "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/pkg/errors" "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/storage" ) type FactorState int const ( FactorStateUnverified FactorState = iota FactorStateVerified ) func (factorState FactorState) String() string { switch factorState { case FactorStateUnverified: return "unverified" case FactorStateVerified: return "verified" } return "" } const TOTP = "totp" const Phone = "phone" const WebAuthn = "webauthn" type AuthenticationMethod int const ( OAuth AuthenticationMethod = iota PasswordGrant OTP TOTPSignIn MFAPhone MFAWebAuthn SSOSAML Recovery Invite MagicLink EmailSignup EmailChange TokenRefresh Anonymous ) func (authMethod AuthenticationMethod) String() string { switch authMethod { case OAuth: return "oauth" case PasswordGrant: return "password" case OTP: return "otp" case TOTPSignIn: return "totp" case Recovery: return "recovery" case Invite: return "invite" case SSOSAML: return "sso/saml" case MagicLink: return "magiclink" case EmailSignup: return "email/signup" case EmailChange: return "email_change" case TokenRefresh: return "token_refresh" case Anonymous: return "anonymous" case MFAPhone: return "mfa/phone" case MFAWebAuthn: return "mfa/webauthn" } return "" } func ParseAuthenticationMethod(authMethod string) (AuthenticationMethod, error) { if strings.HasSuffix(authMethod, "signup") { authMethod = "email/signup" } switch authMethod { case "oauth": return OAuth, nil case "password": return PasswordGrant, nil case "otp": return OTP, nil case "totp": return TOTPSignIn, nil case "recovery": return Recovery, nil case "invite": return Invite, nil case "sso/saml": return SSOSAML, nil case "magiclink": return MagicLink, nil case "email/signup": return EmailSignup, nil case "email_change": return EmailChange, nil case "token_refresh": return TokenRefresh, nil case "mfa/sms": return MFAPhone, nil case "mfa/webauthn": return MFAWebAuthn, nil } return 0, fmt.Errorf("unsupported authentication method %q", authMethod) } type Factor struct { ID uuid.UUID `json:"id" db:"id"` // TODO: Consider removing this nested user field. We don't use it. User User `json:"-" belongs_to:"user"` UserID uuid.UUID `json:"-" db:"user_id"` CreatedAt time.Time `json:"created_at" db:"created_at"` UpdatedAt time.Time `json:"updated_at" db:"updated_at"` Status string `json:"status" db:"status"` FriendlyName string `json:"friendly_name,omitempty" db:"friendly_name"` Secret string `json:"-" db:"secret"` FactorType string `json:"factor_type" db:"factor_type"` Challenge []Challenge `json:"-" has_many:"challenges"` Phone storage.NullString `json:"phone" db:"phone"` LastChallengedAt *time.Time `json:"last_challenged_at" db:"last_challenged_at"` WebAuthnCredential *WebAuthnCredential `json:"-" db:"web_authn_credential"` WebAuthnAAGUID *uuid.UUID `json:"web_authn_aaguid,omitempty" db:"web_authn_aaguid"` } type WebAuthnCredential struct { webauthn.Credential } func (wc *WebAuthnCredential) Value() (driver.Value, error) { if wc == nil { return nil, nil } return json.Marshal(wc) } func (wc *WebAuthnCredential) Scan(value interface{}) error { if value == nil { wc.Credential = webauthn.Credential{} return nil } // Handle byte and string as a precaution, in postgres driver, json/jsonb should be returned as []byte var data []byte switch v := value.(type) { case []byte: data = v case string: data = []byte(v) default: return fmt.Errorf("unsupported type for web_authn_credential: %T", value) } if len(data) == 0 { wc.Credential = webauthn.Credential{} return nil } return json.Unmarshal(data, &wc.Credential) } func (Factor) TableName() string { tableName := "mfa_factors" return tableName } func NewFactor(user *User, friendlyName string, factorType string, state FactorState) *Factor { id := uuid.Must(uuid.NewV4()) factor := &Factor{ ID: id, UserID: user.ID, Status: state.String(), FriendlyName: friendlyName, FactorType: factorType, } return factor } func NewTOTPFactor(user *User, friendlyName string) *Factor { return NewFactor(user, friendlyName, TOTP, FactorStateUnverified) } func NewPhoneFactor(user *User, phone, friendlyName string) *Factor { factor := NewFactor(user, friendlyName, Phone, FactorStateUnverified) factor.Phone = storage.NullString(phone) return factor } func NewWebAuthnFactor(user *User, friendlyName string) *Factor { factor := NewFactor(user, friendlyName, WebAuthn, FactorStateUnverified) return factor } func (f *Factor) SetSecret(secret string, encrypt bool, encryptionKeyID, encryptionKey string) error { f.Secret = secret if encrypt { es, err := crypto.NewEncryptedString(f.ID.String(), []byte(secret), encryptionKeyID, encryptionKey) if err != nil { return err } f.Secret = es.String() } return nil } func (f *Factor) GetSecret(decryptionKeys map[string]string, encrypt bool, encryptionKeyID string) (string, bool, error) { if es := crypto.ParseEncryptedString(f.Secret); es != nil { bytes, err := es.Decrypt(f.ID.String(), decryptionKeys) if err != nil { return "", false, err } return string(bytes), encrypt && es.ShouldReEncrypt(encryptionKeyID), nil } return f.Secret, encrypt, nil } func (f *Factor) SaveWebAuthnCredential(tx *storage.Connection, credential *webauthn.Credential) error { f.WebAuthnCredential = &WebAuthnCredential{ Credential: *credential, } if len(credential.Authenticator.AAGUID) > 0 { aaguidUUID, err := uuid.FromBytes(credential.Authenticator.AAGUID) if err != nil { return fmt.Errorf("WebAuthn authenticator AAGUID is not UUID: %w", err) } f.WebAuthnAAGUID = &aaguidUUID } else { f.WebAuthnAAGUID = nil } return tx.UpdateOnly(f, "web_authn_credential", "web_authn_aaguid", "updated_at") } func FindFactorByFactorID(conn *storage.Connection, factorID uuid.UUID) (*Factor, error) { var factor Factor err := conn.Find(&factor, factorID) if err != nil && errors.Cause(err) == sql.ErrNoRows { return nil, FactorNotFoundError{} } else if err != nil { return nil, err } return &factor, nil } func DeleteUnverifiedFactors(tx *storage.Connection, user *User, factorType string) error { if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Factor{}}).TableName()+" WHERE user_id = ? and status = ? and factor_type = ?", user.ID, FactorStateUnverified.String(), factorType).Exec(); err != nil { return err } return nil } func (f *Factor) CreateChallenge(ipAddress string) *Challenge { id := uuid.Must(uuid.NewV4()) challenge := &Challenge{ ID: id, FactorID: f.ID, IPAddress: ipAddress, } return challenge } func (f *Factor) WriteChallengeToDatabase(tx *storage.Connection, challenge *Challenge) error { if challenge.FactorID != f.ID { return errors.New("Can only write challenges that you own") } now := time.Now() f.LastChallengedAt = &now if terr := tx.Create(challenge); terr != nil { return terr } if err := tx.UpdateOnly(f, "last_challenged_at"); err != nil { return err } return nil } func (f *Factor) CreatePhoneChallenge(ipAddress string, otpCode string, encrypt bool, encryptionKeyID, encryptionKey string) (*Challenge, error) { phoneChallenge := f.CreateChallenge(ipAddress) if err := phoneChallenge.SetOtpCode(otpCode, encrypt, encryptionKeyID, encryptionKey); err != nil { return nil, err } return phoneChallenge, nil } // UpdateFriendlyName changes the friendly name func (f *Factor) UpdateFriendlyName(tx *storage.Connection, friendlyName string) error { f.FriendlyName = friendlyName return tx.UpdateOnly(f, "friendly_name", "updated_at") } func (f *Factor) UpdatePhone(tx *storage.Connection, phone string) error { f.Phone = storage.NullString(phone) return tx.UpdateOnly(f, "phone", "updated_at") } // UpdateStatus modifies the factor status func (f *Factor) UpdateStatus(tx *storage.Connection, state FactorState) error { f.Status = state.String() return tx.UpdateOnly(f, "status", "updated_at") } func (f *Factor) DowngradeSessionsToAAL1(tx *storage.Connection) error { sessions, err := FindSessionsByFactorID(tx, f.ID) if err != nil { return err } for _, session := range sessions { if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: AMRClaim{}}).TableName()+" WHERE session_id = ? AND authentication_method = ?", session.ID, f.FactorType).Exec(); err != nil { return err } } return updateFactorAssociatedSessions(tx, f.UserID, f.ID, AAL1.String()) } func (f *Factor) IsVerified() bool { return f.Status == FactorStateVerified.String() } func (f *Factor) IsUnverified() bool { return f.Status == FactorStateUnverified.String() } func (f *Factor) IsPhoneFactor() bool { return f.FactorType == Phone } func (f *Factor) FindChallengeByID(conn *storage.Connection, challengeID uuid.UUID) (*Challenge, error) { var challenge Challenge err := conn.Q().Where("id = ? and factor_id = ?", challengeID, f.ID).First(&challenge) if err != nil && errors.Cause(err) == sql.ErrNoRows { return nil, ChallengeNotFoundError{} } else if err != nil { return nil, err } return &challenge, nil } func DeleteFactorsByUserId(tx *storage.Connection, userId uuid.UUID) error { if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Factor{}}).TableName()+" WHERE user_id = ?", userId).Exec(); err != nil { return err } return nil } func DeleteExpiredFactors(tx *storage.Connection, validityDuration time.Duration) error { totalSeconds := int64(validityDuration / time.Second) validityInterval := fmt.Sprintf("interval '%d seconds'", totalSeconds) factorTable := (&pop.Model{Value: Factor{}}).TableName() challengeTable := (&pop.Model{Value: Challenge{}}).TableName() query := fmt.Sprintf(`delete from %q where status != 'verified' and not exists (select * from %q where %q.id = %q.factor_id ) and created_at + %s < current_timestamp;`, factorTable, challengeTable, factorTable, challengeTable, validityInterval) if err := tx.RawQuery(query).Exec(); err != nil { return err } return nil } func (f *Factor) FindLatestUnexpiredChallenge(tx *storage.Connection, expiryDuration float64) (*Challenge, error) { now := time.Now() var challenge Challenge expirationTime := now.Add(time.Duration(expiryDuration) * time.Second) err := tx.Where("sent_at > ? and factor_id = ?", expirationTime, f.ID). Order("sent_at desc"). First(&challenge) if err != nil && errors.Cause(err) == sql.ErrNoRows { return nil, ChallengeNotFoundError{} } else if err != nil { return nil, err } return &challenge, nil }