133 lines
3.5 KiB
Go
133 lines
3.5 KiB
Go
package pgtest
|
|
|
|
import (
|
|
"reflect"
|
|
|
|
"github.com/go-errors/errors"
|
|
"github.com/jackc/pgmock"
|
|
"github.com/jackc/pgproto3/v2"
|
|
"github.com/jackc/pgtype"
|
|
)
|
|
|
|
var ci = pgtype.NewConnInfo()
|
|
|
|
type extendedQueryStep struct {
|
|
sql string
|
|
params [][]byte
|
|
oids []uint32
|
|
reply pgmock.Script
|
|
}
|
|
|
|
func (e *extendedQueryStep) Step(backend *pgproto3.Backend) error {
|
|
msg, err := getFrontendMessage(backend)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Handle prepared statements, name can be dynamic: lrupsc_5_0
|
|
if m, ok := msg.(*pgproto3.Parse); ok {
|
|
want := &pgproto3.Parse{Name: m.Name, Query: e.sql, ParameterOIDs: m.ParameterOIDs}
|
|
if !reflect.DeepEqual(m, want) {
|
|
return errors.Errorf("expected => %#v\nactual => %#v", want, m)
|
|
}
|
|
// Anonymous ps falls through
|
|
if m.Name != "" {
|
|
script := pgmock.Script{Steps: []pgmock.Step{
|
|
pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: m.Name}),
|
|
pgmock.ExpectMessage(&pgproto3.Sync{}),
|
|
pgmock.SendMessage(&pgproto3.ParseComplete{}),
|
|
pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: e.oids}),
|
|
// Postgres responds pgproto3.RowDescription but it's optional for pgx
|
|
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
|
}}
|
|
if err := script.Run(backend); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// Expect bind command next
|
|
msg, err = backend.Receive()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if m, ok := msg.(*pgproto3.Bind); ok {
|
|
var codes []int16
|
|
for _, oid := range e.oids {
|
|
codes = append(codes, ci.ParamFormatCodeForOID(oid))
|
|
}
|
|
want := &pgproto3.Bind{
|
|
ParameterFormatCodes: codes,
|
|
Parameters: e.params,
|
|
ResultFormatCodes: []int16{},
|
|
DestinationPortal: m.DestinationPortal,
|
|
PreparedStatement: m.PreparedStatement,
|
|
}
|
|
if !reflect.DeepEqual(m, want) {
|
|
return errors.Errorf("expected => %#v\nactual => %#v", want, msg)
|
|
}
|
|
e.reply.Steps = append([]pgmock.Step{
|
|
pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'P'}),
|
|
pgmock.ExpectMessage(&pgproto3.Execute{}),
|
|
pgmock.SendMessage(&pgproto3.ParseComplete{}),
|
|
pgmock.SendMessage(&pgproto3.BindComplete{}),
|
|
}, e.reply.Steps...)
|
|
return e.reply.Run(backend)
|
|
}
|
|
|
|
// Handle simple query
|
|
want := &pgproto3.Query{String: e.sql}
|
|
if m, ok := msg.(*pgproto3.Query); ok && reflect.DeepEqual(m, want) {
|
|
e.reply.Steps = append(e.reply.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
|
|
return e.reply.Run(backend)
|
|
}
|
|
|
|
return errors.Errorf("expected => %#v\nactual => %#v", want, msg)
|
|
}
|
|
|
|
// Expects a SQL query in any form: simple, prepared, or anonymous.
|
|
func ExpectQuery(sql string, params [][]byte, oids []uint32) pgmock.Step {
|
|
return &extendedQueryStep{sql: sql, params: params, oids: oids}
|
|
}
|
|
|
|
type terminateStep struct{}
|
|
|
|
func (e *terminateStep) Step(backend *pgproto3.Backend) error {
|
|
msg, err := getFrontendMessage(backend)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Handle simple query
|
|
if _, ok := msg.(*pgproto3.Terminate); ok {
|
|
return nil
|
|
}
|
|
|
|
return errors.Errorf("expected => %#v\nactual => %#v", &pgproto3.Terminate{}, msg)
|
|
}
|
|
|
|
func ExpectTerminate() pgmock.Step {
|
|
return &terminateStep{}
|
|
}
|
|
|
|
func getFrontendMessage(backend *pgproto3.Backend) (pgproto3.FrontendMessage, error) {
|
|
msg, err := backend.Receive()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Sync signals end of batch statements
|
|
if _, ok := msg.(*pgproto3.Sync); ok {
|
|
reply := pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})
|
|
if err := reply.Step(backend); err != nil {
|
|
return nil, err
|
|
}
|
|
msg, err = backend.Receive()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return msg, nil
|
|
}
|