move dynamic stuff to the pkg/dynamic package

This commit is contained in:
c9s 2022-06-29 18:49:42 +08:00
parent a74decc47d
commit 3013eeccc7
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
12 changed files with 198 additions and 166 deletions

View File

@ -3,6 +3,7 @@ package bbgo
import ( import (
"reflect" "reflect"
"github.com/c9s/bbgo/pkg/dynamic"
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
) )
@ -23,7 +24,7 @@ func (m *ExitMethod) Subscribe(session *ExchangeSession) {
rt = rt.Elem() rt = rt.Elem()
infType := reflect.TypeOf((*types.Subscriber)(nil)).Elem() infType := reflect.TypeOf((*types.Subscriber)(nil)).Elem()
argValues := toReflectValues(session) argValues := dynamic.ToReflectValues(session)
for i := 0; i < rt.NumField(); i++ { for i := 0; i < rt.NumField(); i++ {
fieldType := rt.Field(i) fieldType := rt.Field(i)
if fieldType.Type.Implements(infType) { if fieldType.Type.Implements(infType) {

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/dynamic"
"github.com/c9s/bbgo/pkg/service" "github.com/c9s/bbgo/pkg/service"
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
) )
@ -22,7 +23,7 @@ func Test_injectField(t *testing.T) {
// get the value of the pointer, or it can not be set. // get the value of the pointer, or it can not be set.
var rv = reflect.ValueOf(tt).Elem() var rv = reflect.ValueOf(tt).Elem()
_, ret := hasField(rv, "TradeService") _, ret := dynamic.HasField(rv, "TradeService")
assert.True(t, ret) assert.True(t, ret)
ts := &service.TradeService{} ts := &service.TradeService{}

View File

@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/c9s/bbgo/pkg/dynamic"
"github.com/c9s/bbgo/pkg/service" "github.com/c9s/bbgo/pkg/service"
) )
@ -106,10 +107,10 @@ func Sync(obj interface{}) {
} }
func loadPersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error { func loadPersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error {
return iterateFieldsByTag(obj, "persistence", func(tag string, field reflect.StructField, value reflect.Value) error { return dynamic.IterateFieldsByTag(obj, "persistence", func(tag string, field reflect.StructField, value reflect.Value) error {
log.Debugf("[loadPersistenceFields] loading value into field %v, tag = %s, original value = %v", field, tag, value) log.Debugf("[loadPersistenceFields] loading value into field %v, tag = %s, original value = %v", field, tag, value)
newValueInf := newTypeValueInterface(value.Type()) newValueInf := dynamic.NewTypeValueInterface(value.Type())
// inf := value.Interface() // inf := value.Interface()
store := persistence.NewStore("state", id, tag) store := persistence.NewStore("state", id, tag)
if err := store.Load(&newValueInf); err != nil { if err := store.Load(&newValueInf); err != nil {
@ -134,7 +135,7 @@ func loadPersistenceFields(obj interface{}, id string, persistence service.Persi
} }
func storePersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error { func storePersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error {
return iterateFieldsByTag(obj, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error { return dynamic.IterateFieldsByTag(obj, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error {
log.Debugf("[storePersistenceFields] storing value from field %v, tag = %s, original value = %v", ft, tag, fv) log.Debugf("[storePersistenceFields] storing value from field %v, tag = %s, original value = %v", ft, tag, fv)
inf := fv.Interface() inf := fv.Interface()

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/dynamic"
"github.com/c9s/bbgo/pkg/fixedpoint" "github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/service" "github.com/c9s/bbgo/pkg/service"
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
@ -83,7 +84,7 @@ func Test_loadPersistenceFields(t *testing.T) {
t.Run(psName+"/nil", func(t *testing.T) { t.Run(psName+"/nil", func(t *testing.T) {
var b *TestStruct = nil var b *TestStruct = nil
err := loadPersistenceFields(b, "test-nil", ps) err := loadPersistenceFields(b, "test-nil", ps)
assert.Equal(t, errCanNotIterateNilPointer, err) assert.Equal(t, dynamic.ErrCanNotIterateNilPointer, err)
}) })
t.Run(psName+"/pointer-field", func(t *testing.T) { t.Run(psName+"/pointer-field", func(t *testing.T) {

View File

@ -1,8 +1,6 @@
package bbgo package bbgo
import ( import (
"errors"
"fmt"
"reflect" "reflect"
) )
@ -48,96 +46,3 @@ func isSymbolBasedStrategy(rs reflect.Value) (string, bool) {
return field.String(), true return field.String(), true
} }
func hasField(rs reflect.Value, fieldName string) (field reflect.Value, ok bool) {
field = rs.FieldByName(fieldName)
return field, field.IsValid()
}
type StructFieldIterator func(tag string, ft reflect.StructField, fv reflect.Value) error
var errCanNotIterateNilPointer = errors.New("can not iterate struct on a nil pointer")
func iterateFieldsByTag(obj interface{}, tagName string, cb StructFieldIterator) error {
sv := reflect.ValueOf(obj)
st := reflect.TypeOf(obj)
if st.Kind() != reflect.Ptr {
return fmt.Errorf("f should be a pointer of a struct, %s given", st)
}
// for pointer, check if it's nil
if sv.IsNil() {
return errCanNotIterateNilPointer
}
// solve the reference
st = st.Elem()
sv = sv.Elem()
if st.Kind() != reflect.Struct {
return fmt.Errorf("f should be a struct, %s given", st)
}
for i := 0; i < sv.NumField(); i++ {
fv := sv.Field(i)
ft := st.Field(i)
// skip unexported fields
if !st.Field(i).IsExported() {
continue
}
tag, ok := ft.Tag.Lookup(tagName)
if !ok {
continue
}
if err := cb(tag, ft, fv); err != nil {
return err
}
}
return nil
}
// https://github.com/xiaojun207/go-base-utils/blob/master/utils/Clone.go
func newTypeValueInterface(typ reflect.Type) interface{} {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
dst := reflect.New(typ).Elem()
return dst.Addr().Interface()
}
dst := reflect.New(typ)
return dst.Interface()
}
// toReflectValues convert the go objects into reflect.Value slice
func toReflectValues(args ...interface{}) (values []reflect.Value) {
for _, arg := range args {
values = append(values, reflect.ValueOf(arg))
}
return values
}
func reflectMergeStructFields(dst, src interface{}) {
rtA := reflect.TypeOf(dst)
srcStructType := reflect.TypeOf(src)
rtA = rtA.Elem()
srcStructType = srcStructType.Elem()
for i := 0; i < rtA.NumField(); i++ {
fieldType := rtA.Field(i)
fieldName := fieldType.Name
if fieldSrcType, ok := srcStructType.FieldByName(fieldName); ok {
if fieldSrcType.Type == fieldType.Type {
srcValue := reflect.ValueOf(src).Elem().FieldByName(fieldName)
dstValue := reflect.ValueOf(dst).Elem().FieldByName(fieldName)
if (fieldType.Type.Kind() == reflect.Ptr && dstValue.IsNil()) || dstValue.IsZero() {
dstValue.Set(srcValue)
}
}
}
}
}

View File

@ -1,66 +1,2 @@
package bbgo package bbgo
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/types"
)
func Test_reflectMergeStructFields(t *testing.T) {
t.Run("zero value", func(t *testing.T) {
a := &TestStrategy{Symbol: "BTCUSDT"}
b := &CumulatedVolumeTakeProfit{Symbol: ""}
reflectMergeStructFields(b, a)
assert.Equal(t, "BTCUSDT", b.Symbol)
})
t.Run("non-zero value", func(t *testing.T) {
a := &TestStrategy{Symbol: "BTCUSDT"}
b := &CumulatedVolumeTakeProfit{Symbol: "ETHUSDT"}
reflectMergeStructFields(b, a)
assert.Equal(t, "ETHUSDT", b.Symbol, "should be the original value")
})
t.Run("zero embedded struct", func(t *testing.T) {
iw := types.IntervalWindow{Interval: types.Interval1h, Window: 30}
a := &struct {
types.IntervalWindow
}{
IntervalWindow: iw,
}
b := &CumulatedVolumeTakeProfit{}
reflectMergeStructFields(b, a)
assert.Equal(t, iw, b.IntervalWindow)
})
t.Run("non-zero embedded struct", func(t *testing.T) {
iw := types.IntervalWindow{Interval: types.Interval1h, Window: 30}
a := &struct {
types.IntervalWindow
}{
IntervalWindow: iw,
}
b := &CumulatedVolumeTakeProfit{
IntervalWindow: types.IntervalWindow{Interval: types.Interval5m, Window: 9},
}
reflectMergeStructFields(b, a)
assert.Equal(t, types.IntervalWindow{Interval: types.Interval5m, Window: 9}, b.IntervalWindow)
})
t.Run("skip different type but the same name", func(t *testing.T) {
a := &struct {
A float64
}{
A: 1.99,
}
b := &struct {
A string
}{}
reflectMergeStructFields(b, a)
assert.Equal(t, "", b.A)
assert.Equal(t, 1.99, a.A)
})
}

View File

@ -10,6 +10,7 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/c9s/bbgo/pkg/dynamic"
"github.com/c9s/bbgo/pkg/interact" "github.com/c9s/bbgo/pkg/interact"
) )
@ -394,7 +395,7 @@ func (trader *Trader) injectCommonServices(s interface{}) error {
// a special injection for persistence selector: // a special injection for persistence selector:
// if user defined the selector, the facade pointer will be nil, hence we need to update the persistence facade pointer // if user defined the selector, the facade pointer will be nil, hence we need to update the persistence facade pointer
sv := reflect.ValueOf(s).Elem() sv := reflect.ValueOf(s).Elem()
if field, ok := hasField(sv, "Persistence"); ok { if field, ok := dynamic.HasField(sv, "Persistence"); ok {
// the selector is set, but we need to update the facade pointer // the selector is set, but we need to update the facade pointer
if !field.IsNil() { if !field.IsNil() {
elem := field.Elem() elem := field.Elem()

8
pkg/dynamic/field.go Normal file
View File

@ -0,0 +1,8 @@
package dynamic
import "reflect"
func HasField(rs reflect.Value, fieldName string) (field reflect.Value, ok bool) {
field = rs.FieldByName(fieldName)
return field, field.IsValid()
}

54
pkg/dynamic/iterate.go Normal file
View File

@ -0,0 +1,54 @@
package dynamic
import (
"errors"
"fmt"
"reflect"
)
type StructFieldIterator func(tag string, ft reflect.StructField, fv reflect.Value) error
var ErrCanNotIterateNilPointer = errors.New("can not iterate struct on a nil pointer")
func IterateFieldsByTag(obj interface{}, tagName string, cb StructFieldIterator) error {
sv := reflect.ValueOf(obj)
st := reflect.TypeOf(obj)
if st.Kind() != reflect.Ptr {
return fmt.Errorf("f should be a pointer of a struct, %s given", st)
}
// for pointer, check if it's nil
if sv.IsNil() {
return ErrCanNotIterateNilPointer
}
// solve the reference
st = st.Elem()
sv = sv.Elem()
if st.Kind() != reflect.Struct {
return fmt.Errorf("f should be a struct, %s given", st)
}
for i := 0; i < sv.NumField(); i++ {
fv := sv.Field(i)
ft := st.Field(i)
// skip unexported fields
if !st.Field(i).IsExported() {
continue
}
tag, ok := ft.Tag.Lookup(tagName)
if !ok {
continue
}
if err := cb(tag, ft, fv); err != nil {
return err
}
}
return nil
}

25
pkg/dynamic/merge.go Normal file
View File

@ -0,0 +1,25 @@
package dynamic
import "reflect"
func MergeStructValues(dst, src interface{}) {
rtA := reflect.TypeOf(dst)
srcStructType := reflect.TypeOf(src)
rtA = rtA.Elem()
srcStructType = srcStructType.Elem()
for i := 0; i < rtA.NumField(); i++ {
fieldType := rtA.Field(i)
fieldName := fieldType.Name
if fieldSrcType, ok := srcStructType.FieldByName(fieldName); ok {
if fieldSrcType.Type == fieldType.Type {
srcValue := reflect.ValueOf(src).Elem().FieldByName(fieldName)
dstValue := reflect.ValueOf(dst).Elem().FieldByName(fieldName)
if (fieldType.Type.Kind() == reflect.Ptr && dstValue.IsNil()) || dstValue.IsZero() {
dstValue.Set(srcValue)
}
}
}
}
}

75
pkg/dynamic/merge_test.go Normal file
View File

@ -0,0 +1,75 @@
package dynamic
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/bbgo"
"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/types"
)
type TestStrategy struct {
Symbol string `json:"symbol"`
Interval string `json:"interval"`
BaseQuantity fixedpoint.Value `json:"baseQuantity"`
MaxAssetQuantity fixedpoint.Value `json:"maxAssetQuantity"`
MinDropPercentage fixedpoint.Value `json:"minDropPercentage"`
}
func Test_reflectMergeStructFields(t *testing.T) {
t.Run("zero value", func(t *testing.T) {
a := &TestStrategy{Symbol: "BTCUSDT"}
b := &bbgo.CumulatedVolumeTakeProfit{Symbol: ""}
MergeStructValues(b, a)
assert.Equal(t, "BTCUSDT", b.Symbol)
})
t.Run("non-zero value", func(t *testing.T) {
a := &TestStrategy{Symbol: "BTCUSDT"}
b := &bbgo.CumulatedVolumeTakeProfit{Symbol: "ETHUSDT"}
MergeStructValues(b, a)
assert.Equal(t, "ETHUSDT", b.Symbol, "should be the original value")
})
t.Run("zero embedded struct", func(t *testing.T) {
iw := types.IntervalWindow{Interval: types.Interval1h, Window: 30}
a := &struct {
types.IntervalWindow
}{
IntervalWindow: iw,
}
b := &bbgo.CumulatedVolumeTakeProfit{}
MergeStructValues(b, a)
assert.Equal(t, iw, b.IntervalWindow)
})
t.Run("non-zero embedded struct", func(t *testing.T) {
iw := types.IntervalWindow{Interval: types.Interval1h, Window: 30}
a := &struct {
types.IntervalWindow
}{
IntervalWindow: iw,
}
b := &bbgo.CumulatedVolumeTakeProfit{
IntervalWindow: types.IntervalWindow{Interval: types.Interval5m, Window: 9},
}
MergeStructValues(b, a)
assert.Equal(t, types.IntervalWindow{Interval: types.Interval5m, Window: 9}, b.IntervalWindow)
})
t.Run("skip different type but the same name", func(t *testing.T) {
a := &struct {
A float64
}{
A: 1.99,
}
b := &struct {
A string
}{}
MergeStructValues(b, a)
assert.Equal(t, "", b.A)
assert.Equal(t, 1.99, a.A)
})
}

24
pkg/dynamic/typevalue.go Normal file
View File

@ -0,0 +1,24 @@
package dynamic
import "reflect"
// https://github.com/xiaojun207/go-base-utils/blob/master/utils/Clone.go
func NewTypeValueInterface(typ reflect.Type) interface{} {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
dst := reflect.New(typ).Elem()
return dst.Addr().Interface()
}
dst := reflect.New(typ)
return dst.Interface()
}
// ToReflectValues convert the go objects into reflect.Value slice
func ToReflectValues(args ...interface{}) (values []reflect.Value) {
for _, arg := range args {
values = append(values, reflect.ValueOf(arg))
}
return values
}