diff --git a/pkg/exchange/okex/parse.go b/pkg/exchange/okex/parse.go index 193518474..c09e21c15 100644 --- a/pkg/exchange/okex/parse.go +++ b/pkg/exchange/okex/parse.go @@ -101,8 +101,10 @@ func parseWebSocketEvent(in []byte) (interface{}, error) { type WsEventType string const ( - WsEventTypeLogin = "login" - WsEventTypeError = "error" + WsEventTypeLogin = "login" + WsEventTypeError = "error" + WsEventTypeSubscribe = "subscribe" + WsEventTypeUnsubscribe = "unsubscribe" ) type WebSocketEvent struct { @@ -122,6 +124,9 @@ func (w *WebSocketEvent) IsValid() error { case WsEventTypeError: return fmt.Errorf("websocket request error, code: %s, msg: %s", w.Code, w.Message) + case WsEventTypeSubscribe, WsEventTypeUnsubscribe: + return nil + case WsEventTypeLogin: // Actually, this code is unnecessary because the events are either `Subscribe` or `Unsubscribe`, But to avoid bugs // in the exchange, we still check. diff --git a/pkg/exchange/okex/stream.go b/pkg/exchange/okex/stream.go index c27fedbfa..96b955c73 100644 --- a/pkg/exchange/okex/stream.go +++ b/pkg/exchange/okex/stream.go @@ -2,6 +2,7 @@ package okex import ( "context" + "fmt" "golang.org/x/time/rate" "strconv" "time" @@ -15,7 +16,7 @@ var ( ) type WebsocketOp struct { - Op string `json:"op"` + Op WsEventType `json:"op"` Args interface{} `json:"args"` } @@ -60,6 +61,44 @@ func NewStream(client *okexapi.RestClient) *Stream { return stream } +func (s *Stream) syncSubscriptions(opType WsEventType) error { + if opType != WsEventTypeUnsubscribe && opType != WsEventTypeSubscribe { + return fmt.Errorf("unexpected subscription type: %v", opType) + } + + logger := log.WithField("opType", opType) + var topics []WebsocketSubscription + for _, subscription := range s.Subscriptions { + topic, err := convertSubscription(subscription) + if err != nil { + logger.WithError(err).Errorf("convert error, subscription: %+v", subscription) + return err + } + + topics = append(topics, topic) + } + + logger.Infof("%s channels: %+v", opType, topics) + if err := s.Conn.WriteJSON(WebsocketOp{ + Op: opType, + Args: topics, + }); err != nil { + logger.WithError(err).Error("failed to send request") + return err + } + + return nil +} + +func (s *Stream) Unsubscribe() { + // errors are handled in the syncSubscriptions, so they are skipped here. + _ = s.syncSubscriptions(WsEventTypeUnsubscribe) + s.Resubscribe(func(old []types.Subscription) (new []types.Subscription, err error) { + // clear the subscriptions + return []types.Subscription{}, nil + }) +} + func (s *Stream) handleConnect() { if s.PublicOnly { var subs []WebsocketSubscription diff --git a/pkg/exchange/okex/stream_test.go b/pkg/exchange/okex/stream_test.go index cf0125ded..a832767cf 100644 --- a/pkg/exchange/okex/stream_test.go +++ b/pkg/exchange/okex/stream_test.go @@ -5,6 +5,7 @@ import ( "os" "strconv" "testing" + "time" "github.com/stretchr/testify/assert" @@ -93,4 +94,50 @@ func TestStream(t *testing.T) { c := make(chan struct{}) <-c }) + + t.Run("Subscribe/Unsubscribe test", func(t *testing.T) { + s.Subscribe(types.BookChannel, "BTCUSDT", types.SubscribeOptions{ + Depth: types.DepthLevel50, + }) + s.SetPublicOnly() + err := s.Connect(context.Background()) + assert.NoError(t, err) + + s.OnBookSnapshot(func(book types.SliceOrderBook) { + t.Log("got snapshot", book) + }) + s.OnBookUpdate(func(book types.SliceOrderBook) { + t.Log("got update", book) + }) + + <-time.After(5 * time.Second) + + s.Unsubscribe() + c := make(chan struct{}) + <-c + }) + + t.Run("Resubscribe test", func(t *testing.T) { + s.Subscribe(types.BookChannel, "BTCUSDT", types.SubscribeOptions{ + Depth: types.DepthLevel50, + }) + s.SetPublicOnly() + err := s.Connect(context.Background()) + assert.NoError(t, err) + + s.OnBookSnapshot(func(book types.SliceOrderBook) { + t.Log("got snapshot", book) + }) + s.OnBookUpdate(func(book types.SliceOrderBook) { + t.Log("got update", book) + }) + + <-time.After(5 * time.Second) + + s.Resubscribe(func(old []types.Subscription) (new []types.Subscription, err error) { + return old, nil + }) + c := make(chan struct{}) + <-c + }) }