Move ReconnectC to the StandardStream

This commit is contained in:
c9s 2021-05-27 14:42:14 +08:00
parent 8c50ce725c
commit 18045bb1e7
5 changed files with 105 additions and 49 deletions

View File

@ -42,11 +42,10 @@ type Stream struct {
types.StandardStream types.StandardStream
Client *binance.Client Client *binance.Client
ListenKey string ListenKey string
Conn *websocket.Conn Conn *websocket.Conn
connLock sync.Mutex ConnLock sync.Mutex
reconnectC chan struct{}
connCtx context.Context connCtx context.Context
connCancel context.CancelFunc connCancel context.CancelFunc
@ -68,9 +67,11 @@ type Stream struct {
func NewStream(client *binance.Client) *Stream { func NewStream(client *binance.Client) *Stream {
stream := &Stream{ stream := &Stream{
StandardStream: types.StandardStream{
ReconnectC: make(chan struct{}, 1),
},
Client: client, Client: client,
depthFrames: make(map[string]*DepthFrame), depthFrames: make(map[string]*DepthFrame),
reconnectC: make(chan struct{}, 1),
} }
stream.OnDepthEvent(func(e *DepthEvent) { stream.OnDepthEvent(func(e *DepthEvent) {
@ -274,13 +275,6 @@ func (s *Stream) keepaliveListenKey(ctx context.Context, listenKey string) error
return s.Client.NewKeepaliveUserStreamService().ListenKey(listenKey).Do(ctx) return s.Client.NewKeepaliveUserStreamService().ListenKey(listenKey).Do(ctx)
} }
func (s *Stream) emitReconnect() {
select {
case s.reconnectC <- struct{}{}:
default:
}
}
func (s *Stream) Connect(ctx context.Context) error { func (s *Stream) Connect(ctx context.Context) error {
err := s.connect(ctx) err := s.connect(ctx)
if err != nil { if err != nil {
@ -300,7 +294,7 @@ func (s *Stream) reconnector(ctx context.Context) {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-s.reconnectC: case <-s.ReconnectC:
// ensure the previous context is cancelled // ensure the previous context is cancelled
if s.connCancel != nil { if s.connCancel != nil {
s.connCancel() s.connCancel()
@ -311,7 +305,7 @@ func (s *Stream) reconnector(ctx context.Context) {
if err := s.connect(ctx); err != nil { if err := s.connect(ctx); err != nil {
log.WithError(err).Errorf("connect error, try to reconnect again...") log.WithError(err).Errorf("connect error, try to reconnect again...")
s.emitReconnect() s.Reconnect()
} }
} }
} }
@ -319,7 +313,7 @@ func (s *Stream) reconnector(ctx context.Context) {
func (s *Stream) connect(ctx context.Context) error { func (s *Stream) connect(ctx context.Context) error {
// should only start one connection one time, so we lock the mutex // should only start one connection one time, so we lock the mutex
s.connLock.Lock() s.ConnLock.Lock()
// create a new context // create a new context
s.connCtx, s.connCancel = context.WithCancel(ctx) s.connCtx, s.connCancel = context.WithCancel(ctx)
@ -332,7 +326,7 @@ func (s *Stream) connect(ctx context.Context) error {
listenKey, err := s.fetchListenKey(ctx) listenKey, err := s.fetchListenKey(ctx)
if err != nil { if err != nil {
s.connCancel() s.connCancel()
s.connLock.Unlock() s.ConnLock.Unlock()
return err return err
} }
@ -346,14 +340,14 @@ func (s *Stream) connect(ctx context.Context) error {
conn, err := s.dial(s.ListenKey) conn, err := s.dial(s.ListenKey)
if err != nil { if err != nil {
s.connCancel() s.connCancel()
s.connLock.Unlock() s.ConnLock.Unlock()
return err return err
} }
log.Infof("websocket connected") log.Infof("websocket connected")
s.Conn = conn s.Conn = conn
s.connLock.Unlock() s.ConnLock.Unlock()
s.EmitConnect() s.EmitConnect()
@ -374,12 +368,12 @@ func (s *Stream) ping(ctx context.Context) {
return return
case <-pingTicker.C: case <-pingTicker.C:
s.connLock.Lock() s.ConnLock.Lock()
if err := s.Conn.WriteControl(websocket.PingMessage, []byte("hb"), time.Now().Add(3*time.Second)); err != nil { if err := s.Conn.WriteControl(websocket.PingMessage, []byte("hb"), time.Now().Add(3*time.Second)); err != nil {
log.WithError(err).Error("ping error", err) log.WithError(err).Error("ping error", err)
s.emitReconnect() s.Reconnect()
} }
s.connLock.Unlock() s.ConnLock.Unlock()
} }
} }
} }
@ -405,7 +399,7 @@ func (s *Stream) listenKeyKeepAlive(ctx context.Context, listenKey string) {
case <-keepAliveTicker.C: case <-keepAliveTicker.C:
if err := s.keepaliveListenKey(ctx, listenKey); err != nil { if err := s.keepaliveListenKey(ctx, listenKey); err != nil {
log.WithError(err).Errorf("listen key keep-alive error: %v key: %s", err, maskListenKey(listenKey)) log.WithError(err).Errorf("listen key keep-alive error: %v key: %s", err, maskListenKey(listenKey))
s.emitReconnect() s.Reconnect()
return return
} }
@ -428,13 +422,13 @@ func (s *Stream) read(ctx context.Context) {
return return
default: default:
s.connLock.Lock() s.ConnLock.Lock()
if err := s.Conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := s.Conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
log.WithError(err).Errorf("set read deadline error: %s", err.Error()) log.WithError(err).Errorf("set read deadline error: %s", err.Error())
} }
mt, message, err := s.Conn.ReadMessage() mt, message, err := s.Conn.ReadMessage()
s.connLock.Unlock() s.ConnLock.Unlock()
if err != nil { if err != nil {
// if it's a network timeout error, we should re-connect // if it's a network timeout error, we should re-connect
@ -448,17 +442,17 @@ func (s *Stream) read(ctx context.Context) {
// for unexpected close error, we should re-connect // for unexpected close error, we should re-connect
// emit reconnect to start a new connection // emit reconnect to start a new connection
s.emitReconnect() s.Reconnect()
return return
case net.Error: case net.Error:
log.WithError(err).Error("network error") log.WithError(err).Error("network error")
s.emitReconnect() s.Reconnect()
return return
default: default:
log.WithError(err).Error("unexpected connection error") log.WithError(err).Error("unexpected connection error")
s.emitReconnect() s.Reconnect()
return return
} }
} }
@ -538,9 +532,9 @@ func (s *Stream) Close() error {
s.connCancel() s.connCancel()
} }
s.connLock.Lock() s.ConnLock.Lock()
err := s.Conn.Close() err := s.Conn.Close()
s.connLock.Unlock() s.ConnLock.Unlock()
return err return err
} }

View File

@ -1,6 +1,7 @@
package okex package okex
import ( import (
"fmt"
"strings" "strings"
"github.com/c9s/bbgo/pkg/exchange/okex/okexapi" "github.com/c9s/bbgo/pkg/exchange/okex/okexapi"
@ -48,3 +49,29 @@ func toGlobalBalance(balanceSummaries []okexapi.BalanceSummary) types.BalanceMap
} }
return balanceMap return balanceMap
} }
type WebsocketSubscription struct {
Channel string `json:"channel"`
InstrumentID string `json:"instId"`
}
func convertSubscription(s types.Subscription) (WebsocketSubscription, error) {
// binance uses lower case symbol name,
// for kline, it's "<symbol>@kline_<interval>"
// for depth, it's "<symbol>@depth OR <symbol>@depth@100ms"
switch s.Channel {
case types.KLineChannel:
return WebsocketSubscription{
Channel: "candle" + s.Options.Interval,
InstrumentID: toLocalSymbol(s.Symbol),
}, nil
case types.BookChannel:
return WebsocketSubscription{
Channel: "books",
InstrumentID: toLocalSymbol(s.Symbol),
}, nil
}
return WebsocketSubscription{}, fmt.Errorf("unsupported public stream channel %s", s.Channel)
}

View File

@ -0,0 +1 @@
package okex

View File

@ -11,16 +11,18 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
type WebsocketOp struct {
Op string `json:"op"`
Args interface{} `json:"args"`
}
//go:generate callbackgen -type Stream -interface //go:generate callbackgen -type Stream -interface
type Stream struct { type Stream struct {
types.StandardStream types.StandardStream
Client *okexapi.RestClient Client *okexapi.RestClient
ListenKey string
Conn *websocket.Conn Conn *websocket.Conn
connLock sync.Mutex connLock sync.Mutex
reconnectC chan struct{}
connCtx context.Context connCtx context.Context
connCancel context.CancelFunc connCancel context.CancelFunc
@ -31,9 +33,39 @@ type Stream struct {
func NewStream(client *okexapi.RestClient) *Stream { func NewStream(client *okexapi.RestClient) *Stream {
stream := &Stream{ stream := &Stream{
Client: client, Client: client,
reconnectC: make(chan struct{}, 1), StandardStream: types.StandardStream{
ReconnectC: make(chan struct{}, 1),
},
} }
stream.OnConnect(func() {
var subs []WebsocketSubscription
for _, subscription := range stream.Subscriptions {
sub, err := convertSubscription(subscription)
if err != nil {
log.WithError(err).Errorf("subscription convert error")
continue
}
subs = append(subs, sub)
}
if len(subs) == 0 {
return
}
log.Infof("subscribing channels: %+v", subs)
err := stream.Conn.WriteJSON(WebsocketOp{
Op: "subscribe",
Args: subs,
})
if err != nil {
log.WithError(err).Error("subscribe error")
}
})
return stream return stream
} }
@ -58,20 +90,13 @@ func (s *Stream) Connect(ctx context.Context) error {
return nil return nil
} }
func (s *Stream) emitReconnect() {
select {
case s.reconnectC <- struct{}{}:
default:
}
}
func (s *Stream) reconnector(ctx context.Context) { func (s *Stream) reconnector(ctx context.Context) {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-s.reconnectC: case <-s.ReconnectC:
// ensure the previous context is cancelled // ensure the previous context is cancelled
if s.connCancel != nil { if s.connCancel != nil {
s.connCancel() s.connCancel()
@ -82,7 +107,7 @@ func (s *Stream) reconnector(ctx context.Context) {
if err := s.connect(ctx); err != nil { if err := s.connect(ctx); err != nil {
log.WithError(err).Errorf("connect error, try to reconnect again...") log.WithError(err).Errorf("connect error, try to reconnect again...")
s.emitReconnect() s.Reconnect()
} }
} }
} }
@ -175,17 +200,17 @@ func (s *Stream) read(ctx context.Context) {
// for unexpected close error, we should re-connect // for unexpected close error, we should re-connect
// emit reconnect to start a new connection // emit reconnect to start a new connection
s.emitReconnect() s.Reconnect()
return return
case net.Error: case net.Error:
log.WithError(err).Error("network error") log.WithError(err).Error("network error")
s.emitReconnect() s.Reconnect()
return return
default: default:
log.WithError(err).Error("unexpected connection error") log.WithError(err).Error("unexpected connection error")
s.emitReconnect() s.Reconnect()
return return
} }
} }
@ -195,7 +220,7 @@ func (s *Stream) read(ctx context.Context) {
continue continue
} }
log.Debug(string(message)) log.Infof(string(message))
} }
} }
} }
@ -215,7 +240,7 @@ func (s *Stream) ping(ctx context.Context) {
s.connLock.Lock() s.connLock.Lock()
if err := s.Conn.WriteControl(websocket.PingMessage, []byte("hb"), time.Now().Add(3*time.Second)); err != nil { if err := s.Conn.WriteControl(websocket.PingMessage, []byte("hb"), time.Now().Add(3*time.Second)); err != nil {
log.WithError(err).Error("ping error", err) log.WithError(err).Error("ping error", err)
s.emitReconnect() s.Reconnect()
} }
s.connLock.Unlock() s.connLock.Unlock()
} }

View File

@ -21,6 +21,8 @@ var KLineChannel = Channel("kline")
//go:generate callbackgen -type StandardStream -interface //go:generate callbackgen -type StandardStream -interface
type StandardStream struct { type StandardStream struct {
ReconnectC chan struct{}
Subscriptions []Subscription Subscriptions []Subscription
startCallbacks []func() startCallbacks []func()
@ -57,6 +59,13 @@ func (stream *StandardStream) Subscribe(channel Channel, symbol string, options
}) })
} }
func (stream *StandardStream) Reconnect() {
select {
case stream.ReconnectC <- struct{}{}:
default:
}
}
// SubscribeOptions provides the standard stream options // SubscribeOptions provides the standard stream options
type SubscribeOptions struct { type SubscribeOptions struct {
Interval string `json:"interval,omitempty"` Interval string `json:"interval,omitempty"`