drop legacy pkg/migration

This commit is contained in:
c9s 2021-01-14 00:00:05 +08:00
parent 2592ae55a2
commit 6491166459
8 changed files with 0 additions and 1171 deletions

View File

@ -1,5 +0,0 @@
package main
func main() {
}

View File

@ -1,285 +0,0 @@
package migration
import (
"database/sql"
"fmt"
)
// SQLDialect abstracts the details of specific SQL dialects
// for goose's few SQL specific statements
type SQLDialect interface {
createVersionTableSQL() string // sql string to create the db version table
insertVersionSQL() string // sql string to insert the initial version table row
deleteVersionSQL() string // sql string to delete version
migrationSQL() string // sql string to retrieve migrations
dbVersionQuery(db *sql.DB) (*sql.Rows, error)
}
var dialect SQLDialect = &PostgresDialect{}
// GetDialect gets the SQLDialect
func GetDialect() SQLDialect {
return dialect
}
// SetDialect sets the SQLDialect
func SetDialect(d string) error {
switch d {
case "postgres":
dialect = &PostgresDialect{}
case "mysql":
dialect = &MySQLDialect{}
case "sqlite3":
dialect = &Sqlite3Dialect{}
case "mssql":
dialect = &SqlServerDialect{}
case "redshift":
dialect = &RedshiftDialect{}
case "tidb":
dialect = &TiDBDialect{}
default:
return fmt.Errorf("%q: unknown dialect", d)
}
return nil
}
////////////////////////////
// Postgres
////////////////////////////
// PostgresDialect struct.
type PostgresDialect struct{}
func (pg PostgresDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id serial NOT NULL,
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default now(),
PRIMARY KEY(id)
);`, TableName())
}
func (pg PostgresDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName())
}
func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}
return rows, err
}
func (m PostgresDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (pg PostgresDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName())
}
////////////////////////////
// MySQL
////////////////////////////
// MySQLDialect struct.
type MySQLDialect struct{}
func (m MySQLDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id serial NOT NULL,
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default now(),
PRIMARY KEY(id)
);`, TableName())
}
func (m MySQLDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName())
}
func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}
return rows, err
}
func (m MySQLDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (m MySQLDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
}
////////////////////////////
// MSSQL
////////////////////////////
// SqlServerDialect struct.
type SqlServerDialect struct{}
func (m SqlServerDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id INT NOT NULL IDENTITY(1,1) PRIMARY KEY,
version_id BIGINT NOT NULL,
is_applied BIT NOT NULL,
tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP
);`, TableName())
}
func (m SqlServerDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2);", TableName())
}
func (m SqlServerDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}
return rows, err
}
func (m SqlServerDialect) migrationSQL() string {
const tpl = `
WITH Migrations AS
(
SELECT tstamp, is_applied,
ROW_NUMBER() OVER (ORDER BY tstamp) AS 'RowNumber'
FROM %s
WHERE version_id=@p1
)
SELECT tstamp, is_applied
FROM Migrations
WHERE RowNumber BETWEEN 1 AND 2
ORDER BY tstamp DESC
`
return fmt.Sprintf(tpl, TableName())
}
func (m SqlServerDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=@p1;", TableName())
}
////////////////////////////
// sqlite3
////////////////////////////
// Sqlite3Dialect struct.
type Sqlite3Dialect struct{}
func (m Sqlite3Dialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id INTEGER PRIMARY KEY AUTOINCREMENT,
version_id INTEGER NOT NULL,
is_applied INTEGER NOT NULL,
tstamp TIMESTAMP DEFAULT (datetime('now'))
);`, TableName())
}
func (m Sqlite3Dialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName())
}
func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}
return rows, err
}
func (m Sqlite3Dialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (m Sqlite3Dialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
}
////////////////////////////
// Redshift
////////////////////////////
// RedshiftDialect struct.
type RedshiftDialect struct{}
func (rs RedshiftDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id integer NOT NULL identity(1, 1),
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default sysdate,
PRIMARY KEY(id)
);`, TableName())
}
func (rs RedshiftDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName())
}
func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}
return rows, err
}
func (m RedshiftDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (rs RedshiftDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName())
}
////////////////////////////
// TiDB
////////////////////////////
// TiDBDialect struct.
type TiDBDialect struct{}
func (m TiDBDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE,
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default now(),
PRIMARY KEY(id)
);`, TableName())
}
func (m TiDBDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName())
}
func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}
return rows, err
}
func (m TiDBDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (m TiDBDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
}

View File

@ -1,492 +0,0 @@
package migration
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"path/filepath"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/pkg/errors"
)
const VERSION = "v2.7.0-rc3"
var (
duplicateCheckOnce sync.Once
minVersion = int64(0)
maxVersion = int64((1 << 63) - 1)
timestampFormat = "20060102150405"
verbose = false
)
// SetVerbose set the goose verbosity mode
func SetVerbose(v bool) {
verbose = v
}
var (
// ErrNoCurrentVersion when a current migration version is not found.
ErrNoCurrentVersion = errors.New("no current version found")
// ErrNoNextVersion when the next migration version is not found.
ErrNoNextVersion = errors.New("no next version found")
// MaxVersion is the maximum allowed version.
MaxVersion int64 = 9223372036854775807 // max(int64)
registeredGoMigrations = map[int64]*Migration{}
)
// MigrationRecord struct.
type MigrationRecord struct {
VersionID int64
TStamp time.Time
IsApplied bool // was this a result of up() or down()
}
// Migration struct.
type Migration struct {
Version int64
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
Source string // path to .sql script
Registered bool
UpFn func(*sql.Tx) error // Up go migration function
DownFn func(*sql.Tx) error // Down go migration function
}
func (m *Migration) String() string {
return fmt.Sprintf(m.Source)
}
// Up runs an up migration.
func (m *Migration) Up(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, true); err != nil {
return err
}
return nil
}
// Down runs a down migration.
func (m *Migration) Down(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, false); err != nil {
return err
}
return nil
}
func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
switch filepath.Ext(m.Source) {
case ".sql":
f, err := os.Open(m.Source)
if err != nil {
return errors.Wrapf(err, "ERROR %v: failed to open SQL migration file", filepath.Base(m.Source))
}
defer f.Close()
statements, useTx, err := parseSQLMigration(f, direction)
if err != nil {
return errors.Wrapf(err, "ERROR %v: failed to parse SQL migration file", filepath.Base(m.Source))
}
if err := runSQLMigration(ctx, db, statements, useTx, m.Version, direction); err != nil {
return errors.Wrapf(err, "ERROR %v: failed to run SQL migration", filepath.Base(m.Source))
}
if len(statements) > 0 {
log.Println("OK ", filepath.Base(m.Source))
} else {
log.Println("EMPTY", filepath.Base(m.Source))
}
case ".go":
if !m.Registered {
return errors.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/c9s/goose/tree/master/examples/go-migrations)", m.Source)
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return errors.Wrap(err, "ERROR failed to begin transaction")
}
fn := m.UpFn
if !direction {
fn = m.DownFn
}
if fn != nil {
// Run Go migration function.
if err := fn(tx); err != nil {
tx.Rollback()
return errors.Wrapf(err, "ERROR %v: failed to run Go migration function %T", filepath.Base(m.Source), fn)
}
}
if direction {
if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil {
tx.Rollback()
return errors.Wrap(err, "ERROR failed to execute transaction")
}
} else {
if _, err := tx.Exec(GetDialect().deleteVersionSQL(), m.Version); err != nil {
tx.Rollback()
return errors.Wrap(err, "ERROR failed to execute transaction")
}
}
if err := tx.Commit(); err != nil {
return errors.Wrap(err, "ERROR failed to commit transaction")
}
if fn != nil {
log.Println("OK ", filepath.Base(m.Source))
} else {
log.Println("EMPTY", filepath.Base(m.Source))
}
return nil
}
return nil
}
// ExtractNumericComponent looks for migration scripts with names in the form:
// XXX_descriptivename.ext where XXX specifies the version number
// and ext specifies the type of migration
func ExtractNumericComponent(name string) (int64, error) {
base := filepath.Base(name)
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
return 0, errors.New("not a recognized migration file type")
}
idx := strings.Index(base, "_")
if idx < 0 {
return 0, errors.New("no separator found")
}
n, e := strconv.ParseInt(base[:idx], 10, 64)
if e == nil && n <= 0 {
return 0, errors.New("migration IDs must be greater than zero")
}
return n, e
}
// Migrations slice.
type Migrations []*Migration
// helpers so we can use pkg sort
func (ms Migrations) Len() int { return len(ms) }
func (ms Migrations) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] }
func (ms Migrations) Less(i, j int) bool {
if ms[i].Version == ms[j].Version {
panic(fmt.Sprintf("goose: duplicate version %v detected:\n%v\n%v", ms[i].Version, ms[i].Source, ms[j].Source))
}
return ms[i].Version < ms[j].Version
}
// Current gets the current migration.
func (ms Migrations) Current(current int64) (*Migration, error) {
for i, migration := range ms {
if migration.Version == current {
return ms[i], nil
}
}
return nil, ErrNoCurrentVersion
}
// Next gets the next migration.
func (ms Migrations) Next(current int64) (*Migration, error) {
for i, migration := range ms {
if migration.Version > current {
return ms[i], nil
}
}
return nil, ErrNoNextVersion
}
// Previous : Get the previous migration.
func (ms Migrations) Previous(current int64) (*Migration, error) {
for i := len(ms) - 1; i >= 0; i-- {
if ms[i].Version < current {
return ms[i], nil
}
}
return nil, ErrNoNextVersion
}
// Last gets the last migration.
func (ms Migrations) Last() (*Migration, error) {
if len(ms) == 0 {
return nil, ErrNoNextVersion
}
return ms[len(ms)-1], nil
}
// Versioned gets versioned migrations.
func (ms Migrations) versioned() (Migrations, error) {
var migrations Migrations
// assume that the user will never have more than 19700101000000 migrations
for _, m := range ms {
// parse version as timestmap
versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version))
if versionTime.Before(time.Unix(0, 0)) || err != nil {
migrations = append(migrations, m)
}
}
return migrations, nil
}
// Timestamped gets the timestamped migrations.
func (ms Migrations) timestamped() (Migrations, error) {
var migrations Migrations
// assume that the user will never have more than 19700101000000 migrations
for _, m := range ms {
// parse version as timestmap
versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version))
if err != nil {
// probably not a timestamp
continue
}
if versionTime.After(time.Unix(0, 0)) {
migrations = append(migrations, m)
}
}
return migrations, nil
}
func (ms Migrations) String() string {
str := ""
for _, m := range ms {
str += fmt.Sprintln(m)
}
return str
}
// AddMigration adds a migration.
func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigration(filename, up, down)
}
// AddNamedMigration : Add a named migration.
func AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) {
v, _ := ExtractNumericComponent(filename)
migration := &Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename}
if existing, ok := registeredGoMigrations[v]; ok {
panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source))
}
registeredGoMigrations[v] = migration
}
// CollectMigrationsFromDir returns all the valid looking migration scripts in the
// migrations folder and go func registry, and key them by version.
func CollectMigrationsFromDir(dirpath string, current, target int64) (Migrations, error) {
if _, err := os.Stat(dirpath); os.IsNotExist(err) {
return nil, fmt.Errorf("%s directory does not exists", dirpath)
}
var migrations Migrations
// SQL migration files.
sqlMigrationFiles, err := filepath.Glob(dirpath + "/**.sql")
if err != nil {
return nil, err
}
for _, file := range sqlMigrationFiles {
v, err := ExtractNumericComponent(file)
if err != nil {
return nil, err
}
if versionFilter(v, current, target) {
migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file}
migrations = append(migrations, migration)
}
}
// Go migrations registered via goose.AddMigration().
for _, migration := range registeredGoMigrations {
v, err := ExtractNumericComponent(migration.Source)
if err != nil {
return nil, err
}
if versionFilter(v, current, target) {
migrations = append(migrations, migration)
}
}
// Go migration files
goMigrationFiles, err := filepath.Glob(dirpath + "/**.go")
if err != nil {
return nil, err
}
for _, file := range goMigrationFiles {
v, err := ExtractNumericComponent(file)
if err != nil {
continue // Skip any files that don't have version prefix.
}
// Skip migrations already existing migrations registered via goose.AddMigration().
if _, ok := registeredGoMigrations[v]; ok {
continue
}
if versionFilter(v, current, target) {
migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file, Registered: false}
migrations = append(migrations, migration)
}
}
migrations = sortAndConnectMigrations(migrations)
return migrations, nil
}
func sortAndConnectMigrations(migrations Migrations) Migrations {
sort.Sort(migrations)
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range migrations {
prev := int64(-1)
if i > 0 {
prev = migrations[i-1].Version
migrations[i-1].Next = m.Version
}
migrations[i].Previous = prev
}
return migrations
}
func versionFilter(v, current, target int64) bool {
if target > current {
return v > current && v <= target
}
if target < current {
return v <= current && v > target
}
return false
}
// Create the db version table
// and insert the initial 0 value into it
func createVersionTable(db *sql.DB) error {
txn, err := db.Begin()
if err != nil {
return err
}
d := GetDialect()
if _, err := txn.Exec(d.createVersionTableSQL()); err != nil {
txn.Rollback()
return err
}
version := 0
applied := true
if _, err := txn.Exec(d.insertVersionSQL(), version, applied); err != nil {
txn.Rollback()
return err
}
return txn.Commit()
}
// Run a migration specified in raw SQL.
//
// Sections of the script can be annotated with a special comment,
// starting with "-- +goose" to specify whether the section should
// be applied during an Up or Down migration
//
// All statements following an Up or Down directive are grouped together
// until another direction directive is found.
func runSQLMigration(ctx context.Context, db *sql.DB, statements []string, useTx bool, v int64, direction bool) error {
if useTx {
// TRANSACTION.
verboseInfo("Begin transaction")
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return errors.Wrap(err, "failed to begin transaction")
}
for _, query := range statements {
verboseInfo("Executing statement: %s\n", clearStatement(query))
if _, err = tx.Exec(query); err != nil {
verboseInfo("Rollback transaction")
tx.Rollback()
return errors.Wrapf(err, "failed to execute SQL query %q", clearStatement(query))
}
}
if direction {
if _, err := tx.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil {
verboseInfo("Rollback transaction")
tx.Rollback()
return errors.Wrap(err, "failed to insert new goose version")
}
} else {
if _, err := tx.Exec(GetDialect().deleteVersionSQL(), v); err != nil {
verboseInfo("Rollback transaction")
tx.Rollback()
return errors.Wrap(err, "failed to delete goose version")
}
}
verboseInfo("Commit transaction")
if err := tx.Commit(); err != nil {
return errors.Wrap(err, "failed to commit transaction")
}
return nil
}
// NO TRANSACTION.
for _, query := range statements {
verboseInfo("Executing statement: %s", clearStatement(query))
if _, err := db.Exec(query); err != nil {
return errors.Wrapf(err, "failed to execute SQL query %q", clearStatement(query))
}
}
if _, err := db.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil {
return errors.Wrap(err, "failed to insert new goose version")
}
return nil
}
const (
grayColor = "\033[90m"
resetColor = "\033[00m"
)
func verboseInfo(s string, args ...interface{}) {
log.Printf(grayColor+s+resetColor, args...)
}

