From 55e9c7ee25c5954751fc5207eddc108dd94bd397 Mon Sep 17 00:00:00 2001 From: c9s Date: Fri, 3 Jun 2022 01:57:39 +0800 Subject: [PATCH] add more test on Test_loadPersistenceFields --- pkg/bbgo/persistence.go | 54 ---------------------------- pkg/bbgo/persistence_test.go | 25 ++++++++++++- pkg/bbgo/reflect.go | 63 +++++++++++++++++++++++++++++++++ pkg/service/persistence_json.go | 2 +- 4 files changed, 88 insertions(+), 56 deletions(-) diff --git a/pkg/bbgo/persistence.go b/pkg/bbgo/persistence.go index 3cc177463..9162c7381 100644 --- a/pkg/bbgo/persistence.go +++ b/pkg/bbgo/persistence.go @@ -86,60 +86,6 @@ func (p *Persistence) Sync(obj interface{}) error { return storePersistenceFields(obj, id, ps) } -type StructFieldIterator func(tag string, ft reflect.StructField, fv reflect.Value) error - -func iterateFieldsByTag(obj interface{}, tagName string, cb StructFieldIterator) error { - sv := reflect.ValueOf(obj) - st := reflect.TypeOf(obj) - - if st.Kind() != reflect.Ptr { - return fmt.Errorf("f needs to be a pointer of a struct, %s given", st) - } - - // solve the reference - st = st.Elem() - sv = sv.Elem() - - if st.Kind() != reflect.Struct { - return fmt.Errorf("f needs to be a struct, %s given", st) - } - - for i := 0; i < sv.NumField(); i++ { - fv := sv.Field(i) - ft := st.Field(i) - - fvt := fv.Type() - _ = fvt - - // skip unexported fields - if !st.Field(i).IsExported() { - continue - } - - tag, ok := ft.Tag.Lookup(tagName) - if !ok { - continue - } - - if err := cb(tag, ft, fv); err != nil { - return err - } - } - - return nil -} - -// https://github.com/xiaojun207/go-base-utils/blob/master/utils/Clone.go -func newTypeValueInterface(typ reflect.Type) interface{} { - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - dst := reflect.New(typ).Elem() - return dst.Addr().Interface() - } - dst := reflect.New(typ) - return dst.Interface() -} - func loadPersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error { return iterateFieldsByTag(obj, "persistence", func(tag string, field reflect.StructField, value reflect.Value) error { newValueInf := newTypeValueInterface(value.Type()) diff --git a/pkg/bbgo/persistence_test.go b/pkg/bbgo/persistence_test.go index 97916a172..7ae80b8fd 100644 --- a/pkg/bbgo/persistence_test.go +++ b/pkg/bbgo/persistence_test.go @@ -13,6 +13,9 @@ import ( ) type TestStruct struct { + *Environment + *Graceful + Position *types.Position `persistence:"position"` Integer int64 `persistence:"integer"` Integer2 int64 `persistence:"integer2"` @@ -49,6 +52,25 @@ func Test_callID(t *testing.T) { assert.NotEmpty(t, id) } +func Test_loadPersistenceFields(t *testing.T) { + var pss = preparePersistentServices() + + for _, ps := range pss { + psName := reflect.TypeOf(ps).Elem().String() + t.Run(psName+"/empty", func(t *testing.T) { + b := &TestStruct{} + err := loadPersistenceFields(b, "test-empty", ps) + assert.NoError(t, err) + }) + + t.Run(psName+"/nil", func(t *testing.T) { + var b *TestStruct = nil + err := loadPersistenceFields(b, "test-nil", ps) + assert.Equal(t, errCanNotIterateNilPointer, err) + }) + } +} + func Test_storePersistenceFields(t *testing.T) { var pss = preparePersistentServices() @@ -64,7 +86,8 @@ func Test_storePersistenceFields(t *testing.T) { a.Position.AverageCost = fixedpoint.NewFromFloat(3343.0) for _, ps := range pss { - t.Run(reflect.TypeOf(ps).Elem().String(), func(t *testing.T) { + psName := reflect.TypeOf(ps).Elem().String() + t.Run("all/"+psName, func(t *testing.T) { id := callID(a) err := storePersistenceFields(a, id, ps) assert.NoError(t, err) diff --git a/pkg/bbgo/reflect.go b/pkg/bbgo/reflect.go index 78c929e5b..ce9f7c4c6 100644 --- a/pkg/bbgo/reflect.go +++ b/pkg/bbgo/reflect.go @@ -1,6 +1,8 @@ package bbgo import ( + "errors" + "fmt" "reflect" ) @@ -40,3 +42,64 @@ func hasField(rs reflect.Value, fieldName string) (field reflect.Value, ok bool) field = rs.FieldByName(fieldName) return field, field.IsValid() } + +type StructFieldIterator func(tag string, ft reflect.StructField, fv reflect.Value) error + +var errCanNotIterateNilPointer = errors.New("can not iterate struct on a nil pointer") + +func iterateFieldsByTag(obj interface{}, tagName string, cb StructFieldIterator) error { + sv := reflect.ValueOf(obj) + st := reflect.TypeOf(obj) + + if st.Kind() != reflect.Ptr { + return fmt.Errorf("f should be a pointer of a struct, %s given", st) + } + + // for pointer, check if it's nil + if sv.IsNil() { + return errCanNotIterateNilPointer + } + + // solve the reference + st = st.Elem() + sv = sv.Elem() + + if st.Kind() != reflect.Struct { + return fmt.Errorf("f should be a struct, %s given", st) + } + + for i := 0; i < sv.NumField(); i++ { + fv := sv.Field(i) + ft := st.Field(i) + + fvt := fv.Type() + _ = fvt + + // skip unexported fields + if !st.Field(i).IsExported() { + continue + } + + tag, ok := ft.Tag.Lookup(tagName) + if !ok { + continue + } + + if err := cb(tag, ft, fv); err != nil { + return err + } + } + + return nil +} + +// https://github.com/xiaojun207/go-base-utils/blob/master/utils/Clone.go +func newTypeValueInterface(typ reflect.Type) interface{} { + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + dst := reflect.New(typ).Elem() + return dst.Addr().Interface() + } + dst := reflect.New(typ) + return dst.Interface() +} diff --git a/pkg/service/persistence_json.go b/pkg/service/persistence_json.go index 53799134d..c125ea041 100644 --- a/pkg/service/persistence_json.go +++ b/pkg/service/persistence_json.go @@ -38,7 +38,7 @@ func (store JsonStore) Reset() error { func (store JsonStore) Load(val interface{}) error { if _, err := os.Stat(store.Directory); os.IsNotExist(err) { - if err2 := os.Mkdir(store.Directory, 0777); err2 != nil { + if err2 := os.MkdirAll(store.Directory, 0777); err2 != nil { return err2 } }