mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-22 06:53:52 +00:00
update
This commit is contained in:
commit
3fe812264f
40
bbgo/event.go
Normal file
40
bbgo/event.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package bbgo
|
||||
|
||||
/*
|
||||
|
||||
kline
|
||||
|
||||
{
|
||||
"e": "kline", // KLineEvent type
|
||||
"E": 123456789, // KLineEvent time
|
||||
"s": "BNBBTC", // Symbol
|
||||
"k": {
|
||||
"t": 123400000, // Kline start time
|
||||
"T": 123460000, // Kline close time
|
||||
"s": "BNBBTC", // Symbol
|
||||
"i": "1m", // Interval
|
||||
"f": 100, // First trade ID
|
||||
"L": 200, // Last trade ID
|
||||
"o": "0.0010", // Open price
|
||||
"c": "0.0020", // Close price
|
||||
"h": "0.0025", // High price
|
||||
"l": "0.0015", // Low price
|
||||
"v": "1000", // Base asset volume
|
||||
"n": 100, // Number of trades
|
||||
"x": false, // Is this kline closed?
|
||||
"q": "1.0000", // Quote asset volume
|
||||
"V": "500", // Taker buy base asset volume
|
||||
"Q": "0.500", // Taker buy quote asset volume
|
||||
"B": "123456" // Ignore
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
*/
|
||||
type EventBase struct {
|
||||
Event string `json:"e"` // event
|
||||
Time int64 `json:"E"`
|
||||
}
|
||||
|
137
bbgo/kline.go
Normal file
137
bbgo/kline.go
Normal file
|
@ -0,0 +1,137 @@
|
|||
package bbgo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type KLineEvent struct {
|
||||
EventBase
|
||||
Symbol string `json:"s"`
|
||||
KLine *KLine `json:"k,omitempty"`
|
||||
}
|
||||
|
||||
type KLine struct {
|
||||
StartTime int64 `json:"t"`
|
||||
EndTime int64 `json:"T"`
|
||||
|
||||
Symbol string `json:"s"`
|
||||
Interval string `json:"i"`
|
||||
|
||||
Open string `json:"o"`
|
||||
Close string `json:"c"`
|
||||
High string `json:"h"`
|
||||
Low string `json:"l"`
|
||||
Volume string `json:"V"` // taker buy base asset volume (like 10 BTC)
|
||||
QuoteVolume string `json:"Q"` // taker buy quote asset volume (like 1000USDT)
|
||||
|
||||
LastTradeID int `json:"L"`
|
||||
NumberOfTrades int `json:"n"`
|
||||
Closed bool `json:"x"`
|
||||
}
|
||||
|
||||
func (k KLine) GetTrend() int {
|
||||
o := k.GetOpen()
|
||||
c := k.GetClose()
|
||||
|
||||
if c > o {
|
||||
return 1
|
||||
} else if c < o {
|
||||
return -1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (k KLine) GetHigh() float64 {
|
||||
return MustParseFloat(k.High)
|
||||
}
|
||||
|
||||
func (k KLine) GetLow() float64 {
|
||||
return MustParseFloat(k.Low)
|
||||
}
|
||||
|
||||
func (k KLine) GetOpen() float64 {
|
||||
return MustParseFloat(k.Open)
|
||||
}
|
||||
|
||||
func (k KLine) GetClose() float64 {
|
||||
return MustParseFloat(k.Close)
|
||||
}
|
||||
|
||||
func (k KLine) GetMaxChange() float64 {
|
||||
return k.GetHigh() - k.GetLow()
|
||||
}
|
||||
|
||||
func (k KLine) GetThickness() float64 {
|
||||
return k.GetChange() / k.GetMaxChange()
|
||||
}
|
||||
|
||||
func (k KLine) GetChange() float64 {
|
||||
return k.GetClose() - k.GetOpen()
|
||||
}
|
||||
|
||||
func (k KLine) String() string {
|
||||
return fmt.Sprintf("%s %s Open: % 14s Close: % 14s High: % 14s Low: % 14s Volume: % 13s Change: % 13f %s", k.Symbol, k.Interval, k.Open, k.Close, k.High, k.Low, k.Volume, k.GetChange(), k.Interval)
|
||||
}
|
||||
|
||||
type KLineWindow []KLine
|
||||
|
||||
func (w KLineWindow) Len() int {
|
||||
return len(w)
|
||||
}
|
||||
|
||||
func (w KLineWindow) GetOpen() float64 {
|
||||
return w[0].GetOpen()
|
||||
}
|
||||
|
||||
func (w KLineWindow) GetClose() float64 {
|
||||
end := len(w) - 1
|
||||
return w[end].GetClose()
|
||||
}
|
||||
|
||||
func (w KLineWindow) GetHigh() float64 {
|
||||
high := w.GetOpen()
|
||||
for _, line := range w {
|
||||
val := line.GetHigh()
|
||||
if val > high {
|
||||
high = val
|
||||
}
|
||||
}
|
||||
return high
|
||||
}
|
||||
|
||||
func (w KLineWindow) GetLow() float64 {
|
||||
low := w.GetOpen()
|
||||
for _, line := range w {
|
||||
val := line.GetHigh()
|
||||
if val < low {
|
||||
low = val
|
||||
}
|
||||
}
|
||||
return low
|
||||
}
|
||||
|
||||
func (w KLineWindow) GetChange() float64 {
|
||||
return w.GetClose() - w.GetOpen()
|
||||
}
|
||||
|
||||
func (w KLineWindow) GetMaxChange() float64 {
|
||||
return w.GetHigh() - w.GetLow()
|
||||
}
|
||||
|
||||
func (w *KLineWindow) Add(line KLine) {
|
||||
*w = append(*w, line)
|
||||
}
|
||||
|
||||
func (w *KLineWindow) Truncate(size int) {
|
||||
if len(*w) <= size {
|
||||
return
|
||||
}
|
||||
|
||||
end := len(*w) - 1
|
||||
start := end - size
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
*w = (*w)[end-5 : end]
|
||||
}
|
||||
|
12
bbgo/parse.go
Normal file
12
bbgo/parse.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package bbgo
|
||||
|
||||
import "strconv"
|
||||
|
||||
func MustParseFloat(s string) float64 {
|
||||
v, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
196
bbgo/parser.go
Normal file
196
bbgo/parser.go
Normal file
|
@ -0,0 +1,196 @@
|
|||
package bbgo
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/valyala/fastjson"
|
||||
)
|
||||
|
||||
/*
|
||||
|
||||
executionReport
|
||||
|
||||
{
|
||||
"e": "executionReport", // KLineEvent type
|
||||
"E": 1499405658658, // KLineEvent time
|
||||
"s": "ETHBTC", // Symbol
|
||||
"c": "mUvoqJxFIILMdfAW5iGSOW", // Client order ID
|
||||
"S": "BUY", // Side
|
||||
"o": "LIMIT", // Order type
|
||||
"f": "GTC", // Time in force
|
||||
"q": "1.00000000", // Order quantity
|
||||
"p": "0.10264410", // Order price
|
||||
"P": "0.00000000", // Stop price
|
||||
"F": "0.00000000", // Iceberg quantity
|
||||
"g": -1, // OrderListId
|
||||
"C": null, // Original client order ID; This is the ID of the order being canceled
|
||||
"x": "NEW", // Current execution type
|
||||
"X": "NEW", // Current order status
|
||||
"r": "NONE", // Order reject reason; will be an error code.
|
||||
"i": 4293153, // Order ID
|
||||
"l": "0.00000000", // Last executed quantity
|
||||
"z": "0.00000000", // Cumulative filled quantity
|
||||
"L": "0.00000000", // Last executed price
|
||||
"n": "0", // Commission amount
|
||||
"N": null, // Commission asset
|
||||
"T": 1499405658657, // Transaction time
|
||||
"t": -1, // Trade ID
|
||||
"I": 8641984, // Ignore
|
||||
"w": true, // Is the order on the book?
|
||||
"m": false, // Is this trade the maker side?
|
||||
"M": false, // Ignore
|
||||
"O": 1499405658657, // Order creation time
|
||||
"Z": "0.00000000", // Cumulative quote asset transacted quantity
|
||||
"Y": "0.00000000", // Last quote asset transacted quantity (i.e. lastPrice * lastQty)
|
||||
"Q": "0.00000000" // Quote Order Qty
|
||||
}
|
||||
*/
|
||||
type ExecutionReportEvent struct {
|
||||
EventBase
|
||||
|
||||
Symbol string `json:"s"`
|
||||
ClientOrderID string `json:"c"`
|
||||
Side string `json:"S"`
|
||||
OrderType string `json:"o"`
|
||||
TimeInForce string `json:"f"`
|
||||
|
||||
Quantity string `json:"q"`
|
||||
Price string `json:"p"`
|
||||
StopPrice string `json:"P"`
|
||||
|
||||
CurrentExecutionType string `json:"x"`
|
||||
CurrentOrderStatus string `json:"X"`
|
||||
|
||||
OrderID int `json:"i"`
|
||||
|
||||
LastExecutedQuantity string `json:"l"`
|
||||
CumulativeFilledQuantity string `json:"z"`
|
||||
LastExecutedPrice string `json:"L"`
|
||||
|
||||
OrderCreationTime int `json:"O"`
|
||||
}
|
||||
|
||||
/*
|
||||
balanceUpdate
|
||||
|
||||
{
|
||||
"e": "balanceUpdate", //KLineEvent Type
|
||||
"E": 1573200697110, //KLineEvent Time
|
||||
"a": "BTC", //Asset
|
||||
"d": "100.00000000", //Balance Delta
|
||||
"T": 1573200697068 //Clear Time
|
||||
}
|
||||
*/
|
||||
type BalanceUpdateEvent struct {
|
||||
EventBase
|
||||
|
||||
Asset string `json:"a"`
|
||||
Delta string `json:"d"`
|
||||
ClearTime int64 `json:"T"`
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
outboundAccountInfo
|
||||
|
||||
{
|
||||
"e": "outboundAccountInfo", // KLineEvent type
|
||||
"E": 1499405658849, // KLineEvent time
|
||||
"m": 0, // Maker commission rate (bips)
|
||||
"t": 0, // Taker commission rate (bips)
|
||||
"b": 0, // Buyer commission rate (bips)
|
||||
"s": 0, // Seller commission rate (bips)
|
||||
"T": true, // Can trade?
|
||||
"W": true, // Can withdraw?
|
||||
"D": true, // Can deposit?
|
||||
"u": 1499405658848, // Time of last account update
|
||||
"B": [ // Balances array
|
||||
{
|
||||
"a": "LTC", // Asset
|
||||
"f": "17366.18538083", // Free amount
|
||||
"l": "0.00000000" // Locked amount
|
||||
},
|
||||
{
|
||||
"a": "BTC",
|
||||
"f": "10537.85314051",
|
||||
"l": "2.19464093"
|
||||
},
|
||||
{
|
||||
"a": "ETH",
|
||||
"f": "17902.35190619",
|
||||
"l": "0.00000000"
|
||||
},
|
||||
{
|
||||
"a": "BNC",
|
||||
"f": "1114503.29769312",
|
||||
"l": "0.00000000"
|
||||
},
|
||||
{
|
||||
"a": "NEO",
|
||||
"f": "0.00000000",
|
||||
"l": "0.00000000"
|
||||
}
|
||||
],
|
||||
"P": [ // Account Permissions
|
||||
"SPOT"
|
||||
]
|
||||
}
|
||||
|
||||
*/
|
||||
type Balance struct {
|
||||
Asset string `json:"a"`
|
||||
Free string `json:"f"`
|
||||
Locked string `json:"l"`
|
||||
}
|
||||
type OutboundAccountInfoEvent struct {
|
||||
EventBase
|
||||
|
||||
MakerCommissionRate int `json:"m"`
|
||||
TakerCommissionRate int `json:"t"`
|
||||
BuyerCommissionRate int `json:"b"`
|
||||
SellerCommissionRate int `json:"s"`
|
||||
|
||||
CanTrade bool `json:"T"`
|
||||
CanWithdraw bool `json:"W"`
|
||||
CanDeposit bool `json:"D"`
|
||||
|
||||
LastAccountUpdateTime int `json:"u"`
|
||||
|
||||
Balances []Balance `json:"B,omitempty"`
|
||||
Permissions []string `json:"P,omitempty"`
|
||||
}
|
||||
|
||||
func ParseEvent(message string) (interface{}, error) {
|
||||
val, err := fastjson.Parse(message)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
eventType := string(val.GetStringBytes("e"))
|
||||
|
||||
switch eventType {
|
||||
case "kline":
|
||||
var event KLineEvent
|
||||
err := json.Unmarshal([]byte(message), &event)
|
||||
return &event, err
|
||||
|
||||
case "outboundAccountInfo", "outboundAccountPosition":
|
||||
var event OutboundAccountInfoEvent
|
||||
err := json.Unmarshal([]byte(message), &event)
|
||||
return &event, err
|
||||
|
||||
case "balanceUpdate":
|
||||
var event BalanceUpdateEvent
|
||||
err := json.Unmarshal([]byte(message), &event)
|
||||
return &event, err
|
||||
|
||||
case "executionReport":
|
||||
var event ExecutionReportEvent
|
||||
err := json.Unmarshal([]byte(message), &event)
|
||||
return &event, err
|
||||
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported message: %s", message)
|
||||
}
|
||||
|
66
bbgo/pnl.go
Normal file
66
bbgo/pnl.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package bbgo
|
||||
|
||||
import log "github.com/sirupsen/logrus"
|
||||
|
||||
func CalculateAverageCost(trades []Trade) (averageCost float64) {
|
||||
var totalCost = 0.0
|
||||
var totalQuantity = 0.0
|
||||
for _, t := range trades {
|
||||
if t.IsBuyer {
|
||||
totalCost += t.Price * t.Volume
|
||||
totalQuantity += t.Volume
|
||||
} else {
|
||||
totalCost -= t.Price * t.Volume
|
||||
totalQuantity -= t.Volume
|
||||
}
|
||||
}
|
||||
|
||||
averageCost = totalCost / totalQuantity
|
||||
return
|
||||
}
|
||||
|
||||
func CalculateCostAndProfit(trades []Trade, currentPrice float64) (averageBidPrice, stock, profit, fee float64) {
|
||||
var bidVolume = 0.0
|
||||
var bidAmount = 0.0
|
||||
var bidFee = 0.0
|
||||
for _, t := range trades {
|
||||
if t.IsBuyer {
|
||||
bidVolume += t.Volume
|
||||
bidAmount += t.Price * t.Volume
|
||||
switch t.FeeCurrency {
|
||||
case "BTC":
|
||||
bidFee += t.Price * t.Fee
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("average bid price = (total amount %f + total fee %f) / volume %f", bidAmount, bidFee, bidVolume)
|
||||
averageBidPrice = (bidAmount + bidFee) / bidVolume
|
||||
|
||||
var feeRate = 0.001
|
||||
var askVolume = 0.0
|
||||
var askFee = 0.0
|
||||
for _, t := range trades {
|
||||
if !t.IsBuyer {
|
||||
profit += (t.Price - averageBidPrice) * t.Volume
|
||||
askVolume += t.Volume
|
||||
switch t.FeeCurrency {
|
||||
case "USDT":
|
||||
askFee += t.Fee
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
profit -= askFee
|
||||
|
||||
stock = bidVolume - askVolume
|
||||
futureFee := 0.0
|
||||
if stock > 0 {
|
||||
stockfee := currentPrice * feeRate * stock
|
||||
profit += (currentPrice-averageBidPrice)*stock - stockfee
|
||||
futureFee += stockfee
|
||||
}
|
||||
|
||||
fee = bidFee + askFee + futureFee
|
||||
return
|
||||
}
|
92
bbgo/trade.go
Normal file
92
bbgo/trade.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package bbgo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/adshao/go-binance"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Trade struct {
|
||||
ID int64
|
||||
Price float64
|
||||
Volume float64
|
||||
IsBuyer bool
|
||||
IsMaker bool
|
||||
Time time.Time
|
||||
|
||||
Fee float64
|
||||
FeeCurrency string
|
||||
}
|
||||
|
||||
func QueryTrades(ctx context.Context, client *binance.Client, market string, startTime time.Time) (trades []Trade, err error) {
|
||||
var lastTradeID int64 = 0
|
||||
for {
|
||||
req := client.NewListTradesService().
|
||||
Limit(1000).
|
||||
Symbol(market).
|
||||
StartTime(startTime.UnixNano() / 1000000)
|
||||
|
||||
if lastTradeID > 0 {
|
||||
req.FromID(lastTradeID)
|
||||
}
|
||||
|
||||
bnTrades, err := req.Do(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(bnTrades) <= 1 {
|
||||
break
|
||||
}
|
||||
|
||||
for _, t := range bnTrades {
|
||||
// skip trade ID that is the same
|
||||
if t.ID == lastTradeID {
|
||||
continue
|
||||
}
|
||||
|
||||
var side string
|
||||
if t.IsBuyer {
|
||||
side = "BUY"
|
||||
} else {
|
||||
side = "SELL"
|
||||
}
|
||||
|
||||
// trade time
|
||||
tt := time.Unix(0, t.Time*1000000)
|
||||
log.Infof("trade: %d %4s Price: % 13s Volume: % 13s %s", t.ID, side, t.Price, t.Quantity, tt)
|
||||
|
||||
price, err := strconv.ParseFloat(t.Price, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
quantity, err := strconv.ParseFloat(t.Quantity, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fee, err := strconv.ParseFloat(t.Commission, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
trades = append(trades, Trade{
|
||||
ID: t.ID,
|
||||
Price: price,
|
||||
Volume: quantity,
|
||||
IsBuyer: t.IsBuyer,
|
||||
IsMaker: t.IsMaker,
|
||||
Fee: fee,
|
||||
FeeCurrency: t.CommissionAsset,
|
||||
Time: tt,
|
||||
})
|
||||
|
||||
lastTradeID = t.ID
|
||||
}
|
||||
}
|
||||
|
||||
return trades, nil
|
||||
}
|
39
util/math.go
Normal file
39
util/math.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const MaxDigits = 18 // MAX_INT64 ~ 9 * 10^18
|
||||
|
||||
var Pow10Table = [MaxDigits + 1]int64{
|
||||
1, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18,
|
||||
}
|
||||
|
||||
func Pow10(n int64) int64 {
|
||||
if n < 0 || n > MaxDigits {
|
||||
return 0
|
||||
}
|
||||
return Pow10Table[n]
|
||||
}
|
||||
|
||||
var NegPow10Table = [MaxDigits + 1]float64{
|
||||
1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15, 1e-16, 1e-17, 1e-18,
|
||||
}
|
||||
|
||||
func NegPow10(n int64) float64 {
|
||||
if n < 0 || n > MaxDigits {
|
||||
return 0.0
|
||||
}
|
||||
return NegPow10Table[n]
|
||||
}
|
||||
|
||||
func Float64ToStr(input float64) string {
|
||||
return strconv.FormatFloat(input, 'f', -1, 64)
|
||||
}
|
||||
|
||||
func Float64ToInt64(input float64) int64 {
|
||||
// eliminate rounding error for IEEE754 floating points
|
||||
return int64(math.Round(input))
|
||||
}
|
19
util/rate_limit.go
Normal file
19
util/rate_limit.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func ShouldDelay(l *rate.Limiter, minInterval time.Duration) time.Duration {
|
||||
return l.Reserve().Delay()
|
||||
}
|
||||
|
||||
func NewValidLimiter(r rate.Limit, b int) (*rate.Limiter, error) {
|
||||
if b <= 0 || r <= 0 {
|
||||
return nil, fmt.Errorf("Bad rate limit config. Insufficient tokens. (rate=%f, b=%d)", r, b)
|
||||
}
|
||||
return rate.NewLimiter(r, b), nil
|
||||
}
|
43
util/rate_limit_test.go
Normal file
43
util/rate_limit_test.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func TestNewValidRateLimiter(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
r rate.Limit
|
||||
b int
|
||||
hasError bool
|
||||
}{
|
||||
{"valid limiter", 0.1, 1, false},
|
||||
{"zero rate", 0, 1, true},
|
||||
{"zero burst", 0.1, 0, true},
|
||||
{"both zero", 0, 0, true},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
limiter, err := NewValidLimiter(c.r, c.b)
|
||||
assert.Equal(t, c.hasError, err != nil)
|
||||
if !c.hasError {
|
||||
assert.NotNil(t, limiter)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldDelay(t *testing.T) {
|
||||
minInterval := time.Second * 3
|
||||
maxRate := rate.Limit(1 / minInterval.Seconds())
|
||||
limiter := rate.NewLimiter(maxRate, 1)
|
||||
assert.Equal(t, time.Duration(0), ShouldDelay(limiter, minInterval))
|
||||
for i := 0; i < 100; i++ {
|
||||
assert.True(t, ShouldDelay(limiter, minInterval) > 0)
|
||||
}
|
||||
}
|
69
util/retry.go
Normal file
69
util/retry.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
InfiniteRetry = 0
|
||||
)
|
||||
|
||||
type RetryPredicator func(e error) bool
|
||||
|
||||
// Retry retrys the passed functin for "attempts" times, if passed function return error. Setting attemps to zero means keep retrying.
|
||||
func Retry(ctx context.Context, attempts int, duration time.Duration, fnToRetry func() error, errHandler func(error), predicators ...RetryPredicator) (err error) {
|
||||
infinite := false
|
||||
if attempts == InfiniteRetry {
|
||||
infinite = true
|
||||
}
|
||||
|
||||
for attempts > 0 || infinite {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
errMsg := "return for context done"
|
||||
if err != nil {
|
||||
return errors.Wrap(err, errMsg)
|
||||
} else {
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
default:
|
||||
if err = fnToRetry(); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !needRetry(err, predicators) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = errors.Wrapf(err, "failed in retry: countdown: %v", attempts)
|
||||
|
||||
if errHandler != nil {
|
||||
errHandler(err)
|
||||
}
|
||||
|
||||
if !infinite {
|
||||
attempts--
|
||||
}
|
||||
|
||||
time.Sleep(duration)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func needRetry(err error, predicators []RetryPredicator) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// If no predicators specified, we will retry for all errors
|
||||
if len(predicators) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return predicators[0](err)
|
||||
}
|
106
util/retry_test.go
Normal file
106
util/retry_test.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func addAndCheck(a *int, target int) error {
|
||||
if *a++; *a == target {
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("a is not %v. It is %v\n", target, *a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetry(t *testing.T) {
|
||||
type test struct {
|
||||
input int
|
||||
targetNum int
|
||||
ans int
|
||||
ansErr error
|
||||
}
|
||||
tests := []test{
|
||||
{input: 0, targetNum: 3, ans: 3, ansErr: nil},
|
||||
{input: 0, targetNum: 10, ans: 3, ansErr: errors.New("failed in retry")},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
errHandled := false
|
||||
|
||||
err := Retry(context.Background(), 3, 1*time.Second, func() error {
|
||||
return addAndCheck(&tc.input, tc.targetNum)
|
||||
}, func(e error) { errHandled = true })
|
||||
|
||||
assert.Equal(t, true, errHandled)
|
||||
if tc.ansErr == nil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Contains(t, err.Error(), tc.ansErr.Error())
|
||||
}
|
||||
assert.Equal(t, tc.ans, tc.input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryWithPredicator(t *testing.T) {
|
||||
type test struct {
|
||||
count int
|
||||
f func() error
|
||||
errHandler func(error)
|
||||
predicator RetryPredicator
|
||||
ansCount int
|
||||
ansErr error
|
||||
}
|
||||
knownErr := errors.New("Duplicate entry '1-389837488-1' for key 'UNI_Trade'")
|
||||
unknownErr := errors.New("Some Error")
|
||||
tests := []test{
|
||||
{
|
||||
predicator: func(err error) bool {
|
||||
return !strings.Contains(err.Error(), "Duplicate entry")
|
||||
},
|
||||
f: func() error { return knownErr },
|
||||
ansCount: 1,
|
||||
ansErr: knownErr,
|
||||
},
|
||||
{
|
||||
predicator: func(err error) bool {
|
||||
return !strings.Contains(err.Error(), "Duplicate entry")
|
||||
},
|
||||
f: func() error { return unknownErr },
|
||||
ansCount: 3,
|
||||
ansErr: unknownErr,
|
||||
},
|
||||
}
|
||||
attempts := 3
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
for _, tc := range tests {
|
||||
err := Retry(ctx, attempts, 100*time.Millisecond, func() error {
|
||||
tc.count++
|
||||
return tc.f()
|
||||
}, tc.errHandler, tc.predicator)
|
||||
|
||||
assert.Equal(t, tc.ansCount, tc.count)
|
||||
assert.EqualError(t, errors.Cause(err), tc.ansErr.Error(), "should be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryCtxCancel(t *testing.T) {
|
||||
result := int(0)
|
||||
target := int(3)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := Retry(ctx, 5, 1*time.Second, func() error { return addAndCheck(&result, target) }, func(error) {})
|
||||
assert.Error(t, err)
|
||||
fmt.Println("Error:", err.Error())
|
||||
assert.Equal(t, int(0), result)
|
||||
}
|
78
websocket/backoff.go
Normal file
78
websocket/backoff.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMinBackoff = 5 * time.Second
|
||||
DefaultMaxBackoff = time.Minute
|
||||
|
||||
DefaultBackoffFactor = 2
|
||||
)
|
||||
|
||||
type Backoff struct {
|
||||
attempt int64
|
||||
Factor float64
|
||||
cur, Min, Max time.Duration
|
||||
}
|
||||
|
||||
// Duration returns the duration for the current attempt before incrementing
|
||||
// the attempt counter. See ForAttempt.
|
||||
func (b *Backoff) Duration() time.Duration {
|
||||
d := b.calculate(b.attempt)
|
||||
b.attempt++
|
||||
return d
|
||||
}
|
||||
|
||||
func (b *Backoff) calculate(attempt int64) time.Duration {
|
||||
min := b.Min
|
||||
if min <= 0 {
|
||||
min = DefaultMinBackoff
|
||||
}
|
||||
max := b.Max
|
||||
if max <= 0 {
|
||||
max = DefaultMaxBackoff
|
||||
}
|
||||
if min >= max {
|
||||
return max
|
||||
}
|
||||
factor := b.Factor
|
||||
if factor <= 0 {
|
||||
factor = DefaultBackoffFactor
|
||||
}
|
||||
cur := b.cur
|
||||
if cur < min {
|
||||
cur = min
|
||||
} else if cur > max {
|
||||
cur = max
|
||||
}
|
||||
|
||||
//calculate this duration
|
||||
next := cur
|
||||
if attempt > 0 {
|
||||
next = time.Duration(float64(cur) * factor)
|
||||
}
|
||||
|
||||
if next < cur {
|
||||
// overflow
|
||||
next = max
|
||||
} else if next <= min {
|
||||
next = min
|
||||
} else if next >= max {
|
||||
next = max
|
||||
}
|
||||
b.cur = next
|
||||
return next
|
||||
}
|
||||
|
||||
// Reset restarts the current attempt counter at zero.
|
||||
func (b *Backoff) Reset() {
|
||||
b.attempt = 0
|
||||
b.cur = b.Min
|
||||
}
|
||||
|
||||
// Attempt returns the current attempt counter value.
|
||||
func (b *Backoff) Attempt() int64 {
|
||||
return b.attempt
|
||||
}
|
172
websocket/backoff_test.go
Normal file
172
websocket/backoff_test.go
Normal file
|
@ -0,0 +1,172 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBackoff(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 200 * time.Millisecond,
|
||||
Max: 100 * time.Second,
|
||||
Factor: 2,
|
||||
}
|
||||
|
||||
equals(t, b.Duration(), 200*time.Millisecond)
|
||||
equals(t, b.Duration(), 400*time.Millisecond)
|
||||
equals(t, b.Duration(), 800*time.Millisecond)
|
||||
b.Reset()
|
||||
equals(t, b.Duration(), 200*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestReachMaxBackoff(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 2 * time.Second,
|
||||
Max: 100 * time.Second,
|
||||
Factor: 10,
|
||||
}
|
||||
|
||||
equals(t, b.Duration(), 2*time.Second)
|
||||
equals(t, b.Duration(), 20*time.Second)
|
||||
equals(t, b.Duration(), b.Max)
|
||||
b.Reset()
|
||||
equals(t, b.Duration(), 2*time.Second)
|
||||
}
|
||||
|
||||
func TestAttempt(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 2 * time.Second,
|
||||
Max: 100 * time.Second,
|
||||
Factor: 10,
|
||||
}
|
||||
|
||||
equals(t, b.Attempt(), int64(0))
|
||||
equals(t, b.Duration(), 2*time.Second)
|
||||
equals(t, b.Attempt(), int64(1))
|
||||
equals(t, b.Duration(), 20*time.Second)
|
||||
equals(t, b.Attempt(), int64(2))
|
||||
equals(t, b.Duration(), b.Max)
|
||||
equals(t, b.Attempt(), int64(3))
|
||||
b.Reset()
|
||||
equals(t, b.Attempt(), int64(0))
|
||||
equals(t, b.Duration(), 2*time.Second)
|
||||
equals(t, b.Attempt(), int64(1))
|
||||
}
|
||||
|
||||
func TestAttemptWithZeroValue(t *testing.T) {
|
||||
b := &Backoff{}
|
||||
|
||||
cur := DefaultMinBackoff
|
||||
equals(t, b.Attempt(), int64(0))
|
||||
equals(t, b.Duration(), cur)
|
||||
for i := 1; i < 10; i++ {
|
||||
equals(t, b.Attempt(), int64(i))
|
||||
cur *= DefaultBackoffFactor
|
||||
if cur >= DefaultMaxBackoff {
|
||||
equals(t, b.Duration(), DefaultMaxBackoff)
|
||||
break
|
||||
}
|
||||
equals(t, b.Duration(), cur)
|
||||
}
|
||||
|
||||
b.Reset()
|
||||
equals(t, b.Attempt(), int64(0))
|
||||
equals(t, b.Duration(), DefaultMinBackoff)
|
||||
equals(t, b.Attempt(), int64(1))
|
||||
}
|
||||
|
||||
func TestNegOrZeroMin(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 0,
|
||||
}
|
||||
equals(t, b.Duration(), DefaultMinBackoff)
|
||||
|
||||
b2 := &Backoff{
|
||||
Min: -1 * time.Second,
|
||||
}
|
||||
equals(t, b2.Duration(), DefaultMinBackoff)
|
||||
}
|
||||
|
||||
func TestNegOrZeroMax(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: DefaultMaxBackoff,
|
||||
Max: 0,
|
||||
}
|
||||
equals(t, b.Duration(), DefaultMaxBackoff)
|
||||
|
||||
b2 := &Backoff{
|
||||
Min: DefaultMaxBackoff,
|
||||
Max: -1 * time.Second,
|
||||
}
|
||||
equals(t, b2.Duration(), DefaultMaxBackoff)
|
||||
}
|
||||
|
||||
func TestMinLargerThanMax(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 500 * time.Second,
|
||||
Max: 100 * time.Second,
|
||||
Factor: 1,
|
||||
}
|
||||
equals(t, b.Duration(), b.Max)
|
||||
}
|
||||
|
||||
func TestFakeCur(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 100 * time.Second,
|
||||
Max: 500 * time.Second,
|
||||
Factor: 1,
|
||||
cur: 1000 * time.Second,
|
||||
}
|
||||
equals(t, b.Duration(), b.Max)
|
||||
}
|
||||
|
||||
func TestLargeFactor(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 1 * time.Second,
|
||||
Max: 10 * time.Second,
|
||||
Factor: math.MaxFloat64,
|
||||
}
|
||||
equals(t, b.Duration(), b.Min)
|
||||
for i := 0; i < 100; i++ {
|
||||
equals(t, b.Duration(), b.Max)
|
||||
}
|
||||
b.Reset()
|
||||
equals(t, b.Duration(), b.Min)
|
||||
}
|
||||
|
||||
func TestLargeMinMax(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: time.Duration(math.MaxInt64 - 1),
|
||||
Max: time.Duration(math.MaxInt64),
|
||||
Factor: math.MaxFloat64,
|
||||
}
|
||||
equals(t, b.Duration(), b.Min)
|
||||
for i := 0; i < 100; i++ {
|
||||
equals(t, b.Duration(), b.Max)
|
||||
}
|
||||
b.Reset()
|
||||
equals(t, b.Duration(), b.Min)
|
||||
}
|
||||
|
||||
func TestManyAttempts(t *testing.T) {
|
||||
b := &Backoff{
|
||||
Min: 10 * time.Second,
|
||||
Max: time.Duration(math.MaxInt64),
|
||||
Factor: 1000,
|
||||
}
|
||||
for i := 0; i < 10000; i++ {
|
||||
b.Duration()
|
||||
}
|
||||
equals(t, b.Duration(), b.Max)
|
||||
b.Reset()
|
||||
}
|
||||
|
||||
func equals(t *testing.T, v1, v2 interface{}) {
|
||||
if !reflect.DeepEqual(v1, v2) {
|
||||
assert.Failf(t, "not equal", "Got %v, Expecting %v", v1, v2)
|
||||
}
|
||||
}
|
368
websocket/client.go
Normal file
368
websocket/client.go
Normal file
|
@ -0,0 +1,368 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/log"
|
||||
"github.com/c9s/bbgo/pkg/util"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const DefaultMessageBufferSize = 128
|
||||
|
||||
const DefaultWriteTimeout = 30 * time.Second
|
||||
const DefaultReadTimeout = 30 * time.Second
|
||||
|
||||
var ErrReconnectContextDone = errors.New("reconnect canceled due to context done.")
|
||||
var ErrReconnectFailed = errors.New("failed to reconnect")
|
||||
var ErrConnectionLost = errors.New("connection lost")
|
||||
|
||||
var MaxReconnectRate = rate.Limit(1 / DefaultMinBackoff.Seconds())
|
||||
|
||||
// WebSocketClient allows to connect and receive stream data
|
||||
type WebSocketClient struct {
|
||||
// Url is the websocket connection location, start with ws:// or wss://
|
||||
Url string
|
||||
|
||||
// conn is the current websocket connection, please note the connection
|
||||
// object can be replaced with a new connection object when the connection
|
||||
// is unexpected closed.
|
||||
conn *websocket.Conn
|
||||
|
||||
// Dialer is used for creating the websocket connection
|
||||
Dialer *websocket.Dialer
|
||||
|
||||
// requestHeader is used for the Dial function call. Some credential can be
|
||||
// stored in the http request header for authentication
|
||||
requestHeader http.Header
|
||||
|
||||
// messages is a read-only channel, received messages will be sent to this
|
||||
// channel.
|
||||
messages chan Message
|
||||
|
||||
readTimeout time.Duration
|
||||
|
||||
writeTimeout time.Duration
|
||||
|
||||
onConnect []func(c Client)
|
||||
|
||||
onDisconnect []func(c Client)
|
||||
|
||||
// cancel is mapped to the ctx context object
|
||||
cancel func()
|
||||
|
||||
readerClosed chan struct{}
|
||||
|
||||
connected bool
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
reconnectCh chan struct{}
|
||||
|
||||
backoff Backoff
|
||||
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
// websocket.BinaryMessage or websocket.TextMessage
|
||||
Type int
|
||||
Body []byte
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) Messages() <-chan Message {
|
||||
return c.messages
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) SetReadTimeout(timeout time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.readTimeout = timeout
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) SetWriteTimeout(timeout time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.writeTimeout = timeout
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) OnConnect(f func(c Client)) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.onConnect = append(c.onConnect, f)
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) OnDisconnect(f func(c Client)) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.onDisconnect = append(c.onDisconnect, f)
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) WriteTextMessage(message []byte) error {
|
||||
return c.WriteMessage(websocket.TextMessage, message)
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) WriteBinaryMessage(message []byte) error {
|
||||
return c.WriteMessage(websocket.BinaryMessage, message)
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) WriteMessage(messageType int, data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return ErrConnectionLost
|
||||
}
|
||||
|
||||
if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.conn.WriteMessage(messageType, data)
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) readMessages() error {
|
||||
c.mu.Lock()
|
||||
if !c.connected {
|
||||
c.mu.Unlock()
|
||||
return ErrConnectionLost
|
||||
}
|
||||
timeout := c.readTimeout
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
msgtype, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.messages <- Message{msgtype, message}
|
||||
return nil
|
||||
}
|
||||
|
||||
// listen starts a goroutine for reading message and tries to re-connect to the
|
||||
// server when the reader returns error
|
||||
//
|
||||
// Please note we should always break the reader loop if there is any error
|
||||
// returned from the server.
|
||||
func (c *WebSocketClient) listen(ctx context.Context) {
|
||||
// The life time of both channels "readerClosed" and "reconnectCh" is bound to one connection.
|
||||
// Each channel should be created before loop starts and be closed after loop ends.
|
||||
// "readerClosed" is used to inform "Close()" reader loop ends.
|
||||
// "reconnectCh" is used to centralize reconnection logics in this reader loop.
|
||||
c.mu.Lock()
|
||||
c.readerClosed = make(chan struct{})
|
||||
c.reconnectCh = make(chan struct{}, 1)
|
||||
c.mu.Unlock()
|
||||
defer func() {
|
||||
c.mu.Lock()
|
||||
close(c.readerClosed)
|
||||
close(c.reconnectCh)
|
||||
c.reconnectCh = nil
|
||||
c.mu.Unlock()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case <-c.reconnectCh:
|
||||
// it could be i/o timeout for network disconnection
|
||||
// or it could be invoked from outside.
|
||||
c.SetDisconnected()
|
||||
var maxTries = 1
|
||||
if _, response, err := c.reconnect(ctx, maxTries); err != nil {
|
||||
if err == ErrReconnectContextDone {
|
||||
log.Debugf("[websocket] context canceled. stop reconnecting.")
|
||||
return
|
||||
}
|
||||
log.Warnf("[websocket] failed to reconnect after %d tries!! error: %v response: %v", maxTries, err, response)
|
||||
c.Reconnect()
|
||||
}
|
||||
|
||||
default:
|
||||
if err := c.readMessages(); err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
|
||||
log.Warnf("[websocket] unexpected close error reconnecting: %v", err)
|
||||
}
|
||||
|
||||
log.Warnf("[websocket] failed to read message. error: %+v", err)
|
||||
c.Reconnect()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reconnect triggers reconnection logics
|
||||
func (c *WebSocketClient) Reconnect() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
select {
|
||||
// c.reconnectCh is a buffered channel with cap=1.
|
||||
// At most one reconnect signal could be processed.
|
||||
case c.reconnectCh <- struct{}{}:
|
||||
default:
|
||||
// Entering here means it is already reconnecting.
|
||||
// Drop the current reconnect signal.
|
||||
}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the reader and the connection
|
||||
// ctx is the context used for shutdown process.
|
||||
func (c *WebSocketClient) Close() (err error) {
|
||||
c.mu.Lock()
|
||||
// leave the listen goroutine before we close the connection
|
||||
// checking nil is to handle calling "Close" before "Connect" is called
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.SetDisconnected()
|
||||
|
||||
// wait for the reader func to be closed
|
||||
if c.readerClosed != nil {
|
||||
<-c.readerClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// reconnect tries to create a new connection from the existing dialer
|
||||
func (c *WebSocketClient) reconnect(ctx context.Context, maxTries int) (*websocket.Conn, *http.Response, error) {
|
||||
log.Debugf("[websocket] start reconnecting to %q", c.Url)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, nil, ErrReconnectContextDone
|
||||
default:
|
||||
}
|
||||
|
||||
if s := util.ShouldDelay(c.limiter, DefaultMinBackoff); s > 0 {
|
||||
log.Warn("[websocket] reconnect too frequently. Sleep for ", s)
|
||||
time.Sleep(s)
|
||||
}
|
||||
|
||||
log.Warnf("[websocket] reconnecting x %d to %q", c.backoff.Attempt()+1, c.Url)
|
||||
conn, resp, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
|
||||
if err != nil {
|
||||
dur := c.backoff.Duration()
|
||||
log.Warnf("failed to dial %s: %v, response: %+v. Wait for %v", c.Url, err, resp, dur)
|
||||
time.Sleep(dur)
|
||||
return nil, nil, ErrReconnectFailed
|
||||
}
|
||||
|
||||
log.Infof("[websocket] reconnected to %q", c.Url)
|
||||
// Reset backoff value if connected.
|
||||
c.backoff.Reset()
|
||||
c.setConn(conn)
|
||||
c.setPingHandler(conn)
|
||||
|
||||
return conn, resp, err
|
||||
}
|
||||
|
||||
// Conn returns the current active connection instance
|
||||
func (c *WebSocketClient) Conn() (conn *websocket.Conn) {
|
||||
c.mu.Lock()
|
||||
conn = c.conn
|
||||
c.mu.Unlock()
|
||||
return conn
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) setConn(conn *websocket.Conn) {
|
||||
// Disconnect old connection before replacing with new one.
|
||||
c.SetDisconnected()
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = conn
|
||||
c.connected = true
|
||||
c.mu.Unlock()
|
||||
for _, f := range c.onConnect {
|
||||
go f(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) setPingHandler(conn *websocket.Conn) {
|
||||
conn.SetPingHandler(func(message string) error {
|
||||
if err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
|
||||
return err
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) SetDisconnected() {
|
||||
c.mu.Lock()
|
||||
closed := false
|
||||
if c.conn != nil {
|
||||
closed = true
|
||||
c.conn.Close()
|
||||
}
|
||||
c.connected = false
|
||||
c.conn = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
// Only call disconnect callbacks when a connection is closed
|
||||
for _, f := range c.onDisconnect {
|
||||
go f(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) IsConnected() (ret bool) {
|
||||
c.mu.Lock()
|
||||
ret = c.connected
|
||||
c.mu.Unlock()
|
||||
return ret
|
||||
}
|
||||
|
||||
func (c *WebSocketClient) Connect(basectx context.Context) error {
|
||||
// maintain a context by the client it self, so that we can manually shutdown the connection
|
||||
ctx, cancel := context.WithCancel(basectx)
|
||||
c.cancel = cancel
|
||||
|
||||
conn, _, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
|
||||
if err == nil {
|
||||
// setup connection only when connected
|
||||
c.setConn(conn)
|
||||
c.setPingHandler(conn)
|
||||
}
|
||||
|
||||
// 1) if connection is built up, start listening for messages.
|
||||
// 2) if connection is NOT ready, start reconnecting infinitely.
|
||||
go c.listen(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func New(url string, requestHeader http.Header) *WebSocketClient {
|
||||
return NewWithDialer(url, websocket.DefaultDialer, requestHeader)
|
||||
}
|
||||
|
||||
func NewWithDialer(url string, d *websocket.Dialer, requestHeader http.Header) *WebSocketClient {
|
||||
limiter, err := util.NewValidLimiter(MaxReconnectRate, 1)
|
||||
if err != nil {
|
||||
log.WithError(err).Panic("Invalid rate limiter")
|
||||
}
|
||||
return &WebSocketClient{
|
||||
Url: url,
|
||||
Dialer: d,
|
||||
requestHeader: requestHeader,
|
||||
readTimeout: DefaultReadTimeout,
|
||||
writeTimeout: DefaultWriteTimeout,
|
||||
messages: make(chan Message, DefaultMessageBufferSize),
|
||||
limiter: limiter,
|
||||
}
|
||||
}
|
202
websocket/client_test.go
Normal file
202
websocket/client_test.go
Normal file
|
@ -0,0 +1,202 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"***REMOVED***/pkg/log"
|
||||
"***REMOVED***/pkg/testing/testutil"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func messageTicker(t *testing.T, ctx context.Context, conn *websocket.Conn) {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case tt := <-ticker.C:
|
||||
err := conn.WriteMessage(websocket.TextMessage, []byte(tt.String()))
|
||||
assert.NoError(t, err)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func messageReader(t *testing.T, ctx context.Context, conn *websocket.Conn) {
|
||||
for {
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Errorf("unexpected closed error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
t.Logf("message: %v", message)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconnect(t *testing.T) {
|
||||
basectx := context.Background()
|
||||
|
||||
wsHandler := func(ctx context.Context) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Logf("upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
go messageTicker(t, ctx, conn)
|
||||
messageReader(t, ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
serverctx, cancelserver := context.WithCancel(basectx)
|
||||
server := testutil.NewWebSocketServerFunc(t, 0, "/ws", wsHandler(serverctx))
|
||||
go func() {
|
||||
err := server.ListenAndServe()
|
||||
assert.Equal(t, http.ErrServerClosed, err)
|
||||
}()
|
||||
|
||||
_, port, err := net.SplitHostPort(server.Addr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
log.Debugf("waiting for server port ready")
|
||||
testutil.WaitForPort(t, server.Addr, 3)
|
||||
|
||||
u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"}
|
||||
log.Infof("url: %s", u.String())
|
||||
|
||||
client := New(u.String(), nil)
|
||||
|
||||
// start the message reader
|
||||
clientctx, cancelClient := context.WithCancel(basectx)
|
||||
defer cancelClient()
|
||||
|
||||
err = client.Connect(clientctx)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
client.SetReadTimeout(2 * time.Second)
|
||||
|
||||
// read one message
|
||||
log.Debugf("waiting for message...")
|
||||
msg := <-client.messages
|
||||
log.Debugf("received message: %v", msg)
|
||||
|
||||
triggerReconnectManyTimes := func() {
|
||||
// Forcedly reconnect multiple times
|
||||
for i := 0; i < 10; i = i + 1 {
|
||||
client.Reconnect()
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("reconnecting...")
|
||||
triggerReconnectManyTimes()
|
||||
msg = <-client.messages
|
||||
log.Debugf("received message: %v", msg)
|
||||
|
||||
log.Debugf("shutting down server...")
|
||||
err = server.Shutdown(serverctx)
|
||||
assert.NoError(t, err)
|
||||
cancelserver()
|
||||
server.Close()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
triggerReconnectManyTimes()
|
||||
|
||||
log.Debugf("restarting server...")
|
||||
server2ctx, cancelserver2 := context.WithCancel(basectx)
|
||||
server2 := testutil.NewWebSocketServerWithAddressFunc(t, server.Addr, "/ws", wsHandler(server2ctx))
|
||||
go func() {
|
||||
err := server2.ListenAndServe()
|
||||
assert.Equal(t, http.ErrServerClosed, err)
|
||||
}()
|
||||
|
||||
triggerReconnectManyTimes()
|
||||
|
||||
log.Debugf("waiting for server2 port ready")
|
||||
testutil.WaitForPort(t, server.Addr, 3)
|
||||
|
||||
log.Debugf("waiting for message from server2...")
|
||||
msg = <-client.messages
|
||||
log.Debugf("received message: %v", msg)
|
||||
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = server2.Shutdown(server2ctx)
|
||||
assert.NoError(t, err)
|
||||
cancelserver2()
|
||||
}
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
basectx := context.Background()
|
||||
ctx, cancel := context.WithCancel(basectx)
|
||||
|
||||
server := testutil.NewWebSocketServerFunc(t, 0, "/ws", func(w http.ResponseWriter, r *http.Request) {
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Logf("upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go messageTicker(t, ctx, conn)
|
||||
messageReader(t, ctx, conn)
|
||||
})
|
||||
|
||||
go func() {
|
||||
err := server.ListenAndServe()
|
||||
assert.Equal(t, http.ErrServerClosed, err)
|
||||
}()
|
||||
|
||||
_, port, err := net.SplitHostPort(server.Addr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
testutil.WaitForPort(t, server.Addr, 3)
|
||||
|
||||
u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"}
|
||||
t.Logf("url: %s", u.String())
|
||||
|
||||
client := New(u.String(), nil)
|
||||
|
||||
err = client.Connect(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
client.SetReadTimeout(2 * time.Second)
|
||||
|
||||
// read one message
|
||||
t.Logf("waiting for message...")
|
||||
msg := <-client.messages
|
||||
t.Logf("received message: %v", msg)
|
||||
|
||||
// recreate server at the same address
|
||||
client.Close()
|
||||
|
||||
t.Logf("shutting down server...")
|
||||
cancel()
|
||||
err = server.Shutdown(basectx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
}
|
21
websocket/interface.go
Normal file
21
websocket/interface.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
//go:generate mockery -name=Client
|
||||
|
||||
type Client interface {
|
||||
SetWriteTimeout(time.Duration)
|
||||
SetReadTimeout(time.Duration)
|
||||
OnConnect(func(c Client))
|
||||
OnDisconnect(func(c Client))
|
||||
Connect(context.Context) error
|
||||
Reconnect()
|
||||
Close() error
|
||||
IsConnected() bool
|
||||
WriteJSON(interface{}) error
|
||||
Messages() <-chan Message
|
||||
}
|
55
websocket/json.go
Normal file
55
websocket/json.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// WriteJSON writes the JSON encoding of v as a message.
|
||||
//
|
||||
// See the documentation for encoding/json Marshal for details about the
|
||||
// conversion of Go values to JSON.
|
||||
func (c *WebSocketClient) WriteJSON(v interface{}) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return ErrConnectionLost
|
||||
}
|
||||
w, err := c.conn.NextWriter(websocket.TextMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err1 := json.NewEncoder(w).Encode(v)
|
||||
err2 := w.Close()
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
// ReadJSON reads the next JSON-encoded message from the connection and stores
|
||||
// it in the value pointed to by v.
|
||||
//
|
||||
// See the documentation for the encoding/json Unmarshal function for details
|
||||
// about the conversion of JSON to a Go value.
|
||||
func (c *WebSocketClient) ReadJSON(v interface{}) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return ErrConnectionLost
|
||||
}
|
||||
_, r, err := c.conn.NextReader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.NewDecoder(r).Decode(v)
|
||||
if err == io.EOF {
|
||||
// One value is expected in the message.
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
110
websocket/mocks/Client.go
Normal file
110
websocket/mocks/Client.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
// Code generated by mockery v1.0.0. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import context "context"
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
import time "time"
|
||||
import websocket "***REMOVED***/pkg/net/http/websocket"
|
||||
|
||||
// Client is an autogenerated mock type for the Client type
|
||||
type Client struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Close provides a mock function with given fields:
|
||||
func (_m *Client) Close() error {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Connect provides a mock function with given fields: _a0
|
||||
func (_m *Client) Connect(_a0 context.Context) error {
|
||||
ret := _m.Called(_a0)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = rf(_a0)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// IsConnected provides a mock function with given fields:
|
||||
func (_m *Client) IsConnected() bool {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func() bool); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Messages provides a mock function with given fields:
|
||||
func (_m *Client) Messages() <-chan websocket.Message {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 <-chan websocket.Message
|
||||
if rf, ok := ret.Get(0).(func() <-chan websocket.Message); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(<-chan websocket.Message)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// OnConnect provides a mock function with given fields: _a0
|
||||
func (_m *Client) OnConnect(_a0 func(websocket.Client)) {
|
||||
_m.Called(_a0)
|
||||
}
|
||||
|
||||
// OnDisconnect provides a mock function with given fields: _a0
|
||||
func (_m *Client) OnDisconnect(_a0 func(websocket.Client)) {
|
||||
_m.Called(_a0)
|
||||
}
|
||||
|
||||
// Reconnect provides a mock function with given fields:
|
||||
func (_m *Client) Reconnect() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// SetReadTimeout provides a mock function with given fields: _a0
|
||||
func (_m *Client) SetReadTimeout(_a0 time.Duration) {
|
||||
_m.Called(_a0)
|
||||
}
|
||||
|
||||
// SetWriteTimeout provides a mock function with given fields: _a0
|
||||
func (_m *Client) SetWriteTimeout(_a0 time.Duration) {
|
||||
_m.Called(_a0)
|
||||
}
|
||||
|
||||
// WriteJSON provides a mock function with given fields: _a0
|
||||
func (_m *Client) WriteJSON(_a0 interface{}) error {
|
||||
ret := _m.Called(_a0)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
||||
r0 = rf(_a0)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
Loading…
Reference in New Issue
Block a user