Avoid to use map[string]fixedpoint.Value

This commit is contained in:
なるみ 2022-04-07 16:00:30 +08:00
parent 4df818e8e1
commit 859933d4ed
2 changed files with 70 additions and 92 deletions

View File

@ -3,6 +3,8 @@ package rebalance
import ( import (
"context" "context"
"fmt" "fmt"
"math"
"sort"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -19,35 +21,6 @@ func init() {
bbgo.RegisterStrategy(ID, &Strategy{}) bbgo.RegisterStrategy(ID, &Strategy{})
} }
func Sum(m map[string]fixedpoint.Value) fixedpoint.Value {
sum := fixedpoint.NewFromFloat(0.0)
for _, v := range m {
sum = sum.Add(v)
}
return sum
}
func Normalize(m map[string]fixedpoint.Value) map[string]fixedpoint.Value {
sum := Sum(m)
if sum.Float64() == 1.0 {
return m
}
normalized := make(map[string]fixedpoint.Value)
for k, v := range m {
normalized[k] = v.Div(sum)
}
return normalized
}
func ElementwiseProduct(m1, m2 map[string]fixedpoint.Value) map[string]fixedpoint.Value {
m := make(map[string]fixedpoint.Value)
for k, v := range m1 {
m[k] = v.Mul(m2[k])
}
return m
}
type Strategy struct { type Strategy struct {
Notifiability *bbgo.Notifiability Notifiability *bbgo.Notifiability
@ -60,6 +33,17 @@ type Strategy struct {
DryRun bool `json:"dryRun"` DryRun bool `json:"dryRun"`
// max amount to buy or sell per order // max amount to buy or sell per order
MaxAmount fixedpoint.Value `json:"maxAmount"` MaxAmount fixedpoint.Value `json:"maxAmount"`
currencies []string
}
func (s *Strategy) Initialize() error {
for currency := range s.TargetWeights {
s.currencies = append(s.currencies, currency)
}
sort.Strings(s.currencies)
return nil
} }
func (s *Strategy) ID() string { func (s *Strategy) ID() string {
@ -95,7 +79,6 @@ func (s *Strategy) Subscribe(session *bbgo.ExchangeSession) {
} }
func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, session *bbgo.ExchangeSession) error { func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, session *bbgo.ExchangeSession) error {
s.TargetWeights = Normalize(s.TargetWeights)
session.MarketDataStream.OnKLineClosed(func(kline types.KLine) { session.MarketDataStream.OnKLineClosed(func(kline types.KLine) {
s.rebalance(ctx, orderExecutor, session) s.rebalance(ctx, orderExecutor, session)
}) })
@ -110,7 +93,7 @@ func (s *Strategy) rebalance(ctx context.Context, orderExecutor bbgo.OrderExecut
balances := session.Account.Balances() balances := session.Account.Balances()
quantities := s.getQuantities(balances) quantities := s.getQuantities(balances)
marketValues := ElementwiseProduct(prices, quantities) marketValues := prices.Mul(quantities)
orders := s.generateSubmitOrders(prices, marketValues) orders := s.generateSubmitOrders(prices, marketValues)
for _, order := range orders { for _, order := range orders {
@ -128,12 +111,10 @@ func (s *Strategy) rebalance(ctx context.Context, orderExecutor bbgo.OrderExecut
} }
} }
func (s *Strategy) getPrices(ctx context.Context, session *bbgo.ExchangeSession) (map[string]fixedpoint.Value, error) { func (s *Strategy) getPrices(ctx context.Context, session *bbgo.ExchangeSession) (prices types.Float64Slice, err error) {
prices := make(map[string]fixedpoint.Value) for _, currency := range s.currencies {
for currency := range s.TargetWeights {
if currency == s.BaseCurrency { if currency == s.BaseCurrency {
prices[currency] = fixedpoint.One prices = append(prices, 1.0)
continue continue
} }
@ -145,38 +126,38 @@ func (s *Strategy) getPrices(ctx context.Context, session *bbgo.ExchangeSession)
return prices, err return prices, err
} }
prices[currency] = ticker.Last prices = append(prices, ticker.Last.Float64())
} }
return prices, nil return prices, nil
} }
func (s *Strategy) getQuantities(balances types.BalanceMap) map[string]fixedpoint.Value { func (s *Strategy) getQuantities(balances types.BalanceMap) (quantities types.Float64Slice) {
quantities := make(map[string]fixedpoint.Value) for _, currency := range s.currencies {
for currency := range s.TargetWeights {
if s.IgnoreLocked { if s.IgnoreLocked {
quantities[currency] = balances[currency].Total() quantities = append(quantities, balances[currency].Total().Float64())
} else { } else {
quantities[currency] = balances[currency].Available quantities = append(quantities, balances[currency].Available.Float64())
} }
} }
return quantities return quantities
} }
func (s *Strategy) generateSubmitOrders(prices, marketValues map[string]fixedpoint.Value) []types.SubmitOrder { func (s *Strategy) generateSubmitOrders(prices, marketValues types.Float64Slice) (submitOrders []types.SubmitOrder) {
var submitOrders []types.SubmitOrder currentWeights := marketValues.Normalize()
totalValue := marketValues.Sum()
currentWeights := Normalize(marketValues) log.Infof("total value: %f", totalValue)
totalValue := Sum(marketValues)
log.Infof("total value: %f", totalValue.Float64()) for i, currency := range s.currencies {
for currency, targetWeight := range s.TargetWeights {
if currency == s.BaseCurrency { if currency == s.BaseCurrency {
continue continue
} }
symbol := currency + s.BaseCurrency symbol := currency + s.BaseCurrency
currentWeight := currentWeights[currency] currentWeight := currentWeights[i]
currentPrice := prices[currency] currentPrice := prices[i]
targetWeight := s.TargetWeights[currency].Float64()
log.Infof("%s price: %v, current weight: %v, target weight: %v", log.Infof("%s price: %v, current weight: %v, target weight: %v",
symbol, symbol,
currentPrice, currentPrice,
@ -185,8 +166,8 @@ func (s *Strategy) generateSubmitOrders(prices, marketValues map[string]fixedpoi
// calculate the difference between current weight and target weight // calculate the difference between current weight and target weight
// if the difference is less than threshold, then we will not create the order // if the difference is less than threshold, then we will not create the order
weightDifference := targetWeight.Sub(currentWeight) weightDifference := targetWeight - currentWeight
if weightDifference.Abs().Compare(s.Threshold) < 0 { if math.Abs(weightDifference) < s.Threshold.Float64() {
log.Infof("%s weight distance |%v - %v| = |%v| less than the threshold: %v", log.Infof("%s weight distance |%v - %v| = |%v| less than the threshold: %v",
symbol, symbol,
currentWeight, currentWeight,
@ -196,7 +177,7 @@ func (s *Strategy) generateSubmitOrders(prices, marketValues map[string]fixedpoi
continue continue
} }
quantity := weightDifference.Mul(totalValue).Div(currentPrice) quantity := fixedpoint.NewFromFloat((weightDifference * totalValue) / currentPrice)
side := types.SideTypeBuy side := types.SideTypeBuy
if quantity.Sign() < 0 { if quantity.Sign() < 0 {
@ -205,7 +186,7 @@ func (s *Strategy) generateSubmitOrders(prices, marketValues map[string]fixedpoi
} }
if s.MaxAmount.Sign() > 0 { if s.MaxAmount.Sign() > 0 {
quantity = bbgo.AdjustQuantityByMaxAmount(quantity, currentPrice, s.MaxAmount) quantity = bbgo.AdjustQuantityByMaxAmount(quantity, fixedpoint.NewFromFloat(currentPrice), s.MaxAmount)
log.Infof("adjust the quantity %v (%s %s @ %v) by max amount %v", log.Infof("adjust the quantity %v (%s %s @ %v) by max amount %v",
quantity, quantity,
symbol, symbol,
@ -213,7 +194,7 @@ func (s *Strategy) generateSubmitOrders(prices, marketValues map[string]fixedpoi
currentPrice, currentPrice,
s.MaxAmount) s.MaxAmount)
} }
log.Debugf("symbol: %v, quantity: %v", symbol, quantity)
order := types.SubmitOrder{ order := types.SubmitOrder{
Symbol: symbol, Symbol: symbol,
Side: side, Side: side,
@ -225,14 +206,12 @@ func (s *Strategy) generateSubmitOrders(prices, marketValues map[string]fixedpoi
return submitOrders return submitOrders
} }
func (s *Strategy) getSymbols() []string { func (s *Strategy) getSymbols() (symbols []string) {
var symbols []string for _, currency := range s.currencies {
for currency := range s.TargetWeights {
if currency == s.BaseCurrency { if currency == s.BaseCurrency {
continue continue
} }
symbol := currency + s.BaseCurrency symbols = append(symbols, currency+s.BaseCurrency)
symbols = append(symbols, symbol)
} }
return symbols return symbols
} }

View File

@ -1,6 +1,10 @@
package types package types
import "math" import (
"math"
"gonum.org/v1/gonum/floats"
)
type Float64Slice []float64 type Float64Slice []float64
@ -15,30 +19,23 @@ func (s *Float64Slice) Pop(i int64) (v float64) {
} }
func (s Float64Slice) Max() float64 { func (s Float64Slice) Max() float64 {
m := -math.MaxFloat64 return floats.Max(s)
for _, v := range s {
m = math.Max(m, v)
}
return m
} }
func (s Float64Slice) Min() float64 { func (s Float64Slice) Min() float64 {
m := math.MaxFloat64 return floats.Min(s)
for _, v := range s {
m = math.Min(m, v)
}
return m
} }
func (s Float64Slice) Sum() (sum float64) { func (s Float64Slice) Sum() (sum float64) {
for _, v := range s { return floats.Sum(s)
sum += v
}
return sum
} }
func (s Float64Slice) Mean() (mean float64) { func (s Float64Slice) Mean() (mean float64) {
return s.Sum() / float64(len(s)) length := len(s)
if length == 0 {
panic("zero length slice")
}
return s.Sum() / float64(length)
} }
func (s Float64Slice) Tail(size int) Float64Slice { func (s Float64Slice) Tail(size int) Float64Slice {
@ -54,8 +51,7 @@ func (s Float64Slice) Tail(size int) Float64Slice {
return win return win
} }
func (s Float64Slice) Diff() Float64Slice { func (s Float64Slice) Diff() (values Float64Slice) {
var values Float64Slice
for i, v := range s { for i, v := range s {
if i == 0 { if i == 0 {
values.Push(0) values.Push(0)
@ -66,54 +62,57 @@ func (s Float64Slice) Diff() Float64Slice {
return values return values
} }
func (s Float64Slice) PositiveValuesOrZero() Float64Slice { func (s Float64Slice) PositiveValuesOrZero() (values Float64Slice) {
var values Float64Slice
for _, v := range s { for _, v := range s {
values.Push(math.Max(v, 0)) values.Push(math.Max(v, 0))
} }
return values return values
} }
func (s Float64Slice) NegativeValuesOrZero() Float64Slice { func (s Float64Slice) NegativeValuesOrZero() (values Float64Slice) {
var values Float64Slice
for _, v := range s { for _, v := range s {
values.Push(math.Min(v, 0)) values.Push(math.Min(v, 0))
} }
return values return values
} }
func (s Float64Slice) AbsoluteValues() Float64Slice { func (s Float64Slice) AbsoluteValues() (values Float64Slice) {
var values Float64Slice
for _, v := range s { for _, v := range s {
values.Push(math.Abs(v)) values.Push(math.Abs(v))
} }
return values return values
} }
func (s Float64Slice) MulScalar(x float64) Float64Slice { func (s Float64Slice) MulScalar(x float64) (values Float64Slice) {
var values Float64Slice
for _, v := range s { for _, v := range s {
values.Push(v * x) values.Push(v * x)
} }
return values return values
} }
func (s Float64Slice) DivScalar(x float64) Float64Slice { func (s Float64Slice) DivScalar(x float64) (values Float64Slice) {
var values Float64Slice
for _, v := range s { for _, v := range s {
values.Push(v / x) values.Push(v / x)
} }
return values return values
} }
func (s Float64Slice) ElementwiseProduct(other Float64Slice) Float64Slice { func (s Float64Slice) Mul(other Float64Slice) (values Float64Slice) {
var values Float64Slice if len(s) != len(other) {
panic("slice lengths do not match")
}
for i, v := range s { for i, v := range s {
values.Push(v * other[i]) values.Push(v * other[i])
} }
return values return values
} }
func (s Float64Slice) Dot(other Float64Slice) float64 { func (s Float64Slice) Dot(other Float64Slice) float64 {
return s.ElementwiseProduct(other).Sum() return floats.Dot(s, other)
}
func (s Float64Slice) Normalize() (values Float64Slice) {
return s.DivScalar(s.Sum())
} }