This commit is contained in:
c9s 2020-06-08 10:42:50 +08:00
commit 3fe812264f
18 changed files with 1825 additions and 0 deletions

40
bbgo/event.go Normal file
View File

@ -0,0 +1,40 @@
package bbgo
/*
kline
{
"e": "kline", // KLineEvent type
"E": 123456789, // KLineEvent time
"s": "BNBBTC", // Symbol
"k": {
"t": 123400000, // Kline start time
"T": 123460000, // Kline close time
"s": "BNBBTC", // Symbol
"i": "1m", // Interval
"f": 100, // First trade ID
"L": 200, // Last trade ID
"o": "0.0010", // Open price
"c": "0.0020", // Close price
"h": "0.0025", // High price
"l": "0.0015", // Low price
"v": "1000", // Base asset volume
"n": 100, // Number of trades
"x": false, // Is this kline closed?
"q": "1.0000", // Quote asset volume
"V": "500", // Taker buy base asset volume
"Q": "0.500", // Taker buy quote asset volume
"B": "123456" // Ignore
}
}
*/
type EventBase struct {
Event string `json:"e"` // event
Time int64 `json:"E"`
}

137
bbgo/kline.go Normal file
View File

@ -0,0 +1,137 @@
package bbgo
import (
"fmt"
)
type KLineEvent struct {
EventBase
Symbol string `json:"s"`
KLine *KLine `json:"k,omitempty"`
}
type KLine struct {
StartTime int64 `json:"t"`
EndTime int64 `json:"T"`
Symbol string `json:"s"`
Interval string `json:"i"`
Open string `json:"o"`
Close string `json:"c"`
High string `json:"h"`
Low string `json:"l"`
Volume string `json:"V"` // taker buy base asset volume (like 10 BTC)
QuoteVolume string `json:"Q"` // taker buy quote asset volume (like 1000USDT)
LastTradeID int `json:"L"`
NumberOfTrades int `json:"n"`
Closed bool `json:"x"`
}
func (k KLine) GetTrend() int {
o := k.GetOpen()
c := k.GetClose()
if c > o {
return 1
} else if c < o {
return -1
}
return 0
}
func (k KLine) GetHigh() float64 {
return MustParseFloat(k.High)
}
func (k KLine) GetLow() float64 {
return MustParseFloat(k.Low)
}
func (k KLine) GetOpen() float64 {
return MustParseFloat(k.Open)
}
func (k KLine) GetClose() float64 {
return MustParseFloat(k.Close)
}
func (k KLine) GetMaxChange() float64 {
return k.GetHigh() - k.GetLow()
}
func (k KLine) GetThickness() float64 {
return k.GetChange() / k.GetMaxChange()
}
func (k KLine) GetChange() float64 {
return k.GetClose() - k.GetOpen()
}
func (k KLine) String() string {
return fmt.Sprintf("%s %s Open: % 14s Close: % 14s High: % 14s Low: % 14s Volume: % 13s Change: % 13f %s", k.Symbol, k.Interval, k.Open, k.Close, k.High, k.Low, k.Volume, k.GetChange(), k.Interval)
}
type KLineWindow []KLine
func (w KLineWindow) Len() int {
return len(w)
}
func (w KLineWindow) GetOpen() float64 {
return w[0].GetOpen()
}
func (w KLineWindow) GetClose() float64 {
end := len(w) - 1
return w[end].GetClose()
}
func (w KLineWindow) GetHigh() float64 {
high := w.GetOpen()
for _, line := range w {
val := line.GetHigh()
if val > high {
high = val
}
}
return high
}
func (w KLineWindow) GetLow() float64 {
low := w.GetOpen()
for _, line := range w {
val := line.GetHigh()
if val < low {
low = val
}
}
return low
}
func (w KLineWindow) GetChange() float64 {
return w.GetClose() - w.GetOpen()
}
func (w KLineWindow) GetMaxChange() float64 {
return w.GetHigh() - w.GetLow()
}
func (w *KLineWindow) Add(line KLine) {
*w = append(*w, line)
}
func (w *KLineWindow) Truncate(size int) {
if len(*w) <= size {
return
}
end := len(*w) - 1
start := end - size
if start < 0 {
start = 0
}
*w = (*w)[end-5 : end]
}

12
bbgo/parse.go Normal file
View File

@ -0,0 +1,12 @@
package bbgo
import "strconv"
func MustParseFloat(s string) float64 {
v, err := strconv.ParseFloat(s, 64)
if err != nil {
panic(err)
}
return v
}

196
bbgo/parser.go Normal file
View File

