From f2765201866b59ef52b2ebc92f602a7c8de002ca Mon Sep 17 00:00:00 2001 From: c9s Date: Tue, 4 Aug 2020 20:04:15 +0800 Subject: [PATCH] improve error messages --- bbgo/stock.go | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/bbgo/stock.go b/bbgo/stock.go index 0dd99593e..fd966c735 100644 --- a/bbgo/stock.go +++ b/bbgo/stock.go @@ -5,6 +5,7 @@ import ( "github.com/c9s/bbgo/pkg/bbgo/types" "math" "strings" + "sync" ) func zero(a float64) bool { @@ -48,18 +49,25 @@ func (slice StockSlice) Quantity() (total float64) { } type StockManager struct { + mu sync.Mutex + Symbol string TradingFeeCurrency string Stocks StockSlice PendingSells StockSlice } -func (m *StockManager) Stock(buy Stock) error { - m.Stocks = append(m.Stocks, buy) +func (m *StockManager) stock(stock Stock) error { + m.mu.Lock() + m.Stocks = append(m.Stocks, stock) + m.mu.Unlock() return m.flushPendingSells() } func (m *StockManager) squash() { + m.mu.Lock() + defer m.mu.Unlock() + var squashed StockSlice for _, stock := range m.Stocks { if !zero(stock.Quantity) { @@ -70,22 +78,27 @@ func (m *StockManager) squash() { } func (m *StockManager) flushPendingSells() error { - if len(m.Stocks) > 0 && len(m.PendingSells) > 0 { + if len(m.Stocks) == 0 || len(m.PendingSells) == 0 { + return nil + } - pendingSells := m.PendingSells - m.PendingSells = nil + pendingSells := m.PendingSells + m.PendingSells = nil - for _, sell := range pendingSells { - if err := m.Consume(sell); err != nil { - return err - } + for _, sell := range pendingSells { + if err := m.consume(sell); err != nil { + return err } } return nil } -func (m *StockManager) Consume(sell Stock) error { +func (m *StockManager) consume(sell Stock) error { + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.Stocks) == 0 { m.PendingSells = append(m.PendingSells, sell) return nil @@ -158,12 +171,12 @@ func (m *StockManager) AddTrades(trades []types.Trade) (checkpoints []int, err e } stock := toStock(trade) - if err := m.Stock(stock); err != nil { + if err := m.stock(stock); err != nil { return checkpoints, err } } else { stock := toStock(trade) - if err := m.Consume(stock); err != nil { + if err := m.consume(stock); err != nil { return checkpoints, err } }