406 lines
12 KiB
Go
406 lines
12 KiB
Go
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"mime"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofrs/uuid"
|
|
"github.com/sirupsen/logrus"
|
|
standardwebhooks "github.com/standard-webhooks/standard-webhooks/libraries/go"
|
|
|
|
"github.com/supabase/auth/internal/conf"
|
|
"github.com/supabase/auth/internal/hooks"
|
|
"github.com/supabase/auth/internal/observability"
|
|
"github.com/supabase/auth/internal/storage"
|
|
)
|
|
|
|
const (
|
|
DefaultHTTPHookTimeout = 5 * time.Second
|
|
DefaultHTTPHookRetries = 3
|
|
HTTPHookBackoffDuration = 2 * time.Second
|
|
PayloadLimit = 200 * 1024 // 200KB
|
|
)
|
|
|
|
func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
|
|
db := a.db.WithContext(ctx)
|
|
|
|
request, err := json.Marshal(input)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
var response []byte
|
|
invokeHookFunc := func(tx *storage.Connection) error {
|
|
// We rely on Postgres timeouts to ensure the function doesn't overrun
|
|
if terr := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)).Exec(); terr != nil {
|
|
return terr
|
|
}
|
|
|
|
if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", hookConfig.HookName), request).First(&response); terr != nil {
|
|
return terr
|
|
}
|
|
|
|
// reset the timeout
|
|
if terr := tx.RawQuery("set local statement_timeout TO default;").Exec(); terr != nil {
|
|
return terr
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
if tx != nil {
|
|
if err := invokeHookFunc(tx); err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
if err := db.Transaction(invokeHookFunc); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if err := json.Unmarshal(response, output); err != nil {
|
|
return response, err
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input any) ([]byte, error) {
|
|
ctx := r.Context()
|
|
client := http.Client{
|
|
Timeout: DefaultHTTPHookTimeout,
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout)
|
|
defer cancel()
|
|
|
|
log := observability.GetLogEntry(r).Entry
|
|
requestURL := hookConfig.URI
|
|
hookLog := log.WithFields(logrus.Fields{
|
|
"component": "auth_hook",
|
|
"url": requestURL,
|
|
})
|
|
|
|
inputPayload, err := json.Marshal(input)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := 0; i < DefaultHTTPHookRetries; i++ {
|
|
if i == 0 {
|
|
hookLog.Debugf("invocation attempt: %d", i)
|
|
} else {
|
|
hookLog.Infof("invocation attempt: %d", i)
|
|
}
|
|
msgID := uuid.Must(uuid.NewV4())
|
|
currentTime := time.Now()
|
|
signatureList, err := generateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(inputPayload))
|
|
if err != nil {
|
|
panic("Failed to make request object")
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("webhook-id", msgID.String())
|
|
req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix()))
|
|
req.Header.Set("webhook-signature", strings.Join(signatureList, ", "))
|
|
// By default, Go Client sets encoding to gzip, which does not carry a content length header.
|
|
req.Header.Set("Accept-Encoding", "identity")
|
|
|
|
rsp, err := client.Do(req)
|
|
if err != nil && errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, unprocessableEntityError(ErrorCodeHookTimeout, fmt.Sprintf("Failed to reach hook within maximum time of %f seconds", DefaultHTTPHookTimeout.Seconds()))
|
|
|
|
} else if err != nil {
|
|
if terr, ok := err.(net.Error); ok && terr.Timeout() || i < DefaultHTTPHookRetries-1 {
|
|
hookLog.Errorf("Request timed out for attempt %d with err %s", i, err)
|
|
time.Sleep(HTTPHookBackoffDuration)
|
|
continue
|
|
} else if i == DefaultHTTPHookRetries-1 {
|
|
return nil, unprocessableEntityError(ErrorCodeHookTimeoutAfterRetry, "Failed to reach hook after maximum retries")
|
|
} else {
|
|
return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err)
|
|
}
|
|
}
|
|
|
|
defer rsp.Body.Close()
|
|
|
|
switch rsp.StatusCode {
|
|
case http.StatusOK, http.StatusNoContent, http.StatusAccepted:
|
|
// Header.Get is case insensitive
|
|
contentType := rsp.Header.Get("Content-Type")
|
|
if contentType == "" {
|
|
return nil, badRequestError(ErrorCodeHookPayloadInvalidContentType, "Invalid Content-Type: Missing Content-Type header")
|
|
}
|
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
|
if err != nil {
|
|
return nil, badRequestError(ErrorCodeHookPayloadInvalidContentType, fmt.Sprintf("Invalid Content-Type header: %s", err.Error()))
|
|
}
|
|
if mediaType != "application/json" {
|
|
return nil, badRequestError(ErrorCodeHookPayloadInvalidContentType, "Invalid JSON response. Received content-type: "+contentType)
|
|
}
|
|
if rsp.Body == nil {
|
|
return nil, nil
|
|
}
|
|
limitedReader := io.LimitedReader{R: rsp.Body, N: PayloadLimit}
|
|
body, err := io.ReadAll(&limitedReader)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if limitedReader.N <= 0 {
|
|
// check if the response body still has excess bytes to be read
|
|
if n, _ := rsp.Body.Read(make([]byte, 1)); n > 0 {
|
|
return nil, unprocessableEntityError(ErrorCodeHookPayloadOverSizeLimit, fmt.Sprintf("Payload size exceeded size limit of %d bytes", PayloadLimit))
|
|
}
|
|
}
|
|
return body, nil
|
|
case http.StatusTooManyRequests, http.StatusServiceUnavailable:
|
|
retryAfterHeader := rsp.Header.Get("retry-after")
|
|
// Check for truthy values to allow for flexibility to switch to time duration
|
|
if retryAfterHeader != "" {
|
|
continue
|
|
}
|
|
return nil, internalServerError("Service currently unavailable due to hook")
|
|
case http.StatusBadRequest:
|
|
return nil, internalServerError("Invalid payload sent to hook")
|
|
case http.StatusUnauthorized:
|
|
return nil, internalServerError("Hook requires authorization token")
|
|
default:
|
|
return nil, internalServerError("Unexpected status code returned from hook: %d", rsp.StatusCode)
|
|
}
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
// invokePostgresHook invokes the hook code. conn can be nil, in which case a new
|
|
// transaction is opened. If calling invokeHook within a transaction, always
|
|
// pass the current transaction, as pool-exhaustion deadlocks are very easy to
|
|
// trigger.
|
|
func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, output any) error {
|
|
var err error
|
|
var response []byte
|
|
|
|
switch input.(type) {
|
|
case *hooks.SendSMSInput:
|
|
hookOutput, ok := output.(*hooks.SendSMSOutput)
|
|
if !ok {
|
|
panic("output should be *hooks.SendSMSOutput")
|
|
}
|
|
if response, err = a.runHook(r, conn, a.config.Hook.SendSMS, input, output); err != nil {
|
|
return err
|
|
}
|
|
if err := json.Unmarshal(response, hookOutput); err != nil {
|
|
return internalServerError("Error unmarshaling Send SMS output.").WithInternalError(err)
|
|
}
|
|
if hookOutput.IsError() {
|
|
httpCode := hookOutput.HookError.HTTPCode
|
|
|
|
if httpCode == 0 {
|
|
httpCode = http.StatusInternalServerError
|
|
}
|
|
httpError := &HTTPError{
|
|
HTTPStatus: httpCode,
|
|
Message: hookOutput.HookError.Message,
|
|
}
|
|
return httpError.WithInternalError(&hookOutput.HookError)
|
|
}
|
|
return nil
|
|
case *hooks.SendEmailInput:
|
|
hookOutput, ok := output.(*hooks.SendEmailOutput)
|
|
if !ok {
|
|
panic("output should be *hooks.SendEmailOutput")
|
|
}
|
|
if response, err = a.runHook(r, conn, a.config.Hook.SendEmail, input, output); err != nil {
|
|
return err
|
|
}
|
|
if err := json.Unmarshal(response, hookOutput); err != nil {
|
|
return internalServerError("Error unmarshaling Send Email output.").WithInternalError(err)
|
|
}
|
|
if hookOutput.IsError() {
|
|
httpCode := hookOutput.HookError.HTTPCode
|
|
|
|
if httpCode == 0 {
|
|
httpCode = http.StatusInternalServerError
|
|
}
|
|
|
|
httpError := &HTTPError{
|
|
HTTPStatus: httpCode,
|
|
Message: hookOutput.HookError.Message,
|
|
}
|
|
|
|
return httpError.WithInternalError(&hookOutput.HookError)
|
|
}
|
|
return nil
|
|
case *hooks.MFAVerificationAttemptInput:
|
|
hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput)
|
|
if !ok {
|
|
panic("output should be *hooks.MFAVerificationAttemptOutput")
|
|
}
|
|
if response, err = a.runHook(r, conn, a.config.Hook.MFAVerificationAttempt, input, output); err != nil {
|
|
return err
|
|
}
|
|
if err := json.Unmarshal(response, hookOutput); err != nil {
|
|
return internalServerError("Error unmarshaling MFA Verification Attempt output.").WithInternalError(err)
|
|
}
|
|
if hookOutput.IsError() {
|
|
httpCode := hookOutput.HookError.HTTPCode
|
|
|
|
if httpCode == 0 {
|
|
httpCode = http.StatusInternalServerError
|
|
}
|
|
|
|
httpError := &HTTPError{
|
|
HTTPStatus: httpCode,
|
|
Message: hookOutput.HookError.Message,
|
|
}
|
|
|
|
return httpError.WithInternalError(&hookOutput.HookError)
|
|
}
|
|
return nil
|
|
case *hooks.PasswordVerificationAttemptInput:
|
|
hookOutput, ok := output.(*hooks.PasswordVerificationAttemptOutput)
|
|
if !ok {
|
|
panic("output should be *hooks.PasswordVerificationAttemptOutput")
|
|
}
|
|
|
|
if response, err = a.runHook(r, conn, a.config.Hook.PasswordVerificationAttempt, input, output); err != nil {
|
|
return err
|
|
}
|
|
if err := json.Unmarshal(response, hookOutput); err != nil {
|
|
return internalServerError("Error unmarshaling Password Verification Attempt output.").WithInternalError(err)
|
|
}
|
|
if hookOutput.IsError() {
|
|
httpCode := hookOutput.HookError.HTTPCode
|
|
|
|
if httpCode == 0 {
|
|
httpCode = http.StatusInternalServerError
|
|
}
|
|
|
|
httpError := &HTTPError{
|
|
HTTPStatus: httpCode,
|
|
Message: hookOutput.HookError.Message,
|
|
}
|
|
|
|
return httpError.WithInternalError(&hookOutput.HookError)
|
|
}
|
|
|
|
return nil
|
|
case *hooks.CustomAccessTokenInput:
|
|
hookOutput, ok := output.(*hooks.CustomAccessTokenOutput)
|
|
if !ok {
|
|
panic("output should be *hooks.CustomAccessTokenOutput")
|
|
}
|
|
if response, err = a.runHook(r, conn, a.config.Hook.CustomAccessToken, input, output); err != nil {
|
|
return err
|
|
}
|
|
if err := json.Unmarshal(response, hookOutput); err != nil {
|
|
return internalServerError("Error unmarshaling Custom Access Token output.").WithInternalError(err)
|
|
}
|
|
|
|
if hookOutput.IsError() {
|
|
httpCode := hookOutput.HookError.HTTPCode
|
|
|
|
if httpCode == 0 {
|
|
httpCode = http.StatusInternalServerError
|
|
}
|
|
|
|
httpError := &HTTPError{
|
|
HTTPStatus: httpCode,
|
|
Message: hookOutput.HookError.Message,
|
|
}
|
|
|
|
return httpError.WithInternalError(&hookOutput.HookError)
|
|
}
|
|
if err := validateTokenClaims(hookOutput.Claims); err != nil {
|
|
httpCode := hookOutput.HookError.HTTPCode
|
|
|
|
if httpCode == 0 {
|
|
httpCode = http.StatusInternalServerError
|
|
}
|
|
httpError := &HTTPError{
|
|
HTTPStatus: httpCode,
|
|
Message: err.Error(),
|
|
}
|
|
|
|
return httpError
|
|
}
|
|
return nil
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *API) runHook(r *http.Request, conn *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
|
|
ctx := r.Context()
|
|
|
|
logEntry := observability.GetLogEntry(r)
|
|
hookStart := time.Now()
|
|
|
|
var response []byte
|
|
var err error
|
|
|
|
switch {
|
|
case strings.HasPrefix(hookConfig.URI, "http:") || strings.HasPrefix(hookConfig.URI, "https:"):
|
|
response, err = a.runHTTPHook(r, hookConfig, input)
|
|
case strings.HasPrefix(hookConfig.URI, "pg-functions:"):
|
|
response, err = a.runPostgresHook(ctx, conn, hookConfig, input, output)
|
|
default:
|
|
return nil, fmt.Errorf("unsupported protocol: %q only postgres hooks and HTTPS functions are supported at the moment", hookConfig.URI)
|
|
}
|
|
|
|
duration := time.Since(hookStart)
|
|
|
|
if err != nil {
|
|
logEntry.Entry.WithFields(logrus.Fields{
|
|
"action": "run_hook",
|
|
"hook": hookConfig.URI,
|
|
"success": false,
|
|
"duration": duration.Microseconds(),
|
|
}).WithError(err).Warn("Hook errored out")
|
|
|
|
return nil, internalServerError("Error running hook URI: %v", hookConfig.URI).WithInternalError(err)
|
|
}
|
|
|
|
logEntry.Entry.WithFields(logrus.Fields{
|
|
"action": "run_hook",
|
|
"hook": hookConfig.URI,
|
|
"success": true,
|
|
"duration": duration.Microseconds(),
|
|
}).WithError(err).Info("Hook ran successfully")
|
|
|
|
return response, nil
|
|
}
|
|
|
|
func generateSignatures(secrets []string, msgID uuid.UUID, currentTime time.Time, inputPayload []byte) ([]string, error) {
|
|
SymmetricSignaturePrefix := "v1,"
|
|
// TODO(joel): Handle asymmetric case once library has been upgraded
|
|
var signatureList []string
|
|
for _, secret := range secrets {
|
|
if strings.HasPrefix(secret, SymmetricSignaturePrefix) {
|
|
trimmedSecret := strings.TrimPrefix(secret, SymmetricSignaturePrefix)
|
|
wh, err := standardwebhooks.NewWebhook(trimmedSecret)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
signature, err := wh.Sign(msgID.String(), currentTime, inputPayload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
signatureList = append(signatureList, signature)
|
|
} else {
|
|
return nil, errors.New("invalid signature format")
|
|
}
|
|
}
|
|
return signatureList, nil
|
|
}
|