@ -0,0 +1,196 @@
package bbgo
import (
"encoding/json"
"fmt"
"github.com/valyala/fastjson"
)
/*
executionReport
{
"e": "executionReport", // KLineEvent type
"E": 1499405658658, // KLineEvent time
"s": "ETHBTC", // Symbol
"c": "mUvoqJxFIILMdfAW5iGSOW", // Client order ID
"S": "BUY", // Side
"o": "LIMIT", // Order type
"f": "GTC", // Time in force
"q": "1.00000000", // Order quantity
"p": "0.10264410", // Order price
"P": "0.00000000", // Stop price
"F": "0.00000000", // Iceberg quantity
"g": -1, // OrderListId
"C": null, // Original client order ID; This is the ID of the order being canceled
"x": "NEW", // Current execution type
"X": "NEW", // Current order status
"r": "NONE", // Order reject reason; will be an error code.
"i": 4293153, // Order ID
"l": "0.00000000", // Last executed quantity
"z": "0.00000000", // Cumulative filled quantity
"L": "0.00000000", // Last executed price
"n": "0", // Commission amount
"N": null, // Commission asset
"T": 1499405658657, // Transaction time
"t": -1, // Trade ID
"I": 8641984, // Ignore
"w": true, // Is the order on the book?
"m": false, // Is this trade the maker side?
"M": false, // Ignore
"O": 1499405658657, // Order creation time
"Z": "0.00000000", // Cumulative quote asset transacted quantity
"Y": "0.00000000", // Last quote asset transacted quantity (i.e. lastPrice * lastQty)
"Q": "0.00000000" // Quote Order Qty
}
*/
type ExecutionReportEvent struct {
EventBase
Symbol string `json:"s"`
ClientOrderID string `json:"c"`
Side string `json:"S"`
OrderType string `json:"o"`
TimeInForce string `json:"f"`
Quantity string `json:"q"`
Price string `json:"p"`
StopPrice string `json:"P"`
CurrentExecutionType string `json:"x"`
CurrentOrderStatus string `json:"X"`
OrderID int `json:"i"`
LastExecutedQuantity string `json:"l"`
CumulativeFilledQuantity string `json:"z"`
LastExecutedPrice string `json:"L"`
OrderCreationTime int `json:"O"`
}
/*
balanceUpdate
{
"e": "balanceUpdate", //KLineEvent Type
"E": 1573200697110, //KLineEvent Time
"a": "BTC", //Asset
"d": "100.00000000", //Balance Delta
"T": 1573200697068 //Clear Time
}
*/
type BalanceUpdateEvent struct {
EventBase
Asset string `json:"a"`
Delta string `json:"d"`
ClearTime int64 `json:"T"`
}
/*
outboundAccountInfo
{
"e": "outboundAccountInfo", // KLineEvent type
"E": 1499405658849, // KLineEvent time
"m": 0, // Maker commission rate (bips)
"t": 0, // Taker commission rate (bips)
"b": 0, // Buyer commission rate (bips)
"s": 0, // Seller commission rate (bips)
"T": true, // Can trade?
"W": true, // Can withdraw?
"D": true, // Can deposit?
"u": 1499405658848, // Time of last account update
"B": [ // Balances array
{
"a": "LTC", // Asset
"f": "17366.18538083", // Free amount
"l": "0.00000000" // Locked amount
},
{
"a": "BTC",
"f": "10537.85314051",
"l": "2.19464093"
},
{
"a": "ETH",
"f": "17902.35190619",
"l": "0.00000000"
},
{
"a": "BNC",
"f": "1114503.29769312",
"l": "0.00000000"
},
{
"a": "NEO",
"f": "0.00000000",
"l": "0.00000000"
}
],
"P": [ // Account Permissions
"SPOT"
]
}
*/
type Balance struct {
Asset string `json:"a"`
Free string `json:"f"`
Locked string `json:"l"`
}
type OutboundAccountInfoEvent struct {
EventBase
MakerCommissionRate int `json:"m"`
TakerCommissionRate int `json:"t"`
BuyerCommissionRate int `json:"b"`
SellerCommissionRate int `json:"s"`
CanTrade bool `json:"T"`
CanWithdraw bool `json:"W"`
CanDeposit bool `json:"D"`
LastAccountUpdateTime int `json:"u"`
Balances []Balance `json:"B,omitempty"`
Permissions []string `json:"P,omitempty"`
}
func ParseEvent(message string) (interface{}, error) {
val, err := fastjson.Parse(message)
if err != nil {
return nil, err
}
eventType := string(val.GetStringBytes("e"))
switch eventType {
case "kline":
var event KLineEvent
err := json.Unmarshal([]byte(message), &event)
return &event, err
case "outboundAccountInfo", "outboundAccountPosition":
var event OutboundAccountInfoEvent
err := json.Unmarshal([]byte(message), &event)
return &event, err
case "balanceUpdate":
var event BalanceUpdateEvent
err := json.Unmarshal([]byte(message), &event)
return &event, err
case "executionReport":
var event ExecutionReportEvent
err := json.Unmarshal([]byte(message), &event)
return &event, err
}
return nil, fmt.Errorf("unsupported message: %s", message)
}

66
bbgo/pnl.go Normal file
View File

@ -0,0 +1,66 @@
package bbgo
import log "github.com/sirupsen/logrus"
func CalculateAverageCost(trades []Trade) (averageCost float64) {
var totalCost = 0.0
var totalQuantity = 0.0
for _, t := range trades {
if t.IsBuyer {
totalCost += t.Price * t.Volume
totalQuantity += t.Volume
} else {
totalCost -= t.Price * t.Volume
totalQuantity -= t.Volume
}
}
averageCost = totalCost / totalQuantity
return
}
func CalculateCostAndProfit(trades []Trade, currentPrice float64) (averageBidPrice, stock, profit, fee float64) {
var bidVolume = 0.0
var bidAmount = 0.0
var bidFee = 0.0
for _, t := range trades {
if t.IsBuyer {
bidVolume += t.Volume
bidAmount += t.Price * t.Volume
switch t.FeeCurrency {
case "BTC":
bidFee += t.Price * t.Fee
}
}
}
log.Infof("average bid price = (total amount %f + total fee %f) / volume %f", bidAmount, bidFee, bidVolume)
averageBidPrice = (bidAmount + bidFee) / bidVolume
var feeRate = 0.001
var askVolume = 0.0
var askFee = 0.0
for _, t := range trades {
if !t.IsBuyer {
profit += (t.Price - averageBidPrice) * t.Volume
askVolume += t.Volume
switch t.FeeCurrency {
case "USDT":
askFee += t.Fee
}
}
}
profit -= askFee
stock = bidVolume - askVolume
futureFee := 0.0
if stock > 0 {
stockfee := currentPrice * feeRate * stock
profit += (currentPrice-averageBidPrice)*stock - stockfee
futureFee += stockfee
}
fee = bidFee + askFee + futureFee
return
}

92
bbgo/trade.go Normal file
View File

@ -0,0 +1,92 @@
package bbgo
import (
"context"
"github.com/adshao/go-binance"
log "github.com/sirupsen/logrus"
"strconv"
"time"
)
type Trade struct {
ID int64
Price float64
Volume float64
IsBuyer bool
IsMaker bool
Time time.Time
Fee float64
FeeCurrency string
}
func QueryTrades(ctx context.Context, client *binance.Client, market string, startTime time.Time) (trades []Trade, err error) {
var lastTradeID int64 = 0
for {
req := client.NewListTradesService().
Limit(1000).
Symbol(market).
StartTime(startTime.UnixNano() / 1000000)
if lastTradeID > 0 {
req.FromID(lastTradeID)
}
bnTrades, err := req.Do(ctx)
if err != nil {
return nil, err
}
if len(bnTrades) <= 1 {
break
}
for _, t := range bnTrades {
// skip trade ID that is the same
if t.ID == lastTradeID {
continue
}
var side string
if t.IsBuyer {
side = "BUY"
} else {
side = "SELL"
}
// trade time
tt := time.Unix(0, t.Time*1000000)
log.Infof("trade: %d %4s Price: % 13s Volume: % 13s %s", t.ID, side, t.Price, t.Quantity, tt)
price, err := strconv.ParseFloat(t.Price, 64)
if err != nil {
return nil, err
}
quantity, err := strconv.ParseFloat(t.Quantity, 64)
if err != nil {
return nil, err
}
fee, err := strconv.ParseFloat(t.Commission, 64)
if err != nil {
return nil, err
}
trades = append(trades, Trade{
ID: t.ID,
Price: price,
Volume: quantity,
IsBuyer: t.IsBuyer,
IsMaker: t.IsMaker,
Fee: fee,
FeeCurrency: t.CommissionAsset,
Time: tt,
})
lastTradeID = t.ID
}
}
return trades, nil
}

39
util/math.go Normal file
View File

@ -0,0 +1,39 @@
package util
import (
"math"
"strconv"
)
const MaxDigits = 18 // MAX_INT64 ~ 9 * 10^18
var Pow10Table = [MaxDigits + 1]int64{
1, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18,
}
func Pow10(n int64) int64 {
if n < 0 || n > MaxDigits {
return 0
}
return Pow10Table[n]
}
var NegPow10Table = [MaxDigits + 1]float64{
1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15, 1e-16, 1e-17, 1e-18,
}
func NegPow10(n int64) float64 {
if n < 0 || n > MaxDigits {
return 0.0
}
return NegPow10Table[n]
}
func Float64ToStr(input float64) string {
return strconv.FormatFloat(input, 'f', -1, 64)
}
func Float64ToInt64(input float64) int64 {
// eliminate rounding error for IEEE754 floating points
return int64(math.Round(input))
}

19
util/rate_limit.go Normal file
View File

@ -0,0 +1,19 @@
package util
import (
"fmt"
"time"
"golang.org/x/time/rate"
)
func ShouldDelay(l *rate.Limiter, minInterval time.Duration) time.Duration {
return l.Reserve().Delay()
}
func NewValidLimiter(r rate.Limit, b int) (*rate.Limiter, error) {
if b <= 0 || r <= 0 {
return nil, fmt.Errorf("Bad rate limit config. Insufficient tokens. (rate=%f, b=%d)", r, b)
}
return rate.NewLimiter(r, b), nil
}

43
util/rate_limit_test.go Normal file
View File

@ -0,0 +1,43 @@
package util
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/time/rate"
)
func TestNewValidRateLimiter(t *testing.T) {
cases := []struct {
name string
r rate.Limit
b int
hasError bool
}{
{"valid limiter", 0.1, 1, false},
{"zero rate", 0, 1, true},
{"zero burst", 0.1, 0, true},
{"both zero", 0, 0, true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
limiter, err := NewValidLimiter(c.r, c.b)
assert.Equal(t, c.hasError, err != nil)
if !c.hasError {
assert.NotNil(t, limiter)
}
})
}
}
func TestShouldDelay(t *testing.T) {
minInterval := time.Second * 3
maxRate := rate.Limit(1 / minInterval.Seconds())
limiter := rate.NewLimiter(maxRate, 1)
assert.Equal(t, time.Duration(0), ShouldDelay(limiter, minInterval))
for i := 0; i < 100; i++ {
assert.True(t, ShouldDelay(limiter, minInterval) > 0)
}
}

69
util/retry.go Normal file
View File

