supabase-cli/pkg/pgtest/step.go

133 lines
3.5 KiB
Go

package pgtest
import (
"reflect"
"github.com/go-errors/errors"
"github.com/jackc/pgmock"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgtype"
)
var ci = pgtype.NewConnInfo()
type extendedQueryStep struct {
sql string
params [][]byte
oids []uint32
reply pgmock.Script
}
func (e *extendedQueryStep) Step(backend *pgproto3.Backend) error {
msg, err := getFrontendMessage(backend)
if err != nil {
return err
}
// Handle prepared statements, name can be dynamic: lrupsc_5_0
if m, ok := msg.(*pgproto3.Parse); ok {
want := &pgproto3.Parse{Name: m.Name, Query: e.sql, ParameterOIDs: m.ParameterOIDs}
if !reflect.DeepEqual(m, want) {
return errors.Errorf("expected => %#v\nactual => %#v", want, m)
}
// Anonymous ps falls through
if m.Name != "" {
script := pgmock.Script{Steps: []pgmock.Step{
pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: m.Name}),
pgmock.ExpectMessage(&pgproto3.Sync{}),
pgmock.SendMessage(&pgproto3.ParseComplete{}),
pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: e.oids}),
// Postgres responds pgproto3.RowDescription but it's optional for pgx
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
}}
if err := script.Run(backend); err != nil {
return err
}
}
// Expect bind command next
msg, err = backend.Receive()
if err != nil {
return err
}
}
if m, ok := msg.(*pgproto3.Bind); ok {
var codes []int16
for _, oid := range e.oids {
codes = append(codes, ci.ParamFormatCodeForOID(oid))
}
want := &pgproto3.Bind{
ParameterFormatCodes: codes,
Parameters: e.params,
ResultFormatCodes: []int16{},
DestinationPortal: m.DestinationPortal,
PreparedStatement: m.PreparedStatement,
}
if !reflect.DeepEqual(m, want) {
return errors.Errorf("expected => %#v\nactual => %#v", want, msg)
}
e.reply.Steps = append([]pgmock.Step{
pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'P'}),
pgmock.ExpectMessage(&pgproto3.Execute{}),
pgmock.SendMessage(&pgproto3.ParseComplete{}),
pgmock.SendMessage(&pgproto3.BindComplete{}),
}, e.reply.Steps...)
return e.reply.Run(backend)
}
// Handle simple query
want := &pgproto3.Query{String: e.sql}
if m, ok := msg.(*pgproto3.Query); ok && reflect.DeepEqual(m, want) {
e.reply.Steps = append(e.reply.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
return e.reply.Run(backend)
}
return errors.Errorf("expected => %#v\nactual => %#v", want, msg)
}
// Expects a SQL query in any form: simple, prepared, or anonymous.
func ExpectQuery(sql string, params [][]byte, oids []uint32) pgmock.Step {
return &extendedQueryStep{sql: sql, params: params, oids: oids}
}
type terminateStep struct{}
func (e *terminateStep) Step(backend *pgproto3.Backend) error {
msg, err := getFrontendMessage(backend)
if err != nil {
return err
}
// Handle simple query
if _, ok := msg.(*pgproto3.Terminate); ok {
return nil
}
return errors.Errorf("expected => %#v\nactual => %#v", &pgproto3.Terminate{}, msg)
}
func ExpectTerminate() pgmock.Step {
return &terminateStep{}
}
func getFrontendMessage(backend *pgproto3.Backend) (pgproto3.FrontendMessage, error) {
msg, err := backend.Receive()
if err != nil {
return nil, err
}
// Sync signals end of batch statements
if _, ok := msg.(*pgproto3.Sync); ok {
reply := pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})
if err := reply.Step(backend); err != nil {
return nil, err
}
msg, err = backend.Receive()
if err != nil {
return nil, err
}
}
return msg, nil
}