commit 3fe812264f65605a3ae2a9e035252f08d274d764 Author: c9s Date: Mon Jun 8 10:42:50 2020 +0800 update diff --git a/bbgo/event.go b/bbgo/event.go new file mode 100644 index 000000000..28ab3b36b --- /dev/null +++ b/bbgo/event.go @@ -0,0 +1,40 @@ +package bbgo + +/* + +kline + +{ + "e": "kline", // KLineEvent type + "E": 123456789, // KLineEvent time + "s": "BNBBTC", // Symbol + "k": { + "t": 123400000, // Kline start time + "T": 123460000, // Kline close time + "s": "BNBBTC", // Symbol + "i": "1m", // Interval + "f": 100, // First trade ID + "L": 200, // Last trade ID + "o": "0.0010", // Open price + "c": "0.0020", // Close price + "h": "0.0025", // High price + "l": "0.0015", // Low price + "v": "1000", // Base asset volume + "n": 100, // Number of trades + "x": false, // Is this kline closed? + "q": "1.0000", // Quote asset volume + "V": "500", // Taker buy base asset volume + "Q": "0.500", // Taker buy quote asset volume + "B": "123456" // Ignore + } +} + + + + +*/ +type EventBase struct { + Event string `json:"e"` // event + Time int64 `json:"E"` +} + diff --git a/bbgo/kline.go b/bbgo/kline.go new file mode 100644 index 000000000..04b321d64 --- /dev/null +++ b/bbgo/kline.go @@ -0,0 +1,137 @@ +package bbgo + +import ( + "fmt" +) + +type KLineEvent struct { + EventBase + Symbol string `json:"s"` + KLine *KLine `json:"k,omitempty"` +} + +type KLine struct { + StartTime int64 `json:"t"` + EndTime int64 `json:"T"` + + Symbol string `json:"s"` + Interval string `json:"i"` + + Open string `json:"o"` + Close string `json:"c"` + High string `json:"h"` + Low string `json:"l"` + Volume string `json:"V"` // taker buy base asset volume (like 10 BTC) + QuoteVolume string `json:"Q"` // taker buy quote asset volume (like 1000USDT) + + LastTradeID int `json:"L"` + NumberOfTrades int `json:"n"` + Closed bool `json:"x"` +} + +func (k KLine) GetTrend() int { + o := k.GetOpen() + c := k.GetClose() + + if c > o { + return 1 + } else if c < o { + return -1 + } + return 0 +} + +func (k KLine) GetHigh() float64 { + return MustParseFloat(k.High) +} + +func (k KLine) GetLow() float64 { + return MustParseFloat(k.Low) +} + +func (k KLine) GetOpen() float64 { + return MustParseFloat(k.Open) +} + +func (k KLine) GetClose() float64 { + return MustParseFloat(k.Close) +} + +func (k KLine) GetMaxChange() float64 { + return k.GetHigh() - k.GetLow() +} + +func (k KLine) GetThickness() float64 { + return k.GetChange() / k.GetMaxChange() +} + +func (k KLine) GetChange() float64 { + return k.GetClose() - k.GetOpen() +} + +func (k KLine) String() string { + return fmt.Sprintf("%s %s Open: % 14s Close: % 14s High: % 14s Low: % 14s Volume: % 13s Change: % 13f %s", k.Symbol, k.Interval, k.Open, k.Close, k.High, k.Low, k.Volume, k.GetChange(), k.Interval) +} + +type KLineWindow []KLine + +func (w KLineWindow) Len() int { + return len(w) +} + +func (w KLineWindow) GetOpen() float64 { + return w[0].GetOpen() +} + +func (w KLineWindow) GetClose() float64 { + end := len(w) - 1 + return w[end].GetClose() +} + +func (w KLineWindow) GetHigh() float64 { + high := w.GetOpen() + for _, line := range w { + val := line.GetHigh() + if val > high { + high = val + } + } + return high +} + +func (w KLineWindow) GetLow() float64 { + low := w.GetOpen() + for _, line := range w { + val := line.GetHigh() + if val < low { + low = val + } + } + return low +} + +func (w KLineWindow) GetChange() float64 { + return w.GetClose() - w.GetOpen() +} + +func (w KLineWindow) GetMaxChange() float64 { + return w.GetHigh() - w.GetLow() +} + +func (w *KLineWindow) Add(line KLine) { + *w = append(*w, line) +} + +func (w *KLineWindow) Truncate(size int) { + if len(*w) <= size { + return + } + + end := len(*w) - 1 + start := end - size + if start < 0 { + start = 0 + } + *w = (*w)[end-5 : end] +} + diff --git a/bbgo/parse.go b/bbgo/parse.go new file mode 100644 index 000000000..a9e080bca --- /dev/null +++ b/bbgo/parse.go @@ -0,0 +1,12 @@ +package bbgo + +import "strconv" + +func MustParseFloat(s string) float64 { + v, err := strconv.ParseFloat(s, 64) + if err != nil { + panic(err) + } + return v +} + diff --git a/bbgo/parser.go b/bbgo/parser.go new file mode 100644 index 000000000..5e409dc8d --- /dev/null +++ b/bbgo/parser.go @@ -0,0 +1,196 @@ +package bbgo + +import ( + "encoding/json" + "fmt" + "github.com/valyala/fastjson" +) + +/* + +executionReport + +{ + "e": "executionReport", // KLineEvent type + "E": 1499405658658, // KLineEvent time + "s": "ETHBTC", // Symbol + "c": "mUvoqJxFIILMdfAW5iGSOW", // Client order ID + "S": "BUY", // Side + "o": "LIMIT", // Order type + "f": "GTC", // Time in force + "q": "1.00000000", // Order quantity + "p": "0.10264410", // Order price + "P": "0.00000000", // Stop price + "F": "0.00000000", // Iceberg quantity + "g": -1, // OrderListId + "C": null, // Original client order ID; This is the ID of the order being canceled + "x": "NEW", // Current execution type + "X": "NEW", // Current order status + "r": "NONE", // Order reject reason; will be an error code. + "i": 4293153, // Order ID + "l": "0.00000000", // Last executed quantity + "z": "0.00000000", // Cumulative filled quantity + "L": "0.00000000", // Last executed price + "n": "0", // Commission amount + "N": null, // Commission asset + "T": 1499405658657, // Transaction time + "t": -1, // Trade ID + "I": 8641984, // Ignore + "w": true, // Is the order on the book? + "m": false, // Is this trade the maker side? + "M": false, // Ignore + "O": 1499405658657, // Order creation time + "Z": "0.00000000", // Cumulative quote asset transacted quantity + "Y": "0.00000000", // Last quote asset transacted quantity (i.e. lastPrice * lastQty) + "Q": "0.00000000" // Quote Order Qty +} +*/ +type ExecutionReportEvent struct { + EventBase + + Symbol string `json:"s"` + ClientOrderID string `json:"c"` + Side string `json:"S"` + OrderType string `json:"o"` + TimeInForce string `json:"f"` + + Quantity string `json:"q"` + Price string `json:"p"` + StopPrice string `json:"P"` + + CurrentExecutionType string `json:"x"` + CurrentOrderStatus string `json:"X"` + + OrderID int `json:"i"` + + LastExecutedQuantity string `json:"l"` + CumulativeFilledQuantity string `json:"z"` + LastExecutedPrice string `json:"L"` + + OrderCreationTime int `json:"O"` +} + +/* +balanceUpdate + +{ + "e": "balanceUpdate", //KLineEvent Type + "E": 1573200697110, //KLineEvent Time + "a": "BTC", //Asset + "d": "100.00000000", //Balance Delta + "T": 1573200697068 //Clear Time +} +*/ +type BalanceUpdateEvent struct { + EventBase + + Asset string `json:"a"` + Delta string `json:"d"` + ClearTime int64 `json:"T"` +} + +/* + +outboundAccountInfo + +{ + "e": "outboundAccountInfo", // KLineEvent type + "E": 1499405658849, // KLineEvent time + "m": 0, // Maker commission rate (bips) + "t": 0, // Taker commission rate (bips) + "b": 0, // Buyer commission rate (bips) + "s": 0, // Seller commission rate (bips) + "T": true, // Can trade? + "W": true, // Can withdraw? + "D": true, // Can deposit? + "u": 1499405658848, // Time of last account update + "B": [ // Balances array + { + "a": "LTC", // Asset + "f": "17366.18538083", // Free amount + "l": "0.00000000" // Locked amount + }, + { + "a": "BTC", + "f": "10537.85314051", + "l": "2.19464093" + }, + { + "a": "ETH", + "f": "17902.35190619", + "l": "0.00000000" + }, + { + "a": "BNC", + "f": "1114503.29769312", + "l": "0.00000000" + }, + { + "a": "NEO", + "f": "0.00000000", + "l": "0.00000000" + } + ], + "P": [ // Account Permissions + "SPOT" + ] +} + +*/ +type Balance struct { + Asset string `json:"a"` + Free string `json:"f"` + Locked string `json:"l"` +} +type OutboundAccountInfoEvent struct { + EventBase + + MakerCommissionRate int `json:"m"` + TakerCommissionRate int `json:"t"` + BuyerCommissionRate int `json:"b"` + SellerCommissionRate int `json:"s"` + + CanTrade bool `json:"T"` + CanWithdraw bool `json:"W"` + CanDeposit bool `json:"D"` + + LastAccountUpdateTime int `json:"u"` + + Balances []Balance `json:"B,omitempty"` + Permissions []string `json:"P,omitempty"` +} + +func ParseEvent(message string) (interface{}, error) { + val, err := fastjson.Parse(message) + if err != nil { + return nil, err + } + + eventType := string(val.GetStringBytes("e")) + + switch eventType { + case "kline": + var event KLineEvent + err := json.Unmarshal([]byte(message), &event) + return &event, err + + case "outboundAccountInfo", "outboundAccountPosition": + var event OutboundAccountInfoEvent + err := json.Unmarshal([]byte(message), &event) + return &event, err + + case "balanceUpdate": + var event BalanceUpdateEvent + err := json.Unmarshal([]byte(message), &event) + return &event, err + + case "executionReport": + var event ExecutionReportEvent + err := json.Unmarshal([]byte(message), &event) + return &event, err + + } + + return nil, fmt.Errorf("unsupported message: %s", message) +} + diff --git a/bbgo/pnl.go b/bbgo/pnl.go new file mode 100644 index 000000000..900f72dee --- /dev/null +++ b/bbgo/pnl.go @@ -0,0 +1,66 @@ +package bbgo + +import log "github.com/sirupsen/logrus" + +func CalculateAverageCost(trades []Trade) (averageCost float64) { + var totalCost = 0.0 + var totalQuantity = 0.0 + for _, t := range trades { + if t.IsBuyer { + totalCost += t.Price * t.Volume + totalQuantity += t.Volume + } else { + totalCost -= t.Price * t.Volume + totalQuantity -= t.Volume + } + } + + averageCost = totalCost / totalQuantity + return +} + +func CalculateCostAndProfit(trades []Trade, currentPrice float64) (averageBidPrice, stock, profit, fee float64) { + var bidVolume = 0.0 + var bidAmount = 0.0 + var bidFee = 0.0 + for _, t := range trades { + if t.IsBuyer { + bidVolume += t.Volume + bidAmount += t.Price * t.Volume + switch t.FeeCurrency { + case "BTC": + bidFee += t.Price * t.Fee + } + } + } + + log.Infof("average bid price = (total amount %f + total fee %f) / volume %f", bidAmount, bidFee, bidVolume) + averageBidPrice = (bidAmount + bidFee) / bidVolume + + var feeRate = 0.001 + var askVolume = 0.0 + var askFee = 0.0 + for _, t := range trades { + if !t.IsBuyer { + profit += (t.Price - averageBidPrice) * t.Volume + askVolume += t.Volume + switch t.FeeCurrency { + case "USDT": + askFee += t.Fee + } + } + } + + profit -= askFee + + stock = bidVolume - askVolume + futureFee := 0.0 + if stock > 0 { + stockfee := currentPrice * feeRate * stock + profit += (currentPrice-averageBidPrice)*stock - stockfee + futureFee += stockfee + } + + fee = bidFee + askFee + futureFee + return +} diff --git a/bbgo/trade.go b/bbgo/trade.go new file mode 100644 index 000000000..a31570b25 --- /dev/null +++ b/bbgo/trade.go @@ -0,0 +1,92 @@ +package bbgo + +import ( + "context" + "github.com/adshao/go-binance" + log "github.com/sirupsen/logrus" + "strconv" + "time" +) + +type Trade struct { + ID int64 + Price float64 + Volume float64 + IsBuyer bool + IsMaker bool + Time time.Time + + Fee float64 + FeeCurrency string +} + +func QueryTrades(ctx context.Context, client *binance.Client, market string, startTime time.Time) (trades []Trade, err error) { + var lastTradeID int64 = 0 + for { + req := client.NewListTradesService(). + Limit(1000). + Symbol(market). + StartTime(startTime.UnixNano() / 1000000) + + if lastTradeID > 0 { + req.FromID(lastTradeID) + } + + bnTrades, err := req.Do(ctx) + if err != nil { + return nil, err + } + + if len(bnTrades) <= 1 { + break + } + + for _, t := range bnTrades { + // skip trade ID that is the same + if t.ID == lastTradeID { + continue + } + + var side string + if t.IsBuyer { + side = "BUY" + } else { + side = "SELL" + } + + // trade time + tt := time.Unix(0, t.Time*1000000) + log.Infof("trade: %d %4s Price: % 13s Volume: % 13s %s", t.ID, side, t.Price, t.Quantity, tt) + + price, err := strconv.ParseFloat(t.Price, 64) + if err != nil { + return nil, err + } + + quantity, err := strconv.ParseFloat(t.Quantity, 64) + if err != nil { + return nil, err + } + + fee, err := strconv.ParseFloat(t.Commission, 64) + if err != nil { + return nil, err + } + + trades = append(trades, Trade{ + ID: t.ID, + Price: price, + Volume: quantity, + IsBuyer: t.IsBuyer, + IsMaker: t.IsMaker, + Fee: fee, + FeeCurrency: t.CommissionAsset, + Time: tt, + }) + + lastTradeID = t.ID + } + } + + return trades, nil +} diff --git a/util/math.go b/util/math.go new file mode 100644 index 000000000..7ac2543e1 --- /dev/null +++ b/util/math.go @@ -0,0 +1,39 @@ +package util + +import ( + "math" + "strconv" +) + +const MaxDigits = 18 // MAX_INT64 ~ 9 * 10^18 + +var Pow10Table = [MaxDigits + 1]int64{ + 1, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, +} + +func Pow10(n int64) int64 { + if n < 0 || n > MaxDigits { + return 0 + } + return Pow10Table[n] +} + +var NegPow10Table = [MaxDigits + 1]float64{ + 1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15, 1e-16, 1e-17, 1e-18, +} + +func NegPow10(n int64) float64 { + if n < 0 || n > MaxDigits { + return 0.0 + } + return NegPow10Table[n] +} + +func Float64ToStr(input float64) string { + return strconv.FormatFloat(input, 'f', -1, 64) +} + +func Float64ToInt64(input float64) int64 { + // eliminate rounding error for IEEE754 floating points + return int64(math.Round(input)) +} diff --git a/util/rate_limit.go b/util/rate_limit.go new file mode 100644 index 000000000..63c83bf82 --- /dev/null +++ b/util/rate_limit.go @@ -0,0 +1,19 @@ +package util + +import ( + "fmt" + "time" + + "golang.org/x/time/rate" +) + +func ShouldDelay(l *rate.Limiter, minInterval time.Duration) time.Duration { + return l.Reserve().Delay() +} + +func NewValidLimiter(r rate.Limit, b int) (*rate.Limiter, error) { + if b <= 0 || r <= 0 { + return nil, fmt.Errorf("Bad rate limit config. Insufficient tokens. (rate=%f, b=%d)", r, b) + } + return rate.NewLimiter(r, b), nil +} diff --git a/util/rate_limit_test.go b/util/rate_limit_test.go new file mode 100644 index 000000000..6efd54a2a --- /dev/null +++ b/util/rate_limit_test.go @@ -0,0 +1,43 @@ +package util + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.org/x/time/rate" +) + +func TestNewValidRateLimiter(t *testing.T) { + cases := []struct { + name string + r rate.Limit + b int + hasError bool + }{ + {"valid limiter", 0.1, 1, false}, + {"zero rate", 0, 1, true}, + {"zero burst", 0.1, 0, true}, + {"both zero", 0, 0, true}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + limiter, err := NewValidLimiter(c.r, c.b) + assert.Equal(t, c.hasError, err != nil) + if !c.hasError { + assert.NotNil(t, limiter) + } + }) + } +} + +func TestShouldDelay(t *testing.T) { + minInterval := time.Second * 3 + maxRate := rate.Limit(1 / minInterval.Seconds()) + limiter := rate.NewLimiter(maxRate, 1) + assert.Equal(t, time.Duration(0), ShouldDelay(limiter, minInterval)) + for i := 0; i < 100; i++ { + assert.True(t, ShouldDelay(limiter, minInterval) > 0) + } +} diff --git a/util/retry.go b/util/retry.go new file mode 100644 index 000000000..84cdd68c8 --- /dev/null +++ b/util/retry.go @@ -0,0 +1,69 @@ +package util + +import ( + "context" + "time" + + "github.com/pkg/errors" +) + +const ( + InfiniteRetry = 0 +) + +type RetryPredicator func(e error) bool + +// Retry retrys the passed functin for "attempts" times, if passed function return error. Setting attemps to zero means keep retrying. +func Retry(ctx context.Context, attempts int, duration time.Duration, fnToRetry func() error, errHandler func(error), predicators ...RetryPredicator) (err error) { + infinite := false + if attempts == InfiniteRetry { + infinite = true + } + + for attempts > 0 || infinite { + select { + case <-ctx.Done(): + errMsg := "return for context done" + if err != nil { + return errors.Wrap(err, errMsg) + } else { + return errors.New(errMsg) + } + default: + if err = fnToRetry(); err == nil { + return nil + } + + if !needRetry(err, predicators) { + return err + } + + err = errors.Wrapf(err, "failed in retry: countdown: %v", attempts) + + if errHandler != nil { + errHandler(err) + } + + if !infinite { + attempts-- + } + + time.Sleep(duration) + } + } + + return err +} + +func needRetry(err error, predicators []RetryPredicator) bool { + if err == nil { + return false + } + + // If no predicators specified, we will retry for all errors + if len(predicators) == 0 { + return true + } + + return predicators[0](err) +} diff --git a/util/retry_test.go b/util/retry_test.go new file mode 100644 index 000000000..ff604e1ad --- /dev/null +++ b/util/retry_test.go @@ -0,0 +1,106 @@ +package util + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func addAndCheck(a *int, target int) error { + if *a++; *a == target { + return nil + } else { + return fmt.Errorf("a is not %v. It is %v\n", target, *a) + } +} + +func TestRetry(t *testing.T) { + type test struct { + input int + targetNum int + ans int + ansErr error + } + tests := []test{ + {input: 0, targetNum: 3, ans: 3, ansErr: nil}, + {input: 0, targetNum: 10, ans: 3, ansErr: errors.New("failed in retry")}, + } + + for _, tc := range tests { + errHandled := false + + err := Retry(context.Background(), 3, 1*time.Second, func() error { + return addAndCheck(&tc.input, tc.targetNum) + }, func(e error) { errHandled = true }) + + assert.Equal(t, true, errHandled) + if tc.ansErr == nil { + assert.NoError(t, err) + } else { + assert.Contains(t, err.Error(), tc.ansErr.Error()) + } + assert.Equal(t, tc.ans, tc.input) + } +} + +func TestRetryWithPredicator(t *testing.T) { + type test struct { + count int + f func() error + errHandler func(error) + predicator RetryPredicator + ansCount int + ansErr error + } + knownErr := errors.New("Duplicate entry '1-389837488-1' for key 'UNI_Trade'") + unknownErr := errors.New("Some Error") + tests := []test{ + { + predicator: func(err error) bool { + return !strings.Contains(err.Error(), "Duplicate entry") + }, + f: func() error { return knownErr }, + ansCount: 1, + ansErr: knownErr, + }, + { + predicator: func(err error) bool { + return !strings.Contains(err.Error(), "Duplicate entry") + }, + f: func() error { return unknownErr }, + ansCount: 3, + ansErr: unknownErr, + }, + } + attempts := 3 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for _, tc := range tests { + err := Retry(ctx, attempts, 100*time.Millisecond, func() error { + tc.count++ + return tc.f() + }, tc.errHandler, tc.predicator) + + assert.Equal(t, tc.ansCount, tc.count) + assert.EqualError(t, errors.Cause(err), tc.ansErr.Error(), "should be equal") + } +} + +func TestRetryCtxCancel(t *testing.T) { + result := int(0) + target := int(3) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := Retry(ctx, 5, 1*time.Second, func() error { return addAndCheck(&result, target) }, func(error) {}) + assert.Error(t, err) + fmt.Println("Error:", err.Error()) + assert.Equal(t, int(0), result) +} diff --git a/websocket/backoff.go b/websocket/backoff.go new file mode 100644 index 000000000..f366a55b4 --- /dev/null +++ b/websocket/backoff.go @@ -0,0 +1,78 @@ +package websocket + +import ( + "time" +) + +const ( + DefaultMinBackoff = 5 * time.Second + DefaultMaxBackoff = time.Minute + + DefaultBackoffFactor = 2 +) + +type Backoff struct { + attempt int64 + Factor float64 + cur, Min, Max time.Duration +} + +// Duration returns the duration for the current attempt before incrementing +// the attempt counter. See ForAttempt. +func (b *Backoff) Duration() time.Duration { + d := b.calculate(b.attempt) + b.attempt++ + return d +} + +func (b *Backoff) calculate(attempt int64) time.Duration { + min := b.Min + if min <= 0 { + min = DefaultMinBackoff + } + max := b.Max + if max <= 0 { + max = DefaultMaxBackoff + } + if min >= max { + return max + } + factor := b.Factor + if factor <= 0 { + factor = DefaultBackoffFactor + } + cur := b.cur + if cur < min { + cur = min + } else if cur > max { + cur = max + } + + //calculate this duration + next := cur + if attempt > 0 { + next = time.Duration(float64(cur) * factor) + } + + if next < cur { + // overflow + next = max + } else if next <= min { + next = min + } else if next >= max { + next = max + } + b.cur = next + return next +} + +// Reset restarts the current attempt counter at zero. +func (b *Backoff) Reset() { + b.attempt = 0 + b.cur = b.Min +} + +// Attempt returns the current attempt counter value. +func (b *Backoff) Attempt() int64 { + return b.attempt +} diff --git a/websocket/backoff_test.go b/websocket/backoff_test.go new file mode 100644 index 000000000..81a8a7fb3 --- /dev/null +++ b/websocket/backoff_test.go @@ -0,0 +1,172 @@ +package websocket + +import ( + "math" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBackoff(t *testing.T) { + b := &Backoff{ + Min: 200 * time.Millisecond, + Max: 100 * time.Second, + Factor: 2, + } + + equals(t, b.Duration(), 200*time.Millisecond) + equals(t, b.Duration(), 400*time.Millisecond) + equals(t, b.Duration(), 800*time.Millisecond) + b.Reset() + equals(t, b.Duration(), 200*time.Millisecond) +} + +func TestReachMaxBackoff(t *testing.T) { + b := &Backoff{ + Min: 2 * time.Second, + Max: 100 * time.Second, + Factor: 10, + } + + equals(t, b.Duration(), 2*time.Second) + equals(t, b.Duration(), 20*time.Second) + equals(t, b.Duration(), b.Max) + b.Reset() + equals(t, b.Duration(), 2*time.Second) +} + +func TestAttempt(t *testing.T) { + b := &Backoff{ + Min: 2 * time.Second, + Max: 100 * time.Second, + Factor: 10, + } + + equals(t, b.Attempt(), int64(0)) + equals(t, b.Duration(), 2*time.Second) + equals(t, b.Attempt(), int64(1)) + equals(t, b.Duration(), 20*time.Second) + equals(t, b.Attempt(), int64(2)) + equals(t, b.Duration(), b.Max) + equals(t, b.Attempt(), int64(3)) + b.Reset() + equals(t, b.Attempt(), int64(0)) + equals(t, b.Duration(), 2*time.Second) + equals(t, b.Attempt(), int64(1)) +} + +func TestAttemptWithZeroValue(t *testing.T) { + b := &Backoff{} + + cur := DefaultMinBackoff + equals(t, b.Attempt(), int64(0)) + equals(t, b.Duration(), cur) + for i := 1; i < 10; i++ { + equals(t, b.Attempt(), int64(i)) + cur *= DefaultBackoffFactor + if cur >= DefaultMaxBackoff { + equals(t, b.Duration(), DefaultMaxBackoff) + break + } + equals(t, b.Duration(), cur) + } + + b.Reset() + equals(t, b.Attempt(), int64(0)) + equals(t, b.Duration(), DefaultMinBackoff) + equals(t, b.Attempt(), int64(1)) +} + +func TestNegOrZeroMin(t *testing.T) { + b := &Backoff{ + Min: 0, + } + equals(t, b.Duration(), DefaultMinBackoff) + + b2 := &Backoff{ + Min: -1 * time.Second, + } + equals(t, b2.Duration(), DefaultMinBackoff) +} + +func TestNegOrZeroMax(t *testing.T) { + b := &Backoff{ + Min: DefaultMaxBackoff, + Max: 0, + } + equals(t, b.Duration(), DefaultMaxBackoff) + + b2 := &Backoff{ + Min: DefaultMaxBackoff, + Max: -1 * time.Second, + } + equals(t, b2.Duration(), DefaultMaxBackoff) +} + +func TestMinLargerThanMax(t *testing.T) { + b := &Backoff{ + Min: 500 * time.Second, + Max: 100 * time.Second, + Factor: 1, + } + equals(t, b.Duration(), b.Max) +} + +func TestFakeCur(t *testing.T) { + b := &Backoff{ + Min: 100 * time.Second, + Max: 500 * time.Second, + Factor: 1, + cur: 1000 * time.Second, + } + equals(t, b.Duration(), b.Max) +} + +func TestLargeFactor(t *testing.T) { + b := &Backoff{ + Min: 1 * time.Second, + Max: 10 * time.Second, + Factor: math.MaxFloat64, + } + equals(t, b.Duration(), b.Min) + for i := 0; i < 100; i++ { + equals(t, b.Duration(), b.Max) + } + b.Reset() + equals(t, b.Duration(), b.Min) +} + +func TestLargeMinMax(t *testing.T) { + b := &Backoff{ + Min: time.Duration(math.MaxInt64 - 1), + Max: time.Duration(math.MaxInt64), + Factor: math.MaxFloat64, + } + equals(t, b.Duration(), b.Min) + for i := 0; i < 100; i++ { + equals(t, b.Duration(), b.Max) + } + b.Reset() + equals(t, b.Duration(), b.Min) +} + +func TestManyAttempts(t *testing.T) { + b := &Backoff{ + Min: 10 * time.Second, + Max: time.Duration(math.MaxInt64), + Factor: 1000, + } + for i := 0; i < 10000; i++ { + b.Duration() + } + equals(t, b.Duration(), b.Max) + b.Reset() +} + +func equals(t *testing.T, v1, v2 interface{}) { + if !reflect.DeepEqual(v1, v2) { + assert.Failf(t, "not equal", "Got %v, Expecting %v", v1, v2) + } +} diff --git a/websocket/client.go b/websocket/client.go new file mode 100644 index 000000000..c858c4582 --- /dev/null +++ b/websocket/client.go @@ -0,0 +1,368 @@ +package websocket + +import ( + "context" + "net/http" + "sync" + "time" + + "golang.org/x/time/rate" + + "github.com/c9s/bbgo/pkg/log" + "github.com/c9s/bbgo/pkg/util" + + "github.com/gorilla/websocket" + "github.com/pkg/errors" +) + +const DefaultMessageBufferSize = 128 + +const DefaultWriteTimeout = 30 * time.Second +const DefaultReadTimeout = 30 * time.Second + +var ErrReconnectContextDone = errors.New("reconnect canceled due to context done.") +var ErrReconnectFailed = errors.New("failed to reconnect") +var ErrConnectionLost = errors.New("connection lost") + +var MaxReconnectRate = rate.Limit(1 / DefaultMinBackoff.Seconds()) + +// WebSocketClient allows to connect and receive stream data +type WebSocketClient struct { + // Url is the websocket connection location, start with ws:// or wss:// + Url string + + // conn is the current websocket connection, please note the connection + // object can be replaced with a new connection object when the connection + // is unexpected closed. + conn *websocket.Conn + + // Dialer is used for creating the websocket connection + Dialer *websocket.Dialer + + // requestHeader is used for the Dial function call. Some credential can be + // stored in the http request header for authentication + requestHeader http.Header + + // messages is a read-only channel, received messages will be sent to this + // channel. + messages chan Message + + readTimeout time.Duration + + writeTimeout time.Duration + + onConnect []func(c Client) + + onDisconnect []func(c Client) + + // cancel is mapped to the ctx context object + cancel func() + + readerClosed chan struct{} + + connected bool + + mu sync.Mutex + + reconnectCh chan struct{} + + backoff Backoff + + limiter *rate.Limiter +} + +type Message struct { + // websocket.BinaryMessage or websocket.TextMessage + Type int + Body []byte +} + +func (c *WebSocketClient) Messages() <-chan Message { + return c.messages +} + +func (c *WebSocketClient) SetReadTimeout(timeout time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + c.readTimeout = timeout +} + +func (c *WebSocketClient) SetWriteTimeout(timeout time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + c.writeTimeout = timeout +} + +func (c *WebSocketClient) OnConnect(f func(c Client)) { + c.mu.Lock() + defer c.mu.Unlock() + c.onConnect = append(c.onConnect, f) +} + +func (c *WebSocketClient) OnDisconnect(f func(c Client)) { + c.mu.Lock() + defer c.mu.Unlock() + c.onDisconnect = append(c.onDisconnect, f) +} + +func (c *WebSocketClient) WriteTextMessage(message []byte) error { + return c.WriteMessage(websocket.TextMessage, message) +} + +func (c *WebSocketClient) WriteBinaryMessage(message []byte) error { + return c.WriteMessage(websocket.BinaryMessage, message) +} + +func (c *WebSocketClient) WriteMessage(messageType int, data []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.connected { + return ErrConnectionLost + } + + if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil { + return err + } + return c.conn.WriteMessage(messageType, data) +} + +func (c *WebSocketClient) readMessages() error { + c.mu.Lock() + if !c.connected { + c.mu.Unlock() + return ErrConnectionLost + } + timeout := c.readTimeout + conn := c.conn + c.mu.Unlock() + + if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { + return err + } + msgtype, message, err := conn.ReadMessage() + if err != nil { + return err + } + c.messages <- Message{msgtype, message} + return nil +} + +// listen starts a goroutine for reading message and tries to re-connect to the +// server when the reader returns error +// +// Please note we should always break the reader loop if there is any error +// returned from the server. +func (c *WebSocketClient) listen(ctx context.Context) { + // The life time of both channels "readerClosed" and "reconnectCh" is bound to one connection. + // Each channel should be created before loop starts and be closed after loop ends. + // "readerClosed" is used to inform "Close()" reader loop ends. + // "reconnectCh" is used to centralize reconnection logics in this reader loop. + c.mu.Lock() + c.readerClosed = make(chan struct{}) + c.reconnectCh = make(chan struct{}, 1) + c.mu.Unlock() + defer func() { + c.mu.Lock() + close(c.readerClosed) + close(c.reconnectCh) + c.reconnectCh = nil + c.mu.Unlock() + }() + + for { + select { + + case <-ctx.Done(): + return + + case <-c.reconnectCh: + // it could be i/o timeout for network disconnection + // or it could be invoked from outside. + c.SetDisconnected() + var maxTries = 1 + if _, response, err := c.reconnect(ctx, maxTries); err != nil { + if err == ErrReconnectContextDone { + log.Debugf("[websocket] context canceled. stop reconnecting.") + return + } + log.Warnf("[websocket] failed to reconnect after %d tries!! error: %v response: %v", maxTries, err, response) + c.Reconnect() + } + + default: + if err := c.readMessages(); err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { + log.Warnf("[websocket] unexpected close error reconnecting: %v", err) + } + + log.Warnf("[websocket] failed to read message. error: %+v", err) + c.Reconnect() + } + } + } +} + +// Reconnect triggers reconnection logics +func (c *WebSocketClient) Reconnect() { + c.mu.Lock() + defer c.mu.Unlock() + + select { + // c.reconnectCh is a buffered channel with cap=1. + // At most one reconnect signal could be processed. + case c.reconnectCh <- struct{}{}: + default: + // Entering here means it is already reconnecting. + // Drop the current reconnect signal. + } +} + +// Close gracefully shuts down the reader and the connection +// ctx is the context used for shutdown process. +func (c *WebSocketClient) Close() (err error) { + c.mu.Lock() + // leave the listen goroutine before we close the connection + // checking nil is to handle calling "Close" before "Connect" is called + if c.cancel != nil { + c.cancel() + } + c.mu.Unlock() + c.SetDisconnected() + + // wait for the reader func to be closed + if c.readerClosed != nil { + <-c.readerClosed + } + return err +} + +// reconnect tries to create a new connection from the existing dialer +func (c *WebSocketClient) reconnect(ctx context.Context, maxTries int) (*websocket.Conn, *http.Response, error) { + log.Debugf("[websocket] start reconnecting to %q", c.Url) + + select { + case <-ctx.Done(): + return nil, nil, ErrReconnectContextDone + default: + } + + if s := util.ShouldDelay(c.limiter, DefaultMinBackoff); s > 0 { + log.Warn("[websocket] reconnect too frequently. Sleep for ", s) + time.Sleep(s) + } + + log.Warnf("[websocket] reconnecting x %d to %q", c.backoff.Attempt()+1, c.Url) + conn, resp, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader) + if err != nil { + dur := c.backoff.Duration() + log.Warnf("failed to dial %s: %v, response: %+v. Wait for %v", c.Url, err, resp, dur) + time.Sleep(dur) + return nil, nil, ErrReconnectFailed + } + + log.Infof("[websocket] reconnected to %q", c.Url) + // Reset backoff value if connected. + c.backoff.Reset() + c.setConn(conn) + c.setPingHandler(conn) + + return conn, resp, err +} + +// Conn returns the current active connection instance +func (c *WebSocketClient) Conn() (conn *websocket.Conn) { + c.mu.Lock() + conn = c.conn + c.mu.Unlock() + return conn +} + +func (c *WebSocketClient) setConn(conn *websocket.Conn) { + // Disconnect old connection before replacing with new one. + c.SetDisconnected() + + c.mu.Lock() + c.conn = conn + c.connected = true + c.mu.Unlock() + for _, f := range c.onConnect { + go f(c) + } +} + +func (c *WebSocketClient) setPingHandler(conn *websocket.Conn) { + conn.SetPingHandler(func(message string) error { + if err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil { + return err + } + c.mu.Lock() + defer c.mu.Unlock() + return conn.SetReadDeadline(time.Now().Add(c.readTimeout)) + }) +} + +func (c *WebSocketClient) SetDisconnected() { + c.mu.Lock() + closed := false + if c.conn != nil { + closed = true + c.conn.Close() + } + c.connected = false + c.conn = nil + c.mu.Unlock() + + if closed { + // Only call disconnect callbacks when a connection is closed + for _, f := range c.onDisconnect { + go f(c) + } + } +} + +func (c *WebSocketClient) IsConnected() (ret bool) { + c.mu.Lock() + ret = c.connected + c.mu.Unlock() + return ret +} + +func (c *WebSocketClient) Connect(basectx context.Context) error { + // maintain a context by the client it self, so that we can manually shutdown the connection + ctx, cancel := context.WithCancel(basectx) + c.cancel = cancel + + conn, _, err := c.Dialer.DialContext(ctx, c.Url, c.requestHeader) + if err == nil { + // setup connection only when connected + c.setConn(conn) + c.setPingHandler(conn) + } + + // 1) if connection is built up, start listening for messages. + // 2) if connection is NOT ready, start reconnecting infinitely. + go c.listen(ctx) + + return err +} + +func New(url string, requestHeader http.Header) *WebSocketClient { + return NewWithDialer(url, websocket.DefaultDialer, requestHeader) +} + +func NewWithDialer(url string, d *websocket.Dialer, requestHeader http.Header) *WebSocketClient { + limiter, err := util.NewValidLimiter(MaxReconnectRate, 1) + if err != nil { + log.WithError(err).Panic("Invalid rate limiter") + } + return &WebSocketClient{ + Url: url, + Dialer: d, + requestHeader: requestHeader, + readTimeout: DefaultReadTimeout, + writeTimeout: DefaultWriteTimeout, + messages: make(chan Message, DefaultMessageBufferSize), + limiter: limiter, + } +} diff --git a/websocket/client_test.go b/websocket/client_test.go new file mode 100644 index 000000000..90dcd4e55 --- /dev/null +++ b/websocket/client_test.go @@ -0,0 +1,202 @@ +package websocket + +import ( + "context" + "net" + "net/http" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "***REMOVED***/pkg/log" + "***REMOVED***/pkg/testing/testutil" + + "github.com/gorilla/websocket" +) + +func messageTicker(t *testing.T, ctx context.Context, conn *websocket.Conn) { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + select { + case tt := <-ticker.C: + err := conn.WriteMessage(websocket.TextMessage, []byte(tt.String())) + assert.NoError(t, err) + case <-ctx.Done(): + return + } + } +} + +func messageReader(t *testing.T, ctx context.Context, conn *websocket.Conn) { + for { + _, message, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Errorf("unexpected closed error: %v", err) + } + break + } + t.Logf("message: %v", message) + select { + case <-ctx.Done(): + return + default: + } + } +} + +func TestReconnect(t *testing.T) { + basectx := context.Background() + + wsHandler := func(ctx context.Context) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("upgrade error: %v", err) + return + } + go messageTicker(t, ctx, conn) + messageReader(t, ctx, conn) + } + } + + serverctx, cancelserver := context.WithCancel(basectx) + server := testutil.NewWebSocketServerFunc(t, 0, "/ws", wsHandler(serverctx)) + go func() { + err := server.ListenAndServe() + assert.Equal(t, http.ErrServerClosed, err) + }() + + _, port, err := net.SplitHostPort(server.Addr) + assert.NoError(t, err) + + log.Debugf("waiting for server port ready") + testutil.WaitForPort(t, server.Addr, 3) + + u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"} + log.Infof("url: %s", u.String()) + + client := New(u.String(), nil) + + // start the message reader + clientctx, cancelClient := context.WithCancel(basectx) + defer cancelClient() + + err = client.Connect(clientctx) + assert.NoError(t, err) + assert.NotNil(t, client) + client.SetReadTimeout(2 * time.Second) + + // read one message + log.Debugf("waiting for message...") + msg := <-client.messages + log.Debugf("received message: %v", msg) + + triggerReconnectManyTimes := func() { + // Forcedly reconnect multiple times + for i := 0; i < 10; i = i + 1 { + client.Reconnect() + } + } + + log.Debugf("reconnecting...") + triggerReconnectManyTimes() + msg = <-client.messages + log.Debugf("received message: %v", msg) + + log.Debugf("shutting down server...") + err = server.Shutdown(serverctx) + assert.NoError(t, err) + cancelserver() + server.Close() + + time.Sleep(500 * time.Millisecond) + + triggerReconnectManyTimes() + + log.Debugf("restarting server...") + server2ctx, cancelserver2 := context.WithCancel(basectx) + server2 := testutil.NewWebSocketServerWithAddressFunc(t, server.Addr, "/ws", wsHandler(server2ctx)) + go func() { + err := server2.ListenAndServe() + assert.Equal(t, http.ErrServerClosed, err) + }() + + triggerReconnectManyTimes() + + log.Debugf("waiting for server2 port ready") + testutil.WaitForPort(t, server.Addr, 3) + + log.Debugf("waiting for message from server2...") + msg = <-client.messages + log.Debugf("received message: %v", msg) + + err = client.Close() + assert.NoError(t, err) + + err = server2.Shutdown(server2ctx) + assert.NoError(t, err) + cancelserver2() +} + +func TestConnect(t *testing.T) { + basectx := context.Background() + ctx, cancel := context.WithCancel(basectx) + + server := testutil.NewWebSocketServerFunc(t, 0, "/ws", func(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("upgrade error: %v", err) + return + } + + go messageTicker(t, ctx, conn) + messageReader(t, ctx, conn) + }) + + go func() { + err := server.ListenAndServe() + assert.Equal(t, http.ErrServerClosed, err) + }() + + _, port, err := net.SplitHostPort(server.Addr) + assert.NoError(t, err) + + testutil.WaitForPort(t, server.Addr, 3) + + u := url.URL{Scheme: "ws", Host: net.JoinHostPort("127.0.0.1", port), Path: "/ws"} + t.Logf("url: %s", u.String()) + + client := New(u.String(), nil) + + err = client.Connect(ctx) + assert.NoError(t, err) + assert.NotNil(t, client) + client.SetReadTimeout(2 * time.Second) + + // read one message + t.Logf("waiting for message...") + msg := <-client.messages + t.Logf("received message: %v", msg) + + // recreate server at the same address + client.Close() + + t.Logf("shutting down server...") + cancel() + err = server.Shutdown(basectx) + assert.NoError(t, err) + +} diff --git a/websocket/interface.go b/websocket/interface.go new file mode 100644 index 000000000..5392a4f7a --- /dev/null +++ b/websocket/interface.go @@ -0,0 +1,21 @@ +package websocket + +import ( + "context" + "time" +) + +//go:generate mockery -name=Client + +type Client interface { + SetWriteTimeout(time.Duration) + SetReadTimeout(time.Duration) + OnConnect(func(c Client)) + OnDisconnect(func(c Client)) + Connect(context.Context) error + Reconnect() + Close() error + IsConnected() bool + WriteJSON(interface{}) error + Messages() <-chan Message +} diff --git a/websocket/json.go b/websocket/json.go new file mode 100644 index 000000000..705fb5f1e --- /dev/null +++ b/websocket/json.go @@ -0,0 +1,55 @@ +package websocket + +import ( + "encoding/json" + "io" + + "github.com/gorilla/websocket" +) + +// WriteJSON writes the JSON encoding of v as a message. +// +// See the documentation for encoding/json Marshal for details about the +// conversion of Go values to JSON. +func (c *WebSocketClient) WriteJSON(v interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.connected { + return ErrConnectionLost + } + w, err := c.conn.NextWriter(websocket.TextMessage) + if err != nil { + return err + } + err1 := json.NewEncoder(w).Encode(v) + err2 := w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// See the documentation for the encoding/json Unmarshal function for details +// about the conversion of JSON to a Go value. +func (c *WebSocketClient) ReadJSON(v interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.connected { + return ErrConnectionLost + } + _, r, err := c.conn.NextReader() + if err != nil { + return err + } + err = json.NewDecoder(r).Decode(v) + if err == io.EOF { + // One value is expected in the message. + err = io.ErrUnexpectedEOF + } + return err +} diff --git a/websocket/mocks/Client.go b/websocket/mocks/Client.go new file mode 100644 index 000000000..105ee099e --- /dev/null +++ b/websocket/mocks/Client.go @@ -0,0 +1,110 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import context "context" +import mock "github.com/stretchr/testify/mock" +import time "time" +import websocket "***REMOVED***/pkg/net/http/websocket" + +// Client is an autogenerated mock type for the Client type +type Client struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *Client) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Connect provides a mock function with given fields: _a0 +func (_m *Client) Connect(_a0 context.Context) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// IsConnected provides a mock function with given fields: +func (_m *Client) IsConnected() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Messages provides a mock function with given fields: +func (_m *Client) Messages() <-chan websocket.Message { + ret := _m.Called() + + var r0 <-chan websocket.Message + if rf, ok := ret.Get(0).(func() <-chan websocket.Message); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan websocket.Message) + } + } + + return r0 +} + +// OnConnect provides a mock function with given fields: _a0 +func (_m *Client) OnConnect(_a0 func(websocket.Client)) { + _m.Called(_a0) +} + +// OnDisconnect provides a mock function with given fields: _a0 +func (_m *Client) OnDisconnect(_a0 func(websocket.Client)) { + _m.Called(_a0) +} + +// Reconnect provides a mock function with given fields: +func (_m *Client) Reconnect() { + _m.Called() +} + +// SetReadTimeout provides a mock function with given fields: _a0 +func (_m *Client) SetReadTimeout(_a0 time.Duration) { + _m.Called(_a0) +} + +// SetWriteTimeout provides a mock function with given fields: _a0 +func (_m *Client) SetWriteTimeout(_a0 time.Duration) { + _m.Called(_a0) +} + +// WriteJSON provides a mock function with given fields: _a0 +func (_m *Client) WriteJSON(_a0 interface{}) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +}