mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-23 07:15:15 +00:00
99 lines
1.8 KiB
Go
99 lines
1.8 KiB
Go
|
package service
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/gorilla/websocket"
|
||
|
)
|
||
|
|
||
|
//go:generate callbackgen -type WebsocketClientBase
|
||
|
type WebsocketClientBase struct {
|
||
|
baseURL string
|
||
|
|
||
|
// mu protects conn
|
||
|
mu sync.Mutex
|
||
|
conn *websocket.Conn
|
||
|
reconnectC chan struct{}
|
||
|
reconnectDuration time.Duration
|
||
|
|
||
|
connectedCallbacks []func(conn *websocket.Conn)
|
||
|
disconnectedCallbacks []func(conn *websocket.Conn)
|
||
|
messageCallbacks []func(message []byte)
|
||
|
errorCallbacks []func(err error)
|
||
|
}
|
||
|
|
||
|
func NewWebsocketClientBase(baseURL string, reconnectDuration time.Duration) *WebsocketClientBase {
|
||
|
return &WebsocketClientBase{
|
||
|
baseURL: baseURL,
|
||
|
reconnectC: make(chan struct{}, 1),
|
||
|
reconnectDuration: reconnectDuration,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *WebsocketClientBase) Listen(ctx context.Context) {
|
||
|
for {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return
|
||
|
case <-s.reconnectC:
|
||
|
time.Sleep(s.reconnectDuration)
|
||
|
if err := s.connect(ctx); err != nil {
|
||
|
s.reconnect()
|
||
|
}
|
||
|
default:
|
||
|
conn := s.Conn()
|
||
|
mt, msg, err := conn.ReadMessage()
|
||
|
|
||
|
if err != nil {
|
||
|
s.reconnect()
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
if mt != websocket.TextMessage {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
s.EmitMessage(msg)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *WebsocketClientBase) Connect(ctx context.Context) error {
|
||
|
if err := s.connect(ctx); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
go s.Listen(ctx)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *WebsocketClientBase) reconnect() {
|
||
|
select {
|
||
|
case s.reconnectC <- struct{}{}:
|
||
|
default:
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *WebsocketClientBase) connect(ctx context.Context) error {
|
||
|
dialer := websocket.DefaultDialer
|
||
|
conn, _, err := dialer.DialContext(ctx, s.baseURL, nil)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
s.mu.Lock()
|
||
|
s.conn = conn
|
||
|
s.mu.Unlock()
|
||
|
|
||
|
s.EmitConnected(conn)
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *WebsocketClientBase) Conn() *websocket.Conn {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
return s.conn
|
||
|
}
|