pkg/exchange: emit balance snapshot

This commit is contained in:
Edwin 2023-09-26 16:05:52 +08:00
parent 1600277ac3
commit 70884538bc
3 changed files with 84 additions and 33 deletions

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/c9s/bbgo/pkg/exchange/bybit (interfaces: MarketInfoProvider) // Source: github.com/c9s/bbgo/pkg/exchange/bybit (interfaces: StreamDataProvider)
// Package mocks is a generated GoMock package. // Package mocks is a generated GoMock package.
package mocks package mocks
@ -13,31 +13,31 @@ import (
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
// MockMarketInfoProvider is a mock of MarketInfoProvider interface. // MockStreamDataProvider is a mock of StreamDataProvider interface.
type MockMarketInfoProvider struct { type MockStreamDataProvider struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockMarketInfoProviderMockRecorder recorder *MockStreamDataProviderMockRecorder
} }
// MockMarketInfoProviderMockRecorder is the mock recorder for MockMarketInfoProvider. // MockStreamDataProviderMockRecorder is the mock recorder for MockStreamDataProvider.
type MockMarketInfoProviderMockRecorder struct { type MockStreamDataProviderMockRecorder struct {
mock *MockMarketInfoProvider mock *MockStreamDataProvider
} }
// NewMockMarketInfoProvider creates a new mock instance. // NewMockStreamDataProvider creates a new mock instance.
func NewMockMarketInfoProvider(ctrl *gomock.Controller) *MockMarketInfoProvider { func NewMockStreamDataProvider(ctrl *gomock.Controller) *MockStreamDataProvider {
mock := &MockMarketInfoProvider{ctrl: ctrl} mock := &MockStreamDataProvider{ctrl: ctrl}
mock.recorder = &MockMarketInfoProviderMockRecorder{mock} mock.recorder = &MockStreamDataProviderMockRecorder{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockMarketInfoProvider) EXPECT() *MockMarketInfoProviderMockRecorder { func (m *MockStreamDataProvider) EXPECT() *MockStreamDataProviderMockRecorder {
return m.recorder return m.recorder
} }
// GetAllFeeRates mocks base method. // GetAllFeeRates mocks base method.
func (m *MockMarketInfoProvider) GetAllFeeRates(arg0 context.Context) (bybitapi.FeeRates, error) { func (m *MockStreamDataProvider) GetAllFeeRates(arg0 context.Context) (bybitapi.FeeRates, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAllFeeRates", arg0) ret := m.ctrl.Call(m, "GetAllFeeRates", arg0)
ret0, _ := ret[0].(bybitapi.FeeRates) ret0, _ := ret[0].(bybitapi.FeeRates)
@ -46,13 +46,28 @@ func (m *MockMarketInfoProvider) GetAllFeeRates(arg0 context.Context) (bybitapi.
} }
// GetAllFeeRates indicates an expected call of GetAllFeeRates. // GetAllFeeRates indicates an expected call of GetAllFeeRates.
func (mr *MockMarketInfoProviderMockRecorder) GetAllFeeRates(arg0 interface{}) *gomock.Call { func (mr *MockStreamDataProviderMockRecorder) GetAllFeeRates(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllFeeRates", reflect.TypeOf((*MockMarketInfoProvider)(nil).GetAllFeeRates), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllFeeRates", reflect.TypeOf((*MockStreamDataProvider)(nil).GetAllFeeRates), arg0)
}
// QueryAccountBalances mocks base method.
func (m *MockStreamDataProvider) QueryAccountBalances(arg0 context.Context) (types.BalanceMap, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "QueryAccountBalances", arg0)
ret0, _ := ret[0].(types.BalanceMap)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// QueryAccountBalances indicates an expected call of QueryAccountBalances.
func (mr *MockStreamDataProviderMockRecorder) QueryAccountBalances(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryAccountBalances", reflect.TypeOf((*MockStreamDataProvider)(nil).QueryAccountBalances), arg0)
} }
// QueryMarkets mocks base method. // QueryMarkets mocks base method.
func (m *MockMarketInfoProvider) QueryMarkets(arg0 context.Context) (types.MarketMap, error) { func (m *MockStreamDataProvider) QueryMarkets(arg0 context.Context) (types.MarketMap, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "QueryMarkets", arg0) ret := m.ctrl.Call(m, "QueryMarkets", arg0)
ret0, _ := ret[0].(types.MarketMap) ret0, _ := ret[0].(types.MarketMap)
@ -61,7 +76,7 @@ func (m *MockMarketInfoProvider) QueryMarkets(arg0 context.Context) (types.Marke
} }
// QueryMarkets indicates an expected call of QueryMarkets. // QueryMarkets indicates an expected call of QueryMarkets.
func (mr *MockMarketInfoProviderMockRecorder) QueryMarkets(arg0 interface{}) *gomock.Call { func (mr *MockStreamDataProviderMockRecorder) QueryMarkets(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryMarkets", reflect.TypeOf((*MockMarketInfoProvider)(nil).QueryMarkets), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryMarkets", reflect.TypeOf((*MockStreamDataProvider)(nil).QueryMarkets), arg0)
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/c9s/bbgo/pkg/exchange/bybit/bybitapi" "github.com/c9s/bbgo/pkg/exchange/bybit/bybitapi"
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
"github.com/c9s/bbgo/pkg/util"
) )
const ( const (
@ -27,18 +28,29 @@ var (
wsAuthRequest = 10 * time.Second wsAuthRequest = 10 * time.Second
) )
//go:generate mockgen -destination=mocks/stream.go -package=mocks . MarketInfoProvider // MarketInfoProvider calculates trade fees since trading fees are not supported by streaming.
type MarketInfoProvider interface { type MarketInfoProvider interface {
GetAllFeeRates(ctx context.Context) (bybitapi.FeeRates, error) GetAllFeeRates(ctx context.Context) (bybitapi.FeeRates, error)
QueryMarkets(ctx context.Context) (types.MarketMap, error) QueryMarkets(ctx context.Context) (types.MarketMap, error)
} }
// AccountBalanceProvider provides a function to query all balances at streaming connected and emit balance snapshot.
type AccountBalanceProvider interface {
QueryAccountBalances(ctx context.Context) (types.BalanceMap, error)
}
//go:generate mockgen -destination=mocks/stream.go -package=mocks . StreamDataProvider
type StreamDataProvider interface {
MarketInfoProvider
AccountBalanceProvider
}
//go:generate callbackgen -type Stream //go:generate callbackgen -type Stream
type Stream struct { type Stream struct {
types.StandardStream types.StandardStream
key, secret string key, secret string
marketProvider MarketInfoProvider streamDataProvider StreamDataProvider
// TODO: update the fee rate at 7:00 am UTC; rotation required. // TODO: update the fee rate at 7:00 am UTC; rotation required.
symbolFeeDetails map[string]*symbolFeeDetail symbolFeeDetails map[string]*symbolFeeDetail
@ -50,13 +62,13 @@ type Stream struct {
tradeEventCallbacks []func(e []TradeEvent) tradeEventCallbacks []func(e []TradeEvent)
} }
func NewStream(key, secret string, marketProvider MarketInfoProvider) *Stream { func NewStream(key, secret string, userDataProvider StreamDataProvider) *Stream {
stream := &Stream{ stream := &Stream{
StandardStream: types.NewStandardStream(), StandardStream: types.NewStandardStream(),
// pragma: allowlist nextline secret // pragma: allowlist nextline secret
key: key, key: key,
secret: secret, secret: secret,
marketProvider: marketProvider, streamDataProvider: userDataProvider,
} }
stream.SetEndpointCreator(stream.createEndpoint) stream.SetEndpointCreator(stream.createEndpoint)
@ -65,6 +77,7 @@ func NewStream(key, secret string, marketProvider MarketInfoProvider) *Stream {
stream.SetHeartBeat(stream.ping) stream.SetHeartBeat(stream.ping)
stream.SetBeforeConnect(stream.getAllFeeRates) stream.SetBeforeConnect(stream.getAllFeeRates)
stream.OnConnect(stream.handlerConnect) stream.OnConnect(stream.handlerConnect)
stream.OnAuth(stream.handleAuthEvent)
stream.OnBookEvent(stream.handleBookEvent) stream.OnBookEvent(stream.handleBookEvent)
stream.OnMarketTradeEvent(stream.handleMarketTradeEvent) stream.OnMarketTradeEvent(stream.handleMarketTradeEvent)
@ -326,6 +339,26 @@ func (s *Stream) convertSubscription(sub types.Subscription) (string, error) {
return "", fmt.Errorf("unsupported stream channel: %s", sub.Channel) return "", fmt.Errorf("unsupported stream channel: %s", sub.Channel)
} }
func (s *Stream) handleAuthEvent() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
var balnacesMap types.BalanceMap
var err error
err = util.Retry(ctx, 10, 300*time.Millisecond, func() error {
balnacesMap, err = s.streamDataProvider.QueryAccountBalances(ctx)
return err
}, func(err error) {
log.WithError(err).Error("failed to call query account balances")
})
if err != nil {
log.WithError(err).Error("no more attempts to retrieve balances")
return
}
s.EmitBalanceSnapshot(balnacesMap)
}
func (s *Stream) handleBookEvent(e BookEvent) { func (s *Stream) handleBookEvent(e BookEvent) {
orderBook := e.OrderBook() orderBook := e.OrderBook()
switch { switch {
@ -417,7 +450,7 @@ type symbolFeeDetail struct {
// getAllFeeRates retrieves all fee rates from the Bybit API and then fetches markets to ensure the base coin and quote coin // getAllFeeRates retrieves all fee rates from the Bybit API and then fetches markets to ensure the base coin and quote coin
// are correct. // are correct.
func (e *Stream) getAllFeeRates(ctx context.Context) error { func (e *Stream) getAllFeeRates(ctx context.Context) error {
feeRates, err := e.marketProvider.GetAllFeeRates(ctx) feeRates, err := e.streamDataProvider.GetAllFeeRates(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to call get fee rates: %w", err) return fmt.Errorf("failed to call get fee rates: %w", err)
} }
@ -429,7 +462,7 @@ func (e *Stream) getAllFeeRates(ctx context.Context) error {
} }
} }
mkts, err := e.marketProvider.QueryMarkets(ctx) mkts, err := e.streamDataProvider.QueryMarkets(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to get markets: %w", err) return fmt.Errorf("failed to get markets: %w", err)
} }

View File

@ -54,6 +54,9 @@ func TestStream(t *testing.T) {
} }
t.Run("Auth test", func(t *testing.T) { t.Run("Auth test", func(t *testing.T) {
s.OnBalanceSnapshot(func(balances types.BalanceMap) {
t.Log("got balance snapshot", balances)
})
s.Connect(context.Background()) s.Connect(context.Background())
c := make(chan struct{}) c := make(chan struct{})
<-c <-c
@ -450,9 +453,9 @@ func TestStream_getFeeRate(t *testing.T) {
unknownErr := errors.New("unknown err") unknownErr := errors.New("unknown err")
t.Run("succeeds", func(t *testing.T) { t.Run("succeeds", func(t *testing.T) {
mockMarketProvider := mocks.NewMockMarketInfoProvider(mockCtrl) mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl)
s := &Stream{ s := &Stream{
marketProvider: mockMarketProvider, streamDataProvider: mockMarketProvider,
} }
ctx := context.Background() ctx := context.Background()
@ -510,9 +513,9 @@ func TestStream_getFeeRate(t *testing.T) {
}) })
t.Run("failed to query markets", func(t *testing.T) { t.Run("failed to query markets", func(t *testing.T) {
mockMarketProvider := mocks.NewMockMarketInfoProvider(mockCtrl) mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl)
s := &Stream{ s := &Stream{
marketProvider: mockMarketProvider, streamDataProvider: mockMarketProvider,
} }
ctx := context.Background() ctx := context.Background()
@ -545,9 +548,9 @@ func TestStream_getFeeRate(t *testing.T) {
}) })
t.Run("failed to get fee rates", func(t *testing.T) { t.Run("failed to get fee rates", func(t *testing.T) {
mockMarketProvider := mocks.NewMockMarketInfoProvider(mockCtrl) mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl)
s := &Stream{ s := &Stream{
marketProvider: mockMarketProvider, streamDataProvider: mockMarketProvider,
} }
ctx := context.Background() ctx := context.Background()