bbgo_origin/websocket/client.go
2020-06-09 10:38:20 +08:00

369 lines
9.1 KiB
Go

package websocket
import (
"context"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
"github.com/c9s/bbgo/pkg/log"
"github.com/c9s/bbgo/pkg/util"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
const DefaultMessageBufferSize = 128
const DefaultWriteTimeout = 30 * time.Second
const DefaultReadTimeout = 30 * time.Second
var ErrReconnectContextDone = errors.New("reconnect canceled due to context done.")
var ErrReconnectFailed = errors.New("failed to reconnect")
var ErrConnectionLost = errors.New("connection lost")
var MaxReconnectRate = rate.Limit(1 / DefaultMinBackoff.Seconds())
// WebSocketClient allows to connect and receive stream data
type WebSocketClient struct {
// Url is the websocket connection location, start with ws:// or wss://
Url string
// conn is the current websocket connection, please note the connection
// object can be replaced with a new connection object when the connection
// is unexpected closed.
conn *websocket.Conn
// Dialer is used for creating the websocket connection
Dialer *websocket.Dialer
// requestHeader is used for the Dial function call. Some credential can be
// stored in the http request header for authentication
requestHeader http.Header
// messages is a read-only channel, received messages will be sent to this
// channel.
messages chan Message
readTimeout time.Duration
writeTimeout time.Duration
onConnect []func(c Client)
onDisconnect []func(c Client)
// cancel is mapped to the ctx context object
cancel func()
readerClosed chan struct{}
connected bool
mu sync.Mutex
reconnectCh chan struct{}
backoff Backoff
limiter *rate.Limiter
}
type Message struct {
// websocket.BinaryMessage or websocket.TextMessage
Type int
Body []byte
}
func (c *WebSocketClient) Messages() <-chan Message {
return c.messages
}
func (c *WebSocketClient) SetReadTimeout(timeout time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.readTimeout = timeout
}
func (c *WebSocketClient) SetWriteTimeout(timeout time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.writeTimeout = timeout
}
func (c *WebSocketClient) OnConnect(f func(c Client)) {
c.mu.Lock()
defer c.mu.Unlock()
c.onConnect = append(c.onConnect, f)
}
func (c *WebSocketClient) OnDisconnect(f func(c Client)) {
c.mu.Lock()
defer c.mu.Unlock()
c.onDisconnect = append(c.onDisconnect, f)
}
func (c *WebSocketClient) WriteTextMessage(message []byte) error {
return c.WriteMessage(websocket.TextMessage, message)
}
func (c *WebSocketClient) WriteBinaryMessage(message []byte) error {
return c.WriteMessage(websocket.BinaryMessage, message)
}
func (c *WebSocketClient) WriteMessage(messageType int, data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return ErrConnectionLost
}
if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
return err
}
return c.conn.WriteMessage(messageType, data)
}
func (c *WebSocketClient) readMessages() error {
c.mu.Lock()
if !c.connected {
c.mu.Unlock()
return ErrConnectionLost
}
timeout := c.readTimeout
conn := c.conn
c.mu.Unlock()
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
return err
}
msgtype, message, err := conn.ReadMessage()
if err != nil {
return err
}
c.messages <- Message{msgtype, message}
return nil
}
// listen starts a goroutine for reading message and tries to re-connect to the
// server when the reader returns error
//
// Please note we should always break the reader loop if there is any error
// returned from the server.
func (c *WebSocketClient) listen(ctx context.Context) {
// The life time of both channels "readerClosed" and "reconnectCh" is bound to one connection.
// Each channel should be created before loop starts and be closed after loop ends.
// "readerClosed" is used to inform "Close()" reader loop ends.
// "reconnectCh" is used to centralize reconnection logics in this reader loop.
c.mu.Lock()
c.readerClosed = make(chan struct{})
c.reconnectCh = make(chan struct{}, 1)
c.mu.Unlock()
defer func() {
c.mu.Lock()
close(c.readerClosed)
close(c.reconnectCh)
c.reconnectCh = nil
c.mu.Unlock()
}()
for {
select {
case <-ctx.Done():
return
case <-c.reconnectCh:
// it could be i/o timeout for network disconnection
// or it could be invoked from outside.
c.SetDisconnected()
var maxTries = 1
if _, response, err := c.reconnect(ctx, maxTries); err != nil {
if err == ErrReconnectContextDone {
log.Debugf("[websocket] context canceled. stop reconnecting.")
return
}
log.Warnf("[websocket] failed to reconnect after %d tries!! error: %v response: %v", maxTries, err, response)
c.Reconnect()
}
default:
if err := c.readMessages(); err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
log.Warnf("[websocket] unexpected close error reconnecting: %v", err)
}
log.Warnf("[websocket] failed to read message. error: %+v", err)
c.Reconnect()
}
}
}
}
// Reconnect triggers reconnection logics
func (c *WebSocketClient) Reconnect() {
c.mu.Lock()
defer c.mu.Unlock()
select {
// c.reconnectCh is a buffered channel with cap=1.
// At most one reconnect signal could be processed.
case c.reconnectCh <- struct{}{}:
default:
// Entering here means it is already reconnecting.
// Drop the current reconnect signal.
}
}
// Close gracefully shuts down the reader and the connection
// ctx is the context used for shutdown process.
func (c *WebSocketClient) Close() (err error) {
c.mu.Lock()
// leave the listen goroutine before we close the connection
// checking nil is to handle calling "Close" before "Connect" is called
if c.cancel != nil {
c.cancel()
}
c.mu.Unlock()
c.SetDisconnected()
// wait for the reader func to be closed
if c.readerClosed != nil {
<-c.readerClosed
}
return err
}
// reconnect tries to create a new connection from the existing dialer
func (c *WebSocketClient) reconnect(ctx context.Context, maxTries int) (*websocket.Conn, *http.Response, error) {
log.Debugf("[websocket] start reconnecting to %q", c.Url)
select {
case <-ctx.Done():
return nil, nil, ErrReconnectContextDone
default:
}
if s := util.ShouldDelay(c.limiter, DefaultMinBackoff); s > 0 {
log.Warn("[websocket] reconnect too frequently. Sleep for ", s)
time.Sleep(s)
}
log.Warnf("[websocket] reconnecting x %d to %q", c.backoff.Attempt()+1, c.Url)
conn, resp, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
if err != nil {
dur := c.backoff.Duration()
log.Warnf("failed to dial %s: %v, response: %+v. Wait for %v", c.Url, err, resp, dur)
time.Sleep(dur)
return nil, nil, ErrReconnectFailed
}
log.Infof("[websocket] reconnected to %q", c.Url)
// Reset backoff value if connected.
c.backoff.Reset()
c.setConn(conn)
c.setPingHandler(conn)
return conn, resp, err
}
// Conn returns the current active connection instance
func (c *WebSocketClient) Conn() (conn *websocket.Conn) {
c.mu.Lock()
conn = c.conn
c.mu.Unlock()
return conn
}
func (c *WebSocketClient) setConn(conn *websocket.Conn) {
// Disconnect old connection before replacing with new one.
c.SetDisconnected()
c.mu.Lock()
c.conn = conn
c.connected = true
c.mu.Unlock()
for _, f := range c.onConnect {
go f(c)
}
}
func (c *WebSocketClient) setPingHandler(conn *websocket.Conn) {
conn.SetPingHandler(func(message string) error {
if err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
return err
}
c.mu.Lock()
defer c.mu.Unlock()
return conn.SetReadDeadline(time.Now().Add(c.readTimeout))
})
}
func (c *WebSocketClient) SetDisconnected() {
c.mu.Lock()
closed := false
if c.conn != nil {
closed = true
c.conn.Close()
}
c.connected = false
c.conn = nil
c.mu.Unlock()
if closed {
// Only call disconnect callbacks when a connection is closed
for _, f := range c.onDisconnect {
go f(c)
}
}
}
func (c *WebSocketClient) IsConnected() (ret bool) {
c.mu.Lock()
ret = c.connected
c.mu.Unlock()
return ret
}
func (c *WebSocketClient) Connect(basectx context.Context) error {
// maintain a context by the client it self, so that we can manually shutdown the connection
ctx, cancel := context.WithCancel(basectx)
c.cancel = cancel
conn, _, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
if err == nil {
// setup connection only when connected
c.setConn(conn)
c.setPingHandler(conn)
}
// 1) if connection is built up, start listening for messages.
// 2) if connection is NOT ready, start reconnecting infinitely.
go c.listen(ctx)
return err
}
func New(url string, requestHeader http.Header) *WebSocketClient {
return NewWithDialer(url, websocket.DefaultDialer, requestHeader)
}
func NewWithDialer(url string, d *websocket.Dialer, requestHeader http.Header) *WebSocketClient {
limiter, err := util.NewValidLimiter(MaxReconnectRate, 1)
if err != nil {
log.WithError(err).Panic("Invalid rate limiter")
}
return &WebSocketClient{
Url: url,
Dialer: d,
requestHeader: requestHeader,
readTimeout: DefaultReadTimeout,
writeTimeout: DefaultWriteTimeout,
messages: make(chan Message, DefaultMessageBufferSize),
limiter: limiter,
}
}