From d3bc37f45e8d0387c769e976b643f8da29c32cee Mon Sep 17 00:00:00 2001 From: "chiahung.lin" Date: Tue, 9 Jan 2024 16:01:10 +0800 Subject: [PATCH] use CommonCallback and pull PersistenceTTL out --- pkg/strategy/dca2/open_position.go | 6 +- pkg/strategy/dca2/profit_stats.go | 107 +++++++----------------- pkg/strategy/dca2/recover.go | 8 +- pkg/strategy/dca2/state.go | 2 +- pkg/strategy/dca2/strategy.go | 49 ++++++++++- pkg/strategy/dca2/strategy_callbacks.go | 30 ------- pkg/types/callbacks.go | 37 ++++++++ 7 files changed, 123 insertions(+), 116 deletions(-) create mode 100644 pkg/types/callbacks.go diff --git a/pkg/strategy/dca2/open_position.go b/pkg/strategy/dca2/open_position.go index f40d17b94..617a06c44 100644 --- a/pkg/strategy/dca2/open_position.go +++ b/pkg/strategy/dca2/open_position.go @@ -61,7 +61,7 @@ func generateOpenPositionOrders(market types.Market, quoteInvestment, price, pri prices = append(prices, price) } - notional, orderNum := calculateNotionalAndNum(market, quoteInvestment, prices) + notional, orderNum := calculateNotionalAndNumOrders(market, quoteInvestment, prices) if orderNum == 0 { return nil, fmt.Errorf("failed to calculate notional and num of open position orders, price: %s, quote investment: %s", price, quoteInvestment) } @@ -87,9 +87,9 @@ func generateOpenPositionOrders(market types.Market, quoteInvestment, price, pri return submitOrders, nil } -// calculateNotionalAndNum calculates the notional and num of open position orders +// calculateNotionalAndNumOrders calculates the notional and num of open position orders // DCA2 is notional-based, every order has the same notional -func calculateNotionalAndNum(market types.Market, quoteInvestment fixedpoint.Value, prices []fixedpoint.Value) (fixedpoint.Value, int) { +func calculateNotionalAndNumOrders(market types.Market, quoteInvestment fixedpoint.Value, prices []fixedpoint.Value) (fixedpoint.Value, int) { for num := len(prices); num > 0; num-- { notional := quoteInvestment.Div(fixedpoint.NewFromInt(int64(num))) if notional.Compare(market.MinNotional) < 0 { diff --git a/pkg/strategy/dca2/profit_stats.go b/pkg/strategy/dca2/profit_stats.go index 5ceb21b7a..a468e65cd 100644 --- a/pkg/strategy/dca2/profit_stats.go +++ b/pkg/strategy/dca2/profit_stats.go @@ -1,9 +1,7 @@ package dca2 import ( - "context" "fmt" - "strconv" "strings" "time" @@ -11,6 +9,21 @@ import ( "github.com/c9s/bbgo/pkg/types" ) +type PersistenceTTL struct { + ttl time.Duration +} + +func (p *PersistenceTTL) SetTTL(ttl time.Duration) { + if ttl.Nanoseconds() <= 0 { + return + } + p.ttl = ttl +} + +func (p *PersistenceTTL) Expiration() time.Duration { + return p.ttl +} + type ProfitStats struct { Symbol string `json:"symbol"` Market types.Market `json:"market,omitempty"` @@ -19,13 +32,12 @@ type ProfitStats struct { Round int64 `json:"round,omitempty"` QuoteInvestment fixedpoint.Value `json:"quoteInvestment,omitempty"` - RoundProfit fixedpoint.Value `json:"roundProfit,omitempty"` - RoundFee map[string]fixedpoint.Value `json:"roundFee,omitempty"` - TotalProfit fixedpoint.Value `json:"totalProfit,omitempty"` - TotalFee map[string]fixedpoint.Value `json:"totalFee,omitempty"` + CurrentRoundProfit fixedpoint.Value `json:"currentRoundProfit,omitempty"` + CurrentRoundFee map[string]fixedpoint.Value `json:"currentRoundFee,omitempty"` + TotalProfit fixedpoint.Value `json:"totalProfit,omitempty"` + TotalFee map[string]fixedpoint.Value `json:"totalFee,omitempty"` - // ttl is the ttl to keep in persistence - ttl time.Duration + PersistenceTTL } func newProfitStats(market types.Market, quoteInvestment fixedpoint.Value) *ProfitStats { @@ -34,31 +46,20 @@ func newProfitStats(market types.Market, quoteInvestment fixedpoint.Value) *Prof Market: market, Round: 0, QuoteInvestment: quoteInvestment, - RoundFee: make(map[string]fixedpoint.Value), + CurrentRoundFee: make(map[string]fixedpoint.Value), TotalFee: make(map[string]fixedpoint.Value), } } -func (s *ProfitStats) SetTTL(ttl time.Duration) { - if ttl.Nanoseconds() <= 0 { - return - } - s.ttl = ttl -} - -func (s *ProfitStats) Expiration() time.Duration { - return s.ttl -} - func (s *ProfitStats) AddTrade(trade types.Trade) { - if s.RoundFee == nil { - s.RoundFee = make(map[string]fixedpoint.Value) + if s.CurrentRoundFee == nil { + s.CurrentRoundFee = make(map[string]fixedpoint.Value) } - if fee, ok := s.RoundFee[trade.FeeCurrency]; ok { - s.RoundFee[trade.FeeCurrency] = fee.Add(trade.Fee) + if fee, ok := s.CurrentRoundFee[trade.FeeCurrency]; ok { + s.CurrentRoundFee[trade.FeeCurrency] = fee.Add(trade.Fee) } else { - s.RoundFee[trade.FeeCurrency] = trade.Fee + s.CurrentRoundFee[trade.FeeCurrency] = trade.Fee } if s.TotalFee == nil { @@ -76,63 +77,19 @@ func (s *ProfitStats) AddTrade(trade types.Trade) { quoteQuantity = quoteQuantity.Neg() } - s.RoundProfit = s.RoundProfit.Add(quoteQuantity) + s.CurrentRoundProfit = s.CurrentRoundProfit.Add(quoteQuantity) s.TotalProfit = s.TotalProfit.Add(quoteQuantity) if s.Market.QuoteCurrency == trade.FeeCurrency { - s.RoundProfit.Sub(trade.Fee) + s.CurrentRoundProfit.Sub(trade.Fee) s.TotalProfit.Sub(trade.Fee) } } func (s *ProfitStats) NewRound() { s.Round++ - s.RoundProfit = fixedpoint.Zero - s.RoundFee = make(map[string]fixedpoint.Value) -} - -func (s *ProfitStats) CalculateProfitOfRound(ctx context.Context, exchange types.Exchange) error { - historyService, ok := exchange.(types.ExchangeTradeHistoryService) - if !ok { - return fmt.Errorf("exchange %s doesn't support ExchangeTradeHistoryService", exchange.Name()) - } - - queryService, ok := exchange.(types.ExchangeOrderQueryService) - if !ok { - return fmt.Errorf("exchange %s doesn't support ExchangeOrderQueryService", exchange.Name()) - } - - // query the orders of this round - orders, err := historyService.QueryClosedOrders(ctx, s.Symbol, time.Time{}, time.Time{}, s.FromOrderID) - if err != nil { - return err - } - - // query the trades of this round - for _, order := range orders { - if order.ExecutedQuantity.Sign() == 0 { - // skip no trade orders - continue - } - - trades, err := queryService.QueryOrderTrades(ctx, types.OrderQuery{ - Symbol: order.Symbol, - OrderID: strconv.FormatUint(order.OrderID, 10), - }) - - if err != nil { - return err - } - - for _, trade := range trades { - s.AddTrade(trade) - } - } - - s.FromOrderID = s.FromOrderID + 1 - s.QuoteInvestment = s.QuoteInvestment.Add(s.RoundProfit) - - return nil + s.CurrentRoundProfit = fixedpoint.Zero + s.CurrentRoundFee = make(map[string]fixedpoint.Value) } func (s *ProfitStats) String() string { @@ -141,9 +98,9 @@ func (s *ProfitStats) String() string { sb.WriteString(fmt.Sprintf("Round: %d\n", s.Round)) sb.WriteString(fmt.Sprintf("From Order ID: %d\n", s.FromOrderID)) sb.WriteString(fmt.Sprintf("Quote Investment: %s\n", s.QuoteInvestment)) - sb.WriteString(fmt.Sprintf("Round Profit: %s\n", s.RoundProfit)) + sb.WriteString(fmt.Sprintf("Current Round Profit: %s\n", s.CurrentRoundProfit)) sb.WriteString(fmt.Sprintf("Total Profit: %s\n", s.TotalProfit)) - for currency, fee := range s.RoundFee { + for currency, fee := range s.CurrentRoundFee { sb.WriteString(fmt.Sprintf("FEE (%s): %s\n", currency, fee)) } sb.WriteString("[------------------ Profit Stats ------------------]\n") diff --git a/pkg/strategy/dca2/recover.go b/pkg/strategy/dca2/recover.go index fd5de6fbf..b96474b07 100644 --- a/pkg/strategy/dca2/recover.go +++ b/pkg/strategy/dca2/recover.go @@ -57,7 +57,7 @@ func (s *Strategy) recover(ctx context.Context) error { } // recover profit stats - recoverProfitStats(ctx, s.ProfitStats, s.Session.Exchange) + recoverProfitStats(ctx, s) // recover startTimeOfNextRound startTimeOfNextRound := recoverStartTimeOfNextRound(ctx, currentRound, s.CoolDownInterval) @@ -189,12 +189,12 @@ func recoverPosition(ctx context.Context, position *types.Position, queryService return nil } -func recoverProfitStats(ctx context.Context, profitStats *ProfitStats, exchange types.Exchange) error { - if profitStats == nil { +func recoverProfitStats(ctx context.Context, strategy *Strategy) error { + if strategy.ProfitStats == nil { return fmt.Errorf("profit stats is nil, please check it") } - profitStats.CalculateProfitOfRound(ctx, exchange) + strategy.CalculateProfitOfCurrentRound(ctx) return nil } diff --git a/pkg/strategy/dca2/state.go b/pkg/strategy/dca2/state.go index 53ee3386c..38190d2d2 100644 --- a/pkg/strategy/dca2/state.go +++ b/pkg/strategy/dca2/state.go @@ -193,7 +193,7 @@ func (s *Strategy) runTakeProfitReady(ctx context.Context, next State) { s.logger.Info("[State] TakeProfitReady - start reseting position and calculate quote investment for next round") // calculate profit stats - s.ProfitStats.CalculateProfitOfRound(ctx, s.Session.Exchange) + s.CalculateProfitOfCurrentRound(ctx) bbgo.Sync(ctx, s) s.EmitProfit(s.ProfitStats) diff --git a/pkg/strategy/dca2/strategy.go b/pkg/strategy/dca2/strategy.go index 3fec4a6c9..d355b5f1f 100644 --- a/pkg/strategy/dca2/strategy.go +++ b/pkg/strategy/dca2/strategy.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math" + "strconv" "sync" "time" @@ -68,11 +69,9 @@ type Strategy struct { state State // callbacks - readyCallbacks []func() + types.CommonCallback positionCallbacks []func(*types.Position) profitCallbacks []func(*ProfitStats) - closedCallbacks []func() - errorCallbacks []func(error) } func (s *Strategy) ID() string { @@ -278,3 +277,47 @@ func (s *Strategy) CleanUp(ctx context.Context) error { bbgo.Sync(ctx, s) return err } + +func (s *Strategy) CalculateProfitOfCurrentRound(ctx context.Context) error { + historyService, ok := s.Session.Exchange.(types.ExchangeTradeHistoryService) + if !ok { + return fmt.Errorf("exchange %s doesn't support ExchangeTradeHistoryService", s.Session.Exchange.Name()) + } + + queryService, ok := s.Session.Exchange.(types.ExchangeOrderQueryService) + if !ok { + return fmt.Errorf("exchange %s doesn't support ExchangeOrderQueryService", s.Session.Exchange.Name()) + } + + // query the orders of this round + orders, err := historyService.QueryClosedOrders(ctx, s.Symbol, time.Time{}, time.Time{}, s.ProfitStats.FromOrderID) + if err != nil { + return err + } + + // query the trades of this round + for _, order := range orders { + if order.ExecutedQuantity.Sign() == 0 { + // skip no trade orders + continue + } + + trades, err := queryService.QueryOrderTrades(ctx, types.OrderQuery{ + Symbol: order.Symbol, + OrderID: strconv.FormatUint(order.OrderID, 10), + }) + + if err != nil { + return err + } + + for _, trade := range trades { + s.ProfitStats.AddTrade(trade) + } + } + + s.ProfitStats.FromOrderID = s.ProfitStats.FromOrderID + 1 + s.ProfitStats.QuoteInvestment = s.ProfitStats.QuoteInvestment.Add(s.ProfitStats.CurrentRoundProfit) + + return nil +} diff --git a/pkg/strategy/dca2/strategy_callbacks.go b/pkg/strategy/dca2/strategy_callbacks.go index 64781b4b2..febebd52e 100644 --- a/pkg/strategy/dca2/strategy_callbacks.go +++ b/pkg/strategy/dca2/strategy_callbacks.go @@ -6,16 +6,6 @@ import ( "github.com/c9s/bbgo/pkg/types" ) -func (s *Strategy) OnReady(cb func()) { - s.readyCallbacks = append(s.readyCallbacks, cb) -} - -func (s *Strategy) EmitReady() { - for _, cb := range s.readyCallbacks { - cb() - } -} - func (s *Strategy) OnPosition(cb func(*types.Position)) { s.positionCallbacks = append(s.positionCallbacks, cb) } @@ -35,23 +25,3 @@ func (s *Strategy) EmitProfit(profitStats *ProfitStats) { cb(profitStats) } } - -func (s *Strategy) OnClosed(cb func()) { - s.closedCallbacks = append(s.closedCallbacks, cb) -} - -func (s *Strategy) EmitClosed() { - for _, cb := range s.closedCallbacks { - cb() - } -} - -func (s *Strategy) OnError(cb func(err error)) { - s.errorCallbacks = append(s.errorCallbacks, cb) -} - -func (s *Strategy) EmitError(err error) { - for _, cb := range s.errorCallbacks { - cb(err) - } -} diff --git a/pkg/types/callbacks.go b/pkg/types/callbacks.go new file mode 100644 index 000000000..01a4af82b --- /dev/null +++ b/pkg/types/callbacks.go @@ -0,0 +1,37 @@ +package types + +type CommonCallback struct { + readyCallbacks []func() + closedCallbacks []func() + errorCallbacks []func(error) +} + +func (c *CommonCallback) OnReady(cb func()) { + c.readyCallbacks = append(c.readyCallbacks, cb) +} + +func (c *CommonCallback) EmitReady() { + for _, cb := range c.readyCallbacks { + cb() + } +} + +func (c *CommonCallback) OnClosed(cb func()) { + c.closedCallbacks = append(c.closedCallbacks, cb) +} + +func (c *CommonCallback) EmitClosed() { + for _, cb := range c.closedCallbacks { + cb() + } +} + +func (c *CommonCallback) OnError(cb func(err error)) { + c.errorCallbacks = append(c.errorCallbacks, cb) +} + +func (c *CommonCallback) EmitError(err error) { + for _, cb := range c.errorCallbacks { + cb(err) + } +}