diff --git a/pkg/bbgo/exit.go b/pkg/bbgo/exit.go index 98404ddbf..11f97d8a8 100644 --- a/pkg/bbgo/exit.go +++ b/pkg/bbgo/exit.go @@ -3,6 +3,7 @@ package bbgo import ( "reflect" + "github.com/c9s/bbgo/pkg/dynamic" "github.com/c9s/bbgo/pkg/types" ) @@ -23,7 +24,7 @@ func (m *ExitMethod) Subscribe(session *ExchangeSession) { rt = rt.Elem() infType := reflect.TypeOf((*types.Subscriber)(nil)).Elem() - argValues := toReflectValues(session) + argValues := dynamic.ToReflectValues(session) for i := 0; i < rt.NumField(); i++ { fieldType := rt.Field(i) if fieldType.Type.Implements(infType) { diff --git a/pkg/bbgo/injection_test.go b/pkg/bbgo/injection_test.go index dd6370320..69a8b5f72 100644 --- a/pkg/bbgo/injection_test.go +++ b/pkg/bbgo/injection_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/c9s/bbgo/pkg/dynamic" "github.com/c9s/bbgo/pkg/service" "github.com/c9s/bbgo/pkg/types" ) @@ -22,7 +23,7 @@ func Test_injectField(t *testing.T) { // get the value of the pointer, or it can not be set. var rv = reflect.ValueOf(tt).Elem() - _, ret := hasField(rv, "TradeService") + _, ret := dynamic.HasField(rv, "TradeService") assert.True(t, ret) ts := &service.TradeService{} diff --git a/pkg/bbgo/persistence.go b/pkg/bbgo/persistence.go index b435c8f07..89b4179df 100644 --- a/pkg/bbgo/persistence.go +++ b/pkg/bbgo/persistence.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/c9s/bbgo/pkg/dynamic" "github.com/c9s/bbgo/pkg/service" ) @@ -106,10 +107,10 @@ func Sync(obj interface{}) { } func loadPersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error { - return iterateFieldsByTag(obj, "persistence", func(tag string, field reflect.StructField, value reflect.Value) error { + return dynamic.IterateFieldsByTag(obj, "persistence", func(tag string, field reflect.StructField, value reflect.Value) error { log.Debugf("[loadPersistenceFields] loading value into field %v, tag = %s, original value = %v", field, tag, value) - newValueInf := newTypeValueInterface(value.Type()) + newValueInf := dynamic.NewTypeValueInterface(value.Type()) // inf := value.Interface() store := persistence.NewStore("state", id, tag) if err := store.Load(&newValueInf); err != nil { @@ -134,7 +135,7 @@ func loadPersistenceFields(obj interface{}, id string, persistence service.Persi } func storePersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error { - return iterateFieldsByTag(obj, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error { + return dynamic.IterateFieldsByTag(obj, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error { log.Debugf("[storePersistenceFields] storing value from field %v, tag = %s, original value = %v", ft, tag, fv) inf := fv.Interface() diff --git a/pkg/bbgo/persistence_test.go b/pkg/bbgo/persistence_test.go index ebc5314f0..1c14eb2e9 100644 --- a/pkg/bbgo/persistence_test.go +++ b/pkg/bbgo/persistence_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/c9s/bbgo/pkg/dynamic" "github.com/c9s/bbgo/pkg/fixedpoint" "github.com/c9s/bbgo/pkg/service" "github.com/c9s/bbgo/pkg/types" @@ -83,7 +84,7 @@ func Test_loadPersistenceFields(t *testing.T) { t.Run(psName+"/nil", func(t *testing.T) { var b *TestStruct = nil err := loadPersistenceFields(b, "test-nil", ps) - assert.Equal(t, errCanNotIterateNilPointer, err) + assert.Equal(t, dynamic.ErrCanNotIterateNilPointer, err) }) t.Run(psName+"/pointer-field", func(t *testing.T) { diff --git a/pkg/bbgo/reflect.go b/pkg/bbgo/reflect.go index 13ed95be1..aad404e1b 100644 --- a/pkg/bbgo/reflect.go +++ b/pkg/bbgo/reflect.go @@ -1,8 +1,6 @@ package bbgo import ( - "errors" - "fmt" "reflect" ) @@ -48,96 +46,3 @@ func isSymbolBasedStrategy(rs reflect.Value) (string, bool) { return field.String(), true } -func hasField(rs reflect.Value, fieldName string) (field reflect.Value, ok bool) { - field = rs.FieldByName(fieldName) - return field, field.IsValid() -} - -type StructFieldIterator func(tag string, ft reflect.StructField, fv reflect.Value) error - -var errCanNotIterateNilPointer = errors.New("can not iterate struct on a nil pointer") - -func iterateFieldsByTag(obj interface{}, tagName string, cb StructFieldIterator) error { - sv := reflect.ValueOf(obj) - st := reflect.TypeOf(obj) - - if st.Kind() != reflect.Ptr { - return fmt.Errorf("f should be a pointer of a struct, %s given", st) - } - - // for pointer, check if it's nil - if sv.IsNil() { - return errCanNotIterateNilPointer - } - - // solve the reference - st = st.Elem() - sv = sv.Elem() - - if st.Kind() != reflect.Struct { - return fmt.Errorf("f should be a struct, %s given", st) - } - - for i := 0; i < sv.NumField(); i++ { - fv := sv.Field(i) - ft := st.Field(i) - - // skip unexported fields - if !st.Field(i).IsExported() { - continue - } - - tag, ok := ft.Tag.Lookup(tagName) - if !ok { - continue - } - - if err := cb(tag, ft, fv); err != nil { - return err - } - } - - return nil -} - -// https://github.com/xiaojun207/go-base-utils/blob/master/utils/Clone.go -func newTypeValueInterface(typ reflect.Type) interface{} { - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - dst := reflect.New(typ).Elem() - return dst.Addr().Interface() - } - dst := reflect.New(typ) - return dst.Interface() -} - -// toReflectValues convert the go objects into reflect.Value slice -func toReflectValues(args ...interface{}) (values []reflect.Value) { - for _, arg := range args { - values = append(values, reflect.ValueOf(arg)) - } - - return values -} - -func reflectMergeStructFields(dst, src interface{}) { - rtA := reflect.TypeOf(dst) - srcStructType := reflect.TypeOf(src) - - rtA = rtA.Elem() - srcStructType = srcStructType.Elem() - - for i := 0; i < rtA.NumField(); i++ { - fieldType := rtA.Field(i) - fieldName := fieldType.Name - if fieldSrcType, ok := srcStructType.FieldByName(fieldName); ok { - if fieldSrcType.Type == fieldType.Type { - srcValue := reflect.ValueOf(src).Elem().FieldByName(fieldName) - dstValue := reflect.ValueOf(dst).Elem().FieldByName(fieldName) - if (fieldType.Type.Kind() == reflect.Ptr && dstValue.IsNil()) || dstValue.IsZero() { - dstValue.Set(srcValue) - } - } - } - } -} diff --git a/pkg/bbgo/reflect_test.go b/pkg/bbgo/reflect_test.go index 60f32cc32..920078f66 100644 --- a/pkg/bbgo/reflect_test.go +++ b/pkg/bbgo/reflect_test.go @@ -1,66 +1,2 @@ package bbgo -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/c9s/bbgo/pkg/types" -) - -func Test_reflectMergeStructFields(t *testing.T) { - t.Run("zero value", func(t *testing.T) { - a := &TestStrategy{Symbol: "BTCUSDT"} - b := &CumulatedVolumeTakeProfit{Symbol: ""} - reflectMergeStructFields(b, a) - assert.Equal(t, "BTCUSDT", b.Symbol) - }) - - t.Run("non-zero value", func(t *testing.T) { - a := &TestStrategy{Symbol: "BTCUSDT"} - b := &CumulatedVolumeTakeProfit{Symbol: "ETHUSDT"} - reflectMergeStructFields(b, a) - assert.Equal(t, "ETHUSDT", b.Symbol, "should be the original value") - }) - - t.Run("zero embedded struct", func(t *testing.T) { - iw := types.IntervalWindow{Interval: types.Interval1h, Window: 30} - a := &struct { - types.IntervalWindow - }{ - IntervalWindow: iw, - } - b := &CumulatedVolumeTakeProfit{} - reflectMergeStructFields(b, a) - assert.Equal(t, iw, b.IntervalWindow) - }) - - t.Run("non-zero embedded struct", func(t *testing.T) { - iw := types.IntervalWindow{Interval: types.Interval1h, Window: 30} - a := &struct { - types.IntervalWindow - }{ - IntervalWindow: iw, - } - b := &CumulatedVolumeTakeProfit{ - IntervalWindow: types.IntervalWindow{Interval: types.Interval5m, Window: 9}, - } - reflectMergeStructFields(b, a) - assert.Equal(t, types.IntervalWindow{Interval: types.Interval5m, Window: 9}, b.IntervalWindow) - }) - - t.Run("skip different type but the same name", func(t *testing.T) { - a := &struct { - A float64 - }{ - A: 1.99, - } - b := &struct { - A string - }{} - reflectMergeStructFields(b, a) - assert.Equal(t, "", b.A) - assert.Equal(t, 1.99, a.A) - }) - -} diff --git a/pkg/bbgo/trader.go b/pkg/bbgo/trader.go index 4fc35e08e..f05e8f9b7 100644 --- a/pkg/bbgo/trader.go +++ b/pkg/bbgo/trader.go @@ -10,6 +10,7 @@ import ( _ "github.com/go-sql-driver/mysql" + "github.com/c9s/bbgo/pkg/dynamic" "github.com/c9s/bbgo/pkg/interact" ) @@ -394,7 +395,7 @@ func (trader *Trader) injectCommonServices(s interface{}) error { // a special injection for persistence selector: // if user defined the selector, the facade pointer will be nil, hence we need to update the persistence facade pointer sv := reflect.ValueOf(s).Elem() - if field, ok := hasField(sv, "Persistence"); ok { + if field, ok := dynamic.HasField(sv, "Persistence"); ok { // the selector is set, but we need to update the facade pointer if !field.IsNil() { elem := field.Elem() diff --git a/pkg/dynamic/field.go b/pkg/dynamic/field.go new file mode 100644 index 000000000..baccd563a --- /dev/null +++ b/pkg/dynamic/field.go @@ -0,0 +1,8 @@ +package dynamic + +import "reflect" + +func HasField(rs reflect.Value, fieldName string) (field reflect.Value, ok bool) { + field = rs.FieldByName(fieldName) + return field, field.IsValid() +} diff --git a/pkg/dynamic/iterate.go b/pkg/dynamic/iterate.go new file mode 100644 index 000000000..12d6e2842 --- /dev/null +++ b/pkg/dynamic/iterate.go @@ -0,0 +1,54 @@ +package dynamic + +import ( + "errors" + "fmt" + "reflect" +) + +type StructFieldIterator func(tag string, ft reflect.StructField, fv reflect.Value) error + +var ErrCanNotIterateNilPointer = errors.New("can not iterate struct on a nil pointer") + +func IterateFieldsByTag(obj interface{}, tagName string, cb StructFieldIterator) error { + sv := reflect.ValueOf(obj) + st := reflect.TypeOf(obj) + + if st.Kind() != reflect.Ptr { + return fmt.Errorf("f should be a pointer of a struct, %s given", st) + } + + // for pointer, check if it's nil + if sv.IsNil() { + return ErrCanNotIterateNilPointer + } + + // solve the reference + st = st.Elem() + sv = sv.Elem() + + if st.Kind() != reflect.Struct { + return fmt.Errorf("f should be a struct, %s given", st) + } + + for i := 0; i < sv.NumField(); i++ { + fv := sv.Field(i) + ft := st.Field(i) + + // skip unexported fields + if !st.Field(i).IsExported() { + continue + } + + tag, ok := ft.Tag.Lookup(tagName) + if !ok { + continue + } + + if err := cb(tag, ft, fv); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/dynamic/merge.go b/pkg/dynamic/merge.go new file mode 100644 index 000000000..1f0472c88 --- /dev/null +++ b/pkg/dynamic/merge.go @@ -0,0 +1,25 @@ +package dynamic + +import "reflect" + +func MergeStructValues(dst, src interface{}) { + rtA := reflect.TypeOf(dst) + srcStructType := reflect.TypeOf(src) + + rtA = rtA.Elem() + srcStructType = srcStructType.Elem() + + for i := 0; i < rtA.NumField(); i++ { + fieldType := rtA.Field(i) + fieldName := fieldType.Name + if fieldSrcType, ok := srcStructType.FieldByName(fieldName); ok { + if fieldSrcType.Type == fieldType.Type { + srcValue := reflect.ValueOf(src).Elem().FieldByName(fieldName) + dstValue := reflect.ValueOf(dst).Elem().FieldByName(fieldName) + if (fieldType.Type.Kind() == reflect.Ptr && dstValue.IsNil()) || dstValue.IsZero() { + dstValue.Set(srcValue) + } + } + } + } +} diff --git a/pkg/dynamic/merge_test.go b/pkg/dynamic/merge_test.go new file mode 100644 index 000000000..2e8929ff0 --- /dev/null +++ b/pkg/dynamic/merge_test.go @@ -0,0 +1,75 @@ +package dynamic + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/c9s/bbgo/pkg/bbgo" + "github.com/c9s/bbgo/pkg/fixedpoint" + "github.com/c9s/bbgo/pkg/types" +) + +type TestStrategy struct { + Symbol string `json:"symbol"` + Interval string `json:"interval"` + BaseQuantity fixedpoint.Value `json:"baseQuantity"` + MaxAssetQuantity fixedpoint.Value `json:"maxAssetQuantity"` + MinDropPercentage fixedpoint.Value `json:"minDropPercentage"` +} + +func Test_reflectMergeStructFields(t *testing.T) { + t.Run("zero value", func(t *testing.T) { + a := &TestStrategy{Symbol: "BTCUSDT"} + b := &bbgo.CumulatedVolumeTakeProfit{Symbol: ""} + MergeStructValues(b, a) + assert.Equal(t, "BTCUSDT", b.Symbol) + }) + + t.Run("non-zero value", func(t *testing.T) { + a := &TestStrategy{Symbol: "BTCUSDT"} + b := &bbgo.CumulatedVolumeTakeProfit{Symbol: "ETHUSDT"} + MergeStructValues(b, a) + assert.Equal(t, "ETHUSDT", b.Symbol, "should be the original value") + }) + + t.Run("zero embedded struct", func(t *testing.T) { + iw := types.IntervalWindow{Interval: types.Interval1h, Window: 30} + a := &struct { + types.IntervalWindow + }{ + IntervalWindow: iw, + } + b := &bbgo.CumulatedVolumeTakeProfit{} + MergeStructValues(b, a) + assert.Equal(t, iw, b.IntervalWindow) + }) + + t.Run("non-zero embedded struct", func(t *testing.T) { + iw := types.IntervalWindow{Interval: types.Interval1h, Window: 30} + a := &struct { + types.IntervalWindow + }{ + IntervalWindow: iw, + } + b := &bbgo.CumulatedVolumeTakeProfit{ + IntervalWindow: types.IntervalWindow{Interval: types.Interval5m, Window: 9}, + } + MergeStructValues(b, a) + assert.Equal(t, types.IntervalWindow{Interval: types.Interval5m, Window: 9}, b.IntervalWindow) + }) + + t.Run("skip different type but the same name", func(t *testing.T) { + a := &struct { + A float64 + }{ + A: 1.99, + } + b := &struct { + A string + }{} + MergeStructValues(b, a) + assert.Equal(t, "", b.A) + assert.Equal(t, 1.99, a.A) + }) +} diff --git a/pkg/dynamic/typevalue.go b/pkg/dynamic/typevalue.go new file mode 100644 index 000000000..a12ccf416 --- /dev/null +++ b/pkg/dynamic/typevalue.go @@ -0,0 +1,24 @@ +package dynamic + +import "reflect" + +// https://github.com/xiaojun207/go-base-utils/blob/master/utils/Clone.go +func NewTypeValueInterface(typ reflect.Type) interface{} { + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + dst := reflect.New(typ).Elem() + return dst.Addr().Interface() + } + dst := reflect.New(typ) + return dst.Interface() +} + +// ToReflectValues convert the go objects into reflect.Value slice +func ToReflectValues(args ...interface{}) (values []reflect.Value) { + for _, arg := range args { + values = append(values, reflect.ValueOf(arg)) + } + + return values +} +