pkg/exchange: support unsubscribe and resubscribe

This commit is contained in:
Edwin 2024-01-10 13:56:17 +08:00
parent 08fa1e9e13
commit 1dedd32f42
3 changed files with 94 additions and 3 deletions

View File

@ -101,8 +101,10 @@ func parseWebSocketEvent(in []byte) (interface{}, error) {
type WsEventType string
const (
WsEventTypeLogin = "login"
WsEventTypeError = "error"
WsEventTypeLogin = "login"
WsEventTypeError = "error"
WsEventTypeSubscribe = "subscribe"
WsEventTypeUnsubscribe = "unsubscribe"
)
type WebSocketEvent struct {
@ -122,6 +124,9 @@ func (w *WebSocketEvent) IsValid() error {
case WsEventTypeError:
return fmt.Errorf("websocket request error, code: %s, msg: %s", w.Code, w.Message)
case WsEventTypeSubscribe, WsEventTypeUnsubscribe:
return nil
case WsEventTypeLogin:
// Actually, this code is unnecessary because the events are either `Subscribe` or `Unsubscribe`, But to avoid bugs
// in the exchange, we still check.

View File

@ -2,6 +2,7 @@ package okex
import (
"context"
"fmt"
"golang.org/x/time/rate"
"strconv"
"time"
@ -15,7 +16,7 @@ var (
)
type WebsocketOp struct {
Op string `json:"op"`
Op WsEventType `json:"op"`
Args interface{} `json:"args"`
}
@ -60,6 +61,44 @@ func NewStream(client *okexapi.RestClient) *Stream {
return stream
}
func (s *Stream) syncSubscriptions(opType WsEventType) error {
if opType != WsEventTypeUnsubscribe && opType != WsEventTypeSubscribe {
return fmt.Errorf("unexpected subscription type: %v", opType)
}
logger := log.WithField("opType", opType)
var topics []WebsocketSubscription
for _, subscription := range s.Subscriptions {
topic, err := convertSubscription(subscription)
if err != nil {
logger.WithError(err).Errorf("convert error, subscription: %+v", subscription)
return err
}
topics = append(topics, topic)
}
logger.Infof("%s channels: %+v", opType, topics)
if err := s.Conn.WriteJSON(WebsocketOp{
Op: opType,
Args: topics,
}); err != nil {
logger.WithError(err).Error("failed to send request")
return err
}
return nil
}
func (s *Stream) Unsubscribe() {
// errors are handled in the syncSubscriptions, so they are skipped here.
_ = s.syncSubscriptions(WsEventTypeUnsubscribe)
s.Resubscribe(func(old []types.Subscription) (new []types.Subscription, err error) {
// clear the subscriptions
return []types.Subscription{}, nil
})
}
func (s *Stream) handleConnect() {
if s.PublicOnly {
var subs []WebsocketSubscription

View File

@ -5,6 +5,7 @@ import (
"os"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
@ -93,4 +94,50 @@ func TestStream(t *testing.T) {
c := make(chan struct{})
<-c
})
t.Run("Subscribe/Unsubscribe test", func(t *testing.T) {
s.Subscribe(types.BookChannel, "BTCUSDT", types.SubscribeOptions{
Depth: types.DepthLevel50,
})
s.SetPublicOnly()
err := s.Connect(context.Background())
assert.NoError(t, err)
s.OnBookSnapshot(func(book types.SliceOrderBook) {
t.Log("got snapshot", book)
})
s.OnBookUpdate(func(book types.SliceOrderBook) {
t.Log("got update", book)
})
<-time.After(5 * time.Second)
s.Unsubscribe()
c := make(chan struct{})
<-c
})
t.Run("Resubscribe test", func(t *testing.T) {
s.Subscribe(types.BookChannel, "BTCUSDT", types.SubscribeOptions{
Depth: types.DepthLevel50,
})
s.SetPublicOnly()
err := s.Connect(context.Background())
assert.NoError(t, err)
s.OnBookSnapshot(func(book types.SliceOrderBook) {
t.Log("got snapshot", book)
})
s.OnBookUpdate(func(book types.SliceOrderBook) {
t.Log("got update", book)
})
<-time.After(5 * time.Second)
s.Resubscribe(func(old []types.Subscription) (new []types.Subscription, err error) {
return old, nil
})
c := make(chan struct{})
<-c
})
}