mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-26 00:35:15 +00:00
ws: implement base websocket client
This commit is contained in:
parent
03d7290e03
commit
bf97af34f3
98
pkg/service/websocket.go
Normal file
98
pkg/service/websocket.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
//go:generate callbackgen -type WebsocketClientBase
|
||||
type WebsocketClientBase struct {
|
||||
baseURL string
|
||||
|
||||
// mu protects conn
|
||||
mu sync.Mutex
|
||||
conn *websocket.Conn
|
||||
reconnectC chan struct{}
|
||||
reconnectDuration time.Duration
|
||||
|
||||
connectedCallbacks []func(conn *websocket.Conn)
|
||||
disconnectedCallbacks []func(conn *websocket.Conn)
|
||||
messageCallbacks []func(message []byte)
|
||||
errorCallbacks []func(err error)
|
||||
}
|
||||
|
||||
func NewWebsocketClientBase(baseURL string, reconnectDuration time.Duration) *WebsocketClientBase {
|
||||
return &WebsocketClientBase{
|
||||
baseURL: baseURL,
|
||||
reconnectC: make(chan struct{}, 1),
|
||||
reconnectDuration: reconnectDuration,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WebsocketClientBase) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-s.reconnectC:
|
||||
time.Sleep(s.reconnectDuration)
|
||||
if err := s.connect(ctx); err != nil {
|
||||
s.reconnect()
|
||||
}
|
||||
default:
|
||||
conn := s.Conn()
|
||||
mt, msg, err := conn.ReadMessage()
|
||||
|
||||
if err != nil {
|
||||
s.reconnect()
|
||||
continue
|
||||
}
|
||||
|
||||
if mt != websocket.TextMessage {
|
||||
continue
|
||||
}
|
||||
|
||||
s.EmitMessage(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WebsocketClientBase) Connect(ctx context.Context) error {
|
||||
if err := s.connect(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
go s.Listen(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WebsocketClientBase) reconnect() {
|
||||
select {
|
||||
case s.reconnectC <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WebsocketClientBase) connect(ctx context.Context) error {
|
||||
dialer := websocket.DefaultDialer
|
||||
conn, _, err := dialer.DialContext(ctx, s.baseURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.conn = conn
|
||||
s.mu.Unlock()
|
||||
|
||||
s.EmitConnected(conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WebsocketClientBase) Conn() *websocket.Conn {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.conn
|
||||
}
|
47
pkg/service/websocketclientbase_callbacks.go
Normal file
47
pkg/service/websocketclientbase_callbacks.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
// Code generated by "callbackgen -type WebsocketClientBase"; DO NOT EDIT.
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func (s *WebsocketClientBase) OnConnected(cb func(conn *websocket.Conn)) {
|
||||
s.connectedCallbacks = append(s.connectedCallbacks, cb)
|
||||
}
|
||||
|
||||
func (s *WebsocketClientBase) EmitConnected(conn *websocket.Conn) {
|
||||
for _, cb := range s.connectedCallbacks {
|
||||
cb(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WebsocketClientBase) OnDisconnected(cb func(conn *websocket.Conn)) {
|
||||
s.disconnectedCallbacks = append(s.disconnectedCallbacks, cb)
|
||||
}
|
||||
|
||||
func (s *WebsocketClientBase) EmitDisconnected(conn *websocket.Conn) {
|
||||
for _, cb := range s.disconnectedCallbacks {
|
||||
cb(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WebsocketClientBase) OnMessage(cb func(message []byte)) {
|
||||
s.messageCallbacks = append(s.messageCallbacks, cb)
|
||||
}
|
||||
|
||||
func (s *WebsocketClientBase) EmitMessage(message []byte) {
|
||||
for _, cb := range s.messageCallbacks {
|
||||
cb(message)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WebsocketClientBase) OnError(cb func(err error)) {
|
||||
s.errorCallbacks = append(s.errorCallbacks, cb)
|
||||
}
|
||||
|
||||
func (s *WebsocketClientBase) EmitError(err error) {
|
||||
for _, cb := range s.errorCallbacks {
|
||||
cb(err)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user