mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-25 16:25:16 +00:00
369 lines
9.1 KiB
Go
369 lines
9.1 KiB
Go
package websocket
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
|
|
"github.com/c9s/bbgo/pkg/log"
|
|
"github.com/c9s/bbgo/pkg/util"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
const DefaultMessageBufferSize = 128
|
|
|
|
const DefaultWriteTimeout = 30 * time.Second
|
|
const DefaultReadTimeout = 30 * time.Second
|
|
|
|
var ErrReconnectContextDone = errors.New("reconnect canceled due to context done.")
|
|
var ErrReconnectFailed = errors.New("failed to reconnect")
|
|
var ErrConnectionLost = errors.New("connection lost")
|
|
|
|
var MaxReconnectRate = rate.Limit(1 / DefaultMinBackoff.Seconds())
|
|
|
|
// WebSocketClient allows to connect and receive stream data
|
|
type WebSocketClient struct {
|
|
// Url is the websocket connection location, start with ws:// or wss://
|
|
Url string
|
|
|
|
// conn is the current websocket connection, please note the connection
|
|
// object can be replaced with a new connection object when the connection
|
|
// is unexpected closed.
|
|
conn *websocket.Conn
|
|
|
|
// Dialer is used for creating the websocket connection
|
|
Dialer *websocket.Dialer
|
|
|
|
// requestHeader is used for the Dial function call. Some credential can be
|
|
// stored in the http request header for authentication
|
|
requestHeader http.Header
|
|
|
|
// messages is a read-only channel, received messages will be sent to this
|
|
// channel.
|
|
messages chan Message
|
|
|
|
readTimeout time.Duration
|
|
|
|
writeTimeout time.Duration
|
|
|
|
onConnect []func(c Client)
|
|
|
|
onDisconnect []func(c Client)
|
|
|
|
// cancel is mapped to the ctx context object
|
|
cancel func()
|
|
|
|
readerClosed chan struct{}
|
|
|
|
connected bool
|
|
|
|
mu sync.Mutex
|
|
|
|
reconnectCh chan struct{}
|
|
|
|
backoff Backoff
|
|
|
|
limiter *rate.Limiter
|
|
}
|
|
|
|
type Message struct {
|
|
// websocket.BinaryMessage or websocket.TextMessage
|
|
Type int
|
|
Body []byte
|
|
}
|
|
|
|
func (c *WebSocketClient) Messages() <-chan Message {
|
|
return c.messages
|
|
}
|
|
|
|
func (c *WebSocketClient) SetReadTimeout(timeout time.Duration) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.readTimeout = timeout
|
|
}
|
|
|
|
func (c *WebSocketClient) SetWriteTimeout(timeout time.Duration) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.writeTimeout = timeout
|
|
}
|
|
|
|
func (c *WebSocketClient) OnConnect(f func(c Client)) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.onConnect = append(c.onConnect, f)
|
|
}
|
|
|
|
func (c *WebSocketClient) OnDisconnect(f func(c Client)) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.onDisconnect = append(c.onDisconnect, f)
|
|
}
|
|
|
|
func (c *WebSocketClient) WriteTextMessage(message []byte) error {
|
|
return c.WriteMessage(websocket.TextMessage, message)
|
|
}
|
|
|
|
func (c *WebSocketClient) WriteBinaryMessage(message []byte) error {
|
|
return c.WriteMessage(websocket.BinaryMessage, message)
|
|
}
|
|
|
|
func (c *WebSocketClient) WriteMessage(messageType int, data []byte) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if !c.connected {
|
|
return ErrConnectionLost
|
|
}
|
|
|
|
if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
|
|
return err
|
|
}
|
|
return c.conn.WriteMessage(messageType, data)
|
|
}
|
|
|
|
func (c *WebSocketClient) readMessages() error {
|
|
c.mu.Lock()
|
|
if !c.connected {
|
|
c.mu.Unlock()
|
|
return ErrConnectionLost
|
|
}
|
|
timeout := c.readTimeout
|
|
conn := c.conn
|
|
c.mu.Unlock()
|
|
|
|
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
|
|
return err
|
|
}
|
|
msgtype, message, err := conn.ReadMessage()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.messages <- Message{msgtype, message}
|
|
return nil
|
|
}
|
|
|
|
// listen starts a goroutine for reading message and tries to re-connect to the
|
|
// server when the reader returns error
|
|
//
|
|
// Please note we should always break the reader loop if there is any error
|
|
// returned from the server.
|
|
func (c *WebSocketClient) listen(ctx context.Context) {
|
|
// The life time of both channels "readerClosed" and "reconnectCh" is bound to one connection.
|
|
// Each channel should be created before loop starts and be closed after loop ends.
|
|
// "readerClosed" is used to inform "Close()" reader loop ends.
|
|
// "reconnectCh" is used to centralize reconnection logics in this reader loop.
|
|
c.mu.Lock()
|
|
c.readerClosed = make(chan struct{})
|
|
c.reconnectCh = make(chan struct{}, 1)
|
|
c.mu.Unlock()
|
|
defer func() {
|
|
c.mu.Lock()
|
|
close(c.readerClosed)
|
|
close(c.reconnectCh)
|
|
c.reconnectCh = nil
|
|
c.mu.Unlock()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
return
|
|
|
|
case <-c.reconnectCh:
|
|
// it could be i/o timeout for network disconnection
|
|
// or it could be invoked from outside.
|
|
c.SetDisconnected()
|
|
var maxTries = 1
|
|
if _, response, err := c.reconnect(ctx, maxTries); err != nil {
|
|
if err == ErrReconnectContextDone {
|
|
log.Debugf("[websocket] context canceled. stop reconnecting.")
|
|
return
|
|
}
|
|
log.Warnf("[websocket] failed to reconnect after %d tries!! error: %v response: %v", maxTries, err, response)
|
|
c.Reconnect()
|
|
}
|
|
|
|
default:
|
|
if err := c.readMessages(); err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
|
|
log.Warnf("[websocket] unexpected close error reconnecting: %v", err)
|
|
}
|
|
|
|
log.Warnf("[websocket] failed to read message. error: %+v", err)
|
|
c.Reconnect()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Reconnect triggers reconnection logics
|
|
func (c *WebSocketClient) Reconnect() {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
select {
|
|
// c.reconnectCh is a buffered channel with cap=1.
|
|
// At most one reconnect signal could be processed.
|
|
case c.reconnectCh <- struct{}{}:
|
|
default:
|
|
// Entering here means it is already reconnecting.
|
|
// Drop the current reconnect signal.
|
|
}
|
|
}
|
|
|
|
// Close gracefully shuts down the reader and the connection
|
|
// ctx is the context used for shutdown process.
|
|
func (c *WebSocketClient) Close() (err error) {
|
|
c.mu.Lock()
|
|
// leave the listen goroutine before we close the connection
|
|
// checking nil is to handle calling "Close" before "Connect" is called
|
|
if c.cancel != nil {
|
|
c.cancel()
|
|
}
|
|
c.mu.Unlock()
|
|
c.SetDisconnected()
|
|
|
|
// wait for the reader func to be closed
|
|
if c.readerClosed != nil {
|
|
<-c.readerClosed
|
|
}
|
|
return err
|
|
}
|
|
|
|
// reconnect tries to create a new connection from the existing dialer
|
|
func (c *WebSocketClient) reconnect(ctx context.Context, maxTries int) (*websocket.Conn, *http.Response, error) {
|
|
log.Debugf("[websocket] start reconnecting to %q", c.Url)
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, nil, ErrReconnectContextDone
|
|
default:
|
|
}
|
|
|
|
if s := util.ShouldDelay(c.limiter, DefaultMinBackoff); s > 0 {
|
|
log.Warn("[websocket] reconnect too frequently. Sleep for ", s)
|
|
time.Sleep(s)
|
|
}
|
|
|
|
log.Warnf("[websocket] reconnecting x %d to %q", c.backoff.Attempt()+1, c.Url)
|
|
conn, resp, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
|
|
if err != nil {
|
|
dur := c.backoff.Duration()
|
|
log.Warnf("failed to dial %s: %v, response: %+v. Wait for %v", c.Url, err, resp, dur)
|
|
time.Sleep(dur)
|
|
return nil, nil, ErrReconnectFailed
|
|
}
|
|
|
|
log.Infof("[websocket] reconnected to %q", c.Url)
|
|
// Reset backoff value if connected.
|
|
c.backoff.Reset()
|
|
c.setConn(conn)
|
|
c.setPingHandler(conn)
|
|
|
|
return conn, resp, err
|
|
}
|
|
|
|
// Conn returns the current active connection instance
|
|
func (c *WebSocketClient) Conn() (conn *websocket.Conn) {
|
|
c.mu.Lock()
|
|
conn = c.conn
|
|
c.mu.Unlock()
|
|
return conn
|
|
}
|
|
|
|
func (c *WebSocketClient) setConn(conn *websocket.Conn) {
|
|
// Disconnect old connection before replacing with new one.
|
|
c.SetDisconnected()
|
|
|
|
c.mu.Lock()
|
|
c.conn = conn
|
|
c.connected = true
|
|
c.mu.Unlock()
|
|
for _, f := range c.onConnect {
|
|
go f(c)
|
|
}
|
|
}
|
|
|
|
func (c *WebSocketClient) setPingHandler(conn *websocket.Conn) {
|
|
conn.SetPingHandler(func(message string) error {
|
|
if err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
|
|
return err
|
|
}
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
|
})
|
|
}
|
|
|
|
func (c *WebSocketClient) SetDisconnected() {
|
|
c.mu.Lock()
|
|
closed := false
|
|
if c.conn != nil {
|
|
closed = true
|
|
c.conn.Close()
|
|
}
|
|
c.connected = false
|
|
c.conn = nil
|
|
c.mu.Unlock()
|
|
|
|
if closed {
|
|
// Only call disconnect callbacks when a connection is closed
|
|
for _, f := range c.onDisconnect {
|
|
go f(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *WebSocketClient) IsConnected() (ret bool) {
|
|
c.mu.Lock()
|
|
ret = c.connected
|
|
c.mu.Unlock()
|
|
return ret
|
|
}
|
|
|
|
func (c *WebSocketClient) Connect(basectx context.Context) error {
|
|
// maintain a context by the client it self, so that we can manually shutdown the connection
|
|
ctx, cancel := context.WithCancel(basectx)
|
|
c.cancel = cancel
|
|
|
|
conn, _, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
|
|
if err == nil {
|
|
// setup connection only when connected
|
|
c.setConn(conn)
|
|
c.setPingHandler(conn)
|
|
}
|
|
|
|
// 1) if connection is built up, start listening for messages.
|
|
// 2) if connection is NOT ready, start reconnecting infinitely.
|
|
go c.listen(ctx)
|
|
|
|
return err
|
|
}
|
|
|
|
func New(url string, requestHeader http.Header) *WebSocketClient {
|
|
return NewWithDialer(url, websocket.DefaultDialer, requestHeader)
|
|
}
|
|
|
|
func NewWithDialer(url string, d *websocket.Dialer, requestHeader http.Header) *WebSocketClient {
|
|
limiter, err := util.NewValidLimiter(MaxReconnectRate, 1)
|
|
if err != nil {
|
|
log.WithError(err).Panic("Invalid rate limiter")
|
|
}
|
|
return &WebSocketClient{
|
|
Url: url,
|
|
Dialer: d,
|
|
requestHeader: requestHeader,
|
|
readTimeout: DefaultReadTimeout,
|
|
writeTimeout: DefaultWriteTimeout,
|
|
messages: make(chan Message, DefaultMessageBufferSize),
|
|
limiter: limiter,
|
|
}
|
|
}
|