improve bbgo db migration process

This commit is contained in:
c9s 2024-01-19 00:34:02 +08:00
parent 3e233627be
commit 611b2a9247
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
3 changed files with 72 additions and 23 deletions

View File

@ -22,7 +22,9 @@ type BacktestService struct {
DB *sqlx.DB DB *sqlx.DB
} }
func (s *BacktestService) SyncKLineByInterval(ctx context.Context, exchange types.Exchange, symbol string, interval types.Interval, startTime, endTime time.Time) error { func (s *BacktestService) SyncKLineByInterval(
ctx context.Context, exchange types.Exchange, symbol string, interval types.Interval, startTime, endTime time.Time,
) error {
_, isFutures, isIsolated, isolatedSymbol := exchange2.GetSessionAttributes(exchange) _, isFutures, isIsolated, isolatedSymbol := exchange2.GetSessionAttributes(exchange)
// override symbol if isolatedSymbol is not empty // override symbol if isolatedSymbol is not empty
@ -127,7 +129,9 @@ func (s *BacktestService) Verify(sourceExchange types.Exchange, symbols []string
return nil return nil
} }
func (s *BacktestService) SyncFresh(ctx context.Context, exchange types.Exchange, symbol string, interval types.Interval, startTime, endTime time.Time) error { func (s *BacktestService) SyncFresh(
ctx context.Context, exchange types.Exchange, symbol string, interval types.Interval, startTime, endTime time.Time,
) error {
log.Infof("starting fresh sync %s %s %s: %s <=> %s", exchange.Name(), symbol, interval, startTime, endTime) log.Infof("starting fresh sync %s %s %s: %s <=> %s", exchange.Name(), symbol, interval, startTime, endTime)
startTime = startTime.Truncate(time.Minute).Add(-2 * time.Second) startTime = startTime.Truncate(time.Minute).Add(-2 * time.Second)
endTime = endTime.Truncate(time.Minute).Add(2 * time.Second) endTime = endTime.Truncate(time.Minute).Add(2 * time.Second)
@ -135,7 +139,9 @@ func (s *BacktestService) SyncFresh(ctx context.Context, exchange types.Exchange
} }
// QueryKLine queries the klines from the database // QueryKLine queries the klines from the database
func (s *BacktestService) QueryKLine(ex types.Exchange, symbol string, interval types.Interval, orderBy string, limit int) (*types.KLine, error) { func (s *BacktestService) QueryKLine(
ex types.Exchange, symbol string, interval types.Interval, orderBy string, limit int,
) (*types.KLine, error) {
log.Infof("querying last kline exchange = %s AND symbol = %s AND interval = %s", ex, symbol, interval) log.Infof("querying last kline exchange = %s AND symbol = %s AND interval = %s", ex, symbol, interval)
tableName := targetKlineTable(ex) tableName := targetKlineTable(ex)
@ -166,7 +172,9 @@ func (s *BacktestService) QueryKLine(ex types.Exchange, symbol string, interval
} }
// QueryKLinesForward is used for querying klines to back-testing // QueryKLinesForward is used for querying klines to back-testing
func (s *BacktestService) QueryKLinesForward(exchange types.Exchange, symbol string, interval types.Interval, startTime time.Time, limit int) ([]types.KLine, error) { func (s *BacktestService) QueryKLinesForward(
exchange types.Exchange, symbol string, interval types.Interval, startTime time.Time, limit int,
) ([]types.KLine, error) {
tableName := targetKlineTable(exchange) tableName := targetKlineTable(exchange)
sql := "SELECT * FROM `binance_klines` WHERE `end_time` >= :start_time AND `symbol` = :symbol AND `interval` = :interval and exchange = :exchange ORDER BY end_time ASC LIMIT :limit" sql := "SELECT * FROM `binance_klines` WHERE `end_time` >= :start_time AND `symbol` = :symbol AND `interval` = :interval and exchange = :exchange ORDER BY end_time ASC LIMIT :limit"
sql = strings.ReplaceAll(sql, "binance_klines", tableName) sql = strings.ReplaceAll(sql, "binance_klines", tableName)
@ -185,7 +193,9 @@ func (s *BacktestService) QueryKLinesForward(exchange types.Exchange, symbol str
return s.scanRows(rows) return s.scanRows(rows)
} }
func (s *BacktestService) QueryKLinesBackward(exchange types.Exchange, symbol string, interval types.Interval, endTime time.Time, limit int) ([]types.KLine, error) { func (s *BacktestService) QueryKLinesBackward(
exchange types.Exchange, symbol string, interval types.Interval, endTime time.Time, limit int,
) ([]types.KLine, error) {
tableName := targetKlineTable(exchange) tableName := targetKlineTable(exchange)
sql := "SELECT * FROM `binance_klines` WHERE `end_time` <= :end_time and exchange = :exchange AND `symbol` = :symbol AND `interval` = :interval ORDER BY end_time DESC LIMIT :limit" sql := "SELECT * FROM `binance_klines` WHERE `end_time` <= :end_time and exchange = :exchange AND `symbol` = :symbol AND `interval` = :interval ORDER BY end_time DESC LIMIT :limit"
@ -206,7 +216,9 @@ func (s *BacktestService) QueryKLinesBackward(exchange types.Exchange, symbol st
return s.scanRows(rows) return s.scanRows(rows)
} }
func (s *BacktestService) QueryKLinesCh(since, until time.Time, exchange types.Exchange, symbols []string, intervals []types.Interval) (chan types.KLine, chan error) { func (s *BacktestService) QueryKLinesCh(
since, until time.Time, exchange types.Exchange, symbols []string, intervals []types.Interval,
) (chan types.KLine, chan error) {
if len(symbols) == 0 { if len(symbols) == 0 {
return returnError(errors.Errorf("symbols is empty when querying kline, plesae check your strategy setting. ")) return returnError(errors.Errorf("symbols is empty when querying kline, plesae check your strategy setting. "))
} }
@ -361,7 +373,9 @@ func (t *TimeRange) String() string {
return t.Start.String() + " ~ " + t.End.String() return t.Start.String() + " ~ " + t.End.String()
} }
func (s *BacktestService) Sync(ctx context.Context, ex types.Exchange, symbol string, interval types.Interval, since, until time.Time) error { func (s *BacktestService) Sync(
ctx context.Context, ex types.Exchange, symbol string, interval types.Interval, since, until time.Time,
) error {
t1, t2, err := s.QueryExistingDataRange(ctx, ex, symbol, interval, since, until) t1, t2, err := s.QueryExistingDataRange(ctx, ex, symbol, interval, since, until)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return err
@ -380,7 +394,9 @@ func (s *BacktestService) Sync(ctx context.Context, ex types.Exchange, symbol st
// scan if there is a missing part // scan if there is a missing part
// create a time range slice []TimeRange // create a time range slice []TimeRange
// iterate the []TimeRange slice to sync data. // iterate the []TimeRange slice to sync data.
func (s *BacktestService) SyncPartial(ctx context.Context, ex types.Exchange, symbol string, interval types.Interval, since, until time.Time) error { func (s *BacktestService) SyncPartial(
ctx context.Context, ex types.Exchange, symbol string, interval types.Interval, since, until time.Time,
) error {
log.Infof("starting partial sync %s %s %s: %s <=> %s", ex.Name(), symbol, interval, since, until) log.Infof("starting partial sync %s %s %s: %s <=> %s", ex.Name(), symbol, interval, since, until)
t1, t2, err := s.QueryExistingDataRange(ctx, ex, symbol, interval, since, until) t1, t2, err := s.QueryExistingDataRange(ctx, ex, symbol, interval, since, until)
@ -431,7 +447,9 @@ func (s *BacktestService) SyncPartial(ctx context.Context, ex types.Exchange, sy
// FindMissingTimeRanges returns the missing time ranges, the start/end time represents the existing data time points. // FindMissingTimeRanges returns the missing time ranges, the start/end time represents the existing data time points.
// So when sending kline query to the exchange API, we need to add one second to the start time and minus one second to the end time. // So when sending kline query to the exchange API, we need to add one second to the start time and minus one second to the end time.
func (s *BacktestService) FindMissingTimeRanges(ctx context.Context, ex types.Exchange, symbol string, interval types.Interval, since, until time.Time) ([]TimeRange, error) { func (s *BacktestService) FindMissingTimeRanges(
ctx context.Context, ex types.Exchange, symbol string, interval types.Interval, since, until time.Time,
) ([]TimeRange, error) {
query := s.SelectKLineTimePoints(ex, symbol, interval, since, until) query := s.SelectKLineTimePoints(ex, symbol, interval, since, until)
sql, args, err := query.ToSql() sql, args, err := query.ToSql()
if err != nil { if err != nil {
@ -474,7 +492,9 @@ func (s *BacktestService) FindMissingTimeRanges(ctx context.Context, ex types.Ex
return timeRanges, nil return timeRanges, nil
} }
func (s *BacktestService) QueryExistingDataRange(ctx context.Context, ex types.Exchange, symbol string, interval types.Interval, tArgs ...time.Time) (start, end *types.Time, err error) { func (s *BacktestService) QueryExistingDataRange(
ctx context.Context, ex types.Exchange, symbol string, interval types.Interval, tArgs ...time.Time,
) (start, end *types.Time, err error) {
sel := s.SelectKLineTimeRange(ex, symbol, interval, tArgs...) sel := s.SelectKLineTimeRange(ex, symbol, interval, tArgs...)
sql, args, err := sel.ToSql() sql, args, err := sel.ToSql()
if err != nil { if err != nil {
@ -493,14 +513,16 @@ func (s *BacktestService) QueryExistingDataRange(ctx context.Context, ex types.E
return nil, nil, err return nil, nil, err
} }
if t1 == (types.Time{}) || t2 == (types.Time{}) { if t1.Time().IsZero() || t2.Time().IsZero() {
return nil, nil, nil return nil, nil, nil
} }
return &t1, &t2, nil return &t1, &t2, nil
} }
func (s *BacktestService) SelectKLineTimePoints(ex types.Exchange, symbol string, interval types.Interval, args ...time.Time) sq.SelectBuilder { func (s *BacktestService) SelectKLineTimePoints(
ex types.Exchange, symbol string, interval types.Interval, args ...time.Time,
) sq.SelectBuilder {
conditions := sq.And{ conditions := sq.And{
sq.Eq{"symbol": symbol}, sq.Eq{"symbol": symbol},
sq.Eq{"`interval`": interval.String()}, sq.Eq{"`interval`": interval.String()},
@ -521,7 +543,9 @@ func (s *BacktestService) SelectKLineTimePoints(ex types.Exchange, symbol string
} }
// SelectKLineTimeRange returns the existing klines time range (since < kline.start_time < until) // SelectKLineTimeRange returns the existing klines time range (since < kline.start_time < until)
func (s *BacktestService) SelectKLineTimeRange(ex types.Exchange, symbol string, interval types.Interval, args ...time.Time) sq.SelectBuilder { func (s *BacktestService) SelectKLineTimeRange(
ex types.Exchange, symbol string, interval types.Interval, args ...time.Time,
) sq.SelectBuilder {
conditions := sq.And{ conditions := sq.And{
sq.Eq{"symbol": symbol}, sq.Eq{"symbol": symbol},
sq.Eq{"`interval`": interval.String()}, sq.Eq{"`interval`": interval.String()},
@ -544,7 +568,9 @@ func (s *BacktestService) SelectKLineTimeRange(ex types.Exchange, symbol string,
} }
// TODO: add is_futures column since the klines data is different // TODO: add is_futures column since the klines data is different
func (s *BacktestService) SelectLastKLines(ex types.Exchange, symbol string, interval types.Interval, startTime, endTime time.Time, limit uint64) sq.SelectBuilder { func (s *BacktestService) SelectLastKLines(
ex types.Exchange, symbol string, interval types.Interval, startTime, endTime time.Time, limit uint64,
) sq.SelectBuilder {
tableName := targetKlineTable(ex) tableName := targetKlineTable(ex)
return sq.Select("*"). return sq.Select("*").
From(tableName). From(tableName).

View File

@ -3,10 +3,11 @@ package service
import ( import (
"context" "context"
"github.com/c9s/rockhopper/v2"
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/c9s/rockhopper/v2"
mysqlMigrations "github.com/c9s/bbgo/pkg/migrations/mysql" mysqlMigrations "github.com/c9s/bbgo/pkg/migrations/mysql"
sqlite3Migrations "github.com/c9s/bbgo/pkg/migrations/sqlite3" sqlite3Migrations "github.com/c9s/bbgo/pkg/migrations/sqlite3"
) )
@ -72,9 +73,27 @@ func (s *DatabaseService) Upgrade(ctx context.Context) error {
// sqlx.DB is different from sql.DB // sqlx.DB is different from sql.DB
rh := rockhopper.New(s.Driver, dialect, s.DB.DB, rockhopper.TableName) rh := rockhopper.New(s.Driver, dialect, s.DB.DB, rockhopper.TableName)
migrations = migrations.FilterPackage([]string{"main"}).SortAndConnect() if err := rh.Touch(ctx); err != nil {
return err
}
return rockhopper.Align(ctx, rh, 20231123125402, migrations) migrations = migrations.FilterPackage([]string{"main"}).SortAndConnect()
if len(migrations) == 0 {
return nil
}
_, lastAppliedMigration, err := rh.FindLastAppliedMigration(ctx, migrations)
if err != nil {
return err
}
if lastAppliedMigration != nil {
return rockhopper.Up(ctx, rh, lastAppliedMigration.Next, 0)
}
// TODO: use align in the next major version
// return rockhopper.Align(ctx, rh, 20231123125402, migrations)
return rockhopper.Up(ctx, rh, migrations.Head(), 0)
} }
func ReformatMysqlDSN(dsn string) (string, error) { func ReformatMysqlDSN(dsn string) (string, error) {

View File

@ -4,11 +4,14 @@ import (
"context" "context"
"testing" "testing"
"github.com/c9s/rockhopper/v2"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/c9s/rockhopper/v2"
) )
func prepareDB(t *testing.T) (*rockhopper.DB, error) { func prepareDB(t *testing.T) (*rockhopper.DB, error) {
ctx := context.Background()
dialect, err := rockhopper.LoadDialect("sqlite3") dialect, err := rockhopper.LoadDialect("sqlite3")
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return nil, err return nil, err
@ -16,28 +19,29 @@ func prepareDB(t *testing.T) (*rockhopper.DB, error) {
assert.NotNil(t, dialect) assert.NotNil(t, dialect)
db, err := rockhopper.Open("sqlite3", dialect, ":memory:") db, err := rockhopper.Open("sqlite3", dialect, ":memory:", rockhopper.TableName)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return nil, err return nil, err
} }
assert.NotNil(t, db) assert.NotNil(t, db)
_, err = db.CurrentVersion() err = db.Touch(ctx)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return nil, err return nil, err
} }
var loader rockhopper.SqlMigrationLoader var loader = &rockhopper.SqlMigrationLoader{}
migrations, err := loader.Load("../../migrations/sqlite3") migrations, err := loader.Load("../../migrations/sqlite3")
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return nil, err return nil, err
} }
migrations = migrations.Sort().Connect()
assert.NotEmpty(t, migrations) assert.NotEmpty(t, migrations)
ctx := context.Background() err = rockhopper.Up(ctx, db, migrations.Head(), 0)
err = rockhopper.Up(ctx, db, migrations, 0, 0)
assert.NoError(t, err, "should migrate successfully") assert.NoError(t, err, "should migrate successfully")
return db, err return db, err