supabase-cli/internal/utils/api.go

261 lines
7.4 KiB
Go

package utils
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"log"
"net"
"net/http"
"net/http/httptrace"
"net/textproto"
"sync"
"github.com/go-errors/errors"
"github.com/spf13/viper"
"github.com/supabase/cli/internal/utils/cloudflare"
supabase "github.com/supabase/cli/pkg/api"
"github.com/supabase/cli/pkg/cast"
)
const (
DNS_GO_NATIVE = "native"
DNS_OVER_HTTPS = "https"
)
var (
clientOnce sync.Once
apiClient *supabase.ClientWithResponses
DNSResolver = EnumFlag{
Allowed: []string{DNS_GO_NATIVE, DNS_OVER_HTTPS},
Value: DNS_GO_NATIVE,
}
)
// Performs DNS lookup via HTTPS, in case firewall blocks native netgo resolver.
func FallbackLookupIP(ctx context.Context, host string) ([]string, error) {
if net.ParseIP(host) != nil {
return []string{host}, nil
}
// Ref: https://developers.cloudflare.com/1.1.1.1/encryption/dns-over-https/make-api-requests/dns-json
cf := cloudflare.NewCloudflareAPI()
data, err := cf.DNSQuery(ctx, cloudflare.DNSParams{Name: host})
if err != nil {
return nil, err
}
// Look for first valid IP
var resolved []string
for _, answer := range data.Answer {
if answer.Type == cloudflare.TypeA || answer.Type == cloudflare.TypeAAAA {
resolved = append(resolved, answer.Data)
}
}
if len(resolved) == 0 {
return nil, errors.Errorf("failed to locate valid IP for %s; resolves to %#v", host, data.Answer)
}
return resolved, nil
}
func ResolveCNAME(ctx context.Context, host string) (string, error) {
// Ref: https://developers.cloudflare.com/1.1.1.1/encryption/dns-over-https/make-api-requests/dns-json
cf := cloudflare.NewCloudflareAPI()
data, err := cf.DNSQuery(ctx, cloudflare.DNSParams{Name: host, Type: cast.Ptr(cloudflare.TypeCNAME)})
if err != nil {
return "", err
}
// Look for first valid IP
for _, answer := range data.Answer {
if answer.Type == cloudflare.TypeCNAME {
return answer.Data, nil
}
}
serialized, err := json.MarshalIndent(data.Answer, "", " ")
if err != nil {
// we ignore the error (not great), and use the underlying struct in our error message
return "", errors.Errorf("failed to locate appropriate CNAME record for %s; resolves to %+v", host, data.Answer)
}
return "", errors.Errorf("failed to locate appropriate CNAME record for %s; resolves to %+v", host, serialized)
}
func WithTraceContext(ctx context.Context) context.Context {
trace := &httptrace.ClientTrace{
DNSStart: func(info httptrace.DNSStartInfo) {
log.Printf("DNS Start: %+v\n", info)
},
DNSDone: func(info httptrace.DNSDoneInfo) {
if info.Err != nil {
log.Println("DNS Error:", info.Err)
} else {
log.Printf("DNS Done: %+v\n", info)
}
},
ConnectStart: func(network, addr string) {
log.Println("Connect Start:", network, addr)
},
ConnectDone: func(network, addr string, err error) {
if err != nil {
log.Println("Connect Error:", network, addr, err)
} else {
log.Println("Connect Done:", network, addr)
}
},
TLSHandshakeStart: func() {
log.Println("TLS Start")
},
TLSHandshakeDone: func(cs tls.ConnectionState, err error) {
if err != nil {
log.Println("TLS Error:", err)
} else {
log.Printf("TLS Done: %+v\n", cs)
}
},
WroteHeaderField: func(key string, value []string) {
log.Println("Sent Header:", key, value)
},
WroteRequest: func(wr httptrace.WroteRequestInfo) {
if wr.Err != nil {
log.Println("Send Error:", wr.Err)
} else {
log.Println("Send Done")
}
},
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
log.Println("Recv 1xx:", code, header)
return nil
},
GotFirstResponseByte: func() {
log.Println("Recv First Byte")
},
}
return httptrace.WithClientTrace(ctx, trace)
}
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
// Wraps a DialContext with DNS-over-HTTPS as fallback resolver
func withFallbackDNS(dialContext DialContextFunc) DialContextFunc {
dnsOverHttps := func(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, errors.Errorf("failed to split host port: %w", err)
}
ip, err := FallbackLookupIP(ctx, host)
if err != nil {
return nil, err
}
conn, err := dialContext(ctx, network, net.JoinHostPort(ip[0], port))
if err != nil {
return nil, errors.Errorf("failed to dial fallback: %w", err)
}
return conn, nil
}
if DNSResolver.Value == DNS_OVER_HTTPS {
return dnsOverHttps
}
nativeWithFallback := func(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := dialContext(ctx, network, address)
// Workaround when pure Go DNS resolver fails https://github.com/golang/go/issues/12524
if err, ok := err.(net.Error); ok && err.Timeout() {
if conn, err := dnsOverHttps(ctx, network, address); err == nil {
return conn, nil
}
}
if err != nil {
return nil, errors.Errorf("failed to dial native: %w", err)
}
return conn, nil
}
return nativeWithFallback
}
func GetSupabase() *supabase.ClientWithResponses {
clientOnce.Do(func() {
token, err := LoadAccessToken()
if err != nil {
log.Fatalln(err)
}
if t, ok := http.DefaultTransport.(*http.Transport); ok {
t.DialContext = withFallbackDNS(t.DialContext)
}
apiClient, err = supabase.NewClientWithResponses(
GetSupabaseAPIHost(),
supabase.WithRequestEditorFn(func(ctx context.Context, req *http.Request) error {
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "SupabaseCLI/"+Version)
return nil
}),
)
if err != nil {
log.Fatalln(err)
}
})
return apiClient
}
const (
DefaultApiHost = "https://api.supabase.com"
// DEPRECATED
DeprecatedApiHost = "https://api.supabase.io"
)
var RegionMap = map[string]string{
"ap-northeast-1": "Northeast Asia (Tokyo)",
"ap-northeast-2": "Northeast Asia (Seoul)",
"ap-south-1": "South Asia (Mumbai)",
"ap-southeast-1": "Southeast Asia (Singapore)",
"ap-southeast-2": "Oceania (Sydney)",
"ca-central-1": "Canada (Central)",
"eu-central-1": "Central EU (Frankfurt)",
"eu-west-1": "West EU (Ireland)",
"eu-west-2": "West EU (London)",
"eu-west-3": "West EU (Paris)",
"sa-east-1": "South America (São Paulo)",
"us-east-1": "East US (North Virginia)",
"us-west-1": "West US (North California)",
"us-west-2": "West US (Oregon)",
}
func GetSupabaseAPIHost() string {
apiHost := viper.GetString("INTERNAL_API_HOST")
if apiHost == "" {
apiHost = DefaultApiHost
}
return apiHost
}
func GetSupabaseDashboardURL() string {
switch GetSupabaseAPIHost() {
case DefaultApiHost, DeprecatedApiHost:
return "https://supabase.com/dashboard"
case "https://api.supabase.green":
return "https://supabase.green/dashboard"
default:
return "http://127.0.0.1:8082"
}
}
func GetSupabaseDbHost(projectRef string) string {
// TODO: query projects api for db_host
switch GetSupabaseAPIHost() {
case DefaultApiHost, DeprecatedApiHost:
return fmt.Sprintf("db.%s.supabase.co", projectRef)
case "https://api.supabase.green":
return fmt.Sprintf("db.%s.supabase.red", projectRef)
default:
return fmt.Sprintf("db.%s.supabase.red", projectRef)
}
}
func GetSupabaseHost(projectRef string) string {
switch GetSupabaseAPIHost() {
case DefaultApiHost, DeprecatedApiHost:
return fmt.Sprintf("%s.supabase.co", projectRef)
case "https://api.supabase.green":
return fmt.Sprintf("%s.supabase.red", projectRef)
default:
return fmt.Sprintf("%s.supabase.red", projectRef)
}
}