ws: implement base websocket client

This commit is contained in:
ycdesu 2021-02-27 16:46:43 +08:00
parent 03d7290e03
commit bf97af34f3
2 changed files with 145 additions and 0 deletions

98
pkg/service/websocket.go Normal file
View File

@ -0,0 +1,98 @@
package service
import (
"context"
"sync"
"time"
"github.com/gorilla/websocket"
)
//go:generate callbackgen -type WebsocketClientBase
type WebsocketClientBase struct {
baseURL string
// mu protects conn
mu sync.Mutex
conn *websocket.Conn
reconnectC chan struct{}
reconnectDuration time.Duration
connectedCallbacks []func(conn *websocket.Conn)
disconnectedCallbacks []func(conn *websocket.Conn)
messageCallbacks []func(message []byte)
errorCallbacks []func(err error)
}
func NewWebsocketClientBase(baseURL string, reconnectDuration time.Duration) *WebsocketClientBase {
return &WebsocketClientBase{
baseURL: baseURL,
reconnectC: make(chan struct{}, 1),
reconnectDuration: reconnectDuration,
}
}
func (s *WebsocketClientBase) Listen(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-s.reconnectC:
time.Sleep(s.reconnectDuration)
if err := s.connect(ctx); err != nil {
s.reconnect()
}
default:
conn := s.Conn()
mt, msg, err := conn.ReadMessage()
if err != nil {
s.reconnect()
continue
}
if mt != websocket.TextMessage {
continue
}
s.EmitMessage(msg)
}
}
}
func (s *WebsocketClientBase) Connect(ctx context.Context) error {
if err := s.connect(ctx); err != nil {
return err
}
go s.Listen(ctx)
return nil
}
func (s *WebsocketClientBase) reconnect() {
select {
case s.reconnectC <- struct{}{}:
default:
}
}
func (s *WebsocketClientBase) connect(ctx context.Context) error {
dialer := websocket.DefaultDialer
conn, _, err := dialer.DialContext(ctx, s.baseURL, nil)
if err != nil {
return err
}
s.mu.Lock()
s.conn = conn
s.mu.Unlock()
s.EmitConnected(conn)
return nil
}
func (s *WebsocketClientBase) Conn() *websocket.Conn {
s.mu.Lock()
defer s.mu.Unlock()
return s.conn
}

View File

@ -0,0 +1,47 @@
// Code generated by "callbackgen -type WebsocketClientBase"; DO NOT EDIT.
package service
import (
"github.com/gorilla/websocket"
)
func (s *WebsocketClientBase) OnConnected(cb func(conn *websocket.Conn)) {
s.connectedCallbacks = append(s.connectedCallbacks, cb)
}
func (s *WebsocketClientBase) EmitConnected(conn *websocket.Conn) {
for _, cb := range s.connectedCallbacks {
cb(conn)
}
}
func (s *WebsocketClientBase) OnDisconnected(cb func(conn *websocket.Conn)) {
s.disconnectedCallbacks = append(s.disconnectedCallbacks, cb)
}
func (s *WebsocketClientBase) EmitDisconnected(conn *websocket.Conn) {
for _, cb := range s.disconnectedCallbacks {
cb(conn)
}
}
func (s *WebsocketClientBase) OnMessage(cb func(message []byte)) {
s.messageCallbacks = append(s.messageCallbacks, cb)
}
func (s *WebsocketClientBase) EmitMessage(message []byte) {
for _, cb := range s.messageCallbacks {
cb(message)
}
}
func (s *WebsocketClientBase) OnError(cb func(err error)) {
s.errorCallbacks = append(s.errorCallbacks, cb)
}
func (s *WebsocketClientBase) EmitError(err error) {
for _, cb := range s.errorCallbacks {
cb(err)
}
}