diff --git a/pkg/strategy/shakegrid/grid.go b/pkg/strategy/shakegrid/grid.go index 44a50aa30..276327e3a 100644 --- a/pkg/strategy/shakegrid/grid.go +++ b/pkg/strategy/shakegrid/grid.go @@ -1,16 +1,18 @@ package shakegrid -import "github.com/c9s/bbgo/pkg/fixedpoint" +import ( + "github.com/c9s/bbgo/pkg/fixedpoint" +) type Grid struct { UpperPrice fixedpoint.Value `json:"upperPrice"` LowerPrice fixedpoint.Value `json:"lowerPrice"` - // Size is the spread of each grid - Size fixedpoint.Value `json:"size"` + // Spread is the spread of each grid + Spread fixedpoint.Value `json:"spread"` - // Density is the number of total grids - Density fixedpoint.Value `json:"density"` + // Size is the number of total grids + Size fixedpoint.Value `json:"size"` // Pins are the pinned grid prices, from low to high Pins []fixedpoint.Value `json:"pins"` @@ -21,23 +23,47 @@ func NewGrid(lower, upper, density fixedpoint.Value) *Grid { var size = height.Div(density) var pins []fixedpoint.Value - for p := lower ; p <= upper ; p += size { + for p := lower; p <= upper; p += size { pins = append(pins, p) } return &Grid{ UpperPrice: upper, LowerPrice: lower, - Density: density, - Size: size, - Pins: pins, + Size: density, + Spread: size, + Pins: pins, } } -func (g *Grid) ExtendUpperPrice(price fixedpoint.Value) { +func (g *Grid) ExtendUpperPrice(upper fixedpoint.Value) (newPins []fixedpoint.Value) { + g.UpperPrice = upper + // since the grid is extended, the size should be updated as well + g.Size = (g.UpperPrice - g.LowerPrice).Div(g.Spread).Ceil() + + lastPin := g.Pins[ len(g.Pins) - 1 ] + for p := lastPin + g.Spread; p <= g.UpperPrice; p += g.Spread { + newPins = append(newPins, p) + } + + g.Pins = append(g.Pins, newPins...) + return newPins } -func (g *Grid) ExtendLowerPrice(price fixedpoint.Value) { +func (g *Grid) ExtendLowerPrice(lower fixedpoint.Value) (newPins []fixedpoint.Value) { + g.LowerPrice = lower + // since the grid is extended, the size should be updated as well + g.Size = (g.UpperPrice - g.LowerPrice).Div(g.Spread).Ceil() + + firstPin := g.Pins[0] + numToAdd := (firstPin - g.LowerPrice).Div(g.Spread).Ceil() + + for p := firstPin - g.Spread.Mul(numToAdd); p < firstPin; p += g.Spread { + newPins = append(newPins, p) + } + + g.Pins = append(newPins, g.Pins...) + return newPins } diff --git a/pkg/strategy/shakegrid/grid_test.go b/pkg/strategy/shakegrid/grid_test.go index 403822844..af5606b7d 100644 --- a/pkg/strategy/shakegrid/grid_test.go +++ b/pkg/strategy/shakegrid/grid_test.go @@ -9,13 +9,56 @@ import ( func TestNewGrid(t *testing.T) { upper := fixedpoint.NewFromFloat(500.0) lower := fixedpoint.NewFromFloat(100.0) - density := fixedpoint.NewFromFloat(100.0) - grid := NewGrid(lower, upper, density) + size := fixedpoint.NewFromFloat(100.0) + grid := NewGrid(lower, upper, size) assert.Equal(t, upper, grid.UpperPrice) assert.Equal(t, lower, grid.LowerPrice) - assert.Equal(t, fixedpoint.NewFromFloat(4), grid.Size) + assert.Equal(t, fixedpoint.NewFromFloat(4), grid.Spread) if assert.Len(t, grid.Pins, 101) { assert.Equal(t, fixedpoint.NewFromFloat(100.0), grid.Pins[0]) assert.Equal(t, fixedpoint.NewFromFloat(500.0), grid.Pins[100]) } } + +func TestGrid_ExtendUpperPrice(t *testing.T) { + upper := fixedpoint.NewFromFloat(500.0) + lower := fixedpoint.NewFromFloat(100.0) + size := fixedpoint.NewFromFloat(100.0) + grid := NewGrid(lower, upper, size) + + originalSpread := grid.Spread + newPins := grid.ExtendUpperPrice(fixedpoint.NewFromFloat(1000.0)) + assert.Equal(t, originalSpread, grid.Spread) + assert.Len(t, newPins, 125) // (1000-500) / 4 + assert.Equal(t, fixedpoint.NewFromFloat(4), grid.Spread) + if assert.Len(t, grid.Pins, 226) { + assert.Equal(t, fixedpoint.NewFromFloat(100.0), grid.Pins[0]) + assert.Equal(t, fixedpoint.NewFromFloat(1000.0), grid.Pins[225]) + } +} + +func TestGrid_ExtendLowerPrice(t *testing.T) { + upper := fixedpoint.NewFromFloat(3000.0) + lower := fixedpoint.NewFromFloat(2000.0) + size := fixedpoint.NewFromFloat(100.0) + grid := NewGrid(lower, upper, size) + + // spread = (3000 - 2000) / 100.0 + expectedSpread := fixedpoint.NewFromFloat(10.0) + assert.Equal(t, expectedSpread, grid.Spread) + + originalSpread := grid.Spread + newPins := grid.ExtendLowerPrice(fixedpoint.NewFromFloat(1000.0)) + assert.Equal(t, originalSpread, grid.Spread) + + // 100 = (2000-1000) / 10 + if assert.Len(t, newPins, 100) { + assert.Equal(t, fixedpoint.NewFromFloat(2000.0) - expectedSpread, newPins[99]) + } + + assert.Equal(t, expectedSpread, grid.Spread) + if assert.Len(t, grid.Pins, 201) { + assert.Equal(t, fixedpoint.NewFromFloat(1000.0), grid.Pins[0]) + assert.Equal(t, fixedpoint.NewFromFloat(3000.0), grid.Pins[200]) + } +}