From 1dae711d332e5ddac1230f83d924abc4ad3ece60 Mon Sep 17 00:00:00 2001 From: c9s Date: Wed, 19 Jul 2023 17:33:12 +0800 Subject: [PATCH] fix trade collector race condition and infinite iterate --- pkg/bbgo/persistence.go | 4 +-- pkg/core/tradecollector.go | 67 +++++++++++++++++-------------------- pkg/core/tradestore.go | 2 ++ pkg/dynamic/iterate.go | 6 ++-- pkg/dynamic/iterate_test.go | 6 ++-- 5 files changed, 40 insertions(+), 45 deletions(-) diff --git a/pkg/bbgo/persistence.go b/pkg/bbgo/persistence.go index 34b8205b5..04e13bdcb 100644 --- a/pkg/bbgo/persistence.go +++ b/pkg/bbgo/persistence.go @@ -43,7 +43,7 @@ func Sync(ctx context.Context, obj interface{}) { } func loadPersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error { - return dynamic.IterateFieldsByTag(obj, "persistence", func(tag string, field reflect.StructField, value reflect.Value) error { + return dynamic.IterateFieldsByTag(obj, "persistence", true, func(tag string, field reflect.StructField, value reflect.Value) error { log.Debugf("[loadPersistenceFields] loading value into field %v, tag = %s, original value = %v", field, tag, value) newValueInf := dynamic.NewTypeValueInterface(value.Type()) @@ -71,7 +71,7 @@ func loadPersistenceFields(obj interface{}, id string, persistence service.Persi } func storePersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error { - return dynamic.IterateFieldsByTag(obj, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error { + return dynamic.IterateFieldsByTag(obj, "persistence", true, func(tag string, ft reflect.StructField, fv reflect.Value) error { log.Debugf("[storePersistenceFields] storing value from field %v, tag = %s, original value = %v", ft, tag, fv) inf := fv.Interface() diff --git a/pkg/core/tradecollector.go b/pkg/core/tradecollector.go index 4bf1ddaa6..e844d35d5 100644 --- a/pkg/core/tradecollector.go +++ b/pkg/core/tradecollector.go @@ -107,12 +107,6 @@ func (c *TradeCollector) Recover(ctx context.Context, ex types.ExchangeTradeHist return nil } -func (c *TradeCollector) setDone(key types.TradeKey) { - c.mu.Lock() - c.doneTrades[key] = struct{}{} - c.mu.Unlock() -} - // Process filters the received trades and see if there are orders matching the trades // if we have the order in the order store, then the trade will be considered for the position. // profit will also be calculated. @@ -120,48 +114,47 @@ func (c *TradeCollector) Process() bool { logrus.Debugf("TradeCollector.Process()") positionChanged := false + var trades []types.Trade + + // if it's already done, remove the trade from the trade store + c.mu.Lock() c.tradeStore.Filter(func(trade types.Trade) bool { key := trade.Key() - c.mu.Lock() - - // if it's already done, remove the trade from the trade store + // remove done trades if _, done := c.doneTrades[key]; done { - c.mu.Unlock() return true } - if c.position != nil { - if c.orderStore.Exists(trade.OrderID) { - var p types.Profit - profit, netProfit, madeProfit := c.position.AddTrade(trade) - if madeProfit { - p = c.position.NewProfit(trade, profit, netProfit) - } - - c.doneTrades[key] = struct{}{} - c.mu.Unlock() - - c.EmitTrade(trade, profit, netProfit) - if !p.Profit.IsZero() { - c.EmitProfit(trade, &p) - } - - positionChanged = true - return true - } - - } else { - if c.orderStore.Exists(trade.OrderID) { - c.doneTrades[key] = struct{}{} - c.mu.Unlock() - c.EmitTrade(trade, fixedpoint.Zero, fixedpoint.Zero) - return true - } + // if it's the trade we're looking for, add it to the list and mark it as done + if c.orderStore.Exists(trade.OrderID) { + trades = append(trades, trade) + c.doneTrades[key] = struct{}{} + return true } return false }) + c.mu.Unlock() + + for _, trade := range trades { + var p types.Profit + if c.position != nil { + profit, netProfit, madeProfit := c.position.AddTrade(trade) + if madeProfit { + p = c.position.NewProfit(trade, profit, netProfit) + } + positionChanged = true + + c.EmitTrade(trade, profit, netProfit) + } else { + c.EmitTrade(trade, fixedpoint.Zero, fixedpoint.Zero) + } + + if !p.Profit.IsZero() { + c.EmitProfit(trade, &p) + } + } if positionChanged && c.position != nil { c.EmitPositionUpdate(c.position) diff --git a/pkg/core/tradestore.go b/pkg/core/tradestore.go index 98ce886e2..485f820dc 100644 --- a/pkg/core/tradestore.go +++ b/pkg/core/tradestore.go @@ -60,6 +60,7 @@ func (s *TradeStore) Clear() { type TradeFilter func(trade types.Trade) bool +// Filter filters the trades by a given TradeFilter function func (s *TradeStore) Filter(filter TradeFilter) { s.Lock() var trades = make(map[uint64]types.Trade) @@ -72,6 +73,7 @@ func (s *TradeStore) Filter(filter TradeFilter) { s.Unlock() } +// GetOrderTrades finds the trades match order id matches to the given order func (s *TradeStore) GetOrderTrades(o types.Order) (trades []types.Trade) { s.Lock() for _, t := range s.trades { diff --git a/pkg/dynamic/iterate.go b/pkg/dynamic/iterate.go index 8063698ba..9c932d605 100644 --- a/pkg/dynamic/iterate.go +++ b/pkg/dynamic/iterate.go @@ -56,7 +56,7 @@ func isStructPtr(tpe reflect.Type) bool { return tpe.Kind() == reflect.Ptr && tpe.Elem().Kind() == reflect.Struct } -func IterateFieldsByTag(obj interface{}, tagName string, cb StructFieldIterator) error { +func IterateFieldsByTag(obj interface{}, tagName string, children bool, cb StructFieldIterator) error { sv := reflect.ValueOf(obj) st := reflect.TypeOf(obj) @@ -86,9 +86,9 @@ func IterateFieldsByTag(obj interface{}, tagName string, cb StructFieldIterator) continue } - if isStructPtr(ft.Type) && !fv.IsNil() { + if children && isStructPtr(ft.Type) && !fv.IsNil() { // recursive iterate the struct field - if err := IterateFieldsByTag(fv.Interface(), tagName, cb); err != nil { + if err := IterateFieldsByTag(fv.Interface(), tagName, false, cb); err != nil { return fmt.Errorf("unable to iterate struct fields over the type %v: %v", ft, err) } } diff --git a/pkg/dynamic/iterate_test.go b/pkg/dynamic/iterate_test.go index 11a42d1c0..2fb502b67 100644 --- a/pkg/dynamic/iterate_test.go +++ b/pkg/dynamic/iterate_test.go @@ -75,7 +75,7 @@ func TestIterateFieldsByTag(t *testing.T) { collectedTags := []string{} cnt := 0 - err := IterateFieldsByTag(&a, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error { + err := IterateFieldsByTag(&a, "persistence", false, func(tag string, ft reflect.StructField, fv reflect.Value) error { cnt++ collectedTags = append(collectedTags, tag) return nil @@ -101,7 +101,7 @@ func TestIterateFieldsByTag(t *testing.T) { collectedTags := []string{} cnt := 0 - err := IterateFieldsByTag(&a, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error { + err := IterateFieldsByTag(&a, "persistence", false, func(tag string, ft reflect.StructField, fv reflect.Value) error { cnt++ collectedTags = append(collectedTags, tag) return nil @@ -119,7 +119,7 @@ func TestIterateFieldsByTag(t *testing.T) { collectedTags := []string{} cnt := 0 - err := IterateFieldsByTag(a, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error { + err := IterateFieldsByTag(a, "persistence", false, func(tag string, ft reflect.StructField, fv reflect.Value) error { cnt++ collectedTags = append(collectedTags, tag) return nil