pkg/exchange: to periodically fetch the fee rate

This commit is contained in:
Edwin 2023-11-06 22:17:29 +08:00
parent e773bb0e52
commit 82ac8f184f
4 changed files with 328 additions and 184 deletions

View File

@ -0,0 +1,137 @@
package bybit
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/c9s/bbgo/pkg/exchange/bybit/bybitapi"
"github.com/c9s/bbgo/pkg/util"
)
const (
// To maintain aligned fee rates, it's important to update fees frequently.
feeRatePollingPeriod = time.Minute
)
type symbolFeeDetail struct {
bybitapi.FeeRate
BaseCoin string
QuoteCoin string
}
// feeRatePoller pulls the specified market data from bbgo QueryMarkets.
type feeRatePoller struct {
mu sync.Mutex
once sync.Once
client MarketInfoProvider
symbolFeeDetail map[string]symbolFeeDetail
}
func newFeeRatePoller(marketInfoProvider MarketInfoProvider) *feeRatePoller {
return &feeRatePoller{
client: marketInfoProvider,
symbolFeeDetail: map[string]symbolFeeDetail{},
}
}
func (p *feeRatePoller) Start(ctx context.Context) {
p.once.Do(func() {
p.startLoop(ctx)
})
}
func (p *feeRatePoller) startLoop(ctx context.Context) {
ticker := time.NewTicker(feeRatePollingPeriod)
defer ticker.Stop()
// Make sure the first poll should succeed by retrying with a shorter period.
_ = util.Retry(ctx, util.InfiniteRetry, 30*time.Second,
func() error { return p.poll(ctx) },
func(e error) { log.WithError(e).Warn("failed to update fee rate") })
for {
select {
case <-ctx.Done():
if err := ctx.Err(); !errors.Is(err, context.Canceled) {
log.WithError(err).Error("context done with error")
}
return
case <-ticker.C:
if err := p.poll(ctx); err != nil {
log.WithError(err).Warn("failed to update fee rate")
}
}
}
}
func (p *feeRatePoller) poll(ctx context.Context) error {
symbolFeeRate, err := p.getAllFeeRates(ctx)
if err != nil {
return err
}
p.mu.Lock()
p.symbolFeeDetail = symbolFeeRate
p.mu.Unlock()
return nil
}
func (p *feeRatePoller) Get(symbol string) (symbolFeeDetail, error) {
p.mu.Lock()
defer p.mu.Unlock()
fee, ok := p.symbolFeeDetail[symbol]
if !ok {
return symbolFeeDetail{}, fmt.Errorf("%s fee rate not found", symbol)
}
return fee, nil
}
func (e *feeRatePoller) getAllFeeRates(ctx context.Context) (map[string]symbolFeeDetail, error) {
feeRates, err := e.client.GetAllFeeRates(ctx)
if err != nil {
return nil, fmt.Errorf("failed to call get fee rates: %w", err)
}
symbolMap := map[string]symbolFeeDetail{}
for _, f := range feeRates.List {
if _, found := symbolMap[f.Symbol]; !found {
symbolMap[f.Symbol] = symbolFeeDetail{FeeRate: f}
}
}
mkts, err := e.client.QueryMarkets(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get markets: %w", err)
}
// update base coin, quote coin into symbolFeeDetail
for _, mkt := range mkts {
feeRate, found := symbolMap[mkt.Symbol]
if !found {
continue
}
feeRate.BaseCoin = mkt.BaseCurrency
feeRate.QuoteCoin = mkt.QuoteCurrency
symbolMap[mkt.Symbol] = feeRate
}
// remove trading pairs that are not present in spot market.
for k, v := range symbolMap {
if len(v.BaseCoin) == 0 || len(v.QuoteCoin) == 0 {
log.Debugf("related market not found: %s, skipping the associated trade", k)
delete(symbolMap, k)
}
}
return symbolMap, nil
}

View File

