pkg/exchange: merge FeeRatePoller into StreamDataProvider

This commit is contained in:
edwin 2024-09-30 21:42:41 +08:00
parent 7f1e1a3a51
commit a68764b763
6 changed files with 24 additions and 24 deletions

View File

@ -58,7 +58,7 @@ type Exchange struct {
// Because the bybit exchange does not provide a fee currency on traditional SPOT accounts, we need to query the marker // Because the bybit exchange does not provide a fee currency on traditional SPOT accounts, we need to query the marker
// fee rate to get the fee currency. // fee rate to get the fee currency.
// https://bybit-exchange.github.io/docs/v5/enum#spot-fee-currency-instruction // https://bybit-exchange.github.io/docs/v5/enum#spot-fee-currency-instruction
feeRateProvider FeeRatePoller FeeRatePoller
} }
func New(key, secret string) (*Exchange, error) { func New(key, secret string) (*Exchange, error) {
@ -74,7 +74,7 @@ func New(key, secret string) (*Exchange, error) {
} }
if len(key) > 0 && len(secret) > 0 { if len(key) > 0 && len(secret) > 0 {
client.Auth(key, secret) client.Auth(key, secret)
ex.feeRateProvider = newFeeRatePoller(ex) ex.FeeRatePoller = newFeeRatePoller(ex)
ctx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, errors.New("query markets timeout")) ctx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, errors.New("query markets timeout"))
defer cancel() defer cancel()
@ -437,7 +437,7 @@ func (e *Exchange) queryTrades(ctx context.Context, req *bybitapi.GetExecutionLi
} }
for _, trade := range res.List { for _, trade := range res.List {
feeRate, err := pollAndGetFeeRate(ctx, trade.Symbol, e.feeRateProvider, e.marketsInfo) feeRate, err := pollAndGetFeeRate(ctx, trade.Symbol, e.FeeRatePoller, e.marketsInfo)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get fee rate, err: %v", err) return nil, fmt.Errorf("failed to get fee rate, err: %v", err)
} }
@ -607,5 +607,5 @@ func (e *Exchange) GetAllFeeRates(ctx context.Context) (bybitapi.FeeRates, error
} }
func (e *Exchange) NewStream() types.Stream { func (e *Exchange) NewStream() types.Stream {
return NewStream(e.key, e.secret, e, e.feeRateProvider) return NewStream(e.key, e.secret, e)
} }

View File

@ -22,9 +22,9 @@ var (
) )
type FeeRatePoller interface { type FeeRatePoller interface {
Start(ctx context.Context) StartFeeRatePoller(ctx context.Context)
Get(symbol string) (SymbolFeeDetail, bool) GetFeeRate(symbol string) (SymbolFeeDetail, bool)
Poll(ctx context.Context) error PollFeeRate(ctx context.Context) error
} }
type SymbolFeeDetail struct { type SymbolFeeDetail struct {
@ -53,14 +53,14 @@ func newFeeRatePoller(marketInfoProvider MarketInfoProvider) *feeRatePoller {
} }
} }
func (p *feeRatePoller) Start(ctx context.Context) { func (p *feeRatePoller) StartFeeRatePoller(ctx context.Context) {
p.once.Do(func() { p.once.Do(func() {
p.startLoop(ctx) p.startLoop(ctx)
}) })
} }
func (p *feeRatePoller) startLoop(ctx context.Context) { func (p *feeRatePoller) startLoop(ctx context.Context) {
err := p.Poll(ctx) err := p.PollFeeRate(ctx)
if err != nil { if err != nil {
log.WithError(err).Warn("failed to initialize the fee rate, the ticker is scheduled to update it subsequently") log.WithError(err).Warn("failed to initialize the fee rate, the ticker is scheduled to update it subsequently")
} }
@ -76,14 +76,14 @@ func (p *feeRatePoller) startLoop(ctx context.Context) {
return return
case <-ticker.C: case <-ticker.C:
if err := p.Poll(ctx); err != nil { if err := p.PollFeeRate(ctx); err != nil {
log.WithError(err).Warn("failed to update fee rate") log.WithError(err).Warn("failed to update fee rate")
} }
} }
} }
} }
func (p *feeRatePoller) Poll(ctx context.Context) error { func (p *feeRatePoller) PollFeeRate(ctx context.Context) error {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
// the poll will be called frequently, so we need to check the last sync time. // the poll will be called frequently, so we need to check the last sync time.
@ -105,7 +105,7 @@ func (p *feeRatePoller) Poll(ctx context.Context) error {
return nil return nil
} }
func (p *feeRatePoller) Get(symbol string) (SymbolFeeDetail, bool) { func (p *feeRatePoller) GetFeeRate(symbol string) (SymbolFeeDetail, bool) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()

View File

@ -154,7 +154,7 @@ func Test_feeRatePoller_Get(t *testing.T) {
}, },
} }
res, found := s.Get(symbol) res, found := s.GetFeeRate(symbol)
assert.True(t, found) assert.True(t, found)
assert.Equal(t, expFeeDetail, res) assert.Equal(t, expFeeDetail, res)
}) })
@ -165,7 +165,7 @@ func Test_feeRatePoller_Get(t *testing.T) {
symbolFeeDetail: map[string]SymbolFeeDetail{}, symbolFeeDetail: map[string]SymbolFeeDetail{},
} }
_, found := s.Get(symbol) _, found := s.GetFeeRate(symbol)
assert.False(t, found) assert.False(t, found)
}) })
} }

View File

