122 lines
3.8 KiB
Go
122 lines
3.8 KiB
Go
package update
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"testing"
|
|
|
|
"github.com/go-errors/errors"
|
|
"github.com/h2non/gock"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/supabase/cli/internal/testing/apitest"
|
|
"github.com/supabase/cli/internal/utils"
|
|
"github.com/supabase/cli/pkg/api"
|
|
)
|
|
|
|
func TestUpdateRestrictionsCommand(t *testing.T) {
|
|
projectRef := apitest.RandomProjectRef()
|
|
// Setup valid access token
|
|
token := apitest.RandomAccessToken(t)
|
|
t.Setenv("SUPABASE_ACCESS_TOKEN", string(token))
|
|
|
|
t.Run("updates v4 and v6 CIDR", func(t *testing.T) {
|
|
// Setup mock api
|
|
defer gock.OffAll()
|
|
gock.New(utils.DefaultApiHost).
|
|
Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
|
|
MatchType("json").
|
|
JSON(api.NetworkRestrictionsRequest{
|
|
DbAllowedCidrs: &[]string{"12.3.4.5/32", "1.2.3.1/24"},
|
|
DbAllowedCidrsV6: &[]string{"2001:db8:abcd:0012::0/64"},
|
|
}).
|
|
Reply(http.StatusCreated).
|
|
JSON(api.NetworkRestrictionsResponse{
|
|
Status: api.NetworkRestrictionsResponseStatus("applied"),
|
|
})
|
|
// Run test
|
|
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "2001:db8:abcd:0012::0/64", "1.2.3.1/24"}, false)
|
|
// Check error
|
|
assert.NoError(t, err)
|
|
assert.Empty(t, apitest.ListUnmatchedRequests())
|
|
})
|
|
|
|
t.Run("throws error on network failure", func(t *testing.T) {
|
|
errNetwork := errors.New("network error")
|
|
// Setup mock api
|
|
defer gock.OffAll()
|
|
gock.New(utils.DefaultApiHost).
|
|
Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
|
|
MatchType("json").
|
|
JSON(api.NetworkRestrictionsRequest{
|
|
DbAllowedCidrs: &[]string{},
|
|
DbAllowedCidrsV6: &[]string{},
|
|
}).
|
|
ReplyError(errNetwork)
|
|
// Run test
|
|
err := Run(context.Background(), projectRef, []string{}, true)
|
|
// Check error
|
|
assert.ErrorIs(t, err, errNetwork)
|
|
assert.Empty(t, apitest.ListUnmatchedRequests())
|
|
})
|
|
|
|
t.Run("throws error on server unavailable", func(t *testing.T) {
|
|
// Setup mock api
|
|
defer gock.OffAll()
|
|
gock.New(utils.DefaultApiHost).
|
|
Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
|
|
MatchType("json").
|
|
JSON(api.NetworkRestrictionsRequest{
|
|
DbAllowedCidrs: &[]string{},
|
|
DbAllowedCidrsV6: &[]string{},
|
|
}).
|
|
Reply(http.StatusServiceUnavailable)
|
|
// Run test
|
|
err := Run(context.Background(), projectRef, []string{}, true)
|
|
// Check error
|
|
assert.ErrorContains(t, err, "failed to apply network restrictions:")
|
|
assert.Empty(t, apitest.ListUnmatchedRequests())
|
|
})
|
|
}
|
|
|
|
func TestValidateCIDR(t *testing.T) {
|
|
projectRef := apitest.RandomProjectRef()
|
|
// Setup valid access token
|
|
token := apitest.RandomAccessToken(t)
|
|
t.Setenv("SUPABASE_ACCESS_TOKEN", string(token))
|
|
|
|
t.Run("bypasses private subnet checks", func(t *testing.T) {
|
|
// Setup mock api
|
|
defer gock.OffAll()
|
|
gock.New(utils.DefaultApiHost).
|
|
Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
|
|
MatchType("json").
|
|
JSON(api.NetworkRestrictionsRequest{
|
|
DbAllowedCidrs: &[]string{"10.0.0.0/8"},
|
|
DbAllowedCidrsV6: &[]string{},
|
|
}).
|
|
Reply(http.StatusCreated).
|
|
JSON(api.NetworkRestrictionsResponse{
|
|
Status: api.NetworkRestrictionsResponseStatus("applied"),
|
|
})
|
|
// Run test
|
|
err := Run(context.Background(), projectRef, []string{"10.0.0.0/8"}, true)
|
|
// Check error
|
|
assert.NoError(t, err)
|
|
assert.Empty(t, apitest.ListUnmatchedRequests())
|
|
})
|
|
|
|
t.Run("throws error on private subnet", func(t *testing.T) {
|
|
// Run test
|
|
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "10.0.0.0/8", "1.2.3.1/24"}, false)
|
|
// Check error
|
|
assert.ErrorContains(t, err, "private IP provided: 10.0.0.0/8")
|
|
})
|
|
|
|
t.Run("throws error on invalid subnet", func(t *testing.T) {
|
|
// Run test
|
|
err := Run(context.Background(), projectRef, []string{"12.3.4.5", "10.0.0.0/8", "1.2.3.1/24"}, false)
|
|
// Check error
|
|
assert.ErrorContains(t, err, "failed to parse IP: 12.3.4.5")
|
|
})
|
|
}
|