use CommonCallback and pull PersistenceTTL out

This commit is contained in:
chiahung.lin 2024-01-09 16:01:10 +08:00
parent 21e87079b5
commit d3bc37f45e
7 changed files with 123 additions and 116 deletions

View File

@ -61,7 +61,7 @@ func generateOpenPositionOrders(market types.Market, quoteInvestment, price, pri
prices = append(prices, price)
}
notional, orderNum := calculateNotionalAndNum(market, quoteInvestment, prices)
notional, orderNum := calculateNotionalAndNumOrders(market, quoteInvestment, prices)
if orderNum == 0 {
return nil, fmt.Errorf("failed to calculate notional and num of open position orders, price: %s, quote investment: %s", price, quoteInvestment)
}
@ -87,9 +87,9 @@ func generateOpenPositionOrders(market types.Market, quoteInvestment, price, pri
return submitOrders, nil
}
// calculateNotionalAndNum calculates the notional and num of open position orders
// calculateNotionalAndNumOrders calculates the notional and num of open position orders
// DCA2 is notional-based, every order has the same notional
func calculateNotionalAndNum(market types.Market, quoteInvestment fixedpoint.Value, prices []fixedpoint.Value) (fixedpoint.Value, int) {
func calculateNotionalAndNumOrders(market types.Market, quoteInvestment fixedpoint.Value, prices []fixedpoint.Value) (fixedpoint.Value, int) {
for num := len(prices); num > 0; num-- {
notional := quoteInvestment.Div(fixedpoint.NewFromInt(int64(num)))
if notional.Compare(market.MinNotional) < 0 {

View File

@ -1,9 +1,7 @@
package dca2
import (
"context"
"fmt"
"strconv"
"strings"
"time"
@ -11,6 +9,21 @@ import (
"github.com/c9s/bbgo/pkg/types"
)
type PersistenceTTL struct {
ttl time.Duration
}
func (p *PersistenceTTL) SetTTL(ttl time.Duration) {
if ttl.Nanoseconds() <= 0 {
return
}
p.ttl = ttl
}
func (p *PersistenceTTL) Expiration() time.Duration {
return p.ttl
}
type ProfitStats struct {
Symbol string `json:"symbol"`
Market types.Market `json:"market,omitempty"`
@ -19,13 +32,12 @@ type ProfitStats struct {
Round int64 `json:"round,omitempty"`
QuoteInvestment fixedpoint.Value `json:"quoteInvestment,omitempty"`
RoundProfit fixedpoint.Value `json:"roundProfit,omitempty"`
RoundFee map[string]fixedpoint.Value `json:"roundFee,omitempty"`
TotalProfit fixedpoint.Value `json:"totalProfit,omitempty"`
TotalFee map[string]fixedpoint.Value `json:"totalFee,omitempty"`
CurrentRoundProfit fixedpoint.Value `json:"currentRoundProfit,omitempty"`
CurrentRoundFee map[string]fixedpoint.Value `json:"currentRoundFee,omitempty"`
TotalProfit fixedpoint.Value `json:"totalProfit,omitempty"`
TotalFee map[string]fixedpoint.Value `json:"totalFee,omitempty"`
// ttl is the ttl to keep in persistence
ttl time.Duration
PersistenceTTL
}
func newProfitStats(market types.Market, quoteInvestment fixedpoint.Value) *ProfitStats {
@ -34,31 +46,20 @@ func newProfitStats(market types.Market, quoteInvestment fixedpoint.Value) *Prof
Market: market,
Round: 0,
QuoteInvestment: quoteInvestment,
RoundFee: make(map[string]fixedpoint.Value),
CurrentRoundFee: make(map[string]fixedpoint.Value),
TotalFee: make(map[string]fixedpoint.Value),
}
}
func (s *ProfitStats) SetTTL(ttl time.Duration) {
if ttl.Nanoseconds() <= 0 {
return
}
s.ttl = ttl
}
func (s *ProfitStats) Expiration() time.Duration {
return s.ttl
}
func (s *ProfitStats) AddTrade(trade types.Trade) {
if s.RoundFee == nil {
s.RoundFee = make(map[string]fixedpoint.Value)
if s.CurrentRoundFee == nil {
s.CurrentRoundFee = make(map[string]fixedpoint.Value)
}
if fee, ok := s.RoundFee[trade.FeeCurrency]; ok {
s.RoundFee[trade.FeeCurrency] = fee.Add(trade.Fee)
if fee, ok := s.CurrentRoundFee[trade.FeeCurrency]; ok {
s.CurrentRoundFee[trade.FeeCurrency] = fee.Add(trade.Fee)
} else {
s.RoundFee[trade.FeeCurrency] = trade.Fee
s.CurrentRoundFee[trade.FeeCurrency] = trade.Fee
}
if s.TotalFee == nil {
@ -76,63 +77,19 @@ func (s *ProfitStats) AddTrade(trade types.Trade) {
quoteQuantity = quoteQuantity.Neg()
}
s.RoundProfit = s.RoundProfit.Add(quoteQuantity)
s.CurrentRoundProfit = s.CurrentRoundProfit.Add(quoteQuantity)
s.TotalProfit = s.TotalProfit.Add(quoteQuantity)
if s.Market.QuoteCurrency == trade.FeeCurrency {
s.RoundProfit.Sub(trade.Fee)
s.CurrentRoundProfit.Sub(trade.Fee)
s.TotalProfit.Sub(trade.Fee)
}
}
func (s *ProfitStats) NewRound() {
s.Round++
s.RoundProfit = fixedpoint.Zero
s.RoundFee = make(map[string]fixedpoint.Value)
}
func (s *ProfitStats) CalculateProfitOfRound(ctx context.Context, exchange types.Exchange) error {
historyService, ok := exchange.(types.ExchangeTradeHistoryService)
if !ok {
return fmt.Errorf("exchange %s doesn't support ExchangeTradeHistoryService", exchange.Name())
}
queryService, ok := exchange.(types.ExchangeOrderQueryService)
if !ok {
return fmt.Errorf("exchange %s doesn't support ExchangeOrderQueryService", exchange.Name())
}
// query the orders of this round
orders, err := historyService.QueryClosedOrders(ctx, s.Symbol, time.Time{}, time.Time{}, s.FromOrderID)
if err != nil {
return err
}
// query the trades of this round
for _, order := range orders {
if order.ExecutedQuantity.Sign() == 0 {
// skip no trade orders
continue
}
trades, err := queryService.QueryOrderTrades(ctx, types.OrderQuery{
Symbol: order.Symbol,
OrderID: strconv.FormatUint(order.OrderID, 10),
})
if err != nil {
return err
}
for _, trade := range trades {
s.AddTrade(trade)
}
}
s.FromOrderID = s.FromOrderID + 1
s.QuoteInvestment = s.QuoteInvestment.Add(s.RoundProfit)
return nil
s.CurrentRoundProfit = fixedpoint.Zero
s.CurrentRoundFee = make(map[string]fixedpoint.Value)
}
func (s *ProfitStats) String() string {
@ -141,9 +98,9 @@ func (s *ProfitStats) String() string {
sb.WriteString(fmt.Sprintf("Round: %d\n", s.Round))
sb.WriteString(fmt.Sprintf("From Order ID: %d\n", s.FromOrderID))
sb.WriteString(fmt.Sprintf("Quote Investment: %s\n", s.QuoteInvestment))
sb.WriteString(fmt.Sprintf("Round Profit: %s\n", s.RoundProfit))
sb.WriteString(fmt.Sprintf("Current Round Profit: %s\n", s.CurrentRoundProfit))
sb.WriteString(fmt.Sprintf("Total Profit: %s\n", s.TotalProfit))
for currency, fee := range s.RoundFee {
for currency, fee := range s.CurrentRoundFee {
sb.WriteString(fmt.Sprintf("FEE (%s): %s\n", currency, fee))
}
sb.WriteString("[------------------ Profit Stats ------------------]\n")

View File

@ -57,7 +57,7 @@ func (s *Strategy) recover(ctx context.Context) error {
}
// recover profit stats
recoverProfitStats(ctx, s.ProfitStats, s.Session.Exchange)
recoverProfitStats(ctx, s)
// recover startTimeOfNextRound
startTimeOfNextRound := recoverStartTimeOfNextRound(ctx, currentRound, s.CoolDownInterval)
@ -189,12 +189,12 @@ func recoverPosition(ctx context.Context, position *types.Position, queryService
return nil
}
func recoverProfitStats(ctx context.Context, profitStats *ProfitStats, exchange types.Exchange) error {
if profitStats == nil {
func recoverProfitStats(ctx context.Context, strategy *Strategy) error {
if strategy.ProfitStats == nil {
return fmt.Errorf("profit stats is nil, please check it")
}
profitStats.CalculateProfitOfRound(ctx, exchange)
strategy.CalculateProfitOfCurrentRound(ctx)
return nil
}

View File

@ -193,7 +193,7 @@ func (s *Strategy) runTakeProfitReady(ctx context.Context, next State) {
s.logger.Info("[State] TakeProfitReady - start reseting position and calculate quote investment for next round")
// calculate profit stats
s.ProfitStats.CalculateProfitOfRound(ctx, s.Session.Exchange)
s.CalculateProfitOfCurrentRound(ctx)
bbgo.Sync(ctx, s)
s.EmitProfit(s.ProfitStats)

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"math"
"strconv"
"sync"
"time"
@ -68,11 +69,9 @@ type Strategy struct {
state State
// callbacks
readyCallbacks []func()
types.CommonCallback
positionCallbacks []func(*types.Position)
profitCallbacks []func(*ProfitStats)
closedCallbacks []func()
errorCallbacks []func(error)
}
func (s *Strategy) ID() string {
@ -278,3 +277,47 @@ func (s *Strategy) CleanUp(ctx context.Context) error {
bbgo.Sync(ctx, s)
return err
}
func (s *Strategy) CalculateProfitOfCurrentRound(ctx context.Context) error {
historyService, ok := s.Session.Exchange.(types.ExchangeTradeHistoryService)
if !ok {
return fmt.Errorf("exchange %s doesn't support ExchangeTradeHistoryService", s.Session.Exchange.Name())
}
queryService, ok := s.Session.Exchange.(types.ExchangeOrderQueryService)
if !ok {
return fmt.Errorf("exchange %s doesn't support ExchangeOrderQueryService", s.Session.Exchange.Name())
}
// query the orders of this round
orders, err := historyService.QueryClosedOrders(ctx, s.Symbol, time.Time{}, time.Time{}, s.ProfitStats.FromOrderID)
if err != nil {
return err
}
// query the trades of this round
for _, order := range orders {
if order.ExecutedQuantity.Sign() == 0 {
// skip no trade orders
continue
}
trades, err := queryService.QueryOrderTrades(ctx, types.OrderQuery{
Symbol: order.Symbol,
OrderID: strconv.FormatUint(order.OrderID, 10),
})
if err != nil {
return err
}
for _, trade := range trades {
s.ProfitStats.AddTrade(trade)
}
}
s.ProfitStats.FromOrderID = s.ProfitStats.FromOrderID + 1
s.ProfitStats.QuoteInvestment = s.ProfitStats.QuoteInvestment.Add(s.ProfitStats.CurrentRoundProfit)
return nil
}

View File

@ -6,16 +6,6 @@ import (
"github.com/c9s/bbgo/pkg/types"
)
func (s *Strategy) OnReady(cb func()) {
s.readyCallbacks = append(s.readyCallbacks, cb)
}
func (s *Strategy) EmitReady() {
for _, cb := range s.readyCallbacks {
cb()
}
}
func (s *Strategy) OnPosition(cb func(*types.Position)) {
s.positionCallbacks = append(s.positionCallbacks, cb)
}
@ -35,23 +25,3 @@ func (s *Strategy) EmitProfit(profitStats *ProfitStats) {
cb(profitStats)
}
}
func (s *Strategy) OnClosed(cb func()) {
s.closedCallbacks = append(s.closedCallbacks, cb)
}
func (s *Strategy) EmitClosed() {
for _, cb := range s.closedCallbacks {
cb()
}
}
func (s *Strategy) OnError(cb func(err error)) {
s.errorCallbacks = append(s.errorCallbacks, cb)
}
func (s *Strategy) EmitError(err error) {
for _, cb := range s.errorCallbacks {
cb(err)
}
}

37
pkg/types/callbacks.go Normal file
View File

@ -0,0 +1,37 @@
package types
type CommonCallback struct {
readyCallbacks []func()
closedCallbacks []func()
errorCallbacks []func(error)
}
func (c *CommonCallback) OnReady(cb func()) {
c.readyCallbacks = append(c.readyCallbacks, cb)
}
func (c *CommonCallback) EmitReady() {
for _, cb := range c.readyCallbacks {
cb()
}
}
func (c *CommonCallback) OnClosed(cb func()) {
c.closedCallbacks = append(c.closedCallbacks, cb)
}
func (c *CommonCallback) EmitClosed() {
for _, cb := range c.closedCallbacks {
cb()
}
}
func (c *CommonCallback) OnError(cb func(err error)) {
c.errorCallbacks = append(c.errorCallbacks, cb)
}
func (c *CommonCallback) EmitError(err error) {
for _, cb := range c.errorCallbacks {
cb(err)
}
}