@ -51,6 +51,7 @@ type AccountBalanceProvider interface {
type StreamDataProvider interface { type StreamDataProvider interface {
MarketInfoProvider MarketInfoProvider
AccountBalanceProvider AccountBalanceProvider
FeeRatePoller
} }
//go:generate callbackgen -type Stream //go:generate callbackgen -type Stream
@ -70,14 +71,13 @@ type Stream struct {
tradeEventCallbacks []func(e []TradeEvent) tradeEventCallbacks []func(e []TradeEvent)
} }
func NewStream(key, secret string, userDataProvider StreamDataProvider, poller FeeRatePoller) *Stream { func NewStream(key, secret string, userDataProvider StreamDataProvider) *Stream {
stream := &Stream{ stream := &Stream{
StandardStream: types.NewStandardStream(), StandardStream: types.NewStandardStream(),
// pragma: allowlist nextline secret // pragma: allowlist nextline secret
key: key, key: key,
secret: secret, secret: secret,
streamDataProvider: userDataProvider, streamDataProvider: userDataProvider,
feeRateProvider: poller,
} }
stream.SetEndpointCreator(stream.createEndpoint) stream.SetEndpointCreator(stream.createEndpoint)
@ -91,7 +91,7 @@ func NewStream(key, secret string, userDataProvider StreamDataProvider, poller F
} }
// get account fee rate // get account fee rate
go stream.feeRateProvider.Start(ctx) go stream.streamDataProvider.StartFeeRatePoller(ctx)
stream.marketsInfo, err = stream.streamDataProvider.QueryMarkets(ctx) stream.marketsInfo, err = stream.streamDataProvider.QueryMarkets(ctx)
if err != nil { if err != nil {
@ -440,7 +440,7 @@ func (s *Stream) handleKLineEvent(klineEvent KLineEvent) {
} }
func pollAndGetFeeRate(ctx context.Context, symbol string, poller FeeRatePoller, marketsInfo types.MarketMap) (SymbolFeeDetail, error) { func pollAndGetFeeRate(ctx context.Context, symbol string, poller FeeRatePoller, marketsInfo types.MarketMap) (SymbolFeeDetail, error) {
err := poller.Poll(ctx) err := poller.PollFeeRate(ctx)
if err != nil { if err != nil {
return SymbolFeeDetail{}, err return SymbolFeeDetail{}, err
} }
@ -448,7 +448,7 @@ func pollAndGetFeeRate(ctx context.Context, symbol string, poller FeeRatePoller,
} }
func getFeeRate(symbol string, poller FeeRatePoller, marketsInfo types.MarketMap) SymbolFeeDetail { func getFeeRate(symbol string, poller FeeRatePoller, marketsInfo types.MarketMap) SymbolFeeDetail {
feeRate, found := poller.Get(symbol) feeRate, found := poller.GetFeeRate(symbol)
if !found { if !found {
feeRate = SymbolFeeDetail{ feeRate = SymbolFeeDetail{
FeeRate: bybitapi.FeeRate{ FeeRate: bybitapi.FeeRate{

View File

@ -30,7 +30,7 @@ func getTestClientOrSkip(t *testing.T) *Stream {
exchange, err := New(key, secret) exchange, err := New(key, secret)
assert.NoError(t, err) assert.NoError(t, err)
return NewStream(key, secret, exchange, newFeeRatePoller(exchange)) return NewStream(key, secret, exchange)
} }
func TestStream(t *testing.T) { func TestStream(t *testing.T) {

View File

@ -15,7 +15,7 @@ import (
func Test_parseWebSocketEvent(t *testing.T) { func Test_parseWebSocketEvent(t *testing.T) {
t.Run("[public] PingEvent without req id", func(t *testing.T) { t.Run("[public] PingEvent without req id", func(t *testing.T) {
s := NewStream("", "", nil, nil) s := NewStream("", "", nil)
msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","op":"ping"}` msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","op":"ping"}`
raw, err := s.parseWebSocketEvent([]byte(msg)) raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err) assert.NoError(t, err)
@ -26,7 +26,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
}) })
t.Run("[public] PingEvent with req id", func(t *testing.T) { t.Run("[public] PingEvent with req id", func(t *testing.T) {
s := NewStream("", "", nil, nil) s := NewStream("", "", nil)
msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","req_id":"b26704da-f5af-44c2-bdf7-935d6739e1a0","op":"ping"}` msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","req_id":"b26704da-f5af-44c2-bdf7-935d6739e1a0","op":"ping"}`
raw, err := s.parseWebSocketEvent([]byte(msg)) raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err) assert.NoError(t, err)
@ -37,7 +37,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
}) })
t.Run("[private] PingEvent without req id", func(t *testing.T) { t.Run("[private] PingEvent without req id", func(t *testing.T) {
s := NewStream("", "", nil, nil) s := NewStream("", "", nil)
msg := `{"op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}` msg := `{"op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}`
raw, err := s.parseWebSocketEvent([]byte(msg)) raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err) assert.NoError(t, err)
@ -48,7 +48,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
}) })
t.Run("[private] PingEvent with req id", func(t *testing.T) { t.Run("[private] PingEvent with req id", func(t *testing.T) {
s := NewStream("", "", nil, nil) s := NewStream("", "", nil)
msg := `{"req_id":"78d36b57-a142-47b7-9143-5843df77d44d","op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}` msg := `{"req_id":"78d36b57-a142-47b7-9143-5843df77d44d","op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}`
raw, err := s.parseWebSocketEvent([]byte(msg)) raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err) assert.NoError(t, err)