diff --git a/pkg/exchange/bybit/stream.go b/pkg/exchange/bybit/stream.go index abdbbab49..5b9ced6eb 100644 --- a/pkg/exchange/bybit/stream.go +++ b/pkg/exchange/bybit/stream.go @@ -73,6 +73,52 @@ func NewStream(key, secret string, marketProvider MarketInfoProvider) *Stream { return stream } +func (s *Stream) syncSubscriptions(opType WsOpType) error { + if opType != WsOpTypeUnsubscribe && opType != WsOpTypeSubscribe { + return fmt.Errorf("unexpected subscription type: %v", opType) + } + + logger := log.WithField("opType", opType) + lens := len(s.Subscriptions) + for begin := 0; begin < lens; begin += spotArgsLimit { + end := begin + spotArgsLimit + if end > lens { + end = lens + } + + topics := []string{} + for _, subscription := range s.Subscriptions[begin:end] { + topic, err := s.convertSubscription(subscription) + if err != nil { + logger.WithError(err).Errorf("convert error, subscription: %+v", subscription) + return err + } + + topics = append(topics, topic) + } + + logger.Infof("%s channels: %+v", opType, topics) + if err := s.Conn.WriteJSON(WebsocketOp{ + Op: opType, + Args: topics, + }); err != nil { + logger.WithError(err).Error("failed to send request") + return err + } + } + + return nil +} + +func (s *Stream) Unsubscribe() { + // errors are handled in the syncSubscriptions, so they are skipped here. + _ = s.syncSubscriptions(WsOpTypeUnsubscribe) + s.Resubscribe(func(old []types.Subscription) (new []types.Subscription, err error) { + // clear the subscriptions + return []types.Subscription{}, nil + }) +} + func (s *Stream) createEndpoint(_ context.Context) (string, error) { var url string if s.PublicOnly { @@ -205,34 +251,8 @@ func (s *Stream) ping(ctx context.Context, conn *websocket.Conn, cancelFunc cont func (s *Stream) handlerConnect() { if s.PublicOnly { - if len(s.Subscriptions) == 0 { - log.Debug("no subscriptions") - return - } - - var topics []string - - for _, subscription := range s.Subscriptions { - topic, err := s.convertSubscription(subscription) - if err != nil { - log.WithError(err).Errorf("subscription convert error") - continue - } - - topics = append(topics, topic) - } - if len(topics) > spotArgsLimit { - log.Debugf("topics exceeds limit: %d, drop of: %v", spotArgsLimit, topics[spotArgsLimit:]) - topics = topics[:spotArgsLimit] - } - log.Infof("subscribing channels: %+v", topics) - if err := s.Conn.WriteJSON(WebsocketOp{ - Op: WsOpTypeSubscribe, - Args: topics, - }); err != nil { - log.WithError(err).Error("failed to send subscription request") - return - } + // errors are handled in the syncSubscriptions, so they are skipped here. + _ = s.syncSubscriptions(WsOpTypeSubscribe) } else { expires := strconv.FormatInt(time.Now().Add(wsAuthRequest).In(time.UTC).UnixMilli(), 10) diff --git a/pkg/exchange/bybit/stream_test.go b/pkg/exchange/bybit/stream_test.go index 04de5d373..12e483765 100644 --- a/pkg/exchange/bybit/stream_test.go +++ b/pkg/exchange/bybit/stream_test.go @@ -7,6 +7,7 @@ import ( "os" "strconv" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -38,6 +39,20 @@ func TestStream(t *testing.T) { t.Skip() s := getTestClientOrSkip(t) + symbols := []string{ + "BTCUSDT", + "ETHUSDT", + "DOTUSDT", + "ADAUSDT", + "AAVEUSDT", + "APTUSDT", + "ATOMUSDT", + "AXSUSDT", + "BNBUSDT", + "SOLUSDT", + "DOGEUSDT", + } + t.Run("Auth test", func(t *testing.T) { s.Connect(context.Background()) c := make(chan struct{}) @@ -62,6 +77,41 @@ func TestStream(t *testing.T) { <-c }) + t.Run("book test on unsubscribe and reconnect", func(t *testing.T) { + for _, symbol := range symbols { + s.Subscribe(types.BookChannel, symbol, types.SubscribeOptions{ + Depth: types.DepthLevel50, + }) + } + + s.SetPublicOnly() + err := s.Connect(context.Background()) + assert.NoError(t, err) + + s.OnBookSnapshot(func(book types.SliceOrderBook) { + t.Log("got snapshot", book) + }) + s.OnBookUpdate(func(book types.SliceOrderBook) { + t.Log("got update", book) + }) + + <-time.After(2 * time.Second) + + s.Unsubscribe() + for _, symbol := range symbols { + s.Subscribe(types.BookChannel, symbol, types.SubscribeOptions{ + Depth: types.DepthLevel50, + }) + } + + <-time.After(2 * time.Second) + + s.Reconnect() + + c := make(chan struct{}) + <-c + }) + t.Run("wallet test", func(t *testing.T) { err := s.Connect(context.Background()) assert.NoError(t, err) diff --git a/pkg/exchange/bybit/types.go b/pkg/exchange/bybit/types.go index 04e63e3ca..c5acd4916 100644 --- a/pkg/exchange/bybit/types.go +++ b/pkg/exchange/bybit/types.go @@ -29,10 +29,11 @@ func (w *WsEvent) IsTopic() bool { type WsOpType string const ( - WsOpTypePing WsOpType = "ping" - WsOpTypePong WsOpType = "pong" - WsOpTypeAuth WsOpType = "auth" - WsOpTypeSubscribe WsOpType = "subscribe" + WsOpTypePing WsOpType = "ping" + WsOpTypePong WsOpType = "pong" + WsOpTypeAuth WsOpType = "auth" + WsOpTypeSubscribe WsOpType = "subscribe" + WsOpTypeUnsubscribe WsOpType = "unsubscribe" ) type WebsocketOp struct { @@ -73,6 +74,15 @@ func (w *WebSocketOpEvent) IsValid() error { return fmt.Errorf("unexpected response result: %+v", w) } return nil + + case WsOpTypeUnsubscribe: + // in the public channel, you can get RetMsg = 'subscribe', but in the private channel, you cannot. + // so, we only verify that success is true. + if !w.Success { + return fmt.Errorf("unexpected response result: %+v", w) + } + return nil + default: return fmt.Errorf("unexpected op type: %+v", w) } diff --git a/pkg/exchange/bybit/types_test.go b/pkg/exchange/bybit/types_test.go index 3601690fb..21fc04f4f 100644 --- a/pkg/exchange/bybit/types_test.go +++ b/pkg/exchange/bybit/types_test.go @@ -189,6 +189,19 @@ func Test_WebSocketEventIsValid(t *testing.T) { assert.NoError(t, w.IsValid()) }) + t.Run("[unsubscribe] valid with public channel", func(t *testing.T) { + expRetMsg := "subscribe" + w := &WebSocketOpEvent{ + Success: true, + RetMsg: expRetMsg, + ReqId: "", + ConnId: "test-conndid", + Op: WsOpTypeUnsubscribe, + Args: nil, + } + assert.NoError(t, w.IsValid()) + }) + t.Run("[subscribe] valid with private channel", func(t *testing.T) { w := &WebSocketOpEvent{ Success: true, @@ -214,6 +227,19 @@ func Test_WebSocketEventIsValid(t *testing.T) { assert.Equal(t, fmt.Errorf("unexpected response result: %+v", w), w.IsValid()) }) + t.Run("[unsubscribe] un-succeeds", func(t *testing.T) { + expRetMsg := "" + w := &WebSocketOpEvent{ + Success: false, + RetMsg: expRetMsg, + ReqId: "", + ConnId: "test-conndid", + Op: WsOpTypeUnsubscribe, + Args: nil, + } + assert.Equal(t, fmt.Errorf("unexpected response result: %+v", w), w.IsValid()) + }) + t.Run("[auth] valid", func(t *testing.T) { w := &WebSocketOpEvent{ Success: true,