chatai/auth_v2.169.0/internal/api/middleware_test.go

511 lines
14 KiB
Go

package api
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/didip/tollbooth/v5"
"github.com/didip/tollbooth/v5/limiter"
jwt "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/storage"
)
const (
HCaptchaSecret string = "0x0000000000000000000000000000000000000000"
CaptchaResponse string = "10000000-aaaa-bbbb-cccc-000000000001"
TurnstileCaptchaSecret string = "1x0000000000000000000000000000000AA"
)
type MiddlewareTestSuite struct {
suite.Suite
API *API
Config *conf.GlobalConfiguration
}
func TestMiddlewareFunctions(t *testing.T) {
api, config, err := setupAPIForTest()
require.NoError(t, err)
ts := &MiddlewareTestSuite{
API: api,
Config: config,
}
defer api.db.Close()
suite.Run(t, ts)
}
func (ts *MiddlewareTestSuite) TestVerifyCaptchaValid() {
ts.Config.Security.Captcha.Enabled = true
adminClaims := &AccessTokenClaims{
Role: "supabase_admin",
}
adminJwt, err := jwt.NewWithClaims(jwt.SigningMethodHS256, adminClaims).SignedString([]byte(ts.Config.JWT.Secret))
require.NoError(ts.T(), err)
cases := []struct {
desc string
adminJwt string
captcha_token string
captcha_provider string
}{
{
"Valid captcha response",
"",
CaptchaResponse,
"hcaptcha",
},
{
"Valid captcha response",
"",
CaptchaResponse,
"turnstile",
},
{
"Ignore captcha if admin role is present",
adminJwt,
"",
"hcaptcha",
},
{
"Ignore captcha if admin role is present",
adminJwt,
"",
"turnstile",
},
}
for _, c := range cases {
ts.Config.Security.Captcha.Provider = c.captcha_provider
if c.captcha_provider == "turnstile" {
ts.Config.Security.Captcha.Secret = TurnstileCaptchaSecret
} else if c.captcha_provider == "hcaptcha" {
ts.Config.Security.Captcha.Secret = HCaptchaSecret
}
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": "test@example.com",
"password": "secret",
"gotrue_meta_security": map[string]interface{}{
"captcha_token": c.captcha_token,
},
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer)
req.Header.Set("Content-Type", "application/json")
if c.adminJwt != "" {
req.Header.Set("Authorization", "Bearer "+c.adminJwt)
}
beforeCtx := context.Background()
req = req.WithContext(beforeCtx)
w := httptest.NewRecorder()
afterCtx, err := ts.API.verifyCaptcha(w, req)
require.NoError(ts.T(), err)
body, err := io.ReadAll(req.Body)
require.NoError(ts.T(), err)
// re-initialize buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": "test@example.com",
"password": "secret",
"gotrue_meta_security": map[string]interface{}{
"captcha_token": c.captcha_token,
},
}))
// check if body is the same
require.Equal(ts.T(), body, buffer.Bytes())
require.Equal(ts.T(), afterCtx, beforeCtx)
}
}
func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() {
cases := []struct {
desc string
captchaConf *conf.CaptchaConfiguration
expectedCode int
expectedMsg string
}{
{
"Captcha validation failed",
&conf.CaptchaConfiguration{
Enabled: true,
Provider: "hcaptcha",
Secret: "test",
},
http.StatusBadRequest,
"captcha protection: request disallowed (not-using-dummy-secret)",
},
{
"Captcha validation failed",
&conf.CaptchaConfiguration{
Enabled: true,
Provider: "turnstile",
Secret: "anothertest",
},
http.StatusBadRequest,
"captcha protection: request disallowed (invalid-input-secret)",
},
}
for _, c := range cases {
ts.Run(c.desc, func() {
ts.Config.Security.Captcha = *c.captchaConf
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": "test@example.com",
"password": "secret",
"gotrue_meta_security": map[string]interface{}{
"captcha_token": CaptchaResponse,
},
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer)
req.Header.Set("Content-Type", "application/json")
req = req.WithContext(context.Background())
w := httptest.NewRecorder()
_, err := ts.API.verifyCaptcha(w, req)
require.Equal(ts.T(), c.expectedCode, err.(*HTTPError).HTTPStatus)
require.Equal(ts.T(), c.expectedMsg, err.(*HTTPError).Message)
})
}
}
func (ts *MiddlewareTestSuite) TestIsValidExternalHost() {
cases := []struct {
desc string
externalHosts []string
requestURL string
headers http.Header
expectedURL string
}{
{
desc: "no defined external hosts, no headers, no absolute request URL",
requestURL: "/some-path",
expectedURL: ts.API.config.API.ExternalURL,
},
{
desc: "no defined external hosts, unauthorized X-Forwarded-Host without any external hosts",
headers: http.Header{
"X-Forwarded-Host": []string{
"external-host.com",
},
},
requestURL: "/some-path",
expectedURL: ts.API.config.API.ExternalURL,
},
{
desc: "defined external hosts, unauthorized X-Forwarded-Host",
externalHosts: []string{"authorized-host.com"},
headers: http.Header{
"X-Forwarded-Proto": []string{"https"},
"X-Forwarded-Host": []string{
"external-host.com",
},
},
requestURL: "/some-path",
expectedURL: ts.API.config.API.ExternalURL,
},
{
desc: "no defined external hosts, unauthorized Host",
requestURL: "https://external-host.com/some-path",
expectedURL: ts.API.config.API.ExternalURL,
},
{
desc: "defined external hosts, unauthorized Host",
externalHosts: []string{"authorized-host.com"},
requestURL: "https://external-host.com/some-path",
expectedURL: ts.API.config.API.ExternalURL,
},
{
desc: "defined external hosts, authorized X-Forwarded-Host",
externalHosts: []string{"authorized-host.com"},
headers: http.Header{
"X-Forwarded-Proto": []string{"http"}, // this should be ignored and default to HTTPS
"X-Forwarded-Host": []string{
"authorized-host.com",
},
},
requestURL: "https://X-Forwarded-Host-takes-precedence.com/some-path",
expectedURL: "https://authorized-host.com",
},
{
desc: "defined external hosts, authorized Host",
externalHosts: []string{"authorized-host.com"},
requestURL: "https://authorized-host.com/some-path",
expectedURL: "https://authorized-host.com",
},
{
desc: "defined external hosts, authorized X-Forwarded-Host",
externalHosts: []string{"authorized-host.com"},
headers: http.Header{
"X-Forwarded-Proto": []string{"http"}, // this should be ignored and default to HTTPS
"X-Forwarded-Host": []string{
"authorized-host.com",
},
},
requestURL: "https://X-Forwarded-Host-takes-precedence.com/some-path",
expectedURL: "https://authorized-host.com",
},
{
desc: "defined external hosts, authorized localhost in X-Forwarded-Host with HTTP",
externalHosts: []string{"localhost"},
headers: http.Header{
"X-Forwarded-Proto": []string{"http"},
"X-Forwarded-Host": []string{
"localhost",
},
},
requestURL: "/some-path",
expectedURL: "http://localhost",
},
{
desc: "defined external hosts, authorized localhost in Host with HTTP",
externalHosts: []string{"localhost"},
requestURL: "http://localhost:3000/some-path",
expectedURL: "http://localhost",
},
}
require.NotEmpty(ts.T(), ts.API.config.API.ExternalURL)
for _, c := range cases {
ts.Run(c.desc, func() {
req := httptest.NewRequest(http.MethodPost, c.requestURL, nil)
if c.headers != nil {
req.Header = c.headers
}
originalHosts := ts.API.config.Mailer.ExternalHosts
ts.API.config.Mailer.ExternalHosts = c.externalHosts
w := httptest.NewRecorder()
ctx, err := ts.API.isValidExternalHost(w, req)
ts.API.config.Mailer.ExternalHosts = originalHosts
require.NoError(ts.T(), err)
externalURL := getExternalHost(ctx)
require.Equal(ts.T(), c.expectedURL, externalURL.String())
})
}
}
func (ts *MiddlewareTestSuite) TestRequireSAMLEnabled() {
cases := []struct {
desc string
isEnabled bool
expectedErr error
}{
{
desc: "SAML not enabled",
isEnabled: false,
expectedErr: notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled"),
},
{
desc: "SAML enabled",
isEnabled: true,
expectedErr: nil,
},
}
for _, c := range cases {
ts.Run(c.desc, func() {
ts.Config.SAML.Enabled = c.isEnabled
req := httptest.NewRequest("GET", "http://localhost", nil)
w := httptest.NewRecorder()
_, err := ts.API.requireSAMLEnabled(w, req)
require.Equal(ts.T(), c.expectedErr, err)
})
}
}
func TestFunctionHooksUnmarshalJSON(t *testing.T) {
tests := []struct {
in string
ok bool
}{
{`{ "signup" : "identity-signup" }`, true},
{`{ "signup" : ["identity-signup"] }`, true},
{`{ "signup" : {"foo" : "bar"} }`, false},
}
for _, tt := range tests {
t.Run(tt.in, func(t *testing.T) {
var f FunctionHooks
err := json.Unmarshal([]byte(tt.in), &f)
if tt.ok {
assert.NoError(t, err)
assert.Equal(t, FunctionHooks{"signup": {"identity-signup"}}, f)
} else {
assert.Error(t, err)
}
})
}
}
func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() {
ts.Config.API.MaxRequestDuration = 5 * time.Microsecond
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
w := httptest.NewRecorder()
timeoutHandler := timeoutMiddleware(ts.Config.API.MaxRequestDuration)
slowHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Sleep for 1 second to simulate a slow handler which should trigger the timeout
time.Sleep(1 * time.Second)
ts.API.handler.ServeHTTP(w, r)
})
timeoutHandler(slowHandler).ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusGatewayTimeout, w.Code)
var data map[string]interface{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
require.Equal(ts.T(), ErrorCodeRequestTimeout, data["error_code"])
require.Equal(ts.T(), float64(504), data["code"])
require.NotNil(ts.T(), data["msg"])
}
func TestTimeoutResponseWriter(t *testing.T) {
// timeoutResponseWriter should exhitbit a similar behavior as http.ResponseWriter
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
w1 := httptest.NewRecorder()
w2 := httptest.NewRecorder()
timeoutHandler := timeoutMiddleware(time.Second * 10)
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// tries to redirect twice
http.Redirect(w, r, "http://localhost:3001/#message=first_message", http.StatusSeeOther)
// overwrites the first
http.Redirect(w, r, "http://localhost:3001/second", http.StatusSeeOther)
})
timeoutHandler(redirectHandler).ServeHTTP(w1, req)
redirectHandler.ServeHTTP(w2, req)
require.Equal(t, w1.Result(), w2.Result())
}
func (ts *MiddlewareTestSuite) TestLimitHandler() {
ts.Config.RateLimitHeader = "X-Rate-Limit"
lmt := tollbooth.NewLimiter(5, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
})
okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
b, _ := json.Marshal(map[string]interface{}{"message": "ok"})
w.Write([]byte(b))
})
for i := 0; i < 5; i++ {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0")
w := httptest.NewRecorder()
ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
var data map[string]interface{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
require.Equal(ts.T(), "ok", data["message"])
}
// 6th request should fail and return a rate limit exceeded error
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0")
w := httptest.NewRecorder()
ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusTooManyRequests, w.Code)
}
type MockCleanup struct {
mock.Mock
}
func (m *MockCleanup) Clean(db *storage.Connection) (int, error) {
m.Called(db)
return 0, nil
}
func (ts *MiddlewareTestSuite) TestDatabaseCleanup() {
testHandler := func(statusCode int) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(statusCode)
b, _ := json.Marshal(map[string]interface{}{"message": "ok"})
w.Write([]byte(b))
})
}
cases := []struct {
desc string
statusCode int
method string
}{
{
desc: "Run cleanup successfully",
statusCode: http.StatusOK,
method: http.MethodPost,
},
{
desc: "Skip cleanup if GET",
statusCode: http.StatusOK,
method: http.MethodGet,
},
{
desc: "Skip cleanup if 3xx",
statusCode: http.StatusSeeOther,
method: http.MethodPost,
},
{
desc: "Skip cleanup if 4xx",
statusCode: http.StatusBadRequest,
method: http.MethodPost,
},
{
desc: "Skip cleanup if 5xx",
statusCode: http.StatusInternalServerError,
method: http.MethodPost,
},
}
mockCleanup := new(MockCleanup)
mockCleanup.On("Clean", mock.Anything).Return(0, nil)
for _, c := range cases {
ts.Run("DatabaseCleanup", func() {
req := httptest.NewRequest(c.method, "http://localhost", nil)
w := httptest.NewRecorder()
ts.API.databaseCleanup(mockCleanup)(testHandler(c.statusCode)).ServeHTTP(w, req)
require.Equal(ts.T(), c.statusCode, w.Code)
})
}
mockCleanup.AssertNumberOfCalls(ts.T(), "Clean", 1)
}