diff --git a/pkg/exchange/max/maxapi/websocket.go b/pkg/exchange/max/maxapi/websocket.go index 19e9acb5b..f005416f2 100644 --- a/pkg/exchange/max/maxapi/websocket.go +++ b/pkg/exchange/max/maxapi/websocket.go @@ -3,11 +3,13 @@ package max import ( "context" "fmt" + "sync" "time" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/pkg/errors" + log "github.com/sirupsen/logrus" ) var WebSocketURL = "wss://max-stream.maicoin.com/ws" @@ -42,6 +44,7 @@ var UnsubscribeAction = "unsubscribe" type WebSocketService struct { baseURL, key, secret string + mu sync.Mutex conn *websocket.Conn reconnectC chan struct{} @@ -90,7 +93,8 @@ func (s *WebSocketService) Connect(ctx context.Context) error { if err := s.connect(ctx); err != nil { return err } - go s.read(ctx) + + go s.reconnector(ctx) return nil } @@ -107,6 +111,9 @@ func (s *WebSocketService) Auth() error { } func (s *WebSocketService) connect(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + dialer := websocket.DefaultDialer conn, _, err := dialer.DialContext(ctx, s.baseURL, nil) if err != nil { @@ -116,6 +123,8 @@ func (s *WebSocketService) connect(ctx context.Context) error { s.conn = conn s.EmitConnect(conn) + go s.read(ctx) + return nil } @@ -126,7 +135,7 @@ func (s *WebSocketService) emitReconnect() { } } -func (s *WebSocketService) read(ctx context.Context) { +func (s *WebSocketService) reconnector(ctx context.Context) { for { select { case <-ctx.Done(): @@ -137,12 +146,29 @@ func (s *WebSocketService) read(ctx context.Context) { if err := s.connect(ctx); err != nil { s.emitReconnect() } + } + } +} + +func (s *WebSocketService) read(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return default: + s.mu.Lock() mt, msg, err := s.conn.ReadMessage() + s.mu.Unlock() if err != nil { - s.emitReconnect() + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { + // emit reconnect to start a new connection + s.emitReconnect() + return + } + + log.WithError(err).Error("websocket error") continue } @@ -219,9 +245,9 @@ func (s *WebSocketService) Reconnect() { // (Internal public method) func (s *WebSocketService) Subscribe(channel, market string, options SubscribeOptions) { s.AddSubscription(Subscription{ - Channel: channel, - Market: market, - Depth: options.Depth, + Channel: channel, + Market: market, + Depth: options.Depth, Resolution: options.Resolution, }) }