133 lines
3.2 KiB
Go
133 lines
3.2 KiB
Go
// Backported from pgx/v5: https://github.com/jackc/pgx/blob/v5.3.1/rows.go#L408
|
|
package pgxv5
|
|
|
|
import (
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/go-errors/errors"
|
|
"github.com/jackc/pgproto3/v2"
|
|
"github.com/jackc/pgx/v4"
|
|
)
|
|
|
|
func CollectStrings(rows pgx.Rows) ([]string, error) {
|
|
defer rows.Close()
|
|
result := []string{}
|
|
for rows.Next() {
|
|
var version string
|
|
if err := rows.Scan(&version); err != nil {
|
|
return nil, errors.Errorf("failed to scan rows: %w", err)
|
|
}
|
|
result = append(result, version)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, errors.Errorf("failed to parse rows: %w", err)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
|
|
func CollectRows[T any](rows pgx.Rows) ([]T, error) {
|
|
defer rows.Close()
|
|
|
|
slice := []T{}
|
|
|
|
for rows.Next() {
|
|
var value T
|
|
if err := ScanRowToStruct(rows, &value); err != nil {
|
|
return nil, err
|
|
}
|
|
slice = append(slice, value)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, errors.Errorf("failed to collect rows: %w", err)
|
|
}
|
|
|
|
return slice, nil
|
|
}
|
|
|
|
func ScanRowToStruct(rows pgx.Rows, dst any) error {
|
|
dstValue := reflect.ValueOf(dst)
|
|
if dstValue.Kind() != reflect.Ptr {
|
|
return errors.Errorf("dst not a pointer")
|
|
}
|
|
|
|
dstElemValue := dstValue.Elem()
|
|
scanTargets, err := appendScanTargets(dstElemValue, nil, rows.FieldDescriptions())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for i, t := range scanTargets {
|
|
if t == nil {
|
|
return errors.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name)
|
|
}
|
|
}
|
|
|
|
if err := rows.Scan(scanTargets...); err != nil {
|
|
return errors.Errorf("failed to scan targets: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
const structTagKey = "db"
|
|
|
|
func fieldPosByName(fldDescs []pgproto3.FieldDescription, field string) (i int) {
|
|
i = -1
|
|
for i, desc := range fldDescs {
|
|
if strings.EqualFold(string(desc.Name), field) {
|
|
return i
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgproto3.FieldDescription) ([]any, error) {
|
|
var err error
|
|
dstElemType := dstElemValue.Type()
|
|
|
|
if scanTargets == nil {
|
|
scanTargets = make([]any, len(fldDescs))
|
|
}
|
|
|
|
for i := 0; i < dstElemType.NumField(); i++ {
|
|
sf := dstElemType.Field(i)
|
|
if sf.PkgPath != "" && !sf.Anonymous {
|
|
// Field is unexported, skip it.
|
|
continue
|
|
}
|
|
// Handle anoymous struct embedding, but do not try to handle embedded pointers.
|
|
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
|
|
scanTargets, err = appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
colName := GetColumnName(sf)
|
|
if len(colName) == 0 {
|
|
// Field is ignored, skip it.
|
|
continue
|
|
}
|
|
fpos := fieldPosByName(fldDescs, colName)
|
|
if fpos == -1 || fpos >= len(scanTargets) {
|
|
return nil, errors.Errorf("cannot find field %s in returned row", colName)
|
|
}
|
|
scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface()
|
|
}
|
|
}
|
|
|
|
return scanTargets, err
|
|
}
|
|
|
|
func GetColumnName(sf reflect.StructField) string {
|
|
dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
|
|
if !dbTagPresent {
|
|
return sf.Name
|
|
}
|
|
if dbTag = strings.Split(dbTag, ",")[0]; dbTag != "-" {
|
|
return dbTag
|
|
}
|
|
return ""
|
|
}
|