diff --git a/pkg/strategy/grid2/grid.go b/pkg/strategy/grid2/grid.go index 802e7f103..f61b1cb2e 100644 --- a/pkg/strategy/grid2/grid.go +++ b/pkg/strategy/grid2/grid.go @@ -7,6 +7,8 @@ import ( "github.com/c9s/bbgo/pkg/fixedpoint" ) +type PinCalculator func() []Pin + type Grid struct { UpperPrice fixedpoint.Value `json:"upperPrice"` LowerPrice fixedpoint.Value `json:"lowerPrice"` @@ -24,6 +26,8 @@ type Grid struct { Pins []Pin `json:"pins"` pinsCache map[Pin]struct{} `json:"-"` + + calculator PinCalculator } type Pin fixedpoint.Value @@ -63,9 +67,21 @@ func NewGrid(lower, upper, size, tickSize fixedpoint.Value) *Grid { return grid } -func (g *Grid) CalculatePins() { - var pins = calculateArithmeticPins(g.LowerPrice, g.UpperPrice, g.Spread, g.TickSize) - g.addPins(pins) +func (g *Grid) CalculateGeometricPins() { + g.calculator = func() []Pin { + // return calculateArithmeticPins(g.LowerPrice, g.UpperPrice, g.Spread, g.TickSize) + return nil + } + + g.addPins(g.calculator()) +} + +func (g *Grid) CalculateArithmeticPins() { + g.calculator = func() []Pin { + return calculateArithmeticPins(g.LowerPrice, g.UpperPrice, g.Spread, g.TickSize) + } + + g.addPins(g.calculator()) } func (g *Grid) Height() fixedpoint.Value { diff --git a/pkg/strategy/grid2/grid_test.go b/pkg/strategy/grid2/grid_test.go index 82ab446a7..ccd7245cf 100644 --- a/pkg/strategy/grid2/grid_test.go +++ b/pkg/strategy/grid2/grid_test.go @@ -28,7 +28,7 @@ func TestNewGrid(t *testing.T) { lower := fixedpoint.NewFromFloat(100.0) size := fixedpoint.NewFromFloat(100.0) grid := NewGrid(lower, upper, size, number(0.01)) - grid.CalculatePins() + grid.CalculateArithmeticPins() assert.Equal(t, upper, grid.UpperPrice) assert.Equal(t, lower, grid.LowerPrice) @@ -44,7 +44,7 @@ func TestGrid_HasPin(t *testing.T) { lower := fixedpoint.NewFromFloat(100.0) size := fixedpoint.NewFromFloat(100.0) grid := NewGrid(lower, upper, size, number(0.01)) - grid.CalculatePins() + grid.CalculateArithmeticPins() assert.True(t, grid.HasPin(Pin(number(100.0)))) assert.True(t, grid.HasPin(Pin(number(500.0)))) @@ -56,7 +56,7 @@ func TestGrid_ExtendUpperPrice(t *testing.T) { lower := number(100.0) size := number(4.0) grid := NewGrid(lower, upper, size, number(0.01)) - grid.CalculatePins() + grid.CalculateArithmeticPins() originalSpread := grid.Spread @@ -76,7 +76,7 @@ func TestGrid_ExtendLowerPrice(t *testing.T) { lower := fixedpoint.NewFromFloat(2000.0) size := fixedpoint.NewFromFloat(10.0) grid := NewGrid(lower, upper, size, number(0.01)) - grid.CalculatePins() + grid.CalculateArithmeticPins() assert.Equal(t, Pin(number(2000.0)), grid.BottomPin(), "bottom pin should be 1000.0") assert.Equal(t, Pin(number(3000.0)), grid.TopPin(), "top pin should be 3000.0") @@ -111,7 +111,7 @@ func TestGrid_NextLowerPin(t *testing.T) { lower := number(100.0) size := number(4.0) grid := NewGrid(lower, upper, size, number(0.01)) - grid.CalculatePins() + grid.CalculateArithmeticPins() t.Logf("pins: %+v", grid.Pins) @@ -129,7 +129,7 @@ func TestGrid_NextHigherPin(t *testing.T) { lower := number(100.0) size := number(4.0) grid := NewGrid(lower, upper, size, number(0.01)) - grid.CalculatePins() + grid.CalculateArithmeticPins() t.Logf("pins: %+v", grid.Pins) next, ok := grid.NextHigherPin(number(100.0)) diff --git a/pkg/strategy/grid2/strategy.go b/pkg/strategy/grid2/strategy.go index 42ff8449b..710607d8e 100644 --- a/pkg/strategy/grid2/strategy.go +++ b/pkg/strategy/grid2/strategy.go @@ -126,7 +126,7 @@ func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, se } s.grid = NewGrid(s.LowerPrice, s.UpperPrice, fixedpoint.NewFromInt(s.GridNum), s.Market.TickSize) - s.grid.CalculatePins() + s.grid.CalculateArithmeticPins() s.orderStore = bbgo.NewOrderStore(s.Symbol) s.orderStore.BindStream(session.UserDataStream)