supabase-cli/pkg/migration/file.go

186 lines
4.7 KiB
Go

package migration
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"io"
"io/fs"
"path/filepath"
"regexp"
"strings"
"github.com/go-errors/errors"
"github.com/jackc/pgconn"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4"
"github.com/spf13/viper"
"github.com/supabase/cli/pkg/parser"
)
type MigrationFile struct {
Version string
Name string
Statements []string
}
var migrateFilePattern = regexp.MustCompile(`^([0-9]+)_(.*)\.sql$`)
func NewMigrationFromFile(path string, fsys fs.FS) (*MigrationFile, error) {
lines, err := parseFile(path, fsys)
if err != nil {
return nil, err
}
file := MigrationFile{Statements: lines}
// Parse version from file name
filename := filepath.Base(path)
matches := migrateFilePattern.FindStringSubmatch(filename)
if len(matches) > 2 {
file.Version = matches[1]
file.Name = matches[2]
}
return &file, nil
}
func parseFile(path string, fsys fs.FS) ([]string, error) {
sql, err := fsys.Open(path)
if err != nil {
return nil, errors.Errorf("failed to open migration file: %w", err)
}
defer sql.Close()
// Unless explicitly specified, Use file length as max buffer size
if !viper.IsSet("SCANNER_BUFFER_SIZE") {
if fi, err := sql.Stat(); err == nil {
if size := int(fi.Size()); size > parser.MaxScannerCapacity {
parser.MaxScannerCapacity = size
}
}
}
return parser.SplitAndTrim(sql)
}
func NewMigrationFromReader(sql io.Reader) (*MigrationFile, error) {
lines, err := parser.SplitAndTrim(sql)
if err != nil {
return nil, err
}
return &MigrationFile{Statements: lines}, nil
}
func (m *MigrationFile) ExecBatch(ctx context.Context, conn *pgx.Conn) error {
// Batch migration commands, without using statement cache
batch := &pgconn.Batch{}
for _, line := range m.Statements {
batch.ExecParams(line, nil, nil, nil, nil)
}
// Insert into migration history
if len(m.Version) > 0 {
if err := m.insertVersionSQL(conn, batch); err != nil {
return err
}
}
// ExecBatch is implicitly transactional
if result, err := conn.PgConn().ExecBatch(ctx, batch).ReadAll(); err != nil {
// Defaults to printing the last statement on error
stat := INSERT_MIGRATION_VERSION
i := len(result)
if i < len(m.Statements) {
stat = m.Statements[i]
}
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
stat = markError(stat, int(pgErr.Position))
}
return errors.Errorf("%w\nAt statement %d:\n%s", err, i, stat)
}
return nil
}
func markError(stat string, pos int) string {
lines := strings.Split(stat, "\n")
for j, r := range lines {
if c := len(r); pos > c {
pos -= c + 1
continue
}
// Show a caret below the error position
if pos > 0 {
caret := append(bytes.Repeat([]byte{' '}, pos-1), '^')
lines = append(lines[:j+1], string(caret))
}
break
}
return strings.Join(lines, "\n")
}
func (m *MigrationFile) insertVersionSQL(conn *pgx.Conn, batch *pgconn.Batch) error {
value := pgtype.TextArray{}
if err := value.Set(m.Statements); err != nil {
return errors.Errorf("failed to set text array: %w", err)
}
ci := conn.ConnInfo()
var err error
var encoded []byte
var valueFormat int16
if conn.Config().PreferSimpleProtocol {
encoded, err = value.EncodeText(ci, encoded)
valueFormat = pgtype.TextFormatCode
} else {
encoded, err = value.EncodeBinary(ci, encoded)
valueFormat = pgtype.BinaryFormatCode
}
if err != nil {
return errors.Errorf("failed to encode binary: %w", err)
}
batch.ExecParams(
INSERT_MIGRATION_VERSION,
[][]byte{[]byte(m.Version), []byte(m.Name), encoded},
[]uint32{pgtype.TextOID, pgtype.TextOID, pgtype.TextArrayOID},
[]int16{pgtype.TextFormatCode, pgtype.TextFormatCode, valueFormat},
nil,
)
return nil
}
type SeedFile struct {
Path string
Hash string
Dirty bool `db:"-"`
}
func NewSeedFile(path string, fsys fs.FS) (*SeedFile, error) {
sql, err := fsys.Open(path)
if err != nil {
return nil, errors.Errorf("failed to open seed file: %w", err)
}
defer sql.Close()
hash := sha256.New()
if _, err := io.Copy(hash, sql); err != nil {
return nil, errors.Errorf("failed to hash file: %w", err)
}
digest := hex.EncodeToString(hash.Sum(nil))
return &SeedFile{Path: path, Hash: digest}, nil
}
func (m *SeedFile) ExecBatchWithCache(ctx context.Context, conn *pgx.Conn, fsys fs.FS) error {
// Parse each file individually to reduce memory usage
lines, err := parseFile(m.Path, fsys)
if err != nil {
return err
}
// Data statements don't mutate schemas, safe to use statement cache
batch := pgx.Batch{}
if !m.Dirty {
for _, line := range lines {
batch.Queue(line)
}
}
batch.Queue(UPSERT_SEED_FILE, m.Path, m.Hash)
// No need to track version here because there are no schema changes
if err := conn.SendBatch(ctx, &batch).Close(); err != nil {
return errors.Errorf("failed to send batch: %w", err)
}
return nil
}