refactor sync command and add integration tests

This commit is contained in:
c9s 2022-06-05 01:01:59 +08:00
parent 425f8674d2
commit 39fcf1a51b
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
3 changed files with 65 additions and 58 deletions

View File

@ -194,27 +194,27 @@ var BacktestCmd = &cobra.Command{
sourceExchanges[exName] = publicExchange sourceExchanges[exName] = publicExchange
} }
if wantSync { var syncFromTime time.Time
var syncFromTime time.Time
// override the sync from time if the option is given // user can override the sync from time if the option is given
if len(syncFromDateStr) > 0 { if len(syncFromDateStr) > 0 {
syncFromTime, err = time.Parse(types.DateFormat, syncFromDateStr) syncFromTime, err = time.Parse(types.DateFormat, syncFromDateStr)
if err != nil { if err != nil {
return err return err
}
if syncFromTime.After(startTime) {
return fmt.Errorf("sync-from time %s can not be latter than the backtest start time %s", syncFromTime, startTime)
}
} else {
// we need at least 1 month backward data for EMA and last prices
syncFromTime = startTime.AddDate(0, -1, 0)
log.Infof("adjusted sync start time %s to %s for backward market data", startTime, syncFromTime)
} }
if syncFromTime.After(startTime) {
return fmt.Errorf("sync-from time %s can not be latter than the backtest start time %s", syncFromTime, startTime)
}
} else {
// we need at least 1 month backward data for EMA and last prices
syncFromTime = startTime.AddDate(0, -1, 0)
log.Infof("adjusted sync start time %s to %s for backward market data", startTime, syncFromTime)
}
if wantSync {
log.Infof("starting synchronization: %v", userConfig.Backtest.Symbols) log.Infof("starting synchronization: %v", userConfig.Backtest.Symbols)
if err := sync(ctx, userConfig, backtestService, sourceExchanges, syncFromTime); err != nil { if err := sync(ctx, userConfig, backtestService, sourceExchanges, syncFromTime, endTime); err != nil {
return err return err
} }
log.Info("synchronization done") log.Info("synchronization done")
@ -650,9 +650,8 @@ func toExchangeSources(sessions map[string]*bbgo.ExchangeSession, extraIntervals
return exchangeSources, nil return exchangeSources, nil
} }
func sync(ctx context.Context, userConfig *bbgo.Config, backtestService *service.BacktestService, sourceExchanges map[types.ExchangeName]types.Exchange, syncFromTime time.Time) error { func sync(ctx context.Context, userConfig *bbgo.Config, backtestService *service.BacktestService, sourceExchanges map[types.ExchangeName]types.Exchange, syncFrom, syncTo time.Time) error {
for _, symbol := range userConfig.Backtest.Symbols { for _, symbol := range userConfig.Backtest.Symbols {
for _, sourceExchange := range sourceExchanges { for _, sourceExchange := range sourceExchanges {
exCustom, ok := sourceExchange.(types.CustomIntervalProvider) exCustom, ok := sourceExchange.(types.CustomIntervalProvider)
@ -663,11 +662,7 @@ func sync(ctx context.Context, userConfig *bbgo.Config, backtestService *service
supportIntervals = types.SupportedIntervals supportIntervals = types.SupportedIntervals
} }
now := time.Now()
for interval := range supportIntervals { for interval := range supportIntervals {
// if err := s.SyncKLineByInterval(ctx, exchange, symbol, interval, startTime, endTime); err != nil {
// return err
// }
firstKLine, err := backtestService.QueryFirstKLine(sourceExchange.Name(), symbol, interval) firstKLine, err := backtestService.QueryFirstKLine(sourceExchange.Name(), symbol, interval)
if err != nil { if err != nil {
return errors.Wrapf(err, "failed to query backtest kline") return errors.Wrapf(err, "failed to query backtest kline")
@ -676,13 +671,13 @@ func sync(ctx context.Context, userConfig *bbgo.Config, backtestService *service
// if we don't have klines before the start time endpoint, the back-test will fail. // if we don't have klines before the start time endpoint, the back-test will fail.
// because the last price will be missing. // because the last price will be missing.
if firstKLine != nil { if firstKLine != nil {
log.Debugf("found existing kline data using partial sync...")
if err := backtestService.SyncExist(ctx, sourceExchange, symbol, syncFromTime, now, interval); err != nil { if err := backtestService.SyncPartial(ctx, sourceExchange, symbol, interval, syncFrom, syncTo); err != nil {
return err return err
} }
} else { } else {
log.Debugf("starting a fresh kline data sync...") log.Debugf("starting a fresh kline data sync...")
if err := backtestService.Sync(ctx, sourceExchange, symbol, syncFromTime, now, interval); err != nil { if err := backtestService.Sync(ctx, sourceExchange, symbol, interval, syncFrom, syncTo); err != nil {
return err return err
} }
} }

View File

@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"os" "os"
"strconv" "strconv"
@ -114,8 +115,7 @@ func (s *BacktestService) Verify(symbols []string, startTime time.Time, endTime
return nil return nil
} }
func (s *BacktestService) Sync(ctx context.Context, exchange types.Exchange, symbol string, func (s *BacktestService) Sync(ctx context.Context, exchange types.Exchange, symbol string, interval types.Interval, startTime, endTime time.Time) error {
startTime time.Time, endTime time.Time, interval types.Interval) error {
return s.SyncKLineByInterval(ctx, exchange, symbol, interval, startTime, endTime) return s.SyncKLineByInterval(ctx, exchange, symbol, interval, startTime, endTime)
} }
@ -317,33 +317,6 @@ func (s *BacktestService) _deleteDuplicatedKLine(k types.KLine) error {
return err return err
} }
func (s *BacktestService) SyncExist(ctx context.Context, exchange types.Exchange, symbol string,
fromTime time.Time, endTime time.Time, interval types.Interval) error {
klineC, errC := s.QueryKLinesCh(fromTime, endTime, exchange, []string{symbol}, []types.Interval{interval})
nowStartTime := fromTime
for k := range klineC {
if nowStartTime.Unix() < k.StartTime.Unix() {
log.Infof("syncing %s interval %s syncing %s ~ %s ", symbol, interval, nowStartTime, k.EndTime)
if err := s.Sync(ctx, exchange, symbol, nowStartTime, k.EndTime.Time().Add(-1*interval.Duration()), interval); err != nil {
log.WithError(err).Errorf("sync error")
}
}
nowStartTime = k.StartTime.Time().Add(interval.Duration())
}
if nowStartTime.Unix() < endTime.Unix() && nowStartTime.Unix() < time.Now().Unix() {
if err := s.Sync(ctx, exchange, symbol, nowStartTime, endTime, interval); err != nil {
log.WithError(err).Errorf("sync error")
}
}
if err := <-errC; err != nil {
return err
}
return nil
}
type TimeRange struct { type TimeRange struct {
Start time.Time Start time.Time
End time.Time End time.Time
@ -356,13 +329,23 @@ type TimeRange struct {
// iterate the []TimeRange slice to sync data. // iterate the []TimeRange slice to sync data.
func (s *BacktestService) SyncPartial(ctx context.Context, ex types.Exchange, symbol string, interval types.Interval, since, until time.Time) error { func (s *BacktestService) SyncPartial(ctx context.Context, ex types.Exchange, symbol string, interval types.Interval, since, until time.Time) error {
t1, t2, err := s.QueryExistingDataRange(ctx, ex, symbol, interval, since, until) t1, t2, err := s.QueryExistingDataRange(ctx, ex, symbol, interval, since, until)
if err != nil && err != sql.ErrNoRows {
return err
}
if err == sql.ErrNoRows {
// fallback to fresh sync
return s.Sync(ctx, ex, symbol, interval, since, until)
}
log.Debugf("found existing kline data, now using partial sync...")
timeRanges, err := s.FindMissingTimeRanges(ctx, ex, symbol, interval, t1.Time(), t2.Time())
if err != nil { if err != nil {
return err return err
} }
timeRanges, err := s.FindMissingTimeRanges(ctx, ex, symbol, interval, t1.Time(), t2.Time()) if len(timeRanges) > 0 {
if err != nil { log.Infof("found missing time ranges: %v", timeRanges)
return err
} }
// there are few cases: // there are few cases:
@ -440,6 +423,7 @@ func (s *BacktestService) QueryExistingDataRange(ctx context.Context, ex types.E
var t1, t2 types.Time var t1, t2 types.Time
row := s.DB.QueryRowContext(ctx, sql, args...) row := s.DB.QueryRowContext(ctx, sql, args...)
if err := row.Scan(&t1, &t2); err != nil { if err := row.Scan(&t1, &t2); err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"database/sql"
"testing" "testing"
"time" "time"
@ -13,6 +14,33 @@ import (
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
) )
func TestBacktestService_QueryExistingDataRange(t *testing.T) {
db, err := prepareDB(t)
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx := context.Background()
dbx := sqlx.NewDb(db.DB, "sqlite3")
ex, err := exchange.NewPublic(types.ExchangeBinance)
assert.NoError(t, err)
service := &BacktestService{DB: dbx}
symbol := "BTCUSDT"
now := time.Now()
startTime1 := now.AddDate(0, 0, -7).Truncate(time.Hour)
endTime1 := now.AddDate(0, 0, -6).Truncate(time.Hour)
// empty range
t1, t2, err := service.QueryExistingDataRange(ctx, ex, symbol, types.Interval1h, startTime1, endTime1)
assert.Error(t, sql.ErrNoRows, err)
assert.Nil(t, t1)
assert.Nil(t, t2)
}
func TestBacktestService_SyncPartial(t *testing.T) { func TestBacktestService_SyncPartial(t *testing.T) {
db, err := prepareDB(t) db, err := prepareDB(t)
if err != nil { if err != nil {