mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-26 00:35:15 +00:00
pre-parse kline payload so that we don't have to re-parse string everytime
This commit is contained in:
parent
afcde6827f
commit
e37932bace
|
@ -303,19 +303,21 @@ func (e *Exchange) QueryKLines(ctx context.Context, symbol, interval string, opt
|
||||||
}
|
}
|
||||||
|
|
||||||
var kLines []types.KLine
|
var kLines []types.KLine
|
||||||
for _, kline := range resp {
|
for _, k := range resp {
|
||||||
kLines = append(kLines, types.KLine{
|
kLines = append(kLines, types.KLine{
|
||||||
Symbol: symbol,
|
Symbol: symbol,
|
||||||
Interval: interval,
|
Interval: interval,
|
||||||
StartTime: kline.OpenTime,
|
StartTime: time.Unix(0, k.OpenTime*int64(time.Millisecond)),
|
||||||
EndTime: kline.CloseTime,
|
EndTime: time.Unix(0, k.CloseTime*int64(time.Millisecond)),
|
||||||
Open: kline.Open,
|
Open: util.MustParseFloat(k.Open),
|
||||||
Close: kline.Close,
|
Close: util.MustParseFloat(k.Close),
|
||||||
High: kline.High,
|
High: util.MustParseFloat(k.High),
|
||||||
Low: kline.Low,
|
Low: util.MustParseFloat(k.Low),
|
||||||
Volume: kline.Volume,
|
Volume: util.MustParseFloat(k.Volume),
|
||||||
QuoteVolume: kline.QuoteAssetVolume,
|
QuoteVolume: util.MustParseFloat(k.QuoteAssetVolume),
|
||||||
NumberOfTrades: kline.TradeNum,
|
LastTradeID: 0,
|
||||||
|
NumberOfTrades: k.TradeNum,
|
||||||
|
Closed: true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return kLines, nil
|
return kLines, nil
|
||||||
|
@ -465,13 +467,12 @@ func (e *Exchange) BatchQueryKLines(ctx context.Context, symbol, interval string
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, kline := range klines {
|
for _, kline := range klines {
|
||||||
t := time.Unix(0, kline.EndTime*int64(time.Millisecond))
|
if kline.EndTime.After(endTime) {
|
||||||
if t.After(endTime) {
|
|
||||||
return allKLines, nil
|
return allKLines, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
allKLines = append(allKLines, kline)
|
allKLines = append(allKLines, kline)
|
||||||
startTime = t
|
startTime = kline.EndTime
|
||||||
}
|
}
|
||||||
|
|
||||||
// avoid rate limit
|
// avoid rate limit
|
||||||
|
|
|
@ -4,10 +4,12 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fastjson"
|
||||||
|
|
||||||
"github.com/c9s/bbgo/pkg/bbgo/types"
|
"github.com/c9s/bbgo/pkg/bbgo/types"
|
||||||
"github.com/c9s/bbgo/pkg/util"
|
"github.com/c9s/bbgo/pkg/util"
|
||||||
"github.com/valyala/fastjson"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -240,10 +242,49 @@ func ParseEvent(message string) (interface{}, error) {
|
||||||
return nil, fmt.Errorf("unsupported message: %s", message)
|
return nil, fmt.Errorf("unsupported message: %s", message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KLine uses binance's kline as the standard structure
|
||||||
|
type KLine struct {
|
||||||
|
StartTime int64 `json:"t"`
|
||||||
|
EndTime int64 `json:"T"`
|
||||||
|
|
||||||
|
Symbol string `json:"s"`
|
||||||
|
Interval string `json:"i"`
|
||||||
|
|
||||||
|
Open string `json:"o"`
|
||||||
|
Close string `json:"c"`
|
||||||
|
High string `json:"h"`
|
||||||
|
|
||||||
|
Low string `json:"l"`
|
||||||
|
Volume string `json:"V"` // taker buy base asset volume (like 10 BTC)
|
||||||
|
QuoteVolume string `json:"Q"` // taker buy quote asset volume (like 1000USDT)
|
||||||
|
|
||||||
|
LastTradeID int `json:"L"`
|
||||||
|
NumberOfTrades int64 `json:"n"`
|
||||||
|
Closed bool `json:"x"`
|
||||||
|
}
|
||||||
|
|
||||||
type KLineEvent struct {
|
type KLineEvent struct {
|
||||||
EventBase
|
EventBase
|
||||||
Symbol string `json:"s"`
|
Symbol string `json:"s"`
|
||||||
KLine *types.KLine `json:"k,omitempty"`
|
KLine KLine `json:"k,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KLine) KLine() types.KLine {
|
||||||
|
return types.KLine{
|
||||||
|
Symbol: k.Symbol,
|
||||||
|
Interval: k.Interval,
|
||||||
|
StartTime: time.Unix(0, k.StartTime*int64(time.Millisecond)),
|
||||||
|
EndTime: time.Unix(0, k.EndTime*int64(time.Millisecond)),
|
||||||
|
Open: util.MustParseFloat(k.Open),
|
||||||
|
Close: util.MustParseFloat(k.Close),
|
||||||
|
High: util.MustParseFloat(k.High),
|
||||||
|
Low: util.MustParseFloat(k.Low),
|
||||||
|
Volume: util.MustParseFloat(k.Volume),
|
||||||
|
QuoteVolume: util.MustParseFloat(k.QuoteVolume),
|
||||||
|
LastTradeID: k.LastTradeID,
|
||||||
|
NumberOfTrades: k.NumberOfTrades,
|
||||||
|
Closed: k.Closed,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -31,8 +31,8 @@ type PrivateStream struct {
|
||||||
connectCallbacks []func(stream *PrivateStream)
|
connectCallbacks []func(stream *PrivateStream)
|
||||||
|
|
||||||
// custom callbacks
|
// custom callbacks
|
||||||
kLineEventCallbacks []func(event *KLineEvent)
|
kLineEventCallbacks []func(e *KLineEvent)
|
||||||
kLineClosedEventCallbacks []func(event *KLineEvent)
|
kLineClosedEventCallbacks []func(e *KLineEvent)
|
||||||
|
|
||||||
balanceUpdateEventCallbacks []func(event *BalanceUpdateEvent)
|
balanceUpdateEventCallbacks []func(event *BalanceUpdateEvent)
|
||||||
outboundAccountInfoEventCallbacks []func(event *OutboundAccountInfoEvent)
|
outboundAccountInfoEventCallbacks []func(event *OutboundAccountInfoEvent)
|
||||||
|
@ -190,7 +190,7 @@ func (s *PrivateStream) read(ctx context.Context, eventC chan interface{}) {
|
||||||
|
|
||||||
if e.KLine.Closed {
|
if e.KLine.Closed {
|
||||||
s.EmitKLineClosedEvent(e)
|
s.EmitKLineClosedEvent(e)
|
||||||
s.EmitKLineClosed(e.KLine)
|
s.EmitKLineClosed(e.KLine.KLine())
|
||||||
}
|
}
|
||||||
|
|
||||||
case *ExecutionReportEvent:
|
case *ExecutionReportEvent:
|
||||||
|
|
63
bbgo/indicator.go
Normal file
63
bbgo/indicator.go
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
package bbgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/c9s/bbgo/pkg/bbgo/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MovingAverageIndicator struct {
|
||||||
|
store *MarketDataStore
|
||||||
|
Period int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMovingAverageIndicator(period int) *MovingAverageIndicator {
|
||||||
|
return &MovingAverageIndicator{
|
||||||
|
Period: period,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *MovingAverageIndicator) handleUpdate(kline types.KLine) {
|
||||||
|
klines, ok := i.store.KLineWindows[ Interval(kline.Interval) ]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(klines) < i.Period {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate ma
|
||||||
|
}
|
||||||
|
|
||||||
|
type IndicatorValue struct {
|
||||||
|
Value float64
|
||||||
|
Time time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func calculateMovingAverage(klines types.KLineWindow, period int) (values []IndicatorValue) {
|
||||||
|
for idx := range klines[period:] {
|
||||||
|
offset := idx + period
|
||||||
|
sum := klines[offset - period:offset].ReduceClose()
|
||||||
|
values = append(values, IndicatorValue{
|
||||||
|
Time: klines[offset].GetEndTime(),
|
||||||
|
Value: math.Round(sum / float64(period)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
func (i *MovingAverageIndicator) SubscribeStore(store *MarketDataStore) {
|
||||||
|
i.store = store
|
||||||
|
|
||||||
|
// register kline update callback
|
||||||
|
store.OnUpdate(i.handleUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
17
bbgo/indicator_test.go
Normal file
17
bbgo/indicator_test.go
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
package bbgo
|
||||||
|
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/c9s/bbgo/pkg/bbgo/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCalculateMovingAverage(t *testing.T) {
|
||||||
|
klines := types.KLineWindow{
|
||||||
|
{
|
||||||
|
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_ = klines
|
||||||
|
}
|
|
@ -3,7 +3,6 @@ package bbgo
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
@ -54,7 +53,7 @@ func (trader *KLineRegressionTrader) RunStrategy(ctx context.Context, strategy S
|
||||||
|
|
||||||
fmt.Print(".")
|
fmt.Print(".")
|
||||||
|
|
||||||
standardStream.EmitKLineClosed(&kline)
|
standardStream.EmitKLineClosed(kline)
|
||||||
|
|
||||||
for _, order := range trader.pendingOrders {
|
for _, order := range trader.pendingOrders {
|
||||||
switch order.Side {
|
switch order.Side {
|
||||||
|
@ -117,7 +116,7 @@ func (trader *KLineRegressionTrader) RunStrategy(ctx context.Context, strategy S
|
||||||
Side: string(order.Side),
|
Side: string(order.Side),
|
||||||
IsBuyer: order.Side == types.SideTypeBuy,
|
IsBuyer: order.Side == types.SideTypeBuy,
|
||||||
IsMaker: false,
|
IsMaker: false,
|
||||||
Time: time.Unix(0, kline.EndTime*int64(time.Millisecond)),
|
Time: kline.EndTime,
|
||||||
Symbol: trader.Context.Symbol,
|
Symbol: trader.Context.Symbol,
|
||||||
Fee: fee,
|
Fee: fee,
|
||||||
FeeCurrency: feeCurrency,
|
FeeCurrency: feeCurrency,
|
||||||
|
|
17
bbgo/marketdatastore_callbacks.go
Normal file
17
bbgo/marketdatastore_callbacks.go
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
// Code generated by "callbackgen -type MarketDataStore"; DO NOT EDIT.
|
||||||
|
|
||||||
|
package bbgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/c9s/bbgo/pkg/bbgo/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (store *MarketDataStore) OnUpdate(cb KLineCallback) {
|
||||||
|
store.updateCallbacks = append(store.updateCallbacks, cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *MarketDataStore) EmitUpdate(kline types.KLine) {
|
||||||
|
for _, cb := range store.updateCallbacks {
|
||||||
|
cb(kline)
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,18 +11,18 @@ var Interval5m = Interval("5m")
|
||||||
var Interval1h = Interval("1h")
|
var Interval1h = Interval("1h")
|
||||||
var Interval1d = Interval("1d")
|
var Interval1d = Interval("1d")
|
||||||
|
|
||||||
type MarketDataStore struct {
|
type KLineCallback func(kline types.KLine)
|
||||||
// MaxChangeKLines stores the max change kline per interval
|
|
||||||
MaxChangeKLines map[Interval]types.KLine `json:"-"`
|
|
||||||
|
|
||||||
|
//go:generate callbackgen -type MarketDataStore
|
||||||
|
type MarketDataStore struct {
|
||||||
// KLineWindows stores all loaded klines per interval
|
// KLineWindows stores all loaded klines per interval
|
||||||
KLineWindows map[Interval]types.KLineWindow `json:"-"`
|
KLineWindows map[Interval]types.KLineWindow `json:"-"`
|
||||||
|
|
||||||
|
updateCallbacks []KLineCallback
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMarketDataStore() *MarketDataStore {
|
func NewMarketDataStore() *MarketDataStore {
|
||||||
return &MarketDataStore{
|
return &MarketDataStore{
|
||||||
MaxChangeKLines: make(map[Interval]types.KLine),
|
|
||||||
|
|
||||||
// KLineWindows stores all loaded klines per interval
|
// KLineWindows stores all loaded klines per interval
|
||||||
KLineWindows: make(map[Interval]types.KLineWindow),
|
KLineWindows: make(map[Interval]types.KLineWindow),
|
||||||
}
|
}
|
||||||
|
@ -32,19 +32,14 @@ func (store *MarketDataStore) BindPrivateStream(stream *types.StandardPrivateStr
|
||||||
stream.OnKLineClosed(store.handleKLineClosed)
|
stream.OnKLineClosed(store.handleKLineClosed)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (store *MarketDataStore) handleKLineClosed(kline *types.KLine) {
|
func (store *MarketDataStore) handleKLineClosed(kline types.KLine) {
|
||||||
store.AddKLine(*kline)
|
store.AddKLine(kline)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (store *MarketDataStore) AddKLine(kline types.KLine) {
|
func (store *MarketDataStore) AddKLine(kline types.KLine) {
|
||||||
var interval = Interval(kline.Interval)
|
var interval = Interval(kline.Interval)
|
||||||
|
|
||||||
var window = store.KLineWindows[interval]
|
var window = store.KLineWindows[interval]
|
||||||
window.Add(kline)
|
window.Add(kline)
|
||||||
|
|
||||||
if _, ok := store.MaxChangeKLines[interval] ; ok {
|
store.EmitUpdate(kline)
|
||||||
if kline.GetMaxChange() > store.MaxChangeKLines[interval].GetMaxChange() {
|
|
||||||
store.MaxChangeKLines[interval] = kline
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -263,9 +263,9 @@ func (trader *Trader) RunStrategy(ctx context.Context, strategy Strategy) (chan
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
stream.OnKLineEvent(func(e *binance.KLineEvent) {
|
stream.OnKLineClosed(func(kline types.KLine) {
|
||||||
trader.ProfitAndLossCalculator.SetCurrentPrice(e.KLine.GetClose())
|
trader.ProfitAndLossCalculator.SetCurrentPrice(kline.Close)
|
||||||
trader.Context.SetCurrentPrice(e.KLine.GetClose())
|
trader.Context.SetCurrentPrice(kline.Close)
|
||||||
})
|
})
|
||||||
|
|
||||||
var eventC = make(chan interface{}, 20)
|
var eventC = make(chan interface{}, 20)
|
||||||
|
|
|
@ -39,23 +39,30 @@ type KLineQueryOptions struct {
|
||||||
|
|
||||||
// KLine uses binance's kline as the standard structure
|
// KLine uses binance's kline as the standard structure
|
||||||
type KLine struct {
|
type KLine struct {
|
||||||
StartTime int64 `json:"t"`
|
StartTime time.Time
|
||||||
EndTime int64 `json:"T"`
|
EndTime time.Time
|
||||||
|
|
||||||
Symbol string `json:"s"`
|
Symbol string
|
||||||
Interval string `json:"i"`
|
Interval string
|
||||||
|
|
||||||
Open string `json:"o"`
|
Open float64
|
||||||
Close string `json:"c"`
|
Close float64
|
||||||
High string `json:"h"`
|
High float64
|
||||||
|
Low float64
|
||||||
|
Volume float64
|
||||||
|
QuoteVolume float64
|
||||||
|
|
||||||
Low string `json:"l"`
|
LastTradeID int
|
||||||
Volume string `json:"V"` // taker buy base asset volume (like 10 BTC)
|
NumberOfTrades int64
|
||||||
QuoteVolume string `json:"Q"` // taker buy quote asset volume (like 1000USDT)
|
Closed bool
|
||||||
|
}
|
||||||
|
|
||||||
LastTradeID int `json:"L"`
|
func (k KLine) GetStartTime() time.Time {
|
||||||
NumberOfTrades int64 `json:"n"`
|
return k.StartTime
|
||||||
Closed bool `json:"x"`
|
}
|
||||||
|
|
||||||
|
func (k KLine) GetEndTime() time.Time {
|
||||||
|
return k.EndTime
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k KLine) GetInterval() string {
|
func (k KLine) GetInterval() string {
|
||||||
|
@ -63,21 +70,21 @@ func (k KLine) GetInterval() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k KLine) Mid() float64 {
|
func (k KLine) Mid() float64 {
|
||||||
return (k.GetHigh() + k.GetLow()) / 2
|
return (k.High + k.Low) / 2
|
||||||
}
|
}
|
||||||
|
|
||||||
// green candle with open and close near high price
|
// green candle with open and close near high price
|
||||||
func (k KLine) BounceUp() bool {
|
func (k KLine) BounceUp() bool {
|
||||||
mid := k.Mid()
|
mid := k.Mid()
|
||||||
trend := k.GetTrend()
|
trend := k.GetTrend()
|
||||||
return trend > 0 && k.GetOpen() > mid && k.GetClose() > mid
|
return trend > 0 && k.Open > mid && k.Close > mid
|
||||||
}
|
}
|
||||||
|
|
||||||
// red candle with open and close near low price
|
// red candle with open and close near low price
|
||||||
func (k KLine) BounceDown() bool {
|
func (k KLine) BounceDown() bool {
|
||||||
mid := k.Mid()
|
mid := k.Mid()
|
||||||
trend := k.GetTrend()
|
trend := k.GetTrend()
|
||||||
return trend > 0 && k.GetOpen() < mid && k.GetClose() < mid
|
return trend > 0 && k.Open < mid && k.Close < mid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k KLine) GetTrend() int {
|
func (k KLine) GetTrend() int {
|
||||||
|
@ -93,19 +100,19 @@ func (k KLine) GetTrend() int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k KLine) GetHigh() float64 {
|
func (k KLine) GetHigh() float64 {
|
||||||
return util.MustParseFloat(k.High)
|
return k.High
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k KLine) GetLow() float64 {
|
func (k KLine) GetLow() float64 {
|
||||||
return util.MustParseFloat(k.Low)
|
return k.Low
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k KLine) GetOpen() float64 {
|
func (k KLine) GetOpen() float64 {
|
||||||
return util.MustParseFloat(k.Open)
|
return k.Open
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k KLine) GetClose() float64 {
|
func (k KLine) GetClose() float64 {
|
||||||
return util.MustParseFloat(k.Close)
|
return k.Close
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k KLine) GetMaxChange() float64 {
|
func (k KLine) GetMaxChange() float64 {
|
||||||
|
@ -134,11 +141,11 @@ func (k KLine) GetLowerShadowRatio() float64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k KLine) GetLowerShadowHeight() float64 {
|
func (k KLine) GetLowerShadowHeight() float64 {
|
||||||
low := k.GetLow()
|
low := k.Low
|
||||||
if k.GetOpen() < k.GetClose() {
|
if k.Open < k.Close {
|
||||||
return k.GetOpen() - low
|
return k.Open - low
|
||||||
}
|
}
|
||||||
return k.GetClose() - low
|
return k.Close - low
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetBody returns the height of the candle real body
|
// GetBody returns the height of the candle real body
|
||||||
|
@ -147,19 +154,11 @@ func (k KLine) GetBody() float64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k KLine) GetChange() float64 {
|
func (k KLine) GetChange() float64 {
|
||||||
return k.GetClose() - k.GetOpen()
|
return k.Close - k.Open
|
||||||
}
|
|
||||||
|
|
||||||
func (k KLine) GetStartTime() time.Time {
|
|
||||||
return time.Unix(0, k.StartTime*int64(time.Millisecond))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k KLine) GetEndTime() time.Time {
|
|
||||||
return time.Unix(0, k.EndTime*int64(time.Millisecond))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k KLine) String() string {
|
func (k KLine) String() string {
|
||||||
return fmt.Sprintf("%s %s Open: % 14s Close: % 14s High: % 14s Low: % 14s Volume: % 15s Change: % 11f Max Change: % 11f", k.Symbol, k.Interval, k.Open, k.Close, k.High, k.Low, k.Volume, k.GetChange(), k.GetMaxChange())
|
return fmt.Sprintf("%s %s Open: %.8f Close: %.8f High: %.8f Low: %.8f Volume: %.8f Change: %.4f Max Change: %.4f", k.Symbol, k.Interval, k.Open, k.Close, k.High, k.Low, k.Volume, k.GetChange(), k.GetMaxChange())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k KLine) Color() string {
|
func (k KLine) Color() string {
|
||||||
|
@ -176,10 +175,10 @@ func (k KLine) SlackAttachment() slack.Attachment {
|
||||||
Text: fmt.Sprintf("*%s* KLine %s", k.Symbol, k.Interval),
|
Text: fmt.Sprintf("*%s* KLine %s", k.Symbol, k.Interval),
|
||||||
Color: k.Color(),
|
Color: k.Color(),
|
||||||
Fields: []slack.AttachmentField{
|
Fields: []slack.AttachmentField{
|
||||||
{Title: "Open", Value: k.Open, Short: true},
|
{Title: "Open", Value: util.FormatFloat(k.Open, 2), Short: true},
|
||||||
{Title: "High", Value: k.High, Short: true},
|
{Title: "High", Value: util.FormatFloat(k.High, 2), Short: true},
|
||||||
{Title: "Low", Value: k.Low, Short: true},
|
{Title: "Low", Value: util.FormatFloat(k.Low, 2), Short: true},
|
||||||
{Title: "Close", Value: k.Close, Short: true},
|
{Title: "Close", Value: util.FormatFloat(k.Close, 2), Short: true},
|
||||||
{Title: "Mid", Value: util.FormatFloat(k.Mid(), 2), Short: true},
|
{Title: "Mid", Value: util.FormatFloat(k.Mid(), 2), Short: true},
|
||||||
{Title: "Change", Value: util.FormatFloat(k.GetChange(), 2), Short: true},
|
{Title: "Change", Value: util.FormatFloat(k.GetChange(), 2), Short: true},
|
||||||
{Title: "Max Change", Value: util.FormatFloat(k.GetMaxChange(), 2), Short: true},
|
{Title: "Max Change", Value: util.FormatFloat(k.GetMaxChange(), 2), Short: true},
|
||||||
|
@ -206,6 +205,17 @@ func (k KLine) SlackAttachment() slack.Attachment {
|
||||||
|
|
||||||
type KLineWindow []KLine
|
type KLineWindow []KLine
|
||||||
|
|
||||||
|
// ReduceClose reduces the closed prices
|
||||||
|
func (k KLineWindow) ReduceClose() float64 {
|
||||||
|
s := 0.0
|
||||||
|
|
||||||
|
for _, kline := range k {
|
||||||
|
s += kline.GetClose()
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
func (k KLineWindow) Len() int {
|
func (k KLineWindow) Len() int {
|
||||||
return len(k)
|
return len(k)
|
||||||
}
|
}
|
||||||
|
@ -425,7 +435,7 @@ func (k KLineWindow) SlackAttachment() slack.Attachment {
|
||||||
Short: true,
|
Short: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Footer: fmt.Sprintf("Since %s til %s", first.GetStartTime(), end.GetEndTime()),
|
Footer: fmt.Sprintf("Since %s til %s", first.StartTime, end.EndTime),
|
||||||
FooterIcon: "",
|
FooterIcon: "",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,8 +7,8 @@ import ("testing"
|
||||||
|
|
||||||
func TestKLineWindow_Tail(t *testing.T) {
|
func TestKLineWindow_Tail(t *testing.T) {
|
||||||
var win = KLineWindow{
|
var win = KLineWindow{
|
||||||
{ Open: "11600.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
|
{ Open: 11600.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
|
||||||
{ Open: "11600.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
|
{ Open: 11600.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
|
||||||
}
|
}
|
||||||
|
|
||||||
var win2 = win.Tail(1)
|
var win2 = win.Tail(1)
|
||||||
|
@ -23,22 +23,22 @@ func TestKLineWindow_Tail(t *testing.T) {
|
||||||
|
|
||||||
func TestKLineWindow_Truncate(t *testing.T) {
|
func TestKLineWindow_Truncate(t *testing.T) {
|
||||||
var win = KLineWindow{
|
var win = KLineWindow{
|
||||||
{ Open: "11600.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
|
{ Open: 11600.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
|
||||||
{ Open: "11601.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
|
{ Open: 11601.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
|
||||||
{ Open: "11602.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
|
{ Open: 11602.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
|
||||||
{ Open: "11603.0", Close: "11600.0", High: "11600.0", Low: "11600.0"},
|
{ Open: 11603.0, Close: 11600.0, High: 11600.0, Low: 11600.0},
|
||||||
}
|
}
|
||||||
|
|
||||||
win.Truncate(5)
|
win.Truncate(5)
|
||||||
assert.Len(t, win, 4)
|
assert.Len(t, win, 4)
|
||||||
assert.Equal(t, "11603.0", win.Last().Open)
|
assert.Equal(t, 11603.0, win.Last().Open)
|
||||||
|
|
||||||
win.Truncate(3)
|
win.Truncate(3)
|
||||||
assert.Len(t, win, 3)
|
assert.Len(t, win, 3)
|
||||||
assert.Equal(t, "11603.0", win.Last().Open)
|
assert.Equal(t, 11603.0, win.Last().Open)
|
||||||
|
|
||||||
|
|
||||||
win.Truncate(1)
|
win.Truncate(1)
|
||||||
assert.Len(t, win, 1)
|
assert.Len(t, win, 1)
|
||||||
assert.Equal(t, "11603.0", win.Last().Open)
|
assert.Equal(t, 11603.0, win.Last().Open)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,9 +2,7 @@
|
||||||
|
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import ()
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (stream *StandardPrivateStream) OnTrade(cb func(trade *Trade)) {
|
func (stream *StandardPrivateStream) OnTrade(cb func(trade *Trade)) {
|
||||||
stream.tradeCallbacks = append(stream.tradeCallbacks, cb)
|
stream.tradeCallbacks = append(stream.tradeCallbacks, cb)
|
||||||
|
@ -16,25 +14,6 @@ func (stream *StandardPrivateStream) EmitTrade(trade *Trade) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *StandardPrivateStream) RemoveOnTrade(needle func(trade *Trade)) (found bool) {
|
|
||||||
|
|
||||||
var newcallbacks []func(trade *Trade)
|
|
||||||
var fp = reflect.ValueOf(needle).Pointer()
|
|
||||||
for _, cb := range stream.tradeCallbacks {
|
|
||||||
if fp == reflect.ValueOf(cb).Pointer() {
|
|
||||||
found = true
|
|
||||||
} else {
|
|
||||||
newcallbacks = append(newcallbacks, cb)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if found {
|
|
||||||
stream.tradeCallbacks = newcallbacks
|
|
||||||
}
|
|
||||||
|
|
||||||
return found
|
|
||||||
}
|
|
||||||
|
|
||||||
func (stream *StandardPrivateStream) OnBalanceSnapshot(cb func(balanceSnapshot map[string]Balance)) {
|
func (stream *StandardPrivateStream) OnBalanceSnapshot(cb func(balanceSnapshot map[string]Balance)) {
|
||||||
stream.balanceSnapshotCallbacks = append(stream.balanceSnapshotCallbacks, cb)
|
stream.balanceSnapshotCallbacks = append(stream.balanceSnapshotCallbacks, cb)
|
||||||
}
|
}
|
||||||
|
@ -45,58 +24,20 @@ func (stream *StandardPrivateStream) EmitBalanceSnapshot(balanceSnapshot map[str
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *StandardPrivateStream) RemoveOnBalanceSnapshot(needle func(balanceSnapshot map[string]Balance)) (found bool) {
|
func (stream *StandardPrivateStream) OnKLineClosed(cb func(kline KLine)) {
|
||||||
|
|
||||||
var newcallbacks []func(balanceSnapshot map[string]Balance)
|
|
||||||
var fp = reflect.ValueOf(needle).Pointer()
|
|
||||||
for _, cb := range stream.balanceSnapshotCallbacks {
|
|
||||||
if fp == reflect.ValueOf(cb).Pointer() {
|
|
||||||
found = true
|
|
||||||
} else {
|
|
||||||
newcallbacks = append(newcallbacks, cb)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if found {
|
|
||||||
stream.balanceSnapshotCallbacks = newcallbacks
|
|
||||||
}
|
|
||||||
|
|
||||||
return found
|
|
||||||
}
|
|
||||||
|
|
||||||
func (stream *StandardPrivateStream) OnKLineClosed(cb func(kline *KLine)) {
|
|
||||||
stream.kLineClosedCallbacks = append(stream.kLineClosedCallbacks, cb)
|
stream.kLineClosedCallbacks = append(stream.kLineClosedCallbacks, cb)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *StandardPrivateStream) EmitKLineClosed(kline *KLine) {
|
func (stream *StandardPrivateStream) EmitKLineClosed(kline KLine) {
|
||||||
for _, cb := range stream.kLineClosedCallbacks {
|
for _, cb := range stream.kLineClosedCallbacks {
|
||||||
cb(kline)
|
cb(kline)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *StandardPrivateStream) RemoveOnKLineClosed(needle func(kline *KLine)) (found bool) {
|
|
||||||
|
|
||||||
var newcallbacks []func(kline *KLine)
|
|
||||||
var fp = reflect.ValueOf(needle).Pointer()
|
|
||||||
for _, cb := range stream.kLineClosedCallbacks {
|
|
||||||
if fp == reflect.ValueOf(cb).Pointer() {
|
|
||||||
found = true
|
|
||||||
} else {
|
|
||||||
newcallbacks = append(newcallbacks, cb)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if found {
|
|
||||||
stream.kLineClosedCallbacks = newcallbacks
|
|
||||||
}
|
|
||||||
|
|
||||||
return found
|
|
||||||
}
|
|
||||||
|
|
||||||
type StandardPrivateStreamEventHub interface {
|
type StandardPrivateStreamEventHub interface {
|
||||||
OnTrade(cb func(trade *Trade))
|
OnTrade(cb func(trade *Trade))
|
||||||
|
|
||||||
OnBalanceSnapshot(cb func(balanceSnapshot map[string]Balance))
|
OnBalanceSnapshot(cb func(balanceSnapshot map[string]Balance))
|
||||||
|
|
||||||
OnKLineClosed(cb func(kline *KLine))
|
OnKLineClosed(cb func(kline KLine))
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ type StandardPrivateStream struct {
|
||||||
|
|
||||||
tradeCallbacks []func(trade *Trade)
|
tradeCallbacks []func(trade *Trade)
|
||||||
balanceSnapshotCallbacks []func(balanceSnapshot map[string]Balance)
|
balanceSnapshotCallbacks []func(balanceSnapshot map[string]Balance)
|
||||||
kLineClosedCallbacks []func(kline *KLine)
|
kLineClosedCallbacks []func(kline KLine)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *StandardPrivateStream) Subscribe(channel string, symbol string, options SubscribeOptions) {
|
func (stream *StandardPrivateStream) Subscribe(channel string, symbol string, options SubscribeOptions) {
|
||||||
|
|
|
@ -1,78 +0,0 @@
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
DefaultMinBackoff = 5 * time.Second
|
|
||||||
DefaultMaxBackoff = time.Minute
|
|
||||||
|
|
||||||
DefaultBackoffFactor = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
type Backoff struct {
|
|
||||||
attempt int64
|
|
||||||
Factor float64
|
|
||||||
cur, Min, Max time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// Duration returns the duration for the current attempt before incrementing
|
|
||||||
// the attempt counter. See ForAttempt.
|
|
||||||
func (b *Backoff) Duration() time.Duration {
|
|
||||||
d := b.calculate(b.attempt)
|
|
||||||
b.attempt++
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Backoff) calculate(attempt int64) time.Duration {
|
|
||||||
min := b.Min
|
|
||||||
if min <= 0 {
|
|
||||||
min = DefaultMinBackoff
|
|
||||||
}
|
|
||||||
max := b.Max
|
|
||||||
if max <= 0 {
|
|
||||||
max = DefaultMaxBackoff
|
|
||||||
}
|
|
||||||
if min >= max {
|
|
||||||
return max
|
|
||||||
}
|
|
||||||
factor := b.Factor
|
|
||||||
if factor <= 0 {
|
|
||||||
factor = DefaultBackoffFactor
|
|
||||||
}
|
|
||||||
cur := b.cur
|
|
||||||
if cur < min {
|
|
||||||
cur = min
|
|
||||||
} else if cur > max {
|
|
||||||
cur = max
|
|
||||||
}
|
|
||||||
|
|
||||||
//calculate this duration
|
|
||||||
next := cur
|
|
||||||
if attempt > 0 {
|
|
||||||
next = time.Duration(float64(cur) * factor)
|
|
||||||
}
|
|
||||||
|
|
||||||
if next < cur {
|
|
||||||
// overflow
|
|
||||||
next = max
|
|
||||||
} else if next <= min {
|
|
||||||
next = min
|
|
||||||
} else if next >= max {
|
|
||||||
next = max
|
|
||||||
}
|
|
||||||
b.cur = next
|
|
||||||
return next
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset restarts the current attempt counter at zero.
|
|
||||||
func (b *Backoff) Reset() {
|
|
||||||
b.attempt = 0
|
|
||||||
b.cur = b.Min
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt returns the current attempt counter value.
|
|
||||||
func (b *Backoff) Attempt() int64 {
|
|
||||||
return b.attempt
|
|
||||||
}
|
|
|
@ -1,172 +0,0 @@
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBackoff(t *testing.T) {
|
|
||||||
b := &Backoff{
|
|
||||||
Min: 200 * time.Millisecond,
|
|
||||||
Max: 100 * time.Second,
|
|
||||||
Factor: 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
equals(t, b.Duration(), 200*time.Millisecond)
|
|
||||||
equals(t, b.Duration(), 400*time.Millisecond)
|
|
||||||
equals(t, b.Duration(), 800*time.Millisecond)
|
|
||||||
b.Reset()
|
|
||||||
equals(t, b.Duration(), 200*time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReachMaxBackoff(t *testing.T) {
|
|
||||||
b := &Backoff{
|
|
||||||
Min: 2 * time.Second,
|
|
||||||
Max: 100 * time.Second,
|
|
||||||
Factor: 10,
|
|
||||||
}
|
|
||||||
|
|
||||||
equals(t, b.Duration(), 2*time.Second)
|
|
||||||
equals(t, b.Duration(), 20*time.Second)
|
|
||||||
equals(t, b.Duration(), b.Max)
|
|
||||||
b.Reset()
|
|
||||||
equals(t, b.Duration(), 2*time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAttempt(t *testing.T) {
|
|
||||||
b := &Backoff{
|
|
||||||
Min: 2 * time.Second,
|
|
||||||
Max: 100 * time.Second,
|
|
||||||
Factor: 10,
|
|
||||||
}
|
|
||||||
|
|
||||||
equals(t, b.Attempt(), int64(0))
|
|
||||||
equals(t, b.Duration(), 2*time.Second)
|
|
||||||
equals(t, b.Attempt(), int64(1))
|
|
||||||
equals(t, b.Duration(), 20*time.Second)
|
|
||||||
equals(t, b.Attempt(), int64(2))
|
|
||||||
equals(t, b.Duration(), b.Max)
|
|
||||||
equals(t, b.Attempt(), int64(3))
|
|
||||||
b.Reset()
|
|
||||||
equals(t, b.Attempt(), int64(0))
|
|
||||||
equals(t, b.Duration(), 2*time.Second)
|
|
||||||
equals(t, b.Attempt(), int64(1))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAttemptWithZeroValue(t *testing.T) {
|
|
||||||
b := &Backoff{}
|
|
||||||
|
|
||||||
cur := DefaultMinBackoff
|
|
||||||
equals(t, b.Attempt(), int64(0))
|
|
||||||
equals(t, b.Duration(), cur)
|
|
||||||
for i := 1; i < 10; i++ {
|
|
||||||
equals(t, b.Attempt(), int64(i))
|
|
||||||
cur *= DefaultBackoffFactor
|
|
||||||
if cur >= DefaultMaxBackoff {
|
|
||||||
equals(t, b.Duration(), DefaultMaxBackoff)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
equals(t, b.Duration(), cur)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.Reset()
|
|
||||||
equals(t, b.Attempt(), int64(0))
|
|
||||||
equals(t, b.Duration(), DefaultMinBackoff)
|
|
||||||
equals(t, b.Attempt(), int64(1))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNegOrZeroMin(t *testing.T) {
|
|
||||||
b := &Backoff{
|
|
||||||
Min: 0,
|
|
||||||
}
|
|
||||||
equals(t, b.Duration(), DefaultMinBackoff)
|
|
||||||
|
|
||||||
b2 := &Backoff{
|
|
||||||
Min: -1 * time.Second,
|
|
||||||
}
|
|
||||||
equals(t, b2.Duration(), DefaultMinBackoff)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNegOrZeroMax(t *testing.T) {
|
|
||||||
b := &Backoff{
|
|
||||||
Min: DefaultMaxBackoff,
|
|
||||||
Max: 0,
|
|
||||||
}
|
|
||||||
equals(t, b.Duration(), DefaultMaxBackoff)
|
|
||||||
|
|
||||||
b2 := &Backoff{
|
|
||||||
Min: DefaultMaxBackoff,
|
|
||||||
Max: -1 * time.Second,
|
|
||||||
}
|
|
||||||
equals(t, b2.Duration(), DefaultMaxBackoff)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMinLargerThanMax(t *testing.T) {
|
|
||||||
b := &Backoff{
|
|
||||||
Min: 500 * time.Second,
|
|
||||||
Max: 100 * time.Second,
|
|
||||||
Factor: 1,
|
|
||||||
}
|
|
||||||
equals(t, b.Duration(), b.Max)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFakeCur(t *testing.T) {
|
|
||||||
b := &Backoff{
|
|
||||||
Min: 100 * time.Second,
|
|
||||||
Max: 500 * time.Second,
|
|
||||||
Factor: 1,
|
|
||||||
cur: 1000 * time.Second,
|
|
||||||
}
|
|
||||||
equals(t, b.Duration(), b.Max)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLargeFactor(t *testing.T) {
|
|
||||||
b := &Backoff{
|
|
||||||
Min: 1 * time.Second,
|
|
||||||
Max: 10 * time.Second,
|
|
||||||
Factor: math.MaxFloat64,
|
|
||||||
}
|
|
||||||
equals(t, b.Duration(), b.Min)
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
equals(t, b.Duration(), b.Max)
|
|
||||||
}
|
|
||||||
b.Reset()
|
|
||||||
equals(t, b.Duration(), b.Min)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLargeMinMax(t *testing.T) {
|
|
||||||
b := &Backoff{
|
|
||||||
Min: time.Duration(math.MaxInt64 - 1),
|
|
||||||
Max: time.Duration(math.MaxInt64),
|
|
||||||
Factor: math.MaxFloat64,
|
|
||||||
}
|
|
||||||
equals(t, b.Duration(), b.Min)
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
equals(t, b.Duration(), b.Max)
|
|
||||||
}
|
|
||||||
b.Reset()
|
|
||||||
equals(t, b.Duration(), b.Min)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManyAttempts(t *testing.T) {
|
|
||||||
b := &Backoff{
|
|
||||||
Min: 10 * time.Second,
|
|
||||||
Max: time.Duration(math.MaxInt64),
|
|
||||||
Factor: 1000,
|
|
||||||
}
|
|
||||||
for i := 0; i < 10000; i++ {
|
|
||||||
b.Duration()
|
|
||||||
}
|
|
||||||
equals(t, b.Duration(), b.Max)
|
|
||||||
b.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
func equals(t *testing.T, v1, v2 interface{}) {
|
|
||||||
if !reflect.DeepEqual(v1, v2) {
|
|
||||||
assert.Failf(t, "not equal", "Got %v, Expecting %v", v1, v2)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,368 +0,0 @@
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/time/rate"
|
|
||||||
|
|
||||||
"github.com/c9s/bbgo/pkg/log"
|
|
||||||
"github.com/c9s/bbgo/pkg/util"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
const DefaultMessageBufferSize = 128
|
|
||||||
|
|
||||||
const DefaultWriteTimeout = 30 * time.Second
|
|
||||||
const DefaultReadTimeout = 30 * time.Second
|
|
||||||
|
|
||||||
var ErrReconnectContextDone = errors.New("reconnect canceled due to context done.")
|
|
||||||
var ErrReconnectFailed = errors.New("failed to reconnect")
|
|
||||||
var ErrConnectionLost = errors.New("connection lost")
|
|
||||||
|
|
||||||
var MaxReconnectRate = rate.Limit(1 / DefaultMinBackoff.Seconds())
|
|
||||||
|
|
||||||
// WebSocketClient allows to connect and receive stream data
|
|
||||||
type WebSocketClient struct {
|
|
||||||
// Url is the websocket connection location, start with ws:// or wss://
|
|
||||||
Url string
|
|
||||||
|
|
||||||
// conn is the current websocket connection, please note the connection
|
|
||||||
// object can be replaced with a new connection object when the connection
|
|
||||||
// is unexpected closed.
|
|
||||||
conn *websocket.Conn
|
|
||||||
|
|
||||||
// Dialer is used for creating the websocket connection
|
|
||||||
Dialer *websocket.Dialer
|
|
||||||
|
|
||||||
// requestHeader is used for the Dial function call. Some credential can be
|
|
||||||
// stored in the http request header for authentication
|
|
||||||
requestHeader http.Header
|
|
||||||
|
|
||||||
// messages is a read-only channel, received messages will be sent to this
|
|
||||||
// channel.
|
|
||||||
messages chan Message
|
|
||||||
|
|
||||||
readTimeout time.Duration
|
|
||||||
|
|
||||||
writeTimeout time.Duration
|
|
||||||
|
|
||||||
onConnect []func(c Client)
|
|
||||||
|
|
||||||
onDisconnect []func(c Client)
|
|
||||||
|
|
||||||
// cancel is mapped to the ctx context object
|
|
||||||
cancel func()
|
|
||||||
|
|
||||||
readerClosed chan struct{}
|
|
||||||
|
|
||||||
connected bool
|
|
||||||
|
|
||||||
mu sync.Mutex
|
|
||||||
|
|
||||||
reconnectCh chan struct{}
|
|
||||||
|
|
||||||
backoff Backoff
|
|
||||||
|
|
||||||
limiter *rate.Limiter
|
|
||||||
}
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
// websocket.BinaryMessage or websocket.TextMessage
|
|
||||||
Type int
|
|
||||||
Body []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) Messages() <-chan Message {
|
|
||||||
return c.messages
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) SetReadTimeout(timeout time.Duration) {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
c.readTimeout = timeout
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) SetWriteTimeout(timeout time.Duration) {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
c.writeTimeout = timeout
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) OnConnect(f func(c Client)) {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
c.onConnect = append(c.onConnect, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) OnDisconnect(f func(c Client)) {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
c.onDisconnect = append(c.onDisconnect, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) WriteTextMessage(message []byte) error {
|
|
||||||
return c.WriteMessage(websocket.TextMessage, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) WriteBinaryMessage(message []byte) error {
|
|
||||||
return c.WriteMessage(websocket.BinaryMessage, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) WriteMessage(messageType int, data []byte) error {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
if !c.connected {
|
|
||||||
return ErrConnectionLost
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.conn.WriteMessage(messageType, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) readMessages() error {
|
|
||||||
c.mu.Lock()
|
|
||||||
if !c.connected {
|
|
||||||
c.mu.Unlock()
|
|
||||||
return ErrConnectionLost
|
|
||||||
}
|
|
||||||
timeout := c.readTimeout
|
|
||||||
conn := c.conn
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
msgtype, message, err := conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.messages <- Message{msgtype, message}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// listen starts a goroutine for reading message and tries to re-connect to the
|
|
||||||
// server when the reader returns error
|
|
||||||
//
|
|
||||||
// Please note we should always break the reader loop if there is any error
|
|
||||||
// returned from the server.
|
|
||||||
func (c *WebSocketClient) listen(ctx context.Context) {
|
|
||||||
// The life time of both channels "readerClosed" and "reconnectCh" is bound to one connection.
|
|
||||||
// Each channel should be created before loop starts and be closed after loop ends.
|
|
||||||
// "readerClosed" is used to inform "Close()" reader loop ends.
|
|
||||||
// "reconnectCh" is used to centralize reconnection logics in this reader loop.
|
|
||||||
c.mu.Lock()
|
|
||||||
c.readerClosed = make(chan struct{})
|
|
||||||
c.reconnectCh = make(chan struct{}, 1)
|
|
||||||
c.mu.Unlock()
|
|
||||||
defer func() {
|
|
||||||
c.mu.Lock()
|
|
||||||
close(c.readerClosed)
|
|
||||||
close(c.reconnectCh)
|
|
||||||
c.reconnectCh = nil
|
|
||||||
c.mu.Unlock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
|
|
||||||
case <-c.reconnectCh:
|
|
||||||
// it could be i/o timeout for network disconnection
|
|
||||||
// or it could be invoked from outside.
|
|
||||||
c.SetDisconnected()
|
|
||||||
var maxTries = 1
|
|
||||||
if _, response, err := c.reconnect(ctx, maxTries); err != nil {
|
|
||||||
if err == ErrReconnectContextDone {
|
|
||||||
log.Debugf("[websocket] context canceled. stop reconnecting.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Warnf("[websocket] failed to reconnect after %d tries!! error: %v response: %v", maxTries, err, response)
|
|
||||||
c.Reconnect()
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
if err := c.readMessages(); err != nil {
|
|
||||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
|
|
||||||
log.Warnf("[websocket] unexpected close error reconnecting: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Warnf("[websocket] failed to read message. error: %+v", err)
|
|
||||||
c.Reconnect()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reconnect triggers reconnection logics
|
|
||||||
func (c *WebSocketClient) Reconnect() {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
select {
|
|
||||||
// c.reconnectCh is a buffered channel with cap=1.
|
|
||||||
// At most one reconnect signal could be processed.
|
|
||||||
case c.reconnectCh <- struct{}{}:
|
|
||||||
default:
|
|
||||||
// Entering here means it is already reconnecting.
|
|
||||||
// Drop the current reconnect signal.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close gracefully shuts down the reader and the connection
|
|
||||||
// ctx is the context used for shutdown process.
|
|
||||||
func (c *WebSocketClient) Close() (err error) {
|
|
||||||
c.mu.Lock()
|
|
||||||
// leave the listen goroutine before we close the connection
|
|
||||||
// checking nil is to handle calling "Close" before "Connect" is called
|
|
||||||
if c.cancel != nil {
|
|
||||||
c.cancel()
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
|
||||||
c.SetDisconnected()
|
|
||||||
|
|
||||||
// wait for the reader func to be closed
|
|
||||||
if c.readerClosed != nil {
|
|
||||||
<-c.readerClosed
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// reconnect tries to create a new connection from the existing dialer
|
|
||||||
func (c *WebSocketClient) reconnect(ctx context.Context, maxTries int) (*websocket.Conn, *http.Response, error) {
|
|
||||||
log.Debugf("[websocket] start reconnecting to %q", c.Url)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, nil, ErrReconnectContextDone
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
if s := util.ShouldDelay(c.limiter, DefaultMinBackoff); s > 0 {
|
|
||||||
log.Warn("[websocket] reconnect too frequently. Sleep for ", s)
|
|
||||||
time.Sleep(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Warnf("[websocket] reconnecting x %d to %q", c.backoff.Attempt()+1, c.Url)
|
|
||||||
conn, resp, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
|
|
||||||
if err != nil {
|
|
||||||
dur := c.backoff.Duration()
|
|
||||||
log.Warnf("failed to dial %s: %v, response: %+v. Wait for %v", c.Url, err, resp, dur)
|
|
||||||
time.Sleep(dur)
|
|
||||||
return nil, nil, ErrReconnectFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("[websocket] reconnected to %q", c.Url)
|
|
||||||
// Reset backoff value if connected.
|
|
||||||
c.backoff.Reset()
|
|
||||||
c.setConn(conn)
|
|
||||||
c.setPingHandler(conn)
|
|
||||||
|
|
||||||
return conn, resp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Conn returns the current active connection instance
|
|
||||||
func (c *WebSocketClient) Conn() (conn *websocket.Conn) {
|
|
||||||
c.mu.Lock()
|
|
||||||
conn = c.conn
|
|
||||||
c.mu.Unlock()
|
|
||||||
return conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) setConn(conn *websocket.Conn) {
|
|
||||||
// Disconnect old connection before replacing with new one.
|
|
||||||
c.SetDisconnected()
|
|
||||||
|
|
||||||
c.mu.Lock()
|
|
||||||
c.conn = conn
|
|
||||||
c.connected = true
|
|
||||||
c.mu.Unlock()
|
|
||||||
for _, f := range c.onConnect {
|
|
||||||
go f(c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) setPingHandler(conn *websocket.Conn) {
|
|
||||||
conn.SetPingHandler(func(message string) error {
|
|
||||||
if err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
return conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) SetDisconnected() {
|
|
||||||
c.mu.Lock()
|
|
||||||
closed := false
|
|
||||||
if c.conn != nil {
|
|
||||||
closed = true
|
|
||||||
c.conn.Close()
|
|
||||||
}
|
|
||||||
c.connected = false
|
|
||||||
c.conn = nil
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
if closed {
|
|
||||||
// Only call disconnect callbacks when a connection is closed
|
|
||||||
for _, f := range c.onDisconnect {
|
|
||||||
go f(c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) IsConnected() (ret bool) {
|
|
||||||
c.mu.Lock()
|
|
||||||
ret = c.connected
|
|
||||||
c.mu.Unlock()
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WebSocketClient) Connect(basectx context.Context) error {
|
|
||||||
// maintain a context by the client it self, so that we can manually shutdown the connection
|
|
||||||
ctx, cancel := context.WithCancel(basectx)
|
|
||||||
c.cancel = cancel
|
|
||||||
|
|
||||||
conn, _, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
|
|
||||||
if err == nil {
|
|
||||||
// setup connection only when connected
|
|
||||||
c.setConn(conn)
|
|
||||||
c.setPingHandler(conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1) if connection is built up, start listening for messages.
|
|
||||||
// 2) if connection is NOT ready, start reconnecting infinitely.
|
|
||||||
go c.listen(ctx)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(url string, requestHeader http.Header) *WebSocketClient {
|
|
||||||
return NewWithDialer(url, websocket.DefaultDialer, requestHeader)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWithDialer(url string, d *websocket.Dialer, requestHeader http.Header) *WebSocketClient {
|
|
||||||
limiter, err := util.NewValidLimiter(MaxReconnectRate, 1)
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).Panic("Invalid rate limiter")
|
|
||||||
}
|
|
||||||
return &WebSocketClient{
|
|
||||||
Url: url,
|
|
||||||
Dialer: d,
|
|
||||||
requestHeader: requestHeader,
|
|
||||||
readTimeout: DefaultReadTimeout,
|
|
||||||
writeTimeout: DefaultWriteTimeout,
|
|
||||||
messages: make(chan Message, DefaultMessageBufferSize),
|
|
||||||
limiter: limiter,
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,202 +0,0 @@
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"***REMOVED***/pkg/log"
|
|
||||||
"***REMOVED***/pkg/testing/testutil"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
func messageTicker(t *testing.T, ctx context.Context, conn *websocket.Conn) {
|
|
||||||
ticker := time.NewTicker(time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case tt := <-ticker.C:
|
|
||||||
err := conn.WriteMessage(websocket.TextMessage, []byte(tt.String()))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func messageReader(t *testing.T, ctx context.Context, conn *websocket.Conn) {
|
|
||||||
for {
|
|
||||||
_, message, err := conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
||||||
log.Errorf("unexpected closed error: %v", err)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
t.Logf("message: %v", message)
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReconnect(t *testing.T) {
|
|
||||||
basectx := context.Background()
|
|
||||||
|
|
||||||
wsHandler := func(ctx context.Context) func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var upgrader = websocket.Upgrader{
|
|
||||||
ReadBufferSize: 1024,
|
|
||||||
WriteBufferSize: 1024,
|
|
||||||
}
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("upgrade error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go messageTicker(t, ctx, conn)
|
|
||||||
messageReader(t, ctx, conn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
serverctx, cancelserver := context.WithCancel(basectx)
|
|
||||||
server := testutil.NewWebSocketServerFunc(t, 0, "/ws", wsHandler(serverctx))
|
|
||||||
go func() {
|
|
||||||
err := server.ListenAndServe()
|
|
||||||
assert.Equal(t, http.ErrServerClosed, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, port, err := net.SplitHostPort(server.Addr)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
log.Debugf("waiting for server port ready")
|
|
||||||
testutil.WaitForPort(t, server.Addr, 3)
|
|
||||||
|
|
||||||
u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"}
|
|
||||||
log.Infof("url: %s", u.String())
|
|
||||||
|
|
||||||
client := New(u.String(), nil)
|
|
||||||
|
|
||||||
// start the message reader
|
|
||||||
clientctx, cancelClient := context.WithCancel(basectx)
|
|
||||||
defer cancelClient()
|
|
||||||
|
|
||||||
err = client.Connect(clientctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, client)
|
|
||||||
client.SetReadTimeout(2 * time.Second)
|
|
||||||
|
|
||||||
// read one message
|
|
||||||
log.Debugf("waiting for message...")
|
|
||||||
msg := <-client.messages
|
|
||||||
log.Debugf("received message: %v", msg)
|
|
||||||
|
|
||||||
triggerReconnectManyTimes := func() {
|
|
||||||
// Forcedly reconnect multiple times
|
|
||||||
for i := 0; i < 10; i = i + 1 {
|
|
||||||
client.Reconnect()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("reconnecting...")
|
|
||||||
triggerReconnectManyTimes()
|
|
||||||
msg = <-client.messages
|
|
||||||
log.Debugf("received message: %v", msg)
|
|
||||||
|
|
||||||
log.Debugf("shutting down server...")
|
|
||||||
err = server.Shutdown(serverctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
cancelserver()
|
|
||||||
server.Close()
|
|
||||||
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
|
|
||||||
triggerReconnectManyTimes()
|
|
||||||
|
|
||||||
log.Debugf("restarting server...")
|
|
||||||
server2ctx, cancelserver2 := context.WithCancel(basectx)
|
|
||||||
server2 := testutil.NewWebSocketServerWithAddressFunc(t, server.Addr, "/ws", wsHandler(server2ctx))
|
|
||||||
go func() {
|
|
||||||
err := server2.ListenAndServe()
|
|
||||||
assert.Equal(t, http.ErrServerClosed, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
triggerReconnectManyTimes()
|
|
||||||
|
|
||||||
log.Debugf("waiting for server2 port ready")
|
|
||||||
testutil.WaitForPort(t, server.Addr, 3)
|
|
||||||
|
|
||||||
log.Debugf("waiting for message from server2...")
|
|
||||||
msg = <-client.messages
|
|
||||||
log.Debugf("received message: %v", msg)
|
|
||||||
|
|
||||||
err = client.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
err = server2.Shutdown(server2ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
cancelserver2()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnect(t *testing.T) {
|
|
||||||
basectx := context.Background()
|
|
||||||
ctx, cancel := context.WithCancel(basectx)
|
|
||||||
|
|
||||||
server := testutil.NewWebSocketServerFunc(t, 0, "/ws", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var upgrader = websocket.Upgrader{
|
|
||||||
ReadBufferSize: 1024,
|
|
||||||
WriteBufferSize: 1024,
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("upgrade error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go messageTicker(t, ctx, conn)
|
|
||||||
messageReader(t, ctx, conn)
|
|
||||||
})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
err := server.ListenAndServe()
|
|
||||||
assert.Equal(t, http.ErrServerClosed, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, port, err := net.SplitHostPort(server.Addr)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
testutil.WaitForPort(t, server.Addr, 3)
|
|
||||||
|
|
||||||
u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"}
|
|
||||||
t.Logf("url: %s", u.String())
|
|
||||||
|
|
||||||
client := New(u.String(), nil)
|
|
||||||
|
|
||||||
err = client.Connect(ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, client)
|
|
||||||
client.SetReadTimeout(2 * time.Second)
|
|
||||||
|
|
||||||
// read one message
|
|
||||||
t.Logf("waiting for message...")
|
|
||||||
msg := <-client.messages
|
|
||||||
t.Logf("received message: %v", msg)
|
|
||||||
|
|
||||||
// recreate server at the same address
|
|
||||||
client.Close()
|
|
||||||
|
|
||||||
t.Logf("shutting down server...")
|
|
||||||
cancel()
|
|
||||||
err = server.Shutdown(basectx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,21 +0,0 @@
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
//go:generate mockery -name=Client
|
|
||||||
|
|
||||||
type Client interface {
|
|
||||||
SetWriteTimeout(time.Duration)
|
|
||||||
SetReadTimeout(time.Duration)
|
|
||||||
OnConnect(func(c Client))
|
|
||||||
OnDisconnect(func(c Client))
|
|
||||||
Connect(context.Context) error
|
|
||||||
Reconnect()
|
|
||||||
Close() error
|
|
||||||
IsConnected() bool
|
|
||||||
WriteJSON(interface{}) error
|
|
||||||
Messages() <-chan Message
|
|
||||||
}
|
|
|
@ -1,55 +0,0 @@
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WriteJSON writes the JSON encoding of v as a message.
|
|
||||||
//
|
|
||||||
// See the documentation for encoding/json Marshal for details about the
|
|
||||||
// conversion of Go values to JSON.
|
|
||||||
func (c *WebSocketClient) WriteJSON(v interface{}) error {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
if !c.connected {
|
|
||||||
return ErrConnectionLost
|
|
||||||
}
|
|
||||||
w, err := c.conn.NextWriter(websocket.TextMessage)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err1 := json.NewEncoder(w).Encode(v)
|
|
||||||
err2 := w.Close()
|
|
||||||
if err1 != nil {
|
|
||||||
return err1
|
|
||||||
}
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadJSON reads the next JSON-encoded message from the connection and stores
|
|
||||||
// it in the value pointed to by v.
|
|
||||||
//
|
|
||||||
// See the documentation for the encoding/json Unmarshal function for details
|
|
||||||
// about the conversion of JSON to a Go value.
|
|
||||||
func (c *WebSocketClient) ReadJSON(v interface{}) error {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
if !c.connected {
|
|
||||||
return ErrConnectionLost
|
|
||||||
}
|
|
||||||
_, r, err := c.conn.NextReader()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = json.NewDecoder(r).Decode(v)
|
|
||||||
if err == io.EOF {
|
|
||||||
// One value is expected in the message.
|
|
||||||
err = io.ErrUnexpectedEOF
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
|
@ -1,110 +0,0 @@
|
||||||
// Code generated by mockery v1.0.0. DO NOT EDIT.
|
|
||||||
|
|
||||||
package mocks
|
|
||||||
|
|
||||||
import context "context"
|
|
||||||
import mock "github.com/stretchr/testify/mock"
|
|
||||||
import time "time"
|
|
||||||
import websocket "***REMOVED***/pkg/net/http/websocket"
|
|
||||||
|
|
||||||
// Client is an autogenerated mock type for the Client type
|
|
||||||
type Client struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close provides a mock function with given fields:
|
|
||||||
func (_m *Client) Close() error {
|
|
||||||
ret := _m.Called()
|
|
||||||
|
|
||||||
var r0 error
|
|
||||||
if rf, ok := ret.Get(0).(func() error); ok {
|
|
||||||
r0 = rf()
|
|
||||||
} else {
|
|
||||||
r0 = ret.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect provides a mock function with given fields: _a0
|
|
||||||
func (_m *Client) Connect(_a0 context.Context) error {
|
|
||||||
ret := _m.Called(_a0)
|
|
||||||
|
|
||||||
var r0 error
|
|
||||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
|
||||||
r0 = rf(_a0)
|
|
||||||
} else {
|
|
||||||
r0 = ret.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r0
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsConnected provides a mock function with given fields:
|
|
||||||
func (_m *Client) IsConnected() bool {
|
|
||||||
ret := _m.Called()
|
|
||||||
|
|
||||||
var r0 bool
|
|
||||||
if rf, ok := ret.Get(0).(func() bool); ok {
|
|
||||||
r0 = rf()
|
|
||||||
} else {
|
|
||||||
r0 = ret.Get(0).(bool)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Messages provides a mock function with given fields:
|
|
||||||
func (_m *Client) Messages() <-chan websocket.Message {
|
|
||||||
ret := _m.Called()
|
|
||||||
|
|
||||||
var r0 <-chan websocket.Message
|
|
||||||
if rf, ok := ret.Get(0).(func() <-chan websocket.Message); ok {
|
|
||||||
r0 = rf()
|
|
||||||
} else {
|
|
||||||
if ret.Get(0) != nil {
|
|
||||||
r0 = ret.Get(0).(<-chan websocket.Message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return r0
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnConnect provides a mock function with given fields: _a0
|
|
||||||
func (_m *Client) OnConnect(_a0 func(websocket.Client)) {
|
|
||||||
_m.Called(_a0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnDisconnect provides a mock function with given fields: _a0
|
|
||||||
func (_m *Client) OnDisconnect(_a0 func(websocket.Client)) {
|
|
||||||
_m.Called(_a0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reconnect provides a mock function with given fields:
|
|
||||||
func (_m *Client) Reconnect() {
|
|
||||||
_m.Called()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetReadTimeout provides a mock function with given fields: _a0
|
|
||||||
func (_m *Client) SetReadTimeout(_a0 time.Duration) {
|
|
||||||
_m.Called(_a0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetWriteTimeout provides a mock function with given fields: _a0
|
|
||||||
func (_m *Client) SetWriteTimeout(_a0 time.Duration) {
|
|
||||||
_m.Called(_a0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteJSON provides a mock function with given fields: _a0
|
|
||||||
func (_m *Client) WriteJSON(_a0 interface{}) error {
|
|
||||||
ret := _m.Called(_a0)
|
|
||||||
|
|
||||||
var r0 error
|
|
||||||
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
|
||||||
r0 = rf(_a0)
|
|
||||||
} else {
|
|
||||||
r0 = ret.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r0
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user