From ecd2d9ea68aa0adbdb18f2640f685a6d7ba74fb9 Mon Sep 17 00:00:00 2001 From: c9s Date: Thu, 24 Jun 2021 19:29:21 +0800 Subject: [PATCH] bbgo: improve trade collector callbacks --- pkg/bbgo/trade_store.go | 41 +++++++++++++++++++--------- pkg/bbgo/tradecollector.go | 16 +++++++++-- pkg/bbgo/tradecollector_callbacks.go | 11 ++++++++ 3 files changed, 52 insertions(+), 16 deletions(-) diff --git a/pkg/bbgo/trade_store.go b/pkg/bbgo/trade_store.go index cc0ae7e73..0de18f45f 100644 --- a/pkg/bbgo/trade_store.go +++ b/pkg/bbgo/trade_store.go @@ -8,7 +8,8 @@ import ( type TradeStore struct { // any created trades for tracking trades - mu sync.Mutex + sync.Mutex + trades map[int64]types.Trade Symbol string @@ -25,15 +26,15 @@ func NewTradeStore(symbol string) *TradeStore { } func (s *TradeStore) Num() (num int) { - s.mu.Lock() + s.Lock() num = len(s.trades) - s.mu.Unlock() + s.Unlock() return num } func (s *TradeStore) Trades() (trades []types.Trade) { - s.mu.Lock() - defer s.mu.Unlock() + s.Lock() + defer s.Unlock() for _, o := range s.trades { trades = append(trades, o) @@ -43,33 +44,47 @@ func (s *TradeStore) Trades() (trades []types.Trade) { } func (s *TradeStore) Exists(oID int64) (ok bool) { - s.mu.Lock() - defer s.mu.Unlock() + s.Lock() + defer s.Unlock() _, ok = s.trades[oID] return ok } func (s *TradeStore) Clear() { - s.mu.Lock() + s.Lock() s.trades = make(map[int64]types.Trade) - s.mu.Unlock() + s.Unlock() +} + +type TradeFilter func(trade types.Trade) bool + +func (s *TradeStore) Filter(filter TradeFilter) { + s.Lock() + var trades = make(map[int64]types.Trade) + for _, trade := range s.trades { + if filter(trade) { + trades[trade.ID] = trade + } + } + s.trades = trades + s.Unlock() } func (s *TradeStore) GetAndClear() (trades []types.Trade) { - s.mu.Lock() + s.Lock() for _, o := range s.trades { trades = append(trades, o) } s.trades = make(map[int64]types.Trade) - s.mu.Unlock() + s.Unlock() return trades } func (s *TradeStore) Add(trades ...types.Trade) { - s.mu.Lock() - defer s.mu.Unlock() + s.Lock() + defer s.Unlock() for _, trade := range trades { s.trades[trade.ID] = trade diff --git a/pkg/bbgo/tradecollector.go b/pkg/bbgo/tradecollector.go index e4327e5c9..24966255f 100644 --- a/pkg/bbgo/tradecollector.go +++ b/pkg/bbgo/tradecollector.go @@ -3,6 +3,7 @@ package bbgo import ( "context" + "github.com/c9s/bbgo/pkg/fixedpoint" "github.com/c9s/bbgo/pkg/sigchan" "github.com/c9s/bbgo/pkg/types" ) @@ -19,6 +20,7 @@ type TradeCollector struct { tradeCallbacks []func(trade types.Trade) positionUpdateCallbacks []func(position *Position) + profitCallbacks []func(trade types.Trade, profit, netProfit fixedpoint.Value) } func NewTradeCollector(symbol string, position *Position, orderStore *OrderStore) *TradeCollector { @@ -55,15 +57,23 @@ func (c *TradeCollector) Run(ctx context.Context) { trades := c.tradeStore.GetAndClear() for _, trade := range trades { if c.orderStore.Exists(trade.OrderID) { - c.position.AddTrade(trade) c.EmitTrade(trade) + if profit, netProfit, madeProfit := c.position.AddTrade(trade) ; madeProfit { + c.EmitProfit(trade, profit, netProfit) + } } } c.EmitPositionUpdate(c.position) case trade := <-c.tradeC: - c.tradeStore.Add(trade) - + if c.orderStore.Exists(trade.OrderID) { + c.EmitTrade(trade) + if profit, netProfit, madeProfit := c.position.AddTrade(trade) ; madeProfit { + c.EmitProfit(trade, profit, netProfit) + } + } else { + c.tradeStore.Add(trade) + } } } } diff --git a/pkg/bbgo/tradecollector_callbacks.go b/pkg/bbgo/tradecollector_callbacks.go index 73e5863c0..0a3d4f5af 100644 --- a/pkg/bbgo/tradecollector_callbacks.go +++ b/pkg/bbgo/tradecollector_callbacks.go @@ -3,6 +3,7 @@ package bbgo import ( + "github.com/c9s/bbgo/pkg/fixedpoint" "github.com/c9s/bbgo/pkg/types" ) @@ -25,3 +26,13 @@ func (c *TradeCollector) EmitPositionUpdate(position *Position) { cb(position) } } + +func (c *TradeCollector) OnProfit(cb func(trade types.Trade, profit fixedpoint.Value, netProfit fixedpoint.Value)) { + c.profitCallbacks = append(c.profitCallbacks, cb) +} + +func (c *TradeCollector) EmitProfit(trade types.Trade, profit fixedpoint.Value, netProfit fixedpoint.Value) { + for _, cb := range c.profitCallbacks { + cb(trade, profit, netProfit) + } +}