View File

@ -1,5 +0,0 @@
package migration
//go:generate gopackmigration -dir ../../migrations

View File

@ -1,13 +0,0 @@
package migration
import "regexp"
var (
matchSQLComments = regexp.MustCompile(`(?m)^--.*$[\r\n]*`)
matchEmptyEOL = regexp.MustCompile(`(?m)^$[\r\n]*`) // TODO: Duplicate
)
func clearStatement(s string) string {
s = matchSQLComments.ReplaceAllString(s, ``)
return matchEmptyEOL.ReplaceAllString(s, ``)
}

View File

@ -1,212 +0,0 @@
package migration
import (
"bufio"
"bytes"
"io"
"regexp"
"strings"
"sync"
"github.com/pkg/errors"
)
type parserState int
const (
start parserState = iota // 0
upStatement // 1
upStatementBegin // 2
upStatementEnd // 3
downStatement // 4
downStatementBegin // 5
downStatementEnd // 6
)
type stateMachine parserState
func (s *stateMachine) Get() parserState {
return parserState(*s)
}
func (s *stateMachine) Set(new parserState) {
*s = stateMachine(new)
}
const scanBufSize = 4 * 1024 * 1024
var matchEmptyLines = regexp.MustCompile(`^\s*$`)
var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, scanBufSize)
},
}
// Split given SQL script into individual statements and return
// SQL statements for given direction (up=true, down=false).
//
// The base case is to simply split on semicolons, as these
// naturally terminate a statement.
//
// However, more complex cases like pl/pgsql can have semicolons
// within a statement. For these cases, we provide the explicit annotations
// 'StatementBegin' and 'StatementEnd' to allow the script to
// tell us to ignore semicolons.
func parseSQLMigration(r io.Reader, direction bool) (stmts []string, useTx bool, err error) {
var buf bytes.Buffer
scanBuf := bufferPool.Get().([]byte)
defer bufferPool.Put(scanBuf)
scanner := bufio.NewScanner(r)
scanner.Buffer(scanBuf, scanBufSize)
stateMachine := stateMachine(start)
useTx = true
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "--") {
cmd := strings.TrimSpace(strings.TrimPrefix(line, "--"))
switch cmd {
case "+goose Up":
switch stateMachine.Get() {
case start:
stateMachine.Set(upStatement)
default:
return nil, false, errors.Errorf("duplicate '-- +goose Up' annotations; stateMachine=%v, see https://github.com/c9s/goose#sql-migrations", stateMachine)
}
continue
case "+goose Down":
switch stateMachine.Get() {
case upStatement, upStatementEnd:
stateMachine.Set(downStatement)
default:
return nil, false, errors.Errorf("must start with '-- +goose Up' annotation, stateMachine=%v, see https://github.com/c9s/goose#sql-migrations", stateMachine)
}
continue
case "+goose StatementBegin":
switch stateMachine.Get() {
case upStatement, upStatementEnd:
stateMachine.Set(upStatementBegin)
case downStatement, downStatementEnd:
stateMachine.Set(downStatementBegin)
default:
return nil, false, errors.Errorf("'-- +goose StatementBegin' must be defined after '-- +goose Up' or '-- +goose Down' annotation, stateMachine=%v, see https://github.com/c9s/goose#sql-migrations", stateMachine)
}
continue
case "+goose StatementEnd":
switch stateMachine.Get() {
case upStatementBegin:
stateMachine.Set(upStatementEnd)
case downStatementBegin:
stateMachine.Set(downStatementEnd)
default:
return nil, false, errors.New("'-- +goose StatementEnd' must be defined after '-- +goose StatementBegin', see https://github.com/c9s/goose#sql-migrations")
}
case "+goose NO TRANSACTION":
useTx = false
continue
default:
// Ignore comments.
continue
}
}
// Ignore empty lines.
if matchEmptyLines.MatchString(line) {
continue
}
// Write SQL line to a buffer.
if _, err := buf.WriteString(line + "\n"); err != nil {
return nil, false, errors.Wrap(err, "failed to write to buf")
}
// Read SQL body one by line, if we're in the right direction.
//
// 1) basic query with semicolon; 2) psql statement
//
// Export statement once we hit end of statement.
switch stateMachine.Get() {
case upStatement, upStatementBegin, upStatementEnd:
if !direction /*down*/ {
buf.Reset()
continue
}
case downStatement, downStatementBegin, downStatementEnd:
if direction /*up*/ {
buf.Reset()
continue
}
default:
return nil, false, errors.Errorf("failed to parse migration: unexpected state %q on line %q, see https://github.com/c9s/goose#sql-migrations", stateMachine, line)
}
switch stateMachine.Get() {
case upStatement:
if endsWithSemicolon(line) {
stmts = append(stmts, buf.String())
buf.Reset()
}
case downStatement:
if endsWithSemicolon(line) {
stmts = append(stmts, buf.String())
buf.Reset()
}
case upStatementEnd:
stmts = append(stmts, buf.String())
buf.Reset()
stateMachine.Set(upStatement)
case downStatementEnd:
stmts = append(stmts, buf.String())
buf.Reset()
stateMachine.Set(downStatement)
}
}
if err := scanner.Err(); err != nil {
return nil, false, errors.Wrap(err, "failed to scan migration")
}
// EOF
switch stateMachine.Get() {
case start:
return nil, false, errors.New("failed to parse migration: must start with '-- +goose Up' annotation, see https://github.com/c9s/goose#sql-migrations")
case upStatementBegin, downStatementBegin:
return nil, false, errors.New("failed to parse migration: missing '-- +goose StatementEnd' annotation")
}
if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
return nil, false, errors.Errorf("failed to parse migration: state %q, direction: %v: unexpected unfinished SQL query: %q: missing semicolon?", stateMachine, direction, bufferRemaining)
}
return stmts, useTx, nil
}
// Checks the line to see if the line has a statement-ending semicolon
// or if the line contains a double-dash comment.
func endsWithSemicolon(line string) bool {
scanBuf := bufferPool.Get().([]byte)
defer bufferPool.Put(scanBuf)
prev := ""
scanner := bufio.NewScanner(strings.NewReader(line))
scanner.Buffer(scanBuf, scanBufSize)
scanner.Split(bufio.ScanWords)
for scanner.Scan() {
word := scanner.Text()
if strings.HasPrefix(word, "--") {
break
}
prev = word
}
return strings.HasSuffix(prev, ";")
}

