402 lines
11 KiB
Go
402 lines
11 KiB
Go
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
chimiddleware "github.com/go-chi/chi/v5/middleware"
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/supabase/auth/internal/models"
|
|
"github.com/supabase/auth/internal/observability"
|
|
"github.com/supabase/auth/internal/security"
|
|
"github.com/supabase/auth/internal/utilities"
|
|
|
|
"github.com/didip/tollbooth/v5"
|
|
"github.com/didip/tollbooth/v5/limiter"
|
|
jwt "github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
type FunctionHooks map[string][]string
|
|
|
|
type AuthMicroserviceClaims struct {
|
|
jwt.RegisteredClaims
|
|
SiteURL string `json:"site_url"`
|
|
InstanceID string `json:"id"`
|
|
FunctionHooks FunctionHooks `json:"function_hooks"`
|
|
}
|
|
|
|
func (f *FunctionHooks) UnmarshalJSON(b []byte) error {
|
|
var raw map[string][]string
|
|
err := json.Unmarshal(b, &raw)
|
|
if err == nil {
|
|
*f = FunctionHooks(raw)
|
|
return nil
|
|
}
|
|
// If unmarshaling into map[string][]string fails, try legacy format.
|
|
var legacy map[string]string
|
|
err = json.Unmarshal(b, &legacy)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if *f == nil {
|
|
*f = make(FunctionHooks)
|
|
}
|
|
for event, hook := range legacy {
|
|
(*f)[event] = []string{hook}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered")
|
|
|
|
func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
|
|
return func(w http.ResponseWriter, req *http.Request) (context.Context, error) {
|
|
c := req.Context()
|
|
|
|
if limitHeader := a.config.RateLimitHeader; limitHeader != "" {
|
|
key := req.Header.Get(limitHeader)
|
|
|
|
if key == "" {
|
|
log := observability.GetLogEntry(req).Entry
|
|
log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied")
|
|
return c, nil
|
|
} else {
|
|
err := tollbooth.LimitByKeys(lmt, []string{key})
|
|
if err != nil {
|
|
return c, tooManyRequestsError(ErrorCodeOverRequestRateLimit, "Request rate limit reached")
|
|
}
|
|
}
|
|
}
|
|
return c, nil
|
|
}
|
|
}
|
|
|
|
func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request) (context.Context, error) {
|
|
t, err := a.extractBearerToken(req)
|
|
if err != nil || t == "" {
|
|
return nil, err
|
|
}
|
|
|
|
ctx, err := a.parseJWTClaims(t, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return a.requireAdmin(ctx)
|
|
}
|
|
|
|
func (a *API) requireEmailProvider(w http.ResponseWriter, req *http.Request) (context.Context, error) {
|
|
ctx := req.Context()
|
|
config := a.config
|
|
|
|
if !config.External.Email.Enabled {
|
|
return nil, badRequestError(ErrorCodeEmailProviderDisabled, "Email logins are disabled")
|
|
}
|
|
|
|
return ctx, nil
|
|
}
|
|
|
|
func (a *API) verifyCaptcha(w http.ResponseWriter, req *http.Request) (context.Context, error) {
|
|
ctx := req.Context()
|
|
config := a.config
|
|
|
|
if !config.Security.Captcha.Enabled {
|
|
return ctx, nil
|
|
}
|
|
if _, err := a.requireAdminCredentials(w, req); err == nil {
|
|
// skip captcha validation if authorization header contains an admin role
|
|
return ctx, nil
|
|
}
|
|
if shouldIgnore := isIgnoreCaptchaRoute(req); shouldIgnore {
|
|
return ctx, nil
|
|
}
|
|
|
|
body := &security.GotrueRequest{}
|
|
if err := retrieveRequestParams(req, body); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
verificationResult, err := security.VerifyRequest(body, utilities.GetIPAddress(req), strings.TrimSpace(config.Security.Captcha.Secret), config.Security.Captcha.Provider)
|
|
if err != nil {
|
|
return nil, internalServerError("captcha verification process failed").WithInternalError(err)
|
|
}
|
|
|
|
if !verificationResult.Success {
|
|
return nil, badRequestError(ErrorCodeCaptchaFailed, "captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", "))
|
|
}
|
|
|
|
return ctx, nil
|
|
}
|
|
|
|
func isIgnoreCaptchaRoute(req *http.Request) bool {
|
|
// captcha shouldn't be enabled on the following grant_types
|
|
// id_token, refresh_token, pkce
|
|
if req.URL.Path == "/token" && req.FormValue("grant_type") != "password" {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (context.Context, error) {
|
|
ctx := req.Context()
|
|
config := a.config
|
|
|
|
xForwardedHost := req.Header.Get("X-Forwarded-Host")
|
|
xForwardedProto := req.Header.Get("X-Forwarded-Proto")
|
|
reqHost := req.URL.Hostname()
|
|
|
|
if len(config.Mailer.ExternalHosts) > 0 {
|
|
// this server is configured to accept multiple external hosts, validate the host from the X-Forwarded-Host or Host headers
|
|
|
|
hostname := ""
|
|
protocol := "https"
|
|
|
|
if xForwardedHost != "" {
|
|
for _, host := range config.Mailer.ExternalHosts {
|
|
if host == xForwardedHost {
|
|
hostname = host
|
|
break
|
|
}
|
|
}
|
|
} else if reqHost != "" {
|
|
for _, host := range config.Mailer.ExternalHosts {
|
|
if host == reqHost {
|
|
hostname = host
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if hostname != "" {
|
|
if hostname == "localhost" {
|
|
// allow the use of HTTP only if the accepted hostname was localhost
|
|
if xForwardedProto == "http" || req.URL.Scheme == "http" {
|
|
protocol = "http"
|
|
}
|
|
}
|
|
|
|
externalHostURL, err := url.ParseRequestURI(fmt.Sprintf("%s://%s", protocol, hostname))
|
|
if err != nil {
|
|
return ctx, err
|
|
}
|
|
|
|
return withExternalHost(ctx, externalHostURL), nil
|
|
}
|
|
}
|
|
|
|
if xForwardedHost != "" || reqHost != "" {
|
|
// host has been provided to the request, but it hasn't been
|
|
// added to the allow list, raise a log message
|
|
// in Supabase platform the X-Forwarded-Host and full request
|
|
// URL are likely sanitzied before they reach the server
|
|
|
|
fields := make(logrus.Fields)
|
|
|
|
if xForwardedHost != "" {
|
|
fields["x_forwarded_host"] = xForwardedHost
|
|
}
|
|
|
|
if xForwardedProto != "" {
|
|
fields["x_forwarded_proto"] = xForwardedProto
|
|
}
|
|
|
|
if reqHost != "" {
|
|
fields["request_url_host"] = reqHost
|
|
|
|
if req.URL.Scheme != "" {
|
|
fields["request_url_scheme"] = req.URL.Scheme
|
|
}
|
|
}
|
|
|
|
logrus.WithFields(fields).Info("Request received external host in X-Forwarded-Host or Host headers, but the values have not been added to GOTRUE_MAILER_EXTERNAL_HOSTS and will not be used. To suppress this message add the host, or sanitize the headers before the request reaches Auth.")
|
|
}
|
|
|
|
// either the provided external hosts don't match the allow list, or
|
|
// the server is not configured to accept multiple hosts -- use the
|
|
// configured external URL instead
|
|
|
|
externalHostURL, err := url.ParseRequestURI(config.API.ExternalURL)
|
|
if err != nil {
|
|
return ctx, err
|
|
}
|
|
|
|
return withExternalHost(ctx, externalHostURL), nil
|
|
}
|
|
|
|
func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) {
|
|
ctx := req.Context()
|
|
if !a.config.SAML.Enabled {
|
|
return nil, notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled")
|
|
}
|
|
return ctx, nil
|
|
}
|
|
|
|
func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) {
|
|
ctx := req.Context()
|
|
if !a.config.Security.ManualLinkingEnabled {
|
|
return nil, notFoundError(ErrorCodeManualLinkingDisabled, "Manual linking is disabled")
|
|
}
|
|
return ctx, nil
|
|
}
|
|
|
|
func (a *API) databaseCleanup(cleanup models.Cleaner) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
wrappedResp := chimiddleware.NewWrapResponseWriter(w, r.ProtoMajor)
|
|
next.ServeHTTP(wrappedResp, r)
|
|
switch r.Method {
|
|
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
|
|
if (wrappedResp.Status() / 100) != 2 {
|
|
// don't do any cleanups for non-2xx responses
|
|
return
|
|
}
|
|
// continue
|
|
default:
|
|
return
|
|
}
|
|
|
|
db := a.db.WithContext(r.Context())
|
|
log := observability.GetLogEntry(r).Entry
|
|
|
|
affectedRows, err := cleanup.Clean(db)
|
|
if err != nil {
|
|
log.WithError(err).WithField("affected_rows", affectedRows).Warn("database cleanup failed")
|
|
} else if affectedRows > 0 {
|
|
log.WithField("affected_rows", affectedRows).Debug("cleaned up expired or stale rows")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// timeoutResponseWriter is a http.ResponseWriter that queues up a response
|
|
// body to be sent if the serving completes before the context has exceeded its
|
|
// deadline.
|
|
type timeoutResponseWriter struct {
|
|
sync.Mutex
|
|
|
|
header http.Header
|
|
wroteHeader bool
|
|
snapHeader http.Header // snapshot of the header at the time WriteHeader was called
|
|
statusCode int
|
|
buf bytes.Buffer
|
|
}
|
|
|
|
func (t *timeoutResponseWriter) Header() http.Header {
|
|
t.Lock()
|
|
defer t.Unlock()
|
|
|
|
return t.header
|
|
}
|
|
|
|
func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) {
|
|
t.Lock()
|
|
defer t.Unlock()
|
|
|
|
if !t.wroteHeader {
|
|
t.writeHeaderLocked(http.StatusOK)
|
|
}
|
|
|
|
return t.buf.Write(bytes)
|
|
}
|
|
|
|
func (t *timeoutResponseWriter) WriteHeader(statusCode int) {
|
|
t.Lock()
|
|
defer t.Unlock()
|
|
|
|
t.writeHeaderLocked(statusCode)
|
|
}
|
|
|
|
func (t *timeoutResponseWriter) writeHeaderLocked(statusCode int) {
|
|
if t.wroteHeader {
|
|
// ignore multiple calls to WriteHeader
|
|
// once WriteHeader has been called once, a snapshot of the header map is taken
|
|
// and saved in snapHeader to be used in finallyWrite
|
|
return
|
|
}
|
|
|
|
t.statusCode = statusCode
|
|
t.wroteHeader = true
|
|
t.snapHeader = t.header.Clone()
|
|
}
|
|
|
|
func (t *timeoutResponseWriter) finallyWrite(w http.ResponseWriter) {
|
|
t.Lock()
|
|
defer t.Unlock()
|
|
|
|
dst := w.Header()
|
|
for k, vv := range t.snapHeader {
|
|
dst[k] = vv
|
|
}
|
|
|
|
if !t.wroteHeader {
|
|
t.statusCode = http.StatusOK
|
|
}
|
|
|
|
w.WriteHeader(t.statusCode)
|
|
if _, err := w.Write(t.buf.Bytes()); err != nil {
|
|
logrus.WithError(err).Warn("Write failed")
|
|
}
|
|
}
|
|
|
|
func timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
|
defer cancel()
|
|
|
|
timeoutWriter := &timeoutResponseWriter{
|
|
header: make(http.Header),
|
|
}
|
|
|
|
panicChan := make(chan any, 1)
|
|
serverDone := make(chan struct{})
|
|
go func() {
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
panicChan <- p
|
|
}
|
|
}()
|
|
|
|
next.ServeHTTP(timeoutWriter, r.WithContext(ctx))
|
|
close(serverDone)
|
|
}()
|
|
|
|
select {
|
|
case p := <-panicChan:
|
|
panic(p)
|
|
|
|
case <-serverDone:
|
|
timeoutWriter.finallyWrite(w)
|
|
|
|
case <-ctx.Done():
|
|
err := ctx.Err()
|
|
|
|
if err == context.DeadlineExceeded {
|
|
httpError := &HTTPError{
|
|
HTTPStatus: http.StatusGatewayTimeout,
|
|
ErrorCode: ErrorCodeRequestTimeout,
|
|
Message: "Processing this request timed out, please retry after a moment.",
|
|
}
|
|
|
|
httpError = httpError.WithInternalError(err)
|
|
|
|
HandleResponseError(httpError, w, r)
|
|
} else {
|
|
// unrecognized context error, so we should wait for the server to finish
|
|
// and write out the response
|
|
<-serverDone
|
|
|
|
timeoutWriter.finallyWrite(w)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|