all: refactor backtest functions so that we can run backtest in test

This commit is contained in:
c9s 2022-12-06 13:16:12 +08:00
parent 846695e632
commit c6ce223a13
4 changed files with 116 additions and 64 deletions

View File

@ -358,8 +358,7 @@ func (e *Exchange) SubscribeMarketData(startTime, endTime time.Time, requiredInt
intervals = append(intervals, interval) intervals = append(intervals, interval)
} }
log.Infof("using symbols: %v and intervals: %v for back-testing", symbols, intervals) log.Infof("querying klines from database with exchange: %v symbols: %v and intervals: %v for back-testing", e.Name(), symbols, intervals)
log.Infof("querying klines from database...")
klineC, errC := e.srv.QueryKLinesCh(startTime, endTime, e, symbols, intervals) klineC, errC := e.srv.QueryKLinesCh(startTime, endTime, e, symbols, intervals)
go func() { go func() {
if err := <-errC; err != nil { if err := <-errC; err != nil {

57
pkg/backtest/utils.go Normal file
View File

@ -0,0 +1,57 @@
package backtest
import (
"time"
"github.com/sirupsen/logrus"
"github.com/c9s/bbgo/pkg/bbgo"
"github.com/c9s/bbgo/pkg/types"
)
func CollectSubscriptionIntervals(environ *bbgo.Environment) (allKLineIntervals map[types.Interval]struct{}, requiredInterval types.Interval, backTestIntervals []types.Interval) {
// default extra back-test intervals
backTestIntervals = []types.Interval{types.Interval1h, types.Interval1d}
// all subscribed intervals
allKLineIntervals = make(map[types.Interval]struct{})
for _, interval := range backTestIntervals {
allKLineIntervals[interval] = struct{}{}
}
// default interval is 1m for all exchanges
requiredInterval = types.Interval1m
for _, session := range environ.Sessions() {
for _, sub := range session.Subscriptions {
if sub.Channel == types.KLineChannel {
if sub.Options.Interval.Seconds()%60 > 0 {
// if any subscription interval is less than 60s, then we will use 1s for back-testing
requiredInterval = types.Interval1s
logrus.Warnf("found kline subscription interval less than 60s, modify default backtest interval to 1s")
}
allKLineIntervals[sub.Options.Interval] = struct{}{}
}
}
}
return allKLineIntervals, requiredInterval, backTestIntervals
}
func InitializeExchangeSources(sessions map[string]*bbgo.ExchangeSession, startTime, endTime time.Time, requiredInterval types.Interval, extraIntervals ...types.Interval) (exchangeSources []*ExchangeDataSource, err error) {
for _, session := range sessions {
backtestEx := session.Exchange.(*Exchange)
c, err := backtestEx.SubscribeMarketData(startTime, endTime, requiredInterval, extraIntervals...)
if err != nil {
return exchangeSources, err
}
sessionCopy := session
src := &ExchangeDataSource{
C: c,
Exchange: backtestEx,
Session: sessionCopy,
}
backtestEx.Src = src
exchangeSources = append(exchangeSources, src)
}
return exchangeSources, nil
}

View File

@ -11,11 +11,12 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/fatih/color"
"github.com/google/uuid"
"github.com/c9s/bbgo/pkg/cmd/cmdutil" "github.com/c9s/bbgo/pkg/cmd/cmdutil"
"github.com/c9s/bbgo/pkg/data/tsv" "github.com/c9s/bbgo/pkg/data/tsv"
"github.com/c9s/bbgo/pkg/util" "github.com/c9s/bbgo/pkg/util"
"github.com/fatih/color"
"github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -295,8 +296,8 @@ var BacktestCmd = &cobra.Command{
return err return err
} }
allKLineIntervals, requiredInterval, backTestIntervals := collectSubscriptionIntervals(environ) allKLineIntervals, requiredInterval, backTestIntervals := backtest.CollectSubscriptionIntervals(environ)
exchangeSources, err := toExchangeSources(environ.Sessions(), startTime, endTime, requiredInterval, backTestIntervals...) exchangeSources, err := backtest.InitializeExchangeSources(environ.Sessions(), startTime, endTime, requiredInterval, backTestIntervals...)
if err != nil { if err != nil {
return err return err
} }
@ -593,32 +594,6 @@ var BacktestCmd = &cobra.Command{
}, },
} }
func collectSubscriptionIntervals(environ *bbgo.Environment) (allKLineIntervals map[types.Interval]struct{}, requiredInterval types.Interval, backTestIntervals []types.Interval) {
// default extra back-test intervals
backTestIntervals = []types.Interval{types.Interval1h, types.Interval1d}
// all subscribed intervals
allKLineIntervals = make(map[types.Interval]struct{})
for _, interval := range backTestIntervals {
allKLineIntervals[interval] = struct{}{}
}
// default interval is 1m for all exchanges
requiredInterval = types.Interval1m
for _, session := range environ.Sessions() {
for _, sub := range session.Subscriptions {
if sub.Channel == types.KLineChannel {
if sub.Options.Interval.Seconds()%60 > 0 {
// if any subscription interval is less than 60s, then we will use 1s for back-testing
requiredInterval = types.Interval1s
log.Warnf("found kline subscription interval less than 60s, modify default backtest interval to 1s")
}
allKLineIntervals[sub.Options.Interval] = struct{}{}
}
}
}
return allKLineIntervals, requiredInterval, backTestIntervals
}
func createSymbolReport(userConfig *bbgo.Config, session *bbgo.ExchangeSession, symbol string, trades []types.Trade, intervalProfit *types.IntervalProfitCollector, func createSymbolReport(userConfig *bbgo.Config, session *bbgo.ExchangeSession, symbol string, trades []types.Trade, intervalProfit *types.IntervalProfitCollector,
profitFactor, winningRatio fixedpoint.Value) ( profitFactor, winningRatio fixedpoint.Value) (
*backtest.SessionSymbolReport, *backtest.SessionSymbolReport,
@ -722,27 +697,6 @@ func confirmation(s string) bool {
} }
} }
func toExchangeSources(sessions map[string]*bbgo.ExchangeSession, startTime, endTime time.Time, requiredInterval types.Interval, extraIntervals ...types.Interval) (exchangeSources []*backtest.ExchangeDataSource, err error) {
for _, session := range sessions {
backtestEx := session.Exchange.(*backtest.Exchange)
c, err := backtestEx.SubscribeMarketData(startTime, endTime, requiredInterval, extraIntervals...)
if err != nil {
return exchangeSources, err
}
sessionCopy := session
src := &backtest.ExchangeDataSource{
C: c,
Exchange: backtestEx,
Session: sessionCopy,
}
backtestEx.Src = src
exchangeSources = append(exchangeSources, src)
}
return exchangeSources, nil
}
func sync(ctx context.Context, userConfig *bbgo.Config, backtestService *service.BacktestService, sourceExchanges map[types.ExchangeName]types.Exchange, syncFrom, syncTo time.Time) error { func sync(ctx context.Context, userConfig *bbgo.Config, backtestService *service.BacktestService, sourceExchanges map[types.ExchangeName]types.Exchange, syncFrom, syncTo time.Time) error {
for _, symbol := range userConfig.Backtest.Symbols { for _, symbol := range userConfig.Backtest.Symbols {
for _, sourceExchange := range sourceExchanges { for _, sourceExchange := range sourceExchanges {

View File

@ -4,6 +4,7 @@ package grid2
import ( import (
"context" "context"
"os"
"testing" "testing"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -350,12 +351,20 @@ func TestBacktestStrategy(t *testing.T) {
GridNum: 100, GridNum: 100,
QuoteInvestment: number(9000.0), QuoteInvestment: number(9000.0),
} }
RunBacktest(t, strategy)
}
func RunBacktest(t *testing.T, strategy bbgo.SingleExchangeStrategy) {
// TEMPLATE {{{ start backtest // TEMPLATE {{{ start backtest
startTime, err := types.ParseLooseFormatTime("2021-06-01") const sqliteDbFile = "../../../data/bbgo_test.sqlite3"
const backtestExchangeName = "binance"
const backtestStartTime = "2022-06-01"
const backtestEndTime = "2022-06-30"
startTime, err := types.ParseLooseFormatTime(backtestStartTime)
assert.NoError(t, err) assert.NoError(t, err)
endTime, err := types.ParseLooseFormatTime("2021-06-30") endTime, err := types.ParseLooseFormatTime(backtestEndTime)
assert.NoError(t, err) assert.NoError(t, err)
backtestConfig := &bbgo.Backtest{ backtestConfig := &bbgo.Backtest{
@ -364,7 +373,7 @@ func TestBacktestStrategy(t *testing.T) {
RecordTrades: false, RecordTrades: false,
FeeMode: bbgo.BacktestFeeModeToken, FeeMode: bbgo.BacktestFeeModeToken,
Accounts: map[string]bbgo.BacktestAccount{ Accounts: map[string]bbgo.BacktestAccount{
"binance": { backtestExchangeName: {
MakerFeeRate: number(0.075 * 0.01), MakerFeeRate: number(0.075 * 0.01),
TakerFeeRate: number(0.075 * 0.01), TakerFeeRate: number(0.075 * 0.01),
Balances: bbgo.BacktestAccountBalanceMap{ Balances: bbgo.BacktestAccountBalanceMap{
@ -374,7 +383,7 @@ func TestBacktestStrategy(t *testing.T) {
}, },
}, },
Symbols: []string{"BTCUSDT"}, Symbols: []string{"BTCUSDT"},
Sessions: []string{"binance"}, Sessions: []string{backtestExchangeName},
SyncSecKLines: false, SyncSecKLines: false,
} }
@ -384,8 +393,14 @@ func TestBacktestStrategy(t *testing.T) {
environ := bbgo.NewEnvironment() environ := bbgo.NewEnvironment()
environ.SetStartTime(startTime.Time()) environ.SetStartTime(startTime.Time())
err = environ.ConfigureDatabaseDriver(ctx, "sqlite3", "../../../data/bbgo_test.sqlite3") info, err := os.Stat(sqliteDbFile)
assert.NoError(t, err) assert.NoError(t, err)
t.Logf("sqlite: %+v", info)
err = environ.ConfigureDatabaseDriver(ctx, "sqlite3", sqliteDbFile)
if !assert.NoError(t, err) {
return
}
backtestService := &service.BacktestService{DB: environ.DatabaseService.DB} backtestService := &service.BacktestService{DB: environ.DatabaseService.DB}
defer func() { defer func() {
@ -397,22 +412,24 @@ func TestBacktestStrategy(t *testing.T) {
bbgo.SetBackTesting(backtestService) bbgo.SetBackTesting(backtestService)
defer bbgo.SetBackTesting(nil) defer bbgo.SetBackTesting(nil)
exName, err := types.ValidExchangeName("binance") exName, err := types.ValidExchangeName(backtestExchangeName)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
} }
t.Logf("using exchange source: %s", exName)
publicExchange, err := exchange.NewPublic(exName) publicExchange, err := exchange.NewPublic(exName)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
} }
backtestExchange, err := backtest.NewExchange(publicExchange.Name(), publicExchange, backtestService, backtestConfig) backtestExchange, err := backtest.NewExchange(exName, publicExchange, backtestService, backtestConfig)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
} }
session := environ.AddExchange(exName.String(), backtestExchange) session := environ.AddExchange(backtestExchangeName, backtestExchange)
assert.NotNil(t, session) assert.NotNil(t, session)
err = environ.Init(ctx) err = environ.Init(ctx)
@ -430,11 +447,11 @@ func TestBacktestStrategy(t *testing.T) {
trader.DisableLogging() trader.DisableLogging()
} }
// TODO: add grid2 to the user config and run backtest
userConfig := &bbgo.Config{ userConfig := &bbgo.Config{
Backtest: backtestConfig,
ExchangeStrategies: []bbgo.ExchangeStrategyMount{ ExchangeStrategies: []bbgo.ExchangeStrategyMount{
{ {
Mounts: []string{"binance"}, Mounts: []string{backtestExchangeName},
Strategy: strategy, Strategy: strategy,
}, },
}, },
@ -446,7 +463,32 @@ func TestBacktestStrategy(t *testing.T) {
err = trader.Run(ctx) err = trader.Run(ctx)
assert.NoError(t, err) assert.NoError(t, err)
// TODO: feed data allKLineIntervals, requiredInterval, backTestIntervals := backtest.CollectSubscriptionIntervals(environ)
t.Logf("requiredInterval: %s backTestIntervals: %v", requiredInterval, backTestIntervals)
_ = allKLineIntervals
exchangeSources, err := backtest.InitializeExchangeSources(environ.Sessions(), startTime.Time(), endTime.Time(), requiredInterval, backTestIntervals...)
if !assert.NoError(t, err) {
return
}
doneC := make(chan struct{})
go func() {
count := 0
exSource := exchangeSources[0]
for k := range exSource.C {
exSource.Exchange.ConsumeKLine(k, requiredInterval)
count++
}
err = exSource.Exchange.CloseMarketData()
assert.NoError(t, err)
assert.Greater(t, count, 0, "kLines count must be greater than 0, please check your backtest date range and symbol settings")
close(doneC)
}()
<-doneC
// }}} // }}}
} }