backtest: refactor exchange field, clean up startTime and endTime deps

This commit is contained in:
c9s 2022-07-04 02:34:46 +08:00
parent 8fc17f9c0b
commit 82f9fc139c
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
2 changed files with 10 additions and 18 deletions

View File

@ -50,10 +50,10 @@ var log = logrus.WithField("cmd", "backtest")
var ErrUnimplemented = errors.New("unimplemented method")
type Exchange struct {
sourceName types.ExchangeName
publicExchange types.Exchange
srv *service.BacktestService
startTime, endTime time.Time
sourceName types.ExchangeName
publicExchange types.Exchange
srv *service.BacktestService
startTime time.Time
account *types.Account
config *bbgo.Backtest
@ -80,14 +80,7 @@ func NewExchange(sourceName types.ExchangeName, sourceExchange types.Exchange, s
return nil, err
}
var startTime, endTime time.Time
startTime = config.StartTime.Time()
if config.EndTime != nil {
endTime = config.EndTime.Time()
} else {
endTime = time.Now()
}
startTime := config.StartTime.Time()
configAccount := config.GetAccount(sourceName.String())
account := &types.Account{
@ -107,7 +100,6 @@ func NewExchange(sourceName types.ExchangeName, sourceExchange types.Exchange, s
config: config,
account: account,
startTime: startTime,
endTime: endTime,
closedOrders: make(map[string][]types.Order),
trades: make(map[string][]types.Trade),
}
@ -322,7 +314,7 @@ func (e *Exchange) BindUserData(userDataStream types.StandardStreamEmitter) {
e.matchingBooksMutex.Unlock()
}
func (e *Exchange) SubscribeMarketData(extraIntervals ...types.Interval) (chan types.KLine, error) {
func (e *Exchange) SubscribeMarketData(startTime, endTime time.Time, extraIntervals ...types.Interval) (chan types.KLine, error) {
log.Infof("collecting backtest configurations...")
loadedSymbols := map[string]struct{}{}
@ -361,7 +353,7 @@ func (e *Exchange) SubscribeMarketData(extraIntervals ...types.Interval) (chan t
log.Infof("using symbols: %v and intervals: %v for back-testing", symbols, intervals)
log.Infof("querying klines from database...")
klineC, errC := e.srv.QueryKLinesCh(e.startTime, e.endTime, e, symbols, intervals)
klineC, errC := e.srv.QueryKLinesCh(startTime, endTime, e, symbols, intervals)
go func() {
if err := <-errC; err != nil {
log.WithError(err).Error("backtest data feed error")

View File

@ -288,7 +288,7 @@ var BacktestCmd = &cobra.Command{
}
backTestIntervals := []types.Interval{types.Interval1h, types.Interval1d}
exchangeSources, err := toExchangeSources(environ.Sessions(), backTestIntervals...)
exchangeSources, err := toExchangeSources(environ.Sessions(), startTime, endTime, backTestIntervals...)
if err != nil {
return err
}
@ -647,11 +647,11 @@ func confirmation(s string) bool {
}
}
func toExchangeSources(sessions map[string]*bbgo.ExchangeSession, extraIntervals ...types.Interval) (exchangeSources []backtest.ExchangeDataSource, err error) {
func toExchangeSources(sessions map[string]*bbgo.ExchangeSession, startTime, endTime time.Time, extraIntervals ...types.Interval) (exchangeSources []backtest.ExchangeDataSource, err error) {
for _, session := range sessions {
backtestEx := session.Exchange.(*backtest.Exchange)
c, err := backtestEx.SubscribeMarketData(extraIntervals...)
c, err := backtestEx.SubscribeMarketData(startTime, endTime, extraIntervals...)
if err != nil {
return exchangeSources, err
}