328 lines
11 KiB
Go
328 lines
11 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"encoding/xml"
|
|
"net/http"
|
|
"net/url"
|
|
"time"
|
|
|
|
"github.com/crewjam/saml"
|
|
"github.com/fatih/structs"
|
|
"github.com/gofrs/uuid"
|
|
"github.com/supabase/auth/internal/api/provider"
|
|
"github.com/supabase/auth/internal/models"
|
|
"github.com/supabase/auth/internal/observability"
|
|
"github.com/supabase/auth/internal/storage"
|
|
"github.com/supabase/auth/internal/utilities"
|
|
)
|
|
|
|
func (a *API) samlDestroyRelayState(ctx context.Context, relayState *models.SAMLRelayState) error {
|
|
db := a.db.WithContext(ctx)
|
|
|
|
// It's OK to destroy the RelayState, as a user will
|
|
// likely initiate a completely new login flow, instead
|
|
// of reusing the same one.
|
|
|
|
return db.Transaction(func(tx *storage.Connection) error {
|
|
return tx.Destroy(relayState)
|
|
})
|
|
}
|
|
|
|
func IsSAMLMetadataStale(idpMetadata *saml.EntityDescriptor, samlProvider models.SAMLProvider) bool {
|
|
now := time.Now()
|
|
|
|
hasValidityExpired := !idpMetadata.ValidUntil.IsZero() && now.After(idpMetadata.ValidUntil)
|
|
hasCacheDurationExceeded := idpMetadata.CacheDuration != 0 && now.After(samlProvider.UpdatedAt.Add(idpMetadata.CacheDuration))
|
|
|
|
// if metadata XML does not publish validity or caching information, update once in 24 hours
|
|
needsForceUpdate := idpMetadata.ValidUntil.IsZero() && idpMetadata.CacheDuration == 0 && now.After(samlProvider.UpdatedAt.Add(24*time.Hour))
|
|
|
|
return hasValidityExpired || hasCacheDurationExceeded || needsForceUpdate
|
|
}
|
|
|
|
func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error {
|
|
if err := a.handleSamlAcs(w, r); err != nil {
|
|
u, uerr := url.Parse(a.config.SiteURL)
|
|
if uerr != nil {
|
|
return internalServerError("site url is improperly formattted").WithInternalError(err)
|
|
}
|
|
|
|
q := getErrorQueryString(err, utilities.GetRequestID(r.Context()), observability.GetLogEntry(r).Entry, u.Query())
|
|
u.RawQuery = q.Encode()
|
|
http.Redirect(w, r, u.String(), http.StatusSeeOther)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handleSamlAcs implements the main Assertion Consumer Service endpoint behavior.
|
|
func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
|
|
ctx := r.Context()
|
|
|
|
db := a.db.WithContext(ctx)
|
|
config := a.config
|
|
log := observability.GetLogEntry(r).Entry
|
|
|
|
relayStateValue := r.FormValue("RelayState")
|
|
relayStateUUID := uuid.FromStringOrNil(relayStateValue)
|
|
relayStateURL, _ := url.ParseRequestURI(relayStateValue)
|
|
|
|
entityId := ""
|
|
initiatedBy := ""
|
|
redirectTo := ""
|
|
var requestIds []string
|
|
|
|
var flowState *models.FlowState
|
|
if relayStateUUID != uuid.Nil {
|
|
// relay state is a valid UUID, therefore this is likely a SP initiated flow
|
|
|
|
relayState, err := models.FindSAMLRelayStateByID(db, relayStateUUID)
|
|
if models.IsNotFoundError(err) {
|
|
return notFoundError(ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?")
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
if time.Since(relayState.CreatedAt) >= a.config.SAML.RelayStateValidityPeriod {
|
|
if err := a.samlDestroyRelayState(ctx, relayState); err != nil {
|
|
return internalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err)
|
|
}
|
|
|
|
return unprocessableEntityError(ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try logging in again?")
|
|
}
|
|
|
|
// TODO: add abuse detection to bind the RelayState UUID with a
|
|
// HTTP-Only cookie
|
|
|
|
ssoProvider, err := models.FindSSOProviderByID(db, relayState.SSOProviderID)
|
|
if err != nil {
|
|
return internalServerError("Unable to find SSO Provider from SAML RelayState")
|
|
}
|
|
|
|
initiatedBy = "sp"
|
|
entityId = ssoProvider.SAMLProvider.EntityID
|
|
redirectTo = relayState.RedirectTo
|
|
requestIds = append(requestIds, relayState.RequestID)
|
|
if relayState.FlowState != nil {
|
|
flowState = relayState.FlowState
|
|
}
|
|
|
|
if err := a.samlDestroyRelayState(ctx, relayState); err != nil {
|
|
return err
|
|
}
|
|
} else if relayStateValue == "" || relayStateURL != nil {
|
|
// RelayState may be a URL in which case it's the URL where the
|
|
// IdP is telling us to redirect the user to
|
|
|
|
if r.FormValue("SAMLart") != "" {
|
|
// SAML Artifact responses are possible only when
|
|
// RelayState can be used to identify the Identity
|
|
// Provider.
|
|
return badRequestError(ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow")
|
|
}
|
|
|
|
samlResponse := r.FormValue("SAMLResponse")
|
|
if samlResponse == "" {
|
|
return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is missing")
|
|
}
|
|
|
|
responseXML, err := base64.StdEncoding.DecodeString(samlResponse)
|
|
if err != nil {
|
|
return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string")
|
|
}
|
|
|
|
var peekResponse saml.Response
|
|
err = xml.Unmarshal(responseXML, &peekResponse)
|
|
if err != nil {
|
|
return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err)
|
|
}
|
|
|
|
initiatedBy = "idp"
|
|
entityId = peekResponse.Issuer.Value
|
|
redirectTo = relayStateValue
|
|
} else {
|
|
// RelayState can't be identified, so SAML flow can't continue
|
|
return badRequestError(ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL")
|
|
}
|
|
|
|
ssoProvider, err := models.FindSAMLProviderByEntityID(db, entityId)
|
|
if models.IsNotFoundError(err) {
|
|
return notFoundError(ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider")
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
idpMetadata, err := ssoProvider.SAMLProvider.EntityDescriptor()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
samlMetadataModified := false
|
|
|
|
if ssoProvider.SAMLProvider.MetadataURL == nil {
|
|
if !idpMetadata.ValidUntil.IsZero() && time.Until(idpMetadata.ValidUntil) <= (30*24*60)*time.Second {
|
|
logentry := log.WithField("sso_provider_id", ssoProvider.ID.String())
|
|
logentry = logentry.WithField("expires_in", time.Until(idpMetadata.ValidUntil).String())
|
|
logentry = logentry.WithField("valid_until", idpMetadata.ValidUntil)
|
|
logentry = logentry.WithField("saml_entity_id", ssoProvider.SAMLProvider.EntityID)
|
|
|
|
logentry.Warn("SAML Metadata for identity provider will expire soon! Update its metadata_xml!")
|
|
}
|
|
} else if *ssoProvider.SAMLProvider.MetadataURL != "" && IsSAMLMetadataStale(idpMetadata, ssoProvider.SAMLProvider) {
|
|
rawMetadata, err := fetchSAMLMetadata(ctx, *ssoProvider.SAMLProvider.MetadataURL)
|
|
if err != nil {
|
|
// Fail silently but raise warning and continue with existing metadata
|
|
logentry := log.WithField("sso_provider_id", ssoProvider.ID.String())
|
|
logentry = logentry.WithField("expires_in", time.Until(idpMetadata.ValidUntil).String())
|
|
logentry = logentry.WithField("valid_until", idpMetadata.ValidUntil)
|
|
logentry = logentry.WithError(err)
|
|
logentry.Warn("SAML Metadata could not be retrieved, continuing with existing metadata")
|
|
} else {
|
|
ssoProvider.SAMLProvider.MetadataXML = string(rawMetadata)
|
|
samlMetadataModified = true
|
|
}
|
|
}
|
|
|
|
serviceProvider := a.getSAMLServiceProvider(idpMetadata, initiatedBy == "idp")
|
|
spAssertion, err := serviceProvider.ParseResponse(r, requestIds)
|
|
if err != nil {
|
|
if ire, ok := err.(*saml.InvalidResponseError); ok {
|
|
return badRequestError(ErrorCodeValidationFailed, "SAML Assertion is not valid %s", ire.Response).WithInternalError(ire.PrivateErr)
|
|
}
|
|
|
|
return badRequestError(ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err)
|
|
}
|
|
|
|
assertion := SAMLAssertion{
|
|
spAssertion,
|
|
}
|
|
|
|
userID := assertion.UserID()
|
|
if userID == "" {
|
|
return badRequestError(ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user")
|
|
}
|
|
|
|
claims := assertion.Process(ssoProvider.SAMLProvider.AttributeMapping)
|
|
|
|
email, ok := claims["email"].(string)
|
|
if !ok || email == "" {
|
|
// mapping does not identify the email attribute, try to figure it out
|
|
email = assertion.Email()
|
|
}
|
|
|
|
if email == "" {
|
|
return badRequestError(ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address")
|
|
} else {
|
|
claims["email"] = email
|
|
}
|
|
|
|
jsonClaims, err := json.Marshal(claims)
|
|
if err != nil {
|
|
return internalServerError("Mapped claims from provider could not be serialized into JSON").WithInternalError(err)
|
|
}
|
|
|
|
providerClaims := &provider.Claims{}
|
|
if err := json.Unmarshal(jsonClaims, providerClaims); err != nil {
|
|
return internalServerError("Mapped claims from provider could not be deserialized from JSON").WithInternalError(err)
|
|
}
|
|
|
|
providerClaims.Subject = userID
|
|
providerClaims.Issuer = ssoProvider.SAMLProvider.EntityID
|
|
providerClaims.Email = email
|
|
providerClaims.EmailVerified = true
|
|
|
|
providerClaimsMap := structs.Map(providerClaims)
|
|
|
|
// remove all of the parsed claims, so that the rest can go into CustomClaims
|
|
for key := range providerClaimsMap {
|
|
delete(claims, key)
|
|
}
|
|
|
|
providerClaims.CustomClaims = claims
|
|
|
|
var userProvidedData provider.UserProvidedData
|
|
|
|
userProvidedData.Emails = append(userProvidedData.Emails, provider.Email{
|
|
Email: email,
|
|
Verified: true,
|
|
Primary: true,
|
|
})
|
|
|
|
// userProvidedData.Provider.Type = "saml"
|
|
// userProvidedData.Provider.ID = ssoProvider.ID.String()
|
|
// userProvidedData.Provider.SAMLEntityID = ssoProvider.SAMLProvider.EntityID
|
|
// userProvidedData.Provider.SAMLInitiatedBy = initiatedBy
|
|
|
|
userProvidedData.Metadata = providerClaims
|
|
|
|
// TODO: below
|
|
// refreshTokenParams.SSOProviderID = ssoProvider.ID
|
|
// refreshTokenParams.InitiatedByProvider = initiatedBy == "idp"
|
|
// refreshTokenParams.NotBefore = assertion.NotBefore()
|
|
// refreshTokenParams.NotAfter = assertion.NotAfter()
|
|
|
|
notAfter := assertion.NotAfter()
|
|
|
|
var grantParams models.GrantParams
|
|
|
|
grantParams.FillGrantParams(r)
|
|
|
|
if !notAfter.IsZero() {
|
|
grantParams.SessionNotAfter = ¬After
|
|
}
|
|
|
|
var token *AccessTokenResponse
|
|
if samlMetadataModified {
|
|
if err := db.UpdateColumns(&ssoProvider.SAMLProvider, "metadata_xml", "updated_at"); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := db.Transaction(func(tx *storage.Connection) error {
|
|
var terr error
|
|
var user *models.User
|
|
|
|
// accounts potentially created via SAML can contain non-unique email addresses in the auth.users table
|
|
if user, terr = a.createAccountFromExternalIdentity(tx, r, &userProvidedData, "sso:"+ssoProvider.ID.String()); terr != nil {
|
|
return terr
|
|
}
|
|
if flowState != nil {
|
|
// This means that the callback is using PKCE
|
|
flowState.UserID = &(user.ID)
|
|
if terr := tx.Update(flowState); terr != nil {
|
|
return terr
|
|
}
|
|
}
|
|
|
|
token, terr = a.issueRefreshToken(r, tx, user, models.SSOSAML, grantParams)
|
|
|
|
if terr != nil {
|
|
return internalServerError("Unable to issue refresh token from SAML Assertion").WithInternalError(terr)
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
if !utilities.IsRedirectURLValid(config, redirectTo) {
|
|
redirectTo = config.SiteURL
|
|
}
|
|
if flowState != nil {
|
|
// This means that the callback is using PKCE
|
|
// Set the flowState.AuthCode to the query param here
|
|
redirectTo, err = a.prepPKCERedirectURL(redirectTo, flowState.AuthCode)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
http.Redirect(w, r, redirectTo, http.StatusFound)
|
|
return nil
|
|
|
|
}
|
|
http.Redirect(w, r, token.AsRedirectURL(redirectTo, url.Values{}), http.StatusFound)
|
|
|
|
return nil
|
|
}
|