package utils import ( "context" "errors" "net" "net/http" "testing" "github.com/h2non/gock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/supabase/cli/internal/testing/apitest" "github.com/supabase/cli/internal/utils/cloudflare" ) const host = "api.supabase.io" func TestLookupIP(t *testing.T) { t.Run("resolves IPv4 with CloudFlare", func(t *testing.T) { // Setup http mock defer gock.OffAll() gock.New("https://1.1.1.1"). Get("/dns-query"). MatchParam("name", host). MatchHeader("accept", "application/dns-json"). Reply(http.StatusOK). JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{ {Type: cloudflare.TypeA, Data: "127.0.0.1"}, }}) // Run test ip, err := FallbackLookupIP(context.Background(), host) // Validate output assert.NoError(t, err) assert.ElementsMatch(t, []string{"127.0.0.1"}, ip) assert.Empty(t, apitest.ListUnmatchedRequests()) }) t.Run("resolves IPv6 recursively", func(t *testing.T) { // Setup http mock defer gock.OffAll() gock.New("https://1.1.1.1"). Get("/dns-query"). MatchParam("name", "api.supabase.com"). MatchHeader("accept", "application/dns-json"). Reply(http.StatusOK). JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{ {Type: cloudflare.TypeCNAME, Data: "supabase-api.fly.dev."}, {Type: cloudflare.TypeAAAA, Data: "2606:2800:220:1:248:1893:25c8:1946"}, }}) // Run test ip, err := FallbackLookupIP(context.Background(), "api.supabase.com") // Validate output assert.NoError(t, err) assert.ElementsMatch(t, []string{"2606:2800:220:1:248:1893:25c8:1946"}, ip) assert.Empty(t, apitest.ListUnmatchedRequests()) }) t.Run("returns immediately if already resolved", func(t *testing.T) { // Run test ip, err := FallbackLookupIP(context.Background(), "127.0.0.1") // Validate output assert.NoError(t, err) assert.ElementsMatch(t, []string{"127.0.0.1"}, ip) assert.Empty(t, apitest.ListUnmatchedRequests()) }) t.Run("empty on network failure", func(t *testing.T) { // Setup http mock defer gock.OffAll() gock.New("https://1.1.1.1"). Get("/dns-query"). MatchParam("name", host). MatchHeader("accept", "application/dns-json"). ReplyError(errors.New("network error")) // Run test ip, err := FallbackLookupIP(context.Background(), host) // Validate output assert.ErrorContains(t, err, "network error") assert.Empty(t, ip) assert.Empty(t, apitest.ListUnmatchedRequests()) }) t.Run("empty on service unavailable", func(t *testing.T) { // Setup http mock defer gock.OffAll() gock.New("https://1.1.1.1"). Get("/dns-query"). MatchParam("name", host). MatchHeader("accept", "application/dns-json"). Reply(http.StatusServiceUnavailable) // Run test ip, err := FallbackLookupIP(context.Background(), host) // Validate output assert.ErrorContains(t, err, "status 503") assert.Empty(t, ip) assert.Empty(t, apitest.ListUnmatchedRequests()) }) t.Run("empty on malformed json", func(t *testing.T) { // Setup http mock defer gock.OffAll() gock.New("https://1.1.1.1"). Get("/dns-query"). MatchParam("name", host). MatchHeader("accept", "application/dns-json"). Reply(http.StatusOK). JSON("malformed") // Run test ip, err := FallbackLookupIP(context.Background(), host) // Validate output assert.ErrorContains(t, err, "invalid character 'm' looking for beginning of value") assert.Empty(t, ip) assert.Empty(t, apitest.ListUnmatchedRequests()) }) t.Run("empty on no answer", func(t *testing.T) { // Setup http mock defer gock.OffAll() gock.New("https://1.1.1.1"). Get("/dns-query"). MatchParam("name", host). MatchHeader("accept", "application/dns-json"). Reply(http.StatusOK). JSON(&cloudflare.DNSResponse{}) // Run test ip, err := FallbackLookupIP(context.Background(), host) // Validate output assert.ErrorContains(t, err, "failed to locate valid IP for api.supabase.io; resolves to []cloudflare.DNSAnswer(nil)") assert.Empty(t, ip) assert.Empty(t, apitest.ListUnmatchedRequests()) }) } func TestResolveCNAME(t *testing.T) { t.Run("resolves CNAMEs with CloudFlare", func(t *testing.T) { defer gock.OffAll() gock.New("https://1.1.1.1"). Get("/dns-query"). MatchParam("name", host). MatchParam("type", "5"). MatchHeader("accept", "application/dns-json"). Reply(http.StatusOK). JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{ {Type: cloudflare.TypeCNAME, Data: "foobarbaz.supabase.co"}, }}) // Run test cname, err := ResolveCNAME(context.Background(), host) // Validate output assert.Equal(t, "foobarbaz.supabase.co", cname) assert.Nil(t, err) assert.Empty(t, apitest.ListUnmatchedRequests()) }) t.Run("missing CNAMEs return an error", func(t *testing.T) { defer gock.OffAll() gock.New("https://1.1.1.1"). Get("/dns-query"). MatchParam("name", host). MatchParam("type", "5"). MatchHeader("accept", "application/dns-json"). Reply(http.StatusOK). JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{}}) // Run test cname, err := ResolveCNAME(context.Background(), host) // Validate output assert.Empty(t, cname) assert.ErrorContains(t, err, "failed to locate appropriate CNAME record for api.supabase.io") assert.Empty(t, apitest.ListUnmatchedRequests()) }) t.Run("missing CNAMEs return an error", func(t *testing.T) { defer gock.OffAll() gock.New("https://1.1.1.1"). Get("/dns-query"). MatchParam("name", host). MatchParam("type", "5"). MatchHeader("accept", "application/dns-json"). Reply(http.StatusOK). JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{ {Type: cloudflare.TypeA, Data: "127.0.0.1"}, }}) // Run test cname, err := ResolveCNAME(context.Background(), host) // Validate output assert.Empty(t, cname) assert.ErrorContains(t, err, "failed to locate appropriate CNAME record for api.supabase.io") assert.Empty(t, apitest.ListUnmatchedRequests()) }) } type MockDialer struct { mock.Mock } func (m *MockDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { args := m.Called(ctx, network, address) if conn, ok := args.Get(0).(net.Conn); ok { return conn, args.Error(1) } return nil, args.Error(1) } func TestFallbackDNS(t *testing.T) { errNetwork := errors.New("network error") errDNS := &net.DNSError{ IsTimeout: true, } t.Run("overrides DialContext with DoH", func(t *testing.T) { DNSResolver.Value = DNS_OVER_HTTPS // Setup mock dialer dialer := MockDialer{} dialer.On("DialContext", mock.Anything, mock.Anything, "127.0.0.1:80"). Return(nil, errNetwork) wrapped := withFallbackDNS(dialer.DialContext) // Setup http mock defer gock.OffAll() gock.New("https://1.1.1.1"). Get("/dns-query"). MatchParam("name", host). MatchHeader("accept", "application/dns-json"). Reply(http.StatusOK). JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{ {Type: cloudflare.TypeA, Data: "127.0.0.1"}, }}) // Run test conn, err := wrapped(context.Background(), "udp", host+":80") // Check error assert.ErrorIs(t, err, errNetwork) assert.Nil(t, conn) dialer.AssertExpectations(t) assert.Empty(t, apitest.ListUnmatchedRequests()) }) t.Run("native with DoH fallback", func(t *testing.T) { DNSResolver.Value = DNS_GO_NATIVE // Setup mock dialer dialer := MockDialer{} dialer.On("DialContext", mock.Anything, mock.Anything, host+":80"). Return(nil, errDNS) dialer.On("DialContext", mock.Anything, mock.Anything, "127.0.0.1:80"). Return(nil, nil) wrapped := withFallbackDNS(dialer.DialContext) // Setup http mock defer gock.OffAll() gock.New("https://1.1.1.1"). Get("/dns-query"). MatchParam("name", host). MatchHeader("accept", "application/dns-json"). Reply(http.StatusOK). JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{ {Type: cloudflare.TypeA, Data: "127.0.0.1"}, }}) // Run test conn, err := wrapped(context.Background(), "udp", host+":80") // Check error assert.NoError(t, err) assert.Nil(t, conn) dialer.AssertExpectations(t) assert.Empty(t, apitest.ListUnmatchedRequests()) }) t.Run("throws error on malformed address", func(t *testing.T) { DNSResolver.Value = DNS_OVER_HTTPS // Setup mock dialer dialer := MockDialer{} wrapped := withFallbackDNS(dialer.DialContext) // Run test conn, err := wrapped(context.Background(), "udp", "bad?url") // Check error assert.ErrorContains(t, err, "missing port in address") assert.Nil(t, conn) assert.Empty(t, apitest.ListUnmatchedRequests()) }) t.Run("throws error on fallback failure", func(t *testing.T) { DNSResolver.Value = DNS_GO_NATIVE // Setup mock dialer dialer := MockDialer{} dialer.On("DialContext", mock.Anything, mock.Anything, host+":80"). Return(nil, errDNS) wrapped := withFallbackDNS(dialer.DialContext) // Setup http mock defer gock.OffAll() gock.New("https://1.1.1.1"). Get("/dns-query"). MatchParam("name", host). MatchHeader("accept", "application/dns-json"). ReplyError(errNetwork) // Run test conn, err := wrapped(context.Background(), "udp", host+":80") // Check error assert.ErrorIs(t, err, errDNS) assert.Nil(t, conn) dialer.AssertExpectations(t) assert.Empty(t, apitest.ListUnmatchedRequests()) }) }