507 lines
16 KiB
Go
507 lines
16 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"time"
|
|
|
|
"fmt"
|
|
|
|
"github.com/gofrs/uuid"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/xeipuuv/gojsonschema"
|
|
|
|
"github.com/supabase/auth/internal/hooks"
|
|
"github.com/supabase/auth/internal/metering"
|
|
"github.com/supabase/auth/internal/models"
|
|
"github.com/supabase/auth/internal/observability"
|
|
"github.com/supabase/auth/internal/storage"
|
|
)
|
|
|
|
// AccessTokenClaims is a struct thats used for JWT claims
|
|
type AccessTokenClaims struct {
|
|
jwt.RegisteredClaims
|
|
Email string `json:"email"`
|
|
Phone string `json:"phone"`
|
|
AppMetaData map[string]interface{} `json:"app_metadata"`
|
|
UserMetaData map[string]interface{} `json:"user_metadata"`
|
|
Role string `json:"role"`
|
|
AuthenticatorAssuranceLevel string `json:"aal,omitempty"`
|
|
AuthenticationMethodReference []models.AMREntry `json:"amr,omitempty"`
|
|
SessionId string `json:"session_id,omitempty"`
|
|
IsAnonymous bool `json:"is_anonymous"`
|
|
}
|
|
|
|
// AccessTokenResponse represents an OAuth2 success response
|
|
type AccessTokenResponse struct {
|
|
Token string `json:"access_token"`
|
|
TokenType string `json:"token_type"` // Bearer
|
|
ExpiresIn int `json:"expires_in"`
|
|
ExpiresAt int64 `json:"expires_at"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
User *models.User `json:"user"`
|
|
ProviderAccessToken string `json:"provider_token,omitempty"`
|
|
ProviderRefreshToken string `json:"provider_refresh_token,omitempty"`
|
|
WeakPassword *WeakPasswordError `json:"weak_password,omitempty"`
|
|
}
|
|
|
|
// AsRedirectURL encodes the AccessTokenResponse as a redirect URL that
|
|
// includes the access token response data in a URL fragment.
|
|
func (r *AccessTokenResponse) AsRedirectURL(redirectURL string, extraParams url.Values) string {
|
|
extraParams.Set("access_token", r.Token)
|
|
extraParams.Set("token_type", r.TokenType)
|
|
extraParams.Set("expires_in", strconv.Itoa(r.ExpiresIn))
|
|
extraParams.Set("expires_at", strconv.FormatInt(r.ExpiresAt, 10))
|
|
extraParams.Set("refresh_token", r.RefreshToken)
|
|
|
|
return redirectURL + "#" + extraParams.Encode()
|
|
}
|
|
|
|
// PasswordGrantParams are the parameters the ResourceOwnerPasswordGrant method accepts
|
|
type PasswordGrantParams struct {
|
|
Email string `json:"email"`
|
|
Phone string `json:"phone"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
// PKCEGrantParams are the parameters the PKCEGrant method accepts
|
|
type PKCEGrantParams struct {
|
|
AuthCode string `json:"auth_code"`
|
|
CodeVerifier string `json:"code_verifier"`
|
|
}
|
|
|
|
const useCookieHeader = "x-use-cookie"
|
|
const InvalidLoginMessage = "Invalid login credentials"
|
|
|
|
// Token is the endpoint for OAuth access token requests
|
|
func (a *API) Token(w http.ResponseWriter, r *http.Request) error {
|
|
ctx := r.Context()
|
|
grantType := r.FormValue("grant_type")
|
|
switch grantType {
|
|
case "password":
|
|
return a.ResourceOwnerPasswordGrant(ctx, w, r)
|
|
case "refresh_token":
|
|
return a.RefreshTokenGrant(ctx, w, r)
|
|
case "id_token":
|
|
return a.IdTokenGrant(ctx, w, r)
|
|
case "pkce":
|
|
return a.PKCE(ctx, w, r)
|
|
default:
|
|
return badRequestError(ErrorCodeInvalidCredentials, "unsupported_grant_type")
|
|
}
|
|
}
|
|
|
|
// ResourceOwnerPasswordGrant implements the password grant type flow
|
|
func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
|
|
db := a.db.WithContext(ctx)
|
|
|
|
params := &PasswordGrantParams{}
|
|
if err := retrieveRequestParams(r, params); err != nil {
|
|
return err
|
|
}
|
|
|
|
aud := a.requestAud(ctx, r)
|
|
config := a.config
|
|
|
|
if params.Email != "" && params.Phone != "" {
|
|
return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on login.")
|
|
}
|
|
var user *models.User
|
|
var grantParams models.GrantParams
|
|
var provider string
|
|
var err error
|
|
|
|
grantParams.FillGrantParams(r)
|
|
|
|
if params.Email != "" {
|
|
provider = "email"
|
|
if !config.External.Email.Enabled {
|
|
return unprocessableEntityError(ErrorCodeEmailProviderDisabled, "Email logins are disabled")
|
|
}
|
|
user, err = models.FindUserByEmailAndAudience(db, params.Email, aud)
|
|
} else if params.Phone != "" {
|
|
provider = "phone"
|
|
if !config.External.Phone.Enabled {
|
|
return unprocessableEntityError(ErrorCodePhoneProviderDisabled, "Phone logins are disabled")
|
|
}
|
|
params.Phone = formatPhoneNumber(params.Phone)
|
|
user, err = models.FindUserByPhoneAndAudience(db, params.Phone, aud)
|
|
} else {
|
|
return badRequestError(ErrorCodeValidationFailed, "missing email or phone")
|
|
}
|
|
|
|
if err != nil {
|
|
if models.IsNotFoundError(err) {
|
|
return badRequestError(ErrorCodeInvalidCredentials, InvalidLoginMessage)
|
|
}
|
|
return internalServerError("Database error querying schema").WithInternalError(err)
|
|
}
|
|
|
|
if !user.HasPassword() {
|
|
return badRequestError(ErrorCodeInvalidCredentials, InvalidLoginMessage)
|
|
}
|
|
|
|
if user.IsBanned() {
|
|
return badRequestError(ErrorCodeUserBanned, "User is banned")
|
|
}
|
|
|
|
isValidPassword, shouldReEncrypt, err := user.Authenticate(ctx, db, params.Password, config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var weakPasswordError *WeakPasswordError
|
|
if isValidPassword {
|
|
if err := a.checkPasswordStrength(ctx, params.Password); err != nil {
|
|
if wpe, ok := err.(*WeakPasswordError); ok {
|
|
weakPasswordError = wpe
|
|
} else {
|
|
observability.GetLogEntry(r).Entry.WithError(err).Warn("Password strength check on sign-in failed")
|
|
}
|
|
}
|
|
|
|
if shouldReEncrypt {
|
|
if err := user.SetPassword(ctx, params.Password, true, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil {
|
|
return err
|
|
}
|
|
|
|
// directly change this in the database without
|
|
// calling user.UpdatePassword() because this
|
|
// is not a password change, just encryption
|
|
// change in the database
|
|
if err := db.UpdateOnly(user, "encrypted_password"); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
if config.Hook.PasswordVerificationAttempt.Enabled {
|
|
input := hooks.PasswordVerificationAttemptInput{
|
|
UserID: user.ID,
|
|
Valid: isValidPassword,
|
|
}
|
|
output := hooks.PasswordVerificationAttemptOutput{}
|
|
if err := a.invokeHook(nil, r, &input, &output); err != nil {
|
|
return err
|
|
}
|
|
|
|
if output.Decision == hooks.HookRejection {
|
|
if output.Message == "" {
|
|
output.Message = hooks.DefaultPasswordHookRejectionMessage
|
|
}
|
|
if output.ShouldLogoutUser {
|
|
if err := models.Logout(a.db, user.ID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return badRequestError(ErrorCodeInvalidCredentials, output.Message)
|
|
}
|
|
}
|
|
if !isValidPassword {
|
|
return badRequestError(ErrorCodeInvalidCredentials, InvalidLoginMessage)
|
|
}
|
|
|
|
if params.Email != "" && !user.IsConfirmed() {
|
|
return badRequestError(ErrorCodeEmailNotConfirmed, "Email not confirmed")
|
|
} else if params.Phone != "" && !user.IsPhoneConfirmed() {
|
|
return badRequestError(ErrorCodePhoneNotConfirmed, "Phone not confirmed")
|
|
}
|
|
|
|
var token *AccessTokenResponse
|
|
err = db.Transaction(func(tx *storage.Connection) error {
|
|
var terr error
|
|
if terr = models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{
|
|
"provider": provider,
|
|
}); terr != nil {
|
|
return terr
|
|
}
|
|
token, terr = a.issueRefreshToken(r, tx, user, models.PasswordGrant, grantParams)
|
|
if terr != nil {
|
|
return terr
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
token.WeakPassword = weakPasswordError
|
|
|
|
metering.RecordLogin("password", user.ID)
|
|
return sendJSON(w, http.StatusOK, token)
|
|
}
|
|
|
|
func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
|
|
db := a.db.WithContext(ctx)
|
|
var grantParams models.GrantParams
|
|
|
|
// There is a slight problem with this as it will pick-up the
|
|
// User-Agent and IP addresses from the server if used on the server
|
|
// side. Currently there's no mechanism to distinguish, but the server
|
|
// can be told to at least propagate the User-Agent header.
|
|
grantParams.FillGrantParams(r)
|
|
|
|
params := &PKCEGrantParams{}
|
|
if err := retrieveRequestParams(r, params); err != nil {
|
|
return err
|
|
}
|
|
|
|
if params.AuthCode == "" || params.CodeVerifier == "" {
|
|
return badRequestError(ErrorCodeValidationFailed, "invalid request: both auth code and code verifier should be non-empty")
|
|
}
|
|
|
|
flowState, err := models.FindFlowStateByAuthCode(db, params.AuthCode)
|
|
// Sanity check in case user ID was not set properly
|
|
if models.IsNotFoundError(err) || flowState.UserID == nil {
|
|
return notFoundError(ErrorCodeFlowStateNotFound, "invalid flow state, no valid flow state found")
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
if flowState.IsExpired(a.config.External.FlowStateExpiryDuration) {
|
|
return unprocessableEntityError(ErrorCodeFlowStateExpired, "invalid flow state, flow state has expired")
|
|
}
|
|
|
|
user, err := models.FindUserByID(db, *flowState.UserID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := flowState.VerifyPKCE(params.CodeVerifier); err != nil {
|
|
return badRequestError(ErrorCodeBadCodeVerifier, err.Error())
|
|
}
|
|
|
|
var token *AccessTokenResponse
|
|
err = db.Transaction(func(tx *storage.Connection) error {
|
|
var terr error
|
|
authMethod, err := models.ParseAuthenticationMethod(flowState.AuthenticationMethod)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if terr := models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{
|
|
"provider_type": flowState.ProviderType,
|
|
}); terr != nil {
|
|
return terr
|
|
}
|
|
token, terr = a.issueRefreshToken(r, tx, user, authMethod, grantParams)
|
|
if terr != nil {
|
|
// error type is already handled in issueRefreshToken
|
|
return terr
|
|
}
|
|
token.ProviderAccessToken = flowState.ProviderAccessToken
|
|
// Because not all providers give out a refresh token
|
|
// See corresponding OAuth2 spec: <https://www.rfc-editor.org/rfc/rfc6749.html#section-5.1>
|
|
if flowState.ProviderRefreshToken != "" {
|
|
token.ProviderRefreshToken = flowState.ProviderRefreshToken
|
|
}
|
|
if terr = tx.Destroy(flowState); terr != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return sendJSON(w, http.StatusOK, token)
|
|
}
|
|
|
|
func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) {
|
|
config := a.config
|
|
if sessionId == nil {
|
|
return "", 0, internalServerError("Session is required to issue access token")
|
|
}
|
|
sid := sessionId.String()
|
|
session, terr := models.FindSessionByID(tx, *sessionId, false)
|
|
if terr != nil {
|
|
return "", 0, terr
|
|
}
|
|
aal, amr, terr := session.CalculateAALAndAMR(user)
|
|
if terr != nil {
|
|
return "", 0, terr
|
|
}
|
|
|
|
issuedAt := time.Now().UTC()
|
|
expiresAt := issuedAt.Add(time.Second * time.Duration(config.JWT.Exp))
|
|
|
|
claims := &hooks.AccessTokenClaims{
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: user.ID.String(),
|
|
Audience: jwt.ClaimStrings{user.Aud},
|
|
IssuedAt: jwt.NewNumericDate(issuedAt),
|
|
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
|
Issuer: config.JWT.Issuer,
|
|
},
|
|
Email: user.GetEmail(),
|
|
Phone: user.GetPhone(),
|
|
AppMetaData: user.AppMetaData,
|
|
UserMetaData: user.UserMetaData,
|
|
Role: user.Role,
|
|
SessionId: sid,
|
|
AuthenticatorAssuranceLevel: aal.String(),
|
|
AuthenticationMethodReference: amr,
|
|
IsAnonymous: user.IsAnonymous,
|
|
}
|
|
|
|
var gotrueClaims jwt.Claims = claims
|
|
if config.Hook.CustomAccessToken.Enabled {
|
|
input := hooks.CustomAccessTokenInput{
|
|
UserID: user.ID,
|
|
Claims: claims,
|
|
AuthenticationMethod: authenticationMethod.String(),
|
|
}
|
|
|
|
output := hooks.CustomAccessTokenOutput{}
|
|
|
|
err := a.invokeHook(tx, r, &input, &output)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
gotrueClaims = jwt.MapClaims(output.Claims)
|
|
}
|
|
|
|
signed, err := signJwt(&config.JWT, gotrueClaims)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
return signed, expiresAt.Unix(), nil
|
|
}
|
|
|
|
func (a *API) issueRefreshToken(r *http.Request, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) {
|
|
config := a.config
|
|
|
|
now := time.Now()
|
|
user.LastSignInAt = &now
|
|
|
|
var tokenString string
|
|
var expiresAt int64
|
|
var refreshToken *models.RefreshToken
|
|
|
|
err := conn.Transaction(func(tx *storage.Connection) error {
|
|
var terr error
|
|
|
|
refreshToken, terr = models.GrantAuthenticatedUser(tx, user, grantParams)
|
|
if terr != nil {
|
|
return internalServerError("Database error granting user").WithInternalError(terr)
|
|
}
|
|
|
|
terr = models.AddClaimToSession(tx, *refreshToken.SessionId, authenticationMethod)
|
|
if terr != nil {
|
|
return terr
|
|
}
|
|
|
|
tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, refreshToken.SessionId, authenticationMethod)
|
|
if terr != nil {
|
|
// Account for Hook Error
|
|
httpErr, ok := terr.(*HTTPError)
|
|
if ok {
|
|
return httpErr
|
|
}
|
|
return internalServerError("error generating jwt token").WithInternalError(terr)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &AccessTokenResponse{
|
|
Token: tokenString,
|
|
TokenType: "bearer",
|
|
ExpiresIn: config.JWT.Exp,
|
|
ExpiresAt: expiresAt,
|
|
RefreshToken: refreshToken.Token,
|
|
User: user,
|
|
}, nil
|
|
}
|
|
|
|
func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) {
|
|
ctx := r.Context()
|
|
config := a.config
|
|
var tokenString string
|
|
var expiresAt int64
|
|
var refreshToken *models.RefreshToken
|
|
currentClaims := getClaims(ctx)
|
|
sessionId, err := uuid.FromString(currentClaims.SessionId)
|
|
if err != nil {
|
|
return nil, internalServerError("Cannot read SessionId claim as UUID").WithInternalError(err)
|
|
}
|
|
|
|
err = tx.Transaction(func(tx *storage.Connection) error {
|
|
if terr := models.AddClaimToSession(tx, sessionId, authenticationMethod); terr != nil {
|
|
return terr
|
|
}
|
|
session, terr := models.FindSessionByID(tx, sessionId, false)
|
|
if terr != nil {
|
|
return terr
|
|
}
|
|
currentToken, terr := models.FindTokenBySessionID(tx, &session.ID)
|
|
if terr != nil {
|
|
return terr
|
|
}
|
|
if err := tx.Load(user, "Identities"); err != nil {
|
|
return err
|
|
}
|
|
// Swap to ensure current token is the latest one
|
|
refreshToken, terr = models.GrantRefreshTokenSwap(r, tx, user, currentToken)
|
|
if terr != nil {
|
|
return terr
|
|
}
|
|
aal, _, terr := session.CalculateAALAndAMR(user)
|
|
if terr != nil {
|
|
return terr
|
|
}
|
|
|
|
if err := session.UpdateAALAndAssociatedFactor(tx, aal, grantParams.FactorID); err != nil {
|
|
return err
|
|
}
|
|
|
|
tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, &session.ID, authenticationMethod)
|
|
if terr != nil {
|
|
httpErr, ok := terr.(*HTTPError)
|
|
if ok {
|
|
return httpErr
|
|
}
|
|
return internalServerError("error generating jwt token").WithInternalError(terr)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &AccessTokenResponse{
|
|
Token: tokenString,
|
|
TokenType: "bearer",
|
|
ExpiresIn: config.JWT.Exp,
|
|
ExpiresAt: expiresAt,
|
|
RefreshToken: refreshToken.Token,
|
|
User: user,
|
|
}, nil
|
|
}
|
|
|
|
func validateTokenClaims(outputClaims map[string]interface{}) error {
|
|
schemaLoader := gojsonschema.NewStringLoader(hooks.MinimumViableTokenSchema)
|
|
|
|
documentLoader := gojsonschema.NewGoLoader(outputClaims)
|
|
|
|
result, err := gojsonschema.Validate(schemaLoader, documentLoader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !result.Valid() {
|
|
var errorMessages string
|
|
|
|
for _, desc := range result.Errors() {
|
|
errorMessages += fmt.Sprintf("- %s\n", desc)
|
|
fmt.Printf("- %s\n", desc)
|
|
}
|
|
return fmt.Errorf("output claims do not conform to the expected schema: \n%s", errorMessages)
|
|
|
|
}
|
|
|
|
return nil
|
|
}
|