From 70884538bc7dc09899d37c2284fc881d21886037 Mon Sep 17 00:00:00 2001 From: Edwin Date: Tue, 26 Sep 2023 16:05:52 +0800 Subject: [PATCH] pkg/exchange: emit balance snapshot --- pkg/exchange/bybit/mocks/stream.go | 51 +++++++++++++++++++----------- pkg/exchange/bybit/stream.go | 51 ++++++++++++++++++++++++------ pkg/exchange/bybit/stream_test.go | 15 +++++---- 3 files changed, 84 insertions(+), 33 deletions(-) diff --git a/pkg/exchange/bybit/mocks/stream.go b/pkg/exchange/bybit/mocks/stream.go index 9d20878e8..6a3c9d876 100644 --- a/pkg/exchange/bybit/mocks/stream.go +++ b/pkg/exchange/bybit/mocks/stream.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/c9s/bbgo/pkg/exchange/bybit (interfaces: MarketInfoProvider) +// Source: github.com/c9s/bbgo/pkg/exchange/bybit (interfaces: StreamDataProvider) // Package mocks is a generated GoMock package. package mocks @@ -13,31 +13,31 @@ import ( gomock "github.com/golang/mock/gomock" ) -// MockMarketInfoProvider is a mock of MarketInfoProvider interface. -type MockMarketInfoProvider struct { +// MockStreamDataProvider is a mock of StreamDataProvider interface. +type MockStreamDataProvider struct { ctrl *gomock.Controller - recorder *MockMarketInfoProviderMockRecorder + recorder *MockStreamDataProviderMockRecorder } -// MockMarketInfoProviderMockRecorder is the mock recorder for MockMarketInfoProvider. -type MockMarketInfoProviderMockRecorder struct { - mock *MockMarketInfoProvider +// MockStreamDataProviderMockRecorder is the mock recorder for MockStreamDataProvider. +type MockStreamDataProviderMockRecorder struct { + mock *MockStreamDataProvider } -// NewMockMarketInfoProvider creates a new mock instance. -func NewMockMarketInfoProvider(ctrl *gomock.Controller) *MockMarketInfoProvider { - mock := &MockMarketInfoProvider{ctrl: ctrl} - mock.recorder = &MockMarketInfoProviderMockRecorder{mock} +// NewMockStreamDataProvider creates a new mock instance. +func NewMockStreamDataProvider(ctrl *gomock.Controller) *MockStreamDataProvider { + mock := &MockStreamDataProvider{ctrl: ctrl} + mock.recorder = &MockStreamDataProviderMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMarketInfoProvider) EXPECT() *MockMarketInfoProviderMockRecorder { +func (m *MockStreamDataProvider) EXPECT() *MockStreamDataProviderMockRecorder { return m.recorder } // GetAllFeeRates mocks base method. -func (m *MockMarketInfoProvider) GetAllFeeRates(arg0 context.Context) (bybitapi.FeeRates, error) { +func (m *MockStreamDataProvider) GetAllFeeRates(arg0 context.Context) (bybitapi.FeeRates, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAllFeeRates", arg0) ret0, _ := ret[0].(bybitapi.FeeRates) @@ -46,13 +46,28 @@ func (m *MockMarketInfoProvider) GetAllFeeRates(arg0 context.Context) (bybitapi. } // GetAllFeeRates indicates an expected call of GetAllFeeRates. -func (mr *MockMarketInfoProviderMockRecorder) GetAllFeeRates(arg0 interface{}) *gomock.Call { +func (mr *MockStreamDataProviderMockRecorder) GetAllFeeRates(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllFeeRates", reflect.TypeOf((*MockMarketInfoProvider)(nil).GetAllFeeRates), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllFeeRates", reflect.TypeOf((*MockStreamDataProvider)(nil).GetAllFeeRates), arg0) +} + +// QueryAccountBalances mocks base method. +func (m *MockStreamDataProvider) QueryAccountBalances(arg0 context.Context) (types.BalanceMap, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryAccountBalances", arg0) + ret0, _ := ret[0].(types.BalanceMap) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryAccountBalances indicates an expected call of QueryAccountBalances. +func (mr *MockStreamDataProviderMockRecorder) QueryAccountBalances(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryAccountBalances", reflect.TypeOf((*MockStreamDataProvider)(nil).QueryAccountBalances), arg0) } // QueryMarkets mocks base method. -func (m *MockMarketInfoProvider) QueryMarkets(arg0 context.Context) (types.MarketMap, error) { +func (m *MockStreamDataProvider) QueryMarkets(arg0 context.Context) (types.MarketMap, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "QueryMarkets", arg0) ret0, _ := ret[0].(types.MarketMap) @@ -61,7 +76,7 @@ func (m *MockMarketInfoProvider) QueryMarkets(arg0 context.Context) (types.Marke } // QueryMarkets indicates an expected call of QueryMarkets. -func (mr *MockMarketInfoProviderMockRecorder) QueryMarkets(arg0 interface{}) *gomock.Call { +func (mr *MockStreamDataProviderMockRecorder) QueryMarkets(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryMarkets", reflect.TypeOf((*MockMarketInfoProvider)(nil).QueryMarkets), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryMarkets", reflect.TypeOf((*MockStreamDataProvider)(nil).QueryMarkets), arg0) } diff --git a/pkg/exchange/bybit/stream.go b/pkg/exchange/bybit/stream.go index ef756a757..eab0c83e7 100644 --- a/pkg/exchange/bybit/stream.go +++ b/pkg/exchange/bybit/stream.go @@ -11,6 +11,7 @@ import ( "github.com/c9s/bbgo/pkg/exchange/bybit/bybitapi" "github.com/c9s/bbgo/pkg/types" + "github.com/c9s/bbgo/pkg/util" ) const ( @@ -27,18 +28,29 @@ var ( wsAuthRequest = 10 * time.Second ) -//go:generate mockgen -destination=mocks/stream.go -package=mocks . MarketInfoProvider +// MarketInfoProvider calculates trade fees since trading fees are not supported by streaming. type MarketInfoProvider interface { GetAllFeeRates(ctx context.Context) (bybitapi.FeeRates, error) QueryMarkets(ctx context.Context) (types.MarketMap, error) } +// AccountBalanceProvider provides a function to query all balances at streaming connected and emit balance snapshot. +type AccountBalanceProvider interface { + QueryAccountBalances(ctx context.Context) (types.BalanceMap, error) +} + +//go:generate mockgen -destination=mocks/stream.go -package=mocks . StreamDataProvider +type StreamDataProvider interface { + MarketInfoProvider + AccountBalanceProvider +} + //go:generate callbackgen -type Stream type Stream struct { types.StandardStream - key, secret string - marketProvider MarketInfoProvider + key, secret string + streamDataProvider StreamDataProvider // TODO: update the fee rate at 7:00 am UTC; rotation required. symbolFeeDetails map[string]*symbolFeeDetail @@ -50,13 +62,13 @@ type Stream struct { tradeEventCallbacks []func(e []TradeEvent) } -func NewStream(key, secret string, marketProvider MarketInfoProvider) *Stream { +func NewStream(key, secret string, userDataProvider StreamDataProvider) *Stream { stream := &Stream{ StandardStream: types.NewStandardStream(), // pragma: allowlist nextline secret - key: key, - secret: secret, - marketProvider: marketProvider, + key: key, + secret: secret, + streamDataProvider: userDataProvider, } stream.SetEndpointCreator(stream.createEndpoint) @@ -65,6 +77,7 @@ func NewStream(key, secret string, marketProvider MarketInfoProvider) *Stream { stream.SetHeartBeat(stream.ping) stream.SetBeforeConnect(stream.getAllFeeRates) stream.OnConnect(stream.handlerConnect) + stream.OnAuth(stream.handleAuthEvent) stream.OnBookEvent(stream.handleBookEvent) stream.OnMarketTradeEvent(stream.handleMarketTradeEvent) @@ -326,6 +339,26 @@ func (s *Stream) convertSubscription(sub types.Subscription) (string, error) { return "", fmt.Errorf("unsupported stream channel: %s", sub.Channel) } +func (s *Stream) handleAuthEvent() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + var balnacesMap types.BalanceMap + var err error + err = util.Retry(ctx, 10, 300*time.Millisecond, func() error { + balnacesMap, err = s.streamDataProvider.QueryAccountBalances(ctx) + return err + }, func(err error) { + log.WithError(err).Error("failed to call query account balances") + }) + if err != nil { + log.WithError(err).Error("no more attempts to retrieve balances") + return + } + + s.EmitBalanceSnapshot(balnacesMap) +} + func (s *Stream) handleBookEvent(e BookEvent) { orderBook := e.OrderBook() switch { @@ -417,7 +450,7 @@ type symbolFeeDetail struct { // getAllFeeRates retrieves all fee rates from the Bybit API and then fetches markets to ensure the base coin and quote coin // are correct. func (e *Stream) getAllFeeRates(ctx context.Context) error { - feeRates, err := e.marketProvider.GetAllFeeRates(ctx) + feeRates, err := e.streamDataProvider.GetAllFeeRates(ctx) if err != nil { return fmt.Errorf("failed to call get fee rates: %w", err) } @@ -429,7 +462,7 @@ func (e *Stream) getAllFeeRates(ctx context.Context) error { } } - mkts, err := e.marketProvider.QueryMarkets(ctx) + mkts, err := e.streamDataProvider.QueryMarkets(ctx) if err != nil { return fmt.Errorf("failed to get markets: %w", err) } diff --git a/pkg/exchange/bybit/stream_test.go b/pkg/exchange/bybit/stream_test.go index bcf3572a6..8e322a439 100644 --- a/pkg/exchange/bybit/stream_test.go +++ b/pkg/exchange/bybit/stream_test.go @@ -54,6 +54,9 @@ func TestStream(t *testing.T) { } t.Run("Auth test", func(t *testing.T) { + s.OnBalanceSnapshot(func(balances types.BalanceMap) { + t.Log("got balance snapshot", balances) + }) s.Connect(context.Background()) c := make(chan struct{}) <-c @@ -450,9 +453,9 @@ func TestStream_getFeeRate(t *testing.T) { unknownErr := errors.New("unknown err") t.Run("succeeds", func(t *testing.T) { - mockMarketProvider := mocks.NewMockMarketInfoProvider(mockCtrl) + mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl) s := &Stream{ - marketProvider: mockMarketProvider, + streamDataProvider: mockMarketProvider, } ctx := context.Background() @@ -510,9 +513,9 @@ func TestStream_getFeeRate(t *testing.T) { }) t.Run("failed to query markets", func(t *testing.T) { - mockMarketProvider := mocks.NewMockMarketInfoProvider(mockCtrl) + mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl) s := &Stream{ - marketProvider: mockMarketProvider, + streamDataProvider: mockMarketProvider, } ctx := context.Background() @@ -545,9 +548,9 @@ func TestStream_getFeeRate(t *testing.T) { }) t.Run("failed to get fee rates", func(t *testing.T) { - mockMarketProvider := mocks.NewMockMarketInfoProvider(mockCtrl) + mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl) s := &Stream{ - marketProvider: mockMarketProvider, + streamDataProvider: mockMarketProvider, } ctx := context.Background()