From c6ce223a13bd3c968306a83bd21682d9aebcc97c Mon Sep 17 00:00:00 2001 From: c9s Date: Tue, 6 Dec 2022 13:16:12 +0800 Subject: [PATCH] all: refactor backtest functions so that we can run backtest in test --- pkg/backtest/exchange.go | 3 +- pkg/backtest/utils.go | 57 +++++++++++++++++++++++++ pkg/cmd/backtest.go | 56 +++---------------------- pkg/strategy/grid2/strategy_test.go | 64 ++++++++++++++++++++++++----- 4 files changed, 116 insertions(+), 64 deletions(-) create mode 100644 pkg/backtest/utils.go diff --git a/pkg/backtest/exchange.go b/pkg/backtest/exchange.go index 0f4d1d44a..5faa41289 100644 --- a/pkg/backtest/exchange.go +++ b/pkg/backtest/exchange.go @@ -358,8 +358,7 @@ func (e *Exchange) SubscribeMarketData(startTime, endTime time.Time, requiredInt intervals = append(intervals, interval) } - log.Infof("using symbols: %v and intervals: %v for back-testing", symbols, intervals) - log.Infof("querying klines from database...") + log.Infof("querying klines from database with exchange: %v symbols: %v and intervals: %v for back-testing", e.Name(), symbols, intervals) klineC, errC := e.srv.QueryKLinesCh(startTime, endTime, e, symbols, intervals) go func() { if err := <-errC; err != nil { diff --git a/pkg/backtest/utils.go b/pkg/backtest/utils.go new file mode 100644 index 000000000..870f36556 --- /dev/null +++ b/pkg/backtest/utils.go @@ -0,0 +1,57 @@ +package backtest + +import ( + "time" + + "github.com/sirupsen/logrus" + + "github.com/c9s/bbgo/pkg/bbgo" + "github.com/c9s/bbgo/pkg/types" +) + +func CollectSubscriptionIntervals(environ *bbgo.Environment) (allKLineIntervals map[types.Interval]struct{}, requiredInterval types.Interval, backTestIntervals []types.Interval) { + // default extra back-test intervals + backTestIntervals = []types.Interval{types.Interval1h, types.Interval1d} + // all subscribed intervals + allKLineIntervals = make(map[types.Interval]struct{}) + + for _, interval := range backTestIntervals { + allKLineIntervals[interval] = struct{}{} + } + // default interval is 1m for all exchanges + requiredInterval = types.Interval1m + for _, session := range environ.Sessions() { + for _, sub := range session.Subscriptions { + if sub.Channel == types.KLineChannel { + if sub.Options.Interval.Seconds()%60 > 0 { + // if any subscription interval is less than 60s, then we will use 1s for back-testing + requiredInterval = types.Interval1s + logrus.Warnf("found kline subscription interval less than 60s, modify default backtest interval to 1s") + } + allKLineIntervals[sub.Options.Interval] = struct{}{} + } + } + } + return allKLineIntervals, requiredInterval, backTestIntervals +} + +func InitializeExchangeSources(sessions map[string]*bbgo.ExchangeSession, startTime, endTime time.Time, requiredInterval types.Interval, extraIntervals ...types.Interval) (exchangeSources []*ExchangeDataSource, err error) { + for _, session := range sessions { + backtestEx := session.Exchange.(*Exchange) + + c, err := backtestEx.SubscribeMarketData(startTime, endTime, requiredInterval, extraIntervals...) + if err != nil { + return exchangeSources, err + } + + sessionCopy := session + src := &ExchangeDataSource{ + C: c, + Exchange: backtestEx, + Session: sessionCopy, + } + backtestEx.Src = src + exchangeSources = append(exchangeSources, src) + } + return exchangeSources, nil +} diff --git a/pkg/cmd/backtest.go b/pkg/cmd/backtest.go index 3e2c24a42..b8d688603 100644 --- a/pkg/cmd/backtest.go +++ b/pkg/cmd/backtest.go @@ -11,11 +11,12 @@ import ( "syscall" "time" + "github.com/fatih/color" + "github.com/google/uuid" + "github.com/c9s/bbgo/pkg/cmd/cmdutil" "github.com/c9s/bbgo/pkg/data/tsv" "github.com/c9s/bbgo/pkg/util" - "github.com/fatih/color" - "github.com/google/uuid" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -295,8 +296,8 @@ var BacktestCmd = &cobra.Command{ return err } - allKLineIntervals, requiredInterval, backTestIntervals := collectSubscriptionIntervals(environ) - exchangeSources, err := toExchangeSources(environ.Sessions(), startTime, endTime, requiredInterval, backTestIntervals...) + allKLineIntervals, requiredInterval, backTestIntervals := backtest.CollectSubscriptionIntervals(environ) + exchangeSources, err := backtest.InitializeExchangeSources(environ.Sessions(), startTime, endTime, requiredInterval, backTestIntervals...) if err != nil { return err } @@ -593,32 +594,6 @@ var BacktestCmd = &cobra.Command{ }, } -func collectSubscriptionIntervals(environ *bbgo.Environment) (allKLineIntervals map[types.Interval]struct{}, requiredInterval types.Interval, backTestIntervals []types.Interval) { - // default extra back-test intervals - backTestIntervals = []types.Interval{types.Interval1h, types.Interval1d} - // all subscribed intervals - allKLineIntervals = make(map[types.Interval]struct{}) - - for _, interval := range backTestIntervals { - allKLineIntervals[interval] = struct{}{} - } - // default interval is 1m for all exchanges - requiredInterval = types.Interval1m - for _, session := range environ.Sessions() { - for _, sub := range session.Subscriptions { - if sub.Channel == types.KLineChannel { - if sub.Options.Interval.Seconds()%60 > 0 { - // if any subscription interval is less than 60s, then we will use 1s for back-testing - requiredInterval = types.Interval1s - log.Warnf("found kline subscription interval less than 60s, modify default backtest interval to 1s") - } - allKLineIntervals[sub.Options.Interval] = struct{}{} - } - } - } - return allKLineIntervals, requiredInterval, backTestIntervals -} - func createSymbolReport(userConfig *bbgo.Config, session *bbgo.ExchangeSession, symbol string, trades []types.Trade, intervalProfit *types.IntervalProfitCollector, profitFactor, winningRatio fixedpoint.Value) ( *backtest.SessionSymbolReport, @@ -722,27 +697,6 @@ func confirmation(s string) bool { } } -func toExchangeSources(sessions map[string]*bbgo.ExchangeSession, startTime, endTime time.Time, requiredInterval types.Interval, extraIntervals ...types.Interval) (exchangeSources []*backtest.ExchangeDataSource, err error) { - for _, session := range sessions { - backtestEx := session.Exchange.(*backtest.Exchange) - - c, err := backtestEx.SubscribeMarketData(startTime, endTime, requiredInterval, extraIntervals...) - if err != nil { - return exchangeSources, err - } - - sessionCopy := session - src := &backtest.ExchangeDataSource{ - C: c, - Exchange: backtestEx, - Session: sessionCopy, - } - backtestEx.Src = src - exchangeSources = append(exchangeSources, src) - } - return exchangeSources, nil -} - func sync(ctx context.Context, userConfig *bbgo.Config, backtestService *service.BacktestService, sourceExchanges map[types.ExchangeName]types.Exchange, syncFrom, syncTo time.Time) error { for _, symbol := range userConfig.Backtest.Symbols { for _, sourceExchange := range sourceExchanges { diff --git a/pkg/strategy/grid2/strategy_test.go b/pkg/strategy/grid2/strategy_test.go index f14e867ce..3912e98d5 100644 --- a/pkg/strategy/grid2/strategy_test.go +++ b/pkg/strategy/grid2/strategy_test.go @@ -4,6 +4,7 @@ package grid2 import ( "context" + "os" "testing" "github.com/sirupsen/logrus" @@ -350,12 +351,20 @@ func TestBacktestStrategy(t *testing.T) { GridNum: 100, QuoteInvestment: number(9000.0), } + RunBacktest(t, strategy) +} +func RunBacktest(t *testing.T, strategy bbgo.SingleExchangeStrategy) { // TEMPLATE {{{ start backtest - startTime, err := types.ParseLooseFormatTime("2021-06-01") + const sqliteDbFile = "../../../data/bbgo_test.sqlite3" + const backtestExchangeName = "binance" + const backtestStartTime = "2022-06-01" + const backtestEndTime = "2022-06-30" + + startTime, err := types.ParseLooseFormatTime(backtestStartTime) assert.NoError(t, err) - endTime, err := types.ParseLooseFormatTime("2021-06-30") + endTime, err := types.ParseLooseFormatTime(backtestEndTime) assert.NoError(t, err) backtestConfig := &bbgo.Backtest{ @@ -364,7 +373,7 @@ func TestBacktestStrategy(t *testing.T) { RecordTrades: false, FeeMode: bbgo.BacktestFeeModeToken, Accounts: map[string]bbgo.BacktestAccount{ - "binance": { + backtestExchangeName: { MakerFeeRate: number(0.075 * 0.01), TakerFeeRate: number(0.075 * 0.01), Balances: bbgo.BacktestAccountBalanceMap{ @@ -374,7 +383,7 @@ func TestBacktestStrategy(t *testing.T) { }, }, Symbols: []string{"BTCUSDT"}, - Sessions: []string{"binance"}, + Sessions: []string{backtestExchangeName}, SyncSecKLines: false, } @@ -384,8 +393,14 @@ func TestBacktestStrategy(t *testing.T) { environ := bbgo.NewEnvironment() environ.SetStartTime(startTime.Time()) - err = environ.ConfigureDatabaseDriver(ctx, "sqlite3", "../../../data/bbgo_test.sqlite3") + info, err := os.Stat(sqliteDbFile) assert.NoError(t, err) + t.Logf("sqlite: %+v", info) + + err = environ.ConfigureDatabaseDriver(ctx, "sqlite3", sqliteDbFile) + if !assert.NoError(t, err) { + return + } backtestService := &service.BacktestService{DB: environ.DatabaseService.DB} defer func() { @@ -397,22 +412,24 @@ func TestBacktestStrategy(t *testing.T) { bbgo.SetBackTesting(backtestService) defer bbgo.SetBackTesting(nil) - exName, err := types.ValidExchangeName("binance") + exName, err := types.ValidExchangeName(backtestExchangeName) if !assert.NoError(t, err) { return } + t.Logf("using exchange source: %s", exName) + publicExchange, err := exchange.NewPublic(exName) if !assert.NoError(t, err) { return } - backtestExchange, err := backtest.NewExchange(publicExchange.Name(), publicExchange, backtestService, backtestConfig) + backtestExchange, err := backtest.NewExchange(exName, publicExchange, backtestService, backtestConfig) if !assert.NoError(t, err) { return } - session := environ.AddExchange(exName.String(), backtestExchange) + session := environ.AddExchange(backtestExchangeName, backtestExchange) assert.NotNil(t, session) err = environ.Init(ctx) @@ -430,11 +447,11 @@ func TestBacktestStrategy(t *testing.T) { trader.DisableLogging() } - // TODO: add grid2 to the user config and run backtest userConfig := &bbgo.Config{ + Backtest: backtestConfig, ExchangeStrategies: []bbgo.ExchangeStrategyMount{ { - Mounts: []string{"binance"}, + Mounts: []string{backtestExchangeName}, Strategy: strategy, }, }, @@ -446,7 +463,32 @@ func TestBacktestStrategy(t *testing.T) { err = trader.Run(ctx) assert.NoError(t, err) - // TODO: feed data + allKLineIntervals, requiredInterval, backTestIntervals := backtest.CollectSubscriptionIntervals(environ) + t.Logf("requiredInterval: %s backTestIntervals: %v", requiredInterval, backTestIntervals) + _ = allKLineIntervals + exchangeSources, err := backtest.InitializeExchangeSources(environ.Sessions(), startTime.Time(), endTime.Time(), requiredInterval, backTestIntervals...) + if !assert.NoError(t, err) { + return + } + + doneC := make(chan struct{}) + go func() { + count := 0 + exSource := exchangeSources[0] + for k := range exSource.C { + exSource.Exchange.ConsumeKLine(k, requiredInterval) + count++ + } + + err = exSource.Exchange.CloseMarketData() + assert.NoError(t, err) + + assert.Greater(t, count, 0, "kLines count must be greater than 0, please check your backtest date range and symbol settings") + + close(doneC) + }() + + <-doneC // }}} }