pkg/exchange: support kline subscriptions

This commit is contained in:
Edwin 2024-01-30 12:17:50 +08:00
parent e3208b34fa
commit f0ad014837
3 changed files with 92 additions and 29 deletions

View File

@ -47,19 +47,15 @@ func (s *KLineStream) handleConnect() {
subs = append(subs, sub) subs = append(subs, sub)
} }
if len(subs) == 0 { subscribe(s.Conn, subs)
return }
}
log.Infof("subscribing channels: %+v", subs) func (s *KLineStream) Connect(ctx context.Context) error {
err := s.Conn.WriteJSON(WebsocketOp{ if len(s.StandardStream.Subscriptions) == 0 {
Op: "subscribe", log.Info("no subscriptions in kline")
Args: subs, return nil
})
if err != nil {
log.WithError(err).Error("subscribe error")
} }
return s.StandardStream.Connect(ctx)
} }
func (s *KLineStream) handleKLineEvent(k KLineEvent) { func (s *KLineStream) handleKLineEvent(k KLineEvent) {
@ -85,3 +81,14 @@ func (s *KLineStream) dispatchEvent(e interface{}) {
s.EmitKLineEvent(*et) s.EmitKLineEvent(*et)
} }
} }
func (s *KLineStream) Unsubscribe() {
// errors are handled in the syncSubscriptions, so they are skipped here.
if len(s.StandardStream.Subscriptions) != 0 {
_ = syncSubscriptions(s.StandardStream.Conn, s.StandardStream.Subscriptions, WsEventTypeUnsubscribe)
}
s.Resubscribe(func(old []types.Subscription) (new []types.Subscription, err error) {
// clear the subscriptions
return []types.Subscription{}, nil
})
}

View File

