fix trade collector race condition and infinite iterate

This commit is contained in:
c9s 2023-07-19 17:33:12 +08:00
parent 93d10eba5a
commit 1dae711d33
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
5 changed files with 40 additions and 45 deletions

View File

@ -43,7 +43,7 @@ func Sync(ctx context.Context, obj interface{}) {
}
func loadPersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error {
return dynamic.IterateFieldsByTag(obj, "persistence", func(tag string, field reflect.StructField, value reflect.Value) error {
return dynamic.IterateFieldsByTag(obj, "persistence", true, func(tag string, field reflect.StructField, value reflect.Value) error {
log.Debugf("[loadPersistenceFields] loading value into field %v, tag = %s, original value = %v", field, tag, value)
newValueInf := dynamic.NewTypeValueInterface(value.Type())
@ -71,7 +71,7 @@ func loadPersistenceFields(obj interface{}, id string, persistence service.Persi
}
func storePersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error {
return dynamic.IterateFieldsByTag(obj, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error {
return dynamic.IterateFieldsByTag(obj, "persistence", true, func(tag string, ft reflect.StructField, fv reflect.Value) error {
log.Debugf("[storePersistenceFields] storing value from field %v, tag = %s, original value = %v", ft, tag, fv)
inf := fv.Interface()

View File

@ -107,12 +107,6 @@ func (c *TradeCollector) Recover(ctx context.Context, ex types.ExchangeTradeHist
return nil
}
func (c *TradeCollector) setDone(key types.TradeKey) {
c.mu.Lock()
c.doneTrades[key] = struct{}{}
c.mu.Unlock()
}
// Process filters the received trades and see if there are orders matching the trades
// if we have the order in the order store, then the trade will be considered for the position.
// profit will also be calculated.
@ -120,48 +114,47 @@ func (c *TradeCollector) Process() bool {
logrus.Debugf("TradeCollector.Process()")
positionChanged := false
var trades []types.Trade
// if it's already done, remove the trade from the trade store
c.mu.Lock()
c.tradeStore.Filter(func(trade types.Trade) bool {
key := trade.Key()
c.mu.Lock()
// if it's already done, remove the trade from the trade store
// remove done trades
if _, done := c.doneTrades[key]; done {
c.mu.Unlock()
return true
}
if c.position != nil {
if c.orderStore.Exists(trade.OrderID) {
var p types.Profit
profit, netProfit, madeProfit := c.position.AddTrade(trade)
if madeProfit {
p = c.position.NewProfit(trade, profit, netProfit)
}
c.doneTrades[key] = struct{}{}
c.mu.Unlock()
c.EmitTrade(trade, profit, netProfit)
if !p.Profit.IsZero() {
c.EmitProfit(trade, &p)
}
positionChanged = true
return true
}
} else {
if c.orderStore.Exists(trade.OrderID) {
c.doneTrades[key] = struct{}{}
c.mu.Unlock()
c.EmitTrade(trade, fixedpoint.Zero, fixedpoint.Zero)
return true
}
// if it's the trade we're looking for, add it to the list and mark it as done
if c.orderStore.Exists(trade.OrderID) {
trades = append(trades, trade)
c.doneTrades[key] = struct{}{}
return true
}
return false
})
c.mu.Unlock()
for _, trade := range trades {
var p types.Profit
if c.position != nil {
profit, netProfit, madeProfit := c.position.AddTrade(trade)
if madeProfit {
p = c.position.NewProfit(trade, profit, netProfit)
}
positionChanged = true
c.EmitTrade(trade, profit, netProfit)
} else {
c.EmitTrade(trade, fixedpoint.Zero, fixedpoint.Zero)
}
if !p.Profit.IsZero() {
c.EmitProfit(trade, &p)
}
}
if positionChanged && c.position != nil {
c.EmitPositionUpdate(c.position)

View File

@ -60,6 +60,7 @@ func (s *TradeStore) Clear() {
type TradeFilter func(trade types.Trade) bool
// Filter filters the trades by a given TradeFilter function
func (s *TradeStore) Filter(filter TradeFilter) {
s.Lock()
var trades = make(map[uint64]types.Trade)
@ -72,6 +73,7 @@ func (s *TradeStore) Filter(filter TradeFilter) {
s.Unlock()
}
// GetOrderTrades finds the trades match order id matches to the given order
func (s *TradeStore) GetOrderTrades(o types.Order) (trades []types.Trade) {
s.Lock()
for _, t := range s.trades {

View File

@ -56,7 +56,7 @@ func isStructPtr(tpe reflect.Type) bool {
return tpe.Kind() == reflect.Ptr && tpe.Elem().Kind() == reflect.Struct
}
func IterateFieldsByTag(obj interface{}, tagName string, cb StructFieldIterator) error {
func IterateFieldsByTag(obj interface{}, tagName string, children bool, cb StructFieldIterator) error {
sv := reflect.ValueOf(obj)
st := reflect.TypeOf(obj)
@ -86,9 +86,9 @@ func IterateFieldsByTag(obj interface{}, tagName string, cb StructFieldIterator)
continue
}
if isStructPtr(ft.Type) && !fv.IsNil() {
if children && isStructPtr(ft.Type) && !fv.IsNil() {
// recursive iterate the struct field
if err := IterateFieldsByTag(fv.Interface(), tagName, cb); err != nil {
if err := IterateFieldsByTag(fv.Interface(), tagName, false, cb); err != nil {
return fmt.Errorf("unable to iterate struct fields over the type %v: %v", ft, err)
}
}

View File

@ -75,7 +75,7 @@ func TestIterateFieldsByTag(t *testing.T) {
collectedTags := []string{}
cnt := 0
err := IterateFieldsByTag(&a, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error {
err := IterateFieldsByTag(&a, "persistence", false, func(tag string, ft reflect.StructField, fv reflect.Value) error {
cnt++
collectedTags = append(collectedTags, tag)
return nil
@ -101,7 +101,7 @@ func TestIterateFieldsByTag(t *testing.T) {
collectedTags := []string{}
cnt := 0
err := IterateFieldsByTag(&a, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error {
err := IterateFieldsByTag(&a, "persistence", false, func(tag string, ft reflect.StructField, fv reflect.Value) error {
cnt++
collectedTags = append(collectedTags, tag)
return nil
@ -119,7 +119,7 @@ func TestIterateFieldsByTag(t *testing.T) {
collectedTags := []string{}
cnt := 0
err := IterateFieldsByTag(a, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error {
err := IterateFieldsByTag(a, "persistence", false, func(tag string, ft reflect.StructField, fv reflect.Value) error {
cnt++
collectedTags = append(collectedTags, tag)
return nil