@ -0,0 +1,69 @@
package util
import (
"context"
"time"
"github.com/pkg/errors"
)
const (
InfiniteRetry = 0
)
type RetryPredicator func(e error) bool
// Retry retrys the passed functin for "attempts" times, if passed function return error. Setting attemps to zero means keep retrying.
func Retry(ctx context.Context, attempts int, duration time.Duration, fnToRetry func() error, errHandler func(error), predicators ...RetryPredicator) (err error) {
infinite := false
if attempts == InfiniteRetry {
infinite = true
}
for attempts > 0 || infinite {
select {
case <-ctx.Done():
errMsg := "return for context done"
if err != nil {
return errors.Wrap(err, errMsg)
} else {
return errors.New(errMsg)
}
default:
if err = fnToRetry(); err == nil {
return nil
}
if !needRetry(err, predicators) {
return err
}
err = errors.Wrapf(err, "failed in retry: countdown: %v", attempts)
if errHandler != nil {
errHandler(err)
}
if !infinite {
attempts--
}
time.Sleep(duration)
}
}
return err
}
func needRetry(err error, predicators []RetryPredicator) bool {
if err == nil {
return false
}
// If no predicators specified, we will retry for all errors
if len(predicators) == 0 {
return true
}
return predicators[0](err)
}

106
util/retry_test.go Normal file
View File

@ -0,0 +1,106 @@
package util
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)
func addAndCheck(a *int, target int) error {
if *a++; *a == target {
return nil
} else {
return fmt.Errorf("a is not %v. It is %v\n", target, *a)
}
}
func TestRetry(t *testing.T) {
type test struct {
input int
targetNum int
ans int
ansErr error
}
tests := []test{
{input: 0, targetNum: 3, ans: 3, ansErr: nil},
{input: 0, targetNum: 10, ans: 3, ansErr: errors.New("failed in retry")},
}
for _, tc := range tests {
errHandled := false
err := Retry(context.Background(), 3, 1*time.Second, func() error {
return addAndCheck(&tc.input, tc.targetNum)
}, func(e error) { errHandled = true })
assert.Equal(t, true, errHandled)
if tc.ansErr == nil {
assert.NoError(t, err)
} else {
assert.Contains(t, err.Error(), tc.ansErr.Error())
}
assert.Equal(t, tc.ans, tc.input)
}
}
func TestRetryWithPredicator(t *testing.T) {
type test struct {
count int
f func() error
errHandler func(error)
predicator RetryPredicator
ansCount int
ansErr error
}
knownErr := errors.New("Duplicate entry '1-389837488-1' for key 'UNI_Trade'")
unknownErr := errors.New("Some Error")
tests := []test{
{
predicator: func(err error) bool {
return !strings.Contains(err.Error(), "Duplicate entry")
},
f: func() error { return knownErr },
ansCount: 1,
ansErr: knownErr,
},
{
predicator: func(err error) bool {
return !strings.Contains(err.Error(), "Duplicate entry")
},
f: func() error { return unknownErr },
ansCount: 3,
ansErr: unknownErr,
},
}
attempts := 3
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, tc := range tests {
err := Retry(ctx, attempts, 100*time.Millisecond, func() error {
tc.count++
return tc.f()
}, tc.errHandler, tc.predicator)
assert.Equal(t, tc.ansCount, tc.count)
assert.EqualError(t, errors.Cause(err), tc.ansErr.Error(), "should be equal")
}
}
func TestRetryCtxCancel(t *testing.T) {
result := int(0)
target := int(3)
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := Retry(ctx, 5, 1*time.Second, func() error { return addAndCheck(&result, target) }, func(error) {})
assert.Error(t, err)
fmt.Println("Error:", err.Error())
assert.Equal(t, int(0), result)
}

78
websocket/backoff.go Normal file
View File

@ -0,0 +1,78 @@
package websocket
import (
"time"
)
const (
DefaultMinBackoff = 5 * time.Second
DefaultMaxBackoff = time.Minute
DefaultBackoffFactor = 2
)
type Backoff struct {
attempt int64
Factor float64
cur, Min, Max time.Duration
}
// Duration returns the duration for the current attempt before incrementing
// the attempt counter. See ForAttempt.
func (b *Backoff) Duration() time.Duration {
d := b.calculate(b.attempt)
b.attempt++
return d
}
func (b *Backoff) calculate(attempt int64) time.Duration {
min := b.Min
if min <= 0 {
min = DefaultMinBackoff
}
max := b.Max
if max <= 0 {
max = DefaultMaxBackoff
}
if min >= max {
return max
}
factor := b.Factor
if factor <= 0 {
factor = DefaultBackoffFactor
}
cur := b.cur
if cur < min {
cur = min
} else if cur > max {
cur = max
}
//calculate this duration
next := cur
if attempt > 0 {
next = time.Duration(float64(cur) * factor)
}
if next < cur {
// overflow
next = max
} else if next <= min {
next = min
} else if next >= max {
next = max
}
b.cur = next
return next
}
// Reset restarts the current attempt counter at zero.
func (b *Backoff) Reset() {
b.attempt = 0
b.cur = b.Min
}
// Attempt returns the current attempt counter value.
func (b *Backoff) Attempt() int64 {
return b.attempt
}

172
websocket/backoff_test.go Normal file
View File

