mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-23 07:15:15 +00:00
93 lines
2.4 KiB
Go
93 lines
2.4 KiB
Go
|
package riskcontrol
|
||
|
|
||
|
import (
|
||
|
"testing"
|
||
|
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
|
||
|
"github.com/c9s/bbgo/pkg/bbgo"
|
||
|
"github.com/c9s/bbgo/pkg/fixedpoint"
|
||
|
"github.com/c9s/bbgo/pkg/types"
|
||
|
)
|
||
|
|
||
|
func Test_ModifiedQuantity(t *testing.T) {
|
||
|
|
||
|
riskControl := NewPositionRiskControl(fixedpoint.NewFromInt(10), fixedpoint.NewFromInt(2), &bbgo.TradeCollector{})
|
||
|
|
||
|
cases := []struct {
|
||
|
name string
|
||
|
position fixedpoint.Value
|
||
|
buyQuantity fixedpoint.Value
|
||
|
sellQuantity fixedpoint.Value
|
||
|
}{
|
||
|
{
|
||
|
name: "BuyOverHardLimit",
|
||
|
position: fixedpoint.NewFromInt(9),
|
||
|
buyQuantity: fixedpoint.NewFromInt(1),
|
||
|
sellQuantity: fixedpoint.NewFromInt(2),
|
||
|
},
|
||
|
{
|
||
|
name: "SellOverHardLimit",
|
||
|
position: fixedpoint.NewFromInt(-9),
|
||
|
buyQuantity: fixedpoint.NewFromInt(2),
|
||
|
sellQuantity: fixedpoint.NewFromInt(1),
|
||
|
},
|
||
|
}
|
||
|
for _, tc := range cases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
buyQuantity, sellQuantity := riskControl.ModifiedQuantity(tc.position)
|
||
|
assert.Equal(t, tc.buyQuantity, buyQuantity)
|
||
|
assert.Equal(t, tc.sellQuantity, sellQuantity)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestReleasePositionCallbacks(t *testing.T) {
|
||
|
|
||
|
var position fixedpoint.Value
|
||
|
|
||
|
tradeCollector := &bbgo.TradeCollector{}
|
||
|
riskControl := NewPositionRiskControl(fixedpoint.NewFromInt(10), fixedpoint.NewFromInt(2), tradeCollector)
|
||
|
riskControl.OnReleasePosition(func(quantity fixedpoint.Value, side types.SideType) {
|
||
|
if side == types.SideTypeBuy {
|
||
|
position = position.Add(quantity)
|
||
|
} else {
|
||
|
position = position.Sub(quantity)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
cases := []struct {
|
||
|
name string
|
||
|
position fixedpoint.Value
|
||
|
resultPosition fixedpoint.Value
|
||
|
}{
|
||
|
{
|
||
|
name: "PostivePositionWithinLimit",
|
||
|
position: fixedpoint.NewFromInt(8),
|
||
|
resultPosition: fixedpoint.NewFromInt(8),
|
||
|
},
|
||
|
{
|
||
|
name: "NegativePositionWithinLimit",
|
||
|
position: fixedpoint.NewFromInt(-8),
|
||
|
resultPosition: fixedpoint.NewFromInt(-8),
|
||
|
},
|
||
|
{
|
||
|
name: "PostivePositionOverLimit",
|
||
|
position: fixedpoint.NewFromInt(11),
|
||
|
resultPosition: fixedpoint.NewFromInt(10),
|
||
|
},
|
||
|
{
|
||
|
name: "NegativePositionOverLimit",
|
||
|
position: fixedpoint.NewFromInt(-11),
|
||
|
resultPosition: fixedpoint.NewFromInt(-10),
|
||
|
},
|
||
|
}
|
||
|
for _, tc := range cases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
position = tc.position
|
||
|
tradeCollector.EmitPositionUpdate(&types.Position{Base: tc.position})
|
||
|
assert.Equal(t, tc.resultPosition, position)
|
||
|
})
|
||
|
}
|
||
|
}
|