pkg/exchange: move fee rate calculate method outside

This commit is contained in:
edwin 2024-09-30 21:27:05 +08:00
parent f6e58ded02
commit d2c1ae0642
3 changed files with 48 additions and 36 deletions

View File

@ -59,7 +59,7 @@ type Stream struct {
key, secret string
streamDataProvider StreamDataProvider
feeRateProvider *feeRatePoller
feeRateProvider FeeRatePoller
marketsInfo types.MarketMap
bookEventCallbacks []func(e BookEvent)
@ -70,14 +70,14 @@ type Stream struct {
tradeEventCallbacks []func(e []TradeEvent)
}
func NewStream(key, secret string, userDataProvider StreamDataProvider) *Stream {
func NewStream(key, secret string, userDataProvider StreamDataProvider, poller FeeRatePoller) *Stream {
stream := &Stream{
StandardStream: types.NewStandardStream(),
// pragma: allowlist nextline secret
key: key,
secret: secret,
streamDataProvider: userDataProvider,
feeRateProvider: newFeeRatePoller(userDataProvider),
feeRateProvider: poller,
}
stream.SetEndpointCreator(stream.createEndpoint)
@ -439,37 +439,49 @@ func (s *Stream) handleKLineEvent(klineEvent KLineEvent) {
}
}
func (s *Stream) handleTradeEvent(events []TradeEvent) {
for _, event := range events {
feeRate, found := s.feeRateProvider.Get(event.Symbol)
if !found {
feeRate = symbolFeeDetail{
FeeRate: bybitapi.FeeRate{
Symbol: event.Symbol,
TakerFeeRate: defaultTakerFee,
MakerFeeRate: defaultMakerFee,
},
BaseCoin: "",
QuoteCoin: "",
}
func pollAndGetFeeRate(ctx context.Context, symbol string, poller FeeRatePoller, marketsInfo types.MarketMap) (symbolFeeDetail, error) {
err := poller.Poll(ctx)
if err != nil {
return symbolFeeDetail{}, err
}
return getFeeRate(symbol, poller, marketsInfo), nil
}
if market, ok := s.marketsInfo[event.Symbol]; ok {
feeRate.BaseCoin = market.BaseCurrency
feeRate.QuoteCoin = market.QuoteCurrency
}
if tradeLogLimiter.Allow() {
// The error log level was utilized due to a detected discrepancy in the fee calculations.
log.Errorf("failed to get %s fee rate, use default taker fee %f, maker fee %f, base coin: %s, quote coin: %s",
event.Symbol,
feeRate.TakerFeeRate.Float64(),
feeRate.MakerFeeRate.Float64(),
feeRate.BaseCoin,
feeRate.QuoteCoin,
)
}
func getFeeRate(symbol string, poller FeeRatePoller, marketsInfo types.MarketMap) symbolFeeDetail {
feeRate, found := poller.Get(symbol)
if !found {
feeRate = symbolFeeDetail{
FeeRate: bybitapi.FeeRate{
Symbol: symbol,
TakerFeeRate: defaultTakerFee,
MakerFeeRate: defaultMakerFee,
},
BaseCoin: "",
QuoteCoin: "",
}
if market, ok := marketsInfo[symbol]; ok {
feeRate.BaseCoin = market.BaseCurrency
feeRate.QuoteCoin = market.QuoteCurrency
}
if tradeLogLimiter.Allow() {
// The error log level was utilized due to a detected discrepancy in the fee calculations.
log.Errorf("failed to get %s fee rate, use default taker fee %f, maker fee %f, base coin: %s, quote coin: %s",
symbol,
feeRate.TakerFeeRate.Float64(),
feeRate.MakerFeeRate.Float64(),
feeRate.BaseCoin,
feeRate.QuoteCoin,
)
}
}
return feeRate
}
func (s *Stream) handleTradeEvent(events []TradeEvent) {
for _, event := range events {
feeRate := getFeeRate(event.Symbol, s.feeRateProvider, s.marketsInfo)
gTrade, err := event.toGlobalTrade(feeRate)
if err != nil {
if tradeLogLimiter.Allow() {

View File

@ -30,7 +30,7 @@ func getTestClientOrSkip(t *testing.T) *Stream {
exchange, err := New(key, secret)
assert.NoError(t, err)
return NewStream(key, secret, exchange)
return NewStream(key, secret, exchange, newFeeRatePoller(exchange))
}
func TestStream(t *testing.T) {

View File

@ -15,7 +15,7 @@ import (
func Test_parseWebSocketEvent(t *testing.T) {
t.Run("[public] PingEvent without req id", func(t *testing.T) {
s := NewStream("", "", nil)
s := NewStream("", "", nil, nil)
msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","op":"ping"}`
raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err)
@ -26,7 +26,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
})
t.Run("[public] PingEvent with req id", func(t *testing.T) {
s := NewStream("", "", nil)
s := NewStream("", "", nil, nil)
msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","req_id":"b26704da-f5af-44c2-bdf7-935d6739e1a0","op":"ping"}`
raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err)
@ -37,7 +37,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
})
t.Run("[private] PingEvent without req id", func(t *testing.T) {
s := NewStream("", "", nil)
s := NewStream("", "", nil, nil)
msg := `{"op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}`
raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err)
@ -48,7 +48,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
})
t.Run("[private] PingEvent with req id", func(t *testing.T) {
s := NewStream("", "", nil)
s := NewStream("", "", nil, nil)
msg := `{"req_id":"78d36b57-a142-47b7-9143-5843df77d44d","op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}`
raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err)