@ -0,0 +1,172 @@
package websocket
import (
"math"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestBackoff(t *testing.T) {
b := &Backoff{
Min: 200 * time.Millisecond,
Max: 100 * time.Second,
Factor: 2,
}
equals(t, b.Duration(), 200*time.Millisecond)
equals(t, b.Duration(), 400*time.Millisecond)
equals(t, b.Duration(), 800*time.Millisecond)
b.Reset()
equals(t, b.Duration(), 200*time.Millisecond)
}
func TestReachMaxBackoff(t *testing.T) {
b := &Backoff{
Min: 2 * time.Second,
Max: 100 * time.Second,
Factor: 10,
}
equals(t, b.Duration(), 2*time.Second)
equals(t, b.Duration(), 20*time.Second)
equals(t, b.Duration(), b.Max)
b.Reset()
equals(t, b.Duration(), 2*time.Second)
}
func TestAttempt(t *testing.T) {
b := &Backoff{
Min: 2 * time.Second,
Max: 100 * time.Second,
Factor: 10,
}
equals(t, b.Attempt(), int64(0))
equals(t, b.Duration(), 2*time.Second)
equals(t, b.Attempt(), int64(1))
equals(t, b.Duration(), 20*time.Second)
equals(t, b.Attempt(), int64(2))
equals(t, b.Duration(), b.Max)
equals(t, b.Attempt(), int64(3))
b.Reset()
equals(t, b.Attempt(), int64(0))
equals(t, b.Duration(), 2*time.Second)
equals(t, b.Attempt(), int64(1))
}
func TestAttemptWithZeroValue(t *testing.T) {
b := &Backoff{}
cur := DefaultMinBackoff
equals(t, b.Attempt(), int64(0))
equals(t, b.Duration(), cur)
for i := 1; i < 10; i++ {
equals(t, b.Attempt(), int64(i))
cur *= DefaultBackoffFactor
if cur >= DefaultMaxBackoff {
equals(t, b.Duration(), DefaultMaxBackoff)
break
}
equals(t, b.Duration(), cur)
}
b.Reset()
equals(t, b.Attempt(), int64(0))
equals(t, b.Duration(), DefaultMinBackoff)
equals(t, b.Attempt(), int64(1))
}
func TestNegOrZeroMin(t *testing.T) {
b := &Backoff{
Min: 0,
}
equals(t, b.Duration(), DefaultMinBackoff)
b2 := &Backoff{
Min: -1 * time.Second,
}
equals(t, b2.Duration(), DefaultMinBackoff)
}
func TestNegOrZeroMax(t *testing.T) {
b := &Backoff{
Min: DefaultMaxBackoff,
Max: 0,
}
equals(t, b.Duration(), DefaultMaxBackoff)
b2 := &Backoff{
Min: DefaultMaxBackoff,
Max: -1 * time.Second,
}
equals(t, b2.Duration(), DefaultMaxBackoff)
}
func TestMinLargerThanMax(t *testing.T) {
b := &Backoff{
Min: 500 * time.Second,
Max: 100 * time.Second,
Factor: 1,
}
equals(t, b.Duration(), b.Max)
}
func TestFakeCur(t *testing.T) {
b := &Backoff{
Min: 100 * time.Second,
Max: 500 * time.Second,
Factor: 1,
cur: 1000 * time.Second,
}
equals(t, b.Duration(), b.Max)
}
func TestLargeFactor(t *testing.T) {
b := &Backoff{
Min: 1 * time.Second,
Max: 10 * time.Second,
Factor: math.MaxFloat64,
}
equals(t, b.Duration(), b.Min)
for i := 0; i < 100; i++ {
equals(t, b.Duration(), b.Max)
}
b.Reset()
equals(t, b.Duration(), b.Min)
}
func TestLargeMinMax(t *testing.T) {
b := &Backoff{
Min: time.Duration(math.MaxInt64 - 1),
Max: time.Duration(math.MaxInt64),
Factor: math.MaxFloat64,
}
equals(t, b.Duration(), b.Min)
for i := 0; i < 100; i++ {
equals(t, b.Duration(), b.Max)
}
b.Reset()
equals(t, b.Duration(), b.Min)
}
func TestManyAttempts(t *testing.T) {
b := &Backoff{
Min: 10 * time.Second,
Max: time.Duration(math.MaxInt64),
Factor: 1000,
}
for i := 0; i < 10000; i++ {
b.Duration()
}
equals(t, b.Duration(), b.Max)
b.Reset()
}
func equals(t *testing.T, v1, v2 interface{}) {
if !reflect.DeepEqual(v1, v2) {
assert.Failf(t, "not equal", "Got %v, Expecting %v", v1, v2)
}
}

368
websocket/client.go Normal file
View File

@ -0,0 +1,368 @@
package websocket
import (
"context"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
"github.com/c9s/bbgo/pkg/log"
"github.com/c9s/bbgo/pkg/util"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
const DefaultMessageBufferSize = 128
const DefaultWriteTimeout = 30 * time.Second
const DefaultReadTimeout = 30 * time.Second
var ErrReconnectContextDone = errors.New("reconnect canceled due to context done.")
var ErrReconnectFailed = errors.New("failed to reconnect")
var ErrConnectionLost = errors.New("connection lost")
var MaxReconnectRate = rate.Limit(1 / DefaultMinBackoff.Seconds())
// WebSocketClient allows to connect and receive stream data
type WebSocketClient struct {
// Url is the websocket connection location, start with ws:// or wss://
Url string
// conn is the current websocket connection, please note the connection
// object can be replaced with a new connection object when the connection
// is unexpected closed.
conn *websocket.Conn
// Dialer is used for creating the websocket connection
Dialer *websocket.Dialer
// requestHeader is used for the Dial function call. Some credential can be
// stored in the http request header for authentication
requestHeader http.Header
// messages is a read-only channel, received messages will be sent to this
// channel.
messages chan Message
readTimeout time.Duration
writeTimeout time.Duration
onConnect []func(c Client)
onDisconnect []func(c Client)
// cancel is mapped to the ctx context object
cancel func()
readerClosed chan struct{}
connected bool
mu sync.Mutex
reconnectCh chan struct{}
backoff Backoff
limiter *rate.Limiter
}
type Message struct {
// websocket.BinaryMessage or websocket.TextMessage
Type int
Body []byte
}
func (c *WebSocketClient) Messages() <-chan Message {
return c.messages
}
func (c *WebSocketClient) SetReadTimeout(timeout time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.readTimeout = timeout
}
func (c *WebSocketClient) SetWriteTimeout(timeout time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.writeTimeout = timeout
}
func (c *WebSocketClient) OnConnect(f func(c Client)) {
c.mu.Lock()
defer c.mu.Unlock()
c.onConnect = append(c.onConnect, f)
}
func (c *WebSocketClient) OnDisconnect(f func(c Client)) {
c.mu.Lock()
defer c.mu.Unlock()
c.onDisconnect = append(c.onDisconnect, f)
}
func (c *WebSocketClient) WriteTextMessage(message []byte) error {
return c.WriteMessage(websocket.TextMessage, message)
}
func (c *WebSocketClient) WriteBinaryMessage(message []byte) error {
return c.WriteMessage(websocket.BinaryMessage, message)
}
func (c *WebSocketClient) WriteMessage(messageType int, data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return ErrConnectionLost
}
if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
return err
}
return c.conn.WriteMessage(messageType, data)
}
func (c *WebSocketClient) readMessages() error {
c.mu.Lock()
if !c.connected {
c.mu.Unlock()
return ErrConnectionLost
}
timeout := c.readTimeout
conn := c.conn
c.mu.Unlock()
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
return err
}
msgtype, message, err := conn.ReadMessage()
if err != nil {
return err
}
c.messages <- Message{msgtype, message}
return nil
}
// listen starts a goroutine for reading message and tries to re-connect to the
// server when the reader returns error
//
// Please note we should always break the reader loop if there is any error
// returned from the server.
func (c *WebSocketClient) listen(ctx context.Context) {
// The life time of both channels "readerClosed" and "reconnectCh" is bound to one connection.
// Each channel should be created before loop starts and be closed after loop ends.
// "readerClosed" is used to inform "Close()" reader loop ends.
// "reconnectCh" is used to centralize reconnection logics in this reader loop.
c.mu.Lock()
c.readerClosed = make(chan struct{})
c.reconnectCh = make(chan struct{}, 1)
c.mu.Unlock()
defer func() {
c.mu.Lock()
close(c.readerClosed)
close(c.reconnectCh)
c.reconnectCh = nil
c.mu.Unlock()
}()
for {
select {
case <-ctx.Done():
return
case <-c.reconnectCh:
// it could be i/o timeout for network disconnection
// or it could be invoked from outside.
c.SetDisconnected()
var maxTries = 1
if _, response, err := c.reconnect(ctx, maxTries); err != nil {
if err == ErrReconnectContextDone {
log.Debugf("[websocket] context canceled. stop reconnecting.")
return
}
log.Warnf("[websocket] failed to reconnect after %d tries!! error: %v response: %v", maxTries, err, response)
c.Reconnect()
}
default:
if err := c.readMessages(); err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
log.Warnf("[websocket] unexpected close error reconnecting: %v", err)
}
log.Warnf("[websocket] failed to read message. error: %+v", err)
c.Reconnect()
}
}
}
}
// Reconnect triggers reconnection logics
func (c *WebSocketClient) Reconnect() {
c.mu.Lock()
defer c.mu.Unlock()
select {
// c.reconnectCh is a buffered channel with cap=1.
// At most one reconnect signal could be processed.
case c.reconnectCh <- struct{}{}:
default:
// Entering here means it is already reconnecting.
// Drop the current reconnect signal.
}
}
// Close gracefully shuts down the reader and the connection
// ctx is the context used for shutdown process.
func (c *WebSocketClient) Close() (err error) {
c.mu.Lock()
// leave the listen goroutine before we close the connection
// checking nil is to handle calling "Close" before "Connect" is called
if c.cancel != nil {
c.cancel()
}
c.mu.Unlock()
c.SetDisconnected()
// wait for the reader func to be closed
if c.readerClosed != nil {
<-c.readerClosed
}
return err
}
// reconnect tries to create a new connection from the existing dialer
func (c *WebSocketClient) reconnect(ctx context.Context, maxTries int) (*websocket.Conn, *http.Response, error) {
log.Debugf("[websocket] start reconnecting to %q", c.Url)
select {
case <-ctx.Done():
return nil, nil, ErrReconnectContextDone
default:
}
if s := util.ShouldDelay(c.limiter, DefaultMinBackoff); s > 0 {
log.Warn("[websocket] reconnect too frequently. Sleep for ", s)
time.Sleep(s)
}
log.Warnf("[websocket] reconnecting x %d to %q", c.backoff.Attempt()+1, c.Url)
conn, resp, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
if err != nil {
dur := c.backoff.Duration()
log.Warnf("failed to dial %s: %v, response: %+v. Wait for %v", c.Url, err, resp, dur)
time.Sleep(dur)
return nil, nil, ErrReconnectFailed
}
log.Infof("[websocket] reconnected to %q", c.Url)
// Reset backoff value if connected.
c.backoff.Reset()
c.setConn(conn)
c.setPingHandler(conn)
return conn, resp, err
}
// Conn returns the current active connection instance
func (c *WebSocketClient) Conn() (conn *websocket.Conn) {
c.mu.Lock()
conn = c.conn
c.mu.Unlock()
return conn
}
func (c *WebSocketClient) setConn(conn *websocket.Conn) {
// Disconnect old connection before replacing with new one.
c.SetDisconnected()
c.mu.Lock()
c.conn = conn
c.connected = true
c.mu.Unlock()
for _, f := range c.onConnect {
go f(c)
}
}
func (c *WebSocketClient) setPingHandler(conn *websocket.Conn) {
conn.SetPingHandler(func(message string) error {
if err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
return err
}
c.mu.Lock()
defer c.mu.Unlock()
return conn.SetReadDeadline(time.Now().Add(c.readTimeout))
})
}
func (c *WebSocketClient) SetDisconnected() {
c.mu.Lock()
closed := false
if c.conn != nil {
closed = true
c.conn.Close()
}
c.connected = false
c.conn = nil
c.mu.Unlock()
if closed {
// Only call disconnect callbacks when a connection is closed
for _, f := range c.onDisconnect {
go f(c)
}
}
}
func (c *WebSocketClient) IsConnected() (ret bool) {
c.mu.Lock()
ret = c.connected
c.mu.Unlock()
return ret
}
func (c *WebSocketClient) Connect(basectx context.Context) error {
// maintain a context by the client it self, so that we can manually shutdown the connection
ctx, cancel := context.WithCancel(basectx)
c.cancel = cancel
conn, _, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader)
if err == nil {
// setup connection only when connected
c.setConn(conn)
c.setPingHandler(conn)
}
// 1) if connection is built up, start listening for messages.
// 2) if connection is NOT ready, start reconnecting infinitely.
go c.listen(ctx)
return err
}
func New(url string, requestHeader http.Header) *WebSocketClient {
return NewWithDialer(url, websocket.DefaultDialer, requestHeader)
}
func NewWithDialer(url string, d *websocket.Dialer, requestHeader http.Header) *WebSocketClient {
limiter, err := util.NewValidLimiter(MaxReconnectRate, 1)
if err != nil {
log.WithError(err).Panic("Invalid rate limiter")
}
return &WebSocketClient{
Url: url,
Dialer: d,
requestHeader: requestHeader,
readTimeout: DefaultReadTimeout,
writeTimeout: DefaultWriteTimeout,
messages: make(chan Message, DefaultMessageBufferSize),
limiter: limiter,
}
}

