From 1892d0332655985c930e9341e937e51231bdc7f9 Mon Sep 17 00:00:00 2001 From: c9s Date: Thu, 21 Jan 2021 15:10:40 +0800 Subject: [PATCH] make session trades map thread safe --- pkg/bbgo/environment.go | 5 ++--- pkg/bbgo/redis_persistence_test.go | 7 +++---- pkg/bbgo/reporter.go | 2 +- pkg/bbgo/session.go | 4 ++-- pkg/cmd/backtest.go | 2 +- pkg/types/trade.go | 12 ++++++------ 6 files changed, 15 insertions(+), 17 deletions(-) diff --git a/pkg/bbgo/environment.go b/pkg/bbgo/environment.go index 278c1738a..f60ccb689 100644 --- a/pkg/bbgo/environment.go +++ b/pkg/bbgo/environment.go @@ -216,9 +216,9 @@ func (environ *Environment) Init(ctx context.Context) (err error) { log.Infof("symbol %s: %d trades loaded", symbol, len(trades)) } - session.Trades[symbol] = trades + session.Trades[symbol] = &types.TradeSlice{Trades: trades} session.Stream.OnTradeUpdate(func(trade types.Trade) { - session.Trades[trade.Symbol] = append(session.Trades[trade.Symbol], trade) + session.Trades[symbol].Append(trade) }) session.lastPrices[symbol] = 0.0 @@ -271,7 +271,6 @@ func (environ *Environment) Init(ctx context.Context) (err error) { session.lastPrices[kline.Symbol] = kline.Close }) - // feed klines into the market data store if environ.startTime == emptyTime { environ.startTime = time.Now() diff --git a/pkg/bbgo/redis_persistence_test.go b/pkg/bbgo/redis_persistence_test.go index af43d1586..4c8b0679b 100644 --- a/pkg/bbgo/redis_persistence_test.go +++ b/pkg/bbgo/redis_persistence_test.go @@ -10,9 +10,9 @@ import ( func TestRedisPersistentService(t *testing.T) { redisService := NewRedisPersistenceService(&RedisPersistenceConfig{ - Host: "127.0.0.1", - Port: "6379", - DB: 0, + Host: "127.0.0.1", + Port: "6379", + DB: 0, }) assert.NotNil(t, redisService) @@ -65,4 +65,3 @@ func TestMemoryService(t *testing.T) { assert.Equal(t, i, j) }) } - diff --git a/pkg/bbgo/reporter.go b/pkg/bbgo/reporter.go index 8035a4ecf..bd6018095 100644 --- a/pkg/bbgo/reporter.go +++ b/pkg/bbgo/reporter.go @@ -74,7 +74,7 @@ func (reporter *AverageCostPnLReporter) Run() { } for _, symbol := range reporter.Symbols { - report := calculator.Calculate(symbol, session.Trades[symbol], session.lastPrices[symbol]) + report := calculator.Calculate(symbol, session.Trades[symbol].Copy(), session.lastPrices[symbol]) report.Print() } } diff --git a/pkg/bbgo/session.go b/pkg/bbgo/session.go index ad77fdfab..bdbcc1a96 100644 --- a/pkg/bbgo/session.go +++ b/pkg/bbgo/session.go @@ -116,7 +116,7 @@ type ExchangeSession struct { // Trades collects the executed trades from the exchange // map: symbol -> []trade - Trades map[string][]types.Trade + Trades map[string]*types.TradeSlice // marketDataStores contains the market data store of each market marketDataStores map[string]*MarketDataStore @@ -150,7 +150,7 @@ func NewExchangeSession(name string, exchange types.Exchange) *ExchangeSession { Stream: exchange.NewStream(), Subscriptions: make(map[types.Subscription]types.Subscription), Account: &types.Account{}, - Trades: make(map[string][]types.Trade), + Trades: make(map[string]*types.TradeSlice), markets: make(map[string]types.Market), startPrices: make(map[string]float64), diff --git a/pkg/cmd/backtest.go b/pkg/cmd/backtest.go index 4f616edcb..741144108 100644 --- a/pkg/cmd/backtest.go +++ b/pkg/cmd/backtest.go @@ -255,7 +255,7 @@ var BacktestCmd = &cobra.Command{ return fmt.Errorf("last price not found: %s", symbol) } - report := calculator.Calculate(symbol, trades, lastPrice) + report := calculator.Calculate(symbol, trades.Trades, lastPrice) report.Print() initBalances := userConfig.Backtest.Account.Balances.BalanceMap() diff --git a/pkg/types/trade.go b/pkg/types/trade.go index 6bb7735f0..56e82e409 100644 --- a/pkg/types/trade.go +++ b/pkg/types/trade.go @@ -17,14 +17,14 @@ func init() { } type TradeSlice struct { - mu sync.Mutex - Items []Trade + mu sync.Mutex + Trades []Trade } -func (s *TradeSlice) Slice() []Trade { +func (s *TradeSlice) Copy() []Trade { s.mu.Lock() - slice := make([]Trade, len(s.Items), len(s.Items)) - copy(slice, s.Items) + slice := make([]Trade, len(s.Trades), len(s.Trades)) + copy(slice, s.Trades) s.mu.Unlock() return slice @@ -32,7 +32,7 @@ func (s *TradeSlice) Slice() []Trade { func (s *TradeSlice) Append(t Trade) { s.mu.Lock() - s.Items = append(s.Items, t) + s.Trades = append(s.Trades, t) s.mu.Unlock() }