Merge pull request #517 from c9s/narumi/rebalance-float-slice

strategy: rebalance: use float64slice to avoid map[string]fixedpoint.Value
This commit is contained in:
なるみ 2022-04-11 23:46:53 +08:00 committed by GitHub
commit 4ed753bad6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 94 deletions

View File

@ -37,8 +37,8 @@ func (inc *RSI) Update(kline types.KLine, priceF KLinePriceMapper) {
if len(inc.Prices) == inc.Window+1 {
priceDifferences := inc.Prices.Diff()
avgGain = priceDifferences.PositiveValuesOrZero().AbsoluteValues().Sum() / float64(inc.Window)
avgLoss = priceDifferences.NegativeValuesOrZero().AbsoluteValues().Sum() / float64(inc.Window)
avgGain = priceDifferences.PositiveValuesOrZero().Abs().Sum() / float64(inc.Window)
avgLoss = priceDifferences.NegativeValuesOrZero().Abs().Sum() / float64(inc.Window)
} else {
difference := price - inc.Prices[len(inc.Prices)-2]
currentGain := math.Max(difference, 0)

View File

@ -3,6 +3,8 @@ package rebalance
import (
"context"
"fmt"
"math"
"sort"
"github.com/sirupsen/logrus"
@ -19,35 +21,6 @@ func init() {
bbgo.RegisterStrategy(ID, &Strategy{})
}
func Sum(m map[string]fixedpoint.Value) fixedpoint.Value {
sum := fixedpoint.NewFromFloat(0.0)
for _, v := range m {
sum = sum.Add(v)
}
return sum
}
func Normalize(m map[string]fixedpoint.Value) map[string]fixedpoint.Value {
sum := Sum(m)
if sum.Float64() == 1.0 {
return m
}
normalized := make(map[string]fixedpoint.Value)
for k, v := range m {
normalized[k] = v.Div(sum)
}
return normalized
}
func ElementwiseProduct(m1, m2 map[string]fixedpoint.Value) map[string]fixedpoint.Value {
m := make(map[string]fixedpoint.Value)
for k, v := range m1 {
m[k] = v.Mul(m2[k])
}
return m
}
type Strategy struct {
Notifiability *bbgo.Notifiability
@ -60,6 +33,17 @@ type Strategy struct {
DryRun bool `json:"dryRun"`
// max amount to buy or sell per order
MaxAmount fixedpoint.Value `json:"maxAmount"`
currencies []string
}
func (s *Strategy) Initialize() error {
for currency := range s.TargetWeights {
s.currencies = append(s.currencies, currency)
}
sort.Strings(s.currencies)
return nil
}
func (s *Strategy) ID() string {
@ -95,7 +79,6 @@ func (s *Strategy) Subscribe(session *bbgo.ExchangeSession) {
}
func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, session *bbgo.ExchangeSession) error {
s.TargetWeights = Normalize(s.TargetWeights)
session.MarketDataStream.OnKLineClosed(func(kline types.KLine) {
s.rebalance(ctx, orderExecutor, session)
})
@ -110,7 +93,7 @@ func (s *Strategy) rebalance(ctx context.Context, orderExecutor bbgo.OrderExecut
balances := session.Account.Balances()
quantities := s.getQuantities(balances)
marketValues := ElementwiseProduct(prices, quantities)
marketValues := prices.Mul(quantities)
orders := s.generateSubmitOrders(prices, marketValues)
for _, order := range orders {
@ -128,12 +111,10 @@ func (s *Strategy) rebalance(ctx context.Context, orderExecutor bbgo.OrderExecut
}
}
func (s *Strategy) getPrices(ctx context.Context, session *bbgo.ExchangeSession) (map[string]fixedpoint.Value, error) {
prices := make(map[string]fixedpoint.Value)
for currency := range s.TargetWeights {
func (s *Strategy) getPrices(ctx context.Context, session *bbgo.ExchangeSession) (prices types.Float64Slice, err error) {
for _, currency := range s.currencies {
if currency == s.BaseCurrency {
prices[currency] = fixedpoint.One
prices = append(prices, 1.0)
continue
}
@ -145,38 +126,38 @@ func (s *Strategy) getPrices(ctx context.Context, session *bbgo.ExchangeSession)
return prices, err
}
prices[currency] = ticker.Last
prices = append(prices, ticker.Last.Float64())
}
return prices, nil
}
func (s *Strategy) getQuantities(balances types.BalanceMap) map[string]fixedpoint.Value {
quantities := make(map[string]fixedpoint.Value)
for currency := range s.TargetWeights {
func (s *Strategy) getQuantities(balances types.BalanceMap) (quantities types.Float64Slice) {
for _, currency := range s.currencies {
if s.IgnoreLocked {
quantities[currency] = balances[currency].Total()
quantities = append(quantities, balances[currency].Total().Float64())
} else {
quantities[currency] = balances[currency].Available
quantities = append(quantities, balances[currency].Available.Float64())
}
}
return quantities
}
func (s *Strategy) generateSubmitOrders(prices, marketValues map[string]fixedpoint.Value) []types.SubmitOrder {
var submitOrders []types.SubmitOrder
func (s *Strategy) generateSubmitOrders(prices, marketValues types.Float64Slice) (submitOrders []types.SubmitOrder) {
currentWeights := marketValues.Normalize()
totalValue := marketValues.Sum()
currentWeights := Normalize(marketValues)
totalValue := Sum(marketValues)
log.Infof("total value: %f", totalValue)
log.Infof("total value: %f", totalValue.Float64())
for currency, targetWeight := range s.TargetWeights {
for i, currency := range s.currencies {
if currency == s.BaseCurrency {
continue
}
symbol := currency + s.BaseCurrency
currentWeight := currentWeights[currency]
currentPrice := prices[currency]
currentWeight := currentWeights[i]
currentPrice := prices[i]
targetWeight := s.TargetWeights[currency].Float64()
log.Infof("%s price: %v, current weight: %v, target weight: %v",
symbol,
currentPrice,
@ -185,8 +166,8 @@ func (s *Strategy) generateSubmitOrders(prices, marketValues map[string]fixedpoi
// calculate the difference between current weight and target weight
// if the difference is less than threshold, then we will not create the order
weightDifference := targetWeight.Sub(currentWeight)
if weightDifference.Abs().Compare(s.Threshold) < 0 {
weightDifference := targetWeight - currentWeight
if math.Abs(weightDifference) < s.Threshold.Float64() {
log.Infof("%s weight distance |%v - %v| = |%v| less than the threshold: %v",
symbol,
currentWeight,
@ -196,7 +177,7 @@ func (s *Strategy) generateSubmitOrders(prices, marketValues map[string]fixedpoi
continue
}
quantity := weightDifference.Mul(totalValue).Div(currentPrice)
quantity := fixedpoint.NewFromFloat((weightDifference * totalValue) / currentPrice)
side := types.SideTypeBuy
if quantity.Sign() < 0 {
@ -205,7 +186,7 @@ func (s *Strategy) generateSubmitOrders(prices, marketValues map[string]fixedpoi
}
if s.MaxAmount.Sign() > 0 {
quantity = bbgo.AdjustQuantityByMaxAmount(quantity, currentPrice, s.MaxAmount)
quantity = bbgo.AdjustQuantityByMaxAmount(quantity, fixedpoint.NewFromFloat(currentPrice), s.MaxAmount)
log.Infof("adjust the quantity %v (%s %s @ %v) by max amount %v",
quantity,
symbol,
@ -213,7 +194,7 @@ func (s *Strategy) generateSubmitOrders(prices, marketValues map[string]fixedpoi
currentPrice,
s.MaxAmount)
}
log.Debugf("symbol: %v, quantity: %v", symbol, quantity)
order := types.SubmitOrder{
Symbol: symbol,
Side: side,
@ -225,14 +206,12 @@ func (s *Strategy) generateSubmitOrders(prices, marketValues map[string]fixedpoi
return submitOrders
}
func (s *Strategy) getSymbols() []string {
var symbols []string
for currency := range s.TargetWeights {
func (s *Strategy) getSymbols() (symbols []string) {
for _, currency := range s.currencies {
if currency == s.BaseCurrency {
continue
}
symbol := currency + s.BaseCurrency
symbols = append(symbols, symbol)
symbols = append(symbols, currency+s.BaseCurrency)
}
return symbols
}

View File

@ -1,6 +1,10 @@
package types
import "math"
import (
"math"
"gonum.org/v1/gonum/floats"
)
type Float64Slice []float64
@ -15,30 +19,23 @@ func (s *Float64Slice) Pop(i int64) (v float64) {
}
func (s Float64Slice) Max() float64 {
m := -math.MaxFloat64
for _, v := range s {
m = math.Max(m, v)
}
return m
return floats.Max(s)
}
func (s Float64Slice) Min() float64 {
m := math.MaxFloat64
for _, v := range s {
m = math.Min(m, v)
}
return m
return floats.Min(s)
}
func (s Float64Slice) Sum() (sum float64) {
for _, v := range s {
sum += v
}
return sum
return floats.Sum(s)
}
func (s Float64Slice) Mean() (mean float64) {
return s.Sum() / float64(len(s))
length := len(s)
if length == 0 {
panic("zero length slice")
}
return s.Sum() / float64(length)
}
func (s Float64Slice) Tail(size int) Float64Slice {
@ -54,8 +51,7 @@ func (s Float64Slice) Tail(size int) Float64Slice {
return win
}
func (s Float64Slice) Diff() Float64Slice {
var values Float64Slice
func (s Float64Slice) Diff() (values Float64Slice) {
for i, v := range s {
if i == 0 {
values.Push(0)
@ -66,54 +62,57 @@ func (s Float64Slice) Diff() Float64Slice {
return values
}
func (s Float64Slice) PositiveValuesOrZero() Float64Slice {
var values Float64Slice
func (s Float64Slice) PositiveValuesOrZero() (values Float64Slice) {
for _, v := range s {
values.Push(math.Max(v, 0))
}
return values
}
func (s Float64Slice) NegativeValuesOrZero() Float64Slice {
var values Float64Slice
func (s Float64Slice) NegativeValuesOrZero() (values Float64Slice) {
for _, v := range s {
values.Push(math.Min(v, 0))
}
return values
}
func (s Float64Slice) AbsoluteValues() Float64Slice {
var values Float64Slice
func (s Float64Slice) Abs() (values Float64Slice) {
for _, v := range s {
values.Push(math.Abs(v))
}
return values
}
func (s Float64Slice) MulScalar(x float64) Float64Slice {
var values Float64Slice
func (s Float64Slice) MulScalar(x float64) (values Float64Slice) {
for _, v := range s {
values.Push(v * x)
}
return values
}
func (s Float64Slice) DivScalar(x float64) Float64Slice {
var values Float64Slice
func (s Float64Slice) DivScalar(x float64) (values Float64Slice) {
for _, v := range s {
values.Push(v / x)
}
return values
}
func (s Float64Slice) ElementwiseProduct(other Float64Slice) Float64Slice {
var values Float64Slice
func (s Float64Slice) Mul(other Float64Slice) (values Float64Slice) {
if len(s) != len(other) {
panic("slice lengths do not match")
}
for i, v := range s {
values.Push(v * other[i])
}
return values
}
func (s Float64Slice) Dot(other Float64Slice) float64 {
return s.ElementwiseProduct(other).Sum()
return floats.Dot(s, other)
}
func (s Float64Slice) Normalize() Float64Slice {
return s.DivScalar(s.Sum())
}