205 lines
3.9 KiB
Go
205 lines
3.9 KiB
Go
package ws
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
var (
|
|
pongMsg = []byte("pong")
|
|
)
|
|
|
|
type WSInitFn func(ws *WSConn) error
|
|
type MessageFn func(message []byte) error
|
|
type PingFn func(ws *WSConn) error
|
|
type PongFn func(message []byte) bool
|
|
|
|
func defaultPingFn(ws *WSConn) error {
|
|
return ws.WriteText("ping")
|
|
}
|
|
|
|
func defaultPongFn(message []byte) bool {
|
|
return bytes.Equal(pongMsg, message)
|
|
}
|
|
|
|
type WSConn struct {
|
|
addr string
|
|
ws *websocket.Conn
|
|
initFn WSInitFn
|
|
messageFn MessageFn
|
|
pingFn PingFn
|
|
pongFn PongFn
|
|
closeCh chan int
|
|
writeMuetx sync.Mutex
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
func NewWSConnWithPingPong(addr string, initFn WSInitFn, messageFn MessageFn, ping PingFn, pong PongFn) (conn *WSConn, err error) {
|
|
conn = new(WSConn)
|
|
conn.addr = addr
|
|
conn.initFn = initFn
|
|
conn.messageFn = messageFn
|
|
conn.closeCh = make(chan int, 1)
|
|
conn.pingFn = ping
|
|
conn.pongFn = pong
|
|
err = conn.connect()
|
|
return
|
|
}
|
|
|
|
func NewWSConn(addr string, initFn WSInitFn, messageFn MessageFn) (conn *WSConn, err error) {
|
|
conn = new(WSConn)
|
|
conn.addr = addr
|
|
conn.initFn = initFn
|
|
conn.pingFn = defaultPingFn
|
|
conn.pongFn = defaultPongFn
|
|
conn.messageFn = messageFn
|
|
conn.closeCh = make(chan int, 1)
|
|
err = conn.connect()
|
|
if err != nil {
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
func (conn *WSConn) SetPingPongFn(ping PingFn, pong PongFn) {
|
|
conn.pingFn = ping
|
|
conn.pongFn = pong
|
|
}
|
|
|
|
func (conn *WSConn) Close() (err error) {
|
|
close(conn.closeCh)
|
|
conn.wg.Wait()
|
|
return
|
|
}
|
|
|
|
func (conn *WSConn) WriteText(value string) (err error) {
|
|
conn.writeMuetx.Lock()
|
|
if conn.ws != nil {
|
|
err = conn.ws.WriteMessage(websocket.TextMessage, []byte(value))
|
|
} else {
|
|
log.Warnf("WriteText ignore conn of %s not init", conn.addr)
|
|
}
|
|
conn.writeMuetx.Unlock()
|
|
return
|
|
}
|
|
|
|
func (conn *WSConn) WriteMsg(value interface{}) (err error) {
|
|
conn.writeMuetx.Lock()
|
|
if conn.ws != nil {
|
|
err = conn.ws.WriteJSON(value)
|
|
} else {
|
|
log.Warnf("WriteMsg ignore conn of %s not init", conn.addr)
|
|
}
|
|
conn.writeMuetx.Unlock()
|
|
return
|
|
}
|
|
|
|
func (conn *WSConn) connect() (err error) {
|
|
u, err := url.Parse(conn.addr)
|
|
if err != nil {
|
|
return
|
|
}
|
|
conn.ws, _, err = websocket.DefaultDialer.Dial(u.String(), nil)
|
|
if err != nil {
|
|
err = fmt.Errorf("connect to %s failed: %w", conn.addr, err)
|
|
return
|
|
}
|
|
if conn.initFn != nil {
|
|
err = conn.initFn(conn)
|
|
if err != nil {
|
|
conn.ws.Close()
|
|
return
|
|
}
|
|
}
|
|
go conn.loop()
|
|
return
|
|
}
|
|
|
|
func (conn *WSConn) loop() {
|
|
ws := conn.ws
|
|
ch := make(chan []byte, 1024)
|
|
needReconn := make(chan bool, 1)
|
|
go conn.readLoop(ws, ch, needReconn)
|
|
var msg []byte
|
|
var err error
|
|
var lastMsgTime time.Time
|
|
ticker := time.NewTicker(time.Second * 5)
|
|
|
|
conn.wg.Add(1)
|
|
defer conn.wg.Done()
|
|
var ok bool
|
|
defer ticker.Stop()
|
|
Out:
|
|
for {
|
|
select {
|
|
case msg, ok = <-ch:
|
|
if !ok {
|
|
break Out
|
|
}
|
|
lastMsgTime = time.Now()
|
|
|
|
if conn.pongFn != nil && conn.pongFn(msg) {
|
|
continue
|
|
}
|
|
err = conn.messageFn(msg)
|
|
if err != nil {
|
|
break Out
|
|
}
|
|
case <-ticker.C:
|
|
dur := time.Since(lastMsgTime)
|
|
if dur > time.Second*5 {
|
|
if conn.pingFn != nil {
|
|
err1 := conn.pingFn(conn)
|
|
if err1 != nil {
|
|
log.Errorf("ws pingFn failed: %s", err1.Error())
|
|
}
|
|
}
|
|
}
|
|
case <-conn.closeCh:
|
|
return
|
|
}
|
|
}
|
|
reConn := <-needReconn
|
|
if reConn {
|
|
for i := 0; i != 100; i++ {
|
|
err = conn.connect()
|
|
if err == nil {
|
|
break
|
|
}
|
|
log.Errorf("ws reconnect %d to failed: %s", i, err.Error())
|
|
time.Sleep(time.Second)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (conn *WSConn) readLoop(ws *websocket.Conn, ch chan []byte, needConn chan bool) {
|
|
defer func() {
|
|
ws.Close()
|
|
close(ch)
|
|
close(needConn)
|
|
}()
|
|
var message []byte
|
|
var err error
|
|
for {
|
|
select {
|
|
case <-conn.closeCh:
|
|
needConn <- false
|
|
return
|
|
default:
|
|
}
|
|
_, message, err = ws.ReadMessage()
|
|
if err != nil {
|
|
log.Printf("%s ws read error: %s", conn.addr, err.Error())
|
|
needConn <- true
|
|
return
|
|
}
|
|
ch <- message
|
|
}
|
|
}
|