pkg/exchange: support kline subscription on stream

This commit is contained in:
Edwin 2023-10-31 21:53:13 +08:00
parent 7c1060c208
commit 102b662f7c
5 changed files with 451 additions and 11 deletions

View File

@ -5,17 +5,11 @@ import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/gorilla/websocket"
"strings"
"github.com/c9s/bbgo/pkg/exchange/bitget/bitgetapi"
"github.com/c9s/bbgo/pkg/types"
"github.com/gorilla/websocket"
)
const (
// Client should keep ping the server in every 30 seconds. Server will close the connections which has no ping over
// 120 seconds(even when the client is still receiving data from the server)
pingInterval = 30 * time.Second
)
var (
@ -29,11 +23,15 @@ type Stream struct {
bookEventCallbacks []func(o BookEvent)
marketTradeEventCallbacks []func(o MarketTradeEvent)
KLineEventCallbacks []func(o KLineEvent)
lastCandle map[string]types.KLine
}
func NewStream() *Stream {
stream := &Stream{
StandardStream: types.NewStandardStream(),
lastCandle: map[string]types.KLine{},
}
stream.SetEndpointCreator(stream.createEndpoint)
@ -44,6 +42,7 @@ func NewStream() *Stream {
stream.OnBookEvent(stream.handleBookEvent)
stream.OnMarketTradeEvent(stream.handleMaretTradeEvent)
stream.OnKLineEvent(stream.handleKLineEvent)
return stream
}
@ -108,6 +107,9 @@ func (s *Stream) dispatchEvent(event interface{}) {
case *MarketTradeEvent:
s.EmitMarketTradeEvent(*e)
case *KLineEvent:
s.EmitKLineEvent(*e)
case []byte:
// We only handle the 'pong' case. Others are unexpected.
if !bytes.Equal(e, pongBytes) {
@ -171,6 +173,15 @@ func convertSubscription(sub types.Subscription) (WsArg, error) {
case types.MarketTradeChannel:
arg.Channel = ChannelTrade
return arg, nil
case types.KLineChannel:
interval, found := toLocalInterval[sub.Options.Interval]
if !found {
return WsArg{}, fmt.Errorf("interval %s not supported on KLine subscription", sub.Options.Interval)
}
arg.Channel = ChannelType(interval)
return arg, nil
}
return arg, fmt.Errorf("unsupported stream channel: %s", sub.Channel)
@ -200,7 +211,8 @@ func parseEvent(in []byte) (interface{}, error) {
return &event, nil
}
switch event.Arg.Channel {
ch := event.Arg.Channel
switch ch {
case ChannelOrderBook, ChannelOrderBook5, ChannelOrderBook15:
var book BookEvent
err = json.Unmarshal(event.Data, &book.Events)
@ -222,9 +234,26 @@ func parseEvent(in []byte) (interface{}, error) {
trade.actionType = event.Action
trade.instId = event.Arg.InstId
return &trade, nil
default:
// handle the `KLine` case here to avoid complicating the code structure.
if strings.HasPrefix(string(ch), "candle") {
var kline KLineEvent
err = json.Unmarshal(event.Data, &kline.Events)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal data into KLineEvent, Arg: %+v Data: %s, err: %w", event.Arg, string(event.Data), err)
}
kline.actionType = event.Action
kline.channel = ch
kline.instId = event.Arg.InstId
return &kline, nil
}
// return an error for any other case
return nil, fmt.Errorf("unhandled websocket event: %+v", string(in))
}
}
func (s *Stream) handleMaretTradeEvent(m MarketTradeEvent) {
@ -242,3 +271,28 @@ func (s *Stream) handleMaretTradeEvent(m MarketTradeEvent) {
s.EmitMarketTrade(globalTrade)
}
}
func (s *Stream) handleKLineEvent(k KLineEvent) {
if k.actionType == ActionTypeSnapshot {
// we don't support snapshot event
return
}
interval, found := toGlobalInterval[string(k.channel)]
if !found {
log.Errorf("unexpected interval %s on KLine subscription", k.channel)
return
}
for _, kline := range k.Events {
last, ok := s.lastCandle[k.CacheKey()]
if ok && kline.StartTime.Time().After(last.StartTime.Time()) {
last.Closed = true
s.EmitKLineClosed(last)
}
kLine := kline.ToGlobal(interval, k.instId)
s.EmitKLine(kLine)
s.lastCandle[k.CacheKey()] = kLine
}
}

View File

@ -23,3 +23,13 @@ func (s *Stream) EmitMarketTradeEvent(o MarketTradeEvent) {
cb(o)
}
}
func (s *Stream) OnKLineEvent(cb func(o KLineEvent)) {
s.KLineEventCallbacks = append(s.KLineEventCallbacks, cb)
}
func (s *Stream) EmitKLineEvent(o KLineEvent) {
for _, cb := range s.KLineEventCallbacks {
cb(o)
}
}

View File

@ -106,6 +106,22 @@ func TestStream(t *testing.T) {
<-c
})
t.Run("kline test", func(t *testing.T) {
s.Subscribe(types.KLineChannel, "BTCUSDT", types.SubscribeOptions{Interval: types.Interval1w})
s.SetPublicOnly()
err := s.Connect(context.Background())
assert.NoError(t, err)
s.OnKLine(func(kline types.KLine) {
t.Log("got update", kline)
})
s.OnKLineClosed(func(kline types.KLine) {
t.Log("got closed update", kline)
})
c := make(chan struct{})
<-c
})
}
func TestStream_parseWebSocketEvent(t *testing.T) {
@ -453,6 +469,174 @@ func Test_parseWebSocketEvent_MarketTrade(t *testing.T) {
})
}
func Test_parseWebSocketEvent_KLine(t *testing.T) {
t.Run("KLine event", func(t *testing.T) {
input := `{
"action":"%s",
"arg":{
"instType":"sp",
"channel":"candle5m",
"instId":"BTCUSDT"
},
"data":[
["1698744600000","34361.49","34458.98","34355.53","34416.41","99.6631"]
],
"ts":1697697791670
}`
eventFn := func(in string, actionType ActionType) {
res, err := parseWebSocketEvent([]byte(in))
assert.NoError(t, err)
kline, ok := res.(*KLineEvent)
assert.True(t, ok)
assert.Equal(t, KLineEvent{
channel: "candle5m",
Events: KLineSlice{
{
StartTime: types.NewMillisecondTimestampFromInt(1698744600000),
OpenPrice: fixedpoint.NewFromFloat(34361.49),
HighestPrice: fixedpoint.NewFromFloat(34458.98),
LowestPrice: fixedpoint.NewFromFloat(34355.53),
ClosePrice: fixedpoint.NewFromFloat(34416.41),
Volume: fixedpoint.NewFromFloat(99.6631),
},
},
actionType: actionType,
instId: "BTCUSDT",
}, *kline)
}
t.Run("snapshot type", func(t *testing.T) {
snapshotInput := fmt.Sprintf(input, ActionTypeSnapshot)
eventFn(snapshotInput, ActionTypeSnapshot)
})
t.Run("update type", func(t *testing.T) {
snapshotInput := fmt.Sprintf(input, ActionTypeUpdate)
eventFn(snapshotInput, ActionTypeUpdate)
})
})
t.Run("Unexpected length of kline", func(t *testing.T) {
input := `{
"action":"%s",
"arg":{
"instType":"sp",
"channel":"candle5m",
"instId":"BTCUSDT"
},
"data":[
["1698744600000","34361.45","34458.98","34355.53","34416.41","99.6631", "123456"]
],
"ts":1697697791670
}`
_, err := parseWebSocketEvent([]byte(input))
assert.ErrorContains(t, err, "unexpected kline length")
})
t.Run("Unexpected timestamp", func(t *testing.T) {
input := `{
"action":"%s",
"arg":{
"instType":"sp",
"channel":"candle5m",
"instId":"BTCUSDT"
},
"data":[
["timestamp","34361.49","34458.98","34355.53","34416.41","99.6631"]
],
"ts":1697697791670
}`
_, err := parseWebSocketEvent([]byte(input))
assert.ErrorContains(t, err, "timestamp")
})
t.Run("Unexpected open price", func(t *testing.T) {
input := `{
"action":"%s",
"arg":{
"instType":"sp",
"channel":"candle5m",
"instId":"BTCUSDT"
},
"data":[
["1698744600000","1p","34458.98","34355.53","34416.41","99.6631"]
],
"ts":1697697791670
}`
_, err := parseWebSocketEvent([]byte(input))
assert.ErrorContains(t, err, "open price")
})
t.Run("Unexpected highest price", func(t *testing.T) {
input := `{
"action":"%s",
"arg":{
"instType":"sp",
"channel":"candle5m",
"instId":"BTCUSDT"
},
"data":[
["1698744600000","34361.45","3p","34355.53","34416.41","99.6631"]
],
"ts":1697697791670
}`
_, err := parseWebSocketEvent([]byte(input))
assert.ErrorContains(t, err, "highest price")
})
t.Run("Unexpected lowest price", func(t *testing.T) {
input := `{
"action":"%s",
"arg":{
"instType":"sp",
"channel":"candle5m",
"instId":"BTCUSDT"
},
"data":[
["1698744600000","34361.45","34458.98","1p","34416.41","99.6631"]
],
"ts":1697697791670
}`
_, err := parseWebSocketEvent([]byte(input))
assert.ErrorContains(t, err, "lowest price")
})
t.Run("Unexpected close price", func(t *testing.T) {
input := `{
"action":"%s",
"arg":{
"instType":"sp",
"channel":"candle5m",
"instId":"BTCUSDT"
},
"data":[
["1698744600000","34361.45","34458.98","34355.53","1c","99.6631"]
],
"ts":1697697791670
}`
_, err := parseWebSocketEvent([]byte(input))
assert.ErrorContains(t, err, "close price")
})
t.Run("Unexpected volume", func(t *testing.T) {
input := `{
"action":"%s",
"arg":{
"instType":"sp",
"channel":"candle5m",
"instId":"BTCUSDT"
},
"data":[
["1698744600000","34361.45","34458.98","34355.53","34416.41", "1v"]
],
"ts":1697697791670
}`
_, err := parseWebSocketEvent([]byte(input))
assert.ErrorContains(t, err, "volume")
})
}
func Test_convertSubscription(t *testing.T) {
t.Run("BookChannel.ChannelOrderBook5", func(t *testing.T) {
res, err := convertSubscription(types.Subscription{
@ -512,4 +696,21 @@ func Test_convertSubscription(t *testing.T) {
InstId: "BTCUSDT",
}, res)
})
t.Run("CandleChannel", func(t *testing.T) {
for gInterval, localInterval := range toLocalInterval {
res, err := convertSubscription(types.Subscription{
Symbol: "BTCUSDT",
Channel: types.KLineChannel,
Options: types.SubscribeOptions{
Interval: gInterval,
},
})
assert.NoError(t, err)
assert.Equal(t, WsArg{
InstType: instSp,
Channel: ChannelType(localInterval),
InstId: "BTCUSDT",
}, res)
}
})
}

View File

@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/types"
@ -260,3 +261,134 @@ type MarketTradeEvent struct {
actionType ActionType
instId string
}
var (
toLocalInterval = map[types.Interval]string{
types.Interval1m: "candle1m",
types.Interval5m: "candle5m",
types.Interval15m: "candle15m",
types.Interval30m: "candle30m",
types.Interval1h: "candle1H",
types.Interval4h: "candle4H",
types.Interval12h: "candle12H",
types.Interval1d: "candle1D",
types.Interval1w: "candle1W",
}
toGlobalInterval = map[string]types.Interval{
"candle1m": types.Interval1m,
"candle5m": types.Interval5m,
"candle15m": types.Interval15m,
"candle30m": types.Interval30m,
"candle1H": types.Interval1h,
"candle4H": types.Interval4h,
"candle12H": types.Interval12h,
"candle1D": types.Interval1d,
"candle1W": types.Interval1w,
}
)
type KLine struct {
StartTime types.MillisecondTimestamp
OpenPrice fixedpoint.Value
HighestPrice fixedpoint.Value
LowestPrice fixedpoint.Value
ClosePrice fixedpoint.Value
Volume fixedpoint.Value
}
func (k KLine) ToGlobal(interval types.Interval, symbol string) types.KLine {
startTime := k.StartTime.Time()
return types.KLine{
Exchange: types.ExchangeBitget,
Symbol: symbol,
StartTime: types.Time(startTime),
EndTime: types.Time(startTime.Add(interval.Duration() - time.Millisecond)),
Interval: interval,
Open: k.OpenPrice,
Close: k.ClosePrice,
High: k.HighestPrice,
Low: k.LowestPrice,
Volume: k.Volume,
QuoteVolume: fixedpoint.Zero, // not supported
TakerBuyBaseAssetVolume: fixedpoint.Zero, // not supported
TakerBuyQuoteAssetVolume: fixedpoint.Zero, // not supported
LastTradeID: 0, // not supported
NumberOfTrades: 0, // not supported
Closed: false,
}
}
type KLineSlice []KLine
func (m *KLineSlice) UnmarshalJSON(b []byte) error {
if m == nil {
return errors.New("nil pointer of kline slice")
}
s, err := parseKLineSliceJSON(b)
if err != nil {
return err
}
*m = s
return nil
}
// parseKLineSliceJSON tries to parse a 2 dimensional string array into a KLineSlice
//
// [
//
// ["1597026383085", "8533.02", "8553.74", "8527.17", "8548.26", "45247"]
// ]
func parseKLineSliceJSON(in []byte) (slice KLineSlice, err error) {
var rawKLines [][]json.RawMessage
err = json.Unmarshal(in, &rawKLines)
if err != nil {
return slice, err
}
for _, raw := range rawKLines {
if len(raw) != 6 {
return nil, fmt.Errorf("unexpected kline length: %d, data: %q", len(raw), raw)
}
var kline KLine
if err = json.Unmarshal(raw[0], &kline.StartTime); err != nil {
return nil, fmt.Errorf("failed to unmarshal into timestamp: %q", raw[0])
}
if err = json.Unmarshal(raw[1], &kline.OpenPrice); err != nil {
return nil, fmt.Errorf("failed to unmarshal into open price: %q", raw[1])
}
if err = json.Unmarshal(raw[2], &kline.HighestPrice); err != nil {
return nil, fmt.Errorf("failed to unmarshal into highest price: %q", raw[2])
}
if err = json.Unmarshal(raw[3], &kline.LowestPrice); err != nil {
return nil, fmt.Errorf("failed to unmarshal into lowest price: %q", raw[3])
}
if err = json.Unmarshal(raw[4], &kline.ClosePrice); err != nil {
return nil, fmt.Errorf("failed to unmarshal into close price: %q", raw[4])
}
if err = json.Unmarshal(raw[5], &kline.Volume); err != nil {
return nil, fmt.Errorf("failed to unmarshal into volume: %q", raw[5])
}
slice = append(slice, kline)
}
return slice, nil
}
type KLineEvent struct {
Events KLineSlice
// internal use
actionType ActionType
channel ChannelType
instId string
}
func (k KLineEvent) CacheKey() string {
// e.q: candle5m.BTCUSDT
return fmt.Sprintf("%s.%s", k.channel, k.instId)
}

View File

@ -0,0 +1,43 @@
package bitget
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/types"
)
func TestKLine_ToGlobal(t *testing.T) {
startTime := int64(1698744600000)
interval := types.Interval1m
k := KLine{
StartTime: types.NewMillisecondTimestampFromInt(startTime),
OpenPrice: fixedpoint.NewFromFloat(34361.49),
HighestPrice: fixedpoint.NewFromFloat(34458.98),
LowestPrice: fixedpoint.NewFromFloat(34355.53),
ClosePrice: fixedpoint.NewFromFloat(34416.41),
Volume: fixedpoint.NewFromFloat(99.6631),
}
assert.Equal(t, types.KLine{
Exchange: types.ExchangeBitget,
Symbol: "BTCUSDT",
StartTime: types.Time(types.NewMillisecondTimestampFromInt(startTime).Time()),
EndTime: types.Time(types.NewMillisecondTimestampFromInt(startTime).Time().Add(interval.Duration() - time.Millisecond)),
Interval: interval,
Open: fixedpoint.NewFromFloat(34361.49),
Close: fixedpoint.NewFromFloat(34416.41),
High: fixedpoint.NewFromFloat(34458.98),
Low: fixedpoint.NewFromFloat(34355.53),
Volume: fixedpoint.NewFromFloat(99.6631),
QuoteVolume: fixedpoint.Zero,
TakerBuyBaseAssetVolume: fixedpoint.Zero,
TakerBuyQuoteAssetVolume: fixedpoint.Zero,
LastTradeID: 0,
NumberOfTrades: 0,
Closed: false,
}, k.ToGlobal(interval, "BTCUSDT"))
}