685 lines
22 KiB
Go
685 lines
22 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/fatih/structs"
|
|
"github.com/gofrs/uuid"
|
|
jwt "github.com/golang-jwt/jwt/v5"
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/supabase/auth/internal/api/provider"
|
|
"github.com/supabase/auth/internal/conf"
|
|
"github.com/supabase/auth/internal/models"
|
|
"github.com/supabase/auth/internal/observability"
|
|
"github.com/supabase/auth/internal/storage"
|
|
"github.com/supabase/auth/internal/utilities"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow
|
|
type ExternalProviderClaims struct {
|
|
AuthMicroserviceClaims
|
|
Provider string `json:"provider"`
|
|
InviteToken string `json:"invite_token,omitempty"`
|
|
Referrer string `json:"referrer,omitempty"`
|
|
FlowStateID string `json:"flow_state_id"`
|
|
LinkingTargetID string `json:"linking_target_id,omitempty"`
|
|
}
|
|
|
|
// ExternalProviderRedirect redirects the request to the oauth provider
|
|
func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) error {
|
|
rurl, err := a.GetExternalProviderRedirectURL(w, r, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
http.Redirect(w, r, rurl, http.StatusFound)
|
|
return nil
|
|
}
|
|
|
|
// GetExternalProviderRedirectURL returns the URL to start the oauth flow with the corresponding oauth provider
|
|
func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Request, linkingTargetUser *models.User) (string, error) {
|
|
ctx := r.Context()
|
|
db := a.db.WithContext(ctx)
|
|
config := a.config
|
|
|
|
query := r.URL.Query()
|
|
providerType := query.Get("provider")
|
|
scopes := query.Get("scopes")
|
|
codeChallenge := query.Get("code_challenge")
|
|
codeChallengeMethod := query.Get("code_challenge_method")
|
|
|
|
p, err := a.Provider(ctx, providerType, scopes)
|
|
if err != nil {
|
|
return "", badRequestError(ErrorCodeValidationFailed, "Unsupported provider: %+v", err).WithInternalError(err)
|
|
}
|
|
|
|
inviteToken := query.Get("invite_token")
|
|
if inviteToken != "" {
|
|
_, userErr := models.FindUserByConfirmationToken(db, inviteToken)
|
|
if userErr != nil {
|
|
if models.IsNotFoundError(userErr) {
|
|
return "", notFoundError(ErrorCodeUserNotFound, "User identified by token not found")
|
|
}
|
|
return "", internalServerError("Database error finding user").WithInternalError(userErr)
|
|
}
|
|
}
|
|
|
|
redirectURL := utilities.GetReferrer(r, config)
|
|
log := observability.GetLogEntry(r).Entry
|
|
log.WithField("provider", providerType).Info("Redirecting to external provider")
|
|
if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil {
|
|
return "", err
|
|
}
|
|
flowType := getFlowFromChallenge(codeChallenge)
|
|
|
|
flowStateID := ""
|
|
if isPKCEFlow(flowType) {
|
|
flowState, err := generateFlowState(a.db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
flowStateID = flowState.ID.String()
|
|
}
|
|
|
|
claims := ExternalProviderClaims{
|
|
AuthMicroserviceClaims: AuthMicroserviceClaims{
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
|
|
},
|
|
SiteURL: config.SiteURL,
|
|
InstanceID: uuid.Nil.String(),
|
|
},
|
|
Provider: providerType,
|
|
InviteToken: inviteToken,
|
|
Referrer: redirectURL,
|
|
FlowStateID: flowStateID,
|
|
}
|
|
|
|
if linkingTargetUser != nil {
|
|
// this means that the user is performing manual linking
|
|
claims.LinkingTargetID = linkingTargetUser.ID.String()
|
|
}
|
|
|
|
tokenString, err := signJwt(&config.JWT, claims)
|
|
if err != nil {
|
|
return "", internalServerError("Error creating state").WithInternalError(err)
|
|
}
|
|
|
|
authUrlParams := make([]oauth2.AuthCodeOption, 0)
|
|
query.Del("scopes")
|
|
query.Del("provider")
|
|
query.Del("code_challenge")
|
|
query.Del("code_challenge_method")
|
|
for key := range query {
|
|
if key == "workos_provider" {
|
|
// See https://workos.com/docs/reference/sso/authorize/get
|
|
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam("provider", query.Get(key)))
|
|
} else {
|
|
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam(key, query.Get(key)))
|
|
}
|
|
}
|
|
|
|
authURL := p.AuthCodeURL(tokenString, authUrlParams...)
|
|
|
|
return authURL, nil
|
|
}
|
|
|
|
// ExternalProviderCallback handles the callback endpoint in the external oauth provider flow
|
|
func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) error {
|
|
rurl := a.getExternalRedirectURL(r)
|
|
u, err := url.Parse(rurl)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
redirectErrors(a.internalExternalProviderCallback, w, r, u)
|
|
return nil
|
|
}
|
|
|
|
func (a *API) handleOAuthCallback(r *http.Request) (*OAuthProviderData, error) {
|
|
ctx := r.Context()
|
|
providerType := getExternalProviderType(ctx)
|
|
|
|
var oAuthResponseData *OAuthProviderData
|
|
var err error
|
|
switch providerType {
|
|
case "twitter":
|
|
// future OAuth1.0 providers will use this method
|
|
oAuthResponseData, err = a.oAuth1Callback(ctx, providerType)
|
|
default:
|
|
oAuthResponseData, err = a.oAuthCallback(ctx, r, providerType)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return oAuthResponseData, nil
|
|
}
|
|
|
|
func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Request) error {
|
|
ctx := r.Context()
|
|
db := a.db.WithContext(ctx)
|
|
|
|
var grantParams models.GrantParams
|
|
grantParams.FillGrantParams(r)
|
|
|
|
providerType := getExternalProviderType(ctx)
|
|
data, err := a.handleOAuthCallback(r)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
userData := data.userData
|
|
if len(userData.Emails) <= 0 {
|
|
return internalServerError("Error getting user email from external provider")
|
|
}
|
|
userData.Metadata.EmailVerified = false
|
|
for _, email := range userData.Emails {
|
|
if email.Primary {
|
|
userData.Metadata.Email = email.Email
|
|
userData.Metadata.EmailVerified = email.Verified
|
|
break
|
|
} else {
|
|
userData.Metadata.Email = email.Email
|
|
userData.Metadata.EmailVerified = email.Verified
|
|
}
|
|
}
|
|
providerAccessToken := data.token
|
|
providerRefreshToken := data.refreshToken
|
|
|
|
var flowState *models.FlowState
|
|
// if there's a non-empty FlowStateID we perform PKCE Flow
|
|
if flowStateID := getFlowStateID(ctx); flowStateID != "" {
|
|
flowState, err = models.FindFlowStateByID(a.db, flowStateID)
|
|
if models.IsNotFoundError(err) {
|
|
return unprocessableEntityError(ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err)
|
|
} else if err != nil {
|
|
return internalServerError("Failed to find flow state").WithInternalError(err)
|
|
}
|
|
|
|
}
|
|
|
|
var user *models.User
|
|
var token *AccessTokenResponse
|
|
err = db.Transaction(func(tx *storage.Connection) error {
|
|
var terr error
|
|
if targetUser := getTargetUser(ctx); targetUser != nil {
|
|
if user, terr = a.linkIdentityToUser(r, ctx, tx, userData, providerType); terr != nil {
|
|
return terr
|
|
}
|
|
} else if inviteToken := getInviteToken(ctx); inviteToken != "" {
|
|
if user, terr = a.processInvite(r, tx, userData, inviteToken, providerType); terr != nil {
|
|
return terr
|
|
}
|
|
} else {
|
|
if user, terr = a.createAccountFromExternalIdentity(tx, r, userData, providerType); terr != nil {
|
|
return terr
|
|
}
|
|
}
|
|
if flowState != nil {
|
|
// This means that the callback is using PKCE
|
|
flowState.ProviderAccessToken = providerAccessToken
|
|
flowState.ProviderRefreshToken = providerRefreshToken
|
|
flowState.UserID = &(user.ID)
|
|
issueTime := time.Now()
|
|
flowState.AuthCodeIssuedAt = &issueTime
|
|
|
|
terr = tx.Update(flowState)
|
|
} else {
|
|
token, terr = a.issueRefreshToken(r, tx, user, models.OAuth, grantParams)
|
|
}
|
|
|
|
if terr != nil {
|
|
return oauthError("server_error", terr.Error())
|
|
}
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rurl := a.getExternalRedirectURL(r)
|
|
if flowState != nil {
|
|
// This means that the callback is using PKCE
|
|
// Set the flowState.AuthCode to the query param here
|
|
rurl, err = a.prepPKCERedirectURL(rurl, flowState.AuthCode)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else if token != nil {
|
|
q := url.Values{}
|
|
q.Set("provider_token", providerAccessToken)
|
|
// Because not all providers give out a refresh token
|
|
// See corresponding OAuth2 spec: <https://www.rfc-editor.org/rfc/rfc6749.html#section-5.1>
|
|
if providerRefreshToken != "" {
|
|
q.Set("provider_refresh_token", providerRefreshToken)
|
|
}
|
|
|
|
rurl = token.AsRedirectURL(rurl, q)
|
|
|
|
}
|
|
|
|
http.Redirect(w, r, rurl, http.StatusFound)
|
|
return nil
|
|
}
|
|
|
|
func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.Request, userData *provider.UserProvidedData, providerType string) (*models.User, error) {
|
|
ctx := r.Context()
|
|
aud := a.requestAud(ctx, r)
|
|
config := a.config
|
|
|
|
var user *models.User
|
|
var identity *models.Identity
|
|
var identityData map[string]interface{}
|
|
if userData.Metadata != nil {
|
|
identityData = structs.Map(userData.Metadata)
|
|
}
|
|
|
|
decision, terr := models.DetermineAccountLinking(tx, config, userData.Emails, aud, providerType, userData.Metadata.Subject)
|
|
if terr != nil {
|
|
return nil, terr
|
|
}
|
|
|
|
switch decision.Decision {
|
|
case models.LinkAccount:
|
|
user = decision.User
|
|
|
|
if identity, terr = a.createNewIdentity(tx, user, providerType, identityData); terr != nil {
|
|
return nil, terr
|
|
}
|
|
|
|
if terr = user.UpdateUserMetaData(tx, identityData); terr != nil {
|
|
return nil, terr
|
|
}
|
|
|
|
if terr = user.UpdateAppMetaDataProviders(tx); terr != nil {
|
|
return nil, terr
|
|
}
|
|
|
|
case models.CreateAccount:
|
|
if config.DisableSignup {
|
|
return nil, unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance")
|
|
}
|
|
|
|
params := &SignupParams{
|
|
Provider: providerType,
|
|
Email: decision.CandidateEmail.Email,
|
|
Aud: aud,
|
|
Data: identityData,
|
|
}
|
|
|
|
isSSOUser := false
|
|
if strings.HasPrefix(decision.LinkingDomain, "sso:") {
|
|
isSSOUser = true
|
|
}
|
|
|
|
// because params above sets no password, this method is not
|
|
// computationally hard so it can be used within a database
|
|
// transaction
|
|
user, terr = params.ToUserModel(isSSOUser)
|
|
if terr != nil {
|
|
return nil, terr
|
|
}
|
|
|
|
if user, terr = a.signupNewUser(tx, user); terr != nil {
|
|
return nil, terr
|
|
}
|
|
|
|
if identity, terr = a.createNewIdentity(tx, user, providerType, identityData); terr != nil {
|
|
return nil, terr
|
|
}
|
|
user.Identities = append(user.Identities, *identity)
|
|
case models.AccountExists:
|
|
user = decision.User
|
|
identity = decision.Identities[0]
|
|
|
|
identity.IdentityData = identityData
|
|
if terr = tx.UpdateOnly(identity, "identity_data", "last_sign_in_at"); terr != nil {
|
|
return nil, terr
|
|
}
|
|
if terr = user.UpdateUserMetaData(tx, identityData); terr != nil {
|
|
return nil, terr
|
|
}
|
|
if terr = user.UpdateAppMetaDataProviders(tx); terr != nil {
|
|
return nil, terr
|
|
}
|
|
|
|
case models.MultipleAccounts:
|
|
return nil, internalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain)
|
|
|
|
default:
|
|
return nil, internalServerError("Unknown automatic linking decision: %v", decision.Decision)
|
|
}
|
|
|
|
if user.IsBanned() {
|
|
return nil, forbiddenError(ErrorCodeUserBanned, "User is banned")
|
|
}
|
|
|
|
if !user.IsConfirmed() {
|
|
// The user may have other unconfirmed email + password
|
|
// combination, phone or oauth identities. These identities
|
|
// need to be removed when a new oauth identity is being added
|
|
// to prevent pre-account takeover attacks from happening.
|
|
if terr = user.RemoveUnconfirmedIdentities(tx, identity); terr != nil {
|
|
return nil, internalServerError("Error updating user").WithInternalError(terr)
|
|
}
|
|
if decision.CandidateEmail.Verified || config.Mailer.Autoconfirm {
|
|
if terr := models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", map[string]interface{}{
|
|
"provider": providerType,
|
|
}); terr != nil {
|
|
return nil, terr
|
|
}
|
|
// fall through to auto-confirm and issue token
|
|
if terr = user.Confirm(tx); terr != nil {
|
|
return nil, internalServerError("Error updating user").WithInternalError(terr)
|
|
}
|
|
} else {
|
|
emailConfirmationSent := false
|
|
if decision.CandidateEmail.Email != "" {
|
|
if terr = a.sendConfirmation(r, tx, user, models.ImplicitFlow); terr != nil {
|
|
return nil, terr
|
|
}
|
|
emailConfirmationSent = true
|
|
}
|
|
if !config.Mailer.AllowUnverifiedEmailSignIns {
|
|
if emailConfirmationSent {
|
|
return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType)))
|
|
}
|
|
return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType)))
|
|
}
|
|
}
|
|
} else {
|
|
if terr := models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{
|
|
"provider": providerType,
|
|
}); terr != nil {
|
|
return nil, terr
|
|
}
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *provider.UserProvidedData, inviteToken, providerType string) (*models.User, error) {
|
|
user, err := models.FindUserByConfirmationToken(tx, inviteToken)
|
|
if err != nil {
|
|
if models.IsNotFoundError(err) {
|
|
return nil, notFoundError(ErrorCodeInviteNotFound, "Invite not found")
|
|
}
|
|
return nil, internalServerError("Database error finding user").WithInternalError(err)
|
|
}
|
|
|
|
var emailData *provider.Email
|
|
var emails []string
|
|
for i, e := range userData.Emails {
|
|
emails = append(emails, e.Email)
|
|
if user.GetEmail() == e.Email {
|
|
emailData = &userData.Emails[i]
|
|
break
|
|
}
|
|
}
|
|
|
|
if emailData == nil {
|
|
return nil, badRequestError(ErrorCodeValidationFailed, "Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", "))
|
|
}
|
|
|
|
var identityData map[string]interface{}
|
|
if userData.Metadata != nil {
|
|
identityData = structs.Map(userData.Metadata)
|
|
}
|
|
identity, err := a.createNewIdentity(tx, user, providerType, identityData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := user.UpdateAppMetaData(tx, map[string]interface{}{
|
|
"provider": providerType,
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := user.UpdateAppMetaDataProviders(tx); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := user.UpdateUserMetaData(tx, identityData); err != nil {
|
|
return nil, internalServerError("Database error updating user").WithInternalError(err)
|
|
}
|
|
|
|
if err := models.NewAuditLogEntry(r, tx, user, models.InviteAcceptedAction, "", map[string]interface{}{
|
|
"provider": providerType,
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// an account with a previously unconfirmed email + password
|
|
// combination or phone may exist. so now that there is an
|
|
// OAuth identity bound to this user, and since they have not
|
|
// confirmed their email or phone, they are unaware that a
|
|
// potentially malicious door exists into their account; thus
|
|
// the password and phone needs to be removed.
|
|
if err := user.RemoveUnconfirmedIdentities(tx, identity); err != nil {
|
|
return nil, internalServerError("Error updating user").WithInternalError(err)
|
|
}
|
|
|
|
// confirm because they were able to respond to invite email
|
|
if err := user.Confirm(tx); err != nil {
|
|
return nil, err
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.Context, error) {
|
|
var state string
|
|
switch r.Method {
|
|
case http.MethodPost:
|
|
state = r.FormValue("state")
|
|
default:
|
|
state = r.URL.Query().Get("state")
|
|
}
|
|
if state == "" {
|
|
return ctx, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing")
|
|
}
|
|
config := a.config
|
|
claims := ExternalProviderClaims{}
|
|
p := jwt.NewParser(jwt.WithValidMethods(config.JWT.ValidMethods))
|
|
_, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) {
|
|
if kid, ok := token.Header["kid"]; ok {
|
|
if kidStr, ok := kid.(string); ok {
|
|
return conf.FindPublicKeyByKid(kidStr, &config.JWT)
|
|
}
|
|
}
|
|
if alg, ok := token.Header["alg"]; ok {
|
|
if alg == jwt.SigningMethodHS256.Name {
|
|
// preserve backward compatibility for cases where the kid is not set
|
|
return []byte(config.JWT.Secret), nil
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("missing kid")
|
|
})
|
|
if err != nil {
|
|
return ctx, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err)
|
|
}
|
|
if claims.Provider == "" {
|
|
return ctx, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)")
|
|
}
|
|
if claims.InviteToken != "" {
|
|
ctx = withInviteToken(ctx, claims.InviteToken)
|
|
}
|
|
if claims.Referrer != "" {
|
|
ctx = withExternalReferrer(ctx, claims.Referrer)
|
|
}
|
|
if claims.FlowStateID != "" {
|
|
ctx = withFlowStateID(ctx, claims.FlowStateID)
|
|
}
|
|
if claims.LinkingTargetID != "" {
|
|
linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID)
|
|
if err != nil {
|
|
return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)")
|
|
}
|
|
u, err := models.FindUserByID(a.db, linkingTargetUserID)
|
|
if err != nil {
|
|
if models.IsNotFoundError(err) {
|
|
return nil, unprocessableEntityError(ErrorCodeUserNotFound, "Linking target user not found")
|
|
}
|
|
return nil, internalServerError("Database error loading user").WithInternalError(err)
|
|
}
|
|
ctx = withTargetUser(ctx, u)
|
|
}
|
|
ctx = withExternalProviderType(ctx, claims.Provider)
|
|
return withSignature(ctx, state), nil
|
|
}
|
|
|
|
// Provider returns a Provider interface for the given name.
|
|
func (a *API) Provider(ctx context.Context, name string, scopes string) (provider.Provider, error) {
|
|
config := a.config
|
|
name = strings.ToLower(name)
|
|
|
|
switch name {
|
|
case "apple":
|
|
return provider.NewAppleProvider(ctx, config.External.Apple)
|
|
case "azure":
|
|
return provider.NewAzureProvider(config.External.Azure, scopes)
|
|
case "bitbucket":
|
|
return provider.NewBitbucketProvider(config.External.Bitbucket)
|
|
case "discord":
|
|
return provider.NewDiscordProvider(config.External.Discord, scopes)
|
|
case "facebook":
|
|
return provider.NewFacebookProvider(config.External.Facebook, scopes)
|
|
case "figma":
|
|
return provider.NewFigmaProvider(config.External.Figma, scopes)
|
|
case "fly":
|
|
return provider.NewFlyProvider(config.External.Fly, scopes)
|
|
case "github":
|
|
return provider.NewGithubProvider(config.External.Github, scopes)
|
|
case "gitlab":
|
|
return provider.NewGitlabProvider(config.External.Gitlab, scopes)
|
|
case "google":
|
|
return provider.NewGoogleProvider(ctx, config.External.Google, scopes)
|
|
case "kakao":
|
|
return provider.NewKakaoProvider(config.External.Kakao, scopes)
|
|
case "keycloak":
|
|
return provider.NewKeycloakProvider(config.External.Keycloak, scopes)
|
|
case "linkedin":
|
|
return provider.NewLinkedinProvider(config.External.Linkedin, scopes)
|
|
case "linkedin_oidc":
|
|
return provider.NewLinkedinOIDCProvider(config.External.LinkedinOIDC, scopes)
|
|
case "notion":
|
|
return provider.NewNotionProvider(config.External.Notion)
|
|
case "spotify":
|
|
return provider.NewSpotifyProvider(config.External.Spotify, scopes)
|
|
case "slack":
|
|
return provider.NewSlackProvider(config.External.Slack, scopes)
|
|
case "slack_oidc":
|
|
return provider.NewSlackOIDCProvider(config.External.SlackOIDC, scopes)
|
|
case "twitch":
|
|
return provider.NewTwitchProvider(config.External.Twitch, scopes)
|
|
case "twitter":
|
|
return provider.NewTwitterProvider(config.External.Twitter, scopes)
|
|
case "vercel_marketplace":
|
|
return provider.NewVercelMarketplaceProvider(config.External.VercelMarketplace, scopes)
|
|
case "workos":
|
|
return provider.NewWorkOSProvider(config.External.WorkOS)
|
|
case "zoom":
|
|
return provider.NewZoomProvider(config.External.Zoom)
|
|
default:
|
|
return nil, fmt.Errorf("Provider %s could not be found", name)
|
|
}
|
|
}
|
|
|
|
func redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) {
|
|
ctx := r.Context()
|
|
log := observability.GetLogEntry(r).Entry
|
|
errorID := utilities.GetRequestID(ctx)
|
|
err := handler(w, r)
|
|
if err != nil {
|
|
q := getErrorQueryString(err, errorID, log, u.Query())
|
|
u.RawQuery = q.Encode()
|
|
|
|
// TODO: deprecate returning error details in the query fragment
|
|
hq := url.Values{}
|
|
if q.Get("error") != "" {
|
|
hq.Set("error", q.Get("error"))
|
|
}
|
|
if q.Get("error_description") != "" {
|
|
hq.Set("error_description", q.Get("error_description"))
|
|
}
|
|
if q.Get("error_code") != "" {
|
|
hq.Set("error_code", q.Get("error_code"))
|
|
}
|
|
u.Fragment = hq.Encode()
|
|
http.Redirect(w, r, u.String(), http.StatusFound)
|
|
}
|
|
}
|
|
|
|
func getErrorQueryString(err error, errorID string, log logrus.FieldLogger, q url.Values) *url.Values {
|
|
switch e := err.(type) {
|
|
case *HTTPError:
|
|
if e.ErrorCode == ErrorCodeSignupDisabled {
|
|
q.Set("error", "access_denied")
|
|
} else if e.ErrorCode == ErrorCodeUserBanned {
|
|
q.Set("error", "access_denied")
|
|
} else if e.ErrorCode == ErrorCodeProviderEmailNeedsVerification {
|
|
q.Set("error", "access_denied")
|
|
} else if str, ok := oauthErrorMap[e.HTTPStatus]; ok {
|
|
q.Set("error", str)
|
|
} else {
|
|
q.Set("error", "server_error")
|
|
}
|
|
if e.HTTPStatus >= http.StatusInternalServerError {
|
|
e.ErrorID = errorID
|
|
// this will get us the stack trace too
|
|
log.WithError(e.Cause()).Error(e.Error())
|
|
} else {
|
|
log.WithError(e.Cause()).Info(e.Error())
|
|
}
|
|
q.Set("error_description", e.Message)
|
|
q.Set("error_code", e.ErrorCode)
|
|
case *OAuthError:
|
|
q.Set("error", e.Err)
|
|
q.Set("error_description", e.Description)
|
|
log.WithError(e.Cause()).Info(e.Error())
|
|
case ErrorCause:
|
|
return getErrorQueryString(e.Cause(), errorID, log, q)
|
|
default:
|
|
error_type, error_description := "server_error", err.Error()
|
|
|
|
// Provide better error messages for certain user-triggered Postgres errors.
|
|
if pgErr := utilities.NewPostgresError(e); pgErr != nil {
|
|
error_description = pgErr.Message
|
|
if oauthErrorType, ok := oauthErrorMap[pgErr.HttpStatusCode]; ok {
|
|
error_type = oauthErrorType
|
|
}
|
|
}
|
|
|
|
q.Set("error", error_type)
|
|
q.Set("error_description", error_description)
|
|
}
|
|
return &q
|
|
}
|
|
|
|
func (a *API) getExternalRedirectURL(r *http.Request) string {
|
|
ctx := r.Context()
|
|
config := a.config
|
|
if config.External.RedirectURL != "" {
|
|
return config.External.RedirectURL
|
|
}
|
|
if er := getExternalReferrer(ctx); er != "" {
|
|
return er
|
|
}
|
|
return config.SiteURL
|
|
}
|
|
|
|
func (a *API) createNewIdentity(tx *storage.Connection, user *models.User, providerType string, identityData map[string]interface{}) (*models.Identity, error) {
|
|
identity, err := models.NewIdentity(user, providerType, identityData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if terr := tx.Create(identity); terr != nil {
|
|
return nil, internalServerError("Error creating identity").WithInternalError(terr)
|
|
}
|
|
|
|
return identity, nil
|
|
}
|