From 25388246617a202539721f167bcd18a0f3baa658 Mon Sep 17 00:00:00 2001 From: c9s Date: Thu, 27 May 2021 01:07:25 +0800 Subject: [PATCH] okex: implement basic stream --- pkg/exchange/okex/convert.go | 3 +- pkg/exchange/okex/exchange.go | 3 +- pkg/exchange/okex/stream.go | 223 ++++++++++++++++++++++++++++++++++ 3 files changed, 226 insertions(+), 3 deletions(-) create mode 100644 pkg/exchange/okex/stream.go diff --git a/pkg/exchange/okex/convert.go b/pkg/exchange/okex/convert.go index b9cd7a30e..c7df351f6 100644 --- a/pkg/exchange/okex/convert.go +++ b/pkg/exchange/okex/convert.go @@ -22,7 +22,6 @@ func toLocalSymbol(symbol string) string { return symbol } - func toGlobalTicker(marketTicker okexapi.MarketTicker) *types.Ticker { return &types.Ticker{ Time: marketTicker.Timestamp.Time(), @@ -42,7 +41,7 @@ func toGlobalBalance(balanceSummaries []okexapi.BalanceSummary) types.BalanceMap for _, balanceDetail := range balanceSummary.Details { balanceMap[balanceDetail.Currency] = types.Balance{ Currency: balanceDetail.Currency, - Available: balanceDetail.Available, + Available: balanceDetail.CashBalance, Locked: balanceDetail.Frozen, } } diff --git a/pkg/exchange/okex/exchange.go b/pkg/exchange/okex/exchange.go index 564c1d450..5bfc5658b 100644 --- a/pkg/exchange/okex/exchange.go +++ b/pkg/exchange/okex/exchange.go @@ -31,6 +31,7 @@ func New(key, secret, passphrase string) *Exchange { key: key, secret: secret, passphrase: passphrase, + client: client, } } @@ -162,7 +163,7 @@ func (e *Exchange) CancelOrders(ctx context.Context, orders ...types.Order) erro } func (e *Exchange) NewStream() types.Stream { - panic("implement me") + return NewStream(e.client) } func (e *Exchange) QueryKLines(ctx context.Context, symbol string, interval types.Interval, options types.KLineQueryOptions) ([]types.KLine, error) { diff --git a/pkg/exchange/okex/stream.go b/pkg/exchange/okex/stream.go new file mode 100644 index 000000000..09bec5f45 --- /dev/null +++ b/pkg/exchange/okex/stream.go @@ -0,0 +1,223 @@ +package okex + +import ( + "context" + "net" + "sync" + "time" + + "github.com/c9s/bbgo/pkg/exchange/okex/okexapi" + "github.com/c9s/bbgo/pkg/types" + "github.com/gorilla/websocket" +) + +//go:generate callbackgen -type Stream -interface +type Stream struct { + types.StandardStream + + Client *okexapi.RestClient + ListenKey string + Conn *websocket.Conn + connLock sync.Mutex + reconnectC chan struct{} + + connCtx context.Context + connCancel context.CancelFunc + + publicOnly bool + + klineCallbacks []func() +} + +func NewStream(client *okexapi.RestClient) *Stream { + stream := &Stream{ + Client: client, + reconnectC: make(chan struct{}, 1), + } + return stream +} + +func (s *Stream) SetPublicOnly() { + s.publicOnly = true +} + +func (s *Stream) Close() error { + return nil +} + +func (s *Stream) Connect(ctx context.Context) error { + err := s.connect(ctx) + if err != nil { + return err + } + + // start one re-connector goroutine with the base context + go s.reconnector(ctx) + + s.EmitStart() + return nil +} + +func (s *Stream) emitReconnect() { + select { + case s.reconnectC <- struct{}{}: + default: + } +} + +func (s *Stream) reconnector(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + + case <-s.reconnectC: + // ensure the previous context is cancelled + if s.connCancel != nil { + s.connCancel() + } + + log.Warnf("received reconnect signal, reconnecting...") + time.Sleep(3 * time.Second) + + if err := s.connect(ctx); err != nil { + log.WithError(err).Errorf("connect error, try to reconnect again...") + s.emitReconnect() + } + } + } +} + +func (s *Stream) dial() (*websocket.Conn, error) { + var url string + if s.publicOnly { + url = okexapi.PublicWebSocketURL + } else { + url = okexapi.PrivateWebSocketURL + } + + conn, _, err := websocket.DefaultDialer.Dial(url, nil) + if err != nil { + return nil, err + } + + // use the default ping handler + conn.SetPingHandler(nil) + + return conn, nil +} + +func (s *Stream) connect(ctx context.Context) error { + // should only start one connection one time, so we lock the mutex + s.connLock.Lock() + + // create a new context + s.connCtx, s.connCancel = context.WithCancel(ctx) + + if s.publicOnly { + log.Infof("stream is set to public only mode") + } else { + log.Infof("request listen key for creating user data stream...") + } + + // when in public mode, the listen key is an empty string + conn, err := s.dial() + if err != nil { + s.connCancel() + s.connLock.Unlock() + return err + } + + log.Infof("websocket connected") + + s.Conn = conn + s.connLock.Unlock() + + s.EmitConnect() + + go s.read(s.connCtx) + go s.ping(s.connCtx) + return nil +} + +func (s *Stream) read(ctx context.Context) { + defer func() { + if s.connCancel != nil { + s.connCancel() + } + s.EmitDisconnect() + }() + + for { + select { + + case <-ctx.Done(): + return + + default: + s.connLock.Lock() + if err := s.Conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + log.WithError(err).Errorf("set read deadline error: %s", err.Error()) + } + + mt, message, err := s.Conn.ReadMessage() + s.connLock.Unlock() + + if err != nil { + // if it's a network timeout error, we should re-connect + switch err := err.(type) { + + // if it's a websocket related error + case *websocket.CloseError: + if err.Code == websocket.CloseNormalClosure { + return + } + + // for unexpected close error, we should re-connect + // emit reconnect to start a new connection + s.emitReconnect() + return + + case net.Error: + log.WithError(err).Error("network error") + s.emitReconnect() + return + + default: + log.WithError(err).Error("unexpected connection error") + s.emitReconnect() + return + } + } + + // skip non-text messages + if mt != websocket.TextMessage { + continue + } + + log.Debug(string(message)) + } + } +} + +func (s *Stream) ping(ctx context.Context) { + pingTicker := time.NewTicker(15 * time.Second) + defer pingTicker.Stop() + + for { + select { + + case <-ctx.Done(): + log.Info("ping worker stopped") + return + + case <-pingTicker.C: + s.connLock.Lock() + if err := s.Conn.WriteControl(websocket.PingMessage, []byte("hb"), time.Now().Add(3*time.Second)); err != nil { + log.WithError(err).Error("ping error", err) + s.emitReconnect() + } + s.connLock.Unlock() + } + } +}