mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-10 09:11:55 +00:00
drop legacy pkg/migration
This commit is contained in:
parent
2592ae55a2
commit
6491166459
|
@ -1,5 +0,0 @@
|
|||
package main
|
||||
|
||||
func main() {
|
||||
|
||||
}
|
|
@ -1,285 +0,0 @@
|
|||
package migration
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// SQLDialect abstracts the details of specific SQL dialects
|
||||
// for goose's few SQL specific statements
|
||||
type SQLDialect interface {
|
||||
createVersionTableSQL() string // sql string to create the db version table
|
||||
insertVersionSQL() string // sql string to insert the initial version table row
|
||||
deleteVersionSQL() string // sql string to delete version
|
||||
migrationSQL() string // sql string to retrieve migrations
|
||||
dbVersionQuery(db *sql.DB) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
var dialect SQLDialect = &PostgresDialect{}
|
||||
|
||||
// GetDialect gets the SQLDialect
|
||||
func GetDialect() SQLDialect {
|
||||
return dialect
|
||||
}
|
||||
|
||||
// SetDialect sets the SQLDialect
|
||||
func SetDialect(d string) error {
|
||||
switch d {
|
||||
case "postgres":
|
||||
dialect = &PostgresDialect{}
|
||||
case "mysql":
|
||||
dialect = &MySQLDialect{}
|
||||
case "sqlite3":
|
||||
dialect = &Sqlite3Dialect{}
|
||||
case "mssql":
|
||||
dialect = &SqlServerDialect{}
|
||||
case "redshift":
|
||||
dialect = &RedshiftDialect{}
|
||||
case "tidb":
|
||||
dialect = &TiDBDialect{}
|
||||
default:
|
||||
return fmt.Errorf("%q: unknown dialect", d)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
////////////////////////////
|
||||
// Postgres
|
||||
////////////////////////////
|
||||
|
||||
// PostgresDialect struct.
|
||||
type PostgresDialect struct{}
|
||||
|
||||
func (pg PostgresDialect) createVersionTableSQL() string {
|
||||
return fmt.Sprintf(`CREATE TABLE %s (
|
||||
id serial NOT NULL,
|
||||
version_id bigint NOT NULL,
|
||||
is_applied boolean NOT NULL,
|
||||
tstamp timestamp NULL default now(),
|
||||
PRIMARY KEY(id)
|
||||
);`, TableName())
|
||||
}
|
||||
|
||||
func (pg PostgresDialect) insertVersionSQL() string {
|
||||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName())
|
||||
}
|
||||
|
||||
func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
|
||||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (m PostgresDialect) migrationSQL() string {
|
||||
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName())
|
||||
}
|
||||
|
||||
func (pg PostgresDialect) deleteVersionSQL() string {
|
||||
return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName())
|
||||
}
|
||||
|
||||
////////////////////////////
|
||||
// MySQL
|
||||
////////////////////////////
|
||||
|
||||
// MySQLDialect struct.
|
||||
type MySQLDialect struct{}
|
||||
|
||||
func (m MySQLDialect) createVersionTableSQL() string {
|
||||
return fmt.Sprintf(`CREATE TABLE %s (
|
||||
id serial NOT NULL,
|
||||
version_id bigint NOT NULL,
|
||||
is_applied boolean NOT NULL,
|
||||
tstamp timestamp NULL default now(),
|
||||
PRIMARY KEY(id)
|
||||
);`, TableName())
|
||||
}
|
||||
|
||||
func (m MySQLDialect) insertVersionSQL() string {
|
||||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName())
|
||||
}
|
||||
|
||||
func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
|
||||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (m MySQLDialect) migrationSQL() string {
|
||||
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
|
||||
}
|
||||
|
||||
func (m MySQLDialect) deleteVersionSQL() string {
|
||||
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
|
||||
}
|
||||
|
||||
////////////////////////////
|
||||
// MSSQL
|
||||
////////////////////////////
|
||||
|
||||
// SqlServerDialect struct.
|
||||
type SqlServerDialect struct{}
|
||||
|
||||
func (m SqlServerDialect) createVersionTableSQL() string {
|
||||
return fmt.Sprintf(`CREATE TABLE %s (
|
||||
id INT NOT NULL IDENTITY(1,1) PRIMARY KEY,
|
||||
version_id BIGINT NOT NULL,
|
||||
is_applied BIT NOT NULL,
|
||||
tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`, TableName())
|
||||
}
|
||||
|
||||
func (m SqlServerDialect) insertVersionSQL() string {
|
||||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2);", TableName())
|
||||
}
|
||||
|
||||
func (m SqlServerDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
|
||||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY id DESC", TableName()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (m SqlServerDialect) migrationSQL() string {
|
||||
const tpl = `
|
||||
WITH Migrations AS
|
||||
(
|
||||
SELECT tstamp, is_applied,
|
||||
ROW_NUMBER() OVER (ORDER BY tstamp) AS 'RowNumber'
|
||||
FROM %s
|
||||
WHERE version_id=@p1
|
||||
)
|
||||
SELECT tstamp, is_applied
|
||||
FROM Migrations
|
||||
WHERE RowNumber BETWEEN 1 AND 2
|
||||
ORDER BY tstamp DESC
|
||||
`
|
||||
return fmt.Sprintf(tpl, TableName())
|
||||
}
|
||||
|
||||
func (m SqlServerDialect) deleteVersionSQL() string {
|
||||
return fmt.Sprintf("DELETE FROM %s WHERE version_id=@p1;", TableName())
|
||||
}
|
||||
|
||||
////////////////////////////
|
||||
// sqlite3
|
||||
////////////////////////////
|
||||
|
||||
// Sqlite3Dialect struct.
|
||||
type Sqlite3Dialect struct{}
|
||||
|
||||
func (m Sqlite3Dialect) createVersionTableSQL() string {
|
||||
return fmt.Sprintf(`CREATE TABLE %s (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
version_id INTEGER NOT NULL,
|
||||
is_applied INTEGER NOT NULL,
|
||||
tstamp TIMESTAMP DEFAULT (datetime('now'))
|
||||
);`, TableName())
|
||||
}
|
||||
|
||||
func (m Sqlite3Dialect) insertVersionSQL() string {
|
||||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName())
|
||||
}
|
||||
|
||||
func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
|
||||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (m Sqlite3Dialect) migrationSQL() string {
|
||||
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
|
||||
}
|
||||
|
||||
func (m Sqlite3Dialect) deleteVersionSQL() string {
|
||||
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
|
||||
}
|
||||
|
||||
////////////////////////////
|
||||
// Redshift
|
||||
////////////////////////////
|
||||
|
||||
// RedshiftDialect struct.
|
||||
type RedshiftDialect struct{}
|
||||
|
||||
func (rs RedshiftDialect) createVersionTableSQL() string {
|
||||
return fmt.Sprintf(`CREATE TABLE %s (
|
||||
id integer NOT NULL identity(1, 1),
|
||||
version_id bigint NOT NULL,
|
||||
is_applied boolean NOT NULL,
|
||||
tstamp timestamp NULL default sysdate,
|
||||
PRIMARY KEY(id)
|
||||
);`, TableName())
|
||||
}
|
||||
|
||||
func (rs RedshiftDialect) insertVersionSQL() string {
|
||||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName())
|
||||
}
|
||||
|
||||
func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
|
||||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (m RedshiftDialect) migrationSQL() string {
|
||||
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName())
|
||||
}
|
||||
|
||||
func (rs RedshiftDialect) deleteVersionSQL() string {
|
||||
return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName())
|
||||
}
|
||||
|
||||
////////////////////////////
|
||||
// TiDB
|
||||
////////////////////////////
|
||||
|
||||
// TiDBDialect struct.
|
||||
type TiDBDialect struct{}
|
||||
|
||||
func (m TiDBDialect) createVersionTableSQL() string {
|
||||
return fmt.Sprintf(`CREATE TABLE %s (
|
||||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE,
|
||||
version_id bigint NOT NULL,
|
||||
is_applied boolean NOT NULL,
|
||||
tstamp timestamp NULL default now(),
|
||||
PRIMARY KEY(id)
|
||||
);`, TableName())
|
||||
}
|
||||
|
||||
func (m TiDBDialect) insertVersionSQL() string {
|
||||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName())
|
||||
}
|
||||
|
||||
func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
|
||||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (m TiDBDialect) migrationSQL() string {
|
||||
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
|
||||
}
|
||||
|
||||
func (m TiDBDialect) deleteVersionSQL() string {
|
||||
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
|
||||
}
|
||||
|
|
@ -1,492 +0,0 @@
|
|||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const VERSION = "v2.7.0-rc3"
|
||||
|
||||
var (
|
||||
duplicateCheckOnce sync.Once
|
||||
minVersion = int64(0)
|
||||
maxVersion = int64((1 << 63) - 1)
|
||||
timestampFormat = "20060102150405"
|
||||
verbose = false
|
||||
)
|
||||
|
||||
// SetVerbose set the goose verbosity mode
|
||||
func SetVerbose(v bool) {
|
||||
verbose = v
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrNoCurrentVersion when a current migration version is not found.
|
||||
ErrNoCurrentVersion = errors.New("no current version found")
|
||||
// ErrNoNextVersion when the next migration version is not found.
|
||||
ErrNoNextVersion = errors.New("no next version found")
|
||||
// MaxVersion is the maximum allowed version.
|
||||
MaxVersion int64 = 9223372036854775807 // max(int64)
|
||||
|
||||
registeredGoMigrations = map[int64]*Migration{}
|
||||
)
|
||||
|
||||
// MigrationRecord struct.
|
||||
type MigrationRecord struct {
|
||||
VersionID int64
|
||||
TStamp time.Time
|
||||
IsApplied bool // was this a result of up() or down()
|
||||
}
|
||||
|
||||
// Migration struct.
|
||||
type Migration struct {
|
||||
Version int64
|
||||
Next int64 // next version, or -1 if none
|
||||
Previous int64 // previous version, -1 if none
|
||||
Source string // path to .sql script
|
||||
Registered bool
|
||||
UpFn func(*sql.Tx) error // Up go migration function
|
||||
DownFn func(*sql.Tx) error // Down go migration function
|
||||
}
|
||||
|
||||
func (m *Migration) String() string {
|
||||
return fmt.Sprintf(m.Source)
|
||||
}
|
||||
|
||||
// Up runs an up migration.
|
||||
func (m *Migration) Up(ctx context.Context, db *sql.DB) error {
|
||||
if err := m.run(ctx, db, true); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Down runs a down migration.
|
||||
func (m *Migration) Down(ctx context.Context, db *sql.DB) error {
|
||||
if err := m.run(ctx, db, false); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
|
||||
switch filepath.Ext(m.Source) {
|
||||
case ".sql":
|
||||
f, err := os.Open(m.Source)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "ERROR %v: failed to open SQL migration file", filepath.Base(m.Source))
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
statements, useTx, err := parseSQLMigration(f, direction)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "ERROR %v: failed to parse SQL migration file", filepath.Base(m.Source))
|
||||
}
|
||||
|
||||
if err := runSQLMigration(ctx, db, statements, useTx, m.Version, direction); err != nil {
|
||||
return errors.Wrapf(err, "ERROR %v: failed to run SQL migration", filepath.Base(m.Source))
|
||||
}
|
||||
|
||||
if len(statements) > 0 {
|
||||
log.Println("OK ", filepath.Base(m.Source))
|
||||
} else {
|
||||
log.Println("EMPTY", filepath.Base(m.Source))
|
||||
}
|
||||
|
||||
case ".go":
|
||||
if !m.Registered {
|
||||
return errors.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/c9s/goose/tree/master/examples/go-migrations)", m.Source)
|
||||
}
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "ERROR failed to begin transaction")
|
||||
}
|
||||
|
||||
fn := m.UpFn
|
||||
if !direction {
|
||||
fn = m.DownFn
|
||||
}
|
||||
|
||||
if fn != nil {
|
||||
// Run Go migration function.
|
||||
if err := fn(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return errors.Wrapf(err, "ERROR %v: failed to run Go migration function %T", filepath.Base(m.Source), fn)
|
||||
}
|
||||
}
|
||||
|
||||
if direction {
|
||||
if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil {
|
||||
tx.Rollback()
|
||||
return errors.Wrap(err, "ERROR failed to execute transaction")
|
||||
}
|
||||
} else {
|
||||
if _, err := tx.Exec(GetDialect().deleteVersionSQL(), m.Version); err != nil {
|
||||
tx.Rollback()
|
||||
return errors.Wrap(err, "ERROR failed to execute transaction")
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return errors.Wrap(err, "ERROR failed to commit transaction")
|
||||
}
|
||||
|
||||
if fn != nil {
|
||||
log.Println("OK ", filepath.Base(m.Source))
|
||||
} else {
|
||||
log.Println("EMPTY", filepath.Base(m.Source))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtractNumericComponent looks for migration scripts with names in the form:
|
||||
// XXX_descriptivename.ext where XXX specifies the version number
|
||||
// and ext specifies the type of migration
|
||||
func ExtractNumericComponent(name string) (int64, error) {
|
||||
base := filepath.Base(name)
|
||||
|
||||
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
|
||||
return 0, errors.New("not a recognized migration file type")
|
||||
}
|
||||
|
||||
idx := strings.Index(base, "_")
|
||||
if idx < 0 {
|
||||
return 0, errors.New("no separator found")
|
||||
}
|
||||
|
||||
n, e := strconv.ParseInt(base[:idx], 10, 64)
|
||||
if e == nil && n <= 0 {
|
||||
return 0, errors.New("migration IDs must be greater than zero")
|
||||
}
|
||||
|
||||
return n, e
|
||||
}
|
||||
|
||||
// Migrations slice.
|
||||
type Migrations []*Migration
|
||||
|
||||
// helpers so we can use pkg sort
|
||||
func (ms Migrations) Len() int { return len(ms) }
|
||||
func (ms Migrations) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] }
|
||||
func (ms Migrations) Less(i, j int) bool {
|
||||
if ms[i].Version == ms[j].Version {
|
||||
panic(fmt.Sprintf("goose: duplicate version %v detected:\n%v\n%v", ms[i].Version, ms[i].Source, ms[j].Source))
|
||||
}
|
||||
return ms[i].Version < ms[j].Version
|
||||
}
|
||||
|
||||
// Current gets the current migration.
|
||||
func (ms Migrations) Current(current int64) (*Migration, error) {
|
||||
for i, migration := range ms {
|
||||
if migration.Version == current {
|
||||
return ms[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrNoCurrentVersion
|
||||
}
|
||||
|
||||
// Next gets the next migration.
|
||||
func (ms Migrations) Next(current int64) (*Migration, error) {
|
||||
for i, migration := range ms {
|
||||
if migration.Version > current {
|
||||
return ms[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrNoNextVersion
|
||||
}
|
||||
|
||||
// Previous : Get the previous migration.
|
||||
func (ms Migrations) Previous(current int64) (*Migration, error) {
|
||||
for i := len(ms) - 1; i >= 0; i-- {
|
||||
if ms[i].Version < current {
|
||||
return ms[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrNoNextVersion
|
||||
}
|
||||
|
||||
// Last gets the last migration.
|
||||
func (ms Migrations) Last() (*Migration, error) {
|
||||
if len(ms) == 0 {
|
||||
return nil, ErrNoNextVersion
|
||||
}
|
||||
|
||||
return ms[len(ms)-1], nil
|
||||
}
|
||||
|
||||
// Versioned gets versioned migrations.
|
||||
func (ms Migrations) versioned() (Migrations, error) {
|
||||
var migrations Migrations
|
||||
|
||||
// assume that the user will never have more than 19700101000000 migrations
|
||||
for _, m := range ms {
|
||||
// parse version as timestmap
|
||||
versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version))
|
||||
|
||||
if versionTime.Before(time.Unix(0, 0)) || err != nil {
|
||||
migrations = append(migrations, m)
|
||||
}
|
||||
}
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
// Timestamped gets the timestamped migrations.
|
||||
func (ms Migrations) timestamped() (Migrations, error) {
|
||||
var migrations Migrations
|
||||
|
||||
// assume that the user will never have more than 19700101000000 migrations
|
||||
for _, m := range ms {
|
||||
// parse version as timestmap
|
||||
versionTime, err := time.Parse(timestampFormat, fmt.Sprintf("%d", m.Version))
|
||||
if err != nil {
|
||||
// probably not a timestamp
|
||||
continue
|
||||
}
|
||||
|
||||
if versionTime.After(time.Unix(0, 0)) {
|
||||
migrations = append(migrations, m)
|
||||
}
|
||||
}
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
func (ms Migrations) String() string {
|
||||
str := ""
|
||||
for _, m := range ms {
|
||||
str += fmt.Sprintln(m)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// AddMigration adds a migration.
|
||||
func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
|
||||
_, filename, _, _ := runtime.Caller(1)
|
||||
AddNamedMigration(filename, up, down)
|
||||
}
|
||||
|
||||
// AddNamedMigration : Add a named migration.
|
||||
func AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) {
|
||||
v, _ := ExtractNumericComponent(filename)
|
||||
migration := &Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename}
|
||||
|
||||
if existing, ok := registeredGoMigrations[v]; ok {
|
||||
panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source))
|
||||
}
|
||||
|
||||
registeredGoMigrations[v] = migration
|
||||
}
|
||||
|
||||
// CollectMigrationsFromDir returns all the valid looking migration scripts in the
|
||||
// migrations folder and go func registry, and key them by version.
|
||||
func CollectMigrationsFromDir(dirpath string, current, target int64) (Migrations, error) {
|
||||
if _, err := os.Stat(dirpath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("%s directory does not exists", dirpath)
|
||||
}
|
||||
|
||||
var migrations Migrations
|
||||
|
||||
// SQL migration files.
|
||||
sqlMigrationFiles, err := filepath.Glob(dirpath + "/**.sql")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, file := range sqlMigrationFiles {
|
||||
v, err := ExtractNumericComponent(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if versionFilter(v, current, target) {
|
||||
migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file}
|
||||
migrations = append(migrations, migration)
|
||||
}
|
||||
}
|
||||
|
||||
// Go migrations registered via goose.AddMigration().
|
||||
for _, migration := range registeredGoMigrations {
|
||||
v, err := ExtractNumericComponent(migration.Source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if versionFilter(v, current, target) {
|
||||
migrations = append(migrations, migration)
|
||||
}
|
||||
}
|
||||
|
||||
// Go migration files
|
||||
goMigrationFiles, err := filepath.Glob(dirpath + "/**.go")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, file := range goMigrationFiles {
|
||||
v, err := ExtractNumericComponent(file)
|
||||
if err != nil {
|
||||
continue // Skip any files that don't have version prefix.
|
||||
}
|
||||
|
||||
// Skip migrations already existing migrations registered via goose.AddMigration().
|
||||
if _, ok := registeredGoMigrations[v]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if versionFilter(v, current, target) {
|
||||
migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file, Registered: false}
|
||||
migrations = append(migrations, migration)
|
||||
}
|
||||
}
|
||||
|
||||
migrations = sortAndConnectMigrations(migrations)
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
func sortAndConnectMigrations(migrations Migrations) Migrations {
|
||||
sort.Sort(migrations)
|
||||
|
||||
// now that we're sorted in the appropriate direction,
|
||||
// populate next and previous for each migration
|
||||
for i, m := range migrations {
|
||||
prev := int64(-1)
|
||||
if i > 0 {
|
||||
prev = migrations[i-1].Version
|
||||
migrations[i-1].Next = m.Version
|
||||
}
|
||||
migrations[i].Previous = prev
|
||||
}
|
||||
|
||||
return migrations
|
||||
}
|
||||
|
||||
func versionFilter(v, current, target int64) bool {
|
||||
|
||||
if target > current {
|
||||
return v > current && v <= target
|
||||
}
|
||||
|
||||
if target < current {
|
||||
return v <= current && v > target
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
// Create the db version table
|
||||
// and insert the initial 0 value into it
|
||||
func createVersionTable(db *sql.DB) error {
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d := GetDialect()
|
||||
|
||||
if _, err := txn.Exec(d.createVersionTableSQL()); err != nil {
|
||||
txn.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
version := 0
|
||||
applied := true
|
||||
if _, err := txn.Exec(d.insertVersionSQL(), version, applied); err != nil {
|
||||
txn.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return txn.Commit()
|
||||
}
|
||||
|
||||
|
||||
// Run a migration specified in raw SQL.
|
||||
//
|
||||
// Sections of the script can be annotated with a special comment,
|
||||
// starting with "-- +goose" to specify whether the section should
|
||||
// be applied during an Up or Down migration
|
||||
//
|
||||
// All statements following an Up or Down directive are grouped together
|
||||
// until another direction directive is found.
|
||||
func runSQLMigration(ctx context.Context, db *sql.DB, statements []string, useTx bool, v int64, direction bool) error {
|
||||
if useTx {
|
||||
// TRANSACTION.
|
||||
|
||||
verboseInfo("Begin transaction")
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to begin transaction")
|
||||
}
|
||||
|
||||
for _, query := range statements {
|
||||
verboseInfo("Executing statement: %s\n", clearStatement(query))
|
||||
if _, err = tx.Exec(query); err != nil {
|
||||
verboseInfo("Rollback transaction")
|
||||
tx.Rollback()
|
||||
return errors.Wrapf(err, "failed to execute SQL query %q", clearStatement(query))
|
||||
}
|
||||
}
|
||||
|
||||
if direction {
|
||||
if _, err := tx.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil {
|
||||
verboseInfo("Rollback transaction")
|
||||
tx.Rollback()
|
||||
return errors.Wrap(err, "failed to insert new goose version")
|
||||
}
|
||||
} else {
|
||||
if _, err := tx.Exec(GetDialect().deleteVersionSQL(), v); err != nil {
|
||||
verboseInfo("Rollback transaction")
|
||||
tx.Rollback()
|
||||
return errors.Wrap(err, "failed to delete goose version")
|
||||
}
|
||||
}
|
||||
|
||||
verboseInfo("Commit transaction")
|
||||
if err := tx.Commit(); err != nil {
|
||||
return errors.Wrap(err, "failed to commit transaction")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NO TRANSACTION.
|
||||
for _, query := range statements {
|
||||
verboseInfo("Executing statement: %s", clearStatement(query))
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
return errors.Wrapf(err, "failed to execute SQL query %q", clearStatement(query))
|
||||
}
|
||||
}
|
||||
if _, err := db.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil {
|
||||
return errors.Wrap(err, "failed to insert new goose version")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
grayColor = "\033[90m"
|
||||
resetColor = "\033[00m"
|
||||
)
|
||||
|
||||
func verboseInfo(s string, args ...interface{}) {
|
||||
log.Printf(grayColor+s+resetColor, args...)
|
||||
}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
package migration
|
||||
|
||||
//go:generate gopackmigration -dir ../../migrations
|
||||
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
package migration
|
||||
|
||||
import "regexp"
|
||||
|
||||
var (
|
||||
matchSQLComments = regexp.MustCompile(`(?m)^--.*$[\r\n]*`)
|
||||
matchEmptyEOL = regexp.MustCompile(`(?m)^$[\r\n]*`) // TODO: Duplicate
|
||||
)
|
||||
|
||||
func clearStatement(s string) string {
|
||||
s = matchSQLComments.ReplaceAllString(s, ``)
|
||||
return matchEmptyEOL.ReplaceAllString(s, ``)
|
||||
}
|
|
@ -1,212 +0,0 @@
|
|||
package migration
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type parserState int
|
||||
|
||||
const (
|
||||
start parserState = iota // 0
|
||||
upStatement // 1
|
||||
upStatementBegin // 2
|
||||
upStatementEnd // 3
|
||||
downStatement // 4
|
||||
downStatementBegin // 5
|
||||
downStatementEnd // 6
|
||||
)
|
||||
|
||||
type stateMachine parserState
|
||||
|
||||
func (s *stateMachine) Get() parserState {
|
||||
return parserState(*s)
|
||||
}
|
||||
func (s *stateMachine) Set(new parserState) {
|
||||
*s = stateMachine(new)
|
||||
}
|
||||
|
||||
const scanBufSize = 4 * 1024 * 1024
|
||||
|
||||
var matchEmptyLines = regexp.MustCompile(`^\s*$`)
|
||||
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, scanBufSize)
|
||||
},
|
||||
}
|
||||
|
||||
// Split given SQL script into individual statements and return
|
||||
// SQL statements for given direction (up=true, down=false).
|
||||
//
|
||||
// The base case is to simply split on semicolons, as these
|
||||
// naturally terminate a statement.
|
||||
//
|
||||
// However, more complex cases like pl/pgsql can have semicolons
|
||||
// within a statement. For these cases, we provide the explicit annotations
|
||||
// 'StatementBegin' and 'StatementEnd' to allow the script to
|
||||
// tell us to ignore semicolons.
|
||||
func parseSQLMigration(r io.Reader, direction bool) (stmts []string, useTx bool, err error) {
|
||||
var buf bytes.Buffer
|
||||
scanBuf := bufferPool.Get().([]byte)
|
||||
defer bufferPool.Put(scanBuf)
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Buffer(scanBuf, scanBufSize)
|
||||
|
||||
stateMachine := stateMachine(start)
|
||||
useTx = true
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "--") {
|
||||
cmd := strings.TrimSpace(strings.TrimPrefix(line, "--"))
|
||||
|
||||
switch cmd {
|
||||
case "+goose Up":
|
||||
switch stateMachine.Get() {
|
||||
case start:
|
||||
stateMachine.Set(upStatement)
|
||||
default:
|
||||
return nil, false, errors.Errorf("duplicate '-- +goose Up' annotations; stateMachine=%v, see https://github.com/c9s/goose#sql-migrations", stateMachine)
|
||||
}
|
||||
continue
|
||||
|
||||
case "+goose Down":
|
||||
switch stateMachine.Get() {
|
||||
case upStatement, upStatementEnd:
|
||||
stateMachine.Set(downStatement)
|
||||
default:
|
||||
return nil, false, errors.Errorf("must start with '-- +goose Up' annotation, stateMachine=%v, see https://github.com/c9s/goose#sql-migrations", stateMachine)
|
||||
}
|
||||
continue
|
||||
|
||||
case "+goose StatementBegin":
|
||||
switch stateMachine.Get() {
|
||||
case upStatement, upStatementEnd:
|
||||
stateMachine.Set(upStatementBegin)
|
||||
case downStatement, downStatementEnd:
|
||||
stateMachine.Set(downStatementBegin)
|
||||
default:
|
||||
return nil, false, errors.Errorf("'-- +goose StatementBegin' must be defined after '-- +goose Up' or '-- +goose Down' annotation, stateMachine=%v, see https://github.com/c9s/goose#sql-migrations", stateMachine)
|
||||
}
|
||||
continue
|
||||
|
||||
case "+goose StatementEnd":
|
||||
switch stateMachine.Get() {
|
||||
case upStatementBegin:
|
||||
stateMachine.Set(upStatementEnd)
|
||||
case downStatementBegin:
|
||||
stateMachine.Set(downStatementEnd)
|
||||
default:
|
||||
return nil, false, errors.New("'-- +goose StatementEnd' must be defined after '-- +goose StatementBegin', see https://github.com/c9s/goose#sql-migrations")
|
||||
}
|
||||
|
||||
case "+goose NO TRANSACTION":
|
||||
useTx = false
|
||||
continue
|
||||
|
||||
default:
|
||||
// Ignore comments.
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Ignore empty lines.
|
||||
if matchEmptyLines.MatchString(line) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Write SQL line to a buffer.
|
||||
if _, err := buf.WriteString(line + "\n"); err != nil {
|
||||
return nil, false, errors.Wrap(err, "failed to write to buf")
|
||||
}
|
||||
|
||||
// Read SQL body one by line, if we're in the right direction.
|
||||
//
|
||||
// 1) basic query with semicolon; 2) psql statement
|
||||
//
|
||||
// Export statement once we hit end of statement.
|
||||
switch stateMachine.Get() {
|
||||
case upStatement, upStatementBegin, upStatementEnd:
|
||||
if !direction /*down*/ {
|
||||
buf.Reset()
|
||||
continue
|
||||
}
|
||||
case downStatement, downStatementBegin, downStatementEnd:
|
||||
if direction /*up*/ {
|
||||
buf.Reset()
|
||||
continue
|
||||
}
|
||||
default:
|
||||
return nil, false, errors.Errorf("failed to parse migration: unexpected state %q on line %q, see https://github.com/c9s/goose#sql-migrations", stateMachine, line)
|
||||
}
|
||||
|
||||
switch stateMachine.Get() {
|
||||
case upStatement:
|
||||
if endsWithSemicolon(line) {
|
||||
stmts = append(stmts, buf.String())
|
||||
buf.Reset()
|
||||
}
|
||||
case downStatement:
|
||||
if endsWithSemicolon(line) {
|
||||
stmts = append(stmts, buf.String())
|
||||
buf.Reset()
|
||||
}
|
||||
case upStatementEnd:
|
||||
stmts = append(stmts, buf.String())
|
||||
buf.Reset()
|
||||
stateMachine.Set(upStatement)
|
||||
case downStatementEnd:
|
||||
stmts = append(stmts, buf.String())
|
||||
buf.Reset()
|
||||
stateMachine.Set(downStatement)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, false, errors.Wrap(err, "failed to scan migration")
|
||||
}
|
||||
// EOF
|
||||
|
||||
switch stateMachine.Get() {
|
||||
case start:
|
||||
return nil, false, errors.New("failed to parse migration: must start with '-- +goose Up' annotation, see https://github.com/c9s/goose#sql-migrations")
|
||||
case upStatementBegin, downStatementBegin:
|
||||
return nil, false, errors.New("failed to parse migration: missing '-- +goose StatementEnd' annotation")
|
||||
}
|
||||
|
||||
if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
|
||||
return nil, false, errors.Errorf("failed to parse migration: state %q, direction: %v: unexpected unfinished SQL query: %q: missing semicolon?", stateMachine, direction, bufferRemaining)
|
||||
}
|
||||
|
||||
return stmts, useTx, nil
|
||||
}
|
||||
|
||||
// Checks the line to see if the line has a statement-ending semicolon
|
||||
// or if the line contains a double-dash comment.
|
||||
func endsWithSemicolon(line string) bool {
|
||||
scanBuf := bufferPool.Get().([]byte)
|
||||
defer bufferPool.Put(scanBuf)
|
||||
|
||||
prev := ""
|
||||
scanner := bufio.NewScanner(strings.NewReader(line))
|
||||
scanner.Buffer(scanBuf, scanBufSize)
|
||||
scanner.Split(bufio.ScanWords)
|
||||
|
||||
for scanner.Scan() {
|
||||
word := scanner.Text()
|
||||
if strings.HasPrefix(word, "--") {
|
||||
break
|
||||
}
|
||||
prev = word
|
||||
}
|
||||
|
||||
return strings.HasSuffix(prev, ";")
|
||||
}
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log"
|
||||
)
|
||||
|
||||
// UpTo migrates up to a specific version.
|
||||
func UpTo(ctx context.Context, db *sql.DB, dir string, version int64) error {
|
||||
migrations, err := CollectMigrationsFromDir(dir, minVersion, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
current, err := GetDBVersion(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
next, err := migrations.Next(current)
|
||||
if err != nil {
|
||||
if err == ErrNoNextVersion {
|
||||
log.Printf("no migrations to run. current version: %d\n", current)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err = next.Up(ctx, db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Up applies all available migrations.
|
||||
func Up(ctx context.Context, db *sql.DB, dir string) error {
|
||||
return UpTo(ctx, db, dir, maxVersion)
|
||||
}
|
||||
|
||||
// UpByOne migrates up by a single version.
|
||||
func UpByOne(ctx context.Context, db *sql.DB, dir string) error {
|
||||
migrations, err := CollectMigrationsFromDir(dir, minVersion, maxVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
currentVersion, err := GetDBVersion(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
next, err := migrations.Next(currentVersion)
|
||||
if err != nil {
|
||||
if err == ErrNoNextVersion {
|
||||
log.Printf("no migrations to run. current version: %d\n", currentVersion)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err = next.Up(ctx, db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,91 +0,0 @@
|
|||
package migration
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"log"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Version prints the current version of the database.
|
||||
func Version(db *sql.DB, dir string) error {
|
||||
current, err := GetDBVersion(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("cmd: version %v\n", current)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDBVersion is an alias for EnsureDBVersion, but returns -1 in error.
|
||||
func GetDBVersion(db *sql.DB) (int64, error) {
|
||||
version, err := EnsureDBVersion(db)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// EnsureDBVersion retrieves the current version for this DB.
|
||||
// Create and initialize the DB version table if it doesn't exist.
|
||||
func EnsureDBVersion(db *sql.DB) (int64, error) {
|
||||
rows, err := GetDialect().dbVersionQuery(db)
|
||||
if err != nil {
|
||||
return 0, createVersionTable(db)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// The most recent record for each migration specifies
|
||||
// whether it has been applied or rolled back.
|
||||
// The first version we find that has been applied is the current version.
|
||||
|
||||
toSkip := make([]int64, 0)
|
||||
|
||||
for rows.Next() {
|
||||
var row MigrationRecord
|
||||
if err = rows.Scan(&row.VersionID, &row.IsApplied); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to scan row")
|
||||
}
|
||||
|
||||
// have we already marked this version to be skipped?
|
||||
skip := false
|
||||
for _, v := range toSkip {
|
||||
if v == row.VersionID {
|
||||
skip = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if skip {
|
||||
continue
|
||||
}
|
||||
|
||||
// if version has been applied we're done
|
||||
if row.IsApplied {
|
||||
return row.VersionID, nil
|
||||
}
|
||||
|
||||
// latest version of migration has not been applied.
|
||||
toSkip = append(toSkip, row.VersionID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get next row")
|
||||
}
|
||||
|
||||
return 0, ErrNoNextVersion
|
||||
}
|
||||
|
||||
|
||||
var tableName = "goose_db_version"
|
||||
|
||||
// TableName returns goose db version table name
|
||||
func TableName() string {
|
||||
return tableName
|
||||
}
|
||||
|
||||
// SetTableName set goose db version table name
|
||||
func SetTableName(n string) {
|
||||
tableName = n
|
||||
}
|
Loading…
Reference in New Issue
Block a user