package utils import ( "context" "fmt" "io" "net" "net/url" "os" "strings" "time" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" "github.com/spf13/viper" "github.com/supabase/cli/internal/debug" "github.com/supabase/cli/pkg/pgxv5" ) func ToPostgresURL(config pgconn.Config) string { timeoutSecond := int64(config.ConnectTimeout.Seconds()) if timeoutSecond == 0 { timeoutSecond = 10 } queryParams := fmt.Sprintf("connect_timeout=%d", timeoutSecond) for k, v := range config.RuntimeParams { queryParams += fmt.Sprintf("&%s=%s", k, url.QueryEscape(v)) } // IPv6 address must be wrapped in square brackets host := config.Host if ip := net.ParseIP(host); ip != nil && ip.To4() == nil { host = fmt.Sprintf("[%s]", host) } return fmt.Sprintf( "postgresql://%s@%s:%d/%s?%s", url.UserPassword(config.User, config.Password), host, config.Port, url.PathEscape(config.Database), queryParams, ) } func GetPoolerConfig(projectRef string) *pgconn.Config { logger := GetDebugLogger() if len(Config.Db.Pooler.ConnectionString) == 0 { fmt.Fprintln(logger, "Pooler URL is not configured") return nil } // Remove password from pooler connection string because the placeholder text // [YOUR-PASSWORD] messes up pgconn.ParseConfig. The password must be percent // escaped so we cannot simply call strings.Replace with actual password. poolerUrl := strings.ReplaceAll(Config.Db.Pooler.ConnectionString, "[YOUR-PASSWORD]", "") poolerConfig, err := pgconn.ParseConfig(poolerUrl) if err != nil { fmt.Fprintln(logger, "Failed to parse pooler URL:", poolerUrl) return nil } if poolerConfig.RuntimeParams == nil { poolerConfig.RuntimeParams = make(map[string]string) } // Verify that the pooler username matches the database host being connected to if _, ref, found := strings.Cut(poolerConfig.User, "."); !found { for _, option := range strings.Split(poolerConfig.RuntimeParams["options"], ",") { key, value, found := strings.Cut(option, "=") if found && key == "reference" && value != projectRef { fmt.Fprintln(logger, "Pooler options does not match project ref:", projectRef) return nil } } } else if projectRef != ref { fmt.Fprintln(logger, "Pooler username does not match project ref:", projectRef) return nil } // There is a risk of MITM attack if we simply trust the hostname specified in pooler URL. if !isSupabaseDomain(poolerConfig.Host) { fmt.Fprintln(logger, "Pooler hostname does not belong to Supabase domain:", poolerConfig.Host) return nil } fmt.Fprintln(logger, "Using connection pooler:", poolerUrl) // Supavisor transaction mode does not support prepared statement poolerConfig.Port = 5432 return poolerConfig } func isSupabaseDomain(host string) bool { switch GetSupabaseAPIHost() { case "https://api.supabase.green": return strings.HasSuffix(host, ".supabase.green") default: return strings.HasSuffix(host, ".supabase.com") } } // Connnect to local Postgres with optimised settings. The caller is responsible for closing the connection returned. func ConnectLocalPostgres(ctx context.Context, config pgconn.Config, options ...func(*pgx.ConnConfig)) (*pgx.Conn, error) { if len(config.Host) == 0 { config.Host = Config.Hostname } if config.Port == 0 { config.Port = Config.Db.Port } if len(config.User) == 0 { config.User = "postgres" } if len(config.Password) == 0 { config.Password = Config.Db.Password } if len(config.Database) == 0 { config.Database = "postgres" } if config.ConnectTimeout == 0 { config.ConnectTimeout = 2 * time.Second } return ConnectByUrl(ctx, ToPostgresURL(config), options...) } func ConnectByUrl(ctx context.Context, url string, options ...func(*pgx.ConnConfig)) (*pgx.Conn, error) { if viper.GetBool("DEBUG") { options = append(options, debug.SetupPGX) } return pgxv5.Connect(ctx, url, options...) } func ConnectByConfigStream(ctx context.Context, config pgconn.Config, w io.Writer, options ...func(*pgx.ConnConfig)) (*pgx.Conn, error) { if IsLocalDatabase(config) { fmt.Fprintln(w, "Connecting to local database...") return ConnectLocalPostgres(ctx, config, options...) } fmt.Fprintln(w, "Connecting to remote database...") opts := append(options, func(cc *pgx.ConnConfig) { if DNSResolver.Value == DNS_OVER_HTTPS { cc.LookupFunc = FallbackLookupIP } }) return ConnectByUrl(ctx, ToPostgresURL(config), opts...) } func ConnectByConfig(ctx context.Context, config pgconn.Config, options ...func(*pgx.ConnConfig)) (*pgx.Conn, error) { return ConnectByConfigStream(ctx, config, os.Stderr, options...) } func IsLocalDatabase(config pgconn.Config) bool { return config.Host == Config.Hostname && config.Port == Config.Db.Port }