mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-26 08:45:16 +00:00
update
This commit is contained in:
commit
3fe812264f
40
bbgo/event.go
Normal file
40
bbgo/event.go
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
package bbgo
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
kline
|
||||||
|
|
||||||
|
{
|
||||||
|
"e": "kline", // KLineEvent type
|
||||||
|
"E": 123456789, // KLineEvent time
|
||||||
|
"s": "BNBBTC", // Symbol
|
||||||
|
"k": {
|
||||||
|
"t": 123400000, // Kline start time
|
||||||
|
"T": 123460000, // Kline close time
|
||||||
|
"s": "BNBBTC", // Symbol
|
||||||
|
"i": "1m", // Interval
|
||||||
|
"f": 100, // First trade ID
|
||||||
|
"L": 200, // Last trade ID
|
||||||
|
"o": "0.0010", // Open price
|
||||||
|
"c": "0.0020", // Close price
|
||||||
|
"h": "0.0025", // High price
|
||||||
|
"l": "0.0015", // Low price
|
||||||
|
"v": "1000", // Base asset volume
|
||||||
|
"n": 100, // Number of trades
|
||||||
|
"x": false, // Is this kline closed?
|
||||||
|
"q": "1.0000", // Quote asset volume
|
||||||
|
"V": "500", // Taker buy base asset volume
|
||||||
|
"Q": "0.500", // Taker buy quote asset volume
|
||||||
|
"B": "123456" // Ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
*/
|
||||||
|
type EventBase struct {
|
||||||
|
Event string `json:"e"` // event
|
||||||
|
Time int64 `json:"E"`
|
||||||
|
}
|
||||||
|
|
137
bbgo/kline.go
Normal file
137
bbgo/kline.go
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
package bbgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KLineEvent struct {
|
||||||
|
EventBase
|
||||||
|
Symbol string `json:"s"`
|
||||||
|
KLine *KLine `json:"k,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KLine struct {
|
||||||
|
StartTime int64 `json:"t"`
|
||||||
|
EndTime int64 `json:"T"`
|
||||||
|
|
||||||
|
Symbol string `json:"s"`
|
||||||
|
Interval string `json:"i"`
|
||||||
|
|
||||||
|
Open string `json:"o"`
|
||||||
|
Close string `json:"c"`
|
||||||
|
High string `json:"h"`
|
||||||
|
Low string `json:"l"`
|
||||||
|
Volume string `json:"V"` // taker buy base asset volume (like 10 BTC)
|
||||||
|
QuoteVolume string `json:"Q"` // taker buy quote asset volume (like 1000USDT)
|
||||||
|
|
||||||
|
LastTradeID int `json:"L"`
|
||||||
|
NumberOfTrades int `json:"n"`
|
||||||
|
Closed bool `json:"x"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k KLine) GetTrend() int {
|
||||||
|
o := k.GetOpen()
|
||||||
|
c := k.GetClose()
|
||||||
|
|
||||||
|
if c > o {
|
||||||
|
return 1
|
||||||
|
} else if c < o {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k KLine) GetHigh() float64 {
|
||||||
|
return MustParseFloat(k.High)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k KLine) GetLow() float64 {
|
||||||
|
return MustParseFloat(k.Low)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k KLine) GetOpen() float64 {
|
||||||
|
return MustParseFloat(k.Open)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k KLine) GetClose() float64 {
|
||||||
|
return MustParseFloat(k.Close)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k KLine) GetMaxChange() float64 {
|
||||||
|
return k.GetHigh() - k.GetLow()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k KLine) GetThickness() float64 {
|
||||||
|
return k.GetChange() / k.GetMaxChange()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k KLine) GetChange() float64 {
|
||||||
|
return k.GetClose() - k.GetOpen()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k KLine) String() string {
|
||||||
|
return fmt.Sprintf("%s %s Open: % 14s Close: % 14s High: % 14s Low: % 14s Volume: % 13s Change: % 13f %s", k.Symbol, k.Interval, k.Open, k.Close, k.High, k.Low, k.Volume, k.GetChange(), k.Interval)
|
||||||
|
}
|
||||||
|
|
||||||
|
type KLineWindow []KLine
|
||||||
|
|
||||||
|
func (w KLineWindow) Len() int {
|
||||||
|
return len(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w KLineWindow) GetOpen() float64 {
|
||||||
|
return w[0].GetOpen()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w KLineWindow) GetClose() float64 {
|
||||||
|
end := len(w) - 1
|
||||||
|
return w[end].GetClose()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w KLineWindow) GetHigh() float64 {
|
||||||
|
high := w.GetOpen()
|
||||||
|
for _, line := range w {
|
||||||
|
val := line.GetHigh()
|
||||||
|
if val > high {
|
||||||
|
high = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return high
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w KLineWindow) GetLow() float64 {
|
||||||
|
low := w.GetOpen()
|
||||||
|
for _, line := range w {
|
||||||
|
val := line.GetHigh()
|
||||||
|
if val < low {
|
||||||
|
low = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return low
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w KLineWindow) GetChange() float64 {
|
||||||
|
return w.GetClose() - w.GetOpen()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w KLineWindow) GetMaxChange() float64 {
|
||||||
|
return w.GetHigh() - w.GetLow()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *KLineWindow) Add(line KLine) {
|
||||||
|
*w = append(*w, line)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *KLineWindow) Truncate(size int) {
|
||||||
|
if len(*w) <= size {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
end := len(*w) - 1
|
||||||
|
start := end - size
|
||||||
|
if start < 0 {
|
||||||
|
start = 0
|
||||||
|
}
|
||||||
|
*w = (*w)[end-5 : end]
|
||||||
|
}
|
||||||
|
|
12
bbgo/parse.go
Normal file
12
bbgo/parse.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
package bbgo
|
||||||
|
|
||||||
|
import "strconv"
|
||||||
|
|
||||||
|
func MustParseFloat(s string) float64 {
|
||||||
|
v, err := strconv.ParseFloat(s, 64)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
196
bbgo/parser.go
Normal file
196
bbgo/parser.go
Normal file
|
@ -0,0 +1,196 @@
|
||||||
|
package bbgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/valyala/fastjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
executionReport
|
||||||
|
|
||||||
|
{
|
||||||
|
"e": "executionReport", // KLineEvent type
|
||||||
|
"E": 1499405658658, // KLineEvent time
|
||||||
|
"s": "ETHBTC", // Symbol
|
||||||
|
"c": "mUvoqJxFIILMdfAW5iGSOW", // Client order ID
|
||||||
|
"S": "BUY", // Side
|
||||||
|
"o": "LIMIT", // Order type
|
||||||
|
"f": "GTC", // Time in force
|
||||||
|
"q": "1.00000000", // Order quantity
|
||||||
|
"p": "0.10264410", // Order price
|
||||||
|
"P": "0.00000000", // Stop price
|
||||||
|
"F": "0.00000000", // Iceberg quantity
|
||||||
|
"g": -1, // OrderListId
|
||||||
|
"C": null, // Original client order ID; This is the ID of the order being canceled
|
||||||
|
"x": "NEW", // Current execution type
|
||||||
|
"X": "NEW", // Current order status
|
||||||
|
"r": "NONE", // Order reject reason; will be an error code.
|
||||||
|
"i": 4293153, // Order ID
|
||||||
|
"l": "0.00000000", // Last executed quantity
|
||||||
|
"z": "0.00000000", // Cumulative filled quantity
|
||||||
|
"L": "0.00000000", // Last executed price
|
||||||
|
"n": "0", // Commission amount
|
||||||
|
"N": null, // Commission asset
|
||||||
|
"T": 1499405658657, // Transaction time
|
||||||
|
"t": -1, // Trade ID
|
||||||
|
"I": 8641984, // Ignore
|
||||||
|
"w": true, // Is the order on the book?
|
||||||
|
"m": false, // Is this trade the maker side?
|
||||||
|
"M": false, // Ignore
|
||||||
|
"O": 1499405658657, // Order creation time
|
||||||
|
"Z": "0.00000000", // Cumulative quote asset transacted quantity
|
||||||
|
"Y": "0.00000000", // Last quote asset transacted quantity (i.e. lastPrice * lastQty)
|
||||||
|
"Q": "0.00000000" // Quote Order Qty
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
type ExecutionReportEvent struct {
|
||||||
|
EventBase
|
||||||
|
|
||||||
|
Symbol string `json:"s"`
|
||||||
|
ClientOrderID string `json:"c"`
|
||||||
|
Side string `json:"S"`
|
||||||
|
OrderType string `json:"o"`
|
||||||
|
TimeInForce string `json:"f"`
|
||||||
|
|
||||||
|
Quantity string `json:"q"`
|
||||||
|
Price string `json:"p"`
|
||||||
|
StopPrice string `json:"P"`
|
||||||
|
|
||||||
|
CurrentExecutionType string `json:"x"`
|
||||||
|
CurrentOrderStatus string `json:"X"`
|
||||||
|
|
||||||
|
OrderID int `json:"i"`
|
||||||
|
|
||||||
|
LastExecutedQuantity string `json:"l"`
|
||||||
|
CumulativeFilledQuantity string `json:"z"`
|
||||||
|
LastExecutedPrice string `json:"L"`
|
||||||
|
|
||||||
|
OrderCreationTime int `json:"O"`
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
balanceUpdate
|
||||||
|
|
||||||
|
{
|
||||||
|
"e": "balanceUpdate", //KLineEvent Type
|
||||||
|
"E": 1573200697110, //KLineEvent Time
|
||||||
|
"a": "BTC", //Asset
|
||||||
|
"d": "100.00000000", //Balance Delta
|
||||||
|
"T": 1573200697068 //Clear Time
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
type BalanceUpdateEvent struct {
|
||||||
|
EventBase
|
||||||
|
|
||||||
|
Asset string `json:"a"`
|
||||||
|
Delta string `json:"d"`
|
||||||
|
ClearTime int64 `json:"T"`
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
outboundAccountInfo
|
||||||
|
|
||||||
|
{
|
||||||
|
"e": "outboundAccountInfo", // KLineEvent type
|
||||||
|
"E": 1499405658849, // KLineEvent time
|
||||||
|
"m": 0, // Maker commission rate (bips)
|
||||||
|
"t": 0, // Taker commission rate (bips)
|
||||||
|
"b": 0, // Buyer commission rate (bips)
|
||||||
|
"s": 0, // Seller commission rate (bips)
|
||||||
|
"T": true, // Can trade?
|
||||||
|
"W": true, // Can withdraw?
|
||||||
|
"D": true, // Can deposit?
|
||||||
|
"u": 1499405658848, // Time of last account update
|
||||||
|
"B": [ // Balances array
|
||||||
|
{
|
||||||
|
"a": "LTC", // Asset
|
||||||
|
"f": "17366.18538083", // Free amount
|
||||||
|
"l": "0.00000000" // Locked amount
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"a": "BTC",
|
||||||
|
"f": "10537.85314051",
|
||||||
|
"l": "2.19464093"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"a": "ETH",
|
||||||
|
"f": "17902.35190619",
|
||||||
|
"l": "0.00000000"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"a": "BNC",
|
||||||
|
"f": "1114503.29769312",
|
||||||
|
"l": "0.00000000"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"a": "NEO",
|
||||||
|
"f": "0.00000000",
|
||||||
|
"l": "0.00000000"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"P": [ // Account Permissions
|
||||||
|
"SPOT"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
type Balance struct {
|
||||||
|
Asset string `json:"a"`
|
||||||
|
Free string `json:"f"`
|
||||||
|
Locked string `json:"l"`
|
||||||
|
}
|
||||||
|
type OutboundAccountInfoEvent struct {
|
||||||
|
EventBase
|
||||||
|
|
||||||
|
MakerCommissionRate int `json:"m"`
|
||||||
|
TakerCommissionRate int `json:"t"`
|
||||||
|
BuyerCommissionRate int `json:"b"`
|
||||||
|
SellerCommissionRate int `json:"s"`
|
||||||
|
|
||||||
|
CanTrade bool `json:"T"`
|
||||||
|
CanWithdraw bool `json:"W"`
|
||||||
|
CanDeposit bool `json:"D"`
|
||||||
|
|
||||||
|
LastAccountUpdateTime int `json:"u"`
|
||||||
|
|
||||||
|
Balances []Balance `json:"B,omitempty"`
|
||||||
|
Permissions []string `json:"P,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseEvent(message string) (interface{}, error) {
|
||||||
|
val, err := fastjson.Parse(message)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
eventType := string(val.GetStringBytes("e"))
|
||||||
|
|
||||||
|
switch eventType {
|
||||||
|
case "kline":
|
||||||
|
var event KLineEvent
|
||||||
|
err := json.Unmarshal([]byte(message), &event)
|
||||||
|
return &event, err
|
||||||
|
|
||||||
|
case "outboundAccountInfo", "outboundAccountPosition":
|
||||||
|
var event OutboundAccountInfoEvent
|
||||||
|
err := json.Unmarshal([]byte(message), &event)
|
||||||
|
return &event, err
|
||||||
|
|
||||||
|
case "balanceUpdate":
|
||||||
|
var event BalanceUpdateEvent
|
||||||
|
err := json.Unmarshal([]byte(message), &event)
|
||||||
|
return &event, err
|
||||||
|
|
||||||
|
case "executionReport":
|
||||||
|
var event ExecutionReportEvent
|
||||||
|
err := json.Unmarshal([]byte(message), &event)
|
||||||
|
return &event, err
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unsupported message: %s", message)
|
||||||
|
}
|
||||||
|
|
66
bbgo/pnl.go
Normal file
66
bbgo/pnl.go
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
package bbgo
|
||||||
|
|
||||||
|
import log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
func CalculateAverageCost(trades []Trade) (averageCost float64) {
|
||||||
|
var totalCost = 0.0
|
||||||
|
var totalQuantity = 0.0
|
||||||
|
for _, t := range trades {
|
||||||
|
if t.IsBuyer {
|
||||||
|
totalCost += t.Price * t.Volume
|
||||||
|
totalQuantity += t.Volume
|
||||||
|
} else {
|
||||||
|
totalCost -= t.Price * t.Volume
|
||||||
|
totalQuantity -= t.Volume
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
averageCost = totalCost / totalQuantity
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func CalculateCostAndProfit(trades []Trade, currentPrice float64) (averageBidPrice, stock, profit, fee float64) {
|
||||||
|
var bidVolume = 0.0
|
||||||
|
var bidAmount = 0.0
|
||||||
|
var bidFee = 0.0
|
||||||
|
for _, t := range trades {
|
||||||
|
if t.IsBuyer {
|
||||||
|
bidVolume += t.Volume
|
||||||
|
bidAmount += t.Price * t.Volume
|
||||||
|
switch t.FeeCurrency {
|
||||||
|
case "BTC":
|
||||||
|
bidFee += t.Price * t.Fee
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("average bid price = (total amount %f + total fee %f) / volume %f", bidAmount, bidFee, bidVolume)
|
||||||
|
averageBidPrice = (bidAmount + bidFee) / bidVolume
|
||||||
|
|
||||||
|
var feeRate = 0.001
|
||||||
|
var askVolume = 0.0
|
||||||
|
var askFee = 0.0
|
||||||
|
for _, t := range trades {
|
||||||
|
if !t.IsBuyer {
|
||||||
|
profit += (t.Price - averageBidPrice) * t.Volume
|
||||||
|
askVolume += t.Volume
|
||||||
|
switch t.FeeCurrency {
|
||||||
|
case "USDT":
|
||||||
|
askFee += t.Fee
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
profit -= askFee
|
||||||
|
|
||||||
|
stock = bidVolume - askVolume
|
||||||
|
futureFee := 0.0
|
||||||
|
if stock > 0 {
|
||||||
|
stockfee := currentPrice * feeRate * stock
|
||||||
|
profit += (currentPrice-averageBidPrice)*stock - stockfee
|
||||||
|
futureFee += stockfee
|
||||||
|
}
|
||||||
|
|
||||||
|
fee = bidFee + askFee + futureFee
|
||||||
|
return
|
||||||
|
}
|
92
bbgo/trade.go
Normal file
92
bbgo/trade.go
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
package bbgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/adshao/go-binance"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Trade struct {
|
||||||
|
ID int64
|
||||||
|
Price float64
|
||||||
|
Volume float64
|
||||||
|
IsBuyer bool
|
||||||
|
IsMaker bool
|
||||||
|
Time time.Time
|
||||||
|
|
||||||
|
Fee float64
|
||||||
|
FeeCurrency string
|
||||||
|
}
|
||||||
|
|
||||||
|
func QueryTrades(ctx context.Context, client *binance.Client, market string, startTime time.Time) (trades []Trade, err error) {
|
||||||
|
var lastTradeID int64 = 0
|
||||||
|
for {
|
||||||
|
req := client.NewListTradesService().
|
||||||
|
Limit(1000).
|
||||||
|
Symbol(market).
|
||||||
|
StartTime(startTime.UnixNano() / 1000000)
|
||||||
|
|
||||||
|
if lastTradeID > 0 {
|
||||||
|
req.FromID(lastTradeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
bnTrades, err := req.Do(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(bnTrades) <= 1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range bnTrades {
|
||||||
|
// skip trade ID that is the same
|
||||||
|
if t.ID == lastTradeID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var side string
|
||||||
|
if t.IsBuyer {
|
||||||
|
side = "BUY"
|
||||||
|
} else {
|
||||||
|
side = "SELL"
|
||||||
|
}
|
||||||
|
|
||||||
|
// trade time
|
||||||
|
tt := time.Unix(0, t.Time*1000000)
|
||||||
|
log.Infof("trade: %d %4s Price: % 13s Volume: % 13s %s", t.ID, side, t.Price, t.Quantity, tt)
|
||||||
|
|
||||||
|
price, err := strconv.ParseFloat(t.Price, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
quantity, err := strconv.ParseFloat(t.Quantity, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fee, err := strconv.ParseFloat(t.Commission, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
trades = append(trades, Trade{
|
||||||
|
ID: t.ID,
|
||||||
|
Price: price,
|
||||||
|
Volume: quantity,
|
||||||
|
IsBuyer: t.IsBuyer,
|
||||||
|
IsMaker: t.IsMaker,
|
||||||
|
Fee: fee,
|
||||||
|
FeeCurrency: t.CommissionAsset,
|
||||||
|
Time: tt,
|
||||||
|
})
|
||||||
|
|
||||||
|
lastTradeID = t.ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return trades, nil
|
||||||
|
}
|
39
util/math.go
Normal file
39
util/math.go
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
const MaxDigits = 18 // MAX_INT64 ~ 9 * 10^18
|
||||||
|
|
||||||
|
var Pow10Table = [MaxDigits + 1]int64{
|
||||||
|
1, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18,
|
||||||
|
}
|
||||||
|
|
||||||
|
func Pow10(n int64) int64 {
|
||||||
|
if n < 0 || n > MaxDigits {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return Pow10Table[n]
|
||||||
|
}
|
||||||
|
|
||||||
|
var NegPow10Table = [MaxDigits + 1]float64{
|
||||||
|
1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15, 1e-16, 1e-17, 1e-18,
|
||||||
|
}
|
||||||
|
|
||||||
|
func NegPow10(n int64) float64 {
|
||||||
|
if n < 0 || n > MaxDigits {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
return NegPow10Table[n]
|
||||||
|
}
|
||||||
|
|
||||||
|
func Float64ToStr(input float64) string {
|
||||||
|
return strconv.FormatFloat(input, 'f', -1, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Float64ToInt64(input float64) int64 {
|
||||||
|
// eliminate rounding error for IEEE754 floating points
|
||||||
|
return int64(math.Round(input))
|
||||||
|
}
|
19
util/rate_limit.go
Normal file
19
util/rate_limit.go
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ShouldDelay(l *rate.Limiter, minInterval time.Duration) time.Duration {
|
||||||
|
return l.Reserve().Delay()
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewValidLimiter(r rate.Limit, b int) (*rate.Limiter, error) {
|
||||||
|
if b <= 0 || r <= 0 {
|
||||||
|
return nil, fmt.Errorf("Bad rate limit config. Insufficient tokens. (rate=%f, b=%d)", r, b)
|
||||||
|
}
|
||||||
|
return rate.NewLimiter(r, b), nil
|
||||||
|
}
|
43
util/rate_limit_test.go
Normal file
43
util/rate_limit_test.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewValidRateLimiter(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
r rate.Limit
|
||||||
|
b int
|
||||||
|
hasError bool
|
||||||
|
}{
|
||||||
|
{"valid limiter", 0.1, 1, false},
|
||||||
|
{"zero rate", 0, 1, true},
|
||||||
|
{"zero burst", 0.1, 0, true},
|
||||||
|
{"both zero", 0, 0, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
limiter, err := NewValidLimiter(c.r, c.b)
|
||||||
|
assert.Equal(t, c.hasError, err != nil)
|
||||||
|
if !c.hasError {
|
||||||
|
assert.NotNil(t, limiter)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldDelay(t *testing.T) {
|
||||||
|
minInterval := time.Second * 3
|
||||||
|
maxRate := rate.Limit(1 / minInterval.Seconds())
|
||||||
|
limiter := rate.NewLimiter(maxRate, 1)
|
||||||
|
assert.Equal(t, time.Duration(0), ShouldDelay(limiter, minInterval))
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
assert.True(t, ShouldDelay(limiter, minInterval) > 0)
|
||||||
|
}
|
||||||
|
}
|
69
util/retry.go
Normal file
69
util/retry.go
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
InfiniteRetry = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
type RetryPredicator func(e error) bool
|
||||||
|
|
||||||
|
// Retry retrys the passed functin for "attempts" times, if passed function return error. Setting attemps to zero means keep retrying.
|
||||||
|
func Retry(ctx context.Context, attempts int, duration time.Duration, fnToRetry func() error, errHandler func(error), predicators ...RetryPredicator) (err error) {
|
||||||
|
infinite := false
|
||||||
|
if attempts == InfiniteRetry {
|
||||||
|
infinite = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for attempts > 0 || infinite {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
errMsg := "return for context done"
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, errMsg)
|
||||||
|
} else {
|
||||||
|
return errors.New(errMsg)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if err = fnToRetry(); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !needRetry(err, predicators) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = errors.Wrapf(err, "failed in retry: countdown: %v", attempts)
|
||||||
|
|
||||||
|
if errHandler != nil {
|
||||||
|
errHandler(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !infinite {
|
||||||
|
attempts--
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(duration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func needRetry(err error, predicators []RetryPredicator) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no predicators specified, we will retry for all errors
|
||||||
|
if len(predicators) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return predicators[0](err)
|
||||||
|
}
|
106
util/retry_test.go
Normal file
106
util/retry_test.go
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func addAndCheck(a *int, target int) error {
|
||||||
|
if *a++; *a == target {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("a is not %v. It is %v\n", target, *a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetry(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
input int
|
||||||
|
targetNum int
|
||||||
|
ans int
|
||||||
|
ansErr error
|
||||||
|
}
|
||||||
|
tests := []test{
|
||||||
|
{input: 0, targetNum: 3, ans: 3, ansErr: nil},
|
||||||
|
{input: 0, targetNum: 10, ans: 3, ansErr: errors.New("failed in retry")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
errHandled := false
|
||||||
|
|
||||||
|
err := Retry(context.Background(), 3, 1*time.Second, func() error {
|
||||||
|
return addAndCheck(&tc.input, tc.targetNum)
|
||||||
|
}, func(e error) { errHandled = true })
|
||||||
|
|
||||||
|
assert.Equal(t, true, errHandled)
|
||||||
|
if tc.ansErr == nil {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
assert.Contains(t, err.Error(), tc.ansErr.Error())
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.ans, tc.input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryWithPredicator(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
count int
|
||||||
|
f func() error
|
||||||
|
errHandler func(error)
|
||||||
|
predicator RetryPredicator
|
||||||
|
ansCount int
|
||||||
|
ansErr error
|
||||||
|
}
|
||||||
|
knownErr := errors.New("Duplicate entry '1-389837488-1' for key 'UNI_Trade'")
|
||||||
|
unknownErr := errors.New("Some Error")
|
||||||
|
tests := []test{
|
||||||
|
{
|
||||||
|
predicator: func(err error) bool {
|
||||||
|
return !strings.Contains(err.Error(), "Duplicate entry")
|
||||||
|
},
|
||||||
|
f: func() error { return knownErr },
|
||||||
|
ansCount: 1,
|
||||||
|
ansErr: knownErr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
predicator: func(err error) bool {
|
||||||
|
return !strings.Contains(err.Error(), "Duplicate entry")
|
||||||
|
},
|
||||||
|
f: func() error { return unknownErr },
|
||||||
|
ansCount: 3,
|
||||||
|
ansErr: unknownErr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
attempts := 3
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
err := Retry(ctx, attempts, 100*time.Millisecond, func() error {
|
||||||
|
tc.count++
|
||||||
|
return tc.f()
|
||||||
|
}, tc.errHandler, tc.predicator)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.ansCount, tc.count)
|
||||||
|
assert.EqualError(t, errors.Cause(err), tc.ansErr.Error(), "should be equal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryCtxCancel(t *testing.T) {
|
||||||
|
result := int(0)
|
||||||
|
target := int(3)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := Retry(ctx, 5, 1*time.Second, func() error { return addAndCheck(&result, target) }, func(error) {})
|
||||||
|
assert.Error(t, err)
|
||||||
|
fmt.Println("Error:", err.Error())
|
||||||
|
assert.Equal(t, int(0), result)
|
||||||
|
}
|
78
websocket/backoff.go
Normal file
78
websocket/backoff.go
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultMinBackoff = 5 * time.Second
|
||||||
|
DefaultMaxBackoff = time.Minute
|
||||||
|
|
||||||
|
DefaultBackoffFactor = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type Backoff struct {
|
||||||
|
attempt int64
|
||||||
|
Factor float64
|
||||||
|
cur, Min, Max time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Duration returns the duration for the current attempt before incrementing
|
||||||
|
// the attempt counter. See ForAttempt.
|
||||||
|
func (b *Backoff) Duration() time.Duration {
|
||||||
|
d := b.calculate(b.attempt)
|
||||||
|
b.attempt++
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Backoff) calculate(attempt int64) time.Duration {
|
||||||
|
min := b.Min
|
||||||
|
if min <= 0 {
|
||||||
|
min = DefaultMinBackoff
|
||||||
|
}
|
||||||
|
max := b.Max
|
||||||
|
if max <= 0 {
|
||||||
|
max = DefaultMaxBackoff
|
||||||
|
}
|
||||||
|
if min >= max {
|
||||||
|
return max
|
||||||
|
}
|
||||||
|
factor := b.Factor
|
||||||
|
if factor <= 0 {
|
||||||
|
factor = DefaultBackoffFactor
|
||||||
|
}
|
||||||
|
cur := b.cur
|
||||||
|
if cur < min {
|
||||||
|
cur = min
|
||||||
|
} else if cur > max {
|
||||||
|
cur = max
|
||||||
|
}
|
||||||
|
|
||||||
|
//calculate this duration
|
||||||
|
next := cur
|
||||||
|
if attempt > 0 {
|
||||||
|
next = time.Duration(float64(cur) * factor)
|
||||||
|
}
|
||||||
|
|
||||||
|
if next < cur {
|
||||||
|
// overflow
|
||||||
|
next = max
|
||||||
|
} else if next <= min {
|
||||||
|
next = min
|
||||||
|
} else if next >= max {
|
||||||
|
next = max
|
||||||
|
}
|
||||||
|
b.cur = next
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset restarts the current attempt counter at zero.
|
||||||
|
func (b *Backoff) Reset() {
|
||||||
|
b.attempt = 0
|
||||||
|
b.cur = b.Min
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt returns the current attempt counter value.
|
||||||
|
func (b *Backoff) Attempt() int64 {
|
||||||
|
return b.attempt
|
||||||
|
}
|
172
websocket/backoff_test.go
Normal file
172
websocket/backoff_test.go
Normal file
|
@ -0,0 +1,172 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBackoff(t *testing.T) {
|
||||||
|
b := &Backoff{
|
||||||
|
Min: 200 * time.Millisecond,
|
||||||
|
Max: 100 * time.Second,
|
||||||
|
Factor: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
equals(t, b.Duration(), 200*time.Millisecond)
|
||||||
|
equals(t, b.Duration(), 400*time.Millisecond)
|
||||||
|
equals(t, b.Duration(), 800*time.Millisecond)
|
||||||
|
b.Reset()
|
||||||
|
equals(t, b.Duration(), 200*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReachMaxBackoff(t *testing.T) {
|
||||||
|
b := &Backoff{
|
||||||
|
Min: 2 * time.Second,
|
||||||
|
Max: 100 * time.Second,
|
||||||
|
Factor: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
equals(t, b.Duration(), 2*time.Second)
|
||||||
|
equals(t, b.Duration(), 20*time.Second)
|
||||||
|
equals(t, b.Duration(), b.Max)
|
||||||
|
b.Reset()
|
||||||
|
equals(t, b.Duration(), 2*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAttempt(t *testing.T) {
|
||||||
|
b := &Backoff{
|
||||||
|
Min: 2 * time.Second,
|
||||||
|
Max: 100 * time.Second,
|
||||||
|
Factor: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
equals(t, b.Attempt(), int64(0))
|
||||||
|
equals(t, b.Duration(), 2*time.Second)
|
||||||
|
equals(t, b.Attempt(), int64(1))
|
||||||
|
equals(t, b.Duration(), 20*time.Second)
|
||||||
|
equals(t, b.Attempt(), int64(2))
|
||||||
|
equals(t, b.Duration(), b.Max)
|
||||||
|
equals(t, b.Attempt(), int64(3))
|
||||||
|
b.Reset()
|
||||||
|
equals(t, b.Attempt(), int64(0))
|
||||||
|
equals(t, b.Duration(), 2*time.Second)
|
||||||
|
equals(t, b.Attempt(), int64(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAttemptWithZeroValue(t *testing.T) {
|
||||||
|
b := &Backoff{}
|
||||||
|
|
||||||
|
cur := DefaultMinBackoff
|
||||||
|
equals(t, b.Attempt(), int64(0))
|
||||||
|
equals(t, b.Duration(), cur)
|
||||||
|
for i := 1; i < 10; i++ {
|
||||||
|
equals(t, b.Attempt(), int64(i))
|
||||||
|
cur *= DefaultBackoffFactor
|
||||||
|
if cur >= DefaultMaxBackoff {
|
||||||
|
equals(t, b.Duration(), DefaultMaxBackoff)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
equals(t, b.Duration(), cur)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Reset()
|
||||||
|
equals(t, b.Attempt(), int64(0))
|
||||||
|
equals(t, b.Duration(), DefaultMinBackoff)
|
||||||
|
equals(t, b.Attempt(), int64(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNegOrZeroMin(t *testing.T) {
|
||||||
|
b := &Backoff{
|
||||||
|
Min: 0,
|
||||||
|
}
|
||||||
|
equals(t, b.Duration(), DefaultMinBackoff)
|
||||||
|
|
||||||
|
b2 := &Backoff{
|
||||||
|
Min: -1 * time.Second,
|
||||||
|
}
|
||||||
|
equals(t, b2.Duration(), DefaultMinBackoff)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNegOrZeroMax(t *testing.T) {
|
||||||
|
b := &Backoff{
|
||||||
|
Min: DefaultMaxBackoff,
|
||||||
|
Max: 0,
|
||||||
|
}
|
||||||
|
equals(t, b.Duration(), DefaultMaxBackoff)
|
||||||
|
|
||||||
|
b2 := &Backoff{
|
||||||
|
Min: DefaultMaxBackoff,
|
||||||
|
Max: -1 * time.Second,
|
||||||
|
}
|
||||||
|
equals(t, b2.Duration(), DefaultMaxBackoff)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMinLargerThanMax(t *testing.T) {
|
||||||
|
b := &Backoff{
|
||||||
|
Min: 500 * time.Second,
|
||||||
|
Max: 100 * time.Second,
|
||||||
|
Factor: 1,
|
||||||
|
}
|
||||||
|
equals(t, b.Duration(), b.Max)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFakeCur(t *testing.T) {
|
||||||
|
b := &Backoff{
|
||||||
|
Min: 100 * time.Second,
|
||||||
|
Max: 500 * time.Second,
|
||||||
|
Factor: 1,
|
||||||
|
cur: 1000 * time.Second,
|
||||||
|
}
|
||||||
|
equals(t, b.Duration(), b.Max)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLargeFactor(t *testing.T) {
|
||||||
|
b := &Backoff{
|
||||||
|
Min: 1 * time.Second,
|
||||||
|
Max: 10 * time.Second,
|
||||||
|
Factor: math.MaxFloat64,
|
||||||
|
}
|
||||||
|
equals(t, b.Duration(), b.Min)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
equals(t, b.Duration(), b.Max)
|
||||||
|
}
|
||||||
|
b.Reset()
|
||||||
|
equals(t, b.Duration(), b.Min)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLargeMinMax(t *testing.T) {
|
||||||
|
b := &Backoff{
|
||||||
|
Min: time.Duration(math.MaxInt64 - 1),
|
||||||
|
Max: time.Duration(math.MaxInt64),
|
||||||
|
Factor: math.MaxFloat64,
|
||||||
|
}
|
||||||
|
equals(t, b.Duration(), b.Min)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
equals(t, b.Duration(), b.Max)
|
||||||
|
}
|
||||||
|
b.Reset()
|
||||||
|
equals(t, b.Duration(), b.Min)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManyAttempts(t *testing.T) {
|
||||||
|
b := &Backoff{
|
||||||
|
Min: 10 * time.Second,
|
||||||
|
Max: time.Duration(math.MaxInt64),
|
||||||
|
Factor: 1000,
|
||||||
|
}
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
b.Duration()
|
||||||
|
}
|
||||||
|
equals(t, b.Duration(), b.Max)
|
||||||
|
b.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
func equals(t *testing.T, v1, v2 interface{}) {
|
||||||
|
if !reflect.DeepEqual(v1, v2) {
|
||||||
|
assert.Failf(t, "not equal", "Got %v, Expecting %v", v1, v2)
|
||||||
|
}
|
||||||
|
}
|
368
websocket/client.go
Normal file
368
websocket/client.go
Normal file
|
@ -0,0 +1,368 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
|
||||||
|
"github.com/c9s/bbgo/pkg/log"
|
||||||
|
"github.com/c9s/bbgo/pkg/util"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
const DefaultMessageBufferSize = 128
|
||||||
|
|
||||||
|
const DefaultWriteTimeout = 30 * time.Second
|
||||||
|
const DefaultReadTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
var ErrReconnectContextDone = errors.New("reconnect canceled due to context done.")
|
||||||
|
var ErrReconnectFailed = errors.New("failed to reconnect")
|
||||||
|
var ErrConnectionLost = errors.New("connection lost")
|
||||||
|
|
||||||
|
var MaxReconnectRate = rate.Limit(1 / DefaultMinBackoff.Seconds())
|
||||||
|
|
||||||
|
// WebSocketClient allows to connect and receive stream data
|
||||||
|
type WebSocketClient struct {
|
||||||
|
// Url is the websocket connection location, start with ws:// or wss://
|
||||||
|
Url string
|
||||||
|
|
||||||
|
// conn is the current websocket connection, please note the connection
|
||||||
|
// object can be replaced with a new connection object when the connection
|
||||||
|
// is unexpected closed.
|
||||||
|
conn *websocket.Conn
|
||||||
|
|
||||||
|
// Dialer is used for creating the websocket connection
|
||||||
|
Dialer *websocket.Dialer
|
||||||
|
|
||||||
|
// requestHeader is used for the Dial function call. Some credential can be
|
||||||
|
// stored in the http request header for authentication
|
||||||
|
requestHeader http.Header
|
||||||
|
|
||||||
|
// messages is a read-only channel, received messages will be sent to this
|
||||||
|
// channel.
|
||||||
|
messages chan Message
|
||||||
|
|
||||||
|
readTimeout time.Duration
|
||||||
|
|
||||||
|
writeTimeout time.Duration
|
||||||
|
|
||||||
|
onConnect []func(c Client)
|
||||||
|
|
||||||
|
onDisconnect []func(c Client)
|
||||||
|
|
||||||
|
// cancel is mapped to the ctx context object
|
||||||
|
cancel func()
|
||||||
|
|
||||||
|
readerClosed chan struct{}
|
||||||
|
|
||||||
|
connected bool
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
reconnectCh chan struct{}
|
||||||
|
|
||||||
|
backoff Backoff
|
||||||
|
|
||||||
|
limiter *rate.Limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
// websocket.BinaryMessage or websocket.TextMessage
|
||||||
|
Type int
|
||||||
|
Body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) Messages() <-chan Message {
|
||||||
|
return c.messages
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) SetReadTimeout(timeout time.Duration) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.readTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) SetWriteTimeout(timeout time.Duration) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.writeTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) OnConnect(f func(c Client)) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.onConnect = append(c.onConnect, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) OnDisconnect(f func(c Client)) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.onDisconnect = append(c.onDisconnect, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) WriteTextMessage(message []byte) error {
|
||||||
|
return c.WriteMessage(websocket.TextMessage, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) WriteBinaryMessage(message []byte) error {
|
||||||
|
return c.WriteMessage(websocket.BinaryMessage, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) WriteMessage(messageType int, data []byte) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if !c.connected {
|
||||||
|
return ErrConnectionLost
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.conn.WriteMessage(messageType, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) readMessages() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
if !c.connected {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return ErrConnectionLost
|
||||||
|
}
|
||||||
|
timeout := c.readTimeout
|
||||||
|
conn := c.conn
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
msgtype, message, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.messages <- Message{msgtype, message}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// listen starts a goroutine for reading message and tries to re-connect to the
|
||||||
|
// server when the reader returns error
|
||||||
|
//
|
||||||
|
// Please note we should always break the reader loop if there is any error
|
||||||
|
// returned from the server.
|
||||||
|
func (c *WebSocketClient) listen(ctx context.Context) {
|
||||||
|
// The life time of both channels "readerClosed" and "reconnectCh" is bound to one connection.
|
||||||
|
// Each channel should be created before loop starts and be closed after loop ends.
|
||||||
|
// "readerClosed" is used to inform "Close()" reader loop ends.
|
||||||
|
// "reconnectCh" is used to centralize reconnection logics in this reader loop.
|
||||||
|
c.mu.Lock()
|
||||||
|
c.readerClosed = make(chan struct{})
|
||||||
|
c.reconnectCh = make(chan struct{}, 1)
|
||||||
|
c.mu.Unlock()
|
||||||
|
defer func() {
|
||||||
|
c.mu.Lock()
|
||||||
|
close(c.readerClosed)
|
||||||
|
close(c.reconnectCh)
|
||||||
|
c.reconnectCh = nil
|
||||||
|
c.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-c.reconnectCh:
|
||||||
|
// it could be i/o timeout for network disconnection
|
||||||
|
// or it could be invoked from outside.
|
||||||
|
c.SetDisconnected()
|
||||||
|
var maxTries = 1
|
||||||
|
if _, response, err := c.reconnect(ctx, maxTries); err != nil {
|
||||||
|
if err == ErrReconnectContextDone {
|
||||||
|
log.Debugf("[websocket] context canceled. stop reconnecting.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Warnf("[websocket] failed to reconnect after %d tries!! error: %v response: %v", maxTries, err, response)
|
||||||
|
c.Reconnect()
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
if err := c.readMessages(); err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
|
||||||
|
log.Warnf("[websocket] unexpected close error reconnecting: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("[websocket] failed to read message. error: %+v", err)
|
||||||
|
c.Reconnect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconnect triggers reconnection logics
|
||||||
|
func (c *WebSocketClient) Reconnect() {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
// c.reconnectCh is a buffered channel with cap=1.
|
||||||
|
// At most one reconnect signal could be processed.
|
||||||
|
case c.reconnectCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
// Entering here means it is already reconnecting.
|
||||||
|
// Drop the current reconnect signal.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close gracefully shuts down the reader and the connection
|
||||||
|
// ctx is the context used for shutdown process.
|
||||||
|
func (c *WebSocketClient) Close() (err error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
// leave the listen goroutine before we close the connection
|
||||||
|
// checking nil is to handle calling "Close" before "Connect" is called
|
||||||
|
if c.cancel != nil {
|
||||||
|
c.cancel()
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
c.SetDisconnected()
|
||||||
|
|
||||||
|
// wait for the reader func to be closed
|
||||||
|
if c.readerClosed != nil {
|
||||||
|
<-c.readerClosed
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconnect tries to create a new connection from the existing dialer
|
||||||
|
func (c *WebSocketClient) reconnect(ctx context.Context, maxTries int) (*websocket.Conn, *http.Response, error) {
|
||||||
|
log.Debugf("[websocket] start reconnecting to %q", c.Url)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, nil, ErrReconnectContextDone
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if s := util.ShouldDelay(c.limiter, DefaultMinBackoff); s > 0 {
|
||||||
|
log.Warn("[websocket] reconnect too frequently. Sleep for ", s)
|
||||||
|
time.Sleep(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("[websocket] reconnecting x %d to %q", c.backoff.Attempt()+1, c.Url)
|
||||||
|
conn, resp, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
|
||||||
|
if err != nil {
|
||||||
|
dur := c.backoff.Duration()
|
||||||
|
log.Warnf("failed to dial %s: %v, response: %+v. Wait for %v", c.Url, err, resp, dur)
|
||||||
|
time.Sleep(dur)
|
||||||
|
return nil, nil, ErrReconnectFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("[websocket] reconnected to %q", c.Url)
|
||||||
|
// Reset backoff value if connected.
|
||||||
|
c.backoff.Reset()
|
||||||
|
c.setConn(conn)
|
||||||
|
c.setPingHandler(conn)
|
||||||
|
|
||||||
|
return conn, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conn returns the current active connection instance
|
||||||
|
func (c *WebSocketClient) Conn() (conn *websocket.Conn) {
|
||||||
|
c.mu.Lock()
|
||||||
|
conn = c.conn
|
||||||
|
c.mu.Unlock()
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) setConn(conn *websocket.Conn) {
|
||||||
|
// Disconnect old connection before replacing with new one.
|
||||||
|
c.SetDisconnected()
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
c.conn = conn
|
||||||
|
c.connected = true
|
||||||
|
c.mu.Unlock()
|
||||||
|
for _, f := range c.onConnect {
|
||||||
|
go f(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) setPingHandler(conn *websocket.Conn) {
|
||||||
|
conn.SetPingHandler(func(message string) error {
|
||||||
|
if err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) SetDisconnected() {
|
||||||
|
c.mu.Lock()
|
||||||
|
closed := false
|
||||||
|
if c.conn != nil {
|
||||||
|
closed = true
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
|
c.connected = false
|
||||||
|
c.conn = nil
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
// Only call disconnect callbacks when a connection is closed
|
||||||
|
for _, f := range c.onDisconnect {
|
||||||
|
go f(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) IsConnected() (ret bool) {
|
||||||
|
c.mu.Lock()
|
||||||
|
ret = c.connected
|
||||||
|
c.mu.Unlock()
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WebSocketClient) Connect(basectx context.Context) error {
|
||||||
|
// maintain a context by the client it self, so that we can manually shutdown the connection
|
||||||
|
ctx, cancel := context.WithCancel(basectx)
|
||||||
|
c.cancel = cancel
|
||||||
|
|
||||||
|
conn, _, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
|
||||||
|
if err == nil {
|
||||||
|
// setup connection only when connected
|
||||||
|
c.setConn(conn)
|
||||||
|
c.setPingHandler(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1) if connection is built up, start listening for messages.
|
||||||
|
// 2) if connection is NOT ready, start reconnecting infinitely.
|
||||||
|
go c.listen(ctx)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(url string, requestHeader http.Header) *WebSocketClient {
|
||||||
|
return NewWithDialer(url, websocket.DefaultDialer, requestHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWithDialer(url string, d *websocket.Dialer, requestHeader http.Header) *WebSocketClient {
|
||||||
|
limiter, err := util.NewValidLimiter(MaxReconnectRate, 1)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Panic("Invalid rate limiter")
|
||||||
|
}
|
||||||
|
return &WebSocketClient{
|
||||||
|
Url: url,
|
||||||
|
Dialer: d,
|
||||||
|
requestHeader: requestHeader,
|
||||||
|
readTimeout: DefaultReadTimeout,
|
||||||
|
writeTimeout: DefaultWriteTimeout,
|
||||||
|
messages: make(chan Message, DefaultMessageBufferSize),
|
||||||
|
limiter: limiter,
|
||||||
|
}
|
||||||
|
}
|
202
websocket/client_test.go
Normal file
202
websocket/client_test.go
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"***REMOVED***/pkg/log"
|
||||||
|
"***REMOVED***/pkg/testing/testutil"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func messageTicker(t *testing.T, ctx context.Context, conn *websocket.Conn) {
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case tt := <-ticker.C:
|
||||||
|
err := conn.WriteMessage(websocket.TextMessage, []byte(tt.String()))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func messageReader(t *testing.T, ctx context.Context, conn *websocket.Conn) {
|
||||||
|
for {
|
||||||
|
_, message, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||||
|
log.Errorf("unexpected closed error: %v", err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
t.Logf("message: %v", message)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconnect(t *testing.T) {
|
||||||
|
basectx := context.Background()
|
||||||
|
|
||||||
|
wsHandler := func(ctx context.Context) func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var upgrader = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
}
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("upgrade error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go messageTicker(t, ctx, conn)
|
||||||
|
messageReader(t, ctx, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
serverctx, cancelserver := context.WithCancel(basectx)
|
||||||
|
server := testutil.NewWebSocketServerFunc(t, 0, "/ws", wsHandler(serverctx))
|
||||||
|
go func() {
|
||||||
|
err := server.ListenAndServe()
|
||||||
|
assert.Equal(t, http.ErrServerClosed, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(server.Addr)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
log.Debugf("waiting for server port ready")
|
||||||
|
testutil.WaitForPort(t, server.Addr, 3)
|
||||||
|
|
||||||
|
u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"}
|
||||||
|
log.Infof("url: %s", u.String())
|
||||||
|
|
||||||
|
client := New(u.String(), nil)
|
||||||
|
|
||||||
|
// start the message reader
|
||||||
|
clientctx, cancelClient := context.WithCancel(basectx)
|
||||||
|
defer cancelClient()
|
||||||
|
|
||||||
|
err = client.Connect(clientctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, client)
|
||||||
|
client.SetReadTimeout(2 * time.Second)
|
||||||
|
|
||||||
|
// read one message
|
||||||
|
log.Debugf("waiting for message...")
|
||||||
|
msg := <-client.messages
|
||||||
|
log.Debugf("received message: %v", msg)
|
||||||
|
|
||||||
|
triggerReconnectManyTimes := func() {
|
||||||
|
// Forcedly reconnect multiple times
|
||||||
|
for i := 0; i < 10; i = i + 1 {
|
||||||
|
client.Reconnect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("reconnecting...")
|
||||||
|
triggerReconnectManyTimes()
|
||||||
|
msg = <-client.messages
|
||||||
|
log.Debugf("received message: %v", msg)
|
||||||
|
|
||||||
|
log.Debugf("shutting down server...")
|
||||||
|
err = server.Shutdown(serverctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
cancelserver()
|
||||||
|
server.Close()
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
triggerReconnectManyTimes()
|
||||||
|
|
||||||
|
log.Debugf("restarting server...")
|
||||||
|
server2ctx, cancelserver2 := context.WithCancel(basectx)
|
||||||
|
server2 := testutil.NewWebSocketServerWithAddressFunc(t, server.Addr, "/ws", wsHandler(server2ctx))
|
||||||
|
go func() {
|
||||||
|
err := server2.ListenAndServe()
|
||||||
|
assert.Equal(t, http.ErrServerClosed, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
triggerReconnectManyTimes()
|
||||||
|
|
||||||
|
log.Debugf("waiting for server2 port ready")
|
||||||
|
testutil.WaitForPort(t, server.Addr, 3)
|
||||||
|
|
||||||
|
log.Debugf("waiting for message from server2...")
|
||||||
|
msg = <-client.messages
|
||||||
|
log.Debugf("received message: %v", msg)
|
||||||
|
|
||||||
|
err = client.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = server2.Shutdown(server2ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
cancelserver2()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnect(t *testing.T) {
|
||||||
|
basectx := context.Background()
|
||||||
|
ctx, cancel := context.WithCancel(basectx)
|
||||||
|
|
||||||
|
server := testutil.NewWebSocketServerFunc(t, 0, "/ws", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var upgrader = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("upgrade error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go messageTicker(t, ctx, conn)
|
||||||
|
messageReader(t, ctx, conn)
|
||||||
|
})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := server.ListenAndServe()
|
||||||
|
assert.Equal(t, http.ErrServerClosed, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(server.Addr)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
testutil.WaitForPort(t, server.Addr, 3)
|
||||||
|
|
||||||
|
u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"}
|
||||||
|
t.Logf("url: %s", u.String())
|
||||||
|
|
||||||
|
client := New(u.String(), nil)
|
||||||
|
|
||||||
|
err = client.Connect(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, client)
|
||||||
|
client.SetReadTimeout(2 * time.Second)
|
||||||
|
|
||||||
|
// read one message
|
||||||
|
t.Logf("waiting for message...")
|
||||||
|
msg := <-client.messages
|
||||||
|
t.Logf("received message: %v", msg)
|
||||||
|
|
||||||
|
// recreate server at the same address
|
||||||
|
client.Close()
|
||||||
|
|
||||||
|
t.Logf("shutting down server...")
|
||||||
|
cancel()
|
||||||
|
err = server.Shutdown(basectx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
}
|
21
websocket/interface.go
Normal file
21
websocket/interface.go
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:generate mockery -name=Client
|
||||||
|
|
||||||
|
type Client interface {
|
||||||
|
SetWriteTimeout(time.Duration)
|
||||||
|
SetReadTimeout(time.Duration)
|
||||||
|
OnConnect(func(c Client))
|
||||||
|
OnDisconnect(func(c Client))
|
||||||
|
Connect(context.Context) error
|
||||||
|
Reconnect()
|
||||||
|
Close() error
|
||||||
|
IsConnected() bool
|
||||||
|
WriteJSON(interface{}) error
|
||||||
|
Messages() <-chan Message
|
||||||
|
}
|
55
websocket/json.go
Normal file
55
websocket/json.go
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WriteJSON writes the JSON encoding of v as a message.
|
||||||
|
//
|
||||||
|
// See the documentation for encoding/json Marshal for details about the
|
||||||
|
// conversion of Go values to JSON.
|
||||||
|
func (c *WebSocketClient) WriteJSON(v interface{}) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if !c.connected {
|
||||||
|
return ErrConnectionLost
|
||||||
|
}
|
||||||
|
w, err := c.conn.NextWriter(websocket.TextMessage)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err1 := json.NewEncoder(w).Encode(v)
|
||||||
|
err2 := w.Close()
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadJSON reads the next JSON-encoded message from the connection and stores
|
||||||
|
// it in the value pointed to by v.
|
||||||
|
//
|
||||||
|
// See the documentation for the encoding/json Unmarshal function for details
|
||||||
|
// about the conversion of JSON to a Go value.
|
||||||
|
func (c *WebSocketClient) ReadJSON(v interface{}) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if !c.connected {
|
||||||
|
return ErrConnectionLost
|
||||||
|
}
|
||||||
|
_, r, err := c.conn.NextReader()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = json.NewDecoder(r).Decode(v)
|
||||||
|
if err == io.EOF {
|
||||||
|
// One value is expected in the message.
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
110
websocket/mocks/Client.go
Normal file
110
websocket/mocks/Client.go
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
// Code generated by mockery v1.0.0. DO NOT EDIT.
|
||||||
|
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import context "context"
|
||||||
|
import mock "github.com/stretchr/testify/mock"
|
||||||
|
import time "time"
|
||||||
|
import websocket "***REMOVED***/pkg/net/http/websocket"
|
||||||
|
|
||||||
|
// Client is an autogenerated mock type for the Client type
|
||||||
|
type Client struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close provides a mock function with given fields:
|
||||||
|
func (_m *Client) Close() error {
|
||||||
|
ret := _m.Called()
|
||||||
|
|
||||||
|
var r0 error
|
||||||
|
if rf, ok := ret.Get(0).(func() error); ok {
|
||||||
|
r0 = rf()
|
||||||
|
} else {
|
||||||
|
r0 = ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect provides a mock function with given fields: _a0
|
||||||
|
func (_m *Client) Connect(_a0 context.Context) error {
|
||||||
|
ret := _m.Called(_a0)
|
||||||
|
|
||||||
|
var r0 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||||
|
r0 = rf(_a0)
|
||||||
|
} else {
|
||||||
|
r0 = ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnected provides a mock function with given fields:
|
||||||
|
func (_m *Client) IsConnected() bool {
|
||||||
|
ret := _m.Called()
|
||||||
|
|
||||||
|
var r0 bool
|
||||||
|
if rf, ok := ret.Get(0).(func() bool); ok {
|
||||||
|
r0 = rf()
|
||||||
|
} else {
|
||||||
|
r0 = ret.Get(0).(bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Messages provides a mock function with given fields:
|
||||||
|
func (_m *Client) Messages() <-chan websocket.Message {
|
||||||
|
ret := _m.Called()
|
||||||
|
|
||||||
|
var r0 <-chan websocket.Message
|
||||||
|
if rf, ok := ret.Get(0).(func() <-chan websocket.Message); ok {
|
||||||
|
r0 = rf()
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).(<-chan websocket.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnConnect provides a mock function with given fields: _a0
|
||||||
|
func (_m *Client) OnConnect(_a0 func(websocket.Client)) {
|
||||||
|
_m.Called(_a0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnDisconnect provides a mock function with given fields: _a0
|
||||||
|
func (_m *Client) OnDisconnect(_a0 func(websocket.Client)) {
|
||||||
|
_m.Called(_a0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconnect provides a mock function with given fields:
|
||||||
|
func (_m *Client) Reconnect() {
|
||||||
|
_m.Called()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadTimeout provides a mock function with given fields: _a0
|
||||||
|
func (_m *Client) SetReadTimeout(_a0 time.Duration) {
|
||||||
|
_m.Called(_a0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteTimeout provides a mock function with given fields: _a0
|
||||||
|
func (_m *Client) SetWriteTimeout(_a0 time.Duration) {
|
||||||
|
_m.Called(_a0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteJSON provides a mock function with given fields: _a0
|
||||||
|
func (_m *Client) WriteJSON(_a0 interface{}) error {
|
||||||
|
ret := _m.Called(_a0)
|
||||||
|
|
||||||
|
var r0 error
|
||||||
|
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
||||||
|
r0 = rf(_a0)
|
||||||
|
} else {
|
||||||
|
r0 = ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user