Merge pull request #664 from c9s/fix/save-state-on-change

fix: use the correct id for state loading
This commit is contained in:
Yo-An Lin 2022-06-03 03:16:32 +08:00 committed by GitHub
commit c8055e9278
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 84 additions and 49 deletions

View File

@ -88,11 +88,14 @@ func (p *Persistence) Sync(obj interface{}) error {
func loadPersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error {
return iterateFieldsByTag(obj, "persistence", func(tag string, field reflect.StructField, value reflect.Value) error {
log.Debugf("[loadPersistenceFields] loading value into field %v, tag = %s, original value = %v", field, tag, value)
newValueInf := newTypeValueInterface(value.Type())
// inf := value.Interface()
store := persistence.NewStore("state", id, tag)
if err := store.Load(&newValueInf); err != nil {
if err == service.ErrPersistenceNotExists {
log.Debugf("[loadPersistenceFields] state key does not exist, id = %v, tag = %s", id, tag)
return nil
}
@ -104,7 +107,7 @@ func loadPersistenceFields(obj interface{}, id string, persistence service.Persi
newValue = newValue.Elem()
}
// log.Debugf("%v = %v (%s) -> %v (%s)\n", field, value, value.Type(), newValue, newValue.Type())
log.Debugf("[loadPersistenceFields] %v = %v (%s) -> %v (%s)\n", field, value, value.Type(), newValue, newValue.Type())
value.Set(newValue)
return nil
@ -113,8 +116,9 @@ func loadPersistenceFields(obj interface{}, id string, persistence service.Persi
func storePersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error {
return iterateFieldsByTag(obj, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error {
inf := fv.Interface()
log.Debugf("[storePersistenceFields] storing value from field %v, tag = %s, original value = %v", ft, tag, fv)
inf := fv.Interface()
store := persistence.NewStore("state", id, tag)
return store.Save(inf)
})

View File

@ -68,6 +68,23 @@ func Test_loadPersistenceFields(t *testing.T) {
err := loadPersistenceFields(b, "test-nil", ps)
assert.Equal(t, errCanNotIterateNilPointer, err)
})
t.Run(psName+"/pointer-field", func(t *testing.T) {
var a = &TestStruct{
Position: types.NewPosition("BTCUSDT", "BTC", "USDT"),
}
a.Position.Base = fixedpoint.NewFromFloat(10.0)
a.Position.AverageCost = fixedpoint.NewFromFloat(3343.0)
err := storePersistenceFields(a, "pointer-field-test", ps)
assert.NoError(t, err)
b := &TestStruct{}
err = loadPersistenceFields(b, "pointer-field-test", ps)
assert.NoError(t, err)
assert.Equal(t, "10", a.Position.Base.String())
assert.Equal(t, "3343", a.Position.AverageCost.String())
})
}
}

View File

@ -352,7 +352,8 @@ func (trader *Trader) LoadState() error {
log.Infof("loading strategies states...")
return trader.IterateStrategies(func(strategy StrategyID) error {
return loadPersistenceFields(strategy, strategy.ID(), ps)
id := callID(strategy)
return loadPersistenceFields(strategy, id, ps)
})
}

View File

@ -8,6 +8,7 @@ import (
"strings"
"github.com/go-redis/redis/v8"
log "github.com/sirupsen/logrus"
)
type RedisPersistenceService struct {
@ -49,9 +50,11 @@ func (store *RedisStore) Load(val interface{}) error {
return errors.New("can not load from redis, possible cause: redis persistence is not configured, or you are trying to use redis in back-test")
}
cmd := store.redis.Get(context.Background(), store.ID)
data, err := cmd.Result()
log.Debugf("[redis] get key %q, data = %s", store.ID, string(data))
if err != nil {
if err == redis.Nil {
return ErrPersistenceNotExists
@ -75,6 +78,9 @@ func (store *RedisStore) Save(val interface{}) error {
cmd := store.redis.Set(context.Background(), store.ID, data, 0)
_, err = cmd.Result()
log.Debugf("[redis] set key %q, data = %s", store.ID, string(data))
return err
}

View File

@ -407,7 +407,7 @@ func (s *Strategy) placeOrders(ctx context.Context, orderExecutor bbgo.OrderExec
}
}
trend := s.detectPriceTrend(s.neutralBoll, midPrice.Float64())
trend := detectPriceTrend(s.neutralBoll, midPrice.Float64())
switch trend {
case NeutralTrend:
// do nothing
@ -477,7 +477,7 @@ func (s *Strategy) placeOrders(ctx context.Context, orderExecutor bbgo.OrderExec
}
for i := range submitOrders {
submitOrders[i] = s.adjustOrderQuantity(submitOrders[i])
submitOrders[i] = adjustOrderQuantity(submitOrders[i], s.Market)
}
createdOrders, err := orderExecutor.SubmitOrders(ctx, submitOrders...)
@ -496,43 +496,6 @@ func (s *Strategy) hasShortSet() bool {
return s.Short != nil && *s.Short
}
type PriceTrend string
const (
NeutralTrend PriceTrend = "neutral"
UpTrend PriceTrend = "upTrend"
DownTrend PriceTrend = "downTrend"
UnknownTrend PriceTrend = "unknown"
)
func (s *Strategy) detectPriceTrend(inc *indicator.BOLL, price float64) PriceTrend {
if inBetween(price, inc.LastDownBand(), inc.LastUpBand()) {
return NeutralTrend
}
if price < inc.LastDownBand() {
return DownTrend
}
if price > inc.LastUpBand() {
return UpTrend
}
return UnknownTrend
}
func (s *Strategy) adjustOrderQuantity(submitOrder types.SubmitOrder) types.SubmitOrder {
if submitOrder.Quantity.Mul(submitOrder.Price).Compare(s.Market.MinNotional) < 0 {
submitOrder.Quantity = bbgo.AdjustFloatQuantityByMinAmount(submitOrder.Quantity, submitOrder.Price, s.Market.MinNotional.Mul(notionModifier))
}
if submitOrder.Quantity.Compare(s.Market.MinQuantity) < 0 {
submitOrder.Quantity = fixedpoint.Max(submitOrder.Quantity, s.Market.MinQuantity)
}
return submitOrder
}
func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, session *bbgo.ExchangeSession) error {
// StrategyController
s.Status = types.StrategyStatusRunning
@ -595,15 +558,15 @@ func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, se
instanceID := s.InstanceID()
s.groupID = util.FNV32(instanceID)
// restore state
if err := s.LoadState(); err != nil {
return err
}
// If position is nil, we need to allocate a new position for calculation
if s.Position == nil {
// restore state (legacy)
if err := s.LoadState(); err != nil {
return err
}
// fallback to the legacy position struct in the state
if s.state != nil && s.state.Position != nil {
if s.state != nil && s.state.Position != nil && !s.state.Position.Base.IsZero() {
s.Position = s.state.Position
} else {
s.Position = types.NewPositionFromMarket(s.Market)
@ -669,6 +632,10 @@ func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, se
s.tradeCollector.OnPositionUpdate(func(position *types.Position) {
log.Infof("position changed: %s", s.Position)
s.Notify(s.Position)
if err := s.Persistence.Sync(s); err != nil {
log.WithError(err).Errorf("can not sync state to persistence")
}
})
s.tradeCollector.BindStream(session.UserDataStream)
@ -774,3 +741,15 @@ func calculateBandPercentage(up, down, sma, midPrice float64) float64 {
func inBetween(x, a, b float64) bool {
return a < x && x < b
}
func adjustOrderQuantity(submitOrder types.SubmitOrder, market types.Market) types.SubmitOrder {
if submitOrder.Quantity.Mul(submitOrder.Price).Compare(market.MinNotional) < 0 {
submitOrder.Quantity = bbgo.AdjustFloatQuantityByMinAmount(submitOrder.Quantity, submitOrder.Price, market.MinNotional.Mul(notionModifier))
}
if submitOrder.Quantity.Compare(market.MinQuantity) < 0 {
submitOrder.Quantity = fixedpoint.Max(submitOrder.Quantity, market.MinQuantity)
}
return submitOrder
}

View File

@ -0,0 +1,28 @@
package bollmaker
import "github.com/c9s/bbgo/pkg/indicator"
type PriceTrend string
const (
NeutralTrend PriceTrend = "neutral"
UpTrend PriceTrend = "upTrend"
DownTrend PriceTrend = "downTrend"
UnknownTrend PriceTrend = "unknown"
)
func detectPriceTrend(inc *indicator.BOLL, price float64) PriceTrend {
if inBetween(price, inc.LastDownBand(), inc.LastUpBand()) {
return NeutralTrend
}
if price < inc.LastDownBand() {
return DownTrend
}
if price > inc.LastUpBand() {
return UpTrend
}
return UnknownTrend
}