202
websocket/client_test.go Normal file
View File

@ -0,0 +1,202 @@
package websocket
import (
"context"
"net"
"net/http"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"***REMOVED***/pkg/log"
"***REMOVED***/pkg/testing/testutil"
"github.com/gorilla/websocket"
)
func messageTicker(t *testing.T, ctx context.Context, conn *websocket.Conn) {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case tt := <-ticker.C:
err := conn.WriteMessage(websocket.TextMessage, []byte(tt.String()))
assert.NoError(t, err)
case <-ctx.Done():
return
}
}
}
func messageReader(t *testing.T, ctx context.Context, conn *websocket.Conn) {
for {
_, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Errorf("unexpected closed error: %v", err)
}
break
}
t.Logf("message: %v", message)
select {
case <-ctx.Done():
return
default:
}
}
}
func TestReconnect(t *testing.T) {
basectx := context.Background()
wsHandler := func(ctx context.Context) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Logf("upgrade error: %v", err)
return
}
go messageTicker(t, ctx, conn)
messageReader(t, ctx, conn)
}
}
serverctx, cancelserver := context.WithCancel(basectx)
server := testutil.NewWebSocketServerFunc(t, 0, "/ws", wsHandler(serverctx))
go func() {
err := server.ListenAndServe()
assert.Equal(t, http.ErrServerClosed, err)
}()
_, port, err := net.SplitHostPort(server.Addr)
assert.NoError(t, err)
log.Debugf("waiting for server port ready")
testutil.WaitForPort(t, server.Addr, 3)
u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"}
log.Infof("url: %s", u.String())
client := New(u.String(), nil)
// start the message reader
clientctx, cancelClient := context.WithCancel(basectx)
defer cancelClient()
err = client.Connect(clientctx)
assert.NoError(t, err)
assert.NotNil(t, client)
client.SetReadTimeout(2 * time.Second)
// read one message
log.Debugf("waiting for message...")
msg := <-client.messages
log.Debugf("received message: %v", msg)
triggerReconnectManyTimes := func() {
// Forcedly reconnect multiple times
for i := 0; i < 10; i = i + 1 {
client.Reconnect()
}
}
log.Debugf("reconnecting...")
triggerReconnectManyTimes()
msg = <-client.messages
log.Debugf("received message: %v", msg)
log.Debugf("shutting down server...")
err = server.Shutdown(serverctx)
assert.NoError(t, err)
cancelserver()
server.Close()
time.Sleep(500 * time.Millisecond)
triggerReconnectManyTimes()
log.Debugf("restarting server...")
server2ctx, cancelserver2 := context.WithCancel(basectx)
server2 := testutil.NewWebSocketServerWithAddressFunc(t, server.Addr, "/ws", wsHandler(server2ctx))
go func() {
err := server2.ListenAndServe()
assert.Equal(t, http.ErrServerClosed, err)
}()
triggerReconnectManyTimes()
log.Debugf("waiting for server2 port ready")
testutil.WaitForPort(t, server.Addr, 3)
log.Debugf("waiting for message from server2...")
msg = <-client.messages
log.Debugf("received message: %v", msg)
err = client.Close()
assert.NoError(t, err)
err = server2.Shutdown(server2ctx)
assert.NoError(t, err)
cancelserver2()
}
func TestConnect(t *testing.T) {
basectx := context.Background()
ctx, cancel := context.WithCancel(basectx)
server := testutil.NewWebSocketServerFunc(t, 0, "/ws", func(w http.ResponseWriter, r *http.Request) {
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Logf("upgrade error: %v", err)
return
}
go messageTicker(t, ctx, conn)
messageReader(t, ctx, conn)
})
go func() {
err := server.ListenAndServe()
assert.Equal(t, http.ErrServerClosed, err)
}()
_, port, err := net.SplitHostPort(server.Addr)
assert.NoError(t, err)
testutil.WaitForPort(t, server.Addr, 3)
u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"}
t.Logf("url: %s", u.String())
client := New(u.String(), nil)
err = client.Connect(ctx)
assert.NoError(t, err)
assert.NotNil(t, client)
client.SetReadTimeout(2 * time.Second)
// read one message
t.Logf("waiting for message...")
msg := <-client.messages
t.Logf("received message: %v", msg)
// recreate server at the same address
client.Close()
t.Logf("shutting down server...")
cancel()
err = server.Shutdown(basectx)
assert.NoError(t, err)
}

