package api import ( "bytes" "context" "encoding/json" "fmt" "net/http" "net/url" "strings" "sync" "time" chimiddleware "github.com/go-chi/chi/v5/middleware" "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/security" "github.com/supabase/auth/internal/utilities" "github.com/didip/tollbooth/v5" "github.com/didip/tollbooth/v5/limiter" jwt "github.com/golang-jwt/jwt/v5" ) type FunctionHooks map[string][]string type AuthMicroserviceClaims struct { jwt.RegisteredClaims SiteURL string `json:"site_url"` InstanceID string `json:"id"` FunctionHooks FunctionHooks `json:"function_hooks"` } func (f *FunctionHooks) UnmarshalJSON(b []byte) error { var raw map[string][]string err := json.Unmarshal(b, &raw) if err == nil { *f = FunctionHooks(raw) return nil } // If unmarshaling into map[string][]string fails, try legacy format. var legacy map[string]string err = json.Unmarshal(b, &legacy) if err != nil { return err } if *f == nil { *f = make(FunctionHooks) } for event, hook := range legacy { (*f)[event] = []string{hook} } return nil } var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered") func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { return func(w http.ResponseWriter, req *http.Request) (context.Context, error) { c := req.Context() if limitHeader := a.config.RateLimitHeader; limitHeader != "" { key := req.Header.Get(limitHeader) if key == "" { log := observability.GetLogEntry(req).Entry log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied") return c, nil } else { err := tollbooth.LimitByKeys(lmt, []string{key}) if err != nil { return c, tooManyRequestsError(ErrorCodeOverRequestRateLimit, "Request rate limit reached") } } } return c, nil } } func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request) (context.Context, error) { t, err := a.extractBearerToken(req) if err != nil || t == "" { return nil, err } ctx, err := a.parseJWTClaims(t, req) if err != nil { return nil, err } return a.requireAdmin(ctx) } func (a *API) requireEmailProvider(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() config := a.config if !config.External.Email.Enabled { return nil, badRequestError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } return ctx, nil } func (a *API) verifyCaptcha(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() config := a.config if !config.Security.Captcha.Enabled { return ctx, nil } if _, err := a.requireAdminCredentials(w, req); err == nil { // skip captcha validation if authorization header contains an admin role return ctx, nil } if shouldIgnore := isIgnoreCaptchaRoute(req); shouldIgnore { return ctx, nil } body := &security.GotrueRequest{} if err := retrieveRequestParams(req, body); err != nil { return nil, err } verificationResult, err := security.VerifyRequest(body, utilities.GetIPAddress(req), strings.TrimSpace(config.Security.Captcha.Secret), config.Security.Captcha.Provider) if err != nil { return nil, internalServerError("captcha verification process failed").WithInternalError(err) } if !verificationResult.Success { return nil, badRequestError(ErrorCodeCaptchaFailed, "captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) } return ctx, nil } func isIgnoreCaptchaRoute(req *http.Request) bool { // captcha shouldn't be enabled on the following grant_types // id_token, refresh_token, pkce if req.URL.Path == "/token" && req.FormValue("grant_type") != "password" { return true } return false } func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() config := a.config xForwardedHost := req.Header.Get("X-Forwarded-Host") xForwardedProto := req.Header.Get("X-Forwarded-Proto") reqHost := req.URL.Hostname() if len(config.Mailer.ExternalHosts) > 0 { // this server is configured to accept multiple external hosts, validate the host from the X-Forwarded-Host or Host headers hostname := "" protocol := "https" if xForwardedHost != "" { for _, host := range config.Mailer.ExternalHosts { if host == xForwardedHost { hostname = host break } } } else if reqHost != "" { for _, host := range config.Mailer.ExternalHosts { if host == reqHost { hostname = host break } } } if hostname != "" { if hostname == "localhost" { // allow the use of HTTP only if the accepted hostname was localhost if xForwardedProto == "http" || req.URL.Scheme == "http" { protocol = "http" } } externalHostURL, err := url.ParseRequestURI(fmt.Sprintf("%s://%s", protocol, hostname)) if err != nil { return ctx, err } return withExternalHost(ctx, externalHostURL), nil } } if xForwardedHost != "" || reqHost != "" { // host has been provided to the request, but it hasn't been // added to the allow list, raise a log message // in Supabase platform the X-Forwarded-Host and full request // URL are likely sanitzied before they reach the server fields := make(logrus.Fields) if xForwardedHost != "" { fields["x_forwarded_host"] = xForwardedHost } if xForwardedProto != "" { fields["x_forwarded_proto"] = xForwardedProto } if reqHost != "" { fields["request_url_host"] = reqHost if req.URL.Scheme != "" { fields["request_url_scheme"] = req.URL.Scheme } } logrus.WithFields(fields).Info("Request received external host in X-Forwarded-Host or Host headers, but the values have not been added to GOTRUE_MAILER_EXTERNAL_HOSTS and will not be used. To suppress this message add the host, or sanitize the headers before the request reaches Auth.") } // either the provided external hosts don't match the allow list, or // the server is not configured to accept multiple hosts -- use the // configured external URL instead externalHostURL, err := url.ParseRequestURI(config.API.ExternalURL) if err != nil { return ctx, err } return withExternalHost(ctx, externalHostURL), nil } func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.SAML.Enabled { return nil, notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled") } return ctx, nil } func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.Security.ManualLinkingEnabled { return nil, notFoundError(ErrorCodeManualLinkingDisabled, "Manual linking is disabled") } return ctx, nil } func (a *API) databaseCleanup(cleanup models.Cleaner) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wrappedResp := chimiddleware.NewWrapResponseWriter(w, r.ProtoMajor) next.ServeHTTP(wrappedResp, r) switch r.Method { case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: if (wrappedResp.Status() / 100) != 2 { // don't do any cleanups for non-2xx responses return } // continue default: return } db := a.db.WithContext(r.Context()) log := observability.GetLogEntry(r).Entry affectedRows, err := cleanup.Clean(db) if err != nil { log.WithError(err).WithField("affected_rows", affectedRows).Warn("database cleanup failed") } else if affectedRows > 0 { log.WithField("affected_rows", affectedRows).Debug("cleaned up expired or stale rows") } }) } } // timeoutResponseWriter is a http.ResponseWriter that queues up a response // body to be sent if the serving completes before the context has exceeded its // deadline. type timeoutResponseWriter struct { sync.Mutex header http.Header wroteHeader bool snapHeader http.Header // snapshot of the header at the time WriteHeader was called statusCode int buf bytes.Buffer } func (t *timeoutResponseWriter) Header() http.Header { t.Lock() defer t.Unlock() return t.header } func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) { t.Lock() defer t.Unlock() if !t.wroteHeader { t.writeHeaderLocked(http.StatusOK) } return t.buf.Write(bytes) } func (t *timeoutResponseWriter) WriteHeader(statusCode int) { t.Lock() defer t.Unlock() t.writeHeaderLocked(statusCode) } func (t *timeoutResponseWriter) writeHeaderLocked(statusCode int) { if t.wroteHeader { // ignore multiple calls to WriteHeader // once WriteHeader has been called once, a snapshot of the header map is taken // and saved in snapHeader to be used in finallyWrite return } t.statusCode = statusCode t.wroteHeader = true t.snapHeader = t.header.Clone() } func (t *timeoutResponseWriter) finallyWrite(w http.ResponseWriter) { t.Lock() defer t.Unlock() dst := w.Header() for k, vv := range t.snapHeader { dst[k] = vv } if !t.wroteHeader { t.statusCode = http.StatusOK } w.WriteHeader(t.statusCode) if _, err := w.Write(t.buf.Bytes()); err != nil { logrus.WithError(err).Warn("Write failed") } } func timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() timeoutWriter := &timeoutResponseWriter{ header: make(http.Header), } panicChan := make(chan any, 1) serverDone := make(chan struct{}) go func() { defer func() { if p := recover(); p != nil { panicChan <- p } }() next.ServeHTTP(timeoutWriter, r.WithContext(ctx)) close(serverDone) }() select { case p := <-panicChan: panic(p) case <-serverDone: timeoutWriter.finallyWrite(w) case <-ctx.Done(): err := ctx.Err() if err == context.DeadlineExceeded { httpError := &HTTPError{ HTTPStatus: http.StatusGatewayTimeout, ErrorCode: ErrorCodeRequestTimeout, Message: "Processing this request timed out, please retry after a moment.", } httpError = httpError.WithInternalError(err) HandleResponseError(httpError, w, r) } else { // unrecognized context error, so we should wait for the server to finish // and write out the response <-serverDone timeoutWriter.finallyWrite(w) } } }) } }