mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-26 00:35:15 +00:00
pre-parse kline payload so that we don't have to re-parse string everytime
This commit is contained in:
parent
afcde6827f
commit
e37932bace
|
@ -303,19 +303,21 @@ func (e *Exchange) QueryKLines(ctx context.Context, symbol, interval string, opt
|
|||
}
|
||||
|
||||
var kLines []types.KLine
|
||||
for _, kline := range resp {
|
||||
for _, k := range resp {
|
||||
kLines = append(kLines, types.KLine{
|
||||
Symbol: symbol,
|
||||
Interval: interval,
|
||||
StartTime: kline.OpenTime,
|
||||
EndTime: kline.CloseTime,
|
||||
Open: kline.Open,
|
||||
Close: kline.Close,
|
||||
High: kline.High,
|
||||
Low: kline.Low,
|
||||
Volume: kline.Volume,
|
||||
QuoteVolume: kline.QuoteAssetVolume,
|
||||
NumberOfTrades: kline.TradeNum,
|
||||
StartTime: time.Unix(0, k.OpenTime*int64(time.Millisecond)),
|
||||
EndTime: time.Unix(0, k.CloseTime*int64(time.Millisecond)),
|
||||
Open: util.MustParseFloat(k.Open),
|
||||
Close: util.MustParseFloat(k.Close),
|
||||
High: util.MustParseFloat(k.High),
|
||||
Low: util.MustParseFloat(k.Low),
|
||||
Volume: util.MustParseFloat(k.Volume),
|
||||
QuoteVolume: util.MustParseFloat(k.QuoteAssetVolume),
|
||||
LastTradeID: 0,
|
||||
NumberOfTrades: k.TradeNum,
|
||||
Closed: true,
|
||||
})
|
||||
}
|
||||
return kLines, nil
|
||||
|
@ -465,13 +467,12 @@ func (e *Exchange) BatchQueryKLines(ctx context.Context, symbol, interval string
|
|||
}
|
||||
|
||||
for _, kline := range klines {
|
||||
t := time.Unix(0, kline.EndTime*int64(time.Millisecond))
|
||||
if t.After(endTime) {
|
||||
if kline.EndTime.After(endTime) {
|
||||
return allKLines, nil
|
||||
}
|
||||
|
||||
allKLines = append(allKLines, kline)
|
||||
startTime = t
|
||||
startTime = kline.EndTime
|
||||
}
|
||||
|
||||
// avoid rate limit
|
||||
|
|
|
@ -4,10 +4,12 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fastjson"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/bbgo/types"
|
||||
"github.com/c9s/bbgo/pkg/util"
|
||||
"github.com/valyala/fastjson"
|
||||
"time"
|
||||
)
|
||||
|
||||
/*
|
||||
|
@ -76,9 +78,9 @@ type ExecutionReportEvent struct {
|
|||
TradeID int64 `json:"t"`
|
||||
TransactionTime int64 `json:"T"`
|
||||
|
||||
LastExecutedQuantity string `json:"l"`
|
||||
CumulativeFilledQuantity string `json:"z"`
|
||||
LastExecutedPrice string `json:"L"`
|
||||
LastExecutedQuantity string `json:"l"`
|
||||
CumulativeFilledQuantity string `json:"z"`
|
||||
LastExecutedPrice string `json:"L"`
|
||||
LastQuoteAssetTransactedQuantity string `json:"Y"`
|
||||
|
||||
OrderCreationTime int `json:"O"`
|
||||
|
@ -91,17 +93,17 @@ func (e *ExecutionReportEvent) Trade() (*types.Trade, error) {
|
|||
|
||||
tt := time.Unix(0, e.TransactionTime/1000000)
|
||||
return &types.Trade{
|
||||
ID: e.TradeID,
|
||||
Symbol: e.Symbol,
|
||||
Price: util.MustParseFloat(e.LastExecutedPrice),
|
||||
Quantity: util.MustParseFloat(e.LastExecutedQuantity),
|
||||
ID: e.TradeID,
|
||||
Symbol: e.Symbol,
|
||||
Price: util.MustParseFloat(e.LastExecutedPrice),
|
||||
Quantity: util.MustParseFloat(e.LastExecutedQuantity),
|
||||
QuoteQuantity: util.MustParseFloat(e.LastQuoteAssetTransactedQuantity),
|
||||
Side: e.Side,
|
||||
IsBuyer: e.Side == "BUY",
|
||||
IsMaker: e.IsMaker,
|
||||
Time: tt,
|
||||
Fee: util.MustParseFloat(e.CommissionAmount),
|
||||
FeeCurrency: e.CommissionAsset,
|
||||
Side: e.Side,
|
||||
IsBuyer: e.Side == "BUY",
|
||||
IsMaker: e.IsMaker,
|
||||
Time: tt,
|
||||
Fee: util.MustParseFloat(e.CommissionAmount),
|
||||
FeeCurrency: e.CommissionAsset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -240,10 +242,49 @@ func ParseEvent(message string) (interface{}, error) {
|
|||
return nil, fmt.Errorf("unsupported message: %s", message)
|
||||
}
|
||||
|
||||
// KLine uses binance's kline as the standard structure
|
||||
type KLine struct {
|
||||
StartTime int64 `json:"t"`
|
||||
EndTime int64 `json:"T"`
|
||||
|
||||
Symbol string `json:"s"`
|
||||
Interval string `json:"i"`
|
||||
|
||||
Open string `json:"o"`
|
||||
Close string `json:"c"`
|
||||
High string `json:"h"`
|
||||
|
||||
Low string `json:"l"`
|
||||
Volume string `json:"V"` // taker buy base asset volume (like 10 BTC)
|
||||
QuoteVolume string `json:"Q"` // taker buy quote asset volume (like 1000USDT)
|
||||
|
||||
LastTradeID int `json:"L"`
|
||||
NumberOfTrades int64 `json:"n"`
|
||||
Closed bool `json:"x"`
|
||||
}
|
||||
|
||||
type KLineEvent struct {
|
||||
EventBase
|
||||
Symbol string `json:"s"`
|
||||
KLine *types.KLine `json:"k,omitempty"`
|
||||
Symbol string `json:"s"`
|
||||
KLine KLine `json:"k,omitempty"`
|
||||
}
|
||||
|
||||
func (k *KLine) KLine() types.KLine {
|
||||
return types.KLine{
|
||||
Symbol: k.Symbol,
|
||||
Interval: k.Interval,
|
||||
StartTime: time.Unix(0, k.StartTime*int64(time.Millisecond)),
|
||||
EndTime: time.Unix(0, k.EndTime*int64(time.Millisecond)),
|
||||
Open: util.MustParseFloat(k.Open),
|
||||
Close: util.MustParseFloat(k.Close),
|
||||
High: util.MustParseFloat(k.High),
|
||||
Low: util.MustParseFloat(k.Low),
|
||||
Volume: util.MustParseFloat(k.Volume),
|
||||
QuoteVolume: util.MustParseFloat(k.QuoteVolume),
|
||||
LastTradeID: k.LastTradeID,
|
||||
NumberOfTrades: k.NumberOfTrades,
|
||||
Closed: k.Closed,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
@ -31,8 +31,8 @@ type PrivateStream struct {
|
|||
connectCallbacks []func(stream *PrivateStream)
|
||||
|
||||
// custom callbacks
|
||||
kLineEventCallbacks []func(event *KLineEvent)
|
||||
kLineClosedEventCallbacks []func(event *KLineEvent)
|
||||
kLineEventCallbacks []func(e *KLineEvent)
|
||||
kLineClosedEventCallbacks []func(e *KLineEvent)
|
||||
|
||||
balanceUpdateEventCallbacks []func(event *BalanceUpdateEvent)
|
||||
outboundAccountInfoEventCallbacks []func(event *OutboundAccountInfoEvent)
|
||||
|
@ -190,7 +190,7 @@ func (s *PrivateStream) read(ctx context.Context, eventC chan interface{}) {
|
|||
|
||||
if e.KLine.Closed {
|
||||
s.EmitKLineClosedEvent(e)
|
||||
s.EmitKLineClosed(e.KLine)
|
||||
s.EmitKLineClosed(e.KLine.KLine())
|
||||
}
|
||||
|
||||
case *ExecutionReportEvent:
|
||||
|
|
63
bbgo/indicator.go
Normal file
63
bbgo/indicator.go
Normal file
|
@ -0,0 +1,63 @@
|
|||
package bbgo
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/bbgo/types"
|
||||
)
|
||||
|
||||
type MovingAverageIndicator struct {
|
||||
store *MarketDataStore
|
||||
Period int
|
||||
}
|
||||
|
||||
func NewMovingAverageIndicator(period int) *MovingAverageIndicator {
|
||||
return &MovingAverageIndicator{
|
||||
Period: period,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *MovingAverageIndicator) handleUpdate(kline types.KLine) {
|
||||
klines, ok := i.store.KLineWindows[ Interval(kline.Interval) ]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if len(klines) < i.Period {
|
||||
return
|
||||
}
|
||||
|
||||
// calculate ma
|
||||
}
|
||||
|
||||
type IndicatorValue struct {
|
||||
Value float64
|
||||
Time time.Time
|
||||
}
|
||||
|
||||
func calculateMovingAverage(klines types.KLineWindow, period int) (values []IndicatorValue) {
|
||||
for idx := range klines[period:] {
|
||||
offset := idx + period
|
||||
sum := klines[offset - period:offset].ReduceClose()
|
||||
values = append(values, IndicatorValue{
|
||||
Time: klines[offset].GetEndTime(),
|
||||
Value: math.Round(sum / float64(period)),
|
||||
})
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
func (i *MovingAverageIndicator) SubscribeStore(store *MarketDataStore) {
|
||||
i.store = store
|
||||
|
||||
// register kline update callback
|
||||
store.OnUpdate(i.handleUpdate)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
17
bbgo/indicator_test.go
Normal file
17
bbgo/indicator_test.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package bbgo
|
||||
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/bbgo/types"
|
||||
)
|
||||
|
||||
func TestCalculateMovingAverage(t *testing.T) {
|
||||
klines := types.KLineWindow{
|
||||
{
|
||||
|
||||
},
|
||||
}
|
||||
_ = klines
|
||||
}
|
|
@ -3,7 +3,6 @@ package bbgo
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
|
@ -54,7 +53,7 @@ func (trader *KLineRegressionTrader) RunStrategy(ctx context.Context, strategy S
|
|||
|
||||
fmt.Print(".")
|
||||
|
||||
standardStream.EmitKLineClosed(&kline)
|
||||
standardStream.EmitKLineClosed(kline)
|
||||
|
||||
for _, order := range trader.pendingOrders {
|
||||
switch order.Side {
|
||||
|
@ -117,7 +116,7 @@ func (trader *KLineRegressionTrader) RunStrategy(ctx context.Context, strategy S
|
|||
Side: string(order.Side),
|
||||
IsBuyer: order.Side == types.SideTypeBuy,
|
||||
IsMaker: false,
|
||||
Time: time.Unix(0, kline.EndTime*int64(time.Millisecond)),
|
||||
Time: kline.EndTime,
|
||||
Symbol: trader.Context.Symbol,
|
||||
Fee: fee,
|
||||
FeeCurrency: feeCurrency,
|
||||
|
|
17
bbgo/marketdatastore_callbacks.go
Normal file
17
bbgo/marketdatastore_callbacks.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
// Code generated by "callbackgen -type MarketDataStore"; DO NOT EDIT.
|
||||
|
||||
package bbgo
|
||||
|
||||
import (
|
||||
"github.com/c9s/bbgo/pkg/bbgo/types"
|
||||
)
|
||||
|
||||
func (store *MarketDataStore) OnUpdate(cb KLineCallback) {
|
||||
store.updateCallbacks = append(store.updateCallbacks, cb)
|
||||
}
|
||||
|
||||
func (store *MarketDataStore) EmitUpdate(kline types.KLine) {
|
||||
for _, cb := range store.updateCallbacks {
|
||||
cb(kline)
|
||||
}
|
||||
}
|
|
@ -11,18 +11,18 @@ var Interval5m = Interval("5m")
|
|||
var Interval1h = Interval("1h")
|
||||
var Interval1d = Interval("1d")
|
||||
|
||||
type MarketDataStore struct {
|
||||
// MaxChangeKLines stores the max change kline per interval
|
||||
MaxChangeKLines map[Interval]types.KLine `json:"-"`
|
||||
type KLineCallback func(kline types.KLine)
|
||||
|
||||
//go:generate callbackgen -type MarketDataStore
|
||||
type MarketDataStore struct {
|
||||
// KLineWindows stores all loaded klines per interval
|
||||
KLineWindows map[Interval]types.KLineWindow `json:"-"`
|
||||
|
||||
updateCallbacks []KLineCallback
|
||||
}
|
||||
|
||||
func NewMarketDataStore() *MarketDataStore {
|
||||
return &MarketDataStore{
|
||||
MaxChangeKLines: make(map[Interval]types.KLine),
|
||||
|
||||
// KLineWindows stores all loaded klines per interval
|
||||
KLineWindows: make(map[Interval]types.KLineWindow),
|
||||
}
|
||||
|
@ -32,19 +32,14 @@ func (store *MarketDataStore) BindPrivateStream(stream *types.StandardPrivateStr
|
|||
stream.OnKLineClosed(store.handleKLineClosed)
|
||||
}
|
||||
|
||||
func (store *MarketDataStore) handleKLineClosed(kline *types.KLine) {
|
||||
store.AddKLine(*kline)
|
||||
func (store *MarketDataStore) handleKLineClosed(kline types.KLine) {
|
||||
store.AddKLine(kline)
|
||||
}
|
||||
|
||||
func (store *MarketDataStore) AddKLine(kline types.KLine) {
|
||||
var interval = Interval(kline.Interval)
|
||||
|
||||
var window = store.KLineWindows[interval]
|
||||
window.Add(kline)
|
||||
|
||||
if _, ok := store.MaxChangeKLines[interval] ; ok {
|
||||
if kline.GetMaxChange() > store.MaxChangeKLines[interval].GetMaxChange() {
|
||||
store.MaxChangeKLines[interval] = kline
|
||||
}
|
||||
}
|
||||
store.EmitUpdate(kline)
|
||||
}
|
||||
|
|
|
@ -263,9 +263,9 @@ func (trader *Trader) RunStrategy(ctx context.Context, strategy Strategy) (chan
|
|||
})
|
||||
})
|
||||
|
||||
stream.OnKLineEvent(func(e *binance.KLineEvent) {
|
||||
trader.ProfitAndLossCalculator.SetCurrentPrice(e.KLine.GetClose())
|
||||
trader.Context.SetCurrentPrice(e.KLine.GetClose())
|
||||
stream.OnKLineClosed(func(kline types.KLine) {
|
||||
trader.ProfitAndLossCalculator.SetCurrentPrice(kline.Close)
|
||||
trader.Context.SetCurrentPrice(kline.Close)
|
||||
})
|
||||
|
||||
var eventC = make(chan interface{}, 20)
|
||||
|
|
|
@ -39,23 +39,30 @@ type KLineQueryOptions struct {
|
|||
|
||||
// KLine uses binance's kline as the standard structure
|
||||
type KLine struct {
|
||||
StartTime int64 `json:"t"`
|
||||
EndTime int64 `json:"T"`
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
|
||||
Symbol string `json:"s"`
|
||||
Interval string `json:"i"`
|
||||
Symbol string
|
||||
Interval string
|
||||
|
||||
Open string `json:"o"`
|
||||
Close string `json:"c"`
|
||||
High string `json:"h"`
|
||||
Open float64
|
||||
Close float64
|
||||
High float64
|
||||
Low float64
|
||||
Volume float64
|
||||
QuoteVolume float64
|
||||
|
||||
Low string `json:"l"`
|
||||
Volume string `json:"V"` // taker buy base asset volume (like 10 BTC)
|
||||
QuoteVolume string `json:"Q"` // taker buy quote asset volume (like 1000USDT)
|
||||
LastTradeID int
|
||||
NumberOfTrades int64
|
||||
Closed bool
|
||||
}
|
||||
|
||||
LastTradeID int `json:"L"`
|
||||
NumberOfTrades int64 `json:"n"`
|
||||
Closed bool `json:"x"`
|
||||
func (k KLine) GetStartTime() time.Time {
|
||||
return k.StartTime
|
||||
}
|
||||
|
||||
func (k KLine) GetEndTime() time.Time {
|
||||
return k.EndTime
|
||||
}
|
||||
|
||||
func (k KLine) GetInterval() string {
|
||||
|
@ -63,21 +70,21 @@ func (k KLine) GetInterval() string {
|
|||
}
|
||||
|
||||
func (k KLine) Mid() float64 {
|
||||
return (k.GetHigh() + k.GetLow()) / 2
|
||||
return (k.High + k.Low) / 2
|
||||
}
|
||||
|
||||
// green candle with open and close near high price
|
||||
func (k KLine) BounceUp() bool {
|
||||
mid := k.Mid()
|
||||
trend := k.GetTrend()
|
||||
return trend > 0 && k.GetOpen() > mid && k.GetClose() > mid
|
||||
return trend > 0 && k.Open > mid && k.Close > mid
|
||||
}
|
||||
|
||||
// red candle with open and close near low price
|
||||
func (k KLine) BounceDown() bool {
|
||||
mid := k.Mid()
|
||||
trend := k.GetTrend()
|
||||
return trend > 0 && k.GetOpen() < mid && k.GetClose() < mid
|
||||
return trend > 0 && k.Open < mid && k.Close < mid
|
||||
}
|
||||
|
||||
func (k KLine) GetTrend() int {
|
||||
|
@ -93,19 +100,19 @@ func (k KLine) GetTrend() int {
|
|||
}
|
||||
|
||||
func (k KLine) GetHigh() float64 {
|
||||
return util.MustParseFloat(k.High)
|
||||
return k.High
|
||||
}
|
||||
|
||||
func (k KLine) GetLow() float64 {
|
||||
return util.MustParseFloat(k.Low)
|
||||
return k.Low
|
||||
}
|
||||
|
||||
func (k KLine) GetOpen() float64 {
|
||||
return util.MustParseFloat(k.Open)
|
||||
return k.Open
|
||||
}
|
||||
|
||||
func (k KLine) GetClose() float64 {
|
||||
return util.MustParseFloat(k.Close)
|
||||
return k.Close
|
||||
}
|
||||
|
||||
func (k KLine) GetMaxChange() float64 {
|
||||
|
@ -134,11 +141,11 @@ func (k KLine) GetLowerShadowRatio() float64 {
|
|||
}
|
||||
|
||||
func (k KLine) GetLowerShadowHeight() float64 {
|
||||
low := k.GetLow()
|
||||
if k.GetOpen() < k.GetClose() {
|
||||
return k.GetOpen() - low
|
||||
low := k.Low
|
||||
if k.Open < k.Close {
|
||||
return k.Open - low
|
||||
}
|
||||
return k.GetClose() - low
|
||||
return k.Close - low
|
||||
}
|
||||
|
||||
// GetBody returns the height of the candle real body
|
||||
|
@ -147,19 +154,11 @@ func (k KLine) GetBody() float64 {
|
|||
}
|
||||
|
||||
func (k KLine) GetChange() float64 {
|
||||
return k.GetClose() - k.GetOpen()
|
||||
}
|
||||
|
||||
func (k KLine) GetStartTime() time.Time {
|
||||
return time.Unix(0, k.StartTime*int64(time.Millisecond))
|
||||
}
|
||||
|
||||
func (k KLine) GetEndTime() time.Time {
|
||||
return time.Unix(0, k.EndTime*int64(time.Millisecond))
|
||||
return k.Close - k.Open
|
||||
}
|
||||
|
||||
func (k KLine) String() string {
|
||||
return fmt.Sprintf("%s %s Open: % 14s Close: % 14s High: % 14s Low: % 14s Volume: % 15s Change: % 11f Max Change: % 11f", k.Symbol, k.Interval, k.Open, k.Close, k.High, k.Low, k.Volume, k.GetChange(), k.GetMaxChange())
|
||||
return fmt.Sprintf("%s %s Open: %.8f Close: %.8f High: %.8f Low: %.8f Volume: %.8f Change: %.4f Max Change: %.4f", k.Symbol, k.Interval, k.Open, k.Close, k.High, k.Low, k.Volume, k.GetChange(), k.GetMaxChange())
|
||||
}
|
||||
|
||||
func (k KLine) Color() string {
|
||||
|
@ -176,10 +175,10 @@ func (k KLine) SlackAttachment() slack.Attachment {
|
|||
Text: fmt.Sprintf("*%s* KLine %s", k.Symbol, k.Interval),
|
||||
Color: k.Color(),
|
||||
Fields: []slack.AttachmentField{
|
||||
{Title: "Open", Value: k.Open, Short: true},
|
||||
{Title: "High", Value: k.High, Short: true},
|
||||
{Title: "Low", Value: k.Low, Short: true},
|
||||
{Title: "Close", Value: k.Close, Short: true},
|
||||
{Title: "Open", Value: util.FormatFloat(k.Open, 2), Short: true},
|
||||
{Title: "High", Value: util.FormatFloat(k.High, 2), Short: true},
|
||||
{Title: "Low", Value: util.FormatFloat(k.Low, 2), Short: true},
|
||||
{Title: "Close", Value: util.FormatFloat(k.Close, 2), Short: true},
|
||||
{Title: "Mid", Value: util.FormatFloat(k.Mid(), 2), Short: true},
|
||||
{Title: "Change", Value: util.FormatFloat(k.GetChange(), 2), Short: true},
|
||||
{Title: "Max Change", Value: util.FormatFloat(k.GetMaxChange(), 2), Short: true},
|
||||
|
@ -206,6 +205,17 @@ func (k KLine) SlackAttachment() slack.Attachment {
|
|||
|
||||
type KLineWindow []KLine
|
||||
|
||||
// ReduceClose reduces the closed prices
|
||||
func (k KLineWindow) ReduceClose() float64 {
|
||||
s := 0.0
|
||||
|
||||
for _, kline := range k {
|
||||
s += kline.GetClose()
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (k KLineWindow) Len() int {
|
||||
return len(k)
|
||||
}
|
||||
|
@ -425,7 +435,7 @@ func (k KLineWindow) SlackAttachment() slack.Attachment {
|
|||
Short: true,
|
||||
},
|
||||
},
|
||||
Footer: fmt.Sprintf("Since %s til %s", first.GetStartTime(), end.GetEndTime()),
|
||||
Footer: fmt.Sprintf("Since %s til %s", first.StartTime, end.EndTime),
|
||||
FooterIcon: "",
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@ import ("testing"
|
|||
|
||||
func TestKLineWindow_Tail(t *testing.T) {
|
||||
var win = KLineWindow{
|
||||
{ Open: "11600.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
|
||||
{ Open: "11600.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
|
||||
{ Open: 11600.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
|
||||
{ Open: 11600.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
|
||||
}
|
||||
|
||||
var win2 = win.Tail(1)
|
||||
|
@ -23,22 +23,22 @@ func TestKLineWindow_Tail(t *testing.T) {
|
|||
|
||||
func TestKLineWindow_Truncate(t *testing.T) {
|
||||
var win = KLineWindow{
|
||||
{ Open: "11600.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
|
||||
{ Open: "11601.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
|
||||
{ Open: "11602.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
|
||||
{ Open: "11603.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
|
||||
{ Open: 11600.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
|
||||
{ Open: 11601.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
|
||||
{ Open: 11602.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
|
||||
{ Open: 11603.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
|
||||
}
|
||||
|
||||
win.Truncate(5)
|
||||
assert.Len(t, win, 4)
|
||||
assert.Equal(t, "11603.0", win.Last().Open)
|
||||
assert.Equal(t, 11603.0, win.Last().Open)
|
||||
|
||||
win.Truncate(3)
|
||||
assert.Len(t, win, 3)
|
||||
assert.Equal(t, "11603.0", win.Last().Open)
|
||||
assert.Equal(t, 11603.0, win.Last().Open)
|
||||
|
||||
|
||||
win.Truncate(1)
|
||||
assert.Len(t, win, 1)
|
||||
assert.Equal(t, "11603.0", win.Last().Open)
|
||||
assert.Equal(t, 11603.0, win.Last().Open)
|
||||
}
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
|
||||
package types
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
import ()
|
||||
|
||||
func (stream *StandardPrivateStream) OnTrade(cb func(trade *Trade)) {
|
||||
stream.tradeCallbacks = append(stream.tradeCallbacks, cb)
|
||||
|
@ -16,25 +14,6 @@ func (stream *StandardPrivateStream) EmitTrade(trade *Trade) {
|
|||
}
|
||||
}
|
||||
|
||||
func (stream *StandardPrivateStream) RemoveOnTrade(needle func(trade *Trade)) (found bool) {
|
||||
|
||||
var newcallbacks []func(trade *Trade)
|
||||
var fp = reflect.ValueOf(needle).Pointer()
|
||||
for _, cb := range stream.tradeCallbacks {
|
||||
if fp == reflect.ValueOf(cb).Pointer() {
|
||||
found = true
|
||||
} else {
|
||||
newcallbacks = append(newcallbacks, cb)
|
||||
}
|
||||
}
|
||||
|
||||
if found {
|
||||
stream.tradeCallbacks = newcallbacks
|
||||
}
|
||||
|
||||
return found
|
||||
}
|
||||
|
||||
func (stream *StandardPrivateStream) OnBalanceSnapshot(cb func(balanceSnapshot map[string]Balance)) {
|
||||
stream.balanceSnapshotCallbacks = append(stream.balanceSnapshotCallbacks, cb)
|
||||
}
|
||||
|
@ -45,58 +24,20 @@ func (stream *StandardPrivateStream) EmitBalanceSnapshot(balanceSnapshot map[str
|
|||
}
|
||||
}
|
||||
|
||||
func (stream *StandardPrivateStream) RemoveOnBalanceSnapshot(needle func(balanceSnapshot map[string]Balance)) (found bool) {
|
||||
|
||||
var newcallbacks []func(balanceSnapshot map[string]Balance)
|
||||
var fp = reflect.ValueOf(needle).Pointer()
|
||||
for _, cb := range stream.balanceSnapshotCallbacks {
|
||||
if fp == reflect.ValueOf(cb).Pointer() {
|
||||
found = true
|
||||
} else {
|
||||
newcallbacks = append(newcallbacks, cb)
|
||||
}
|
||||
}
|
||||
|
||||
if found {
|
||||
stream.balanceSnapshotCallbacks = newcallbacks
|
||||
}
|
||||
|
||||
return found
|
||||
}
|
||||
|
||||
func (stream *StandardPrivateStream) OnKLineClosed(cb func(kline *KLine)) {
|
||||
func (stream *StandardPrivateStream) OnKLineClosed(cb func(kline KLine)) {
|
||||
stream.kLineClosedCallbacks = append(stream.kLineClosedCallbacks, cb)
|
||||
}
|
||||
|
||||
func (stream *StandardPrivateStream) EmitKLineClosed(kline *KLine) {
|
||||
func (stream *StandardPrivateStream) EmitKLineClosed(kline KLine) {
|
||||
for _, cb := range stream.kLineClosedCallbacks {
|
||||
cb(kline)
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *StandardPrivateStream) RemoveOnKLineClosed(needle func(kline *KLine)) (found bool) {
|
||||
|
||||
var newcallbacks []func(kline *KLine)
|
||||
var fp = reflect.ValueOf(needle).Pointer()
|
||||
for _, cb := range stream.kLineClosedCallbacks {
|
||||
if fp == reflect.ValueOf(cb).Pointer() {
|
||||
found = true
|
||||
} else {
|
||||
newcallbacks = append(newcallbacks, cb)
|
||||
}
|
||||
}
|
||||
|
||||
if found {
|
||||
stream.kLineClosedCallbacks = newcallbacks
|
||||
}
|
||||
|
||||
return found
|
||||
}
|
||||
|
||||
type StandardPrivateStreamEventHub interface {
|
||||
OnTrade(cb func(trade *Trade))
|
||||
|
||||
OnBalanceSnapshot(cb func(balanceSnapshot map[string]Balance))
|
||||
|
||||
OnKLineClosed(cb func(kline *KLine))
|
||||
OnKLineClosed(cb func(kline KLine))
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ type StandardPrivateStream struct {
|
|||
|
||||
tradeCallbacks []func(trade *Trade)
|
||||
balanceSnapshotCallbacks []func(balanceSnapshot map[string]Balance)
|
||||
kLineClosedCallbacks []func(kline *KLine)
|
||||
kLineClosedCallbacks []func(kline KLine)
|
||||
}
|
||||
|
||||
func (stream *StandardPrivateStream) Subscribe(channel string, symbol string, options SubscribeOptions) {
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMinBackoff = 5 * time.Second
|
||||
DefaultMaxBackoff = time.Minute
|
||||
|
||||
DefaultBackoffFactor = 2
|
||||
)
|
||||
|
||||
type Backoff struct {
|
||||
attempt int64
|
||||
Factor float64
|
||||
cur, Min, Max time.Duration
|
||||
}
|
||||
|
||||
// Duration returns the duration for the current attempt before incrementing
|
||||
// the attempt counter. See ForAttempt.
|
||||
func (b *Backoff) Duration() time.Duration {
|
||||
d := b.calculate(b.attempt)
|
||||
b.attempt++
|
||||
return d
|
||||
}
|
||||
|
||||
func (b *Backoff) calculate(attempt int64) time.Duration {
|
||||
min := b.Min
|
||||
if min <= 0 {
|
||||
min = DefaultMinBackoff
|
||||
}
|
||||
max := b.Max
|
||||
if max <= 0 {
|
||||
max = DefaultMaxBackoff
|
||||
}
|
||||
if min >= max {
|
||||
return max
|
||||
}
|
||||
factor := b.Factor
|
||||
if factor <= 0 {
|
||||
factor = DefaultBackoffFactor
|
||||
}
|
||||
cur := b.cur
|
||||
if cur < min {
|
||||
cur = min
|
||||
} else if cur > max {
|
||||
cur = max
|
||||
}
|
||||
|
||||
//calculate this duration
|
||||
next := cur
|
||||
if attempt > 0 {
|
||||
next = time.Duration(float64(cur) * factor)
|
||||
}
|
||||
|
||||
if next < cur {
|
||||
// overflow
|
||||
next = max
|
||||
} else if next <= min {
|
||||
next = min
|
||||
} else if next >= max {
|
||||
next = max
|
||||
}
|
||||
b.cur = next
|
||||
return next
|
||||
}
|
||||
|
||||
// Reset restarts the current attempt counter at zero.
|
||||
func (b *Backoff) Reset() {
|
||||
b.attempt = 0
|
||||
b.cur = b.Min
|
||||
}
|
||||
|
||||
// Attempt returns the current attempt counter value.
|
||||
func (b *Backoff) Attempt() int64 {
|
||||
return b.attempt
|
||||
}
|
|
@ -1,172 +0,0 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBackoff(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 200 * time.Millisecond,
|
||||
Max: 100 * time.Second,
|
||||
Factor: 2,
|
||||
}
|
||||
|
||||
equals(t, b.Duration(), 200*time.Millisecond)
|
||||
equals(t, b.Duration(), 400*time.Millisecond)
|
||||
equals(t, b.Duration(), 800*time.Millisecond)
|
||||
b.Reset()
|
||||
equals(t, b.Duration(), 200*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestReachMaxBackoff(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 2 * time.Second,
|
||||
Max: 100 * time.Second,
|
||||
Factor: 10,
|
||||
}
|
||||
|
||||
equals(t, b.Duration(), 2*time.Second)
|
||||
equals(t, b.Duration(), 20*time.Second)
|
||||
equals(t, b.Duration(), b.Max)
|
||||
b.Reset()
|
||||
equals(t, b.Duration(), 2*time.Second)
|
||||
}
|
||||
|
||||
func TestAttempt(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 2 * time.Second,
|
||||
Max: 100 * time.Second,
|
||||
Factor: 10,
|
||||
}
|
||||
|
||||
equals(t, b.Attempt(), int64(0))
|
||||
equals(t, b.Duration(), 2*time.Second)
|
||||
equals(t, b.Attempt(), int64(1))
|
||||
equals(t, b.Duration(), 20*time.Second)
|
||||
equals(t, b.Attempt(), int64(2))
|
||||
equals(t, b.Duration(), b.Max)
|
||||
equals(t, b.Attempt(), int64(3))
|
||||
b.Reset()
|
||||
equals(t, b.Attempt(), int64(0))
|
||||
equals(t, b.Duration(), 2*time.Second)
|
||||
equals(t, b.Attempt(), int64(1))
|
||||
}
|
||||
|
||||
func TestAttemptWithZeroValue(t *testing.T) {
|
||||
b := &Backoff{}
|
||||
|
||||
cur := DefaultMinBackoff
|
||||
equals(t, b.Attempt(), int64(0))
|
||||
equals(t, b.Duration(), cur)
|
||||
for i := 1; i < 10; i++ {
|
||||
equals(t, b.Attempt(), int64(i))
|
||||
cur *= DefaultBackoffFactor
|
||||
if cur >= DefaultMaxBackoff {
|
||||
equals(t, b.Duration(), DefaultMaxBackoff)
|
||||
break
|
||||
}
|
||||
equals(t, b.Duration(), cur)
|
||||
}
|
||||
|
||||
b.Reset()
|
||||
equals(t, b.Attempt(), int64(0))
|
||||
equals(t, b.Duration(), DefaultMinBackoff)
|
||||
equals(t, b.Attempt(), int64(1))
|
||||
}
|
||||
|
||||
func TestNegOrZeroMin(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 0,
|
||||
}
|
||||
equals(t, b.Duration(), DefaultMinBackoff)
|
||||
|
||||
b2 := &Backoff{
|
||||
Min: -1 * time.Second,
|
||||
}
|
||||
equals(t, b2.Duration(), DefaultMinBackoff)
|
||||
}
|
||||
|
||||
func TestNegOrZeroMax(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: DefaultMaxBackoff,
|
||||
Max: 0,
|
||||
}
|
||||
equals(t, b.Duration(), DefaultMaxBackoff)
|
||||
|
||||
b2 := &Backoff{
|
||||
Min: DefaultMaxBackoff,
|
||||
Max: -1 * time.Second,
|
||||
}
|
||||
equals(t, b2.Duration(), DefaultMaxBackoff)
|
||||
}
|
||||
|
||||
func TestMinLargerThanMax(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 500 * time.Second,
|
||||
Max: 100 * time.Second,
|
||||
Factor: 1,
|
||||
}
|
||||
equals(t, b.Duration(), b.Max)
|
||||
}
|
||||
|
||||
func TestFakeCur(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 100 * time.Second,
|
||||
Max: 500 * time.Second,
|
||||
Factor: 1,
|
||||
cur: 1000 * time.Second,
|
||||
}
|
||||
equals(t, b.Duration(), b.Max)
|
||||
}
|
||||
|
||||
func TestLargeFactor(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 1 * time.Second,
|
||||
Max: 10 * time.Second,
|
||||
Factor: math.MaxFloat64,
|
||||
}
|
||||
equals(t, b.Duration(), b.Min)
|
||||
for i := 0; i < 100; i++ {
|
||||
equals(t, b.Duration(), b.Max)
|
||||
}
|
||||
b.Reset()
|
||||
equals(t, b.Duration(), b.Min)
|
||||
}
|
||||
|
||||
func TestLargeMinMax(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: time.Duration(math.MaxInt64 - 1),
|
||||
Max: time.Duration(math.MaxInt64),
|
||||
Factor: math.MaxFloat64,
|
||||
}
|
||||
equals(t, b.Duration(), b.Min)
|
||||
for i := 0; i < 100; i++ {
|
||||
equals(t, b.Duration(), b.Max)
|
||||
}
|
||||
b.Reset()
|
||||
equals(t, b.Duration(), b.Min)
|
||||
}
|
||||
|
||||
func TestManyAttempts(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 10 * time.Second,
|
||||
Max: time.Duration(math.MaxInt64),
|
||||
Factor: 1000,
|
||||
}
|
||||
for i := 0; i < 10000; i++ {
|
||||
b.Duration()
|
||||
}
|
||||
equals(t, b.Duration(), b.Max)
|
||||
b.Reset()
|
||||
}
|
||||
|
||||
func equals(t *testing.T, v1, v2 interface{}) {
|
||||
if !reflect.DeepEqual(v1, v2) {
|
||||
assert.Failf(t, "not equal", "Got %v, Expecting %v", v1, v2)
|
||||
}
|
||||
}
|
|
@ -1,368 +0,0 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/log"
|
||||
"github.com/c9s/bbgo/pkg/util"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const DefaultMessageBufferSize = 128
|
||||
|
||||
const DefaultWriteTimeout = 30 * time.Second
|
||||
const DefaultReadTimeout = 30 * time.Second
|
||||
|
||||
var ErrReconnectContextDone = errors.New("reconnect canceled due to context done.")
|
||||
var ErrReconnectFailed = errors.New("failed to reconnect")
|
||||
var ErrConnectionLost = errors.New("connection lost")
|
||||
|
||||
var MaxReconnectRate = rate.Limit(1 / DefaultMinBackoff.Seconds())
|
||||
|
||||
// WebSocketClient allows to connect and receive stream data
|
||||
type WebSocketClient struct {
|
||||
// Url is the websocket connection location, start with ws:// or wss://
|
||||
Url string
|
||||
|
||||
// conn is the current websocket connection, please note the connection
|
||||
// object can be replaced with a new connection object when the connection
|
||||
// is unexpected closed.
|
||||
conn *websocket.Conn
|
||||
|
||||
// Dialer is used for creating the websocket connection
|
||||
Dialer *websocket.Dialer
|
||||
|
||||
// requestHeader is used for the Dial function call. Some credential can be
|
||||
// stored in the http request header for authentication
|
||||
requestHeader http.Header
|
||||
|
||||
// messages is a read-only channel, received messages will be sent to this
|
||||
// channel.
|
||||
messages chan Message
|
||||
|
||||
readTimeout time.Duration
|
||||
|
||||
writeTimeout time.Duration
|
||||
|
||||
onConnect []func(c Client)
|
||||
|
||||
onDisconnect []func(c Client)
|
||||
|
||||
// cancel is mapped to the ctx context object
|
||||
cancel func()
|
||||
|
||||
readerClosed chan struct{}
|
||||
|
||||
connected bool
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
reconnectCh chan struct{}
|
||||
|
||||
backoff Backoff
|
||||
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
// websocket.BinaryMessage or websocket.TextMessage
|
||||
Type int
|
||||
Body []byte
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) Messages() <-chan Message {
|
||||
return c.messages
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) SetReadTimeout(timeout time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.readTimeout = timeout
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) SetWriteTimeout(timeout time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.writeTimeout = timeout
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) OnConnect(f func(c Client)) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.onConnect = append(c.onConnect, f)
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) OnDisconnect(f func(c Client)) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.onDisconnect = append(c.onDisconnect, f)
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) WriteTextMessage(message []byte) error {
|
||||
return c.WriteMessage(websocket.TextMessage, message)
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) WriteBinaryMessage(message []byte) error {
|
||||
return c.WriteMessage(websocket.BinaryMessage, message)
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) WriteMessage(messageType int, data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return ErrConnectionLost
|
||||
}
|
||||
|
||||
if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.conn.WriteMessage(messageType, data)
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) readMessages() error {
|
||||
c.mu.Lock()
|
||||
if !c.connected {
|
||||
c.mu.Unlock()
|
||||
return ErrConnectionLost
|
||||
}
|
||||
timeout := c.readTimeout
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
msgtype, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.messages <- Message{msgtype, message}
|
||||
return nil
|
||||
}
|
||||
|
||||
// listen starts a goroutine for reading message and tries to re-connect to the
|
||||
// server when the reader returns error
|
||||
//
|
||||
// Please note we should always break the reader loop if there is any error
|
||||
// returned from the server.
|
||||
func (c *WebSocketClient) listen(ctx context.Context) {
|
||||
// The life time of both channels "readerClosed" and "reconnectCh" is bound to one connection.
|
||||
// Each channel should be created before loop starts and be closed after loop ends.
|
||||
// "readerClosed" is used to inform "Close()" reader loop ends.
|
||||
// "reconnectCh" is used to centralize reconnection logics in this reader loop.
|
||||
c.mu.Lock()
|
||||
c.readerClosed = make(chan struct{})
|
||||
c.reconnectCh = make(chan struct{}, 1)
|
||||
c.mu.Unlock()
|
||||
defer func() {
|
||||
c.mu.Lock()
|
||||
close(c.readerClosed)
|
||||
close(c.reconnectCh)
|
||||
c.reconnectCh = nil
|
||||
c.mu.Unlock()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case <-c.reconnectCh:
|
||||
// it could be i/o timeout for network disconnection
|
||||
// or it could be invoked from outside.
|
||||
c.SetDisconnected()
|
||||
var maxTries = 1
|
||||
if _, response, err := c.reconnect(ctx, maxTries); err != nil {
|
||||
if err == ErrReconnectContextDone {
|
||||
log.Debugf("[websocket] context canceled. stop reconnecting.")
|
||||
return
|
||||
}
|
||||
log.Warnf("[websocket] failed to reconnect after %d tries!! error: %v response: %v", maxTries, err, response)
|
||||
c.Reconnect()
|
||||
}
|
||||
|
||||
default:
|
||||
if err := c.readMessages(); err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
|
||||
log.Warnf("[websocket] unexpected close error reconnecting: %v", err)
|
||||
}
|
||||
|
||||
log.Warnf("[websocket] failed to read message. error: %+v", err)
|
||||
c.Reconnect()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reconnect triggers reconnection logics
|
||||
func (c *WebSocketClient) Reconnect() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
select {
|
||||
// c.reconnectCh is a buffered channel with cap=1.
|
||||
// At most one reconnect signal could be processed.
|
||||
case c.reconnectCh <- struct{}{}:
|
||||
default:
|
||||
// Entering here means it is already reconnecting.
|
||||
// Drop the current reconnect signal.
|
||||
}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the reader and the connection
|
||||
// ctx is the context used for shutdown process.
|
||||
func (c *WebSocketClient) Close() (err error) {
|
||||
c.mu.Lock()
|
||||
// leave the listen goroutine before we close the connection
|
||||
// checking nil is to handle calling "Close" before "Connect" is called
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.SetDisconnected()
|
||||
|
||||
// wait for the reader func to be closed
|
||||
if c.readerClosed != nil {
|
||||
<-c.readerClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// reconnect tries to create a new connection from the existing dialer
|
||||
func (c *WebSocketClient) reconnect(ctx context.Context, maxTries int) (*websocket.Conn, *http.Response, error) {
|
||||
log.Debugf("[websocket] start reconnecting to %q", c.Url)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, nil, ErrReconnectContextDone
|
||||
default:
|
||||
}
|
||||
|
||||
if s := util.ShouldDelay(c.limiter, DefaultMinBackoff); s > 0 {
|
||||
log.Warn("[websocket] reconnect too frequently. Sleep for ", s)
|
||||
time.Sleep(s)
|
||||
}
|
||||
|
||||
log.Warnf("[websocket] reconnecting x %d to %q", c.backoff.Attempt()+1, c.Url)
|
||||
conn, resp, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
|
||||
if err != nil {
|
||||
dur := c.backoff.Duration()
|
||||
log.Warnf("failed to dial %s: %v, response: %+v. Wait for %v", c.Url, err, resp, dur)
|
||||
time.Sleep(dur)
|
||||
return nil, nil, ErrReconnectFailed
|
||||
}
|
||||
|
||||
log.Infof("[websocket] reconnected to %q", c.Url)
|
||||
// Reset backoff value if connected.
|
||||
c.backoff.Reset()
|
||||
c.setConn(conn)
|
||||
c.setPingHandler(conn)
|
||||
|
||||
return conn, resp, err
|
||||
}
|
||||
|
||||
// Conn returns the current active connection instance
|
||||
func (c *WebSocketClient) Conn() (conn *websocket.Conn) {
|
||||
c.mu.Lock()
|
||||
conn = c.conn
|
||||
c.mu.Unlock()
|
||||
return conn
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) setConn(conn *websocket.Conn) {
|
||||
// Disconnect old connection before replacing with new one.
|
||||
c.SetDisconnected()
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = conn
|
||||
c.connected = true
|
||||
c.mu.Unlock()
|
||||
for _, f := range c.onConnect {
|
||||
go f(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) setPingHandler(conn *websocket.Conn) {
|
||||
conn.SetPingHandler(func(message string) error {
|
||||
if err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
|
||||
return err
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) SetDisconnected() {
|
||||
c.mu.Lock()
|
||||
closed := false
|
||||
if c.conn != nil {
|
||||
closed = true
|
||||
c.conn.Close()
|
||||
}
|
||||
c.connected = false
|
||||
c.conn = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
// Only call disconnect callbacks when a connection is closed
|
||||
for _, f := range c.onDisconnect {
|
||||
go f(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) IsConnected() (ret bool) {
|
||||
c.mu.Lock()
|
||||
ret = c.connected
|
||||
c.mu.Unlock()
|
||||
return ret
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) Connect(basectx context.Context) error {
|
||||
// maintain a context by the client it self, so that we can manually shutdown the connection
|
||||
ctx, cancel := context.WithCancel(basectx)
|
||||
c.cancel = cancel
|
||||
|
||||
conn, _, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
|
||||
if err == nil {
|
||||
// setup connection only when connected
|
||||
c.setConn(conn)
|
||||
c.setPingHandler(conn)
|
||||
}
|
||||
|
||||
// 1) if connection is built up, start listening for messages.
|
||||
// 2) if connection is NOT ready, start reconnecting infinitely.
|
||||
go c.listen(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func New(url string, requestHeader http.Header) *WebSocketClient {
|
||||
return NewWithDialer(url, websocket.DefaultDialer, requestHeader)
|
||||
}
|
||||
|
||||
func NewWithDialer(url string, d *websocket.Dialer, requestHeader http.Header) *WebSocketClient {
|
||||
limiter, err := util.NewValidLimiter(MaxReconnectRate, 1)
|
||||
if err != nil {
|
||||
log.WithError(err).Panic("Invalid rate limiter")
|
||||
}
|
||||
return &WebSocketClient{
|
||||
Url: url,
|
||||
Dialer: d,
|
||||
requestHeader: requestHeader,
|
||||
readTimeout: DefaultReadTimeout,
|
||||
writeTimeout: DefaultWriteTimeout,
|
||||
messages: make(chan Message, DefaultMessageBufferSize),
|
||||
limiter: limiter,
|
||||
}
|
||||
}
|
|
@ -1,202 +0,0 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"***REMOVED***/pkg/log"
|
||||
"***REMOVED***/pkg/testing/testutil"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func messageTicker(t *testing.T, ctx context.Context, conn *websocket.Conn) {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case tt := <-ticker.C:
|
||||
err := conn.WriteMessage(websocket.TextMessage, []byte(tt.String()))
|
||||
assert.NoError(t, err)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func messageReader(t *testing.T, ctx context.Context, conn *websocket.Conn) {
|
||||
for {
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Errorf("unexpected closed error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
t.Logf("message: %v", message)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconnect(t *testing.T) {
|
||||
basectx := context.Background()
|
||||
|
||||
wsHandler := func(ctx context.Context) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Logf("upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
go messageTicker(t, ctx, conn)
|
||||
messageReader(t, ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
serverctx, cancelserver := context.WithCancel(basectx)
|
||||
server := testutil.NewWebSocketServerFunc(t, 0, "/ws", wsHandler(serverctx))
|
||||
go func() {
|
||||
err := server.ListenAndServe()
|
||||
assert.Equal(t, http.ErrServerClosed, err)
|
||||
}()
|
||||
|
||||
_, port, err := net.SplitHostPort(server.Addr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
log.Debugf("waiting for server port ready")
|
||||
testutil.WaitForPort(t, server.Addr, 3)
|
||||
|
||||
u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"}
|
||||
log.Infof("url: %s", u.String())
|
||||
|
||||
client := New(u.String(), nil)
|
||||
|
||||
// start the message reader
|
||||
clientctx, cancelClient := context.WithCancel(basectx)
|
||||
defer cancelClient()
|
||||
|
||||
err = client.Connect(clientctx)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
client.SetReadTimeout(2 * time.Second)
|
||||
|
||||
// read one message
|
||||
log.Debugf("waiting for message...")
|
||||
msg := <-client.messages
|
||||
log.Debugf("received message: %v", msg)
|
||||
|
||||
triggerReconnectManyTimes := func() {
|
||||
// Forcedly reconnect multiple times
|
||||
for i := 0; i < 10; i = i + 1 {
|
||||
client.Reconnect()
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("reconnecting...")
|
||||
triggerReconnectManyTimes()
|
||||
msg = <-client.messages
|
||||
log.Debugf("received message: %v", msg)
|
||||
|
||||
log.Debugf("shutting down server...")
|
||||
err = server.Shutdown(serverctx)
|
||||
assert.NoError(t, err)
|
||||
cancelserver()
|
||||
server.Close()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
triggerReconnectManyTimes()
|
||||
|
||||
log.Debugf("restarting server...")
|
||||
server2ctx, cancelserver2 := context.WithCancel(basectx)
|
||||
server2 := testutil.NewWebSocketServerWithAddressFunc(t, server.Addr, "/ws", wsHandler(server2ctx))
|
||||
go func() {
|
||||
err := server2.ListenAndServe()
|
||||
assert.Equal(t, http.ErrServerClosed, err)
|
||||
}()
|
||||
|
||||
triggerReconnectManyTimes()
|
||||
|
||||
log.Debugf("waiting for server2 port ready")
|
||||
testutil.WaitForPort(t, server.Addr, 3)
|
||||
|
||||
log.Debugf("waiting for message from server2...")
|
||||
msg = <-client.messages
|
||||
log.Debugf("received message: %v", msg)
|
||||
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = server2.Shutdown(server2ctx)
|
||||
assert.NoError(t, err)
|
||||
cancelserver2()
|
||||
}
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
basectx := context.Background()
|
||||
ctx, cancel := context.WithCancel(basectx)
|
||||
|
||||
server := testutil.NewWebSocketServerFunc(t, 0, "/ws", func(w http.ResponseWriter, r *http.Request) {
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Logf("upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go messageTicker(t, ctx, conn)
|
||||
messageReader(t, ctx, conn)
|
||||
})
|
||||
|
||||
go func() {
|
||||
err := server.ListenAndServe()
|
||||
assert.Equal(t, http.ErrServerClosed, err)
|
||||
}()
|
||||
|
||||
_, port, err := net.SplitHostPort(server.Addr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
testutil.WaitForPort(t, server.Addr, 3)
|
||||
|
||||
u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"}
|
||||
t.Logf("url: %s", u.String())
|
||||
|
||||
client := New(u.String(), nil)
|
||||
|
||||
err = client.Connect(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
client.SetReadTimeout(2 * time.Second)
|
||||
|
||||
// read one message
|
||||
t.Logf("waiting for message...")
|
||||
msg := <-client.messages
|
||||
t.Logf("received message: %v", msg)
|
||||
|
||||
// recreate server at the same address
|
||||
client.Close()
|
||||
|
||||
t.Logf("shutting down server...")
|
||||
cancel()
|
||||
err = server.Shutdown(basectx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
//go:generate mockery -name=Client
|
||||
|
||||
type Client interface {
|
||||
SetWriteTimeout(time.Duration)
|
||||
SetReadTimeout(time.Duration)
|
||||
OnConnect(func(c Client))
|
||||
OnDisconnect(func(c Client))
|
||||
Connect(context.Context) error
|
||||
Reconnect()
|
||||
Close() error
|
||||
IsConnected() bool
|
||||
WriteJSON(interface{}) error
|
||||
Messages() <-chan Message
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// WriteJSON writes the JSON encoding of v as a message.
|
||||
//
|
||||
// See the documentation for encoding/json Marshal for details about the
|
||||
// conversion of Go values to JSON.
|
||||
func (c *WebSocketClient) WriteJSON(v interface{}) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return ErrConnectionLost
|
||||
}
|
||||
w, err := c.conn.NextWriter(websocket.TextMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err1 := json.NewEncoder(w).Encode(v)
|
||||
err2 := w.Close()
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
// ReadJSON reads the next JSON-encoded message from the connection and stores
|
||||
// it in the value pointed to by v.
|
||||
//
|
||||
// See the documentation for the encoding/json Unmarshal function for details
|
||||
// about the conversion of JSON to a Go value.
|
||||
func (c *WebSocketClient) ReadJSON(v interface{}) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return ErrConnectionLost
|
||||
}
|
||||
_, r, err := c.conn.NextReader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.NewDecoder(r).Decode(v)
|
||||
if err == io.EOF {
|
||||
// One value is expected in the message.
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -1,110 +0,0 @@
|
|||
// Code generated by mockery v1.0.0. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import context "context"
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
import time "time"
|
||||
import websocket "***REMOVED***/pkg/net/http/websocket"
|
||||
|
||||
// Client is an autogenerated mock type for the Client type
|
||||
type Client struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Close provides a mock function with given fields:
|
||||
func (_m *Client) Close() error {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Connect provides a mock function with given fields: _a0
|
||||
func (_m *Client) Connect(_a0 context.Context) error {
|
||||
ret := _m.Called(_a0)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = rf(_a0)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// IsConnected provides a mock function with given fields:
|
||||
func (_m *Client) IsConnected() bool {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func() bool); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Messages provides a mock function with given fields:
|
||||
func (_m *Client) Messages() <-chan websocket.Message {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 <-chan websocket.Message
|
||||
if rf, ok := ret.Get(0).(func() <-chan websocket.Message); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(<-chan websocket.Message)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// OnConnect provides a mock function with given fields: _a0
|
||||
func (_m *Client) OnConnect(_a0 func(websocket.Client)) {
|
||||
_m.Called(_a0)
|
||||
}
|
||||
|
||||
// OnDisconnect provides a mock function with given fields: _a0
|
||||
func (_m *Client) OnDisconnect(_a0 func(websocket.Client)) {
|
||||
_m.Called(_a0)
|
||||
}
|
||||
|
||||
// Reconnect provides a mock function with given fields:
|
||||
func (_m *Client) Reconnect() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// SetReadTimeout provides a mock function with given fields: _a0
|
||||
func (_m *Client) SetReadTimeout(_a0 time.Duration) {
|
||||
_m.Called(_a0)
|
||||
}
|
||||
|
||||
// SetWriteTimeout provides a mock function with given fields: _a0
|
||||
func (_m *Client) SetWriteTimeout(_a0 time.Duration) {
|
||||
_m.Called(_a0)
|
||||
}
|
||||
|
||||
// WriteJSON provides a mock function with given fields: _a0
|
||||
func (_m *Client) WriteJSON(_a0 interface{}) error {
|
||||
ret := _m.Called(_a0)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
||||
r0 = rf(_a0)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
Loading…
Reference in New Issue
Block a user