21
websocket/interface.go Normal file
View File

@ -0,0 +1,21 @@
package websocket
import (
"context"
"time"
)
//go:generate mockery -name=Client
type Client interface {
SetWriteTimeout(time.Duration)
SetReadTimeout(time.Duration)
OnConnect(func(c Client))
OnDisconnect(func(c Client))
Connect(context.Context) error
Reconnect()
Close() error
IsConnected() bool
WriteJSON(interface{}) error
Messages() <-chan Message
}

55
websocket/json.go Normal file
View File

@ -0,0 +1,55 @@
package websocket
import (
"encoding/json"
"io"
"github.com/gorilla/websocket"
)
// WriteJSON writes the JSON encoding of v as a message.
//
// See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON.
func (c *WebSocketClient) WriteJSON(v interface{}) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return ErrConnectionLost
}
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
return err
}
err1 := json.NewEncoder(w).Encode(v)
err2 := w.Close()
if err1 != nil {
return err1
}
return err2
}
// ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// See the documentation for the encoding/json Unmarshal function for details
// about the conversion of JSON to a Go value.
func (c *WebSocketClient) ReadJSON(v interface{}) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return ErrConnectionLost
}
_, r, err := c.conn.NextReader()
if err != nil {
return err
}
err = json.NewDecoder(r).Decode(v)
if err == io.EOF {
// One value is expected in the message.
err = io.ErrUnexpectedEOF
}
return err
}

