package pgtest import ( "reflect" "github.com/go-errors/errors" "github.com/jackc/pgmock" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgtype" ) var ci = pgtype.NewConnInfo() type extendedQueryStep struct { sql string params [][]byte oids []uint32 reply pgmock.Script } func (e *extendedQueryStep) Step(backend *pgproto3.Backend) error { msg, err := getFrontendMessage(backend) if err != nil { return err } // Handle prepared statements, name can be dynamic: lrupsc_5_0 if m, ok := msg.(*pgproto3.Parse); ok { want := &pgproto3.Parse{Name: m.Name, Query: e.sql, ParameterOIDs: m.ParameterOIDs} if !reflect.DeepEqual(m, want) { return errors.Errorf("expected => %#v\nactual => %#v", want, m) } // Anonymous ps falls through if m.Name != "" { script := pgmock.Script{Steps: []pgmock.Step{ pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: m.Name}), pgmock.ExpectMessage(&pgproto3.Sync{}), pgmock.SendMessage(&pgproto3.ParseComplete{}), pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: e.oids}), // Postgres responds pgproto3.RowDescription but it's optional for pgx pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), }} if err := script.Run(backend); err != nil { return err } } // Expect bind command next msg, err = backend.Receive() if err != nil { return err } } if m, ok := msg.(*pgproto3.Bind); ok { var codes []int16 for _, oid := range e.oids { codes = append(codes, ci.ParamFormatCodeForOID(oid)) } want := &pgproto3.Bind{ ParameterFormatCodes: codes, Parameters: e.params, ResultFormatCodes: []int16{}, DestinationPortal: m.DestinationPortal, PreparedStatement: m.PreparedStatement, } if !reflect.DeepEqual(m, want) { return errors.Errorf("expected => %#v\nactual => %#v", want, msg) } e.reply.Steps = append([]pgmock.Step{ pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'P'}), pgmock.ExpectMessage(&pgproto3.Execute{}), pgmock.SendMessage(&pgproto3.ParseComplete{}), pgmock.SendMessage(&pgproto3.BindComplete{}), }, e.reply.Steps...) return e.reply.Run(backend) } // Handle simple query want := &pgproto3.Query{String: e.sql} if m, ok := msg.(*pgproto3.Query); ok && reflect.DeepEqual(m, want) { e.reply.Steps = append(e.reply.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})) return e.reply.Run(backend) } return errors.Errorf("expected => %#v\nactual => %#v", want, msg) } // Expects a SQL query in any form: simple, prepared, or anonymous. func ExpectQuery(sql string, params [][]byte, oids []uint32) pgmock.Step { return &extendedQueryStep{sql: sql, params: params, oids: oids} } type terminateStep struct{} func (e *terminateStep) Step(backend *pgproto3.Backend) error { msg, err := getFrontendMessage(backend) if err != nil { return err } // Handle simple query if _, ok := msg.(*pgproto3.Terminate); ok { return nil } return errors.Errorf("expected => %#v\nactual => %#v", &pgproto3.Terminate{}, msg) } func ExpectTerminate() pgmock.Step { return &terminateStep{} } func getFrontendMessage(backend *pgproto3.Backend) (pgproto3.FrontendMessage, error) { msg, err := backend.Receive() if err != nil { return nil, err } // Sync signals end of batch statements if _, ok := msg.(*pgproto3.Sync); ok { reply := pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}) if err := reply.Step(backend); err != nil { return nil, err } msg, err = backend.Receive() if err != nil { return nil, err } } return msg, nil }