pkg/exchange: emit balance snapshot after authed

This commit is contained in:
Edwin 2024-01-17 11:34:15 +08:00
parent 735123b3a2
commit c5d2047605
3 changed files with 45 additions and 17 deletions

View File

@ -375,7 +375,7 @@ func (e *Exchange) CancelOrders(ctx context.Context, orders ...types.Order) erro
} }
func (e *Exchange) NewStream() types.Stream { func (e *Exchange) NewStream() types.Stream {
return NewStream(e.client) return NewStream(e.client, e)
} }
func (e *Exchange) QueryKLines(ctx context.Context, symbol string, interval types.Interval, options types.KLineQueryOptions) ([]types.KLine, error) { func (e *Exchange) QueryKLines(ctx context.Context, symbol string, interval types.Interval, options types.KLineQueryOptions) ([]types.KLine, error) {

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/c9s/bbgo/pkg/exchange/okex/okexapi" "github.com/c9s/bbgo/pkg/exchange/okex/okexapi"
"github.com/c9s/bbgo/pkg/exchange/retry"
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
) )
@ -31,7 +32,8 @@ type WebsocketLogin struct {
type Stream struct { type Stream struct {
types.StandardStream types.StandardStream
client *okexapi.RestClient client *okexapi.RestClient
balanceProvider types.ExchangeAccountService
// public callbacks // public callbacks
kLineEventCallbacks []func(candle KLineEvent) kLineEventCallbacks []func(candle KLineEvent)
@ -41,10 +43,11 @@ type Stream struct {
marketTradeEventCallbacks []func(tradeDetail []MarketTradeEvent) marketTradeEventCallbacks []func(tradeDetail []MarketTradeEvent)
} }
func NewStream(client *okexapi.RestClient) *Stream { func NewStream(client *okexapi.RestClient, balanceProvider types.ExchangeAccountService) *Stream {
stream := &Stream{ stream := &Stream{
client: client, client: client,
StandardStream: types.NewStandardStream(), balanceProvider: balanceProvider,
StandardStream: types.NewStandardStream(),
} }
stream.SetParser(parseWebSocketEvent) stream.SetParser(parseWebSocketEvent)
@ -57,7 +60,7 @@ func NewStream(client *okexapi.RestClient) *Stream {
stream.OnMarketTradeEvent(stream.handleMarketTradeEvent) stream.OnMarketTradeEvent(stream.handleMarketTradeEvent)
stream.OnOrderDetailsEvent(stream.handleOrderDetailsEvent) stream.OnOrderDetailsEvent(stream.handleOrderDetailsEvent)
stream.OnConnect(stream.handleConnect) stream.OnConnect(stream.handleConnect)
stream.OnAuth(stream.handleAuth) stream.OnAuth(stream.subscribePrivateChannels(stream.emitBalanceSnapshot))
return stream return stream
} }
@ -151,20 +154,42 @@ func (s *Stream) handleConnect() {
} }
} }
func (s *Stream) handleAuth() { func (s *Stream) subscribePrivateChannels(next func()) func() {
var subs = []WebsocketSubscription{ return func() {
{Channel: ChannelAccount}, var subs = []WebsocketSubscription{
{Channel: "orders", InstrumentType: string(okexapi.InstrumentTypeSpot)}, {Channel: ChannelAccount},
} {Channel: "orders", InstrumentType: string(okexapi.InstrumentTypeSpot)},
}
log.Infof("subscribing private channels: %+v", subs) log.Infof("subscribing private channels: %+v", subs)
err := s.Conn.WriteJSON(WebsocketOp{ err := s.Conn.WriteJSON(WebsocketOp{
Op: "subscribe", Op: "subscribe",
Args: subs, Args: subs,
})
if err != nil {
log.WithError(err).Error("private channel subscribe error")
return
}
next()
}
}
func (s *Stream) emitBalanceSnapshot() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
var balancesMap types.BalanceMap
var err error
err = retry.GeneralBackoff(ctx, func() error {
balancesMap, err = s.balanceProvider.QueryAccountBalances(ctx)
return err
}) })
if err != nil { if err != nil {
log.WithError(err).Error("private channel subscribe error") log.WithError(err).Error("no more attempts to retrieve balances")
return
} }
s.EmitBalanceSnapshot(balancesMap)
} }
func (s *Stream) handleOrderDetailsEvent(orderDetails []okexapi.OrderDetails) { func (s *Stream) handleOrderDetailsEvent(orderDetails []okexapi.OrderDetails) {

View File

@ -25,7 +25,7 @@ func getTestClientOrSkip(t *testing.T) *Stream {
} }
exchange := New(key, secret, passphrase) exchange := New(key, secret, passphrase)
return NewStream(exchange.client) return NewStream(exchange.client, exchange)
} }
func TestStream(t *testing.T) { func TestStream(t *testing.T) {
@ -39,6 +39,9 @@ func TestStream(t *testing.T) {
s.OnBalanceUpdate(func(balances types.BalanceMap) { s.OnBalanceUpdate(func(balances types.BalanceMap) {
t.Log("got snapshot", balances) t.Log("got snapshot", balances)
}) })
s.OnBalanceSnapshot(func(balances types.BalanceMap) {
t.Log("got snapshot", balances)
})
s.OnBookUpdate(func(book types.SliceOrderBook) { s.OnBookUpdate(func(book types.SliceOrderBook) {
t.Log("got update", book) t.Log("got update", book)
}) })