From 2761cff2bf126866ff53830f89c56f59dc991a5b Mon Sep 17 00:00:00 2001 From: c9s Date: Fri, 4 Nov 2022 00:26:28 +0800 Subject: [PATCH] grid2: add pin tests --- pkg/strategy/grid2/grid.go | 77 +++++++++++++++++---------- pkg/strategy/grid2/grid_test.go | 93 +++++++++++++++++++++++++++++---- pkg/types/market.go | 4 +- pkg/types/market_test.go | 10 ++-- 4 files changed, 138 insertions(+), 46 deletions(-) diff --git a/pkg/strategy/grid2/grid.go b/pkg/strategy/grid2/grid.go index 73c142e25..054e8f17c 100644 --- a/pkg/strategy/grid2/grid.go +++ b/pkg/strategy/grid2/grid.go @@ -1,6 +1,8 @@ package grid2 import ( + "math" + "github.com/c9s/bbgo/pkg/fixedpoint" ) @@ -15,64 +17,83 @@ type Grid struct { Size fixedpoint.Value `json:"size"` // Pins are the pinned grid prices, from low to high - Pins []fixedpoint.Value `json:"pins"` + Pins []Pin `json:"pins"` - pinsCache map[fixedpoint.Value]struct{} `json:"-"` + pinsCache map[Pin]struct{} `json:"-"` } -func NewGrid(lower, upper, density fixedpoint.Value) *Grid { - var height = upper - lower - var size = height.Div(density) - var pins []fixedpoint.Value +type Pin fixedpoint.Value - for p := lower; p <= upper; p += size { - pins = append(pins, p) +func calculateArithmeticPins(lower, upper, size, tickSize fixedpoint.Value) []Pin { + var height = upper.Sub(lower) + var spread = height.Div(size) + + var pins []Pin + for p := lower; p.Compare(upper) <= 0; p = p.Add(spread) { + // tickSize here = 0.01 + pp := math.Trunc(p.Float64()/tickSize.Float64()) * tickSize.Float64() + pins = append(pins, Pin(fixedpoint.NewFromFloat(pp))) } + return pins +} + +func buildPinCache(pins []Pin) map[Pin]struct{} { + cache := make(map[Pin]struct{}, len(pins)) + for _, pin := range pins { + cache[pin] = struct{}{} + } + + return cache +} + +func NewGrid(lower, upper, size, tickSize fixedpoint.Value) *Grid { + var height = upper.Sub(lower) + var spread = height.Div(size) + var pins = calculateArithmeticPins(lower, upper, size, tickSize) + grid := &Grid{ UpperPrice: upper, LowerPrice: lower, - Size: density, - Spread: size, + Size: size, + Spread: spread, Pins: pins, - pinsCache: make(map[fixedpoint.Value]struct{}, len(pins)), + pinsCache: buildPinCache(pins), } - grid.updatePinsCache() + return grid } func (g *Grid) Above(price fixedpoint.Value) bool { - return price > g.UpperPrice + return price.Compare(g.UpperPrice) > 0 } func (g *Grid) Below(price fixedpoint.Value) bool { - return price < g.LowerPrice + return price.Compare(g.LowerPrice) < 0 } func (g *Grid) OutOfRange(price fixedpoint.Value) bool { - return price < g.LowerPrice || price > g.UpperPrice + return price.Compare(g.LowerPrice) < 0 || price.Compare(g.UpperPrice) > 0 } func (g *Grid) updatePinsCache() { - for _, pin := range g.Pins { - g.pinsCache[pin] = struct{}{} - } + g.pinsCache = buildPinCache(g.Pins) } -func (g *Grid) HasPin(pin fixedpoint.Value) (ok bool) { +func (g *Grid) HasPin(pin Pin) (ok bool) { _, ok = g.pinsCache[pin] return ok } -func (g *Grid) ExtendUpperPrice(upper fixedpoint.Value) (newPins []fixedpoint.Value) { +func (g *Grid) ExtendUpperPrice(upper fixedpoint.Value) (newPins []Pin) { g.UpperPrice = upper // since the grid is extended, the size should be updated as well g.Size = (g.UpperPrice - g.LowerPrice).Div(g.Spread).Floor() - lastPin := g.Pins[len(g.Pins)-1] - for p := lastPin + g.Spread; p <= g.UpperPrice; p += g.Spread { - newPins = append(newPins, p) + lastPinPrice := fixedpoint.Value(g.Pins[len(g.Pins)-1]) + for p := lastPinPrice.Add(g.Spread); p <= g.UpperPrice; p += g.Spread { + newPins = append(newPins, Pin(p)) } g.Pins = append(g.Pins, newPins...) @@ -80,20 +101,20 @@ func (g *Grid) ExtendUpperPrice(upper fixedpoint.Value) (newPins []fixedpoint.Va return newPins } -func (g *Grid) ExtendLowerPrice(lower fixedpoint.Value) (newPins []fixedpoint.Value) { +func (g *Grid) ExtendLowerPrice(lower fixedpoint.Value) (newPins []Pin) { g.LowerPrice = lower // since the grid is extended, the size should be updated as well g.Size = (g.UpperPrice - g.LowerPrice).Div(g.Spread).Floor() - firstPin := g.Pins[0] - numToAdd := (firstPin - g.LowerPrice).Div(g.Spread).Floor() + firstPinPrice := fixedpoint.Value(g.Pins[0]) + numToAdd := (firstPinPrice.Sub(g.LowerPrice)).Div(g.Spread).Floor() if numToAdd == 0 { return newPins } - for p := firstPin - g.Spread.Mul(numToAdd); p < firstPin; p += g.Spread { - newPins = append(newPins, p) + for p := firstPinPrice.Sub(g.Spread.Mul(numToAdd)); p.Compare(firstPinPrice) < 0; p = p.Add(g.Spread) { + newPins = append(newPins, Pin(p)) } g.Pins = append(newPins, g.Pins...) diff --git a/pkg/strategy/grid2/grid_test.go b/pkg/strategy/grid2/grid_test.go index 3ed9d0e4e..12992ea8a 100644 --- a/pkg/strategy/grid2/grid_test.go +++ b/pkg/strategy/grid2/grid_test.go @@ -8,11 +8,20 @@ import ( "github.com/c9s/bbgo/pkg/fixedpoint" ) +func number(a interface{}) fixedpoint.Value { + if s, ok := a.(string); ok { + return fixedpoint.MustNewFromString(s) + } + + f := a.(float64) + return fixedpoint.NewFromFloat(f) +} + func TestNewGrid(t *testing.T) { upper := fixedpoint.NewFromFloat(500.0) lower := fixedpoint.NewFromFloat(100.0) size := fixedpoint.NewFromFloat(100.0) - grid := NewGrid(lower, upper, size) + grid := NewGrid(lower, upper, size, number(2.0)) assert.Equal(t, upper, grid.UpperPrice) assert.Equal(t, lower, grid.LowerPrice) assert.Equal(t, fixedpoint.NewFromFloat(4), grid.Spread) @@ -26,21 +35,21 @@ func TestGrid_HasPin(t *testing.T) { upper := fixedpoint.NewFromFloat(500.0) lower := fixedpoint.NewFromFloat(100.0) size := fixedpoint.NewFromFloat(100.0) - grid := NewGrid(lower, upper, size) + grid := NewGrid(lower, upper, size, number(2)) - assert.True(t, grid.HasPin(fixedpoint.NewFromFloat(100.0))) - assert.True(t, grid.HasPin(fixedpoint.NewFromFloat(500.0))) - assert.False(t, grid.HasPin(fixedpoint.NewFromFloat(101.0))) + assert.True(t, grid.HasPin(Pin(number(100.0)))) + assert.True(t, grid.HasPin(Pin(number(500.0)))) + assert.False(t, grid.HasPin(Pin(number(101.0)))) } func TestGrid_ExtendUpperPrice(t *testing.T) { - upper := fixedpoint.NewFromFloat(500.0) - lower := fixedpoint.NewFromFloat(100.0) - size := fixedpoint.NewFromFloat(100.0) - grid := NewGrid(lower, upper, size) + upper := number(500.0) + lower := number(100.0) + size := number(100.0) + grid := NewGrid(lower, upper, size, number(2.0)) originalSpread := grid.Spread - newPins := grid.ExtendUpperPrice(fixedpoint.NewFromFloat(1000.0)) + newPins := grid.ExtendUpperPrice(number(1000.0)) assert.Equal(t, originalSpread, grid.Spread) assert.Len(t, newPins, 125) // (1000-500) / 4 assert.Equal(t, fixedpoint.NewFromFloat(4), grid.Spread) @@ -54,7 +63,7 @@ func TestGrid_ExtendLowerPrice(t *testing.T) { upper := fixedpoint.NewFromFloat(3000.0) lower := fixedpoint.NewFromFloat(2000.0) size := fixedpoint.NewFromFloat(100.0) - grid := NewGrid(lower, upper, size) + grid := NewGrid(lower, upper, size, number(2.0)) // spread = (3000 - 2000) / 100.0 expectedSpread := fixedpoint.NewFromFloat(10.0) @@ -79,3 +88,65 @@ func TestGrid_ExtendLowerPrice(t *testing.T) { fixedpoint.NewFromFloat(1000.0 - 1.0)) assert.Len(t, newPins2, 0) // should have no new pin generated } + +func Test_calculateArithmeticPins(t *testing.T) { + type args struct { + lower fixedpoint.Value + upper fixedpoint.Value + size fixedpoint.Value + tickSize fixedpoint.Value + } + tests := []struct { + name string + args args + want []Pin + }{ + { + name: "simple", + args: args{ + lower: number(1000.0), + upper: number(3000.0), + size: number(30.0), + tickSize: number(0.01), + }, + want: []Pin{ + Pin(number(1000.0)), + Pin(number(1066.660)), + Pin(number(1133.330)), + Pin(number(1199.990)), + Pin(number(1266.660)), + Pin(number(1333.330)), + Pin(number(1399.990)), + Pin(number(1466.660)), + Pin(number(1533.330)), + Pin(number(1599.990)), + Pin(number(1666.660)), + Pin(number(1733.330)), + Pin(number(1799.990)), + Pin(number(1866.660)), + Pin(number(1933.330)), + Pin(number(1999.990)), + Pin(number(2066.660)), + Pin(number(2133.330)), + Pin(number("2199.99")), + Pin(number(2266.660)), + Pin(number(2333.330)), + Pin(number("2399.99")), + Pin(number(2466.660)), + Pin(number(2533.330)), + Pin(number("2599.99")), + Pin(number(2666.660)), + Pin(number(2733.330)), + Pin(number(2799.990)), + Pin(number(2866.660)), + Pin(number(2933.330)), + Pin(number(2999.990)), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, calculateArithmeticPins(tt.args.lower, tt.args.upper, tt.args.size, tt.args.tickSize), "calculateArithmeticPins(%v, %v, %v, %v)", tt.args.lower, tt.args.upper, tt.args.size, tt.args.tickSize) + }) + } +} diff --git a/pkg/types/market.go b/pkg/types/market.go index 1092b441e..e2b08cd21 100644 --- a/pkg/types/market.go +++ b/pkg/types/market.go @@ -150,10 +150,10 @@ func (m Market) FormatPriceCurrency(val fixedpoint.Value) string { func (m Market) FormatPrice(val fixedpoint.Value) string { // p := math.Pow10(m.PricePrecision) - return formatPrice(val, m.TickSize) + return FormatPrice(val, m.TickSize) } -func formatPrice(price fixedpoint.Value, tickSize fixedpoint.Value) string { +func FormatPrice(price fixedpoint.Value, tickSize fixedpoint.Value) string { prec := int(math.Round(math.Abs(math.Log10(tickSize.Float64())))) return price.FormatString(prec) } diff --git a/pkg/types/market_test.go b/pkg/types/market_test.go index d0544e9ba..809e60b0d 100644 --- a/pkg/types/market_test.go +++ b/pkg/types/market_test.go @@ -26,12 +26,12 @@ func TestFormatQuantity(t *testing.T) { } func TestFormatPrice(t *testing.T) { - price := formatPrice( + price := FormatPrice( s("26.288256"), s("0.0001")) assert.Equal(t, "26.2882", price) - price = formatPrice(s("26.288656"), s("0.001")) + price = FormatPrice(s("26.288656"), s("0.001")) assert.Equal(t, "26.288", price) } @@ -78,7 +78,7 @@ func TestDurationParse(t *testing.T) { } } -func Test_formatPrice(t *testing.T) { +func Test_FormatPrice(t *testing.T) { type args struct { price fixedpoint.Value tickSize fixedpoint.Value @@ -125,9 +125,9 @@ func Test_formatPrice(t *testing.T) { binanceFormatRE := regexp.MustCompile("^([0-9]{1,20})(.[0-9]{1,20})?$") for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := formatPrice(tt.args.price, tt.args.tickSize) + got := FormatPrice(tt.args.price, tt.args.tickSize) if got != tt.want { - t.Errorf("formatPrice() = %v, want %v", got, tt.want) + t.Errorf("FormatPrice() = %v, want %v", got, tt.want) } assert.Regexp(t, binanceFormatRE, got)