pre-parse kline payload so that we don't have to re-parse string everytime

This commit is contained in:
c9s 2020-09-16 12:28:15 +08:00
parent afcde6827f
commit e37932bace
20 changed files with 248 additions and 1170 deletions

View File

@ -303,19 +303,21 @@ func (e *Exchange) QueryKLines(ctx context.Context, symbol, interval string, opt
}
var kLines []types.KLine
for _, kline := range resp {
for _, k := range resp {
kLines = append(kLines, types.KLine{
Symbol: symbol,
Interval: interval,
StartTime: kline.OpenTime,
EndTime: kline.CloseTime,
Open: kline.Open,
Close: kline.Close,
High: kline.High,
Low: kline.Low,
Volume: kline.Volume,
QuoteVolume: kline.QuoteAssetVolume,
NumberOfTrades: kline.TradeNum,
StartTime: time.Unix(0, k.OpenTime*int64(time.Millisecond)),
EndTime: time.Unix(0, k.CloseTime*int64(time.Millisecond)),
Open: util.MustParseFloat(k.Open),
Close: util.MustParseFloat(k.Close),
High: util.MustParseFloat(k.High),
Low: util.MustParseFloat(k.Low),
Volume: util.MustParseFloat(k.Volume),
QuoteVolume: util.MustParseFloat(k.QuoteAssetVolume),
LastTradeID: 0,
NumberOfTrades: k.TradeNum,
Closed: true,
})
}
return kLines, nil
@ -465,13 +467,12 @@ func (e *Exchange) BatchQueryKLines(ctx context.Context, symbol, interval string
}
for _, kline := range klines {
t := time.Unix(0, kline.EndTime*int64(time.Millisecond))
if t.After(endTime) {
if kline.EndTime.After(endTime) {
return allKLines, nil
}
allKLines = append(allKLines, kline)
startTime = t
startTime = kline.EndTime
}
// avoid rate limit

View File

@ -4,10 +4,12 @@ import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/valyala/fastjson"
"github.com/c9s/bbgo/pkg/bbgo/types"
"github.com/c9s/bbgo/pkg/util"
"github.com/valyala/fastjson"
"time"
)
/*
@ -76,9 +78,9 @@ type ExecutionReportEvent struct {
TradeID int64 `json:"t"`
TransactionTime int64 `json:"T"`
LastExecutedQuantity string `json:"l"`
CumulativeFilledQuantity string `json:"z"`
LastExecutedPrice string `json:"L"`
LastExecutedQuantity string `json:"l"`
CumulativeFilledQuantity string `json:"z"`
LastExecutedPrice string `json:"L"`
LastQuoteAssetTransactedQuantity string `json:"Y"`
OrderCreationTime int `json:"O"`
@ -91,17 +93,17 @@ func (e *ExecutionReportEvent) Trade() (*types.Trade, error) {
tt := time.Unix(0, e.TransactionTime/1000000)
return &types.Trade{
ID: e.TradeID,
Symbol: e.Symbol,
Price: util.MustParseFloat(e.LastExecutedPrice),
Quantity: util.MustParseFloat(e.LastExecutedQuantity),
ID: e.TradeID,
Symbol: e.Symbol,
Price: util.MustParseFloat(e.LastExecutedPrice),
Quantity: util.MustParseFloat(e.LastExecutedQuantity),
QuoteQuantity: util.MustParseFloat(e.LastQuoteAssetTransactedQuantity),
Side: e.Side,
IsBuyer: e.Side == "BUY",
IsMaker: e.IsMaker,
Time: tt,
Fee: util.MustParseFloat(e.CommissionAmount),
FeeCurrency: e.CommissionAsset,
Side: e.Side,
IsBuyer: e.Side == "BUY",
IsMaker: e.IsMaker,
Time: tt,
Fee: util.MustParseFloat(e.CommissionAmount),
FeeCurrency: e.CommissionAsset,
}, nil
}
@ -240,10 +242,49 @@ func ParseEvent(message string) (interface{}, error) {
return nil, fmt.Errorf("unsupported message: %s", message)
}
// KLine uses binance's kline as the standard structure
type KLine struct {
StartTime int64 `json:"t"`
EndTime int64 `json:"T"`
Symbol string `json:"s"`
Interval string `json:"i"`
Open string `json:"o"`
Close string `json:"c"`
High string `json:"h"`
Low string `json:"l"`
Volume string `json:"V"` // taker buy base asset volume (like 10 BTC)
QuoteVolume string `json:"Q"` // taker buy quote asset volume (like 1000USDT)
LastTradeID int `json:"L"`
NumberOfTrades int64 `json:"n"`
Closed bool `json:"x"`
}
type KLineEvent struct {
EventBase
Symbol string `json:"s"`
KLine *types.KLine `json:"k,omitempty"`
Symbol string `json:"s"`
KLine KLine `json:"k,omitempty"`
}
func (k *KLine) KLine() types.KLine {
return types.KLine{
Symbol: k.Symbol,
Interval: k.Interval,
StartTime: time.Unix(0, k.StartTime*int64(time.Millisecond)),
EndTime: time.Unix(0, k.EndTime*int64(time.Millisecond)),
Open: util.MustParseFloat(k.Open),
Close: util.MustParseFloat(k.Close),
High: util.MustParseFloat(k.High),
Low: util.MustParseFloat(k.Low),
Volume: util.MustParseFloat(k.Volume),
QuoteVolume: util.MustParseFloat(k.QuoteVolume),
LastTradeID: k.LastTradeID,
NumberOfTrades: k.NumberOfTrades,
Closed: k.Closed,
}
}
/*

View File

@ -31,8 +31,8 @@ type PrivateStream struct {
connectCallbacks []func(stream *PrivateStream)
// custom callbacks
kLineEventCallbacks []func(event *KLineEvent)
kLineClosedEventCallbacks []func(event *KLineEvent)
kLineEventCallbacks []func(e *KLineEvent)
kLineClosedEventCallbacks []func(e *KLineEvent)
balanceUpdateEventCallbacks []func(event *BalanceUpdateEvent)
outboundAccountInfoEventCallbacks []func(event *OutboundAccountInfoEvent)
@ -190,7 +190,7 @@ func (s *PrivateStream) read(ctx context.Context, eventC chan interface{}) {
if e.KLine.Closed {
s.EmitKLineClosedEvent(e)
s.EmitKLineClosed(e.KLine)
s.EmitKLineClosed(e.KLine.KLine())
}
case *ExecutionReportEvent:

63
bbgo/indicator.go Normal file
View File

@ -0,0 +1,63 @@
package bbgo
import (
"math"
"time"
"github.com/c9s/bbgo/pkg/bbgo/types"
)
type MovingAverageIndicator struct {
store *MarketDataStore
Period int
}
func NewMovingAverageIndicator(period int) *MovingAverageIndicator {
return &MovingAverageIndicator{
Period: period,
}
}
func (i *MovingAverageIndicator) handleUpdate(kline types.KLine) {
klines, ok := i.store.KLineWindows[ Interval(kline.Interval) ]
if !ok {
return
}
if len(klines) < i.Period {
return
}
// calculate ma
}
type IndicatorValue struct {
Value float64
Time time.Time
}
func calculateMovingAverage(klines types.KLineWindow, period int) (values []IndicatorValue) {
for idx := range klines[period:] {
offset := idx + period
sum := klines[offset - period:offset].ReduceClose()
values = append(values, IndicatorValue{
Time: klines[offset].GetEndTime(),
Value: math.Round(sum / float64(period)),
})
}
return values
}
func (i *MovingAverageIndicator) SubscribeStore(store *MarketDataStore) {
i.store = store
// register kline update callback
store.OnUpdate(i.handleUpdate)
}

17
bbgo/indicator_test.go Normal file
View File

@ -0,0 +1,17 @@
package bbgo
import (
"testing"
"github.com/c9s/bbgo/pkg/bbgo/types"
)
func TestCalculateMovingAverage(t *testing.T) {
klines := types.KLineWindow{
{
},
}
_ = klines
}

View File

@ -3,7 +3,6 @@ package bbgo
import (
"context"
"fmt"
"time"
"github.com/sirupsen/logrus"
@ -54,7 +53,7 @@ func (trader *KLineRegressionTrader) RunStrategy(ctx context.Context, strategy S
fmt.Print(".")
standardStream.EmitKLineClosed(&kline)
standardStream.EmitKLineClosed(kline)
for _, order := range trader.pendingOrders {
switch order.Side {
@ -117,7 +116,7 @@ func (trader *KLineRegressionTrader) RunStrategy(ctx context.Context, strategy S
Side: string(order.Side),
IsBuyer: order.Side == types.SideTypeBuy,
IsMaker: false,
Time: time.Unix(0, kline.EndTime*int64(time.Millisecond)),
Time: kline.EndTime,
Symbol: trader.Context.Symbol,
Fee: fee,
FeeCurrency: feeCurrency,

View File

@ -0,0 +1,17 @@
// Code generated by "callbackgen -type MarketDataStore"; DO NOT EDIT.
package bbgo
import (
"github.com/c9s/bbgo/pkg/bbgo/types"
)
func (store *MarketDataStore) OnUpdate(cb KLineCallback) {
store.updateCallbacks = append(store.updateCallbacks, cb)
}
func (store *MarketDataStore) EmitUpdate(kline types.KLine) {
for _, cb := range store.updateCallbacks {
cb(kline)
}
}

View File

@ -11,18 +11,18 @@ var Interval5m = Interval("5m")
var Interval1h = Interval("1h")
var Interval1d = Interval("1d")
type MarketDataStore struct {
// MaxChangeKLines stores the max change kline per interval
MaxChangeKLines map[Interval]types.KLine `json:"-"`
type KLineCallback func(kline types.KLine)
//go:generate callbackgen -type MarketDataStore
type MarketDataStore struct {
// KLineWindows stores all loaded klines per interval
KLineWindows map[Interval]types.KLineWindow `json:"-"`
updateCallbacks []KLineCallback
}
func NewMarketDataStore() *MarketDataStore {
return &MarketDataStore{
MaxChangeKLines: make(map[Interval]types.KLine),
// KLineWindows stores all loaded klines per interval
KLineWindows: make(map[Interval]types.KLineWindow),
}
@ -32,19 +32,14 @@ func (store *MarketDataStore) BindPrivateStream(stream *types.StandardPrivateStr
stream.OnKLineClosed(store.handleKLineClosed)
}
func (store *MarketDataStore) handleKLineClosed(kline *types.KLine) {
store.AddKLine(*kline)
func (store *MarketDataStore) handleKLineClosed(kline types.KLine) {
store.AddKLine(kline)
}
func (store *MarketDataStore) AddKLine(kline types.KLine) {
var interval = Interval(kline.Interval)
var window = store.KLineWindows[interval]
window.Add(kline)
if _, ok := store.MaxChangeKLines[interval] ; ok {
if kline.GetMaxChange() > store.MaxChangeKLines[interval].GetMaxChange() {
store.MaxChangeKLines[interval] = kline
}
}
store.EmitUpdate(kline)
}

View File

@ -263,9 +263,9 @@ func (trader *Trader) RunStrategy(ctx context.Context, strategy Strategy) (chan
})
})
stream.OnKLineEvent(func(e *binance.KLineEvent) {
trader.ProfitAndLossCalculator.SetCurrentPrice(e.KLine.GetClose())
trader.Context.SetCurrentPrice(e.KLine.GetClose())
stream.OnKLineClosed(func(kline types.KLine) {
trader.ProfitAndLossCalculator.SetCurrentPrice(kline.Close)
trader.Context.SetCurrentPrice(kline.Close)
})
var eventC = make(chan interface{}, 20)

View File

@ -39,23 +39,30 @@ type KLineQueryOptions struct {
// KLine uses binance's kline as the standard structure
type KLine struct {
StartTime int64 `json:"t"`
EndTime int64 `json:"T"`
StartTime time.Time
EndTime time.Time
Symbol string `json:"s"`
Interval string `json:"i"`
Symbol string
Interval string
Open string `json:"o"`
Close string `json:"c"`
High string `json:"h"`
Open float64
Close float64
High float64
Low float64
Volume float64
QuoteVolume float64
Low string `json:"l"`
Volume string `json:"V"` // taker buy base asset volume (like 10 BTC)
QuoteVolume string `json:"Q"` // taker buy quote asset volume (like 1000USDT)
LastTradeID int
NumberOfTrades int64
Closed bool
}
LastTradeID int `json:"L"`
NumberOfTrades int64 `json:"n"`
Closed bool `json:"x"`
func (k KLine) GetStartTime() time.Time {
return k.StartTime
}
func (k KLine) GetEndTime() time.Time {
return k.EndTime
}
func (k KLine) GetInterval() string {
@ -63,21 +70,21 @@ func (k KLine) GetInterval() string {
}
func (k KLine) Mid() float64 {
return (k.GetHigh() + k.GetLow()) / 2
return (k.High + k.Low) / 2
}
// green candle with open and close near high price
func (k KLine) BounceUp() bool {
mid := k.Mid()
trend := k.GetTrend()
return trend > 0 && k.GetOpen() > mid && k.GetClose() > mid
return trend > 0 && k.Open > mid && k.Close > mid
}
// red candle with open and close near low price
func (k KLine) BounceDown() bool {
mid := k.Mid()
trend := k.GetTrend()
return trend > 0 && k.GetOpen() < mid && k.GetClose() < mid
return trend > 0 && k.Open < mid && k.Close < mid
}
func (k KLine) GetTrend() int {
@ -93,19 +100,19 @@ func (k KLine) GetTrend() int {
}
func (k KLine) GetHigh() float64 {
return util.MustParseFloat(k.High)
return k.High
}
func (k KLine) GetLow() float64 {
return util.MustParseFloat(k.Low)
return k.Low
}
func (k KLine) GetOpen() float64 {
return util.MustParseFloat(k.Open)
return k.Open
}
func (k KLine) GetClose() float64 {
return util.MustParseFloat(k.Close)
return k.Close
}
func (k KLine) GetMaxChange() float64 {
@ -134,11 +141,11 @@ func (k KLine) GetLowerShadowRatio() float64 {
}
func (k KLine) GetLowerShadowHeight() float64 {
low := k.GetLow()
if k.GetOpen() < k.GetClose() {
return k.GetOpen() - low
low := k.Low
if k.Open < k.Close {
return k.Open - low
}
return k.GetClose() - low
return k.Close - low
}
// GetBody returns the height of the candle real body
@ -147,19 +154,11 @@ func (k KLine) GetBody() float64 {
}
func (k KLine) GetChange() float64 {
return k.GetClose() - k.GetOpen()
}
func (k KLine) GetStartTime() time.Time {
return time.Unix(0, k.StartTime*int64(time.Millisecond))
}
func (k KLine) GetEndTime() time.Time {
return time.Unix(0, k.EndTime*int64(time.Millisecond))
return k.Close - k.Open
}
func (k KLine) String() string {
return fmt.Sprintf("%s %s Open: % 14s Close: % 14s High: % 14s Low: % 14s Volume: % 15s Change: % 11f Max Change: % 11f", k.Symbol, k.Interval, k.Open, k.Close, k.High, k.Low, k.Volume, k.GetChange(), k.GetMaxChange())
return fmt.Sprintf("%s %s Open: %.8f Close: %.8f High: %.8f Low: %.8f Volume: %.8f Change: %.4f Max Change: %.4f", k.Symbol, k.Interval, k.Open, k.Close, k.High, k.Low, k.Volume, k.GetChange(), k.GetMaxChange())
}
func (k KLine) Color() string {
@ -176,10 +175,10 @@ func (k KLine) SlackAttachment() slack.Attachment {
Text: fmt.Sprintf("*%s* KLine %s", k.Symbol, k.Interval),
Color: k.Color(),
Fields: []slack.AttachmentField{
{Title: "Open", Value: k.Open, Short: true},
{Title: "High", Value: k.High, Short: true},
{Title: "Low", Value: k.Low, Short: true},
{Title: "Close", Value: k.Close, Short: true},
{Title: "Open", Value: util.FormatFloat(k.Open, 2), Short: true},
{Title: "High", Value: util.FormatFloat(k.High, 2), Short: true},
{Title: "Low", Value: util.FormatFloat(k.Low, 2), Short: true},
{Title: "Close", Value: util.FormatFloat(k.Close, 2), Short: true},
{Title: "Mid", Value: util.FormatFloat(k.Mid(), 2), Short: true},
{Title: "Change", Value: util.FormatFloat(k.GetChange(), 2), Short: true},
{Title: "Max Change", Value: util.FormatFloat(k.GetMaxChange(), 2), Short: true},
@ -206,6 +205,17 @@ func (k KLine) SlackAttachment() slack.Attachment {
type KLineWindow []KLine
// ReduceClose reduces the closed prices
func (k KLineWindow) ReduceClose() float64 {
s := 0.0
for _, kline := range k {
s += kline.GetClose()
}
return s
}
func (k KLineWindow) Len() int {
return len(k)
}
@ -425,7 +435,7 @@ func (k KLineWindow) SlackAttachment() slack.Attachment {
Short: true,
},
},
Footer: fmt.Sprintf("Since %s til %s", first.GetStartTime(), end.GetEndTime()),
Footer: fmt.Sprintf("Since %s til %s", first.StartTime, end.EndTime),
FooterIcon: "",
}
}

View File

@ -7,8 +7,8 @@ import ("testing"
func TestKLineWindow_Tail(t *testing.T) {
var win = KLineWindow{
{ Open: "11600.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
{ Open: "11600.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
{ Open: 11600.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
{ Open: 11600.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
}
var win2 = win.Tail(1)
@ -23,22 +23,22 @@ func TestKLineWindow_Tail(t *testing.T) {
func TestKLineWindow_Truncate(t *testing.T) {
var win = KLineWindow{
{ Open: "11600.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
{ Open: "11601.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
{ Open: "11602.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
{ Open: "11603.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
{ Open: 11600.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
{ Open: 11601.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
{ Open: 11602.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
{ Open: 11603.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
}
win.Truncate(5)
assert.Len(t, win, 4)
assert.Equal(t, "11603.0", win.Last().Open)
assert.Equal(t, 11603.0, win.Last().Open)
win.Truncate(3)
assert.Len(t, win, 3)
assert.Equal(t, "11603.0", win.Last().Open)
assert.Equal(t, 11603.0, win.Last().Open)
win.Truncate(1)
assert.Len(t, win, 1)
assert.Equal(t, "11603.0", win.Last().Open)
assert.Equal(t, 11603.0, win.Last().Open)
}

View File

@ -2,9 +2,7 @@
package types
import (
"reflect"
)
import ()
func (stream *StandardPrivateStream) OnTrade(cb func(trade *Trade)) {
stream.tradeCallbacks = append(stream.tradeCallbacks, cb)
@ -16,25 +14,6 @@ func (stream *StandardPrivateStream) EmitTrade(trade *Trade) {
}
}
func (stream *StandardPrivateStream) RemoveOnTrade(needle func(trade *Trade)) (found bool) {
var newcallbacks []func(trade *Trade)
var fp = reflect.ValueOf(needle).Pointer()
for _, cb := range stream.tradeCallbacks {
if fp == reflect.ValueOf(cb).Pointer() {
found = true
} else {
newcallbacks = append(newcallbacks, cb)
}
}
if found {
stream.tradeCallbacks = newcallbacks
}
return found
}
func (stream *StandardPrivateStream) OnBalanceSnapshot(cb func(balanceSnapshot map[string]Balance)) {
stream.balanceSnapshotCallbacks = append(stream.balanceSnapshotCallbacks, cb)
}
@ -45,58 +24,20 @@ func (stream *StandardPrivateStream) EmitBalanceSnapshot(balanceSnapshot map[str
}
}
func (stream *StandardPrivateStream) RemoveOnBalanceSnapshot(needle func(balanceSnapshot map[string]Balance)) (found bool) {
var newcallbacks []func(balanceSnapshot map[string]Balance)
var fp = reflect.ValueOf(needle).Pointer()
for _, cb := range stream.balanceSnapshotCallbacks {
if fp == reflect.ValueOf(cb).Pointer() {
found = true
} else {
newcallbacks = append(newcallbacks, cb)
}
}
if found {
stream.balanceSnapshotCallbacks = newcallbacks
}
return found
}
func (stream *StandardPrivateStream) OnKLineClosed(cb func(kline *KLine)) {
func (stream *StandardPrivateStream) OnKLineClosed(cb func(kline KLine)) {
stream.kLineClosedCallbacks = append(stream.kLineClosedCallbacks, cb)
}
func (stream *StandardPrivateStream) EmitKLineClosed(kline *KLine) {
func (stream *StandardPrivateStream) EmitKLineClosed(kline KLine) {
for _, cb := range stream.kLineClosedCallbacks {
cb(kline)
}
}
func (stream *StandardPrivateStream) RemoveOnKLineClosed(needle func(kline *KLine)) (found bool) {
var newcallbacks []func(kline *KLine)
var fp = reflect.ValueOf(needle).Pointer()
for _, cb := range stream.kLineClosedCallbacks {
if fp == reflect.ValueOf(cb).Pointer() {
found = true
} else {
newcallbacks = append(newcallbacks, cb)
}
}
if found {
stream.kLineClosedCallbacks = newcallbacks
}
return found
}
type StandardPrivateStreamEventHub interface {
OnTrade(cb func(trade *Trade))
OnBalanceSnapshot(cb func(balanceSnapshot map[string]Balance))
OnKLineClosed(cb func(kline *KLine))
OnKLineClosed(cb func(kline KLine))
}

View File

@ -11,7 +11,7 @@ type StandardPrivateStream struct {
tradeCallbacks []func(trade *Trade)
balanceSnapshotCallbacks []func(balanceSnapshot map[string]Balance)
kLineClosedCallbacks []func(kline *KLine)
kLineClosedCallbacks []func(kline KLine)
}
func (stream *StandardPrivateStream) Subscribe(channel string, symbol string, options SubscribeOptions) {

View File

@ -1,78 +0,0 @@
package websocket
import (
"time"
)
const (
DefaultMinBackoff = 5 * time.Second
DefaultMaxBackoff = time.Minute
DefaultBackoffFactor = 2
)
type Backoff struct {
attempt int64
Factor float64
cur, Min, Max time.Duration
}
// Duration returns the duration for the current attempt before incrementing
// the attempt counter. See ForAttempt.
func (b *Backoff) Duration() time.Duration {
d := b.calculate(b.attempt)
b.attempt++
return d
}
func (b *Backoff) calculate(attempt int64) time.Duration {
min := b.Min
if min <= 0 {
min = DefaultMinBackoff
}
max := b.Max
if max <= 0 {
max = DefaultMaxBackoff
}
if min >= max {
return max
}
factor := b.Factor
if factor <= 0 {
factor = DefaultBackoffFactor
}
cur := b.cur
if cur < min {
cur = min
} else if cur > max {
cur = max
}
//calculate this duration
next := cur
if attempt > 0 {
next = time.Duration(float64(cur) * factor)
}
if next < cur {
// overflow
next = max
} else if next <= min {
next = min
} else if next >= max {
next = max
}
b.cur = next
return next
}
// Reset restarts the current attempt counter at zero.
func (b *Backoff) Reset() {
b.attempt = 0
b.cur = b.Min
}
// Attempt returns the current attempt counter value.
func (b *Backoff) Attempt() int64 {
return b.attempt
}

View File

@ -1,172 +0,0 @@
package websocket
import (
"math"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestBackoff(t *testing.T) {
b := &Backoff{
Min: 200 * time.Millisecond,
Max: 100 * time.Second,
Factor: 2,
}
equals(t, b.Duration(), 200*time.Millisecond)
equals(t, b.Duration(), 400*time.Millisecond)
equals(t, b.Duration(), 800*time.Millisecond)
b.Reset()
equals(t, b.Duration(), 200*time.Millisecond)
}
func TestReachMaxBackoff(t *testing.T) {
b := &Backoff{
Min: 2 * time.Second,
Max: 100 * time.Second,
Factor: 10,
}
equals(t, b.Duration(), 2*time.Second)
equals(t, b.Duration(), 20*time.Second)
equals(t, b.Duration(), b.Max)
b.Reset()
equals(t, b.Duration(), 2*time.Second)
}
func TestAttempt(t *testing.T) {
b := &Backoff{
Min: 2 * time.Second,
Max: 100 * time.Second,
Factor: 10,
}
equals(t, b.Attempt(), int64(0))
equals(t, b.Duration(), 2*time.Second)
equals(t, b.Attempt(), int64(1))
equals(t, b.Duration(), 20*time.Second)
equals(t, b.Attempt(), int64(2))
equals(t, b.Duration(), b.Max)
equals(t, b.Attempt(), int64(3))
b.Reset()
equals(t, b.Attempt(), int64(0))
equals(t, b.Duration(), 2*time.Second)
equals(t, b.Attempt(), int64(1))
}
func TestAttemptWithZeroValue(t *testing.T) {
b := &Backoff{}
cur := DefaultMinBackoff
equals(t, b.Attempt(), int64(0))
equals(t, b.Duration(), cur)
for i := 1; i < 10; i++ {
equals(t, b.Attempt(), int64(i))
cur *= DefaultBackoffFactor
if cur >= DefaultMaxBackoff {
equals(t, b.Duration(), DefaultMaxBackoff)
break
}
equals(t, b.Duration(), cur)
}
b.Reset()
equals(t, b.Attempt(), int64(0))
equals(t, b.Duration(), DefaultMinBackoff)
equals(t, b.Attempt(), int64(1))
}
func TestNegOrZeroMin(t *testing.T) {
b := &Backoff{
Min: 0,
}
equals(t, b.Duration(), DefaultMinBackoff)
b2 := &Backoff{
Min: -1 * time.Second,
}
equals(t, b2.Duration(), DefaultMinBackoff)
}
func TestNegOrZeroMax(t *testing.T) {
b := &Backoff{
Min: DefaultMaxBackoff,
Max: 0,
}
equals(t, b.Duration(), DefaultMaxBackoff)
b2 := &Backoff{
Min: DefaultMaxBackoff,
Max: -1 * time.Second,
}
equals(t, b2.Duration(), DefaultMaxBackoff)
}
func TestMinLargerThanMax(t *testing.T) {
b := &Backoff{
Min: 500 * time.Second,
Max: 100 * time.Second,
Factor: 1,
}
equals(t, b.Duration(), b.Max)
}
func TestFakeCur(t *testing.T) {
b := &Backoff{
Min: 100 * time.Second,
Max: 500 * time.Second,
Factor: 1,
cur: 1000 * time.Second,
}
equals(t, b.Duration(), b.Max)
}
func TestLargeFactor(t *testing.T) {
b := &Backoff{
Min: 1 * time.Second,
Max: 10 * time.Second,
Factor: math.MaxFloat64,
}
equals(t, b.Duration(), b.Min)
for i := 0; i < 100; i++ {
equals(t, b.Duration(), b.Max)
}
b.Reset()
equals(t, b.Duration(), b.Min)
}
func TestLargeMinMax(t *testing.T) {
b := &Backoff{
Min: time.Duration(math.MaxInt64 - 1),
Max: time.Duration(math.MaxInt64),
Factor: math.MaxFloat64,
}
equals(t, b.Duration(), b.Min)
for i := 0; i < 100; i++ {
equals(t, b.Duration(), b.Max)
}
b.Reset()
equals(t, b.Duration(), b.Min)
}
func TestManyAttempts(t *testing.T) {
b := &Backoff{
Min: 10 * time.Second,
Max: time.Duration(math.MaxInt64),
Factor: 1000,
}
for i := 0; i < 10000; i++ {
b.Duration()
}
equals(t, b.Duration(), b.Max)
b.Reset()
}
func equals(t *testing.T, v1, v2 interface{}) {
if !reflect.DeepEqual(v1, v2) {
assert.Failf(t, "not equal", "Got %v, Expecting %v", v1, v2)
}
}

View File

@ -1,368 +0,0 @@
package websocket
import (
"context"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
"github.com/c9s/bbgo/pkg/log"
"github.com/c9s/bbgo/pkg/util"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
const DefaultMessageBufferSize = 128
const DefaultWriteTimeout = 30 * time.Second
const DefaultReadTimeout = 30 * time.Second
var ErrReconnectContextDone = errors.New("reconnect canceled due to context done.")
var ErrReconnectFailed = errors.New("failed to reconnect")
var ErrConnectionLost = errors.New("connection lost")
var MaxReconnectRate = rate.Limit(1 / DefaultMinBackoff.Seconds())
// WebSocketClient allows to connect and receive stream data
type WebSocketClient struct {
// Url is the websocket connection location, start with ws:// or wss://
Url string
// conn is the current websocket connection, please note the connection
// object can be replaced with a new connection object when the connection
// is unexpected closed.
conn *websocket.Conn
// Dialer is used for creating the websocket connection
Dialer *websocket.Dialer
// requestHeader is used for the Dial function call. Some credential can be
// stored in the http request header for authentication
requestHeader http.Header
// messages is a read-only channel, received messages will be sent to this
// channel.
messages chan Message
readTimeout time.Duration
writeTimeout time.Duration
onConnect []func(c Client)
onDisconnect []func(c Client)
// cancel is mapped to the ctx context object
cancel func()
readerClosed chan struct{}
connected bool
mu sync.Mutex
reconnectCh chan struct{}
backoff Backoff
limiter *rate.Limiter
}
type Message struct {
// websocket.BinaryMessage or websocket.TextMessage
Type int
Body []byte
}
func (c *WebSocketClient) Messages() <-chan Message {
return c.messages
}
func (c *WebSocketClient) SetReadTimeout(timeout time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.readTimeout = timeout
}
func (c *WebSocketClient) SetWriteTimeout(timeout time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.writeTimeout = timeout
}
func (c *WebSocketClient) OnConnect(f func(c Client)) {
c.mu.Lock()
defer c.mu.Unlock()
c.onConnect = append(c.onConnect, f)
}
func (c *WebSocketClient) OnDisconnect(f func(c Client)) {
c.mu.Lock()
defer c.mu.Unlock()
c.onDisconnect = append(c.onDisconnect, f)
}
func (c *WebSocketClient) WriteTextMessage(message []byte) error {
return c.WriteMessage(websocket.TextMessage, message)
}
func (c *WebSocketClient) WriteBinaryMessage(message []byte) error {
return c.WriteMessage(websocket.BinaryMessage, message)
}
func (c *WebSocketClient) WriteMessage(messageType int, data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return ErrConnectionLost
}
if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
return err
}
return c.conn.WriteMessage(messageType, data)
}
func (c *WebSocketClient) readMessages() error {
c.mu.Lock()
if !c.connected {
c.mu.Unlock()
return ErrConnectionLost
}
timeout := c.readTimeout
conn := c.conn
c.mu.Unlock()
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
return err
}
msgtype, message, err := conn.ReadMessage()
if err != nil {
return err
}
c.messages <- Message{msgtype, message}
return nil
}
// listen starts a goroutine for reading message and tries to re-connect to the
// server when the reader returns error
//
// Please note we should always break the reader loop if there is any error
// returned from the server.
func (c *WebSocketClient) listen(ctx context.Context) {
// The life time of both channels "readerClosed" and "reconnectCh" is bound to one connection.
// Each channel should be created before loop starts and be closed after loop ends.
// "readerClosed" is used to inform "Close()" reader loop ends.
// "reconnectCh" is used to centralize reconnection logics in this reader loop.
c.mu.Lock()
c.readerClosed = make(chan struct{})
c.reconnectCh = make(chan struct{}, 1)
c.mu.Unlock()
defer func() {
c.mu.Lock()
close(c.readerClosed)
close(c.reconnectCh)
c.reconnectCh = nil
c.mu.Unlock()
}()
for {
select {
case <-ctx.Done():
return
case <-c.reconnectCh:
// it could be i/o timeout for network disconnection
// or it could be invoked from outside.
c.SetDisconnected()
var maxTries = 1
if _, response, err := c.reconnect(ctx, maxTries); err != nil {
if err == ErrReconnectContextDone {
log.Debugf("[websocket] context canceled. stop reconnecting.")
return
}
log.Warnf("[websocket] failed to reconnect after %d tries!! error: %v response: %v", maxTries, err, response)
c.Reconnect()
}
default:
if err := c.readMessages(); err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
log.Warnf("[websocket] unexpected close error reconnecting: %v", err)
}
log.Warnf("[websocket] failed to read message. error: %+v", err)
c.Reconnect()
}
}
}
}
// Reconnect triggers reconnection logics
func (c *WebSocketClient) Reconnect() {
c.mu.Lock()
defer c.mu.Unlock()
select {
// c.reconnectCh is a buffered channel with cap=1.
// At most one reconnect signal could be processed.
case c.reconnectCh <- struct{}{}:
default:
// Entering here means it is already reconnecting.
// Drop the current reconnect signal.
}
}
// Close gracefully shuts down the reader and the connection
// ctx is the context used for shutdown process.
func (c *WebSocketClient) Close() (err error) {
c.mu.Lock()
// leave the listen goroutine before we close the connection
// checking nil is to handle calling "Close" before "Connect" is called
if c.cancel != nil {
c.cancel()
}
c.mu.Unlock()
c.SetDisconnected()
// wait for the reader func to be closed
if c.readerClosed != nil {
<-c.readerClosed
}
return err
}
// reconnect tries to create a new connection from the existing dialer
func (c *WebSocketClient) reconnect(ctx context.Context, maxTries int) (*websocket.Conn, *http.Response, error) {
log.Debugf("[websocket] start reconnecting to %q", c.Url)
select {
case <-ctx.Done():
return nil, nil, ErrReconnectContextDone
default:
}
if s := util.ShouldDelay(c.limiter, DefaultMinBackoff); s > 0 {
log.Warn("[websocket] reconnect too frequently. Sleep for ", s)
time.Sleep(s)
}
log.Warnf("[websocket] reconnecting x %d to %q", c.backoff.Attempt()+1, c.Url)
conn, resp, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
if err != nil {
dur := c.backoff.Duration()
log.Warnf("failed to dial %s: %v, response: %+v. Wait for %v", c.Url, err, resp, dur)
time.Sleep(dur)
return nil, nil, ErrReconnectFailed
}
log.Infof("[websocket] reconnected to %q", c.Url)
// Reset backoff value if connected.
c.backoff.Reset()
c.setConn(conn)
c.setPingHandler(conn)
return conn, resp, err
}
// Conn returns the current active connection instance
func (c *WebSocketClient) Conn() (conn *websocket.Conn) {
c.mu.Lock()
conn = c.conn
c.mu.Unlock()
return conn
}
func (c *WebSocketClient) setConn(conn *websocket.Conn) {
// Disconnect old connection before replacing with new one.
c.SetDisconnected()
c.mu.Lock()
c.conn = conn
c.connected = true
c.mu.Unlock()
for _, f := range c.onConnect {
go f(c)
}
}
func (c *WebSocketClient) setPingHandler(conn *websocket.Conn) {
conn.SetPingHandler(func(message string) error {
if err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
return err
}
c.mu.Lock()
defer c.mu.Unlock()
return conn.SetReadDeadline(time.Now().Add(c.readTimeout))
})
}
func (c *WebSocketClient) SetDisconnected() {
c.mu.Lock()
closed := false
if c.conn != nil {
closed = true
c.conn.Close()
}
c.connected = false
c.conn = nil
c.mu.Unlock()
if closed {
// Only call disconnect callbacks when a connection is closed
for _, f := range c.onDisconnect {
go f(c)
}
}
}
func (c *WebSocketClient) IsConnected() (ret bool) {
c.mu.Lock()
ret = c.connected
c.mu.Unlock()
return ret
}
func (c *WebSocketClient) Connect(basectx context.Context) error {
// maintain a context by the client it self, so that we can manually shutdown the connection
ctx, cancel := context.WithCancel(basectx)
c.cancel = cancel
conn, _, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
if err == nil {
// setup connection only when connected
c.setConn(conn)
c.setPingHandler(conn)
}
// 1) if connection is built up, start listening for messages.
// 2) if connection is NOT ready, start reconnecting infinitely.
go c.listen(ctx)
return err
}
func New(url string, requestHeader http.Header) *WebSocketClient {
return NewWithDialer(url, websocket.DefaultDialer, requestHeader)
}
func NewWithDialer(url string, d *websocket.Dialer, requestHeader http.Header) *WebSocketClient {
limiter, err := util.NewValidLimiter(MaxReconnectRate, 1)
if err != nil {
log.WithError(err).Panic("Invalid rate limiter")
}
return &WebSocketClient{
Url: url,
Dialer: d,
requestHeader: requestHeader,
readTimeout: DefaultReadTimeout,
writeTimeout: DefaultWriteTimeout,
messages: make(chan Message, DefaultMessageBufferSize),
limiter: limiter,
}
}

View File

@ -1,202 +0,0 @@
package websocket
import (
"context"
"net"
"net/http"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"***REMOVED***/pkg/log"
"***REMOVED***/pkg/testing/testutil"
"github.com/gorilla/websocket"
)
func messageTicker(t *testing.T, ctx context.Context, conn *websocket.Conn) {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case tt := <-ticker.C:
err := conn.WriteMessage(websocket.TextMessage, []byte(tt.String()))
assert.NoError(t, err)
case <-ctx.Done():
return
}
}
}
func messageReader(t *testing.T, ctx context.Context, conn *websocket.Conn) {
for {
_, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Errorf("unexpected closed error: %v", err)
}
break
}
t.Logf("message: %v", message)
select {
case <-ctx.Done():
return
default:
}
}
}
func TestReconnect(t *testing.T) {
basectx := context.Background()
wsHandler := func(ctx context.Context) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Logf("upgrade error: %v", err)
return
}
go messageTicker(t, ctx, conn)
messageReader(t, ctx, conn)
}
}
serverctx, cancelserver := context.WithCancel(basectx)
server := testutil.NewWebSocketServerFunc(t, 0, "/ws", wsHandler(serverctx))
go func() {
err := server.ListenAndServe()
assert.Equal(t, http.ErrServerClosed, err)
}()
_, port, err := net.SplitHostPort(server.Addr)
assert.NoError(t, err)
log.Debugf("waiting for server port ready")
testutil.WaitForPort(t, server.Addr, 3)
u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"}
log.Infof("url: %s", u.String())
client := New(u.String(), nil)
// start the message reader
clientctx, cancelClient := context.WithCancel(basectx)
defer cancelClient()
err = client.Connect(clientctx)
assert.NoError(t, err)
assert.NotNil(t, client)
client.SetReadTimeout(2 * time.Second)
// read one message
log.Debugf("waiting for message...")
msg := <-client.messages
log.Debugf("received message: %v", msg)
triggerReconnectManyTimes := func() {
// Forcedly reconnect multiple times
for i := 0; i < 10; i = i + 1 {
client.Reconnect()
}
}
log.Debugf("reconnecting...")
triggerReconnectManyTimes()
msg = <-client.messages
log.Debugf("received message: %v", msg)
log.Debugf("shutting down server...")
err = server.Shutdown(serverctx)
assert.NoError(t, err)
cancelserver()
server.Close()
time.Sleep(500 * time.Millisecond)
triggerReconnectManyTimes()
log.Debugf("restarting server...")
server2ctx, cancelserver2 := context.WithCancel(basectx)
server2 := testutil.NewWebSocketServerWithAddressFunc(t, server.Addr, "/ws", wsHandler(server2ctx))
go func() {
err := server2.ListenAndServe()
assert.Equal(t, http.ErrServerClosed, err)
}()
triggerReconnectManyTimes()
log.Debugf("waiting for server2 port ready")
testutil.WaitForPort(t, server.Addr, 3)
log.Debugf("waiting for message from server2...")
msg = <-client.messages
log.Debugf("received message: %v", msg)
err = client.Close()
assert.NoError(t, err)
err = server2.Shutdown(server2ctx)
assert.NoError(t, err)
cancelserver2()
}
func TestConnect(t *testing.T) {
basectx := context.Background()
ctx, cancel := context.WithCancel(basectx)
server := testutil.NewWebSocketServerFunc(t, 0, "/ws", func(w http.ResponseWriter, r *http.Request) {
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Logf("upgrade error: %v", err)
return
}
go messageTicker(t, ctx, conn)
messageReader(t, ctx, conn)
})
go func() {
err := server.ListenAndServe()
assert.Equal(t, http.ErrServerClosed, err)
}()
_, port, err := net.SplitHostPort(server.Addr)
assert.NoError(t, err)
testutil.WaitForPort(t, server.Addr, 3)
u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"}
t.Logf("url: %s", u.String())
client := New(u.String(), nil)
err = client.Connect(ctx)
assert.NoError(t, err)
assert.NotNil(t, client)
client.SetReadTimeout(2 * time.Second)
// read one message
t.Logf("waiting for message...")
msg := <-client.messages
t.Logf("received message: %v", msg)
// recreate server at the same address
client.Close()
t.Logf("shutting down server...")
cancel()
err = server.Shutdown(basectx)
assert.NoError(t, err)
}

View File

@ -1,21 +0,0 @@
package websocket
import (
"context"
"time"
)
//go:generate mockery -name=Client
type Client interface {
SetWriteTimeout(time.Duration)
SetReadTimeout(time.Duration)
OnConnect(func(c Client))
OnDisconnect(func(c Client))
Connect(context.Context) error
Reconnect()
Close() error
IsConnected() bool
WriteJSON(interface{}) error
Messages() <-chan Message
}

View File

@ -1,55 +0,0 @@
package websocket
import (
"encoding/json"
"io"
"github.com/gorilla/websocket"
)
// WriteJSON writes the JSON encoding of v as a message.
//
// See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON.
func (c *WebSocketClient) WriteJSON(v interface{}) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return ErrConnectionLost
}
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
return err
}
err1 := json.NewEncoder(w).Encode(v)
err2 := w.Close()
if err1 != nil {
return err1
}
return err2
}
// ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// See the documentation for the encoding/json Unmarshal function for details
// about the conversion of JSON to a Go value.
func (c *WebSocketClient) ReadJSON(v interface{}) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return ErrConnectionLost
}
_, r, err := c.conn.NextReader()
if err != nil {
return err
}
err = json.NewDecoder(r).Decode(v)
if err == io.EOF {
// One value is expected in the message.
err = io.ErrUnexpectedEOF
}
return err
}

View File

@ -1,110 +0,0 @@
// Code generated by mockery v1.0.0. DO NOT EDIT.
package mocks
import context "context"
import mock "github.com/stretchr/testify/mock"
import time "time"
import websocket "***REMOVED***/pkg/net/http/websocket"
// Client is an autogenerated mock type for the Client type
type Client struct {
mock.Mock
}
// Close provides a mock function with given fields:
func (_m *Client) Close() error {
ret := _m.Called()
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// Connect provides a mock function with given fields: _a0
func (_m *Client) Connect(_a0 context.Context) error {
ret := _m.Called(_a0)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
}
return r0
}
// IsConnected provides a mock function with given fields:
func (_m *Client) IsConnected() bool {
ret := _m.Called()
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// Messages provides a mock function with given fields:
func (_m *Client) Messages() <-chan websocket.Message {
ret := _m.Called()
var r0 <-chan websocket.Message
if rf, ok := ret.Get(0).(func() <-chan websocket.Message); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(<-chan websocket.Message)
}
}
return r0
}
// OnConnect provides a mock function with given fields: _a0
func (_m *Client) OnConnect(_a0 func(websocket.Client)) {
_m.Called(_a0)
}
// OnDisconnect provides a mock function with given fields: _a0
func (_m *Client) OnDisconnect(_a0 func(websocket.Client)) {
_m.Called(_a0)
}
// Reconnect provides a mock function with given fields:
func (_m *Client) Reconnect() {
_m.Called()
}
// SetReadTimeout provides a mock function with given fields: _a0
func (_m *Client) SetReadTimeout(_a0 time.Duration) {
_m.Called(_a0)
}
// SetWriteTimeout provides a mock function with given fields: _a0
func (_m *Client) SetWriteTimeout(_a0 time.Duration) {
_m.Called(_a0)
}
// WriteJSON provides a mock function with given fields: _a0
func (_m *Client) WriteJSON(_a0 interface{}) error {
ret := _m.Called(_a0)
var r0 error
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
}
return r0
}