mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-10 09:11:55 +00:00
pkg/exchange: emit balance snapshot
This commit is contained in:
parent
1600277ac3
commit
70884538bc
|
@ -1,5 +1,5 @@
|
||||||
// Code generated by MockGen. DO NOT EDIT.
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
// Source: github.com/c9s/bbgo/pkg/exchange/bybit (interfaces: MarketInfoProvider)
|
// Source: github.com/c9s/bbgo/pkg/exchange/bybit (interfaces: StreamDataProvider)
|
||||||
|
|
||||||
// Package mocks is a generated GoMock package.
|
// Package mocks is a generated GoMock package.
|
||||||
package mocks
|
package mocks
|
||||||
|
@ -13,31 +13,31 @@ import (
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockMarketInfoProvider is a mock of MarketInfoProvider interface.
|
// MockStreamDataProvider is a mock of StreamDataProvider interface.
|
||||||
type MockMarketInfoProvider struct {
|
type MockStreamDataProvider struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
recorder *MockMarketInfoProviderMockRecorder
|
recorder *MockStreamDataProviderMockRecorder
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockMarketInfoProviderMockRecorder is the mock recorder for MockMarketInfoProvider.
|
// MockStreamDataProviderMockRecorder is the mock recorder for MockStreamDataProvider.
|
||||||
type MockMarketInfoProviderMockRecorder struct {
|
type MockStreamDataProviderMockRecorder struct {
|
||||||
mock *MockMarketInfoProvider
|
mock *MockStreamDataProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMockMarketInfoProvider creates a new mock instance.
|
// NewMockStreamDataProvider creates a new mock instance.
|
||||||
func NewMockMarketInfoProvider(ctrl *gomock.Controller) *MockMarketInfoProvider {
|
func NewMockStreamDataProvider(ctrl *gomock.Controller) *MockStreamDataProvider {
|
||||||
mock := &MockMarketInfoProvider{ctrl: ctrl}
|
mock := &MockStreamDataProvider{ctrl: ctrl}
|
||||||
mock.recorder = &MockMarketInfoProviderMockRecorder{mock}
|
mock.recorder = &MockStreamDataProviderMockRecorder{mock}
|
||||||
return mock
|
return mock
|
||||||
}
|
}
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
func (m *MockMarketInfoProvider) EXPECT() *MockMarketInfoProviderMockRecorder {
|
func (m *MockStreamDataProvider) EXPECT() *MockStreamDataProviderMockRecorder {
|
||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllFeeRates mocks base method.
|
// GetAllFeeRates mocks base method.
|
||||||
func (m *MockMarketInfoProvider) GetAllFeeRates(arg0 context.Context) (bybitapi.FeeRates, error) {
|
func (m *MockStreamDataProvider) GetAllFeeRates(arg0 context.Context) (bybitapi.FeeRates, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetAllFeeRates", arg0)
|
ret := m.ctrl.Call(m, "GetAllFeeRates", arg0)
|
||||||
ret0, _ := ret[0].(bybitapi.FeeRates)
|
ret0, _ := ret[0].(bybitapi.FeeRates)
|
||||||
|
@ -46,13 +46,28 @@ func (m *MockMarketInfoProvider) GetAllFeeRates(arg0 context.Context) (bybitapi.
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllFeeRates indicates an expected call of GetAllFeeRates.
|
// GetAllFeeRates indicates an expected call of GetAllFeeRates.
|
||||||
func (mr *MockMarketInfoProviderMockRecorder) GetAllFeeRates(arg0 interface{}) *gomock.Call {
|
func (mr *MockStreamDataProviderMockRecorder) GetAllFeeRates(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllFeeRates", reflect.TypeOf((*MockMarketInfoProvider)(nil).GetAllFeeRates), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllFeeRates", reflect.TypeOf((*MockStreamDataProvider)(nil).GetAllFeeRates), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryAccountBalances mocks base method.
|
||||||
|
func (m *MockStreamDataProvider) QueryAccountBalances(arg0 context.Context) (types.BalanceMap, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "QueryAccountBalances", arg0)
|
||||||
|
ret0, _ := ret[0].(types.BalanceMap)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryAccountBalances indicates an expected call of QueryAccountBalances.
|
||||||
|
func (mr *MockStreamDataProviderMockRecorder) QueryAccountBalances(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryAccountBalances", reflect.TypeOf((*MockStreamDataProvider)(nil).QueryAccountBalances), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryMarkets mocks base method.
|
// QueryMarkets mocks base method.
|
||||||
func (m *MockMarketInfoProvider) QueryMarkets(arg0 context.Context) (types.MarketMap, error) {
|
func (m *MockStreamDataProvider) QueryMarkets(arg0 context.Context) (types.MarketMap, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "QueryMarkets", arg0)
|
ret := m.ctrl.Call(m, "QueryMarkets", arg0)
|
||||||
ret0, _ := ret[0].(types.MarketMap)
|
ret0, _ := ret[0].(types.MarketMap)
|
||||||
|
@ -61,7 +76,7 @@ func (m *MockMarketInfoProvider) QueryMarkets(arg0 context.Context) (types.Marke
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryMarkets indicates an expected call of QueryMarkets.
|
// QueryMarkets indicates an expected call of QueryMarkets.
|
||||||
func (mr *MockMarketInfoProviderMockRecorder) QueryMarkets(arg0 interface{}) *gomock.Call {
|
func (mr *MockStreamDataProviderMockRecorder) QueryMarkets(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryMarkets", reflect.TypeOf((*MockMarketInfoProvider)(nil).QueryMarkets), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryMarkets", reflect.TypeOf((*MockStreamDataProvider)(nil).QueryMarkets), arg0)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
|
|
||||||
"github.com/c9s/bbgo/pkg/exchange/bybit/bybitapi"
|
"github.com/c9s/bbgo/pkg/exchange/bybit/bybitapi"
|
||||||
"github.com/c9s/bbgo/pkg/types"
|
"github.com/c9s/bbgo/pkg/types"
|
||||||
|
"github.com/c9s/bbgo/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -27,18 +28,29 @@ var (
|
||||||
wsAuthRequest = 10 * time.Second
|
wsAuthRequest = 10 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:generate mockgen -destination=mocks/stream.go -package=mocks . MarketInfoProvider
|
// MarketInfoProvider calculates trade fees since trading fees are not supported by streaming.
|
||||||
type MarketInfoProvider interface {
|
type MarketInfoProvider interface {
|
||||||
GetAllFeeRates(ctx context.Context) (bybitapi.FeeRates, error)
|
GetAllFeeRates(ctx context.Context) (bybitapi.FeeRates, error)
|
||||||
QueryMarkets(ctx context.Context) (types.MarketMap, error)
|
QueryMarkets(ctx context.Context) (types.MarketMap, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountBalanceProvider provides a function to query all balances at streaming connected and emit balance snapshot.
|
||||||
|
type AccountBalanceProvider interface {
|
||||||
|
QueryAccountBalances(ctx context.Context) (types.BalanceMap, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:generate mockgen -destination=mocks/stream.go -package=mocks . StreamDataProvider
|
||||||
|
type StreamDataProvider interface {
|
||||||
|
MarketInfoProvider
|
||||||
|
AccountBalanceProvider
|
||||||
|
}
|
||||||
|
|
||||||
//go:generate callbackgen -type Stream
|
//go:generate callbackgen -type Stream
|
||||||
type Stream struct {
|
type Stream struct {
|
||||||
types.StandardStream
|
types.StandardStream
|
||||||
|
|
||||||
key, secret string
|
key, secret string
|
||||||
marketProvider MarketInfoProvider
|
streamDataProvider StreamDataProvider
|
||||||
// TODO: update the fee rate at 7:00 am UTC; rotation required.
|
// TODO: update the fee rate at 7:00 am UTC; rotation required.
|
||||||
symbolFeeDetails map[string]*symbolFeeDetail
|
symbolFeeDetails map[string]*symbolFeeDetail
|
||||||
|
|
||||||
|
@ -50,13 +62,13 @@ type Stream struct {
|
||||||
tradeEventCallbacks []func(e []TradeEvent)
|
tradeEventCallbacks []func(e []TradeEvent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStream(key, secret string, marketProvider MarketInfoProvider) *Stream {
|
func NewStream(key, secret string, userDataProvider StreamDataProvider) *Stream {
|
||||||
stream := &Stream{
|
stream := &Stream{
|
||||||
StandardStream: types.NewStandardStream(),
|
StandardStream: types.NewStandardStream(),
|
||||||
// pragma: allowlist nextline secret
|
// pragma: allowlist nextline secret
|
||||||
key: key,
|
key: key,
|
||||||
secret: secret,
|
secret: secret,
|
||||||
marketProvider: marketProvider,
|
streamDataProvider: userDataProvider,
|
||||||
}
|
}
|
||||||
|
|
||||||
stream.SetEndpointCreator(stream.createEndpoint)
|
stream.SetEndpointCreator(stream.createEndpoint)
|
||||||
|
@ -65,6 +77,7 @@ func NewStream(key, secret string, marketProvider MarketInfoProvider) *Stream {
|
||||||
stream.SetHeartBeat(stream.ping)
|
stream.SetHeartBeat(stream.ping)
|
||||||
stream.SetBeforeConnect(stream.getAllFeeRates)
|
stream.SetBeforeConnect(stream.getAllFeeRates)
|
||||||
stream.OnConnect(stream.handlerConnect)
|
stream.OnConnect(stream.handlerConnect)
|
||||||
|
stream.OnAuth(stream.handleAuthEvent)
|
||||||
|
|
||||||
stream.OnBookEvent(stream.handleBookEvent)
|
stream.OnBookEvent(stream.handleBookEvent)
|
||||||
stream.OnMarketTradeEvent(stream.handleMarketTradeEvent)
|
stream.OnMarketTradeEvent(stream.handleMarketTradeEvent)
|
||||||
|
@ -326,6 +339,26 @@ func (s *Stream) convertSubscription(sub types.Subscription) (string, error) {
|
||||||
return "", fmt.Errorf("unsupported stream channel: %s", sub.Channel)
|
return "", fmt.Errorf("unsupported stream channel: %s", sub.Channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Stream) handleAuthEvent() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var balnacesMap types.BalanceMap
|
||||||
|
var err error
|
||||||
|
err = util.Retry(ctx, 10, 300*time.Millisecond, func() error {
|
||||||
|
balnacesMap, err = s.streamDataProvider.QueryAccountBalances(ctx)
|
||||||
|
return err
|
||||||
|
}, func(err error) {
|
||||||
|
log.WithError(err).Error("failed to call query account balances")
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("no more attempts to retrieve balances")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.EmitBalanceSnapshot(balnacesMap)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Stream) handleBookEvent(e BookEvent) {
|
func (s *Stream) handleBookEvent(e BookEvent) {
|
||||||
orderBook := e.OrderBook()
|
orderBook := e.OrderBook()
|
||||||
switch {
|
switch {
|
||||||
|
@ -417,7 +450,7 @@ type symbolFeeDetail struct {
|
||||||
// getAllFeeRates retrieves all fee rates from the Bybit API and then fetches markets to ensure the base coin and quote coin
|
// getAllFeeRates retrieves all fee rates from the Bybit API and then fetches markets to ensure the base coin and quote coin
|
||||||
// are correct.
|
// are correct.
|
||||||
func (e *Stream) getAllFeeRates(ctx context.Context) error {
|
func (e *Stream) getAllFeeRates(ctx context.Context) error {
|
||||||
feeRates, err := e.marketProvider.GetAllFeeRates(ctx)
|
feeRates, err := e.streamDataProvider.GetAllFeeRates(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to call get fee rates: %w", err)
|
return fmt.Errorf("failed to call get fee rates: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -429,7 +462,7 @@ func (e *Stream) getAllFeeRates(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mkts, err := e.marketProvider.QueryMarkets(ctx)
|
mkts, err := e.streamDataProvider.QueryMarkets(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get markets: %w", err)
|
return fmt.Errorf("failed to get markets: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,6 +54,9 @@ func TestStream(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("Auth test", func(t *testing.T) {
|
t.Run("Auth test", func(t *testing.T) {
|
||||||
|
s.OnBalanceSnapshot(func(balances types.BalanceMap) {
|
||||||
|
t.Log("got balance snapshot", balances)
|
||||||
|
})
|
||||||
s.Connect(context.Background())
|
s.Connect(context.Background())
|
||||||
c := make(chan struct{})
|
c := make(chan struct{})
|
||||||
<-c
|
<-c
|
||||||
|
@ -450,9 +453,9 @@ func TestStream_getFeeRate(t *testing.T) {
|
||||||
unknownErr := errors.New("unknown err")
|
unknownErr := errors.New("unknown err")
|
||||||
|
|
||||||
t.Run("succeeds", func(t *testing.T) {
|
t.Run("succeeds", func(t *testing.T) {
|
||||||
mockMarketProvider := mocks.NewMockMarketInfoProvider(mockCtrl)
|
mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl)
|
||||||
s := &Stream{
|
s := &Stream{
|
||||||
marketProvider: mockMarketProvider,
|
streamDataProvider: mockMarketProvider,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -510,9 +513,9 @@ func TestStream_getFeeRate(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("failed to query markets", func(t *testing.T) {
|
t.Run("failed to query markets", func(t *testing.T) {
|
||||||
mockMarketProvider := mocks.NewMockMarketInfoProvider(mockCtrl)
|
mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl)
|
||||||
s := &Stream{
|
s := &Stream{
|
||||||
marketProvider: mockMarketProvider,
|
streamDataProvider: mockMarketProvider,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -545,9 +548,9 @@ func TestStream_getFeeRate(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("failed to get fee rates", func(t *testing.T) {
|
t.Run("failed to get fee rates", func(t *testing.T) {
|
||||||
mockMarketProvider := mocks.NewMockMarketInfoProvider(mockCtrl)
|
mockMarketProvider := mocks.NewMockStreamDataProvider(mockCtrl)
|
||||||
s := &Stream{
|
s := &Stream{
|
||||||
marketProvider: mockMarketProvider,
|
streamDataProvider: mockMarketProvider,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user