Merge pull request #1304 from bailantaotao/edwin/support-unsubscribe

FEATURE: [bybit] support unsubscribe
This commit is contained in:
bailantaotao 2023-09-08 18:22:53 +08:00 committed by GitHub
commit 439f45bdf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 138 additions and 32 deletions

View File

@ -73,6 +73,52 @@ func NewStream(key, secret string, marketProvider MarketInfoProvider) *Stream {
return stream return stream
} }
func (s *Stream) syncSubscriptions(opType WsOpType) error {
if opType != WsOpTypeUnsubscribe && opType != WsOpTypeSubscribe {
return fmt.Errorf("unexpected subscription type: %v", opType)
}
logger := log.WithField("opType", opType)
lens := len(s.Subscriptions)
for begin := 0; begin < lens; begin += spotArgsLimit {
end := begin + spotArgsLimit
if end > lens {
end = lens
}
topics := []string{}
for _, subscription := range s.Subscriptions[begin:end] {
topic, err := s.convertSubscription(subscription)
if err != nil {
logger.WithError(err).Errorf("convert error, subscription: %+v", subscription)
return err
}
topics = append(topics, topic)
}
logger.Infof("%s channels: %+v", opType, topics)
if err := s.Conn.WriteJSON(WebsocketOp{
Op: opType,
Args: topics,
}); err != nil {
logger.WithError(err).Error("failed to send request")
return err
}
}
return nil
}
func (s *Stream) Unsubscribe() {
// errors are handled in the syncSubscriptions, so they are skipped here.
_ = s.syncSubscriptions(WsOpTypeUnsubscribe)
s.Resubscribe(func(old []types.Subscription) (new []types.Subscription, err error) {
// clear the subscriptions
return []types.Subscription{}, nil
})
}
func (s *Stream) createEndpoint(_ context.Context) (string, error) { func (s *Stream) createEndpoint(_ context.Context) (string, error) {
var url string var url string
if s.PublicOnly { if s.PublicOnly {
@ -205,34 +251,8 @@ func (s *Stream) ping(ctx context.Context, conn *websocket.Conn, cancelFunc cont
func (s *Stream) handlerConnect() { func (s *Stream) handlerConnect() {
if s.PublicOnly { if s.PublicOnly {
if len(s.Subscriptions) == 0 { // errors are handled in the syncSubscriptions, so they are skipped here.
log.Debug("no subscriptions") _ = s.syncSubscriptions(WsOpTypeSubscribe)
return
}
var topics []string
for _, subscription := range s.Subscriptions {
topic, err := s.convertSubscription(subscription)
if err != nil {
log.WithError(err).Errorf("subscription convert error")
continue
}
topics = append(topics, topic)
}
if len(topics) > spotArgsLimit {
log.Debugf("topics exceeds limit: %d, drop of: %v", spotArgsLimit, topics[spotArgsLimit:])
topics = topics[:spotArgsLimit]
}
log.Infof("subscribing channels: %+v", topics)
if err := s.Conn.WriteJSON(WebsocketOp{
Op: WsOpTypeSubscribe,
Args: topics,
}); err != nil {
log.WithError(err).Error("failed to send subscription request")
return
}
} else { } else {
expires := strconv.FormatInt(time.Now().Add(wsAuthRequest).In(time.UTC).UnixMilli(), 10) expires := strconv.FormatInt(time.Now().Add(wsAuthRequest).In(time.UTC).UnixMilli(), 10)

View File

@ -7,6 +7,7 @@ import (
"os" "os"
"strconv" "strconv"
"testing" "testing"
"time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -38,6 +39,20 @@ func TestStream(t *testing.T) {
t.Skip() t.Skip()
s := getTestClientOrSkip(t) s := getTestClientOrSkip(t)
symbols := []string{
"BTCUSDT",
"ETHUSDT",
"DOTUSDT",
"ADAUSDT",
"AAVEUSDT",
"APTUSDT",
"ATOMUSDT",
"AXSUSDT",
"BNBUSDT",
"SOLUSDT",
"DOGEUSDT",
}
t.Run("Auth test", func(t *testing.T) { t.Run("Auth test", func(t *testing.T) {
s.Connect(context.Background()) s.Connect(context.Background())
c := make(chan struct{}) c := make(chan struct{})
@ -62,6 +77,41 @@ func TestStream(t *testing.T) {
<-c <-c
}) })
t.Run("book test on unsubscribe and reconnect", func(t *testing.T) {
for _, symbol := range symbols {
s.Subscribe(types.BookChannel, symbol, types.SubscribeOptions{
Depth: types.DepthLevel50,
})
}
s.SetPublicOnly()
err := s.Connect(context.Background())
assert.NoError(t, err)
s.OnBookSnapshot(func(book types.SliceOrderBook) {
t.Log("got snapshot", book)
})
s.OnBookUpdate(func(book types.SliceOrderBook) {
t.Log("got update", book)
})
<-time.After(2 * time.Second)
s.Unsubscribe()
for _, symbol := range symbols {
s.Subscribe(types.BookChannel, symbol, types.SubscribeOptions{
Depth: types.DepthLevel50,
})
}
<-time.After(2 * time.Second)
s.Reconnect()
c := make(chan struct{})
<-c
})
t.Run("wallet test", func(t *testing.T) { t.Run("wallet test", func(t *testing.T) {
err := s.Connect(context.Background()) err := s.Connect(context.Background())
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -29,10 +29,11 @@ func (w *WsEvent) IsTopic() bool {
type WsOpType string type WsOpType string
const ( const (
WsOpTypePing WsOpType = "ping" WsOpTypePing WsOpType = "ping"
WsOpTypePong WsOpType = "pong" WsOpTypePong WsOpType = "pong"
WsOpTypeAuth WsOpType = "auth" WsOpTypeAuth WsOpType = "auth"
WsOpTypeSubscribe WsOpType = "subscribe" WsOpTypeSubscribe WsOpType = "subscribe"
WsOpTypeUnsubscribe WsOpType = "unsubscribe"
) )
type WebsocketOp struct { type WebsocketOp struct {
@ -73,6 +74,15 @@ func (w *WebSocketOpEvent) IsValid() error {
return fmt.Errorf("unexpected response result: %+v", w) return fmt.Errorf("unexpected response result: %+v", w)
} }
return nil return nil
case WsOpTypeUnsubscribe:
// in the public channel, you can get RetMsg = 'subscribe', but in the private channel, you cannot.
// so, we only verify that success is true.
if !w.Success {
return fmt.Errorf("unexpected response result: %+v", w)
}
return nil
default: default:
return fmt.Errorf("unexpected op type: %+v", w) return fmt.Errorf("unexpected op type: %+v", w)
} }

View File

@ -189,6 +189,19 @@ func Test_WebSocketEventIsValid(t *testing.T) {
assert.NoError(t, w.IsValid()) assert.NoError(t, w.IsValid())
}) })
t.Run("[unsubscribe] valid with public channel", func(t *testing.T) {
expRetMsg := "subscribe"
w := &WebSocketOpEvent{
Success: true,
RetMsg: expRetMsg,
ReqId: "",
ConnId: "test-conndid",
Op: WsOpTypeUnsubscribe,
Args: nil,
}
assert.NoError(t, w.IsValid())
})
t.Run("[subscribe] valid with private channel", func(t *testing.T) { t.Run("[subscribe] valid with private channel", func(t *testing.T) {
w := &WebSocketOpEvent{ w := &WebSocketOpEvent{
Success: true, Success: true,
@ -214,6 +227,19 @@ func Test_WebSocketEventIsValid(t *testing.T) {
assert.Equal(t, fmt.Errorf("unexpected response result: %+v", w), w.IsValid()) assert.Equal(t, fmt.Errorf("unexpected response result: %+v", w), w.IsValid())
}) })
t.Run("[unsubscribe] un-succeeds", func(t *testing.T) {
expRetMsg := ""
w := &WebSocketOpEvent{
Success: false,
RetMsg: expRetMsg,
ReqId: "",
ConnId: "test-conndid",
Op: WsOpTypeUnsubscribe,
Args: nil,
}
assert.Equal(t, fmt.Errorf("unexpected response result: %+v", w), w.IsValid())
})
t.Run("[auth] valid", func(t *testing.T) { t.Run("[auth] valid", func(t *testing.T) {
w := &WebSocketOpEvent{ w := &WebSocketOpEvent{
Success: true, Success: true,