From 147b31d81d8e084c74eb19ccb1b6da49c6ef4dd9 Mon Sep 17 00:00:00 2001 From: Edwin Date: Tue, 9 Jan 2024 10:57:33 +0800 Subject: [PATCH] pkg/exchange: refactor account stream --- pkg/exchange/okex/parse.go | 49 +++++-- pkg/exchange/okex/parse_test.go | 182 ++++++++++++++++++++++++++ pkg/exchange/okex/stream.go | 43 +++--- pkg/exchange/okex/stream_callbacks.go | 12 -- pkg/exchange/okex/stream_test.go | 14 ++ 5 files changed, 253 insertions(+), 47 deletions(-) diff --git a/pkg/exchange/okex/parse.go b/pkg/exchange/okex/parse.go index df538b373..193518474 100644 --- a/pkg/exchange/okex/parse.go +++ b/pkg/exchange/okex/parse.go @@ -44,14 +44,12 @@ func parseWebSocketEvent(in []byte) (interface{}, error) { return nil, err } if event.Event != "" { - // TODO: remove fastjson - return event, nil + return &event, nil } switch event.Arg.Channel { case ChannelAccount: - // TODO: remove fastjson - return parseAccount(v) + return parseAccount(event.Data) case ChannelBooks, ChannelBook5: var bookEvent BookEvent @@ -100,10 +98,17 @@ func parseWebSocketEvent(in []byte) (interface{}, error) { return nil, nil } +type WsEventType string + +const ( + WsEventTypeLogin = "login" + WsEventTypeError = "error" +) + type WebSocketEvent struct { - Event string `json:"event"` - Code string `json:"code,omitempty"` - Message string `json:"msg,omitempty"` + Event WsEventType `json:"event"` + Code string `json:"code,omitempty"` + Message string `json:"msg,omitempty"` Arg struct { Channel Channel `json:"channel"` InstId string `json:"instId"` @@ -112,6 +117,28 @@ type WebSocketEvent struct { ActionType ActionType `json:"action"` } +func (w *WebSocketEvent) IsValid() error { + switch w.Event { + case WsEventTypeError: + return fmt.Errorf("websocket request error, code: %s, msg: %s", w.Code, w.Message) + + case WsEventTypeLogin: + // Actually, this code is unnecessary because the events are either `Subscribe` or `Unsubscribe`, But to avoid bugs + // in the exchange, we still check. + if w.Code != "0" || len(w.Message) != 0 { + return fmt.Errorf("websocket request error, code: %s, msg: %s", w.Code, w.Message) + } + return nil + + default: + return fmt.Errorf("unexpected event type: %+v", w) + } +} + +func (w *WebSocketEvent) IsAuthenticated() bool { + return w.Event == WsEventTypeLogin && w.Code == "0" +} + type BookEvent struct { InstrumentID string Symbol string @@ -345,17 +372,15 @@ type KLineEvent struct { Channel Channel } -func parseAccount(v *fastjson.Value) (*okexapi.Account, error) { - data := v.Get("data").MarshalTo(nil) - +func parseAccount(v []byte) (*okexapi.Account, error) { var accounts []okexapi.Account - err := json.Unmarshal(data, &accounts) + err := json.Unmarshal(v, &accounts) if err != nil { return nil, err } if len(accounts) == 0 { - return nil, errors.New("empty account data") + return &okexapi.Account{}, nil } return &accounts[0], nil diff --git a/pkg/exchange/okex/parse_test.go b/pkg/exchange/okex/parse_test.go index 5ebcf1989..55c1d0e89 100644 --- a/pkg/exchange/okex/parse_test.go +++ b/pkg/exchange/okex/parse_test.go @@ -11,6 +11,126 @@ import ( "github.com/c9s/bbgo/pkg/types" ) +func Test_parseWebSocketEvent_accountEvent(t *testing.T) { + t.Run("succeeds", func(t *testing.T) { + in := ` +{ + "arg": { + "channel": "account", + "uid": "77982378738415879" + }, + "data": [ + { + "uTime": "1614846244194", + "totalEq": "91884", + "adjEq": "91884.8502560037982063", + "isoEq": "0", + "ordFroz": "0", + "imr": "0", + "mmr": "0", + "borrowFroz": "", + "notionalUsd": "", + "mgnRatio": "100000", + "details": [{ + "availBal": "", + "availEq": "1", + "ccy": "BTC", + "cashBal": "1", + "uTime": "1617279471503", + "disEq": "50559.01", + "eq": "1", + "eqUsd": "45078", + "fixedBal": "0", + "frozenBal": "0", + "interest": "0", + "isoEq": "0", + "liab": "0", + "maxLoan": "", + "mgnRatio": "", + "notionalLever": "0", + "ordFrozen": "0", + "upl": "0", + "uplLiab": "0", + "crossLiab": "0", + "isoLiab": "0", + "coinUsdPrice": "60000", + "stgyEq":"0", + "spotInUseAmt":"", + "isoUpl":"", + "borrowFroz": "" + }, + { + "availBal": "", + "availEq": "41307", + "ccy": "USDT", + "cashBal": "41307", + "uTime": "1617279471503", + "disEq": "41325", + "eq": "41307", + "eqUsd": "45078", + "fixedBal": "0", + "frozenBal": "0", + "interest": "0", + "isoEq": "0", + "liab": "0", + "maxLoan": "", + "mgnRatio": "", + "notionalLever": "0", + "ordFrozen": "0", + "upl": "0", + "uplLiab": "0", + "crossLiab": "0", + "isoLiab": "0", + "coinUsdPrice": "1.00007", + "stgyEq":"0", + "spotInUseAmt":"", + "isoUpl":"", + "borrowFroz": "" + } + ] + } + ] +} +` + + exp := &okexapi.Account{ + TotalEquityInUSD: fixedpoint.NewFromFloat(91884), + UpdateTime: "1614846244194", + Details: []okexapi.BalanceDetail{ + { + Currency: "BTC", + Available: fixedpoint.NewFromFloat(1), + CashBalance: fixedpoint.NewFromFloat(1), + OrderFrozen: fixedpoint.Zero, + Frozen: fixedpoint.Zero, + Equity: fixedpoint.One, + EquityInUSD: fixedpoint.NewFromFloat(45078), + UpdateTime: types.NewMillisecondTimestampFromInt(1617279471503), + UnrealizedProfitAndLoss: fixedpoint.Zero, + }, + { + Currency: "USDT", + Available: fixedpoint.NewFromFloat(41307), + CashBalance: fixedpoint.NewFromFloat(41307), + OrderFrozen: fixedpoint.Zero, + Frozen: fixedpoint.Zero, + Equity: fixedpoint.NewFromFloat(41307), + EquityInUSD: fixedpoint.NewFromFloat(45078), + UpdateTime: types.NewMillisecondTimestampFromInt(1617279471503), + UnrealizedProfitAndLoss: fixedpoint.Zero, + }, + }, + } + + res, err := parseWebSocketEvent([]byte(in)) + assert.NoError(t, err) + event, ok := res.(*okexapi.Account) + assert.True(t, ok) + assert.Equal(t, exp, event) + }) + +} + func TestParsePriceVolumeOrderSliceJSON(t *testing.T) { t.Run("snapshot", func(t *testing.T) { in := ` @@ -668,3 +788,65 @@ func Test_toGlobalTrade(t *testing.T) { assert.ErrorContains(t, err, "unexpected inst id") }) } + +func TestWebSocketEvent_IsValid(t *testing.T) { + t.Run("op login event", func(t *testing.T) { + input := `{ + "event": "login", + "code": "0", + "msg": "", + "connId": "a4d3ae55" +}` + res, err := parseWebSocketEvent([]byte(input)) + assert.NoError(t, err) + opEvent, ok := res.(*WebSocketEvent) + assert.True(t, ok) + assert.Equal(t, WebSocketEvent{ + Event: WsEventTypeLogin, + Code: "0", + Message: "", + }, *opEvent) + + assert.NoError(t, opEvent.IsValid()) + }) + + t.Run("op error event", func(t *testing.T) { + input := `{ + "event": "error", + "code": "60009", + "msg": "Login failed.", + "connId": "a4d3ae55" +}` + res, err := parseWebSocketEvent([]byte(input)) + assert.NoError(t, err) + opEvent, ok := res.(*WebSocketEvent) + assert.True(t, ok) + assert.Equal(t, WebSocketEvent{ + Event: WsEventTypeError, + Code: "60009", + Message: "Login failed.", + }, *opEvent) + + assert.ErrorContains(t, opEvent.IsValid(), "request error") + }) + + t.Run("unexpected event", func(t *testing.T) { + input := `{ + "event": "test gg", + "code": "60009", + "msg": "unexpected", + "connId": "a4d3ae55" +}` + res, err := parseWebSocketEvent([]byte(input)) + assert.NoError(t, err) + opEvent, ok := res.(*WebSocketEvent) + assert.True(t, ok) + assert.Equal(t, WebSocketEvent{ + Event: "test gg", + Code: "60009", + Message: "unexpected", + }, *opEvent) + + assert.ErrorContains(t, opEvent.IsValid(), "unexpected event type") + }) +} diff --git a/pkg/exchange/okex/stream.go b/pkg/exchange/okex/stream.go index 34d6c853c..1d98e739b 100644 --- a/pkg/exchange/okex/stream.go +++ b/pkg/exchange/okex/stream.go @@ -35,7 +35,6 @@ type Stream struct { // public callbacks kLineEventCallbacks []func(candle KLineEvent) bookEventCallbacks []func(book BookEvent) - eventCallbacks []func(event WebSocketEvent) accountEventCallbacks []func(account okexapi.Account) orderDetailsEventCallbacks []func(orderDetails []okexapi.OrderDetails) marketTradeEventCallbacks []func(tradeDetail []MarketTradeEvent) @@ -56,8 +55,8 @@ func NewStream(client *okexapi.RestClient) *Stream { stream.OnAccountEvent(stream.handleAccountEvent) stream.OnMarketTradeEvent(stream.handleMarketTradeEvent) stream.OnOrderDetailsEvent(stream.handleOrderDetailsEvent) - stream.OnEvent(stream.handleEvent) stream.OnConnect(stream.handleConnect) + stream.OnAuth(stream.handleAuth) return stream } @@ -113,26 +112,19 @@ func (s *Stream) handleConnect() { } } -func (s *Stream) handleEvent(event WebSocketEvent) { - switch event.Event { - case "login": - if event.Code == "0" { - s.EmitAuth() - var subs = []WebsocketSubscription{ - {Channel: "account"}, - {Channel: "orders", InstrumentType: string(okexapi.InstrumentTypeSpot)}, - } +func (s *Stream) handleAuth() { + var subs = []WebsocketSubscription{ + {Channel: ChannelAccount}, + {Channel: "orders", InstrumentType: string(okexapi.InstrumentTypeSpot)}, + } - log.Infof("subscribing private channels: %+v", subs) - err := s.Conn.WriteJSON(WebsocketOp{ - Op: "subscribe", - Args: subs, - }) - - if err != nil { - log.WithError(err).Error("private channel subscribe error") - } - } + log.Infof("subscribing private channels: %+v", subs) + err := s.Conn.WriteJSON(WebsocketOp{ + Op: "subscribe", + Args: subs, + }) + if err != nil { + log.WithError(err).Error("private channel subscribe error") } } @@ -160,7 +152,7 @@ func (s *Stream) handleOrderDetailsEvent(orderDetails []okexapi.OrderDetails) { func (s *Stream) handleAccountEvent(account okexapi.Account) { balances := toGlobalBalance(&account) - s.EmitBalanceSnapshot(balances) + s.EmitBalanceUpdate(balances) } func (s *Stream) handleBookEvent(data BookEvent) { @@ -211,7 +203,12 @@ func (s *Stream) createEndpoint(ctx context.Context) (string, error) { func (s *Stream) dispatchEvent(e interface{}) { switch et := e.(type) { case *WebSocketEvent: - s.EmitEvent(*et) + if err := et.IsValid(); err != nil { + log.Errorf("invalid event: %v", err) + } + if et.IsAuthenticated() { + s.EmitAuth() + } case *BookEvent: // there's "books" for 400 depth and books5 for 5 depth diff --git a/pkg/exchange/okex/stream_callbacks.go b/pkg/exchange/okex/stream_callbacks.go index b735d0985..089da09aa 100644 --- a/pkg/exchange/okex/stream_callbacks.go +++ b/pkg/exchange/okex/stream_callbacks.go @@ -26,16 +26,6 @@ func (s *Stream) EmitBookEvent(book BookEvent) { } } -func (s *Stream) OnEvent(cb func(event WebSocketEvent)) { - s.eventCallbacks = append(s.eventCallbacks, cb) -} - -func (s *Stream) EmitEvent(event WebSocketEvent) { - for _, cb := range s.eventCallbacks { - cb(event) - } -} - func (s *Stream) OnAccountEvent(cb func(account okexapi.Account)) { s.accountEventCallbacks = append(s.accountEventCallbacks, cb) } @@ -71,8 +61,6 @@ type StreamEventHub interface { OnBookEvent(cb func(book BookEvent)) - OnEvent(cb func(event WebSocketEvent)) - OnAccountEvent(cb func(account okexapi.Account)) OnOrderDetailsEvent(cb func(orderDetails []okexapi.OrderDetails)) diff --git a/pkg/exchange/okex/stream_test.go b/pkg/exchange/okex/stream_test.go index b9b758a6a..cf0125ded 100644 --- a/pkg/exchange/okex/stream_test.go +++ b/pkg/exchange/okex/stream_test.go @@ -31,6 +31,20 @@ func TestStream(t *testing.T) { t.Skip() s := getTestClientOrSkip(t) + t.Run("account test", func(t *testing.T) { + err := s.Connect(context.Background()) + assert.NoError(t, err) + + s.OnBalanceUpdate(func(balances types.BalanceMap) { + t.Log("got snapshot", balances) + }) + s.OnBookUpdate(func(book types.SliceOrderBook) { + t.Log("got update", book) + }) + c := make(chan struct{}) + <-c + }) + t.Run("book test", func(t *testing.T) { s.Subscribe(types.BookChannel, "BTCUSDT", types.SubscribeOptions{ Depth: types.DepthLevel50,