pkg/exchange: move fee rate calculate method outside

This commit is contained in:
edwin 2024-09-30 21:27:05 +08:00
parent f6e58ded02
commit d2c1ae0642
3 changed files with 48 additions and 36 deletions

View File

@ -59,7 +59,7 @@ type Stream struct {
key, secret string key, secret string
streamDataProvider StreamDataProvider streamDataProvider StreamDataProvider
feeRateProvider *feeRatePoller feeRateProvider FeeRatePoller
marketsInfo types.MarketMap marketsInfo types.MarketMap
bookEventCallbacks []func(e BookEvent) bookEventCallbacks []func(e BookEvent)
@ -70,14 +70,14 @@ type Stream struct {
tradeEventCallbacks []func(e []TradeEvent) tradeEventCallbacks []func(e []TradeEvent)
} }
func NewStream(key, secret string, userDataProvider StreamDataProvider) *Stream { func NewStream(key, secret string, userDataProvider StreamDataProvider, poller FeeRatePoller) *Stream {
stream := &Stream{ stream := &Stream{
StandardStream: types.NewStandardStream(), StandardStream: types.NewStandardStream(),
// pragma: allowlist nextline secret // pragma: allowlist nextline secret
key: key, key: key,
secret: secret, secret: secret,
streamDataProvider: userDataProvider, streamDataProvider: userDataProvider,
feeRateProvider: newFeeRatePoller(userDataProvider), feeRateProvider: poller,
} }
stream.SetEndpointCreator(stream.createEndpoint) stream.SetEndpointCreator(stream.createEndpoint)
@ -439,13 +439,20 @@ func (s *Stream) handleKLineEvent(klineEvent KLineEvent) {
} }
} }
func (s *Stream) handleTradeEvent(events []TradeEvent) { func pollAndGetFeeRate(ctx context.Context, symbol string, poller FeeRatePoller, marketsInfo types.MarketMap) (symbolFeeDetail, error) {
for _, event := range events { err := poller.Poll(ctx)
feeRate, found := s.feeRateProvider.Get(event.Symbol) if err != nil {
return symbolFeeDetail{}, err
}
return getFeeRate(symbol, poller, marketsInfo), nil
}
func getFeeRate(symbol string, poller FeeRatePoller, marketsInfo types.MarketMap) symbolFeeDetail {
feeRate, found := poller.Get(symbol)
if !found { if !found {
feeRate = symbolFeeDetail{ feeRate = symbolFeeDetail{
FeeRate: bybitapi.FeeRate{ FeeRate: bybitapi.FeeRate{
Symbol: event.Symbol, Symbol: symbol,
TakerFeeRate: defaultTakerFee, TakerFeeRate: defaultTakerFee,
MakerFeeRate: defaultMakerFee, MakerFeeRate: defaultMakerFee,
}, },
@ -453,7 +460,7 @@ func (s *Stream) handleTradeEvent(events []TradeEvent) {
QuoteCoin: "", QuoteCoin: "",
} }
if market, ok := s.marketsInfo[event.Symbol]; ok { if market, ok := marketsInfo[symbol]; ok {
feeRate.BaseCoin = market.BaseCurrency feeRate.BaseCoin = market.BaseCurrency
feeRate.QuoteCoin = market.QuoteCurrency feeRate.QuoteCoin = market.QuoteCurrency
} }
@ -461,7 +468,7 @@ func (s *Stream) handleTradeEvent(events []TradeEvent) {
if tradeLogLimiter.Allow() { if tradeLogLimiter.Allow() {
// The error log level was utilized due to a detected discrepancy in the fee calculations. // The error log level was utilized due to a detected discrepancy in the fee calculations.
log.Errorf("failed to get %s fee rate, use default taker fee %f, maker fee %f, base coin: %s, quote coin: %s", log.Errorf("failed to get %s fee rate, use default taker fee %f, maker fee %f, base coin: %s, quote coin: %s",
event.Symbol, symbol,
feeRate.TakerFeeRate.Float64(), feeRate.TakerFeeRate.Float64(),
feeRate.MakerFeeRate.Float64(), feeRate.MakerFeeRate.Float64(),
feeRate.BaseCoin, feeRate.BaseCoin,
@ -469,7 +476,12 @@ func (s *Stream) handleTradeEvent(events []TradeEvent) {
) )
} }
} }
return feeRate
}
func (s *Stream) handleTradeEvent(events []TradeEvent) {
for _, event := range events {
feeRate := getFeeRate(event.Symbol, s.feeRateProvider, s.marketsInfo)
gTrade, err := event.toGlobalTrade(feeRate) gTrade, err := event.toGlobalTrade(feeRate)
if err != nil { if err != nil {
if tradeLogLimiter.Allow() { if tradeLogLimiter.Allow() {

View File

@ -30,7 +30,7 @@ func getTestClientOrSkip(t *testing.T) *Stream {
exchange, err := New(key, secret) exchange, err := New(key, secret)
assert.NoError(t, err) assert.NoError(t, err)
return NewStream(key, secret, exchange) return NewStream(key, secret, exchange, newFeeRatePoller(exchange))
} }
func TestStream(t *testing.T) { func TestStream(t *testing.T) {

View File

@ -15,7 +15,7 @@ import (
func Test_parseWebSocketEvent(t *testing.T) { func Test_parseWebSocketEvent(t *testing.T) {
t.Run("[public] PingEvent without req id", func(t *testing.T) { t.Run("[public] PingEvent without req id", func(t *testing.T) {
s := NewStream("", "", nil) s := NewStream("", "", nil, nil)
msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","op":"ping"}` msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","op":"ping"}`
raw, err := s.parseWebSocketEvent([]byte(msg)) raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err) assert.NoError(t, err)
@ -26,7 +26,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
}) })
t.Run("[public] PingEvent with req id", func(t *testing.T) { t.Run("[public] PingEvent with req id", func(t *testing.T) {
s := NewStream("", "", nil) s := NewStream("", "", nil, nil)
msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","req_id":"b26704da-f5af-44c2-bdf7-935d6739e1a0","op":"ping"}` msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","req_id":"b26704da-f5af-44c2-bdf7-935d6739e1a0","op":"ping"}`
raw, err := s.parseWebSocketEvent([]byte(msg)) raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err) assert.NoError(t, err)
@ -37,7 +37,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
}) })
t.Run("[private] PingEvent without req id", func(t *testing.T) { t.Run("[private] PingEvent without req id", func(t *testing.T) {
s := NewStream("", "", nil) s := NewStream("", "", nil, nil)
msg := `{"op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}` msg := `{"op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}`
raw, err := s.parseWebSocketEvent([]byte(msg)) raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err) assert.NoError(t, err)
@ -48,7 +48,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
}) })
t.Run("[private] PingEvent with req id", func(t *testing.T) { t.Run("[private] PingEvent with req id", func(t *testing.T) {
s := NewStream("", "", nil) s := NewStream("", "", nil, nil)
msg := `{"req_id":"78d36b57-a142-47b7-9143-5843df77d44d","op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}` msg := `{"req_id":"78d36b57-a142-47b7-9143-5843df77d44d","op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}`
raw, err := s.parseWebSocketEvent([]byte(msg)) raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err) assert.NoError(t, err)