supabase-cli/pkg/pgtest/mock.go

326 lines
8.9 KiB
Go

package pgtest
import (
"context"
"fmt"
"net"
"reflect"
"testing"
"time"
"github.com/jackc/pgmock"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4"
"github.com/supabase/cli/pkg/pgxv5"
"google.golang.org/grpc/test/bufconn"
)
type MockConn struct {
client *pgx.Conn
// Duplex server listener backed by in-memory buffer
server *bufconn.Listener
// Mock server requests and responses
script pgmock.Script
// Status parameters emitted by postgres on first connect
status map[string]string
// Channel for reporting all server error
errChan chan error
}
func (r *MockConn) getStartupMessage(config *pgx.ConnConfig) []pgmock.Step {
var steps []pgmock.Step
// Add auth message
params := map[string]string{"user": config.User}
if len(config.Database) > 0 {
params["database"] = config.Database
}
steps = append(
steps,
pgmock.ExpectMessage(&pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
Parameters: params,
}),
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
)
// Add status message
r.status["session_authorization"] = config.User
for key, value := range r.status {
steps = append(steps, pgmock.SendMessage(&pgproto3.ParameterStatus{Name: key, Value: value}))
}
// Add ready message
steps = append(
steps,
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
)
return steps
}
// Configures pgx to use the mock dialer.
//
// The mock dialer provides a full duplex net.Conn backed by an in-memory buffer.
// It is implemented by grcp/test/bufconn package.
func (r *MockConn) Intercept(config *pgx.ConnConfig) {
// Override config for test
config.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
return r.server.DialContext(ctx)
}
config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) {
return []string{"127.0.0.1"}, nil
}
config.TLSConfig = nil
// Add startup message
r.script.Steps = append(r.getStartupMessage(config), r.script.Steps...)
}
// Adds a simple query or prepared statement to the mock connection.
func (r *MockConn) Query(sql string, args ...interface{}) *MockConn {
var oids []uint32
var params [][]byte
for _, v := range args {
if value, oid := r.encodeValueArg(v); oid > 0 {
params = append(params, value)
oids = append(oids, oid)
}
}
r.script.Steps = append(r.script.Steps, ExpectQuery(sql, params, oids))
return r
}
func (r *MockConn) encodeValueArg(v interface{}) (value []byte, oid uint32) {
if v == nil {
return nil, pgtype.TextArrayOID
}
dt, ok := ci.DataTypeForValue(v)
if !ok {
r.errChan <- fmt.Errorf("no suitable type for arg: %v", v)
return nil, 0
}
if err := dt.Value.Set(v); err != nil {
r.errChan <- fmt.Errorf("failed to set value: %w", err)
return nil, 0
}
var err error
switch dt.OID {
case pgtype.TextOID:
value, err = (dt.Value).(pgtype.TextEncoder).EncodeText(ci, []byte{})
default:
value, err = (dt.Value).(pgtype.BinaryEncoder).EncodeBinary(ci, []byte{})
}
if err != nil {
r.errChan <- fmt.Errorf("failed to encode arg: %w", err)
return nil, 0
}
return value, dt.OID
}
func getDataTypeSize(v interface{}) int16 {
t := reflect.TypeOf(v)
k := t.Kind()
if k < reflect.Int || k > reflect.Complex128 {
return -1
}
return int16(t.Size())
}
func (r *MockConn) lastQuery() *extendedQueryStep {
return r.script.Steps[len(r.script.Steps)-1].(*extendedQueryStep)
}
// Adds a server reply using binary or text protocol format.
//
// TODO: support prepared statements when using binary protocol
func (r *MockConn) Reply(tag string, rows ...interface{}) *MockConn {
q := r.lastQuery()
// Add field description
if len(rows) > 0 {
var desc pgproto3.RowDescription
if arr, ok := rows[0].([]interface{}); ok {
for i, v := range arr {
name := fmt.Sprintf("c_%02d", i)
if fd := toFieldDescription(v); fd != nil {
fd.Name = []byte(name)
desc.Fields = append(desc.Fields, *fd)
} else {
r.errChan <- fmt.Errorf("failed to describe field: %s", name)
}
}
} else if t := reflect.TypeOf(rows[0]); t.Kind() == reflect.Struct {
s := reflect.ValueOf(rows[0])
for i := 0; i < s.NumField(); i++ {
name := pgxv5.GetColumnName(t.Field(i))
if len(name) == 0 {
continue
}
v := s.Field(i).Interface()
if fd := toFieldDescription(v); fd != nil {
fd.Name = []byte(name)
desc.Fields = append(desc.Fields, *fd)
} else {
r.errChan <- fmt.Errorf("failed to describe field: %s", name)
}
}
} else {
r.errChan <- fmt.Errorf("reply type must be one of [array, struct], found: %s", t.Kind())
}
q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&desc))
} else {
// No data is optional, but we add for completeness
q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&pgproto3.NoData{}))
}
// Add row data
for _, data := range rows {
var dr pgproto3.DataRow
if arr, ok := data.([]interface{}); ok {
for _, v := range arr {
if value, oid := r.encodeValueArg(v); oid > 0 {
dr.Values = append(dr.Values, value)
}
}
} else if t := reflect.TypeOf(data); t.Kind() == reflect.Struct {
s := reflect.ValueOf(rows[0])
for i := 0; i < s.NumField(); i++ {
if name := pgxv5.GetColumnName(t.Field(i)); len(name) == 0 {
continue
}
v := s.Field(i).Interface()
if value, oid := r.encodeValueArg(v); oid > 0 {
dr.Values = append(dr.Values, value)
}
}
} else {
r.errChan <- fmt.Errorf("invalid reply value: %v", data)
}
q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&dr))
}
// Add completion message
var complete pgproto3.BackendMessage
if tag == "" {
complete = &pgproto3.EmptyQueryResponse{}
} else {
complete = &pgproto3.CommandComplete{CommandTag: []byte(tag)}
}
q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(complete))
return r
}
func toFieldDescription(v interface{}) *pgproto3.FieldDescription {
if dt, ok := ci.DataTypeForValue(v); ok {
size := getDataTypeSize(v)
format := ci.ParamFormatCodeForOID(dt.OID)
return &pgproto3.FieldDescription{
TableOID: 17131,
TableAttributeNumber: 1,
DataTypeOID: dt.OID,
DataTypeSize: size,
TypeModifier: -1,
Format: format,
}
}
return nil
}
// Simulates an error reply from the server.
//
// TODO: simulate a notice reply
func (r *MockConn) ReplyError(code, message string) *MockConn {
q := r.lastQuery()
q.reply.Steps = append(
q.reply.Steps,
pgmock.SendMessage(&pgproto3.ErrorResponse{
Severity: "ERROR",
SeverityUnlocalized: "ERROR",
Code: code,
Message: message,
}),
)
return r
}
func (r *MockConn) Close(t *testing.T) {
if r.client != nil {
if err := r.client.Close(context.Background()); err != nil {
t.Errorf("failed to close client: %v", err)
}
}
for err := range r.errChan {
t.Errorf("pgmock error:\n%v", err)
}
if err := r.server.Close(); err != nil {
t.Fatalf("failed to close server: %v", err)
}
}
func (r *MockConn) MockClient(t *testing.T, opts ...func(*pgx.ConnConfig)) *pgx.Conn {
if r.client != nil {
return r.client
}
opts = append(opts, r.Intercept, func(cc *pgx.ConnConfig) {
cc.ConnectTimeout = time.Second * 2
})
var err error
if r.client, err = pgxv5.Connect(context.Background(), "", opts...); err != nil {
t.Errorf("failed to connect: %v", err)
}
return r.client
}
func NewWithStatus(status map[string]string) *MockConn {
const bufSize = 1024 * 1024
mock := MockConn{
server: bufconn.Listen(bufSize),
status: status,
errChan: make(chan error, 10),
}
// Start server in background
const timeout = time.Millisecond * 450
go func() {
defer close(mock.errChan)
// Block until we've opened a TCP connection
conn, err := mock.server.Accept()
if err != nil {
mock.errChan <- err
return
}
defer conn.Close()
// Prevent server from hanging the test
err = conn.SetDeadline(time.Now().Add(timeout))
if err != nil {
mock.errChan <- err
return
}
// Always expect clients to terminate the request
mock.script.Steps = append(mock.script.Steps, ExpectTerminate())
err = mock.script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
if err != nil {
mock.errChan <- err
return
}
}()
return &mock
}
func NewConn() *MockConn {
status := map[string]string{
"application_name": "",
"client_encoding": "UTF8",
"DateStyle": "ISO, MDY",
"default_transaction_read_only": "off",
"in_hot_standby": "off",
"integer_datetimes": "on",
"IntervalStyle": "postgres",
"is_superuser": "on",
"server_encoding": "UTF8",
"server_version": "14.3 (Debian 14.3-1.pgdg110+1)",
"standard_conforming_strings": "on",
"TimeZone": "UTC",
}
return NewWithStatus(status)
}