170 lines
4.9 KiB
Go
170 lines
4.9 KiB
Go
package models
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"crypto/subtle"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/supabase/auth/internal/storage"
|
|
|
|
"github.com/gofrs/uuid"
|
|
)
|
|
|
|
const InvalidCodeChallengeError = "code challenge does not match previously saved code verifier"
|
|
const InvalidCodeMethodError = "code challenge method not supported"
|
|
|
|
type FlowState struct {
|
|
ID uuid.UUID `json:"id" db:"id"`
|
|
UserID *uuid.UUID `json:"user_id,omitempty" db:"user_id"`
|
|
AuthCode string `json:"auth_code" db:"auth_code"`
|
|
AuthenticationMethod string `json:"authentication_method" db:"authentication_method"`
|
|
CodeChallenge string `json:"code_challenge" db:"code_challenge"`
|
|
CodeChallengeMethod string `json:"code_challenge_method" db:"code_challenge_method"`
|
|
ProviderType string `json:"provider_type" db:"provider_type"`
|
|
ProviderAccessToken string `json:"provider_access_token" db:"provider_access_token"`
|
|
ProviderRefreshToken string `json:"provider_refresh_token" db:"provider_refresh_token"`
|
|
AuthCodeIssuedAt *time.Time `json:"auth_code_issued_at" db:"auth_code_issued_at"`
|
|
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
|
}
|
|
|
|
type CodeChallengeMethod int
|
|
|
|
const (
|
|
SHA256 CodeChallengeMethod = iota
|
|
Plain
|
|
)
|
|
|
|
func (codeChallengeMethod CodeChallengeMethod) String() string {
|
|
switch codeChallengeMethod {
|
|
case SHA256:
|
|
return "s256"
|
|
case Plain:
|
|
return "plain"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func ParseCodeChallengeMethod(codeChallengeMethod string) (CodeChallengeMethod, error) {
|
|
switch strings.ToLower(codeChallengeMethod) {
|
|
case "s256":
|
|
return SHA256, nil
|
|
case "plain":
|
|
return Plain, nil
|
|
}
|
|
return 0, fmt.Errorf("unsupported code_challenge method %q", codeChallengeMethod)
|
|
}
|
|
|
|
type FlowType int
|
|
|
|
const (
|
|
PKCEFlow FlowType = iota
|
|
ImplicitFlow
|
|
)
|
|
|
|
func (flowType FlowType) String() string {
|
|
switch flowType {
|
|
case PKCEFlow:
|
|
return "pkce"
|
|
case ImplicitFlow:
|
|
return "implicit"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (FlowState) TableName() string {
|
|
tableName := "flow_state"
|
|
return tableName
|
|
}
|
|
|
|
func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) *FlowState {
|
|
id := uuid.Must(uuid.NewV4())
|
|
authCode := uuid.Must(uuid.NewV4())
|
|
flowState := &FlowState{
|
|
ID: id,
|
|
ProviderType: providerType,
|
|
CodeChallenge: codeChallenge,
|
|
CodeChallengeMethod: codeChallengeMethod.String(),
|
|
AuthCode: authCode.String(),
|
|
AuthenticationMethod: authenticationMethod.String(),
|
|
UserID: userID,
|
|
}
|
|
return flowState
|
|
}
|
|
|
|
func FindFlowStateByAuthCode(tx *storage.Connection, authCode string) (*FlowState, error) {
|
|
obj := &FlowState{}
|
|
if err := tx.Eager().Q().Where("auth_code = ?", authCode).First(obj); err != nil {
|
|
if errors.Cause(err) == sql.ErrNoRows {
|
|
return nil, FlowStateNotFoundError{}
|
|
}
|
|
return nil, errors.Wrap(err, "error finding flow state")
|
|
}
|
|
|
|
return obj, nil
|
|
}
|
|
|
|
func FindFlowStateByID(tx *storage.Connection, id string) (*FlowState, error) {
|
|
obj := &FlowState{}
|
|
if err := tx.Eager().Q().Where("id = ?", id).First(obj); err != nil {
|
|
if errors.Cause(err) == sql.ErrNoRows {
|
|
return nil, FlowStateNotFoundError{}
|
|
}
|
|
return nil, errors.Wrap(err, "error finding flow state")
|
|
}
|
|
|
|
return obj, nil
|
|
}
|
|
|
|
func FindFlowStateByUserID(tx *storage.Connection, id string, authenticationMethod AuthenticationMethod) (*FlowState, error) {
|
|
obj := &FlowState{}
|
|
if err := tx.Eager().Q().Where("user_id = ? and authentication_method = ?", id, authenticationMethod).Last(obj); err != nil {
|
|
if errors.Cause(err) == sql.ErrNoRows {
|
|
return nil, FlowStateNotFoundError{}
|
|
}
|
|
return nil, errors.Wrap(err, "error finding flow state")
|
|
}
|
|
|
|
return obj, nil
|
|
}
|
|
|
|
func (f *FlowState) VerifyPKCE(codeVerifier string) error {
|
|
switch f.CodeChallengeMethod {
|
|
case SHA256.String():
|
|
hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier))
|
|
encodedCodeVerifier := base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:])
|
|
if subtle.ConstantTimeCompare([]byte(f.CodeChallenge), []byte(encodedCodeVerifier)) != 1 {
|
|
return errors.New(InvalidCodeChallengeError)
|
|
}
|
|
case Plain.String():
|
|
if subtle.ConstantTimeCompare([]byte(f.CodeChallenge), []byte(codeVerifier)) != 1 {
|
|
return errors.New(InvalidCodeChallengeError)
|
|
}
|
|
default:
|
|
return errors.New(InvalidCodeMethodError)
|
|
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *FlowState) IsExpired(expiryDuration time.Duration) bool {
|
|
if f.AuthCodeIssuedAt != nil && f.AuthenticationMethod == MagicLink.String() {
|
|
return time.Now().After(f.AuthCodeIssuedAt.Add(expiryDuration))
|
|
}
|
|
return time.Now().After(f.CreatedAt.Add(expiryDuration))
|
|
}
|
|
|
|
func (f *FlowState) RecordAuthCodeIssuedAtTime(tx *storage.Connection) error {
|
|
issueTime := time.Now()
|
|
f.AuthCodeIssuedAt = &issueTime
|
|
if err := tx.Update(f); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|