backtest: pull out userDataStream to backtestEx.BindUserData

This commit is contained in:
c9s 2022-07-04 02:27:29 +08:00
parent ecd4df86f9
commit a31f61736a
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
2 changed files with 8 additions and 15 deletions

View File

@ -58,7 +58,7 @@ type Exchange struct {
account *types.Account account *types.Account
config *bbgo.Backtest config *bbgo.Backtest
UserDataStream, MarketDataStream types.StandardStreamEmitter MarketDataStream types.StandardStreamEmitter
trades map[string][]types.Trade trades map[string][]types.Trade
tradesMutex sync.Mutex tradesMutex sync.Mutex
@ -173,10 +173,6 @@ func (e *Exchange) NewStream() types.Stream {
} }
func (e *Exchange) SubmitOrders(ctx context.Context, orders ...types.SubmitOrder) (createdOrders types.OrderSlice, err error) { func (e *Exchange) SubmitOrders(ctx context.Context, orders ...types.SubmitOrder) (createdOrders types.OrderSlice, err error) {
if e.UserDataStream == nil {
return createdOrders, fmt.Errorf("SubmitOrders() should be called after UserDataStream been initialized")
}
for _, order := range orders { for _, order := range orders {
symbol := order.Symbol symbol := order.Symbol
matching, ok := e.matchingBook(symbol) matching, ok := e.matchingBook(symbol)
@ -222,9 +218,6 @@ func (e *Exchange) QueryClosedOrders(ctx context.Context, symbol string, since,
} }
func (e *Exchange) CancelOrders(ctx context.Context, orders ...types.Order) error { func (e *Exchange) CancelOrders(ctx context.Context, orders ...types.Order) error {
if e.UserDataStream == nil {
return fmt.Errorf("CancelOrders should be called after UserDataStream been initialized")
}
for _, order := range orders { for _, order := range orders {
matching, ok := e.matchingBook(order.Symbol) matching, ok := e.matchingBook(order.Symbol)
if !ok { if !ok {
@ -315,16 +308,16 @@ func (e *Exchange) matchingBook(symbol string) (*SimplePriceMatching, bool) {
return m, ok return m, ok
} }
func (e *Exchange) InitMarketData() { func (e *Exchange) BindUserData(userDataStream types.StandardStreamEmitter) {
e.UserDataStream.OnTradeUpdate(func(trade types.Trade) { userDataStream.OnTradeUpdate(func(trade types.Trade) {
e.addTrade(trade) e.addTrade(trade)
}) })
e.matchingBooksMutex.Lock() e.matchingBooksMutex.Lock()
for _, matching := range e.matchingBooks { for _, matching := range e.matchingBooks {
matching.OnTradeUpdate(e.UserDataStream.EmitTradeUpdate) matching.OnTradeUpdate(userDataStream.EmitTradeUpdate)
matching.OnOrderUpdate(e.UserDataStream.EmitOrderUpdate) matching.OnOrderUpdate(userDataStream.EmitOrderUpdate)
matching.OnBalanceUpdate(e.UserDataStream.EmitBalanceUpdate) matching.OnBalanceUpdate(userDataStream.EmitBalanceUpdate)
} }
e.matchingBooksMutex.Unlock() e.matchingBooksMutex.Unlock()
} }

View File

@ -268,9 +268,10 @@ var BacktestCmd = &cobra.Command{
} }
for _, session := range environ.Sessions() { for _, session := range environ.Sessions() {
userDataStream := session.UserDataStream.(types.StandardStreamEmitter)
backtestEx := session.Exchange.(*backtest.Exchange) backtestEx := session.Exchange.(*backtest.Exchange)
backtestEx.UserDataStream = session.UserDataStream.(types.StandardStreamEmitter)
backtestEx.MarketDataStream = session.MarketDataStream.(types.StandardStreamEmitter) backtestEx.MarketDataStream = session.MarketDataStream.(types.StandardStreamEmitter)
backtestEx.BindUserData(userDataStream)
} }
trader := bbgo.NewTrader(environ) trader := bbgo.NewTrader(environ)
@ -649,7 +650,6 @@ func confirmation(s string) bool {
func toExchangeSources(sessions map[string]*bbgo.ExchangeSession, extraIntervals ...types.Interval) (exchangeSources []backtest.ExchangeDataSource, err error) { func toExchangeSources(sessions map[string]*bbgo.ExchangeSession, extraIntervals ...types.Interval) (exchangeSources []backtest.ExchangeDataSource, err error) {
for _, session := range sessions { for _, session := range sessions {
backtestEx := session.Exchange.(*backtest.Exchange) backtestEx := session.Exchange.(*backtest.Exchange)
backtestEx.InitMarketData()
c, err := backtestEx.SubscribeMarketData(extraIntervals...) c, err := backtestEx.SubscribeMarketData(extraIntervals...)
if err != nil { if err != nil {