mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-10 01:01:56 +00:00
all: refactor backtest functions so that we can run backtest in test
This commit is contained in:
parent
846695e632
commit
c6ce223a13
|
@ -358,8 +358,7 @@ func (e *Exchange) SubscribeMarketData(startTime, endTime time.Time, requiredInt
|
|||
intervals = append(intervals, interval)
|
||||
}
|
||||
|
||||
log.Infof("using symbols: %v and intervals: %v for back-testing", symbols, intervals)
|
||||
log.Infof("querying klines from database...")
|
||||
log.Infof("querying klines from database with exchange: %v symbols: %v and intervals: %v for back-testing", e.Name(), symbols, intervals)
|
||||
klineC, errC := e.srv.QueryKLinesCh(startTime, endTime, e, symbols, intervals)
|
||||
go func() {
|
||||
if err := <-errC; err != nil {
|
||||
|
|
57
pkg/backtest/utils.go
Normal file
57
pkg/backtest/utils.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package backtest
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/bbgo"
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
)
|
||||
|
||||
func CollectSubscriptionIntervals(environ *bbgo.Environment) (allKLineIntervals map[types.Interval]struct{}, requiredInterval types.Interval, backTestIntervals []types.Interval) {
|
||||
// default extra back-test intervals
|
||||
backTestIntervals = []types.Interval{types.Interval1h, types.Interval1d}
|
||||
// all subscribed intervals
|
||||
allKLineIntervals = make(map[types.Interval]struct{})
|
||||
|
||||
for _, interval := range backTestIntervals {
|
||||
allKLineIntervals[interval] = struct{}{}
|
||||
}
|
||||
// default interval is 1m for all exchanges
|
||||
requiredInterval = types.Interval1m
|
||||
for _, session := range environ.Sessions() {
|
||||
for _, sub := range session.Subscriptions {
|
||||
if sub.Channel == types.KLineChannel {
|
||||
if sub.Options.Interval.Seconds()%60 > 0 {
|
||||
// if any subscription interval is less than 60s, then we will use 1s for back-testing
|
||||
requiredInterval = types.Interval1s
|
||||
logrus.Warnf("found kline subscription interval less than 60s, modify default backtest interval to 1s")
|
||||
}
|
||||
allKLineIntervals[sub.Options.Interval] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return allKLineIntervals, requiredInterval, backTestIntervals
|
||||
}
|
||||
|
||||
func InitializeExchangeSources(sessions map[string]*bbgo.ExchangeSession, startTime, endTime time.Time, requiredInterval types.Interval, extraIntervals ...types.Interval) (exchangeSources []*ExchangeDataSource, err error) {
|
||||
for _, session := range sessions {
|
||||
backtestEx := session.Exchange.(*Exchange)
|
||||
|
||||
c, err := backtestEx.SubscribeMarketData(startTime, endTime, requiredInterval, extraIntervals...)
|
||||
if err != nil {
|
||||
return exchangeSources, err
|
||||
}
|
||||
|
||||
sessionCopy := session
|
||||
src := &ExchangeDataSource{
|
||||
C: c,
|
||||
Exchange: backtestEx,
|
||||
Session: sessionCopy,
|
||||
}
|
||||
backtestEx.Src = src
|
||||
exchangeSources = append(exchangeSources, src)
|
||||
}
|
||||
return exchangeSources, nil
|
||||
}
|
|
@ -11,11 +11,12 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/cmd/cmdutil"
|
||||
"github.com/c9s/bbgo/pkg/data/tsv"
|
||||
"github.com/c9s/bbgo/pkg/util"
|
||||
"github.com/fatih/color"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
@ -295,8 +296,8 @@ var BacktestCmd = &cobra.Command{
|
|||
return err
|
||||
}
|
||||
|
||||
allKLineIntervals, requiredInterval, backTestIntervals := collectSubscriptionIntervals(environ)
|
||||
exchangeSources, err := toExchangeSources(environ.Sessions(), startTime, endTime, requiredInterval, backTestIntervals...)
|
||||
allKLineIntervals, requiredInterval, backTestIntervals := backtest.CollectSubscriptionIntervals(environ)
|
||||
exchangeSources, err := backtest.InitializeExchangeSources(environ.Sessions(), startTime, endTime, requiredInterval, backTestIntervals...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -593,32 +594,6 @@ var BacktestCmd = &cobra.Command{
|
|||
},
|
||||
}
|
||||
|
||||
func collectSubscriptionIntervals(environ *bbgo.Environment) (allKLineIntervals map[types.Interval]struct{}, requiredInterval types.Interval, backTestIntervals []types.Interval) {
|
||||
// default extra back-test intervals
|
||||
backTestIntervals = []types.Interval{types.Interval1h, types.Interval1d}
|
||||
// all subscribed intervals
|
||||
allKLineIntervals = make(map[types.Interval]struct{})
|
||||
|
||||
for _, interval := range backTestIntervals {
|
||||
allKLineIntervals[interval] = struct{}{}
|
||||
}
|
||||
// default interval is 1m for all exchanges
|
||||
requiredInterval = types.Interval1m
|
||||
for _, session := range environ.Sessions() {
|
||||
for _, sub := range session.Subscriptions {
|
||||
if sub.Channel == types.KLineChannel {
|
||||
if sub.Options.Interval.Seconds()%60 > 0 {
|
||||
// if any subscription interval is less than 60s, then we will use 1s for back-testing
|
||||
requiredInterval = types.Interval1s
|
||||
log.Warnf("found kline subscription interval less than 60s, modify default backtest interval to 1s")
|
||||
}
|
||||
allKLineIntervals[sub.Options.Interval] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return allKLineIntervals, requiredInterval, backTestIntervals
|
||||
}
|
||||
|
||||
func createSymbolReport(userConfig *bbgo.Config, session *bbgo.ExchangeSession, symbol string, trades []types.Trade, intervalProfit *types.IntervalProfitCollector,
|
||||
profitFactor, winningRatio fixedpoint.Value) (
|
||||
*backtest.SessionSymbolReport,
|
||||
|
@ -722,27 +697,6 @@ func confirmation(s string) bool {
|
|||
}
|
||||
}
|
||||
|
||||
func toExchangeSources(sessions map[string]*bbgo.ExchangeSession, startTime, endTime time.Time, requiredInterval types.Interval, extraIntervals ...types.Interval) (exchangeSources []*backtest.ExchangeDataSource, err error) {
|
||||
for _, session := range sessions {
|
||||
backtestEx := session.Exchange.(*backtest.Exchange)
|
||||
|
||||
c, err := backtestEx.SubscribeMarketData(startTime, endTime, requiredInterval, extraIntervals...)
|
||||
if err != nil {
|
||||
return exchangeSources, err
|
||||
}
|
||||
|
||||
sessionCopy := session
|
||||
src := &backtest.ExchangeDataSource{
|
||||
C: c,
|
||||
Exchange: backtestEx,
|
||||
Session: sessionCopy,
|
||||
}
|
||||
backtestEx.Src = src
|
||||
exchangeSources = append(exchangeSources, src)
|
||||
}
|
||||
return exchangeSources, nil
|
||||
}
|
||||
|
||||
func sync(ctx context.Context, userConfig *bbgo.Config, backtestService *service.BacktestService, sourceExchanges map[types.ExchangeName]types.Exchange, syncFrom, syncTo time.Time) error {
|
||||
for _, symbol := range userConfig.Backtest.Symbols {
|
||||
for _, sourceExchange := range sourceExchanges {
|
||||
|
|
|
@ -4,6 +4,7 @@ package grid2
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -350,12 +351,20 @@ func TestBacktestStrategy(t *testing.T) {
|
|||
GridNum: 100,
|
||||
QuoteInvestment: number(9000.0),
|
||||
}
|
||||
RunBacktest(t, strategy)
|
||||
}
|
||||
|
||||
func RunBacktest(t *testing.T, strategy bbgo.SingleExchangeStrategy) {
|
||||
// TEMPLATE {{{ start backtest
|
||||
startTime, err := types.ParseLooseFormatTime("2021-06-01")
|
||||
const sqliteDbFile = "../../../data/bbgo_test.sqlite3"
|
||||
const backtestExchangeName = "binance"
|
||||
const backtestStartTime = "2022-06-01"
|
||||
const backtestEndTime = "2022-06-30"
|
||||
|
||||
startTime, err := types.ParseLooseFormatTime(backtestStartTime)
|
||||
assert.NoError(t, err)
|
||||
|
||||
endTime, err := types.ParseLooseFormatTime("2021-06-30")
|
||||
endTime, err := types.ParseLooseFormatTime(backtestEndTime)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backtestConfig := &bbgo.Backtest{
|
||||
|
@ -364,7 +373,7 @@ func TestBacktestStrategy(t *testing.T) {
|
|||
RecordTrades: false,
|
||||
FeeMode: bbgo.BacktestFeeModeToken,
|
||||
Accounts: map[string]bbgo.BacktestAccount{
|
||||
"binance": {
|
||||
backtestExchangeName: {
|
||||
MakerFeeRate: number(0.075 * 0.01),
|
||||
TakerFeeRate: number(0.075 * 0.01),
|
||||
Balances: bbgo.BacktestAccountBalanceMap{
|
||||
|
@ -374,7 +383,7 @@ func TestBacktestStrategy(t *testing.T) {
|
|||
},
|
||||
},
|
||||
Symbols: []string{"BTCUSDT"},
|
||||
Sessions: []string{"binance"},
|
||||
Sessions: []string{backtestExchangeName},
|
||||
SyncSecKLines: false,
|
||||
}
|
||||
|
||||
|
@ -384,8 +393,14 @@ func TestBacktestStrategy(t *testing.T) {
|
|||
environ := bbgo.NewEnvironment()
|
||||
environ.SetStartTime(startTime.Time())
|
||||
|
||||
err = environ.ConfigureDatabaseDriver(ctx, "sqlite3", "../../../data/bbgo_test.sqlite3")
|
||||
info, err := os.Stat(sqliteDbFile)
|
||||
assert.NoError(t, err)
|
||||
t.Logf("sqlite: %+v", info)
|
||||
|
||||
err = environ.ConfigureDatabaseDriver(ctx, "sqlite3", sqliteDbFile)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
backtestService := &service.BacktestService{DB: environ.DatabaseService.DB}
|
||||
defer func() {
|
||||
|
@ -397,22 +412,24 @@ func TestBacktestStrategy(t *testing.T) {
|
|||
bbgo.SetBackTesting(backtestService)
|
||||
defer bbgo.SetBackTesting(nil)
|
||||
|
||||
exName, err := types.ValidExchangeName("binance")
|
||||
exName, err := types.ValidExchangeName(backtestExchangeName)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("using exchange source: %s", exName)
|
||||
|
||||
publicExchange, err := exchange.NewPublic(exName)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
backtestExchange, err := backtest.NewExchange(publicExchange.Name(), publicExchange, backtestService, backtestConfig)
|
||||
backtestExchange, err := backtest.NewExchange(exName, publicExchange, backtestService, backtestConfig)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
session := environ.AddExchange(exName.String(), backtestExchange)
|
||||
session := environ.AddExchange(backtestExchangeName, backtestExchange)
|
||||
assert.NotNil(t, session)
|
||||
|
||||
err = environ.Init(ctx)
|
||||
|
@ -430,11 +447,11 @@ func TestBacktestStrategy(t *testing.T) {
|
|||
trader.DisableLogging()
|
||||
}
|
||||
|
||||
// TODO: add grid2 to the user config and run backtest
|
||||
userConfig := &bbgo.Config{
|
||||
Backtest: backtestConfig,
|
||||
ExchangeStrategies: []bbgo.ExchangeStrategyMount{
|
||||
{
|
||||
Mounts: []string{"binance"},
|
||||
Mounts: []string{backtestExchangeName},
|
||||
Strategy: strategy,
|
||||
},
|
||||
},
|
||||
|
@ -446,7 +463,32 @@ func TestBacktestStrategy(t *testing.T) {
|
|||
err = trader.Run(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// TODO: feed data
|
||||
allKLineIntervals, requiredInterval, backTestIntervals := backtest.CollectSubscriptionIntervals(environ)
|
||||
t.Logf("requiredInterval: %s backTestIntervals: %v", requiredInterval, backTestIntervals)
|
||||
|
||||
_ = allKLineIntervals
|
||||
exchangeSources, err := backtest.InitializeExchangeSources(environ.Sessions(), startTime.Time(), endTime.Time(), requiredInterval, backTestIntervals...)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
doneC := make(chan struct{})
|
||||
go func() {
|
||||
count := 0
|
||||
exSource := exchangeSources[0]
|
||||
for k := range exSource.C {
|
||||
exSource.Exchange.ConsumeKLine(k, requiredInterval)
|
||||
count++
|
||||
}
|
||||
|
||||
err = exSource.Exchange.CloseMarketData()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Greater(t, count, 0, "kLines count must be greater than 0, please check your backtest date range and symbol settings")
|
||||
|
||||
close(doneC)
|
||||
}()
|
||||
|
||||
<-doneC
|
||||
// }}}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user