@ -0,0 +1,173 @@
package bybit
import (
"context"
"fmt"
"testing"
"github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/exchange/bybit/bybitapi"
"github.com/c9s/bbgo/pkg/exchange/bybit/mocks"
"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/types"
)
func TestFeeRatePoller_getAllFeeRates(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
unknownErr := errors.New("unknown err")
t.Run("succeeds", func(t *testing.T) {
mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl)
s := &feeRatePoller{
client: mockMarketProvider,
}
ctx := context.Background()
feeRates := bybitapi.FeeRates{
List: []bybitapi.FeeRate{
{
Symbol: "BTCUSDT",
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
},
{
Symbol: "ETHUSDT",
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
},
{
Symbol: "OPTIONCOIN",
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
},
},
}
mkts := types.MarketMap{
"BTCUSDT": types.Market{
Symbol: "BTCUSDT",
QuoteCurrency: "USDT",
BaseCurrency: "BTC",
},
"ETHUSDT": types.Market{
Symbol: "ETHUSDT",
QuoteCurrency: "USDT",
BaseCurrency: "ETH",
},
}
mockMarketProvider.EXPECT().GetAllFeeRates(ctx).Return(feeRates, nil).Times(1)
mockMarketProvider.EXPECT().QueryMarkets(ctx).Return(mkts, nil).Times(1)
expFeeRates := map[string]symbolFeeDetail{
"BTCUSDT": {
FeeRate: feeRates.List[0],
BaseCoin: "BTC",
QuoteCoin: "USDT",
},
"ETHUSDT": {
FeeRate: feeRates.List[1],
BaseCoin: "ETH",
QuoteCoin: "USDT",
},
}
symbolFeeDetails, err := s.getAllFeeRates(ctx)
assert.NoError(t, err)
assert.Equal(t, expFeeRates, symbolFeeDetails)
})
t.Run("failed to query markets", func(t *testing.T) {
mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl)
s := &feeRatePoller{
client: mockMarketProvider,
}
ctx := context.Background()
feeRates := bybitapi.FeeRates{
List: []bybitapi.FeeRate{
{
Symbol: "BTCUSDT",
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
},
{
Symbol: "ETHUSDT",
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
},
{
Symbol: "OPTIONCOIN",
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
},
},
}
mockMarketProvider.EXPECT().GetAllFeeRates(ctx).Return(feeRates, nil).Times(1)
mockMarketProvider.EXPECT().QueryMarkets(ctx).Return(nil, unknownErr).Times(1)
symbolFeeDetails, err := s.getAllFeeRates(ctx)
assert.Equal(t, fmt.Errorf("failed to get markets: %w", unknownErr), err)
assert.Equal(t, map[string]symbolFeeDetail(nil), symbolFeeDetails)
})
t.Run("failed to get fee rates", func(t *testing.T) {
mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl)
s := &feeRatePoller{
client: mockMarketProvider,
}
ctx := context.Background()
mockMarketProvider.EXPECT().GetAllFeeRates(ctx).Return(bybitapi.FeeRates{}, unknownErr).Times(1)
symbolFeeDetails, err := s.getAllFeeRates(ctx)
assert.Equal(t, fmt.Errorf("failed to call get fee rates: %w", unknownErr), err)
assert.Equal(t, map[string]symbolFeeDetail(nil), symbolFeeDetails)
})
}
func Test_feeRatePoller_Get(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl)
t.Run("succeeds", func(t *testing.T) {
symbol := "BTCUSDT"
expFeeDetail := symbolFeeDetail{
FeeRate: bybitapi.FeeRate{
Symbol: symbol,
TakerFeeRate: fixedpoint.NewFromFloat(0.1),
MakerFeeRate: fixedpoint.NewFromFloat(0.2),
},
BaseCoin: "BTC",
QuoteCoin: "USDT",
}
s := &feeRatePoller{
client: mockMarketProvider,
symbolFeeDetail: map[string]symbolFeeDetail{
symbol: expFeeDetail,
},
}
res, err := s.Get(symbol)
assert.NoError(t, err)
assert.Equal(t, expFeeDetail, res)
})
t.Run("succeeds", func(t *testing.T) {
symbol := "BTCUSDT"
s := &feeRatePoller{
client: mockMarketProvider,
symbolFeeDetail: map[string]symbolFeeDetail{},
}
_, err := s.Get(symbol)
assert.ErrorContains(t, err, symbol)
})
}

View File

