From a68764b763d89d6d870c68314b7b1aab63aea95d Mon Sep 17 00:00:00 2001 From: edwin Date: Mon, 30 Sep 2024 21:42:41 +0800 Subject: [PATCH] pkg/exchange: merge FeeRatePoller into StreamDataProvider --- pkg/exchange/bybit/exchange.go | 8 ++++---- pkg/exchange/bybit/market_info_poller.go | 16 ++++++++-------- pkg/exchange/bybit/market_info_poller_test.go | 4 ++-- pkg/exchange/bybit/stream.go | 10 +++++----- pkg/exchange/bybit/stream_test.go | 2 +- pkg/exchange/bybit/types_test.go | 8 ++++---- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/pkg/exchange/bybit/exchange.go b/pkg/exchange/bybit/exchange.go index ffd3265c2..629f086db 100644 --- a/pkg/exchange/bybit/exchange.go +++ b/pkg/exchange/bybit/exchange.go @@ -58,7 +58,7 @@ type Exchange struct { // Because the bybit exchange does not provide a fee currency on traditional SPOT accounts, we need to query the marker // fee rate to get the fee currency. // https://bybit-exchange.github.io/docs/v5/enum#spot-fee-currency-instruction - feeRateProvider FeeRatePoller + FeeRatePoller } func New(key, secret string) (*Exchange, error) { @@ -74,7 +74,7 @@ func New(key, secret string) (*Exchange, error) { } if len(key) > 0 && len(secret) > 0 { client.Auth(key, secret) - ex.feeRateProvider = newFeeRatePoller(ex) + ex.FeeRatePoller = newFeeRatePoller(ex) ctx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, errors.New("query markets timeout")) defer cancel() @@ -437,7 +437,7 @@ func (e *Exchange) queryTrades(ctx context.Context, req *bybitapi.GetExecutionLi } for _, trade := range res.List { - feeRate, err := pollAndGetFeeRate(ctx, trade.Symbol, e.feeRateProvider, e.marketsInfo) + feeRate, err := pollAndGetFeeRate(ctx, trade.Symbol, e.FeeRatePoller, e.marketsInfo) if err != nil { return nil, fmt.Errorf("failed to get fee rate, err: %v", err) } @@ -607,5 +607,5 @@ func (e *Exchange) GetAllFeeRates(ctx context.Context) (bybitapi.FeeRates, error } func (e *Exchange) NewStream() types.Stream { - return NewStream(e.key, e.secret, e, e.feeRateProvider) + return NewStream(e.key, e.secret, e) } diff --git a/pkg/exchange/bybit/market_info_poller.go b/pkg/exchange/bybit/market_info_poller.go index af7a2e2a2..460d4c6c4 100644 --- a/pkg/exchange/bybit/market_info_poller.go +++ b/pkg/exchange/bybit/market_info_poller.go @@ -22,9 +22,9 @@ var ( ) type FeeRatePoller interface { - Start(ctx context.Context) - Get(symbol string) (SymbolFeeDetail, bool) - Poll(ctx context.Context) error + StartFeeRatePoller(ctx context.Context) + GetFeeRate(symbol string) (SymbolFeeDetail, bool) + PollFeeRate(ctx context.Context) error } type SymbolFeeDetail struct { @@ -53,14 +53,14 @@ func newFeeRatePoller(marketInfoProvider MarketInfoProvider) *feeRatePoller { } } -func (p *feeRatePoller) Start(ctx context.Context) { +func (p *feeRatePoller) StartFeeRatePoller(ctx context.Context) { p.once.Do(func() { p.startLoop(ctx) }) } func (p *feeRatePoller) startLoop(ctx context.Context) { - err := p.Poll(ctx) + err := p.PollFeeRate(ctx) if err != nil { log.WithError(err).Warn("failed to initialize the fee rate, the ticker is scheduled to update it subsequently") } @@ -76,14 +76,14 @@ func (p *feeRatePoller) startLoop(ctx context.Context) { return case <-ticker.C: - if err := p.Poll(ctx); err != nil { + if err := p.PollFeeRate(ctx); err != nil { log.WithError(err).Warn("failed to update fee rate") } } } } -func (p *feeRatePoller) Poll(ctx context.Context) error { +func (p *feeRatePoller) PollFeeRate(ctx context.Context) error { p.mu.Lock() defer p.mu.Unlock() // the poll will be called frequently, so we need to check the last sync time. @@ -105,7 +105,7 @@ func (p *feeRatePoller) Poll(ctx context.Context) error { return nil } -func (p *feeRatePoller) Get(symbol string) (SymbolFeeDetail, bool) { +func (p *feeRatePoller) GetFeeRate(symbol string) (SymbolFeeDetail, bool) { p.mu.Lock() defer p.mu.Unlock() diff --git a/pkg/exchange/bybit/market_info_poller_test.go b/pkg/exchange/bybit/market_info_poller_test.go index 42401bec8..295577f50 100644 --- a/pkg/exchange/bybit/market_info_poller_test.go +++ b/pkg/exchange/bybit/market_info_poller_test.go @@ -154,7 +154,7 @@ func Test_feeRatePoller_Get(t *testing.T) { }, } - res, found := s.Get(symbol) + res, found := s.GetFeeRate(symbol) assert.True(t, found) assert.Equal(t, expFeeDetail, res) }) @@ -165,7 +165,7 @@ func Test_feeRatePoller_Get(t *testing.T) { symbolFeeDetail: map[string]SymbolFeeDetail{}, } - _, found := s.Get(symbol) + _, found := s.GetFeeRate(symbol) assert.False(t, found) }) } diff --git a/pkg/exchange/bybit/stream.go b/pkg/exchange/bybit/stream.go index 39b5c6805..2a7124dfa 100644 --- a/pkg/exchange/bybit/stream.go +++ b/pkg/exchange/bybit/stream.go @@ -51,6 +51,7 @@ type AccountBalanceProvider interface { type StreamDataProvider interface { MarketInfoProvider AccountBalanceProvider + FeeRatePoller } //go:generate callbackgen -type Stream @@ -70,14 +71,13 @@ type Stream struct { tradeEventCallbacks []func(e []TradeEvent) } -func NewStream(key, secret string, userDataProvider StreamDataProvider, poller FeeRatePoller) *Stream { +func NewStream(key, secret string, userDataProvider StreamDataProvider) *Stream { stream := &Stream{ StandardStream: types.NewStandardStream(), // pragma: allowlist nextline secret key: key, secret: secret, streamDataProvider: userDataProvider, - feeRateProvider: poller, } stream.SetEndpointCreator(stream.createEndpoint) @@ -91,7 +91,7 @@ func NewStream(key, secret string, userDataProvider StreamDataProvider, poller F } // get account fee rate - go stream.feeRateProvider.Start(ctx) + go stream.streamDataProvider.StartFeeRatePoller(ctx) stream.marketsInfo, err = stream.streamDataProvider.QueryMarkets(ctx) if err != nil { @@ -440,7 +440,7 @@ func (s *Stream) handleKLineEvent(klineEvent KLineEvent) { } func pollAndGetFeeRate(ctx context.Context, symbol string, poller FeeRatePoller, marketsInfo types.MarketMap) (SymbolFeeDetail, error) { - err := poller.Poll(ctx) + err := poller.PollFeeRate(ctx) if err != nil { return SymbolFeeDetail{}, err } @@ -448,7 +448,7 @@ func pollAndGetFeeRate(ctx context.Context, symbol string, poller FeeRatePoller, } func getFeeRate(symbol string, poller FeeRatePoller, marketsInfo types.MarketMap) SymbolFeeDetail { - feeRate, found := poller.Get(symbol) + feeRate, found := poller.GetFeeRate(symbol) if !found { feeRate = SymbolFeeDetail{ FeeRate: bybitapi.FeeRate{ diff --git a/pkg/exchange/bybit/stream_test.go b/pkg/exchange/bybit/stream_test.go index ba719a7c9..8ed651022 100644 --- a/pkg/exchange/bybit/stream_test.go +++ b/pkg/exchange/bybit/stream_test.go @@ -30,7 +30,7 @@ func getTestClientOrSkip(t *testing.T) *Stream { exchange, err := New(key, secret) assert.NoError(t, err) - return NewStream(key, secret, exchange, newFeeRatePoller(exchange)) + return NewStream(key, secret, exchange) } func TestStream(t *testing.T) { diff --git a/pkg/exchange/bybit/types_test.go b/pkg/exchange/bybit/types_test.go index bf4d97d20..0fa457cae 100644 --- a/pkg/exchange/bybit/types_test.go +++ b/pkg/exchange/bybit/types_test.go @@ -15,7 +15,7 @@ import ( func Test_parseWebSocketEvent(t *testing.T) { t.Run("[public] PingEvent without req id", func(t *testing.T) { - s := NewStream("", "", nil, nil) + s := NewStream("", "", nil) msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","op":"ping"}` raw, err := s.parseWebSocketEvent([]byte(msg)) assert.NoError(t, err) @@ -26,7 +26,7 @@ func Test_parseWebSocketEvent(t *testing.T) { }) t.Run("[public] PingEvent with req id", func(t *testing.T) { - s := NewStream("", "", nil, nil) + s := NewStream("", "", nil) msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","req_id":"b26704da-f5af-44c2-bdf7-935d6739e1a0","op":"ping"}` raw, err := s.parseWebSocketEvent([]byte(msg)) assert.NoError(t, err) @@ -37,7 +37,7 @@ func Test_parseWebSocketEvent(t *testing.T) { }) t.Run("[private] PingEvent without req id", func(t *testing.T) { - s := NewStream("", "", nil, nil) + s := NewStream("", "", nil) msg := `{"op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}` raw, err := s.parseWebSocketEvent([]byte(msg)) assert.NoError(t, err) @@ -48,7 +48,7 @@ func Test_parseWebSocketEvent(t *testing.T) { }) t.Run("[private] PingEvent with req id", func(t *testing.T) { - s := NewStream("", "", nil, nil) + s := NewStream("", "", nil) msg := `{"req_id":"78d36b57-a142-47b7-9143-5843df77d44d","op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}` raw, err := s.parseWebSocketEvent([]byte(msg)) assert.NoError(t, err)