diff --git a/pkg/exchange/okex/kline_stream.go b/pkg/exchange/okex/kline_stream.go index e9b2ef356..b52727b25 100644 --- a/pkg/exchange/okex/kline_stream.go +++ b/pkg/exchange/okex/kline_stream.go @@ -47,19 +47,15 @@ func (s *KLineStream) handleConnect() { subs = append(subs, sub) } - if len(subs) == 0 { - return - } + subscribe(s.Conn, subs) +} - log.Infof("subscribing channels: %+v", subs) - err := s.Conn.WriteJSON(WebsocketOp{ - Op: "subscribe", - Args: subs, - }) - - if err != nil { - log.WithError(err).Error("subscribe error") +func (s *KLineStream) Connect(ctx context.Context) error { + if len(s.StandardStream.Subscriptions) == 0 { + log.Info("no subscriptions in kline") + return nil } + return s.StandardStream.Connect(ctx) } func (s *KLineStream) handleKLineEvent(k KLineEvent) { @@ -85,3 +81,14 @@ func (s *KLineStream) dispatchEvent(e interface{}) { s.EmitKLineEvent(*et) } } + +func (s *KLineStream) Unsubscribe() { + // errors are handled in the syncSubscriptions, so they are skipped here. + if len(s.StandardStream.Subscriptions) != 0 { + _ = syncSubscriptions(s.StandardStream.Conn, s.StandardStream.Subscriptions, WsEventTypeUnsubscribe) + } + s.Resubscribe(func(old []types.Subscription) (new []types.Subscription, err error) { + // clear the subscriptions + return []types.Subscription{}, nil + }) +} diff --git a/pkg/exchange/okex/stream.go b/pkg/exchange/okex/stream.go index 7d602c668..106d59454 100644 --- a/pkg/exchange/okex/stream.go +++ b/pkg/exchange/okex/stream.go @@ -3,10 +3,12 @@ package okex import ( "context" "fmt" - "golang.org/x/time/rate" "strconv" "time" + "github.com/gorilla/websocket" + "golang.org/x/time/rate" + "github.com/c9s/bbgo/pkg/exchange/okex/okexapi" "github.com/c9s/bbgo/pkg/exchange/retry" "github.com/c9s/bbgo/pkg/types" @@ -67,18 +69,20 @@ func NewStream(client *okexapi.RestClient, balanceProvider types.ExchangeAccount stream.OnOrderTradesEvent(stream.handleOrderDetailsEvent) stream.OnConnect(stream.handleConnect) stream.OnAuth(stream.subscribePrivateChannels(stream.emitBalanceSnapshot)) + stream.kLineStream.OnKLineClosed(stream.EmitKLineClosed) + stream.kLineStream.OnKLine(stream.EmitKLine) return stream } -func (s *Stream) syncSubscriptions(opType WsEventType) error { +func syncSubscriptions(conn *websocket.Conn, subscriptions []types.Subscription, opType WsEventType) error { if opType != WsEventTypeUnsubscribe && opType != WsEventTypeSubscribe { return fmt.Errorf("unexpected subscription type: %v", opType) } logger := log.WithField("opType", opType) var topics []WebsocketSubscription - for _, subscription := range s.Subscriptions { + for _, subscription := range subscriptions { topic, err := convertSubscription(subscription) if err != nil { logger.WithError(err).Errorf("convert error, subscription: %+v", subscription) @@ -89,7 +93,7 @@ func (s *Stream) syncSubscriptions(opType WsEventType) error { } logger.Infof("%s channels: %+v", opType, topics) - if err := s.Conn.WriteJSON(WebsocketOp{ + if err := conn.WriteJSON(WebsocketOp{ Op: opType, Args: topics, }); err != nil { @@ -102,11 +106,47 @@ func (s *Stream) syncSubscriptions(opType WsEventType) error { func (s *Stream) Unsubscribe() { // errors are handled in the syncSubscriptions, so they are skipped here. - _ = s.syncSubscriptions(WsEventTypeUnsubscribe) + _ = syncSubscriptions(s.StandardStream.Conn, s.StandardStream.Subscriptions, WsEventTypeUnsubscribe) s.Resubscribe(func(old []types.Subscription) (new []types.Subscription, err error) { // clear the subscriptions return []types.Subscription{}, nil }) + + s.kLineStream.Unsubscribe() +} + +func (s *Stream) Connect(ctx context.Context) error { + if err := s.StandardStream.Connect(ctx); err != nil { + return err + } + if err := s.kLineStream.Connect(ctx); err != nil { + return err + } + return nil +} + +func (s *Stream) Subscribe(channel types.Channel, symbol string, options types.SubscribeOptions) { + if channel == types.KLineChannel { + s.kLineStream.Subscribe(channel, symbol, options) + } else { + s.StandardStream.Subscribe(channel, symbol, options) + } +} + +func subscribe(conn *websocket.Conn, subs []WebsocketSubscription) { + if len(subs) == 0 { + return + } + + log.Infof("subscribing channels: %+v", subs) + err := conn.WriteJSON(WebsocketOp{ + Op: "subscribe", + Args: subs, + }) + + if err != nil { + log.WithError(err).Error("subscribe error") + } } func (s *Stream) handleConnect() { @@ -121,19 +161,7 @@ func (s *Stream) handleConnect() { subs = append(subs, sub) } - if len(subs) == 0 { - return - } - - log.Infof("subscribing channels: %+v", subs) - err := s.Conn.WriteJSON(WebsocketOp{ - Op: "subscribe", - Args: subs, - }) - - if err != nil { - log.WithError(err).Error("subscribe error") - } + subscribe(s.StandardStream.Conn, subs) } else { // login as private channel // sign example: diff --git a/pkg/exchange/okex/stream_test.go b/pkg/exchange/okex/stream_test.go index 42a50d03a..4ebfc911d 100644 --- a/pkg/exchange/okex/stream_test.go +++ b/pkg/exchange/okex/stream_test.go @@ -67,6 +67,34 @@ func TestStream(t *testing.T) { <-c }) + t.Run("book && kline test", func(t *testing.T) { + s.Subscribe(types.BookChannel, "BTCUSDT", types.SubscribeOptions{ + Depth: types.DepthLevel400, + }) + s.Subscribe(types.KLineChannel, "BTCUSDT", types.SubscribeOptions{ + Interval: types.Interval1m, + }) + s.SetPublicOnly() + err := s.Connect(context.Background()) + assert.NoError(t, err) + + s.OnBookSnapshot(func(book types.SliceOrderBook) { + t.Log("got snapshot", book) + }) + s.OnBookUpdate(func(book types.SliceOrderBook) { + t.Log("got update", book) + }) + s.OnKLine(func(kline types.KLine) { + t.Log("kline", kline) + }) + s.OnKLineClosed(func(kline types.KLine) { + t.Log("kline closed", kline) + }) + + c := make(chan struct{}) + <-c + }) + t.Run("market trade test", func(t *testing.T) { s.Subscribe(types.MarketTradeChannel, "BTCUSDT", types.SubscribeOptions{}) s.SetPublicOnly()