mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-26 00:35:15 +00:00
pkg/exchange: get fee rate before connect
This commit is contained in:
parent
509f9ac8ca
commit
affff32599
|
@ -490,7 +490,7 @@ func (e *Exchange) IsSupportedInterval(interval types.Interval) bool {
|
|||
return ok
|
||||
}
|
||||
|
||||
func (e *Exchange) GetFeeRates(ctx context.Context) (bybitapi.FeeRates, error) {
|
||||
func (e *Exchange) GetAllFeeRates(ctx context.Context) (bybitapi.FeeRates, error) {
|
||||
if err := sharedRateLimiter.Wait(ctx); err != nil {
|
||||
return bybitapi.FeeRates{}, fmt.Errorf("query fee rate limiter wait error: %w", err)
|
||||
}
|
||||
|
@ -503,5 +503,5 @@ func (e *Exchange) GetFeeRates(ctx context.Context) (bybitapi.FeeRates, error) {
|
|||
}
|
||||
|
||||
func (e *Exchange) NewStream() types.Stream {
|
||||
return NewStream(e.key, e.secret)
|
||||
return NewStream(e.key, e.secret, e)
|
||||
}
|
||||
|
|
67
pkg/exchange/bybit/mocks/stream.go
Normal file
67
pkg/exchange/bybit/mocks/stream.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/c9s/bbgo/pkg/exchange/bybit (interfaces: MarketInfoProvider)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
bybitapi "github.com/c9s/bbgo/pkg/exchange/bybit/bybitapi"
|
||||
types "github.com/c9s/bbgo/pkg/types"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockMarketInfoProvider is a mock of MarketInfoProvider interface.
|
||||
type MockMarketInfoProvider struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockMarketInfoProviderMockRecorder
|
||||
}
|
||||
|
||||
// MockMarketInfoProviderMockRecorder is the mock recorder for MockMarketInfoProvider.
|
||||
type MockMarketInfoProviderMockRecorder struct {
|
||||
mock *MockMarketInfoProvider
|
||||
}
|
||||
|
||||
// NewMockMarketInfoProvider creates a new mock instance.
|
||||
func NewMockMarketInfoProvider(ctrl *gomock.Controller) *MockMarketInfoProvider {
|
||||
mock := &MockMarketInfoProvider{ctrl: ctrl}
|
||||
mock.recorder = &MockMarketInfoProviderMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockMarketInfoProvider) EXPECT() *MockMarketInfoProviderMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetFeeRates mocks base method.
|
||||
func (m *MockMarketInfoProvider) GetFeeRates(arg0 context.Context) (bybitapi.FeeRates, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetFeeRates", arg0)
|
||||
ret0, _ := ret[0].(bybitapi.FeeRates)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetFeeRates indicates an expected call of GetFeeRates.
|
||||
func (mr *MockMarketInfoProviderMockRecorder) GetFeeRates(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFeeRates", reflect.TypeOf((*MockMarketInfoProvider)(nil).GetFeeRates), arg0)
|
||||
}
|
||||
|
||||
// QueryMarkets mocks base method.
|
||||
func (m *MockMarketInfoProvider) QueryMarkets(arg0 context.Context) (types.MarketMap, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "QueryMarkets", arg0)
|
||||
ret0, _ := ret[0].(types.MarketMap)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// QueryMarkets indicates an expected call of QueryMarkets.
|
||||
func (mr *MockMarketInfoProviderMockRecorder) QueryMarkets(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryMarkets", reflect.TypeOf((*MockMarketInfoProvider)(nil).QueryMarkets), arg0)
|
||||
}
|
|
@ -27,31 +27,43 @@ var (
|
|||
wsAuthRequest = 10 * time.Second
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=mocks/stream.go -package=mocks . MarketInfoProvider
|
||||
type MarketInfoProvider interface {
|
||||
GetFeeRates(ctx context.Context) (bybitapi.FeeRates, error)
|
||||
QueryMarkets(ctx context.Context) (types.MarketMap, error)
|
||||
}
|
||||
|
||||
//go:generate callbackgen -type Stream
|
||||
type Stream struct {
|
||||
key, secret string
|
||||
types.StandardStream
|
||||
|
||||
key, secret string
|
||||
marketProvider MarketInfoProvider
|
||||
// TODO: update the fee rate at 7:00 am UTC; rotation required.
|
||||
symbolFeeDetails map[string]*symbolFeeDetail
|
||||
|
||||
bookEventCallbacks []func(e BookEvent)
|
||||
walletEventCallbacks []func(e []bybitapi.WalletBalances)
|
||||
kLineEventCallbacks []func(e KLineEvent)
|
||||
orderEventCallbacks []func(e []OrderEvent)
|
||||
}
|
||||
|
||||
func NewStream(key, secret string) *Stream {
|
||||
func NewStream(key, secret string, marketProvider MarketInfoProvider) *Stream {
|
||||
stream := &Stream{
|
||||
StandardStream: types.NewStandardStream(),
|
||||
// pragma: allowlist nextline secret
|
||||
key: key,
|
||||
secret: secret,
|
||||
key: key,
|
||||
secret: secret,
|
||||
marketProvider: marketProvider,
|
||||
}
|
||||
|
||||
stream.SetEndpointCreator(stream.createEndpoint)
|
||||
stream.SetParser(stream.parseWebSocketEvent)
|
||||
stream.SetDispatcher(stream.dispatchEvent)
|
||||
stream.SetHeartBeat(stream.ping)
|
||||
|
||||
stream.SetBeforeConnect(stream.getAllFeeRates)
|
||||
stream.OnConnect(stream.handlerConnect)
|
||||
|
||||
stream.OnBookEvent(stream.handleBookEvent)
|
||||
stream.OnKLineEvent(stream.handleKLineEvent)
|
||||
stream.OnWalletEvent(stream.handleWalletEvent)
|
||||
|
@ -307,3 +319,53 @@ func (s *Stream) handleKLineEvent(klineEvent KLineEvent) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
type symbolFeeDetail struct {
|
||||
bybitapi.FeeRate
|
||||
|
||||
BaseCoin string
|
||||
QuoteCoin string
|
||||
}
|
||||
|
||||
// getAllFeeRates retrieves all fee rates from the Bybit API and then fetches markets to ensure the base coin and quote coin
|
||||
// are correct.
|
||||
func (e *Stream) getAllFeeRates(ctx context.Context) error {
|
||||
feeRates, err := e.marketProvider.GetFeeRates(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to call get fee rates: %w", err)
|
||||
}
|
||||
|
||||
symbolMap := map[string]*symbolFeeDetail{}
|
||||
for _, f := range feeRates.List {
|
||||
if _, found := symbolMap[f.Symbol]; !found {
|
||||
symbolMap[f.Symbol] = &symbolFeeDetail{FeeRate: f}
|
||||
}
|
||||
}
|
||||
|
||||
mkts, err := e.marketProvider.QueryMarkets(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get markets: %w", err)
|
||||
}
|
||||
|
||||
// update base coin, quote coin into symbolFeeDetail
|
||||
for _, mkt := range mkts {
|
||||
feeRate, found := symbolMap[mkt.Symbol]
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
|
||||
feeRate.BaseCoin = mkt.BaseCurrency
|
||||
feeRate.QuoteCoin = mkt.QuoteCurrency
|
||||
}
|
||||
|
||||
// remove trading pairs that are not present in spot market.
|
||||
for k, v := range symbolMap {
|
||||
if len(v.BaseCoin) == 0 || len(v.QuoteCoin) == 0 {
|
||||
log.Debugf("related market not found: %s, skipping the associated trade", k)
|
||||
delete(symbolMap, k)
|
||||
}
|
||||
}
|
||||
|
||||
e.symbolFeeDetails = symbolMap
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -8,8 +8,11 @@ import (
|
|||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/exchange/bybit/bybitapi"
|
||||
"github.com/c9s/bbgo/pkg/exchange/bybit/mocks"
|
||||
"github.com/c9s/bbgo/pkg/fixedpoint"
|
||||
"github.com/c9s/bbgo/pkg/testutil"
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
|
@ -26,7 +29,9 @@ func getTestClientOrSkip(t *testing.T) *Stream {
|
|||
return nil
|
||||
}
|
||||
|
||||
return NewStream(key, secret)
|
||||
exchange, err := New(key, secret)
|
||||
assert.NoError(t, err)
|
||||
return NewStream(key, secret, exchange)
|
||||
}
|
||||
|
||||
func TestStream(t *testing.T) {
|
||||
|
@ -312,3 +317,120 @@ func Test_convertSubscription(t *testing.T) {
|
|||
assert.Equal(t, "", res)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStream_getFeeRate(t *testing.T) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
|
||||
unknownErr := errors.New("unknown err")
|
||||
|
||||
t.Run("succeeds", func(t *testing.T) {
|
||||
mockMarketProvider := mocks.NewMockMarketInfoProvider(mockCtrl)
|
||||
s := &Stream{
|
||||
marketProvider: mockMarketProvider,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
feeRates := bybitapi.FeeRates{
|
||||
List: []bybitapi.FeeRate{
|
||||
{
|
||||
Symbol: "BTCUSDT",
|
||||
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
|
||||
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
|
||||
},
|
||||
{
|
||||
Symbol: "ETHUSDT",
|
||||
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
|
||||
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
|
||||
},
|
||||
{
|
||||
Symbol: "OPTIONCOIN",
|
||||
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
|
||||
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mkts := types.MarketMap{
|
||||
"BTCUSDT": types.Market{
|
||||
Symbol: "BTCUSDT",
|
||||
QuoteCurrency: "USDT",
|
||||
BaseCurrency: "BTC",
|
||||
},
|
||||
"ETHUSDT": types.Market{
|
||||
Symbol: "ETHUSDT",
|
||||
QuoteCurrency: "USDT",
|
||||
BaseCurrency: "ETH",
|
||||
},
|
||||
}
|
||||
|
||||
mockMarketProvider.EXPECT().GetFeeRates(ctx).Return(feeRates, nil).Times(1)
|
||||
mockMarketProvider.EXPECT().QueryMarkets(ctx).Return(mkts, nil).Times(1)
|
||||
|
||||
expFeeRates := map[string]*symbolFeeDetail{
|
||||
"BTCUSDT": {
|
||||
FeeRate: feeRates.List[0],
|
||||
BaseCoin: "BTC",
|
||||
QuoteCoin: "USDT",
|
||||
},
|
||||
"ETHUSDT": {
|
||||
FeeRate: feeRates.List[1],
|
||||
BaseCoin: "ETH",
|
||||
QuoteCoin: "USDT",
|
||||
},
|
||||
}
|
||||
err := s.getFeeRate(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expFeeRates, s.symbolFeeDetails)
|
||||
})
|
||||
|
||||
t.Run("failed to query markets", func(t *testing.T) {
|
||||
mockMarketProvider := mocks.NewMockMarketInfoProvider(mockCtrl)
|
||||
s := &Stream{
|
||||
marketProvider: mockMarketProvider,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
feeRates := bybitapi.FeeRates{
|
||||
List: []bybitapi.FeeRate{
|
||||
{
|
||||
Symbol: "BTCUSDT",
|
||||
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
|
||||
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
|
||||
},
|
||||
{
|
||||
Symbol: "ETHUSDT",
|
||||
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
|
||||
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
|
||||
},
|
||||
{
|
||||
Symbol: "OPTIONCOIN",
|
||||
TakerFeeRate: fixedpoint.NewFromFloat(0.001),
|
||||
MakerFeeRate: fixedpoint.NewFromFloat(0.001),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockMarketProvider.EXPECT().GetFeeRates(ctx).Return(feeRates, nil).Times(1)
|
||||
mockMarketProvider.EXPECT().QueryMarkets(ctx).Return(nil, unknownErr).Times(1)
|
||||
|
||||
err := s.getFeeRate(ctx)
|
||||
assert.Equal(t, fmt.Errorf("failed to get markets: %w", unknownErr), err)
|
||||
assert.Equal(t, map[string]*symbolFeeDetail(nil), s.symbolFeeDetails)
|
||||
})
|
||||
|
||||
t.Run("failed to get fee rates", func(t *testing.T) {
|
||||
mockMarketProvider := mocks.NewMockMarketInfoProvider(mockCtrl)
|
||||
s := &Stream{
|
||||
marketProvider: mockMarketProvider,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
mockMarketProvider.EXPECT().GetFeeRates(ctx).Return(bybitapi.FeeRates{}, unknownErr).Times(1)
|
||||
|
||||
err := s.getFeeRate(ctx)
|
||||
assert.Equal(t, fmt.Errorf("failed to call get fee rates: %w", unknownErr), err)
|
||||
assert.Equal(t, map[string]*symbolFeeDetail(nil), s.symbolFeeDetails)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -2,17 +2,20 @@ package bybit
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/exchange/bybit/bybitapi"
|
||||
"github.com/c9s/bbgo/pkg/fixedpoint"
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
)
|
||||
|
||||
func Test_parseWebSocketEvent(t *testing.T) {
|
||||
t.Run("[public] PingEvent without req id", func(t *testing.T) {
|
||||
s := NewStream("", "")
|
||||
s := NewStream("", "", nil)
|
||||
msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","op":"ping"}`
|
||||
raw, err := s.parseWebSocketEvent([]byte(msg))
|
||||
assert.NoError(t, err)
|
||||
|
@ -33,7 +36,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("[public] PingEvent with req id", func(t *testing.T) {
|
||||
s := NewStream("", "")
|
||||
s := NewStream("", "", nil)
|
||||
msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","req_id":"b26704da-f5af-44c2-bdf7-935d6739e1a0","op":"ping"}`
|
||||
raw, err := s.parseWebSocketEvent([]byte(msg))
|
||||
assert.NoError(t, err)
|
||||
|
@ -55,7 +58,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("[private] PingEvent without req id", func(t *testing.T) {
|
||||
s := NewStream("", "")
|
||||
s := NewStream("", "", nil)
|
||||
msg := `{"op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}`
|
||||
raw, err := s.parseWebSocketEvent([]byte(msg))
|
||||
assert.NoError(t, err)
|
||||
|
@ -75,7 +78,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("[private] PingEvent with req id", func(t *testing.T) {
|
||||
s := NewStream("", "")
|
||||
s := NewStream("", "", nil)
|
||||
msg := `{"req_id":"78d36b57-a142-47b7-9143-5843df77d44d","op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}`
|
||||
raw, err := s.parseWebSocketEvent([]byte(msg))
|
||||
assert.NoError(t, err)
|
||||
|
|
Loading…
Reference in New Issue
Block a user