chatdesk-ui/auth_v2.169.0/internal/models/factor.go

399 lines
11 KiB
Go

package models
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/supabase/auth/internal/crypto"
"github.com/supabase/auth/internal/storage"
)
type FactorState int
const (
FactorStateUnverified FactorState = iota
FactorStateVerified
)
func (factorState FactorState) String() string {
switch factorState {
case FactorStateUnverified:
return "unverified"
case FactorStateVerified:
return "verified"
}
return ""
}
const TOTP = "totp"
const Phone = "phone"
const WebAuthn = "webauthn"
type AuthenticationMethod int
const (
OAuth AuthenticationMethod = iota
PasswordGrant
OTP
TOTPSignIn
MFAPhone
MFAWebAuthn
SSOSAML
Recovery
Invite
MagicLink
EmailSignup
EmailChange
TokenRefresh
Anonymous
)
func (authMethod AuthenticationMethod) String() string {
switch authMethod {
case OAuth:
return "oauth"
case PasswordGrant:
return "password"
case OTP:
return "otp"
case TOTPSignIn:
return "totp"
case Recovery:
return "recovery"
case Invite:
return "invite"
case SSOSAML:
return "sso/saml"
case MagicLink:
return "magiclink"
case EmailSignup:
return "email/signup"
case EmailChange:
return "email_change"
case TokenRefresh:
return "token_refresh"
case Anonymous:
return "anonymous"
case MFAPhone:
return "mfa/phone"
case MFAWebAuthn:
return "mfa/webauthn"
}
return ""
}
func ParseAuthenticationMethod(authMethod string) (AuthenticationMethod, error) {
if strings.HasSuffix(authMethod, "signup") {
authMethod = "email/signup"
}
switch authMethod {
case "oauth":
return OAuth, nil
case "password":
return PasswordGrant, nil
case "otp":
return OTP, nil
case "totp":
return TOTPSignIn, nil
case "recovery":
return Recovery, nil
case "invite":
return Invite, nil
case "sso/saml":
return SSOSAML, nil
case "magiclink":
return MagicLink, nil
case "email/signup":
return EmailSignup, nil
case "email_change":
return EmailChange, nil
case "token_refresh":
return TokenRefresh, nil
case "mfa/sms":
return MFAPhone, nil
case "mfa/webauthn":
return MFAWebAuthn, nil
}
return 0, fmt.Errorf("unsupported authentication method %q", authMethod)
}
type Factor struct {
ID uuid.UUID `json:"id" db:"id"`
// TODO: Consider removing this nested user field. We don't use it.
User User `json:"-" belongs_to:"user"`
UserID uuid.UUID `json:"-" db:"user_id"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
Status string `json:"status" db:"status"`
FriendlyName string `json:"friendly_name,omitempty" db:"friendly_name"`
Secret string `json:"-" db:"secret"`
FactorType string `json:"factor_type" db:"factor_type"`
Challenge []Challenge `json:"-" has_many:"challenges"`
Phone storage.NullString `json:"phone" db:"phone"`
LastChallengedAt *time.Time `json:"last_challenged_at" db:"last_challenged_at"`
WebAuthnCredential *WebAuthnCredential `json:"-" db:"web_authn_credential"`
WebAuthnAAGUID *uuid.UUID `json:"web_authn_aaguid,omitempty" db:"web_authn_aaguid"`
}
type WebAuthnCredential struct {
webauthn.Credential
}
func (wc *WebAuthnCredential) Value() (driver.Value, error) {
if wc == nil {
return nil, nil
}
return json.Marshal(wc)
}
func (wc *WebAuthnCredential) Scan(value interface{}) error {
if value == nil {
wc.Credential = webauthn.Credential{}
return nil
}
// Handle byte and string as a precaution, in postgres driver, json/jsonb should be returned as []byte
var data []byte
switch v := value.(type) {
case []byte:
data = v
case string:
data = []byte(v)
default:
return fmt.Errorf("unsupported type for web_authn_credential: %T", value)
}
if len(data) == 0 {
wc.Credential = webauthn.Credential{}
return nil
}
return json.Unmarshal(data, &wc.Credential)
}
func (Factor) TableName() string {
tableName := "mfa_factors"
return tableName
}
func NewFactor(user *User, friendlyName string, factorType string, state FactorState) *Factor {
id := uuid.Must(uuid.NewV4())
factor := &Factor{
ID: id,
UserID: user.ID,
Status: state.String(),
FriendlyName: friendlyName,
FactorType: factorType,
}
return factor
}
func NewTOTPFactor(user *User, friendlyName string) *Factor {
return NewFactor(user, friendlyName, TOTP, FactorStateUnverified)
}
func NewPhoneFactor(user *User, phone, friendlyName string) *Factor {
factor := NewFactor(user, friendlyName, Phone, FactorStateUnverified)
factor.Phone = storage.NullString(phone)
return factor
}
func NewWebAuthnFactor(user *User, friendlyName string) *Factor {
factor := NewFactor(user, friendlyName, WebAuthn, FactorStateUnverified)
return factor
}
func (f *Factor) SetSecret(secret string, encrypt bool, encryptionKeyID, encryptionKey string) error {
f.Secret = secret
if encrypt {
es, err := crypto.NewEncryptedString(f.ID.String(), []byte(secret), encryptionKeyID, encryptionKey)
if err != nil {
return err
}
f.Secret = es.String()
}
return nil
}
func (f *Factor) GetSecret(decryptionKeys map[string]string, encrypt bool, encryptionKeyID string) (string, bool, error) {
if es := crypto.ParseEncryptedString(f.Secret); es != nil {
bytes, err := es.Decrypt(f.ID.String(), decryptionKeys)
if err != nil {
return "", false, err
}
return string(bytes), encrypt && es.ShouldReEncrypt(encryptionKeyID), nil
}
return f.Secret, encrypt, nil
}
func (f *Factor) SaveWebAuthnCredential(tx *storage.Connection, credential *webauthn.Credential) error {
f.WebAuthnCredential = &WebAuthnCredential{
Credential: *credential,
}
if len(credential.Authenticator.AAGUID) > 0 {
aaguidUUID, err := uuid.FromBytes(credential.Authenticator.AAGUID)
if err != nil {
return fmt.Errorf("WebAuthn authenticator AAGUID is not UUID: %w", err)
}
f.WebAuthnAAGUID = &aaguidUUID
} else {
f.WebAuthnAAGUID = nil
}
return tx.UpdateOnly(f, "web_authn_credential", "web_authn_aaguid", "updated_at")
}
func FindFactorByFactorID(conn *storage.Connection, factorID uuid.UUID) (*Factor, error) {
var factor Factor
err := conn.Find(&factor, factorID)
if err != nil && errors.Cause(err) == sql.ErrNoRows {
return nil, FactorNotFoundError{}
} else if err != nil {
return nil, err
}
return &factor, nil
}
func DeleteUnverifiedFactors(tx *storage.Connection, user *User, factorType string) error {
if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Factor{}}).TableName()+" WHERE user_id = ? and status = ? and factor_type = ?", user.ID, FactorStateUnverified.String(), factorType).Exec(); err != nil {
return err
}
return nil
}
func (f *Factor) CreateChallenge(ipAddress string) *Challenge {
id := uuid.Must(uuid.NewV4())
challenge := &Challenge{
ID: id,
FactorID: f.ID,
IPAddress: ipAddress,
}
return challenge
}
func (f *Factor) WriteChallengeToDatabase(tx *storage.Connection, challenge *Challenge) error {
if challenge.FactorID != f.ID {
return errors.New("Can only write challenges that you own")
}
now := time.Now()
f.LastChallengedAt = &now
if terr := tx.Create(challenge); terr != nil {
return terr
}
if err := tx.UpdateOnly(f, "last_challenged_at"); err != nil {
return err
}
return nil
}
func (f *Factor) CreatePhoneChallenge(ipAddress string, otpCode string, encrypt bool, encryptionKeyID, encryptionKey string) (*Challenge, error) {
phoneChallenge := f.CreateChallenge(ipAddress)
if err := phoneChallenge.SetOtpCode(otpCode, encrypt, encryptionKeyID, encryptionKey); err != nil {
return nil, err
}
return phoneChallenge, nil
}
// UpdateFriendlyName changes the friendly name
func (f *Factor) UpdateFriendlyName(tx *storage.Connection, friendlyName string) error {
f.FriendlyName = friendlyName
return tx.UpdateOnly(f, "friendly_name", "updated_at")
}
func (f *Factor) UpdatePhone(tx *storage.Connection, phone string) error {
f.Phone = storage.NullString(phone)
return tx.UpdateOnly(f, "phone", "updated_at")
}
// UpdateStatus modifies the factor status
func (f *Factor) UpdateStatus(tx *storage.Connection, state FactorState) error {
f.Status = state.String()
return tx.UpdateOnly(f, "status", "updated_at")
}
func (f *Factor) DowngradeSessionsToAAL1(tx *storage.Connection) error {
sessions, err := FindSessionsByFactorID(tx, f.ID)
if err != nil {
return err
}
for _, session := range sessions {
if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: AMRClaim{}}).TableName()+" WHERE session_id = ? AND authentication_method = ?", session.ID, f.FactorType).Exec(); err != nil {
return err
}
}
return updateFactorAssociatedSessions(tx, f.UserID, f.ID, AAL1.String())
}
func (f *Factor) IsVerified() bool {
return f.Status == FactorStateVerified.String()
}
func (f *Factor) IsUnverified() bool {
return f.Status == FactorStateUnverified.String()
}
func (f *Factor) IsPhoneFactor() bool {
return f.FactorType == Phone
}
func (f *Factor) FindChallengeByID(conn *storage.Connection, challengeID uuid.UUID) (*Challenge, error) {
var challenge Challenge
err := conn.Q().Where("id = ? and factor_id = ?", challengeID, f.ID).First(&challenge)
if err != nil && errors.Cause(err) == sql.ErrNoRows {
return nil, ChallengeNotFoundError{}
} else if err != nil {
return nil, err
}
return &challenge, nil
}
func DeleteFactorsByUserId(tx *storage.Connection, userId uuid.UUID) error {
if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Factor{}}).TableName()+" WHERE user_id = ?", userId).Exec(); err != nil {
return err
}
return nil
}
func DeleteExpiredFactors(tx *storage.Connection, validityDuration time.Duration) error {
totalSeconds := int64(validityDuration / time.Second)
validityInterval := fmt.Sprintf("interval '%d seconds'", totalSeconds)
factorTable := (&pop.Model{Value: Factor{}}).TableName()
challengeTable := (&pop.Model{Value: Challenge{}}).TableName()
query := fmt.Sprintf(`delete from %q where status != 'verified' and not exists (select * from %q where %q.id = %q.factor_id ) and created_at + %s < current_timestamp;`, factorTable, challengeTable, factorTable, challengeTable, validityInterval)
if err := tx.RawQuery(query).Exec(); err != nil {
return err
}
return nil
}
func (f *Factor) FindLatestUnexpiredChallenge(tx *storage.Connection, expiryDuration float64) (*Challenge, error) {
now := time.Now()
var challenge Challenge
expirationTime := now.Add(time.Duration(expiryDuration) * time.Second)
err := tx.Where("sent_at > ? and factor_id = ?", expirationTime, f.ID).
Order("sent_at desc").
First(&challenge)
if err != nil && errors.Cause(err) == sql.ErrNoRows {
return nil, ChallengeNotFoundError{}
} else if err != nil {
return nil, err
}
return &challenge, nil
}