@ -3,10 +3,12 @@ package okex
import ( import (
"context" "context"
"fmt" "fmt"
"golang.org/x/time/rate"
"strconv" "strconv"
"time" "time"
"github.com/gorilla/websocket"
"golang.org/x/time/rate"
"github.com/c9s/bbgo/pkg/exchange/okex/okexapi" "github.com/c9s/bbgo/pkg/exchange/okex/okexapi"
"github.com/c9s/bbgo/pkg/exchange/retry" "github.com/c9s/bbgo/pkg/exchange/retry"
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
@ -67,18 +69,20 @@ func NewStream(client *okexapi.RestClient, balanceProvider types.ExchangeAccount
stream.OnOrderTradesEvent(stream.handleOrderDetailsEvent) stream.OnOrderTradesEvent(stream.handleOrderDetailsEvent)
stream.OnConnect(stream.handleConnect) stream.OnConnect(stream.handleConnect)
stream.OnAuth(stream.subscribePrivateChannels(stream.emitBalanceSnapshot)) stream.OnAuth(stream.subscribePrivateChannels(stream.emitBalanceSnapshot))
stream.kLineStream.OnKLineClosed(stream.EmitKLineClosed)
stream.kLineStream.OnKLine(stream.EmitKLine)
return stream return stream
} }
func (s *Stream) syncSubscriptions(opType WsEventType) error { func syncSubscriptions(conn *websocket.Conn, subscriptions []types.Subscription, opType WsEventType) error {
if opType != WsEventTypeUnsubscribe && opType != WsEventTypeSubscribe { if opType != WsEventTypeUnsubscribe && opType != WsEventTypeSubscribe {
return fmt.Errorf("unexpected subscription type: %v", opType) return fmt.Errorf("unexpected subscription type: %v", opType)
} }
logger := log.WithField("opType", opType) logger := log.WithField("opType", opType)
var topics []WebsocketSubscription var topics []WebsocketSubscription
for _, subscription := range s.Subscriptions { for _, subscription := range subscriptions {
topic, err := convertSubscription(subscription) topic, err := convertSubscription(subscription)
if err != nil { if err != nil {
logger.WithError(err).Errorf("convert error, subscription: %+v", subscription) logger.WithError(err).Errorf("convert error, subscription: %+v", subscription)
@ -89,7 +93,7 @@ func (s *Stream) syncSubscriptions(opType WsEventType) error {
} }
logger.Infof("%s channels: %+v", opType, topics) logger.Infof("%s channels: %+v", opType, topics)
if err := s.Conn.WriteJSON(WebsocketOp{ if err := conn.WriteJSON(WebsocketOp{
Op: opType, Op: opType,
Args: topics, Args: topics,
}); err != nil { }); err != nil {
@ -102,11 +106,47 @@ func (s *Stream) syncSubscriptions(opType WsEventType) error {
func (s *Stream) Unsubscribe() { func (s *Stream) Unsubscribe() {
// errors are handled in the syncSubscriptions, so they are skipped here. // errors are handled in the syncSubscriptions, so they are skipped here.
_ = s.syncSubscriptions(WsEventTypeUnsubscribe) _ = syncSubscriptions(s.StandardStream.Conn, s.StandardStream.Subscriptions, WsEventTypeUnsubscribe)
s.Resubscribe(func(old []types.Subscription) (new []types.Subscription, err error) { s.Resubscribe(func(old []types.Subscription) (new []types.Subscription, err error) {
// clear the subscriptions // clear the subscriptions
return []types.Subscription{}, nil return []types.Subscription{}, nil
}) })
s.kLineStream.Unsubscribe()
}
func (s *Stream) Connect(ctx context.Context) error {
if err := s.StandardStream.Connect(ctx); err != nil {
return err
}
if err := s.kLineStream.Connect(ctx); err != nil {
return err
}
return nil
}
func (s *Stream) Subscribe(channel types.Channel, symbol string, options types.SubscribeOptions) {
if channel == types.KLineChannel {
s.kLineStream.Subscribe(channel, symbol, options)
} else {
s.StandardStream.Subscribe(channel, symbol, options)
}
}
func subscribe(conn *websocket.Conn, subs []WebsocketSubscription) {
if len(subs) == 0 {
return
}
log.Infof("subscribing channels: %+v", subs)
err := conn.WriteJSON(WebsocketOp{
Op: "subscribe",
Args: subs,
})
if err != nil {
log.WithError(err).Error("subscribe error")
}
} }
func (s *Stream) handleConnect() { func (s *Stream) handleConnect() {
@ -121,19 +161,7 @@ func (s *Stream) handleConnect() {
subs = append(subs, sub) subs = append(subs, sub)
} }
if len(subs) == 0 { subscribe(s.StandardStream.Conn, subs)
return
}
log.Infof("subscribing channels: %+v", subs)
err := s.Conn.WriteJSON(WebsocketOp{
Op: "subscribe",
Args: subs,
})
if err != nil {
log.WithError(err).Error("subscribe error")
}
} else { } else {
// login as private channel // login as private channel
// sign example: // sign example:

View File

@ -67,6 +67,34 @@ func TestStream(t *testing.T) {
<-c <-c
}) })
t.Run("book && kline test", func(t *testing.T) {
s.Subscribe(types.BookChannel, "BTCUSDT", types.SubscribeOptions{
Depth: types.DepthLevel400,
})
s.Subscribe(types.KLineChannel, "BTCUSDT", types.SubscribeOptions{
Interval: types.Interval1m,
})
s.SetPublicOnly()
err := s.Connect(context.Background())
assert.NoError(t, err)
s.OnBookSnapshot(func(book types.SliceOrderBook) {
t.Log("got snapshot", book)
})
s.OnBookUpdate(func(book types.SliceOrderBook) {
t.Log("got update", book)
})
s.OnKLine(func(kline types.KLine) {
t.Log("kline", kline)
})
s.OnKLineClosed(func(kline types.KLine) {
t.Log("kline closed", kline)
})
c := make(chan struct{})
<-c
})
t.Run("market trade test", func(t *testing.T) { t.Run("market trade test", func(t *testing.T) {
s.Subscribe(types.MarketTradeChannel, "BTCUSDT", types.SubscribeOptions{}) s.Subscribe(types.MarketTradeChannel, "BTCUSDT", types.SubscribeOptions{})
s.SetPublicOnly() s.SetPublicOnly()