156 lines
4.6 KiB
Go
156 lines
4.6 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
|
|
"github.com/mrjones/oauth"
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/supabase/auth/internal/api/provider"
|
|
"github.com/supabase/auth/internal/observability"
|
|
"github.com/supabase/auth/internal/utilities"
|
|
)
|
|
|
|
// OAuthProviderData contains the userData and token returned by the oauth provider
|
|
type OAuthProviderData struct {
|
|
userData *provider.UserProvidedData
|
|
token string
|
|
refreshToken string
|
|
code string
|
|
}
|
|
|
|
// loadFlowState parses the `state` query parameter as a JWS payload,
|
|
// extracting the provider requested
|
|
func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) {
|
|
ctx := r.Context()
|
|
oauthToken := r.URL.Query().Get("oauth_token")
|
|
if oauthToken != "" {
|
|
ctx = withRequestToken(ctx, oauthToken)
|
|
}
|
|
oauthVerifier := r.URL.Query().Get("oauth_verifier")
|
|
if oauthVerifier != "" {
|
|
ctx = withOAuthVerifier(ctx, oauthVerifier)
|
|
}
|
|
|
|
var err error
|
|
ctx, err = a.loadExternalState(ctx, r)
|
|
if err != nil {
|
|
u, uerr := url.ParseRequestURI(a.config.SiteURL)
|
|
if uerr != nil {
|
|
return ctx, internalServerError("site url is improperly formatted").WithInternalError(uerr)
|
|
}
|
|
|
|
q := getErrorQueryString(err, utilities.GetRequestID(ctx), observability.GetLogEntry(r).Entry, u.Query())
|
|
u.RawQuery = q.Encode()
|
|
|
|
http.Redirect(w, r, u.String(), http.StatusSeeOther)
|
|
}
|
|
return ctx, err
|
|
}
|
|
|
|
func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) {
|
|
var rq url.Values
|
|
if err := r.ParseForm(); r.Method == http.MethodPost && err == nil {
|
|
rq = r.Form
|
|
} else {
|
|
rq = r.URL.Query()
|
|
}
|
|
|
|
extError := rq.Get("error")
|
|
if extError != "" {
|
|
return nil, oauthError(extError, rq.Get("error_description"))
|
|
}
|
|
|
|
oauthCode := rq.Get("code")
|
|
if oauthCode == "" {
|
|
return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing")
|
|
}
|
|
|
|
oAuthProvider, err := a.OAuthProvider(ctx, providerType)
|
|
if err != nil {
|
|
return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err)
|
|
}
|
|
|
|
log := observability.GetLogEntry(r).Entry
|
|
log.WithFields(logrus.Fields{
|
|
"provider": providerType,
|
|
"code": oauthCode,
|
|
}).Debug("Exchanging oauth code")
|
|
|
|
token, err := oAuthProvider.GetOAuthToken(oauthCode)
|
|
if err != nil {
|
|
return nil, internalServerError("Unable to exchange external code: %s", oauthCode).WithInternalError(err)
|
|
}
|
|
|
|
userData, err := oAuthProvider.GetUserData(ctx, token)
|
|
if err != nil {
|
|
return nil, internalServerError("Error getting user profile from external provider").WithInternalError(err)
|
|
}
|
|
|
|
switch externalProvider := oAuthProvider.(type) {
|
|
case *provider.AppleProvider:
|
|
// apple only returns user info the first time
|
|
oauthUser := rq.Get("user")
|
|
if oauthUser != "" {
|
|
err := externalProvider.ParseUser(oauthUser, userData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
return &OAuthProviderData{
|
|
userData: userData,
|
|
token: token.AccessToken,
|
|
refreshToken: token.RefreshToken,
|
|
code: oauthCode,
|
|
}, nil
|
|
}
|
|
|
|
func (a *API) oAuth1Callback(ctx context.Context, providerType string) (*OAuthProviderData, error) {
|
|
oAuthProvider, err := a.OAuthProvider(ctx, providerType)
|
|
if err != nil {
|
|
return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err)
|
|
}
|
|
oauthToken := getRequestToken(ctx)
|
|
oauthVerifier := getOAuthVerifier(ctx)
|
|
var accessToken *oauth.AccessToken
|
|
var userData *provider.UserProvidedData
|
|
if twitterProvider, ok := oAuthProvider.(*provider.TwitterProvider); ok {
|
|
accessToken, err = twitterProvider.Consumer.AuthorizeToken(&oauth.RequestToken{
|
|
Token: oauthToken,
|
|
}, oauthVerifier)
|
|
if err != nil {
|
|
return nil, internalServerError("Unable to retrieve access token").WithInternalError(err)
|
|
}
|
|
userData, err = twitterProvider.FetchUserData(ctx, accessToken)
|
|
if err != nil {
|
|
return nil, internalServerError("Error getting user email from external provider").WithInternalError(err)
|
|
}
|
|
}
|
|
|
|
return &OAuthProviderData{
|
|
userData: userData,
|
|
token: accessToken.Token,
|
|
refreshToken: "",
|
|
}, nil
|
|
|
|
}
|
|
|
|
// OAuthProvider returns the corresponding oauth provider as an OAuthProvider interface
|
|
func (a *API) OAuthProvider(ctx context.Context, name string) (provider.OAuthProvider, error) {
|
|
providerCandidate, err := a.Provider(ctx, name, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch p := providerCandidate.(type) {
|
|
case provider.OAuthProvider:
|
|
return p, nil
|
|
default:
|
|
return nil, fmt.Errorf("Provider %v cannot be used for OAuth", name)
|
|
}
|
|
}
|