mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-10 09:11:55 +00:00
Merge pull request #1644 from c9s/narumi/fee-budget
REFACTOR: Extract and move FeeBudget from xgap
This commit is contained in:
commit
396ee68170
92
pkg/strategy/common/fee_budget.go
Normal file
92
pkg/strategy/common/fee_budget.go
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/c9s/bbgo/pkg/fixedpoint"
|
||||||
|
"github.com/c9s/bbgo/pkg/types"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FeeBudget struct {
|
||||||
|
DailyFeeBudgets map[string]fixedpoint.Value `json:"dailyFeeBudgets,omitempty"`
|
||||||
|
State *State `persistence:"state"`
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FeeBudget) Initialize() {
|
||||||
|
if f.State == nil {
|
||||||
|
f.State = &State{}
|
||||||
|
f.State.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.State.IsOver24Hours() {
|
||||||
|
log.Warn("[FeeBudget] state is over 24 hours, resetting to zero")
|
||||||
|
f.State.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FeeBudget) IsBudgetAllowed() bool {
|
||||||
|
if f.DailyFeeBudgets == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.State.AccumulatedFees == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.State.IsOver24Hours() {
|
||||||
|
f.State.Reset()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for asset, budget := range f.DailyFeeBudgets {
|
||||||
|
if fee, ok := f.State.AccumulatedFees[asset]; ok {
|
||||||
|
if fee.Compare(budget) >= 0 {
|
||||||
|
log.Warnf("[FeeBudget] accumulative fee %s exceeded the fee budget %s, skipping...", fee.String(), budget.String())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FeeBudget) HandleTradeUpdate(trade types.Trade) {
|
||||||
|
log.Infof("[FeeBudget] received trade %s", trade.String())
|
||||||
|
|
||||||
|
if f.State.IsOver24Hours() {
|
||||||
|
f.State.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
// safe check
|
||||||
|
if f.State.AccumulatedFees == nil {
|
||||||
|
f.mu.Lock()
|
||||||
|
f.State.AccumulatedFees = make(map[string]fixedpoint.Value)
|
||||||
|
f.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
f.State.AccumulatedFees[trade.FeeCurrency] = f.State.AccumulatedFees[trade.FeeCurrency].Add(trade.Fee)
|
||||||
|
log.Infof("[FeeBudget] accumulated fee: %s %s", f.State.AccumulatedFees[trade.FeeCurrency].String(), trade.FeeCurrency)
|
||||||
|
}
|
||||||
|
|
||||||
|
type State struct {
|
||||||
|
AccumulatedFeeStartedAt time.Time `json:"accumulatedFeeStartedAt,omitempty"`
|
||||||
|
AccumulatedFees map[string]fixedpoint.Value `json:"accumulatedFees,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) IsOver24Hours() bool {
|
||||||
|
return time.Since(s.AccumulatedFeeStartedAt) >= 24*time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) Reset() {
|
||||||
|
t := time.Now()
|
||||||
|
dateTime := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
||||||
|
|
||||||
|
log.Infof("[State] resetting accumulated started time to: %s", dateTime)
|
||||||
|
|
||||||
|
s.AccumulatedFeeStartedAt = dateTime
|
||||||
|
s.AccumulatedFees = make(map[string]fixedpoint.Value)
|
||||||
|
}
|
56
pkg/strategy/common/fee_budget_test.go
Normal file
56
pkg/strategy/common/fee_budget_test.go
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/c9s/bbgo/pkg/fixedpoint"
|
||||||
|
"github.com/c9s/bbgo/pkg/types"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFeeBudget(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
budgets map[string]fixedpoint.Value
|
||||||
|
trades []types.Trade
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
budgets: map[string]fixedpoint.Value{
|
||||||
|
"MAX": fixedpoint.NewFromFloat(0.5),
|
||||||
|
},
|
||||||
|
trades: []types.Trade{
|
||||||
|
{FeeCurrency: "MAX", Fee: fixedpoint.NewFromFloat(0.1)},
|
||||||
|
{FeeCurrency: "USDT", Fee: fixedpoint.NewFromFloat(10.0)},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
budgets: map[string]fixedpoint.Value{
|
||||||
|
"MAX": fixedpoint.NewFromFloat(0.5),
|
||||||
|
},
|
||||||
|
trades: []types.Trade{
|
||||||
|
{FeeCurrency: "MAX", Fee: fixedpoint.NewFromFloat(0.1)},
|
||||||
|
{FeeCurrency: "MAX", Fee: fixedpoint.NewFromFloat(0.5)},
|
||||||
|
{FeeCurrency: "USDT", Fee: fixedpoint.NewFromFloat(10.0)},
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
feeBudget := FeeBudget{
|
||||||
|
DailyFeeBudgets: c.budgets,
|
||||||
|
}
|
||||||
|
feeBudget.Initialize()
|
||||||
|
|
||||||
|
for _, trade := range c.trades {
|
||||||
|
feeBudget.HandleTradeUpdate(trade)
|
||||||
|
}
|
||||||
|
assert.Equal(t, c.expected, feeBudget.IsBudgetAllowed())
|
||||||
|
|
||||||
|
// test reset
|
||||||
|
feeBudget.State.AccumulatedFeeStartedAt = feeBudget.State.AccumulatedFeeStartedAt.Add(-24 * time.Hour)
|
||||||
|
assert.True(t, feeBudget.IsBudgetAllowed())
|
||||||
|
}
|
||||||
|
}
|
|
@ -24,6 +24,7 @@ func init() {
|
||||||
|
|
||||||
type Strategy struct {
|
type Strategy struct {
|
||||||
*common.Strategy
|
*common.Strategy
|
||||||
|
*common.FeeBudget
|
||||||
|
|
||||||
Environment *bbgo.Environment
|
Environment *bbgo.Environment
|
||||||
Market types.Market
|
Market types.Market
|
||||||
|
@ -45,6 +46,10 @@ func (s *Strategy) Initialize() error {
|
||||||
if s.Strategy == nil {
|
if s.Strategy == nil {
|
||||||
s.Strategy = &common.Strategy{}
|
s.Strategy = &common.Strategy{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.FeeBudget == nil {
|
||||||
|
s.FeeBudget = &common.FeeBudget{}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,11 +76,25 @@ func (s *Strategy) Subscribe(session *bbgo.ExchangeSession) {}
|
||||||
|
|
||||||
func (s *Strategy) Run(ctx context.Context, _ bbgo.OrderExecutor, session *bbgo.ExchangeSession) error {
|
func (s *Strategy) Run(ctx context.Context, _ bbgo.OrderExecutor, session *bbgo.ExchangeSession) error {
|
||||||
s.Strategy.Initialize(ctx, s.Environment, session, s.Market, s.ID(), s.InstanceID())
|
s.Strategy.Initialize(ctx, s.Environment, session, s.Market, s.ID(), s.InstanceID())
|
||||||
|
s.FeeBudget.Initialize()
|
||||||
|
|
||||||
session.UserDataStream.OnStart(func() {
|
session.UserDataStream.OnStart(func() {
|
||||||
if s.OnStart {
|
if !s.OnStart {
|
||||||
s.placeOrder()
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !s.FeeBudget.IsBudgetAllowed() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.placeOrder(ctx)
|
||||||
|
})
|
||||||
|
|
||||||
|
session.UserDataStream.OnTradeUpdate(func(trade types.Trade) {
|
||||||
|
if trade.Symbol != s.Symbol {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.FeeBudget.HandleTradeUpdate(trade)
|
||||||
})
|
})
|
||||||
|
|
||||||
// the shutdown handler, you can cancel all orders
|
// the shutdown handler, you can cancel all orders
|
||||||
|
@ -86,15 +105,19 @@ func (s *Strategy) Run(ctx context.Context, _ bbgo.OrderExecutor, session *bbgo.
|
||||||
})
|
})
|
||||||
|
|
||||||
s.cron = cron.New()
|
s.cron = cron.New()
|
||||||
s.cron.AddFunc(s.Schedule, s.placeOrder)
|
s.cron.AddFunc(s.Schedule, func() {
|
||||||
|
if !s.FeeBudget.IsBudgetAllowed() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.placeOrder(ctx)
|
||||||
|
})
|
||||||
s.cron.Start()
|
s.cron.Start()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Strategy) placeOrder() {
|
func (s *Strategy) placeOrder(ctx context.Context) {
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
baseBalance, ok := s.Session.GetAccount().Balance(s.Market.BaseCurrency)
|
baseBalance, ok := s.Session.GetAccount().Balance(s.Market.BaseCurrency)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf("base balance not found")
|
log.Errorf("base balance not found")
|
||||||
|
|
|
@ -37,29 +37,9 @@ func (s *Strategy) InstanceID() string {
|
||||||
return fmt.Sprintf("%s:%s", ID, s.Symbol)
|
return fmt.Sprintf("%s:%s", ID, s.Symbol)
|
||||||
}
|
}
|
||||||
|
|
||||||
type State struct {
|
|
||||||
AccumulatedFeeStartedAt time.Time `json:"accumulatedFeeStartedAt,omitempty"`
|
|
||||||
AccumulatedFees map[string]fixedpoint.Value `json:"accumulatedFees,omitempty"`
|
|
||||||
AccumulatedVolume fixedpoint.Value `json:"accumulatedVolume,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *State) IsOver24Hours() bool {
|
|
||||||
return time.Since(s.AccumulatedFeeStartedAt) >= 24*time.Hour
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *State) Reset() {
|
|
||||||
t := time.Now()
|
|
||||||
dateTime := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
|
||||||
|
|
||||||
log.Infof("resetting accumulated started time to: %s", dateTime)
|
|
||||||
|
|
||||||
s.AccumulatedFeeStartedAt = dateTime
|
|
||||||
s.AccumulatedFees = make(map[string]fixedpoint.Value)
|
|
||||||
s.AccumulatedVolume = fixedpoint.Zero
|
|
||||||
}
|
|
||||||
|
|
||||||
type Strategy struct {
|
type Strategy struct {
|
||||||
*common.Strategy
|
*common.Strategy
|
||||||
|
*common.FeeBudget
|
||||||
|
|
||||||
Environment *bbgo.Environment
|
Environment *bbgo.Environment
|
||||||
|
|
||||||
|
@ -70,7 +50,6 @@ type Strategy struct {
|
||||||
Quantity fixedpoint.Value `json:"quantity"`
|
Quantity fixedpoint.Value `json:"quantity"`
|
||||||
DryRun bool `json:"dryRun"`
|
DryRun bool `json:"dryRun"`
|
||||||
|
|
||||||
DailyFeeBudgets map[string]fixedpoint.Value `json:"dailyFeeBudgets,omitempty"`
|
|
||||||
DailyMaxVolume fixedpoint.Value `json:"dailyMaxVolume,omitempty"`
|
DailyMaxVolume fixedpoint.Value `json:"dailyMaxVolume,omitempty"`
|
||||||
DailyTargetVolume fixedpoint.Value `json:"dailyTargetVolume,omitempty"`
|
DailyTargetVolume fixedpoint.Value `json:"dailyTargetVolume,omitempty"`
|
||||||
UpdateInterval types.Duration `json:"updateInterval"`
|
UpdateInterval types.Duration `json:"updateInterval"`
|
||||||
|
@ -80,8 +59,6 @@ type Strategy struct {
|
||||||
sourceSession, tradingSession *bbgo.ExchangeSession
|
sourceSession, tradingSession *bbgo.ExchangeSession
|
||||||
sourceMarket, tradingMarket types.Market
|
sourceMarket, tradingMarket types.Market
|
||||||
|
|
||||||
State *State `persistence:"state"`
|
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
lastSourceKLine, lastTradingKLine types.KLine
|
lastSourceKLine, lastTradingKLine types.KLine
|
||||||
sourceBook, tradingBook *types.StreamOrderBook
|
sourceBook, tradingBook *types.StreamOrderBook
|
||||||
|
@ -93,6 +70,10 @@ func (s *Strategy) Initialize() error {
|
||||||
if s.Strategy == nil {
|
if s.Strategy == nil {
|
||||||
s.Strategy = &common.Strategy{}
|
s.Strategy = &common.Strategy{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.FeeBudget == nil {
|
||||||
|
s.FeeBudget = &common.FeeBudget{}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,48 +88,6 @@ func (s *Strategy) Defaults() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Strategy) isBudgetAllowed() bool {
|
|
||||||
if s.DailyFeeBudgets == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.State.AccumulatedFees == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
for asset, budget := range s.DailyFeeBudgets {
|
|
||||||
if fee, ok := s.State.AccumulatedFees[asset]; ok {
|
|
||||||
if fee.Compare(budget) >= 0 {
|
|
||||||
log.Warnf("accumulative fee %s exceeded the fee budget %s, skipping...", fee.String(), budget.String())
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Strategy) handleTradeUpdate(trade types.Trade) {
|
|
||||||
log.Infof("received trade %s", trade.String())
|
|
||||||
|
|
||||||
if trade.Symbol != s.Symbol {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.State.IsOver24Hours() {
|
|
||||||
s.State.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
// safe check
|
|
||||||
if s.State.AccumulatedFees == nil {
|
|
||||||
s.State.AccumulatedFees = make(map[string]fixedpoint.Value)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.State.AccumulatedFees[trade.FeeCurrency] = s.State.AccumulatedFees[trade.FeeCurrency].Add(trade.Fee)
|
|
||||||
s.State.AccumulatedVolume = s.State.AccumulatedVolume.Add(trade.Quantity)
|
|
||||||
log.Infof("accumulated fee: %s %s", s.State.AccumulatedFees[trade.FeeCurrency].String(), trade.FeeCurrency)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Strategy) CrossSubscribe(sessions map[string]*bbgo.ExchangeSession) {
|
func (s *Strategy) CrossSubscribe(sessions map[string]*bbgo.ExchangeSession) {
|
||||||
sourceSession, ok := sessions[s.SourceExchange]
|
sourceSession, ok := sessions[s.SourceExchange]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -191,19 +130,10 @@ func (s *Strategy) CrossRun(ctx context.Context, _ bbgo.OrderExecutionRouter, se
|
||||||
}
|
}
|
||||||
|
|
||||||
s.Strategy.Initialize(ctx, s.Environment, tradingSession, s.tradingMarket, ID, s.InstanceID())
|
s.Strategy.Initialize(ctx, s.Environment, tradingSession, s.tradingMarket, ID, s.InstanceID())
|
||||||
|
s.FeeBudget.Initialize()
|
||||||
|
|
||||||
s.stopC = make(chan struct{})
|
s.stopC = make(chan struct{})
|
||||||
|
|
||||||
if s.State == nil {
|
|
||||||
s.State = &State{}
|
|
||||||
s.State.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.State.IsOver24Hours() {
|
|
||||||
log.Warn("state is over 24 hours, resetting to zero")
|
|
||||||
s.State.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
bbgo.OnShutdown(ctx, func(ctx context.Context, wg *sync.WaitGroup) {
|
bbgo.OnShutdown(ctx, func(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
close(s.stopC)
|
close(s.stopC)
|
||||||
|
@ -230,7 +160,12 @@ func (s *Strategy) CrossRun(ctx context.Context, _ bbgo.OrderExecutionRouter, se
|
||||||
s.tradingBook = types.NewStreamBook(s.Symbol)
|
s.tradingBook = types.NewStreamBook(s.Symbol)
|
||||||
s.tradingBook.BindStream(s.tradingSession.MarketDataStream)
|
s.tradingBook.BindStream(s.tradingSession.MarketDataStream)
|
||||||
|
|
||||||
s.tradingSession.UserDataStream.OnTradeUpdate(s.handleTradeUpdate)
|
s.tradingSession.UserDataStream.OnTradeUpdate(func(trade types.Trade) {
|
||||||
|
if trade.Symbol != s.Symbol {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.FeeBudget.HandleTradeUpdate(trade)
|
||||||
|
})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
ticker := time.NewTicker(
|
ticker := time.NewTicker(
|
||||||
|
@ -247,7 +182,7 @@ func (s *Strategy) CrossRun(ctx context.Context, _ bbgo.OrderExecutionRouter, se
|
||||||
return
|
return
|
||||||
|
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if !s.isBudgetAllowed() {
|
if !s.IsBudgetAllowed() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user