supabase-cli/internal/debug/postgres.go

166 lines
3.2 KiB
Go

package debug
import (
"context"
"encoding/json"
"errors"
"io"
"log"
"net"
"os"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgx/v4"
"google.golang.org/grpc/test/bufconn"
)
type Proxy struct {
dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
errChan chan error
}
func NewProxy() Proxy {
dialer := net.Dialer{}
return Proxy{
dialContext: dialer.DialContext,
errChan: make(chan error, 1),
}
}
func SetupPGX(config *pgx.ConnConfig) {
proxy := Proxy{
dialContext: config.DialFunc,
errChan: make(chan error, 1),
}
config.DialFunc = proxy.DialFunc
config.TLSConfig = nil
}
func (p *Proxy) DialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
serverConn, err := p.dialContext(ctx, network, addr)
if err != nil {
return nil, err
}
const bufSize = 1024 * 1024
ln := bufconn.Listen(bufSize)
go func() {
defer serverConn.Close()
clientConn, err := ln.Accept()
if err != nil {
// Unreachable code as bufconn never throws, but just in case
panic(err)
}
defer clientConn.Close()
backend := NewBackend(clientConn)
frontend := NewFrontend(serverConn)
go backend.forward(frontend, p.errChan)
go frontend.forward(backend, p.errChan)
for {
// Since pgx closes connection first, every EOF is seen as unexpected
if err := <-p.errChan; err != nil && !errors.Is(err, io.ErrUnexpectedEOF) {
panic(err)
}
}
}()
return ln.DialContext(ctx)
}
type Backend struct {
*pgproto3.Backend
logger *log.Logger
}
func NewBackend(clientConn net.Conn) Backend {
return Backend{
pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn),
log.New(os.Stderr, "PG Recv: ", log.LstdFlags|log.Lmsgprefix),
}
}
func (b *Backend) forward(frontend Frontend, errChan chan error) {
startupMessage, err := b.ReceiveStartupMessage()
if err != nil {
errChan <- err
return
}
buf, err := json.Marshal(startupMessage)
if err != nil {
errChan <- err
return
}
frontend.logger.Println(string(buf))
if err = frontend.Send(startupMessage); err != nil {
errChan <- err
return
}
for {
msg, err := b.Receive()
if err != nil {
errChan <- err
return
}
buf, err := json.Marshal(msg)
if err != nil {
errChan <- err
return
}
frontend.logger.Println(string(buf))
if err = frontend.Send(msg); err != nil {
errChan <- err
return
}
}
}
type Frontend struct {
*pgproto3.Frontend
logger *log.Logger
}
func NewFrontend(serverConn net.Conn) Frontend {
return Frontend{
pgproto3.NewFrontend(pgproto3.NewChunkReader(serverConn), serverConn),
log.New(os.Stderr, "PG Send: ", log.LstdFlags|log.Lmsgprefix),
}
}
func (f *Frontend) forward(backend Backend, errChan chan error) {
for {
msg, err := f.Receive()
if err != nil {
errChan <- err
return
}
buf, err := json.Marshal(msg)
if err != nil {
errChan <- err
return
}
backend.logger.Println(string(buf))
if _, ok := msg.(pgproto3.AuthenticationResponseMessage); ok {
// Set the authentication type so the next backend.Receive() will
// properly decode the appropriate 'p' message.
if err := backend.SetAuthType(f.GetAuthType()); err != nil {
errChan <- err
return
}
}
if err := backend.Send(msg); err != nil {
errChan <- err
return
}
}
}