@ -47,8 +47,7 @@ type Stream struct {
key, secret string
streamDataProvider StreamDataProvider
// TODO: update the fee rate at 7:00 am UTC; rotation required.
symbolFeeDetails map[string]*symbolFeeDetail
feeRateProvider *feeRatePoller
bookEventCallbacks []func(e BookEvent)
marketTradeEventCallbacks []func(e []MarketTradeEvent)
@ -65,13 +64,17 @@ func NewStream(key, secret string, userDataProvider StreamDataProvider) *Stream
key: key,
secret: secret,
streamDataProvider: userDataProvider,
feeRateProvider: newFeeRatePoller(userDataProvider),
}
stream.SetEndpointCreator(stream.createEndpoint)
stream.SetParser(stream.parseWebSocketEvent)
stream.SetDispatcher(stream.dispatchEvent)
stream.SetHeartBeat(stream.ping)
stream.SetBeforeConnect(stream.getAllFeeRates)
stream.SetBeforeConnect(func(ctx context.Context) error {
go stream.feeRateProvider.Start(ctx)
return nil
})
stream.OnConnect(stream.handlerConnect)
stream.OnAuth(stream.handleAuthEvent)
@ -403,13 +406,13 @@ func (s *Stream) handleKLineEvent(klineEvent KLineEvent) {
func (s *Stream) handleTradeEvent(events []TradeEvent) {
for _, event := range events {
feeRate, found := s.symbolFeeDetails[event.Symbol]
if !found {
log.Warnf("unexpected symbol found, fee rate not supported, symbol: %s", event.Symbol)
feeRate, err := s.feeRateProvider.Get(event.Symbol)
if err != nil {
log.Warnf("failed to get fee rate by symbol: %s", event.Symbol)
continue
}
gTrade, err := event.toGlobalTrade(*feeRate)
gTrade, err := event.toGlobalTrade(feeRate)
if err != nil {
log.WithError(err).Errorf("unable to convert: %+v", event)
continue
@ -417,53 +420,3 @@ func (s *Stream) handleTradeEvent(events []TradeEvent) {
s.StandardStream.EmitTradeUpdate(*gTrade)
}
}
type symbolFeeDetail struct {
bybitapi.FeeRate
BaseCoin string
QuoteCoin string
}
// getAllFeeRates retrieves all fee rates from the Bybit API and then fetches markets to ensure the base coin and quote coin
// are correct.
func (e *Stream) getAllFeeRates(ctx context.Context) error {
feeRates, err := e.streamDataProvider.GetAllFeeRates(ctx)
if err != nil {
return fmt.Errorf("failed to call get fee rates: %w", err)
}
symbolMap := map[string]*symbolFeeDetail{}
for _, f := range feeRates.List {
if _, found := symbolMap[f.Symbol]; !found {
symbolMap[f.Symbol] = &symbolFeeDetail{FeeRate: f}
}
}
mkts, err := e.streamDataProvider.QueryMarkets(ctx)
if err != nil {
return fmt.Errorf("failed to get markets: %w", err)
}
// update base coin, quote coin into symbolFeeDetail
for _, mkt := range mkts {
feeRate, found := symbolMap[mkt.Symbol]
if !found {
continue
}
feeRate.BaseCoin = mkt.BaseCurrency
feeRate.QuoteCoin = mkt.QuoteCurrency
}
// remove trading pairs that are not present in spot market.
for k, v := range symbolMap {
if len(v.BaseCoin) == 0 || len(v.QuoteCoin) == 0 {
log.Debugf("related market not found: %s, skipping the associated trade", k)
delete(symbolMap, k)
}
}
e.symbolFeeDetails = symbolMap
return nil
}

View File

@ -9,11 +9,9 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/exchange/bybit/bybitapi"
"github.com/c9s/bbgo/pkg/exchange/bybit/mocks"
"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/testutil"
"github.com/c9s/bbgo/pkg/types"
@ -36,7 +34,7 @@ func getTestClientOrSkip(t *testing.T) *Stream {
}
func TestStream(t *testing.T) {
t.Skip()
//t.Skip()
s := getTestClientOrSkip(t)
symbols := []string{
@ -70,12 +68,12 @@ func TestStream(t *testing.T) {
err := s.Connect(context.Background())
assert.NoError(t, err)
s.OnBookSnapshot(func(book types.SliceOrderBook) {
t.Log("got snapshot", book)
})
s.OnBookUpdate(func(book types.SliceOrderBook) {
t.Log("got update", book)
})
//s.OnBookSnapshot(func(book types.SliceOrderBook) {
// t.Log("got snapshot", book)
//})
//s.OnBookUpdate(func(book types.SliceOrderBook) {
// t.Log("got update", book)
//})
c := make(chan struct{})
<-c
})
@ -175,7 +173,7 @@ func TestStream(t *testing.T) {
assert.NoError(t, err)
s.OnTradeUpdate(func(trade types.Trade) {
t.Log("got update", trade)
t.Log("got update", trade.Fee, trade.FeeCurrency, trade)
})
c := make(chan struct{})
<-c
@ -467,120 +465,3 @@ func Test_convertSubscription(t *testing.T) {
assert.Equal(t, genTopic(TopicTypeMarketTrade, "BTCUSDT"), res)
})
}
func TestStream_getFeeRate(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
unknownErr := errors.New("unknown err")
t.Run("succeeds", func(t *testing.T) {
mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl)
s := &Stream{
streamDataProvider: mockMarketProvider,
}
ctx := context.Background()
feeRates := bybitapi.FeeRates{
List: []bybitapi.FeeRate{
{
Symbol: "BTCUSDT",
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
},
{
Symbol: "ETHUSDT",
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
},
{
Symbol: "OPTIONCOIN",
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
},
},
}
mkts := types.MarketMap{
"BTCUSDT": types.Market{
Symbol: "BTCUSDT",
QuoteCurrency: "USDT",
BaseCurrency: "BTC",
},
"ETHUSDT": types.Market{
Symbol: "ETHUSDT",
QuoteCurrency: "USDT",
BaseCurrency: "ETH",
},
}
mockMarketProvider.EXPECT().GetAllFeeRates(ctx).Return(feeRates, nil).Times(1)
mockMarketProvider.EXPECT().QueryMarkets(ctx).Return(mkts, nil).Times(1)
expFeeRates := map[string]*symbolFeeDetail{
"BTCUSDT": {
FeeRate: feeRates.List[0],
BaseCoin: "BTC",
QuoteCoin: "USDT",
},
"ETHUSDT": {
FeeRate: feeRates.List[1],
BaseCoin: "ETH",
QuoteCoin: "USDT",
},
}
err := s.getAllFeeRates(ctx)
assert.NoError(t, err)
assert.Equal(t, expFeeRates, s.symbolFeeDetails)
})
t.Run("failed to query markets", func(t *testing.T) {
mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl)
s := &Stream{
streamDataProvider: mockMarketProvider,
}
ctx := context.Background()
feeRates := bybitapi.FeeRates{
List: []bybitapi.FeeRate{
{
Symbol: "BTCUSDT",
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
},
{
Symbol: "ETHUSDT",
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
},
{
Symbol: "OPTIONCOIN",
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
},
},
}
mockMarketProvider.EXPECT().GetAllFeeRates(ctx).Return(feeRates, nil).Times(1)
mockMarketProvider.EXPECT().QueryMarkets(ctx).Return(nil, unknownErr).Times(1)
err := s.getAllFeeRates(ctx)
assert.Equal(t, fmt.Errorf("failed to get markets: %w", unknownErr), err)
assert.Equal(t, map[string]*symbolFeeDetail(nil), s.symbolFeeDetails)
})
t.Run("failed to get fee rates", func(t *testing.T) {
mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl)
s := &Stream{
streamDataProvider: mockMarketProvider,
}
ctx := context.Background()
mockMarketProvider.EXPECT().GetAllFeeRates(ctx).Return(bybitapi.FeeRates{}, unknownErr).Times(1)
err := s.getAllFeeRates(ctx)
assert.Equal(t, fmt.Errorf("failed to call get fee rates: %w", unknownErr), err)
assert.Equal(t, map[string]*symbolFeeDetail(nil), s.symbolFeeDetails)
})
}