View File

@ -1,68 +0,0 @@
package migration
import (
"context"
"database/sql"
"log"
)
// UpTo migrates up to a specific version.
func UpTo(ctx context.Context, db *sql.DB, dir string, version int64) error {
migrations, err := CollectMigrationsFromDir(dir, minVersion, version)
if err != nil {
return err
}
for {
current, err := GetDBVersion(db)
if err != nil {
return err
}
next, err := migrations.Next(current)
if err != nil {
if err == ErrNoNextVersion {
log.Printf("no migrations to run. current version: %d\n", current)
return nil
}
return err
}
if err = next.Up(ctx, db); err != nil {
return err
}
}
}
// Up applies all available migrations.
func Up(ctx context.Context, db *sql.DB, dir string) error {
return UpTo(ctx, db, dir, maxVersion)
}
// UpByOne migrates up by a single version.
func UpByOne(ctx context.Context, db *sql.DB, dir string) error {
migrations, err := CollectMigrationsFromDir(dir, minVersion, maxVersion)
if err != nil {
return err
}
currentVersion, err := GetDBVersion(db)
if err != nil {
return err
}
next, err := migrations.Next(currentVersion)
if err != nil {
if err == ErrNoNextVersion {
log.Printf("no migrations to run. current version: %d\n", currentVersion)
return nil
}
return err
}
if err = next.Up(ctx, db); err != nil {
return err
}
return nil
}

