dynamic: implement CallWithMatch for dynamic calls

Signed-off-by: c9s <yoanlin93@gmail.com>
This commit is contained in:
c9s 2022-07-01 13:09:30 +08:00
parent 503d851c9d
commit 910c17a567
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
8 changed files with 304 additions and 125 deletions

View File

@ -1,106 +0,0 @@
package bbgo
import (
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/dynamic"
"github.com/c9s/bbgo/pkg/service"
"github.com/c9s/bbgo/pkg/types"
)
func Test_injectField(t *testing.T) {
type TT struct {
TradeService *service.TradeService
}
// only pointer object can be set.
var tt = &TT{}
// get the value of the pointer, or it can not be set.
var rv = reflect.ValueOf(tt).Elem()
_, ret := dynamic.HasField(rv, "TradeService")
assert.True(t, ret)
ts := &service.TradeService{}
err := injectField(rv, "TradeService", ts, true)
assert.NoError(t, err)
}
func Test_parseStructAndInject(t *testing.T) {
t.Run("skip nil", func(t *testing.T) {
ss := struct {
a int
Env *Environment
}{
a: 1,
Env: nil,
}
err := parseStructAndInject(&ss, nil)
assert.NoError(t, err)
assert.Nil(t, ss.Env)
})
t.Run("pointer", func(t *testing.T) {
ss := struct {
a int
Env *Environment
}{
a: 1,
Env: nil,
}
err := parseStructAndInject(&ss, &Environment{})
assert.NoError(t, err)
assert.NotNil(t, ss.Env)
})
t.Run("composition", func(t *testing.T) {
type TT struct {
*service.TradeService
}
ss := TT{}
err := parseStructAndInject(&ss, &service.TradeService{})
assert.NoError(t, err)
assert.NotNil(t, ss.TradeService)
})
t.Run("struct", func(t *testing.T) {
ss := struct {
a int
Env Environment
}{
a: 1,
}
err := parseStructAndInject(&ss, Environment{
startTime: time.Now(),
})
assert.NoError(t, err)
assert.NotEqual(t, time.Time{}, ss.Env.startTime)
})
t.Run("interface/any", func(t *testing.T) {
ss := struct {
Any interface{} // anything
}{
Any: nil,
}
err := parseStructAndInject(&ss, &Environment{
startTime: time.Now(),
})
assert.NoError(t, err)
assert.NotNil(t, ss.Any)
})
t.Run("interface/stringer", func(t *testing.T) {
ss := struct {
Stringer types.Stringer // stringer interface
}{
Stringer: nil,
}
err := parseStructAndInject(&ss, &types.Trade{})
assert.NoError(t, err)
assert.NotNil(t, ss.Stringer)
})
}

View File

