mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-22 14:55:16 +00:00
implement reflect-based persistence restore and load
This commit is contained in:
parent
18eab1fbd3
commit
21f81dec29
|
@ -7,24 +7,6 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func isSymbolBasedStrategy(rs reflect.Value) (string, bool) {
|
|
||||||
field := rs.FieldByName("Symbol")
|
|
||||||
if !field.IsValid() {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
if field.Kind() != reflect.String {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
return field.String(), true
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasField(rs reflect.Value, fieldName string) (field reflect.Value, ok bool) {
|
|
||||||
field = rs.FieldByName(fieldName)
|
|
||||||
return field, field.IsValid()
|
|
||||||
}
|
|
||||||
|
|
||||||
func injectField(rs reflect.Value, fieldName string, obj interface{}, pointerOnly bool) error {
|
func injectField(rs reflect.Value, fieldName string, obj interface{}, pointerOnly bool) error {
|
||||||
field := rs.FieldByName(fieldName)
|
field := rs.FieldByName(fieldName)
|
||||||
if !field.IsValid() {
|
if !field.IsValid() {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package bbgo
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
@ -74,3 +75,100 @@ func (p *Persistence) Save(val interface{}, subIDs ...string) error {
|
||||||
store := ps.NewStore(p.PersistenceSelector.StoreID, subIDs...)
|
store := ps.NewStore(p.PersistenceSelector.StoreID, subIDs...)
|
||||||
return store.Save(val)
|
return store.Save(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func iterateFieldsByTag(obj interface{}, tagName string, cb func(tag string, ft reflect.StructField, fv reflect.Value) error) error {
|
||||||
|
sv := reflect.ValueOf(obj)
|
||||||
|
st := reflect.TypeOf(obj)
|
||||||
|
|
||||||
|
if st.Kind() != reflect.Ptr {
|
||||||
|
return fmt.Errorf("f needs to be a pointer of a struct, %s given", st)
|
||||||
|
}
|
||||||
|
|
||||||
|
// solve the reference
|
||||||
|
st = st.Elem()
|
||||||
|
sv = sv.Elem()
|
||||||
|
|
||||||
|
if st.Kind() != reflect.Struct {
|
||||||
|
return fmt.Errorf("f needs to be a struct, %s given", st)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < sv.NumField(); i++ {
|
||||||
|
fv := sv.Field(i)
|
||||||
|
ft := st.Field(i)
|
||||||
|
|
||||||
|
fvt := fv.Type()
|
||||||
|
_ = fvt
|
||||||
|
|
||||||
|
// skip unexported fields
|
||||||
|
if !st.Field(i).IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tag, ok := ft.Tag.Lookup(tagName)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cb(tag, ft, fv); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/xiaojun207/go-base-utils/blob/master/utils/Clone.go
|
||||||
|
func newTypeValueInterface(typ reflect.Type) interface{} {
|
||||||
|
if typ.Kind() == reflect.Ptr {
|
||||||
|
typ = typ.Elem()
|
||||||
|
dst := reflect.New(typ).Elem()
|
||||||
|
return dst.Addr().Interface()
|
||||||
|
} else {
|
||||||
|
dst := reflect.New(typ)
|
||||||
|
return dst.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadPersistenceFields(obj interface{}, persistence service.PersistenceService) error {
|
||||||
|
id := callID(obj)
|
||||||
|
if len(id) == 0 {
|
||||||
|
return fmt.Errorf("object does not provide ID() method for persistency")
|
||||||
|
}
|
||||||
|
|
||||||
|
return iterateFieldsByTag(obj, "persistence", func(tag string, field reflect.StructField, value reflect.Value) error {
|
||||||
|
newValueInf := newTypeValueInterface(value.Type())
|
||||||
|
// inf := value.Interface()
|
||||||
|
store := persistence.NewStore(id, tag)
|
||||||
|
if err := store.Load(&newValueInf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newValue := reflect.ValueOf(newValueInf)
|
||||||
|
if value.Kind() != reflect.Ptr && newValue.Kind() == reflect.Ptr {
|
||||||
|
newValue = newValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// log.Debugf("%v = %v (%s) -> %v (%s)\n", field, value, value.Type(), newValue, newValue.Type())
|
||||||
|
|
||||||
|
value.Set(newValue)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func storePersistenceFields(obj interface{}, persistence service.PersistenceService) error {
|
||||||
|
id := callID(obj)
|
||||||
|
if len(id) == 0 {
|
||||||
|
return fmt.Errorf("object does not provide ID() method for persistency")
|
||||||
|
}
|
||||||
|
|
||||||
|
return iterateFieldsByTag(obj, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error {
|
||||||
|
inf := fv.Interface()
|
||||||
|
|
||||||
|
store := persistence.NewStore(id, tag)
|
||||||
|
if err := store.Save(inf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
84
pkg/bbgo/persistence_test.go
Normal file
84
pkg/bbgo/persistence_test.go
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
package bbgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/c9s/bbgo/pkg/fixedpoint"
|
||||||
|
"github.com/c9s/bbgo/pkg/service"
|
||||||
|
"github.com/c9s/bbgo/pkg/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TestStruct struct {
|
||||||
|
Position *types.Position `persistence:"position"`
|
||||||
|
Integer int64 `persistence:"integer"`
|
||||||
|
Integer2 int64 `persistence:"integer2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TestStruct) InstanceID() string {
|
||||||
|
return "test-struct"
|
||||||
|
}
|
||||||
|
|
||||||
|
func preparePersistentServices() []service.PersistenceService {
|
||||||
|
mem := service.NewMemoryService()
|
||||||
|
jsonDir := &service.JsonPersistenceService{Directory: "testoutput/persistence"}
|
||||||
|
pss := []service.PersistenceService{
|
||||||
|
mem,
|
||||||
|
jsonDir,
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := os.LookupEnv("TEST_REDIS"); ok {
|
||||||
|
redisP := service.NewRedisPersistenceService(&service.RedisPersistenceConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: "6379",
|
||||||
|
DB: 0,
|
||||||
|
})
|
||||||
|
pss = append(pss, redisP)
|
||||||
|
}
|
||||||
|
|
||||||
|
return pss
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_storePersistenceFields(t *testing.T) {
|
||||||
|
var pss = preparePersistentServices()
|
||||||
|
|
||||||
|
var a = &TestStruct{
|
||||||
|
Integer: 1,
|
||||||
|
Integer2: 2,
|
||||||
|
Position: types.NewPosition("BTCUSDT", "BTC", "USDT"),
|
||||||
|
}
|
||||||
|
|
||||||
|
a.Position.Base = fixedpoint.NewFromFloat(10.0)
|
||||||
|
a.Position.AverageCost = fixedpoint.NewFromFloat(3343.0)
|
||||||
|
|
||||||
|
for _, ps := range pss {
|
||||||
|
t.Run(reflect.TypeOf(ps).Elem().String(), func(t *testing.T) {
|
||||||
|
err := storePersistenceFields(a, ps)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var i int64
|
||||||
|
store := ps.NewStore("test-struct", "integer")
|
||||||
|
err = store.Load(&i)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), i)
|
||||||
|
|
||||||
|
var p *types.Position
|
||||||
|
store = ps.NewStore("test-struct", "position")
|
||||||
|
err = store.Load(&p)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, fixedpoint.NewFromFloat(10.0), p.Base)
|
||||||
|
assert.Equal(t, fixedpoint.NewFromFloat(3343.0), p.AverageCost)
|
||||||
|
|
||||||
|
var b = &TestStruct{}
|
||||||
|
err = loadPersistenceFields(b, ps)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, a.Integer, b.Integer)
|
||||||
|
assert.Equal(t, a.Integer2, b.Integer2)
|
||||||
|
assert.Equal(t, a.Position, b.Position)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
39
pkg/bbgo/reflect.go
Normal file
39
pkg/bbgo/reflect.go
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
package bbgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InstanceIDProvider interface{
|
||||||
|
InstanceID() string
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func callID(obj interface{}) string {
|
||||||
|
sv := reflect.ValueOf(obj)
|
||||||
|
st := reflect.TypeOf(obj)
|
||||||
|
if st.Implements(reflect.TypeOf((*InstanceIDProvider)(nil)).Elem()) {
|
||||||
|
m := sv.MethodByName("InstanceID")
|
||||||
|
ret := m.Call(nil)
|
||||||
|
return ret[0].String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSymbolBasedStrategy(rs reflect.Value) (string, bool) {
|
||||||
|
field := rs.FieldByName("Symbol")
|
||||||
|
if !field.IsValid() {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
if field.Kind() != reflect.String {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return field.String(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasField(rs reflect.Value, fieldName string) (field reflect.Value, ok bool) {
|
||||||
|
field = rs.FieldByName(fieldName)
|
||||||
|
return field, field.IsValid()
|
||||||
|
}
|
|
@ -248,6 +248,12 @@ func (trader *Trader) RunSingleExchangeStrategy(ctx context.Context, strategy Si
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Before we run the strategy we need to load the state from the persistence layer:
|
||||||
|
// 1) scan the struct fields and find the persistence field
|
||||||
|
// 2) load the data and set the value into the persistence field.
|
||||||
|
|
||||||
|
_ = trader.environment.PersistenceServiceFacade
|
||||||
|
|
||||||
return strategy.Run(ctx, orderExecutor, session)
|
return strategy.Run(ctx, orderExecutor, session)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
@ -36,7 +37,9 @@ func (store *MemoryStore) Save(val interface{}) error {
|
||||||
func (store *MemoryStore) Load(val interface{}) error {
|
func (store *MemoryStore) Load(val interface{}) error {
|
||||||
v := reflect.ValueOf(val)
|
v := reflect.ValueOf(val)
|
||||||
if data, ok := store.memory.Slots[store.Key]; ok {
|
if data, ok := store.memory.Slots[store.Key]; ok {
|
||||||
v.Elem().Set(reflect.ValueOf(data).Elem())
|
dataRV := reflect.ValueOf(data)
|
||||||
|
fmt.Printf("load %s = %v\n", store.Key, dataRV)
|
||||||
|
v.Elem().Set(dataRV)
|
||||||
} else {
|
} else {
|
||||||
return ErrPersistenceNotExists
|
return ErrPersistenceNotExists
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,7 +63,7 @@ func (store JsonStore) Load(val interface{}) error {
|
||||||
|
|
||||||
func (store JsonStore) Save(val interface{}) error {
|
func (store JsonStore) Save(val interface{}) error {
|
||||||
if _, err := os.Stat(store.Directory); os.IsNotExist(err) {
|
if _, err := os.Stat(store.Directory); os.IsNotExist(err) {
|
||||||
if err2 := os.Mkdir(store.Directory, 0777); err2 != nil {
|
if err2 := os.MkdirAll(store.Directory, 0777); err2 != nil {
|
||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -566,6 +566,10 @@ func (s *Strategy) adjustOrderQuantity(submitOrder types.SubmitOrder) types.Subm
|
||||||
return submitOrder
|
return submitOrder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Strategy) InstanceID() string {
|
||||||
|
return fmt.Sprintf("%s-%s", ID, s.Symbol)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, session *bbgo.ExchangeSession) error {
|
func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, session *bbgo.ExchangeSession) error {
|
||||||
// StrategyController
|
// StrategyController
|
||||||
s.status = types.StrategyStatusRunning
|
s.status = types.StrategyStatusRunning
|
||||||
|
@ -596,7 +600,7 @@ func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, se
|
||||||
s.defaultBoll = s.StandardIndicatorSet.BOLL(s.DefaultBollinger.IntervalWindow, s.DefaultBollinger.BandWidth)
|
s.defaultBoll = s.StandardIndicatorSet.BOLL(s.DefaultBollinger.IntervalWindow, s.DefaultBollinger.BandWidth)
|
||||||
|
|
||||||
// calculate group id for orders
|
// calculate group id for orders
|
||||||
instanceID := fmt.Sprintf("%s-%s", ID, s.Symbol)
|
instanceID := s.InstanceID()
|
||||||
s.groupID = max.GenerateGroupID(instanceID)
|
s.groupID = max.GenerateGroupID(instanceID)
|
||||||
log.Infof("using group id %d from fnv(%s)", s.groupID, instanceID)
|
log.Infof("using group id %d from fnv(%s)", s.groupID, instanceID)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user