110
websocket/mocks/Client.go Normal file
View File

@ -0,0 +1,110 @@
// Code generated by mockery v1.0.0. DO NOT EDIT.
package mocks
import context "context"
import mock "github.com/stretchr/testify/mock"
import time "time"
import websocket "***REMOVED***/pkg/net/http/websocket"
// Client is an autogenerated mock type for the Client type
type Client struct {
mock.Mock
}
// Close provides a mock function with given fields:
func (_m *Client) Close() error {
ret := _m.Called()
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// Connect provides a mock function with given fields: _a0
func (_m *Client) Connect(_a0 context.Context) error {
ret := _m.Called(_a0)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
}
return r0
}
// IsConnected provides a mock function with given fields:
func (_m *Client) IsConnected() bool {
ret := _m.Called()
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// Messages provides a mock function with given fields:
func (_m *Client) Messages() <-chan websocket.Message {
ret := _m.Called()
var r0 <-chan websocket.Message
if rf, ok := ret.Get(0).(func() <-chan websocket.Message); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(<-chan websocket.Message)
}
}
return r0
}
// OnConnect provides a mock function with given fields: _a0
func (_m *Client) OnConnect(_a0 func(websocket.Client)) {
_m.Called(_a0)
}
// OnDisconnect provides a mock function with given fields: _a0
func (_m *Client) OnDisconnect(_a0 func(websocket.Client)) {
_m.Called(_a0)
}
// Reconnect provides a mock function with given fields:
func (_m *Client) Reconnect() {
_m.Called()
}
// SetReadTimeout provides a mock function with given fields: _a0
func (_m *Client) SetReadTimeout(_a0 time.Duration) {
_m.Called(_a0)
}
// SetWriteTimeout provides a mock function with given fields: _a0
func (_m *Client) SetWriteTimeout(_a0 time.Duration) {
_m.Called(_a0)
}
// WriteJSON provides a mock function with given fields: _a0
func (_m *Client) WriteJSON(_a0 interface{}) error {
ret := _m.Called(_a0)
var r0 error
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
}
return r0
}