@ -196,7 +196,7 @@ func (trader *Trader) RunSingleExchangeStrategy(ctx context.Context, strategy Si
return err
}
if err := injectField(rs, "OrderExecutor", orderExecutor, false); err != nil {
if err := dynamic.InjectField(rs, "OrderExecutor", orderExecutor, false); err != nil {
return errors.Wrapf(err, "failed to inject OrderExecutor on %T", strategy)
}
@ -218,7 +218,7 @@ func (trader *Trader) RunSingleExchangeStrategy(ctx context.Context, strategy Si
return fmt.Errorf("marketDataStore of symbol %s not found", symbol)
}
if err := parseStructAndInject(strategy,
if err := dynamic.ParseStructAndInject(strategy,
market,
indicatorSet,
store,
@ -401,19 +401,19 @@ func (trader *Trader) injectCommonServices(s interface{}) error {
return fmt.Errorf("field Persistence is not a struct element, %s given", field)
}
if err := injectField(elem, "Facade", PersistenceServiceFacade, true); err != nil {
if err := dynamic.InjectField(elem, "Facade", PersistenceServiceFacade, true); err != nil {
return err
}
/*
if err := parseStructAndInject(field.Interface(), persistenceFacade); err != nil {
if err := ParseStructAndInject(field.Interface(), persistenceFacade); err != nil {
return err
}
*/
}
}
return parseStructAndInject(s,
return dynamic.ParseStructAndInject(s,
&trader.logger,
Notification,
trader.environment.TradeService,

View File

@ -51,3 +51,101 @@ func CallStructFieldsMethod(m interface{}, method string, args ...interface{}) e
return nil
}
// CallMatch calls the function with the matched argument automatically
func CallMatch(f interface{}, objects ...interface{}) ([]reflect.Value, error) {
fv := reflect.ValueOf(f)
ft := reflect.TypeOf(f)
var startIndex = 0
var fArgs []reflect.Value
var factoryParams = findFactoryParams(objects...)
nextDynamicInputArg:
for i := 0; i < ft.NumIn(); i++ {
at := ft.In(i)
// uat == underlying argument type
uat := at
if at.Kind() == reflect.Ptr {
uat = at.Elem()
}
for oi := startIndex; oi < len(objects); oi++ {
var obj = objects[oi]
var objT = reflect.TypeOf(obj)
if objT == at {
fArgs = append(fArgs, reflect.ValueOf(obj))
startIndex = oi + 1
continue nextDynamicInputArg
}
// get the kind of argument
switch k := uat.Kind(); k {
case reflect.Interface:
if objT.Implements(at) {
fArgs = append(fArgs, reflect.ValueOf(obj))
startIndex = oi + 1
continue nextDynamicInputArg
}
}
}
// factory param can be reused
for _, fp := range factoryParams {
fpt := fp.Type()
outType := fpt.Out(0)
if outType == at {
fOut := fp.Call(nil)
fArgs = append(fArgs, fOut[0])
continue nextDynamicInputArg
}
}
fArgs = append(fArgs, reflect.Zero(at))
}
out := fv.Call(fArgs)
if ft.NumOut() == 0 {
return out, nil
}
// try to get the error object from the return value (if any)
var err error
for i := 0; i < ft.NumOut(); i++ {
outType := ft.Out(i)
switch outType.Kind() {
case reflect.Interface:
o := out[i].Interface()
switch ov := o.(type) {
case error:
err = ov
}
}
}
return out, err
}
func findFactoryParams(objs ...interface{}) (fs []reflect.Value) {
for i := range objs {
obj := objs[i]
objT := reflect.TypeOf(obj)
if objT.Kind() != reflect.Func {
continue
}
if objT.NumOut() == 0 || objT.NumIn() > 0 {
continue
}
fs = append(fs, reflect.ValueOf(obj))
}
return fs
}

View File

@ -27,3 +27,89 @@ func TestCallStructFieldsMethod(t *testing.T) {
err := CallStructFieldsMethod(c, "Subscribe", 10)
assert.NoError(t, err)
}
type S struct {
ID string
}
func (s *S) String() string { return s.ID }
func TestCallMatch(t *testing.T) {
t.Run("simple", func(t *testing.T) {
f := func(a int, b int) {
assert.Equal(t, 1, a)
assert.Equal(t, 2, b)
}
_, err := CallMatch(f, 1, 2)
assert.NoError(t, err)
})
t.Run("interface", func(t *testing.T) {
type A interface {
String() string
}
f := func(foo int, a A) {
assert.Equal(t, "foo", a.String())
}
_, err := CallMatch(f, 10, &S{ID: "foo"})
assert.NoError(t, err)
})
t.Run("nil interface", func(t *testing.T) {
type A interface {
String() string
}
f := func(foo int, a A) {
assert.Equal(t, 10, foo)
assert.Nil(t, a)
}
_, err := CallMatch(f, 10)
assert.NoError(t, err)
})
t.Run("struct pointer", func(t *testing.T) {
f := func(foo int, s *S) {
assert.Equal(t, 10, foo)
assert.NotNil(t, s)
}
_, err := CallMatch(f, 10, &S{})
assert.NoError(t, err)
})
t.Run("struct pointer x 2", func(t *testing.T) {
f := func(foo int, s1, s2 *S) {
assert.Equal(t, 10, foo)
assert.Equal(t, "s1", s1.String())
assert.Equal(t, "s2", s2.String())
}
_, err := CallMatch(f, 10, &S{ID: "s1"}, &S{ID: "s2"})
assert.NoError(t, err)
})
t.Run("func factory", func(t *testing.T) {
f := func(s *S) {
assert.Equal(t, "factory", s.String())
}
_, err := CallMatch(f, func() *S {
return &S{ID: "factory"}
})
assert.NoError(t, err)
})
t.Run("nil", func(t *testing.T) {
f := func(s *S) {
assert.Nil(t, s)
}
_, err := CallMatch(f)
assert.NoError(t, err)
})
t.Run("zero struct", func(t *testing.T) {
f := func(s S) {
assert.Equal(t, S{}, s)
}
_, err := CallMatch(f)
assert.NoError(t, err)
})
}

View File

@ -1,13 +1,23 @@
package bbgo
package dynamic
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/service"
"github.com/c9s/bbgo/pkg/types"
)
func injectField(rs reflect.Value, fieldName string, obj interface{}, pointerOnly bool) error {
type testEnvironment struct {
startTime time.Time
}
func InjectField(rs reflect.Value, fieldName string, obj interface{}, pointerOnly bool) error {
field := rs.FieldByName(fieldName)
if !field.IsValid() {
return nil
@ -38,10 +48,10 @@ func injectField(rs reflect.Value, fieldName string, obj interface{}, pointerOnl
return nil
}
// parseStructAndInject parses the struct fields and injects the objects into the corresponding fields by its type.
// ParseStructAndInject parses the struct fields and injects the objects into the corresponding fields by its type.
// if the given object is a reference of an object, the type of the target field MUST BE a pointer field.
// if the given object is a struct value, the type of the target field CAN BE a pointer field or a struct value field.
func parseStructAndInject(f interface{}, objects ...interface{}) error {
func ParseStructAndInject(f interface{}, objects ...interface{}) error {
sv := reflect.ValueOf(f)
st := reflect.TypeOf(f)
@ -121,3 +131,96 @@ func parseStructAndInject(f interface{}, objects ...interface{}) error {
return nil
}
func Test_injectField(t *testing.T) {
type TT struct {
TradeService *service.TradeService
}
// only pointer object can be set.
var tt = &TT{}
// get the value of the pointer, or it can not be set.
var rv = reflect.ValueOf(tt).Elem()
_, ret := HasField(rv, "TradeService")
assert.True(t, ret)
ts := &service.TradeService{}
err := InjectField(rv, "TradeService", ts, true)
assert.NoError(t, err)
}
func Test_parseStructAndInject(t *testing.T) {
t.Run("skip nil", func(t *testing.T) {
ss := struct {
a int
Env *testEnvironment
}{
a: 1,
Env: nil,
}
err := ParseStructAndInject(&ss, nil)
assert.NoError(t, err)
assert.Nil(t, ss.Env)
})
t.Run("pointer", func(t *testing.T) {
ss := struct {
a int
Env *testEnvironment
}{
a: 1,
Env: nil,
}
err := ParseStructAndInject(&ss, &testEnvironment{})
assert.NoError(t, err)
assert.NotNil(t, ss.Env)
})
t.Run("composition", func(t *testing.T) {
type TT struct {
*service.TradeService
}
ss := TT{}
err := ParseStructAndInject(&ss, &service.TradeService{})
assert.NoError(t, err)
assert.NotNil(t, ss.TradeService)
})
t.Run("struct", func(t *testing.T) {
ss := struct {
a int
Env testEnvironment
}{
a: 1,
}
err := ParseStructAndInject(&ss, testEnvironment{
startTime: time.Now(),
})
assert.NoError(t, err)
assert.NotEqual(t, time.Time{}, ss.Env.startTime)
})
t.Run("interface/any", func(t *testing.T) {
ss := struct {
Any interface{} // anything
}{
Any: nil,
}
err := ParseStructAndInject(&ss, &testEnvironment{
startTime: time.Now(),
})
assert.NoError(t, err)
assert.NotNil(t, ss.Any)
})
t.Run("interface/stringer", func(t *testing.T) {
ss := struct {
Stringer types.Stringer // stringer interface
}{
Stringer: nil,
}
err := ParseStructAndInject(&ss, &types.Trade{})
assert.NoError(t, err)
assert.NotNil(t, ss.Stringer)
})
}

View File

@ -112,7 +112,7 @@ func (it *Interact) handleResponse(session Session, text string, ctxObjects ...i
}
ctxObjects = append(ctxObjects, session)
_, err := parseFuncArgsAndCall(f, args, ctxObjects...)
_, err := ParseFuncArgsAndCall(f, args, ctxObjects...)
if err != nil {
return err
}
@ -154,7 +154,7 @@ func (it *Interact) runCommand(session Session, command string, args []string, c
ctxObjects = append(ctxObjects, session)
session.SetState(cmd.initState)
if _, err := parseFuncArgsAndCall(cmd.F, args, ctxObjects...); err != nil {
if _, err := ParseFuncArgsAndCall(cmd.F, args, ctxObjects...); err != nil {
return err
}

View File

@ -18,7 +18,7 @@ func Test_parseFuncArgsAndCall_NoErrorFunction(t *testing.T) {
return nil
}
_, err := parseFuncArgsAndCall(noErrorFunc, []string{"BTCUSDT", "0.123", "true"})
_, err := ParseFuncArgsAndCall(noErrorFunc, []string{"BTCUSDT", "0.123", "true"})
assert.NoError(t, err)
}
@ -27,7 +27,7 @@ func Test_parseFuncArgsAndCall_ErrorFunction(t *testing.T) {
return errors.New("error")
}
_, err := parseFuncArgsAndCall(errorFunc, []string{"BTCUSDT", "0.123"})
_, err := ParseFuncArgsAndCall(errorFunc, []string{"BTCUSDT", "0.123"})
assert.Error(t, err)
}
@ -38,7 +38,7 @@ func Test_parseFuncArgsAndCall_InterfaceInjection(t *testing.T) {
}
buf := bytes.NewBuffer(nil)
_, err := parseFuncArgsAndCall(f, []string{"BTCUSDT", "0.123"}, buf)
_, err := ParseFuncArgsAndCall(f, []string{"BTCUSDT", "0.123"}, buf)
assert.NoError(t, err)
assert.Equal(t, "123", buf.String())
}

View File

@ -10,21 +10,20 @@ import (
log "github.com/sirupsen/logrus"
)
func parseFuncArgsAndCall(f interface{}, args []string, objects ...interface{}) (State, error) {
func ParseFuncArgsAndCall(f interface{}, args []string, objects ...interface{}) (State, error) {
fv := reflect.ValueOf(f)
ft := reflect.TypeOf(f)
argIndex := 0
var rArgs []reflect.Value
for i := 0; i < ft.NumIn(); i++ {
at := ft.In(i)
// get the kind of argument
switch k := at.Kind(); k {
case reflect.Interface:
found := false
for oi := 0; oi < len(objects); oi++ {
obj := objects[oi]
objT := reflect.TypeOf(obj)
@ -90,8 +89,8 @@ func parseFuncArgsAndCall(f interface{}, args []string, objects ...interface{})
}
// try to get the error object from the return value
var state State
var err error
var state State
for i := 0; i < ft.NumOut(); i++ {
outType := ft.Out(i)
switch outType.Kind() {
@ -107,7 +106,6 @@ func parseFuncArgsAndCall(f interface{}, args []string, objects ...interface{})
err = ov
}
}
}
return state, err