bbgo: improve trade collector callbacks

This commit is contained in:
c9s 2021-06-24 19:29:21 +08:00
parent db4fbbc30c
commit ecd2d9ea68
3 changed files with 52 additions and 16 deletions

View File

@ -8,7 +8,8 @@ import (
type TradeStore struct { type TradeStore struct {
// any created trades for tracking trades // any created trades for tracking trades
mu sync.Mutex sync.Mutex
trades map[int64]types.Trade trades map[int64]types.Trade
Symbol string Symbol string
@ -25,15 +26,15 @@ func NewTradeStore(symbol string) *TradeStore {
} }
func (s *TradeStore) Num() (num int) { func (s *TradeStore) Num() (num int) {
s.mu.Lock() s.Lock()
num = len(s.trades) num = len(s.trades)
s.mu.Unlock() s.Unlock()
return num return num
} }
func (s *TradeStore) Trades() (trades []types.Trade) { func (s *TradeStore) Trades() (trades []types.Trade) {
s.mu.Lock() s.Lock()
defer s.mu.Unlock() defer s.Unlock()
for _, o := range s.trades { for _, o := range s.trades {
trades = append(trades, o) trades = append(trades, o)
@ -43,33 +44,47 @@ func (s *TradeStore) Trades() (trades []types.Trade) {
} }
func (s *TradeStore) Exists(oID int64) (ok bool) { func (s *TradeStore) Exists(oID int64) (ok bool) {
s.mu.Lock() s.Lock()
defer s.mu.Unlock() defer s.Unlock()
_, ok = s.trades[oID] _, ok = s.trades[oID]
return ok return ok
} }
func (s *TradeStore) Clear() { func (s *TradeStore) Clear() {
s.mu.Lock() s.Lock()
s.trades = make(map[int64]types.Trade) s.trades = make(map[int64]types.Trade)
s.mu.Unlock() s.Unlock()
}
type TradeFilter func(trade types.Trade) bool
func (s *TradeStore) Filter(filter TradeFilter) {
s.Lock()
var trades = make(map[int64]types.Trade)
for _, trade := range s.trades {
if filter(trade) {
trades[trade.ID] = trade
}
}
s.trades = trades
s.Unlock()
} }
func (s *TradeStore) GetAndClear() (trades []types.Trade) { func (s *TradeStore) GetAndClear() (trades []types.Trade) {
s.mu.Lock() s.Lock()
for _, o := range s.trades { for _, o := range s.trades {
trades = append(trades, o) trades = append(trades, o)
} }
s.trades = make(map[int64]types.Trade) s.trades = make(map[int64]types.Trade)
s.mu.Unlock() s.Unlock()
return trades return trades
} }
func (s *TradeStore) Add(trades ...types.Trade) { func (s *TradeStore) Add(trades ...types.Trade) {
s.mu.Lock() s.Lock()
defer s.mu.Unlock() defer s.Unlock()
for _, trade := range trades { for _, trade := range trades {
s.trades[trade.ID] = trade s.trades[trade.ID] = trade

View File

@ -3,6 +3,7 @@ package bbgo
import ( import (
"context" "context"
"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/sigchan" "github.com/c9s/bbgo/pkg/sigchan"
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
) )
@ -19,6 +20,7 @@ type TradeCollector struct {
tradeCallbacks []func(trade types.Trade) tradeCallbacks []func(trade types.Trade)
positionUpdateCallbacks []func(position *Position) positionUpdateCallbacks []func(position *Position)
profitCallbacks []func(trade types.Trade, profit, netProfit fixedpoint.Value)
} }
func NewTradeCollector(symbol string, position *Position, orderStore *OrderStore) *TradeCollector { func NewTradeCollector(symbol string, position *Position, orderStore *OrderStore) *TradeCollector {
@ -55,15 +57,23 @@ func (c *TradeCollector) Run(ctx context.Context) {
trades := c.tradeStore.GetAndClear() trades := c.tradeStore.GetAndClear()
for _, trade := range trades { for _, trade := range trades {
if c.orderStore.Exists(trade.OrderID) { if c.orderStore.Exists(trade.OrderID) {
c.position.AddTrade(trade)
c.EmitTrade(trade) c.EmitTrade(trade)
if profit, netProfit, madeProfit := c.position.AddTrade(trade) ; madeProfit {
c.EmitProfit(trade, profit, netProfit)
}
} }
} }
c.EmitPositionUpdate(c.position) c.EmitPositionUpdate(c.position)
case trade := <-c.tradeC: case trade := <-c.tradeC:
if c.orderStore.Exists(trade.OrderID) {
c.EmitTrade(trade)
if profit, netProfit, madeProfit := c.position.AddTrade(trade) ; madeProfit {
c.EmitProfit(trade, profit, netProfit)
}
} else {
c.tradeStore.Add(trade) c.tradeStore.Add(trade)
}
} }
} }
} }

View File

@ -3,6 +3,7 @@
package bbgo package bbgo
import ( import (
"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
) )
@ -25,3 +26,13 @@ func (c *TradeCollector) EmitPositionUpdate(position *Position) {
cb(position) cb(position)
} }
} }
func (c *TradeCollector) OnProfit(cb func(trade types.Trade, profit fixedpoint.Value, netProfit fixedpoint.Value)) {
c.profitCallbacks = append(c.profitCallbacks, cb)
}
func (c *TradeCollector) EmitProfit(trade types.Trade, profit fixedpoint.Value, netProfit fixedpoint.Value) {
for _, cb := range c.profitCallbacks {
cb(trade, profit, netProfit)
}
}