diff --git a/pkg/strategy/common/fee_budget.go b/pkg/strategy/common/fee_budget.go new file mode 100644 index 000000000..57fd93256 --- /dev/null +++ b/pkg/strategy/common/fee_budget.go @@ -0,0 +1,85 @@ +package common + +import ( + "time" + + "github.com/c9s/bbgo/pkg/fixedpoint" + "github.com/c9s/bbgo/pkg/types" + log "github.com/sirupsen/logrus" +) + +type FeeBudget struct { + DailyFeeBudgets map[string]fixedpoint.Value `json:"dailyFeeBudgets,omitempty"` + State *State `persistence:"state"` +} + +func (f *FeeBudget) Initialize() { + if f.State == nil { + f.State = &State{} + f.State.Reset() + } + + if f.State.IsOver24Hours() { + log.Warn("[FeeBudget] state is over 24 hours, resetting to zero") + f.State.Reset() + } +} + +func (f *FeeBudget) IsBudgetAllowed() bool { + if f.DailyFeeBudgets == nil { + return true + } + + if f.State.AccumulatedFees == nil { + return true + } + + for asset, budget := range f.DailyFeeBudgets { + if fee, ok := f.State.AccumulatedFees[asset]; ok { + if fee.Compare(budget) >= 0 { + log.Warnf("[FeeBudget] accumulative fee %s exceeded the fee budget %s, skipping...", fee.String(), budget.String()) + return false + } + } + } + + return true +} + +func (f *FeeBudget) HandleTradeUpdate(trade types.Trade) { + log.Infof("[FeeBudget] received trade %s", trade.String()) + + if f.State.IsOver24Hours() { + f.State.Reset() + } + + // safe check + if f.State.AccumulatedFees == nil { + f.State.AccumulatedFees = make(map[string]fixedpoint.Value) + } + + f.State.AccumulatedFees[trade.FeeCurrency] = f.State.AccumulatedFees[trade.FeeCurrency].Add(trade.Fee) + f.State.AccumulatedVolume = f.State.AccumulatedVolume.Add(trade.Quantity) + log.Infof("[FeeBudget] accumulated fee: %s %s", f.State.AccumulatedFees[trade.FeeCurrency].String(), trade.FeeCurrency) +} + +type State struct { + AccumulatedFeeStartedAt time.Time `json:"accumulatedFeeStartedAt,omitempty"` + AccumulatedFees map[string]fixedpoint.Value `json:"accumulatedFees,omitempty"` + AccumulatedVolume fixedpoint.Value `json:"accumulatedVolume,omitempty"` +} + +func (s *State) IsOver24Hours() bool { + return time.Since(s.AccumulatedFeeStartedAt) >= 24*time.Hour +} + +func (s *State) Reset() { + t := time.Now() + dateTime := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) + + log.Infof("[State] resetting accumulated started time to: %s", dateTime) + + s.AccumulatedFeeStartedAt = dateTime + s.AccumulatedFees = make(map[string]fixedpoint.Value) + s.AccumulatedVolume = fixedpoint.Zero +} diff --git a/pkg/strategy/common/fee_budget_test.go b/pkg/strategy/common/fee_budget_test.go new file mode 100644 index 000000000..7ceb972b4 --- /dev/null +++ b/pkg/strategy/common/fee_budget_test.go @@ -0,0 +1,52 @@ +package common + +import ( + "testing" + + "github.com/c9s/bbgo/pkg/fixedpoint" + "github.com/c9s/bbgo/pkg/types" + "github.com/stretchr/testify/assert" +) + +func TestFeeBudget(t *testing.T) { + cases := []struct { + budgets map[string]fixedpoint.Value + trades []types.Trade + expected bool + }{ + { + budgets: map[string]fixedpoint.Value{ + "MAX": fixedpoint.NewFromFloat(0.5), + }, + trades: []types.Trade{ + {FeeCurrency: "MAX", Fee: fixedpoint.NewFromFloat(0.1)}, + {FeeCurrency: "USDT", Fee: fixedpoint.NewFromFloat(10.0)}, + }, + expected: true, + }, + { + budgets: map[string]fixedpoint.Value{ + "MAX": fixedpoint.NewFromFloat(0.5), + }, + trades: []types.Trade{ + {FeeCurrency: "MAX", Fee: fixedpoint.NewFromFloat(0.1)}, + {FeeCurrency: "MAX", Fee: fixedpoint.NewFromFloat(0.5)}, + {FeeCurrency: "USDT", Fee: fixedpoint.NewFromFloat(10.0)}, + }, + expected: false, + }, + } + + for _, c := range cases { + feeBudget := FeeBudget{ + DailyFeeBudgets: c.budgets, + } + feeBudget.Initialize() + + for _, trade := range c.trades { + feeBudget.HandleTradeUpdate(trade) + } + + assert.Equal(t, c.expected, feeBudget.IsBudgetAllowed()) + } +} diff --git a/pkg/strategy/xgap/strategy.go b/pkg/strategy/xgap/strategy.go index 43caf97a5..2470bf71c 100644 --- a/pkg/strategy/xgap/strategy.go +++ b/pkg/strategy/xgap/strategy.go @@ -37,29 +37,9 @@ func (s *Strategy) InstanceID() string { return fmt.Sprintf("%s:%s", ID, s.Symbol) } -type State struct { - AccumulatedFeeStartedAt time.Time `json:"accumulatedFeeStartedAt,omitempty"` - AccumulatedFees map[string]fixedpoint.Value `json:"accumulatedFees,omitempty"` - AccumulatedVolume fixedpoint.Value `json:"accumulatedVolume,omitempty"` -} - -func (s *State) IsOver24Hours() bool { - return time.Since(s.AccumulatedFeeStartedAt) >= 24*time.Hour -} - -func (s *State) Reset() { - t := time.Now() - dateTime := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) - - log.Infof("resetting accumulated started time to: %s", dateTime) - - s.AccumulatedFeeStartedAt = dateTime - s.AccumulatedFees = make(map[string]fixedpoint.Value) - s.AccumulatedVolume = fixedpoint.Zero -} - type Strategy struct { *common.Strategy + *common.FeeBudget Environment *bbgo.Environment @@ -70,18 +50,15 @@ type Strategy struct { Quantity fixedpoint.Value `json:"quantity"` DryRun bool `json:"dryRun"` - DailyFeeBudgets map[string]fixedpoint.Value `json:"dailyFeeBudgets,omitempty"` - DailyMaxVolume fixedpoint.Value `json:"dailyMaxVolume,omitempty"` - DailyTargetVolume fixedpoint.Value `json:"dailyTargetVolume,omitempty"` - UpdateInterval types.Duration `json:"updateInterval"` - SimulateVolume bool `json:"simulateVolume"` - SimulatePrice bool `json:"simulatePrice"` + DailyMaxVolume fixedpoint.Value `json:"dailyMaxVolume,omitempty"` + DailyTargetVolume fixedpoint.Value `json:"dailyTargetVolume,omitempty"` + UpdateInterval types.Duration `json:"updateInterval"` + SimulateVolume bool `json:"simulateVolume"` + SimulatePrice bool `json:"simulatePrice"` sourceSession, tradingSession *bbgo.ExchangeSession sourceMarket, tradingMarket types.Market - State *State `persistence:"state"` - mu sync.Mutex lastSourceKLine, lastTradingKLine types.KLine sourceBook, tradingBook *types.StreamOrderBook @@ -93,6 +70,10 @@ func (s *Strategy) Initialize() error { if s.Strategy == nil { s.Strategy = &common.Strategy{} } + + if s.FeeBudget == nil { + s.FeeBudget = &common.FeeBudget{} + } return nil } @@ -107,48 +88,6 @@ func (s *Strategy) Defaults() error { return nil } -func (s *Strategy) isBudgetAllowed() bool { - if s.DailyFeeBudgets == nil { - return true - } - - if s.State.AccumulatedFees == nil { - return true - } - - for asset, budget := range s.DailyFeeBudgets { - if fee, ok := s.State.AccumulatedFees[asset]; ok { - if fee.Compare(budget) >= 0 { - log.Warnf("accumulative fee %s exceeded the fee budget %s, skipping...", fee.String(), budget.String()) - return false - } - } - } - - return true -} - -func (s *Strategy) handleTradeUpdate(trade types.Trade) { - log.Infof("received trade %s", trade.String()) - - if trade.Symbol != s.Symbol { - return - } - - if s.State.IsOver24Hours() { - s.State.Reset() - } - - // safe check - if s.State.AccumulatedFees == nil { - s.State.AccumulatedFees = make(map[string]fixedpoint.Value) - } - - s.State.AccumulatedFees[trade.FeeCurrency] = s.State.AccumulatedFees[trade.FeeCurrency].Add(trade.Fee) - s.State.AccumulatedVolume = s.State.AccumulatedVolume.Add(trade.Quantity) - log.Infof("accumulated fee: %s %s", s.State.AccumulatedFees[trade.FeeCurrency].String(), trade.FeeCurrency) -} - func (s *Strategy) CrossSubscribe(sessions map[string]*bbgo.ExchangeSession) { sourceSession, ok := sessions[s.SourceExchange] if !ok { @@ -191,19 +130,10 @@ func (s *Strategy) CrossRun(ctx context.Context, _ bbgo.OrderExecutionRouter, se } s.Strategy.Initialize(ctx, s.Environment, tradingSession, s.tradingMarket, ID, s.InstanceID()) + s.FeeBudget.Initialize() s.stopC = make(chan struct{}) - if s.State == nil { - s.State = &State{} - s.State.Reset() - } - - if s.State.IsOver24Hours() { - log.Warn("state is over 24 hours, resetting to zero") - s.State.Reset() - } - bbgo.OnShutdown(ctx, func(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() close(s.stopC) @@ -230,7 +160,12 @@ func (s *Strategy) CrossRun(ctx context.Context, _ bbgo.OrderExecutionRouter, se s.tradingBook = types.NewStreamBook(s.Symbol) s.tradingBook.BindStream(s.tradingSession.MarketDataStream) - s.tradingSession.UserDataStream.OnTradeUpdate(s.handleTradeUpdate) + s.tradingSession.UserDataStream.OnTradeUpdate(func(trade types.Trade) { + if trade.Symbol != s.Symbol { + return + } + s.FeeBudget.HandleTradeUpdate(trade) + }) go func() { ticker := time.NewTicker( @@ -247,7 +182,7 @@ func (s *Strategy) CrossRun(ctx context.Context, _ bbgo.OrderExecutionRouter, se return case <-ticker.C: - if !s.isBudgetAllowed() { + if !s.IsBudgetAllowed() { continue }