mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-10 09:11:55 +00:00
move dynamic stuff to the pkg/dynamic package
This commit is contained in:
parent
a74decc47d
commit
3013eeccc7
|
@ -3,6 +3,7 @@ package bbgo
|
|||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/dynamic"
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
)
|
||||
|
||||
|
@ -23,7 +24,7 @@ func (m *ExitMethod) Subscribe(session *ExchangeSession) {
|
|||
rt = rt.Elem()
|
||||
infType := reflect.TypeOf((*types.Subscriber)(nil)).Elem()
|
||||
|
||||
argValues := toReflectValues(session)
|
||||
argValues := dynamic.ToReflectValues(session)
|
||||
for i := 0; i < rt.NumField(); i++ {
|
||||
fieldType := rt.Field(i)
|
||||
if fieldType.Type.Implements(infType) {
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/dynamic"
|
||||
"github.com/c9s/bbgo/pkg/service"
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
)
|
||||
|
@ -22,7 +23,7 @@ func Test_injectField(t *testing.T) {
|
|||
// get the value of the pointer, or it can not be set.
|
||||
var rv = reflect.ValueOf(tt).Elem()
|
||||
|
||||
_, ret := hasField(rv, "TradeService")
|
||||
_, ret := dynamic.HasField(rv, "TradeService")
|
||||
assert.True(t, ret)
|
||||
|
||||
ts := &service.TradeService{}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/dynamic"
|
||||
"github.com/c9s/bbgo/pkg/service"
|
||||
)
|
||||
|
||||
|
@ -106,10 +107,10 @@ func Sync(obj interface{}) {
|
|||
}
|
||||
|
||||
func loadPersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error {
|
||||
return iterateFieldsByTag(obj, "persistence", func(tag string, field reflect.StructField, value reflect.Value) error {
|
||||
return dynamic.IterateFieldsByTag(obj, "persistence", func(tag string, field reflect.StructField, value reflect.Value) error {
|
||||
log.Debugf("[loadPersistenceFields] loading value into field %v, tag = %s, original value = %v", field, tag, value)
|
||||
|
||||
newValueInf := newTypeValueInterface(value.Type())
|
||||
newValueInf := dynamic.NewTypeValueInterface(value.Type())
|
||||
// inf := value.Interface()
|
||||
store := persistence.NewStore("state", id, tag)
|
||||
if err := store.Load(&newValueInf); err != nil {
|
||||
|
@ -134,7 +135,7 @@ func loadPersistenceFields(obj interface{}, id string, persistence service.Persi
|
|||
}
|
||||
|
||||
func storePersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error {
|
||||
return iterateFieldsByTag(obj, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error {
|
||||
return dynamic.IterateFieldsByTag(obj, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error {
|
||||
log.Debugf("[storePersistenceFields] storing value from field %v, tag = %s, original value = %v", ft, tag, fv)
|
||||
|
||||
inf := fv.Interface()
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/dynamic"
|
||||
"github.com/c9s/bbgo/pkg/fixedpoint"
|
||||
"github.com/c9s/bbgo/pkg/service"
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
|
@ -83,7 +84,7 @@ func Test_loadPersistenceFields(t *testing.T) {
|
|||
t.Run(psName+"/nil", func(t *testing.T) {
|
||||
var b *TestStruct = nil
|
||||
err := loadPersistenceFields(b, "test-nil", ps)
|
||||
assert.Equal(t, errCanNotIterateNilPointer, err)
|
||||
assert.Equal(t, dynamic.ErrCanNotIterateNilPointer, err)
|
||||
})
|
||||
|
||||
t.Run(psName+"/pointer-field", func(t *testing.T) {
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package bbgo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
|
@ -48,96 +46,3 @@ func isSymbolBasedStrategy(rs reflect.Value) (string, bool) {
|
|||
return field.String(), true
|
||||
}
|
||||
|
||||
func hasField(rs reflect.Value, fieldName string) (field reflect.Value, ok bool) {
|
||||
field = rs.FieldByName(fieldName)
|
||||
return field, field.IsValid()
|
||||
}
|
||||
|
||||
type StructFieldIterator func(tag string, ft reflect.StructField, fv reflect.Value) error
|
||||
|
||||
var errCanNotIterateNilPointer = errors.New("can not iterate struct on a nil pointer")
|
||||
|
||||
func iterateFieldsByTag(obj interface{}, tagName string, cb StructFieldIterator) error {
|
||||
sv := reflect.ValueOf(obj)
|
||||
st := reflect.TypeOf(obj)
|
||||
|
||||
if st.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("f should be a pointer of a struct, %s given", st)
|
||||
}
|
||||
|
||||
// for pointer, check if it's nil
|
||||
if sv.IsNil() {
|
||||
return errCanNotIterateNilPointer
|
||||
}
|
||||
|
||||
// solve the reference
|
||||
st = st.Elem()
|
||||
sv = sv.Elem()
|
||||
|
||||
if st.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("f should be a struct, %s given", st)
|
||||
}
|
||||
|
||||
for i := 0; i < sv.NumField(); i++ {
|
||||
fv := sv.Field(i)
|
||||
ft := st.Field(i)
|
||||
|
||||
// skip unexported fields
|
||||
if !st.Field(i).IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
tag, ok := ft.Tag.Lookup(tagName)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := cb(tag, ft, fv); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// https://github.com/xiaojun207/go-base-utils/blob/master/utils/Clone.go
|
||||
func newTypeValueInterface(typ reflect.Type) interface{} {
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typ = typ.Elem()
|
||||
dst := reflect.New(typ).Elem()
|
||||
return dst.Addr().Interface()
|
||||
}
|
||||
dst := reflect.New(typ)
|
||||
return dst.Interface()
|
||||
}
|
||||
|
||||
// toReflectValues convert the go objects into reflect.Value slice
|
||||
func toReflectValues(args ...interface{}) (values []reflect.Value) {
|
||||
for _, arg := range args {
|
||||
values = append(values, reflect.ValueOf(arg))
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
func reflectMergeStructFields(dst, src interface{}) {
|
||||
rtA := reflect.TypeOf(dst)
|
||||
srcStructType := reflect.TypeOf(src)
|
||||
|
||||
rtA = rtA.Elem()
|
||||
srcStructType = srcStructType.Elem()
|
||||
|
||||
for i := 0; i < rtA.NumField(); i++ {
|
||||
fieldType := rtA.Field(i)
|
||||
fieldName := fieldType.Name
|
||||
if fieldSrcType, ok := srcStructType.FieldByName(fieldName); ok {
|
||||
if fieldSrcType.Type == fieldType.Type {
|
||||
srcValue := reflect.ValueOf(src).Elem().FieldByName(fieldName)
|
||||
dstValue := reflect.ValueOf(dst).Elem().FieldByName(fieldName)
|
||||
if (fieldType.Type.Kind() == reflect.Ptr && dstValue.IsNil()) || dstValue.IsZero() {
|
||||
dstValue.Set(srcValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,66 +1,2 @@
|
|||
package bbgo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
)
|
||||
|
||||
func Test_reflectMergeStructFields(t *testing.T) {
|
||||
t.Run("zero value", func(t *testing.T) {
|
||||
a := &TestStrategy{Symbol: "BTCUSDT"}
|
||||
b := &CumulatedVolumeTakeProfit{Symbol: ""}
|
||||
reflectMergeStructFields(b, a)
|
||||
assert.Equal(t, "BTCUSDT", b.Symbol)
|
||||
})
|
||||
|
||||
t.Run("non-zero value", func(t *testing.T) {
|
||||
a := &TestStrategy{Symbol: "BTCUSDT"}
|
||||
b := &CumulatedVolumeTakeProfit{Symbol: "ETHUSDT"}
|
||||
reflectMergeStructFields(b, a)
|
||||
assert.Equal(t, "ETHUSDT", b.Symbol, "should be the original value")
|
||||
})
|
||||
|
||||
t.Run("zero embedded struct", func(t *testing.T) {
|
||||
iw := types.IntervalWindow{Interval: types.Interval1h, Window: 30}
|
||||
a := &struct {
|
||||
types.IntervalWindow
|
||||
}{
|
||||
IntervalWindow: iw,
|
||||
}
|
||||
b := &CumulatedVolumeTakeProfit{}
|
||||
reflectMergeStructFields(b, a)
|
||||
assert.Equal(t, iw, b.IntervalWindow)
|
||||
})
|
||||
|
||||
t.Run("non-zero embedded struct", func(t *testing.T) {
|
||||
iw := types.IntervalWindow{Interval: types.Interval1h, Window: 30}
|
||||
a := &struct {
|
||||
types.IntervalWindow
|
||||
}{
|
||||
IntervalWindow: iw,
|
||||
}
|
||||
b := &CumulatedVolumeTakeProfit{
|
||||
IntervalWindow: types.IntervalWindow{Interval: types.Interval5m, Window: 9},
|
||||
}
|
||||
reflectMergeStructFields(b, a)
|
||||
assert.Equal(t, types.IntervalWindow{Interval: types.Interval5m, Window: 9}, b.IntervalWindow)
|
||||
})
|
||||
|
||||
t.Run("skip different type but the same name", func(t *testing.T) {
|
||||
a := &struct {
|
||||
A float64
|
||||
}{
|
||||
A: 1.99,
|
||||
}
|
||||
b := &struct {
|
||||
A string
|
||||
}{}
|
||||
reflectMergeStructFields(b, a)
|
||||
assert.Equal(t, "", b.A)
|
||||
assert.Equal(t, 1.99, a.A)
|
||||
})
|
||||
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/dynamic"
|
||||
"github.com/c9s/bbgo/pkg/interact"
|
||||
)
|
||||
|
||||
|
@ -394,7 +395,7 @@ func (trader *Trader) injectCommonServices(s interface{}) error {
|
|||
// a special injection for persistence selector:
|
||||
// if user defined the selector, the facade pointer will be nil, hence we need to update the persistence facade pointer
|
||||
sv := reflect.ValueOf(s).Elem()
|
||||
if field, ok := hasField(sv, "Persistence"); ok {
|
||||
if field, ok := dynamic.HasField(sv, "Persistence"); ok {
|
||||
// the selector is set, but we need to update the facade pointer
|
||||
if !field.IsNil() {
|
||||
elem := field.Elem()
|
||||
|
|
8
pkg/dynamic/field.go
Normal file
8
pkg/dynamic/field.go
Normal file
|
@ -0,0 +1,8 @@
|
|||
package dynamic
|
||||
|
||||
import "reflect"
|
||||
|
||||
func HasField(rs reflect.Value, fieldName string) (field reflect.Value, ok bool) {
|
||||
field = rs.FieldByName(fieldName)
|
||||
return field, field.IsValid()
|
||||
}
|
54
pkg/dynamic/iterate.go
Normal file
54
pkg/dynamic/iterate.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package dynamic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type StructFieldIterator func(tag string, ft reflect.StructField, fv reflect.Value) error
|
||||
|
||||
var ErrCanNotIterateNilPointer = errors.New("can not iterate struct on a nil pointer")
|
||||
|
||||
func IterateFieldsByTag(obj interface{}, tagName string, cb StructFieldIterator) error {
|
||||
sv := reflect.ValueOf(obj)
|
||||
st := reflect.TypeOf(obj)
|
||||
|
||||
if st.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("f should be a pointer of a struct, %s given", st)
|
||||
}
|
||||
|
||||
// for pointer, check if it's nil
|
||||
if sv.IsNil() {
|
||||
return ErrCanNotIterateNilPointer
|
||||
}
|
||||
|
||||
// solve the reference
|
||||
st = st.Elem()
|
||||
sv = sv.Elem()
|
||||
|
||||
if st.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("f should be a struct, %s given", st)
|
||||
}
|
||||
|
||||
for i := 0; i < sv.NumField(); i++ {
|
||||
fv := sv.Field(i)
|
||||
ft := st.Field(i)
|
||||
|
||||
// skip unexported fields
|
||||
if !st.Field(i).IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
tag, ok := ft.Tag.Lookup(tagName)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := cb(tag, ft, fv); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
25
pkg/dynamic/merge.go
Normal file
25
pkg/dynamic/merge.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package dynamic
|
||||
|
||||
import "reflect"
|
||||
|
||||
func MergeStructValues(dst, src interface{}) {
|
||||
rtA := reflect.TypeOf(dst)
|
||||
srcStructType := reflect.TypeOf(src)
|
||||
|
||||
rtA = rtA.Elem()
|
||||
srcStructType = srcStructType.Elem()
|
||||
|
||||
for i := 0; i < rtA.NumField(); i++ {
|
||||
fieldType := rtA.Field(i)
|
||||
fieldName := fieldType.Name
|
||||
if fieldSrcType, ok := srcStructType.FieldByName(fieldName); ok {
|
||||
if fieldSrcType.Type == fieldType.Type {
|
||||
srcValue := reflect.ValueOf(src).Elem().FieldByName(fieldName)
|
||||
dstValue := reflect.ValueOf(dst).Elem().FieldByName(fieldName)
|
||||
if (fieldType.Type.Kind() == reflect.Ptr && dstValue.IsNil()) || dstValue.IsZero() {
|
||||
dstValue.Set(srcValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
75
pkg/dynamic/merge_test.go
Normal file
75
pkg/dynamic/merge_test.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package dynamic
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/bbgo"
|
||||
"github.com/c9s/bbgo/pkg/fixedpoint"
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
)
|
||||
|
||||
type TestStrategy struct {
|
||||
Symbol string `json:"symbol"`
|
||||
Interval string `json:"interval"`
|
||||
BaseQuantity fixedpoint.Value `json:"baseQuantity"`
|
||||
MaxAssetQuantity fixedpoint.Value `json:"maxAssetQuantity"`
|
||||
MinDropPercentage fixedpoint.Value `json:"minDropPercentage"`
|
||||
}
|
||||
|
||||
func Test_reflectMergeStructFields(t *testing.T) {
|
||||
t.Run("zero value", func(t *testing.T) {
|
||||
a := &TestStrategy{Symbol: "BTCUSDT"}
|
||||
b := &bbgo.CumulatedVolumeTakeProfit{Symbol: ""}
|
||||
MergeStructValues(b, a)
|
||||
assert.Equal(t, "BTCUSDT", b.Symbol)
|
||||
})
|
||||
|
||||
t.Run("non-zero value", func(t *testing.T) {
|
||||
a := &TestStrategy{Symbol: "BTCUSDT"}
|
||||
b := &bbgo.CumulatedVolumeTakeProfit{Symbol: "ETHUSDT"}
|
||||
MergeStructValues(b, a)
|
||||
assert.Equal(t, "ETHUSDT", b.Symbol, "should be the original value")
|
||||
})
|
||||
|
||||
t.Run("zero embedded struct", func(t *testing.T) {
|
||||
iw := types.IntervalWindow{Interval: types.Interval1h, Window: 30}
|
||||
a := &struct {
|
||||
types.IntervalWindow
|
||||
}{
|
||||
IntervalWindow: iw,
|
||||
}
|
||||
b := &bbgo.CumulatedVolumeTakeProfit{}
|
||||
MergeStructValues(b, a)
|
||||
assert.Equal(t, iw, b.IntervalWindow)
|
||||
})
|
||||
|
||||
t.Run("non-zero embedded struct", func(t *testing.T) {
|
||||
iw := types.IntervalWindow{Interval: types.Interval1h, Window: 30}
|
||||
a := &struct {
|
||||
types.IntervalWindow
|
||||
}{
|
||||
IntervalWindow: iw,
|
||||
}
|
||||
b := &bbgo.CumulatedVolumeTakeProfit{
|
||||
IntervalWindow: types.IntervalWindow{Interval: types.Interval5m, Window: 9},
|
||||
}
|
||||
MergeStructValues(b, a)
|
||||
assert.Equal(t, types.IntervalWindow{Interval: types.Interval5m, Window: 9}, b.IntervalWindow)
|
||||
})
|
||||
|
||||
t.Run("skip different type but the same name", func(t *testing.T) {
|
||||
a := &struct {
|
||||
A float64
|
||||
}{
|
||||
A: 1.99,
|
||||
}
|
||||
b := &struct {
|
||||
A string
|
||||
}{}
|
||||
MergeStructValues(b, a)
|
||||
assert.Equal(t, "", b.A)
|
||||
assert.Equal(t, 1.99, a.A)
|
||||
})
|
||||
}
|
24
pkg/dynamic/typevalue.go
Normal file
24
pkg/dynamic/typevalue.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package dynamic
|
||||
|
||||
import "reflect"
|
||||
|
||||
// https://github.com/xiaojun207/go-base-utils/blob/master/utils/Clone.go
|
||||
func NewTypeValueInterface(typ reflect.Type) interface{} {
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typ = typ.Elem()
|
||||
dst := reflect.New(typ).Elem()
|
||||
return dst.Addr().Interface()
|
||||
}
|
||||
dst := reflect.New(typ)
|
||||
return dst.Interface()
|
||||
}
|
||||
|
||||
// ToReflectValues convert the go objects into reflect.Value slice
|
||||
func ToReflectValues(args ...interface{}) (values []reflect.Value) {
|
||||
for _, arg := range args {
|
||||
values = append(values, reflect.ValueOf(arg))
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user