chatdesk-ui/auth_v2.169.0/internal/api/hooks.go

406 lines
12 KiB
Go

package api
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"net"
"net/http"
"strings"
"time"
"github.com/gofrs/uuid"
"github.com/sirupsen/logrus"
standardwebhooks "github.com/standard-webhooks/standard-webhooks/libraries/go"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/hooks"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/storage"
)
const (
DefaultHTTPHookTimeout = 5 * time.Second
DefaultHTTPHookRetries = 3
HTTPHookBackoffDuration = 2 * time.Second
PayloadLimit = 200 * 1024 // 200KB
)
func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
db := a.db.WithContext(ctx)
request, err := json.Marshal(input)
if err != nil {
panic(err)
}
var response []byte
invokeHookFunc := func(tx *storage.Connection) error {
// We rely on Postgres timeouts to ensure the function doesn't overrun
if terr := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)).Exec(); terr != nil {
return terr
}
if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", hookConfig.HookName), request).First(&response); terr != nil {
return terr
}
// reset the timeout
if terr := tx.RawQuery("set local statement_timeout TO default;").Exec(); terr != nil {
return terr
}
return nil
}
if tx != nil {
if err := invokeHookFunc(tx); err != nil {
return nil, err
}
} else {
if err := db.Transaction(invokeHookFunc); err != nil {
return nil, err
}
}
if err := json.Unmarshal(response, output); err != nil {
return response, err
}
return response, nil
}
func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input any) ([]byte, error) {
ctx := r.Context()
client := http.Client{
Timeout: DefaultHTTPHookTimeout,
}
ctx, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout)
defer cancel()
log := observability.GetLogEntry(r).Entry
requestURL := hookConfig.URI
hookLog := log.WithFields(logrus.Fields{
"component": "auth_hook",
"url": requestURL,
})
inputPayload, err := json.Marshal(input)
if err != nil {
return nil, err
}
for i := 0; i < DefaultHTTPHookRetries; i++ {
if i == 0 {
hookLog.Debugf("invocation attempt: %d", i)
} else {
hookLog.Infof("invocation attempt: %d", i)
}
msgID := uuid.Must(uuid.NewV4())
currentTime := time.Now()
signatureList, err := generateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(inputPayload))
if err != nil {
panic("Failed to make request object")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("webhook-id", msgID.String())
req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix()))
req.Header.Set("webhook-signature", strings.Join(signatureList, ", "))
// By default, Go Client sets encoding to gzip, which does not carry a content length header.
req.Header.Set("Accept-Encoding", "identity")
rsp, err := client.Do(req)
if err != nil && errors.Is(err, context.DeadlineExceeded) {
return nil, unprocessableEntityError(ErrorCodeHookTimeout, fmt.Sprintf("Failed to reach hook within maximum time of %f seconds", DefaultHTTPHookTimeout.Seconds()))
} else if err != nil {
if terr, ok := err.(net.Error); ok && terr.Timeout() || i < DefaultHTTPHookRetries-1 {
hookLog.Errorf("Request timed out for attempt %d with err %s", i, err)
time.Sleep(HTTPHookBackoffDuration)
continue
} else if i == DefaultHTTPHookRetries-1 {
return nil, unprocessableEntityError(ErrorCodeHookTimeoutAfterRetry, "Failed to reach hook after maximum retries")
} else {
return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err)
}
}
defer rsp.Body.Close()
switch rsp.StatusCode {
case http.StatusOK, http.StatusNoContent, http.StatusAccepted:
// Header.Get is case insensitive
contentType := rsp.Header.Get("Content-Type")
if contentType == "" {
return nil, badRequestError(ErrorCodeHookPayloadInvalidContentType, "Invalid Content-Type: Missing Content-Type header")
}
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return nil, badRequestError(ErrorCodeHookPayloadInvalidContentType, fmt.Sprintf("Invalid Content-Type header: %s", err.Error()))
}
if mediaType != "application/json" {
return nil, badRequestError(ErrorCodeHookPayloadInvalidContentType, "Invalid JSON response. Received content-type: "+contentType)
}
if rsp.Body == nil {
return nil, nil
}
limitedReader := io.LimitedReader{R: rsp.Body, N: PayloadLimit}
body, err := io.ReadAll(&limitedReader)
if err != nil {
return nil, err
}
if limitedReader.N <= 0 {
// check if the response body still has excess bytes to be read
if n, _ := rsp.Body.Read(make([]byte, 1)); n > 0 {
return nil, unprocessableEntityError(ErrorCodeHookPayloadOverSizeLimit, fmt.Sprintf("Payload size exceeded size limit of %d bytes", PayloadLimit))
}
}
return body, nil
case http.StatusTooManyRequests, http.StatusServiceUnavailable:
retryAfterHeader := rsp.Header.Get("retry-after")
// Check for truthy values to allow for flexibility to switch to time duration
if retryAfterHeader != "" {
continue
}
return nil, internalServerError("Service currently unavailable due to hook")
case http.StatusBadRequest:
return nil, internalServerError("Invalid payload sent to hook")
case http.StatusUnauthorized:
return nil, internalServerError("Hook requires authorization token")
default:
return nil, internalServerError("Unexpected status code returned from hook: %d", rsp.StatusCode)
}
}
return nil, nil
}
// invokePostgresHook invokes the hook code. conn can be nil, in which case a new
// transaction is opened. If calling invokeHook within a transaction, always
// pass the current transaction, as pool-exhaustion deadlocks are very easy to
// trigger.
func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, output any) error {
var err error
var response []byte
switch input.(type) {
case *hooks.SendSMSInput:
hookOutput, ok := output.(*hooks.SendSMSOutput)
if !ok {
panic("output should be *hooks.SendSMSOutput")
}
if response, err = a.runHook(r, conn, a.config.Hook.SendSMS, input, output); err != nil {
return err
}
if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling Send SMS output.").WithInternalError(err)
}
if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode
if httpCode == 0 {
httpCode = http.StatusInternalServerError
}
httpError := &HTTPError{
HTTPStatus: httpCode,
Message: hookOutput.HookError.Message,
}
return httpError.WithInternalError(&hookOutput.HookError)
}
return nil
case *hooks.SendEmailInput:
hookOutput, ok := output.(*hooks.SendEmailOutput)
if !ok {
panic("output should be *hooks.SendEmailOutput")
}
if response, err = a.runHook(r, conn, a.config.Hook.SendEmail, input, output); err != nil {
return err
}
if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling Send Email output.").WithInternalError(err)
}
if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode
if httpCode == 0 {
httpCode = http.StatusInternalServerError
}
httpError := &HTTPError{
HTTPStatus: httpCode,
Message: hookOutput.HookError.Message,
}
return httpError.WithInternalError(&hookOutput.HookError)
}
return nil
case *hooks.MFAVerificationAttemptInput:
hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput)
if !ok {
panic("output should be *hooks.MFAVerificationAttemptOutput")
}
if response, err = a.runHook(r, conn, a.config.Hook.MFAVerificationAttempt, input, output); err != nil {
return err
}
if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling MFA Verification Attempt output.").WithInternalError(err)
}
if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode
if httpCode == 0 {
httpCode = http.StatusInternalServerError
}
httpError := &HTTPError{
HTTPStatus: httpCode,
Message: hookOutput.HookError.Message,
}
return httpError.WithInternalError(&hookOutput.HookError)
}
return nil
case *hooks.PasswordVerificationAttemptInput:
hookOutput, ok := output.(*hooks.PasswordVerificationAttemptOutput)
if !ok {
panic("output should be *hooks.PasswordVerificationAttemptOutput")
}
if response, err = a.runHook(r, conn, a.config.Hook.PasswordVerificationAttempt, input, output); err != nil {
return err
}
if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling Password Verification Attempt output.").WithInternalError(err)
}
if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode
if httpCode == 0 {
httpCode = http.StatusInternalServerError
}
httpError := &HTTPError{
HTTPStatus: httpCode,
Message: hookOutput.HookError.Message,
}
return httpError.WithInternalError(&hookOutput.HookError)
}
return nil
case *hooks.CustomAccessTokenInput:
hookOutput, ok := output.(*hooks.CustomAccessTokenOutput)
if !ok {
panic("output should be *hooks.CustomAccessTokenOutput")
}
if response, err = a.runHook(r, conn, a.config.Hook.CustomAccessToken, input, output); err != nil {
return err
}
if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling Custom Access Token output.").WithInternalError(err)
}
if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode
if httpCode == 0 {
httpCode = http.StatusInternalServerError
}
httpError := &HTTPError{
HTTPStatus: httpCode,
Message: hookOutput.HookError.Message,
}
return httpError.WithInternalError(&hookOutput.HookError)
}
if err := validateTokenClaims(hookOutput.Claims); err != nil {
httpCode := hookOutput.HookError.HTTPCode
if httpCode == 0 {
httpCode = http.StatusInternalServerError
}
httpError := &HTTPError{
HTTPStatus: httpCode,
Message: err.Error(),
}
return httpError
}
return nil
}
return nil
}
func (a *API) runHook(r *http.Request, conn *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
ctx := r.Context()
logEntry := observability.GetLogEntry(r)
hookStart := time.Now()
var response []byte
var err error
switch {
case strings.HasPrefix(hookConfig.URI, "http:") || strings.HasPrefix(hookConfig.URI, "https:"):
response, err = a.runHTTPHook(r, hookConfig, input)
case strings.HasPrefix(hookConfig.URI, "pg-functions:"):
response, err = a.runPostgresHook(ctx, conn, hookConfig, input, output)
default:
return nil, fmt.Errorf("unsupported protocol: %q only postgres hooks and HTTPS functions are supported at the moment", hookConfig.URI)
}
duration := time.Since(hookStart)
if err != nil {
logEntry.Entry.WithFields(logrus.Fields{
"action": "run_hook",
"hook": hookConfig.URI,
"success": false,
"duration": duration.Microseconds(),
}).WithError(err).Warn("Hook errored out")
return nil, internalServerError("Error running hook URI: %v", hookConfig.URI).WithInternalError(err)
}
logEntry.Entry.WithFields(logrus.Fields{
"action": "run_hook",
"hook": hookConfig.URI,
"success": true,
"duration": duration.Microseconds(),
}).WithError(err).Info("Hook ran successfully")
return response, nil
}
func generateSignatures(secrets []string, msgID uuid.UUID, currentTime time.Time, inputPayload []byte) ([]string, error) {
SymmetricSignaturePrefix := "v1,"
// TODO(joel): Handle asymmetric case once library has been upgraded
var signatureList []string
for _, secret := range secrets {
if strings.HasPrefix(secret, SymmetricSignaturePrefix) {
trimmedSecret := strings.TrimPrefix(secret, SymmetricSignaturePrefix)
wh, err := standardwebhooks.NewWebhook(trimmedSecret)
if err != nil {
return nil, err
}
signature, err := wh.Sign(msgID.String(), currentTime, inputPayload)
if err != nil {
return nil, err
}
signatureList = append(signatureList, signature)
} else {
return nil, errors.New("invalid signature format")
}
}
return signatureList, nil
}