exchange/ws/ws.go
2024-06-26 00:59:56 +08:00

205 lines
3.9 KiB
Go

package ws
import (
"bytes"
"fmt"
"net/url"
"sync"
"time"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
)
var (
pongMsg = []byte("pong")
)
type WSInitFn func(ws *WSConn) error
type MessageFn func(message []byte) error
type PingFn func(ws *WSConn) error
type PongFn func(message []byte) bool
func defaultPingFn(ws *WSConn) error {
return ws.WriteText("ping")
}
func defaultPongFn(message []byte) bool {
return bytes.Equal(pongMsg, message)
}
type WSConn struct {
addr string
ws *websocket.Conn
initFn WSInitFn
messageFn MessageFn
pingFn PingFn
pongFn PongFn
closeCh chan int
writeMuetx sync.Mutex
wg sync.WaitGroup
}
func NewWSConnWithPingPong(addr string, initFn WSInitFn, messageFn MessageFn, ping PingFn, pong PongFn) (conn *WSConn, err error) {
conn = new(WSConn)
conn.addr = addr
conn.initFn = initFn
conn.messageFn = messageFn
conn.closeCh = make(chan int, 1)
conn.pingFn = ping
conn.pongFn = pong
err = conn.connect()
return
}
func NewWSConn(addr string, initFn WSInitFn, messageFn MessageFn) (conn *WSConn, err error) {
conn = new(WSConn)
conn.addr = addr
conn.initFn = initFn
conn.pingFn = defaultPingFn
conn.pongFn = defaultPongFn
conn.messageFn = messageFn
conn.closeCh = make(chan int, 1)
err = conn.connect()
if err != nil {
return
}
return
}
func (conn *WSConn) SetPingPongFn(ping PingFn, pong PongFn) {
conn.pingFn = ping
conn.pongFn = pong
}
func (conn *WSConn) Close() (err error) {
close(conn.closeCh)
conn.wg.Wait()
return
}
func (conn *WSConn) WriteText(value string) (err error) {
conn.writeMuetx.Lock()
if conn.ws != nil {
err = conn.ws.WriteMessage(websocket.TextMessage, []byte(value))
} else {
log.Warnf("WriteText ignore conn of %s not init", conn.addr)
}
conn.writeMuetx.Unlock()
return
}
func (conn *WSConn) WriteMsg(value interface{}) (err error) {
conn.writeMuetx.Lock()
if conn.ws != nil {
err = conn.ws.WriteJSON(value)
} else {
log.Warnf("WriteMsg ignore conn of %s not init", conn.addr)
}
conn.writeMuetx.Unlock()
return
}
func (conn *WSConn) connect() (err error) {
u, err := url.Parse(conn.addr)
if err != nil {
return
}
conn.ws, _, err = websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
err = fmt.Errorf("connect to %s failed: %w", conn.addr, err)
return
}
if conn.initFn != nil {
err = conn.initFn(conn)
if err != nil {
conn.ws.Close()
return
}
}
go conn.loop()
return
}
func (conn *WSConn) loop() {
ws := conn.ws
ch := make(chan []byte, 1024)
needReconn := make(chan bool, 1)
go conn.readLoop(ws, ch, needReconn)
var msg []byte
var err error
var lastMsgTime time.Time
ticker := time.NewTicker(time.Second * 5)
conn.wg.Add(1)
defer conn.wg.Done()
var ok bool
defer ticker.Stop()
Out:
for {
select {
case msg, ok = <-ch:
if !ok {
break Out
}
lastMsgTime = time.Now()
if conn.pongFn != nil && conn.pongFn(msg) {
continue
}
err = conn.messageFn(msg)
if err != nil {
break Out
}
case <-ticker.C:
dur := time.Since(lastMsgTime)
if dur > time.Second*5 {
if conn.pingFn != nil {
err1 := conn.pingFn(conn)
if err1 != nil {
log.Errorf("ws pingFn failed: %s", err1.Error())
}
}
}
case <-conn.closeCh:
return
}
}
reConn := <-needReconn
if reConn {
for i := 0; i != 100; i++ {
err = conn.connect()
if err == nil {
break
}
log.Errorf("ws reconnect %d to failed: %s", i, err.Error())
time.Sleep(time.Second)
}
}
}
func (conn *WSConn) readLoop(ws *websocket.Conn, ch chan []byte, needConn chan bool) {
defer func() {
ws.Close()
close(ch)
close(needConn)
}()
var message []byte
var err error
for {
select {
case <-conn.closeCh:
needConn <- false
return
default:
}
_, message, err = ws.ReadMessage()
if err != nil {
log.Printf("%s ws read error: %s", conn.addr, err.Error())
needConn <- true
return
}
ch <- message
}
}