View File

@ -1,91 +0,0 @@
package migration
import (
"database/sql"
"log"
"github.com/pkg/errors"
)
// Version prints the current version of the database.
func Version(db *sql.DB, dir string) error {
current, err := GetDBVersion(db)
if err != nil {
return err
}
log.Printf("cmd: version %v\n", current)
return nil
}
// GetDBVersion is an alias for EnsureDBVersion, but returns -1 in error.
func GetDBVersion(db *sql.DB) (int64, error) {
version, err := EnsureDBVersion(db)
if err != nil {
return -1, err
}
return version, nil
}
// EnsureDBVersion retrieves the current version for this DB.
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersion(db *sql.DB) (int64, error) {
rows, err := GetDialect().dbVersionQuery(db)
if err != nil {
return 0, createVersionTable(db)
}
defer rows.Close()
// The most recent record for each migration specifies
// whether it has been applied or rolled back.
// The first version we find that has been applied is the current version.
toSkip := make([]int64, 0)
for rows.Next() {
var row MigrationRecord
if err = rows.Scan(&row.VersionID, &row.IsApplied); err != nil {
return 0, errors.Wrap(err, "failed to scan row")
}
// have we already marked this version to be skipped?
skip := false
for _, v := range toSkip {
if v == row.VersionID {
skip = true
break
}
}
if skip {
continue
}
// if version has been applied we're done
if row.IsApplied {
return row.VersionID, nil
}
// latest version of migration has not been applied.
toSkip = append(toSkip, row.VersionID)
}
if err := rows.Err(); err != nil {
return 0, errors.Wrap(err, "failed to get next row")
}
return 0, ErrNoNextVersion
}
var tableName = "goose_db_version"
// TableName returns goose db version table name
func TableName() string {
return tableName
}
// SetTableName set goose db version table name
func SetTableName(n string) {
tableName = n
}