mirror of
https://github.com/c9s/bbgo.git
synced 2024-09-20 08:11:08 +00:00
dynamic: implement CallWithMatch for dynamic calls
Signed-off-by: c9s <yoanlin93@gmail.com>
This commit is contained in:
parent
503d851c9d
commit
910c17a567
|
@ -1,106 +0,0 @@
|
|||
package bbgo
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/dynamic"
|
||||
"github.com/c9s/bbgo/pkg/service"
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
)
|
||||
|
||||
func Test_injectField(t *testing.T) {
|
||||
type TT struct {
|
||||
TradeService *service.TradeService
|
||||
}
|
||||
|
||||
// only pointer object can be set.
|
||||
var tt = &TT{}
|
||||
|
||||
// get the value of the pointer, or it can not be set.
|
||||
var rv = reflect.ValueOf(tt).Elem()
|
||||
|
||||
_, ret := dynamic.HasField(rv, "TradeService")
|
||||
assert.True(t, ret)
|
||||
|
||||
ts := &service.TradeService{}
|
||||
|
||||
err := injectField(rv, "TradeService", ts, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_parseStructAndInject(t *testing.T) {
|
||||
t.Run("skip nil", func(t *testing.T) {
|
||||
ss := struct {
|
||||
a int
|
||||
Env *Environment
|
||||
}{
|
||||
a: 1,
|
||||
Env: nil,
|
||||
}
|
||||
err := parseStructAndInject(&ss, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, ss.Env)
|
||||
})
|
||||
t.Run("pointer", func(t *testing.T) {
|
||||
ss := struct {
|
||||
a int
|
||||
Env *Environment
|
||||
}{
|
||||
a: 1,
|
||||
Env: nil,
|
||||
}
|
||||
err := parseStructAndInject(&ss, &Environment{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ss.Env)
|
||||
})
|
||||
|
||||
t.Run("composition", func(t *testing.T) {
|
||||
type TT struct {
|
||||
*service.TradeService
|
||||
}
|
||||
ss := TT{}
|
||||
err := parseStructAndInject(&ss, &service.TradeService{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ss.TradeService)
|
||||
})
|
||||
|
||||
t.Run("struct", func(t *testing.T) {
|
||||
ss := struct {
|
||||
a int
|
||||
Env Environment
|
||||
}{
|
||||
a: 1,
|
||||
}
|
||||
err := parseStructAndInject(&ss, Environment{
|
||||
startTime: time.Now(),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, time.Time{}, ss.Env.startTime)
|
||||
})
|
||||
t.Run("interface/any", func(t *testing.T) {
|
||||
ss := struct {
|
||||
Any interface{} // anything
|
||||
}{
|
||||
Any: nil,
|
||||
}
|
||||
err := parseStructAndInject(&ss, &Environment{
|
||||
startTime: time.Now(),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ss.Any)
|
||||
})
|
||||
t.Run("interface/stringer", func(t *testing.T) {
|
||||
ss := struct {
|
||||
Stringer types.Stringer // stringer interface
|
||||
}{
|
||||
Stringer: nil,
|
||||
}
|
||||
err := parseStructAndInject(&ss, &types.Trade{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ss.Stringer)
|
||||
})
|
||||
}
|
|
@ -196,7 +196,7 @@ func (trader *Trader) RunSingleExchangeStrategy(ctx context.Context, strategy Si
|
|||
return err
|
||||
}
|
||||
|
||||
if err := injectField(rs, "OrderExecutor", orderExecutor, false); err != nil {
|
||||
if err := dynamic.InjectField(rs, "OrderExecutor", orderExecutor, false); err != nil {
|
||||
return errors.Wrapf(err, "failed to inject OrderExecutor on %T", strategy)
|
||||
}
|
||||
|
||||
|
@ -218,7 +218,7 @@ func (trader *Trader) RunSingleExchangeStrategy(ctx context.Context, strategy Si
|
|||
return fmt.Errorf("marketDataStore of symbol %s not found", symbol)
|
||||
}
|
||||
|
||||
if err := parseStructAndInject(strategy,
|
||||
if err := dynamic.ParseStructAndInject(strategy,
|
||||
market,
|
||||
indicatorSet,
|
||||
store,
|
||||
|
@ -401,19 +401,19 @@ func (trader *Trader) injectCommonServices(s interface{}) error {
|
|||
return fmt.Errorf("field Persistence is not a struct element, %s given", field)
|
||||
}
|
||||
|
||||
if err := injectField(elem, "Facade", PersistenceServiceFacade, true); err != nil {
|
||||
if err := dynamic.InjectField(elem, "Facade", PersistenceServiceFacade, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
/*
|
||||
if err := parseStructAndInject(field.Interface(), persistenceFacade); err != nil {
|
||||
if err := ParseStructAndInject(field.Interface(), persistenceFacade); err != nil {
|
||||
return err
|
||||
}
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
||||
return parseStructAndInject(s,
|
||||
return dynamic.ParseStructAndInject(s,
|
||||
&trader.logger,
|
||||
Notification,
|
||||
trader.environment.TradeService,
|
||||
|
|
|
@ -51,3 +51,101 @@ func CallStructFieldsMethod(m interface{}, method string, args ...interface{}) e
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CallMatch calls the function with the matched argument automatically
|
||||
func CallMatch(f interface{}, objects ...interface{}) ([]reflect.Value, error) {
|
||||
fv := reflect.ValueOf(f)
|
||||
ft := reflect.TypeOf(f)
|
||||
|
||||
var startIndex = 0
|
||||
var fArgs []reflect.Value
|
||||
|
||||
var factoryParams = findFactoryParams(objects...)
|
||||
|
||||
nextDynamicInputArg:
|
||||
for i := 0; i < ft.NumIn(); i++ {
|
||||
at := ft.In(i)
|
||||
|
||||
// uat == underlying argument type
|
||||
uat := at
|
||||
if at.Kind() == reflect.Ptr {
|
||||
uat = at.Elem()
|
||||
}
|
||||
|
||||
for oi := startIndex; oi < len(objects); oi++ {
|
||||
var obj = objects[oi]
|
||||
var objT = reflect.TypeOf(obj)
|
||||
if objT == at {
|
||||
fArgs = append(fArgs, reflect.ValueOf(obj))
|
||||
startIndex = oi + 1
|
||||
continue nextDynamicInputArg
|
||||
}
|
||||
|
||||
// get the kind of argument
|
||||
switch k := uat.Kind(); k {
|
||||
|
||||
case reflect.Interface:
|
||||
if objT.Implements(at) {
|
||||
fArgs = append(fArgs, reflect.ValueOf(obj))
|
||||
startIndex = oi + 1
|
||||
continue nextDynamicInputArg
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// factory param can be reused
|
||||
for _, fp := range factoryParams {
|
||||
fpt := fp.Type()
|
||||
outType := fpt.Out(0)
|
||||
if outType == at {
|
||||
fOut := fp.Call(nil)
|
||||
fArgs = append(fArgs, fOut[0])
|
||||
continue nextDynamicInputArg
|
||||
}
|
||||
}
|
||||
|
||||
fArgs = append(fArgs, reflect.Zero(at))
|
||||
}
|
||||
|
||||
out := fv.Call(fArgs)
|
||||
if ft.NumOut() == 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// try to get the error object from the return value (if any)
|
||||
var err error
|
||||
for i := 0; i < ft.NumOut(); i++ {
|
||||
outType := ft.Out(i)
|
||||
switch outType.Kind() {
|
||||
case reflect.Interface:
|
||||
o := out[i].Interface()
|
||||
switch ov := o.(type) {
|
||||
case error:
|
||||
err = ov
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return out, err
|
||||
}
|
||||
|
||||
func findFactoryParams(objs ...interface{}) (fs []reflect.Value) {
|
||||
for i := range objs {
|
||||
obj := objs[i]
|
||||
|
||||
objT := reflect.TypeOf(obj)
|
||||
|
||||
if objT.Kind() != reflect.Func {
|
||||
continue
|
||||
}
|
||||
|
||||
if objT.NumOut() == 0 || objT.NumIn() > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
fs = append(fs, reflect.ValueOf(obj))
|
||||
}
|
||||
|
||||
return fs
|
||||
}
|
||||
|
|
|
@ -27,3 +27,89 @@ func TestCallStructFieldsMethod(t *testing.T) {
|
|||
err := CallStructFieldsMethod(c, "Subscribe", 10)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type S struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
func (s *S) String() string { return s.ID }
|
||||
|
||||
func TestCallMatch(t *testing.T) {
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
f := func(a int, b int) {
|
||||
assert.Equal(t, 1, a)
|
||||
assert.Equal(t, 2, b)
|
||||
}
|
||||
_, err := CallMatch(f, 1, 2)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("interface", func(t *testing.T) {
|
||||
type A interface {
|
||||
String() string
|
||||
}
|
||||
f := func(foo int, a A) {
|
||||
assert.Equal(t, "foo", a.String())
|
||||
}
|
||||
_, err := CallMatch(f, 10, &S{ID: "foo"})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil interface", func(t *testing.T) {
|
||||
type A interface {
|
||||
String() string
|
||||
}
|
||||
f := func(foo int, a A) {
|
||||
assert.Equal(t, 10, foo)
|
||||
assert.Nil(t, a)
|
||||
}
|
||||
_, err := CallMatch(f, 10)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("struct pointer", func(t *testing.T) {
|
||||
f := func(foo int, s *S) {
|
||||
assert.Equal(t, 10, foo)
|
||||
assert.NotNil(t, s)
|
||||
}
|
||||
_, err := CallMatch(f, 10, &S{})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("struct pointer x 2", func(t *testing.T) {
|
||||
f := func(foo int, s1, s2 *S) {
|
||||
assert.Equal(t, 10, foo)
|
||||
assert.Equal(t, "s1", s1.String())
|
||||
assert.Equal(t, "s2", s2.String())
|
||||
}
|
||||
_, err := CallMatch(f, 10, &S{ID: "s1"}, &S{ID: "s2"})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("func factory", func(t *testing.T) {
|
||||
f := func(s *S) {
|
||||
assert.Equal(t, "factory", s.String())
|
||||
}
|
||||
_, err := CallMatch(f, func() *S {
|
||||
return &S{ID: "factory"}
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
f := func(s *S) {
|
||||
assert.Nil(t, s)
|
||||
}
|
||||
_, err := CallMatch(f)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("zero struct", func(t *testing.T) {
|
||||
f := func(s S) {
|
||||
assert.Equal(t, S{}, s)
|
||||
}
|
||||
_, err := CallMatch(f)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
}
|
||||
|
|
|
@ -1,13 +1,23 @@
|
|||
package bbgo
|
||||
package dynamic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/service"
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
)
|
||||
|
||||
func injectField(rs reflect.Value, fieldName string, obj interface{}, pointerOnly bool) error {
|
||||
type testEnvironment struct {
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
func InjectField(rs reflect.Value, fieldName string, obj interface{}, pointerOnly bool) error {
|
||||
field := rs.FieldByName(fieldName)
|
||||
if !field.IsValid() {
|
||||
return nil
|
||||
|
@ -38,10 +48,10 @@ func injectField(rs reflect.Value, fieldName string, obj interface{}, pointerOnl
|
|||
return nil
|
||||
}
|
||||
|
||||
// parseStructAndInject parses the struct fields and injects the objects into the corresponding fields by its type.
|
||||
// ParseStructAndInject parses the struct fields and injects the objects into the corresponding fields by its type.
|
||||
// if the given object is a reference of an object, the type of the target field MUST BE a pointer field.
|
||||
// if the given object is a struct value, the type of the target field CAN BE a pointer field or a struct value field.
|
||||
func parseStructAndInject(f interface{}, objects ...interface{}) error {
|
||||
func ParseStructAndInject(f interface{}, objects ...interface{}) error {
|
||||
sv := reflect.ValueOf(f)
|
||||
st := reflect.TypeOf(f)
|
||||
|
||||
|
@ -121,3 +131,96 @@ func parseStructAndInject(f interface{}, objects ...interface{}) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_injectField(t *testing.T) {
|
||||
type TT struct {
|
||||
TradeService *service.TradeService
|
||||
}
|
||||
|
||||
// only pointer object can be set.
|
||||
var tt = &TT{}
|
||||
|
||||
// get the value of the pointer, or it can not be set.
|
||||
var rv = reflect.ValueOf(tt).Elem()
|
||||
|
||||
_, ret := HasField(rv, "TradeService")
|
||||
assert.True(t, ret)
|
||||
|
||||
ts := &service.TradeService{}
|
||||
|
||||
err := InjectField(rv, "TradeService", ts, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_parseStructAndInject(t *testing.T) {
|
||||
t.Run("skip nil", func(t *testing.T) {
|
||||
ss := struct {
|
||||
a int
|
||||
Env *testEnvironment
|
||||
}{
|
||||
a: 1,
|
||||
Env: nil,
|
||||
}
|
||||
err := ParseStructAndInject(&ss, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, ss.Env)
|
||||
})
|
||||
t.Run("pointer", func(t *testing.T) {
|
||||
ss := struct {
|
||||
a int
|
||||
Env *testEnvironment
|
||||
}{
|
||||
a: 1,
|
||||
Env: nil,
|
||||
}
|
||||
err := ParseStructAndInject(&ss, &testEnvironment{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ss.Env)
|
||||
})
|
||||
|
||||
t.Run("composition", func(t *testing.T) {
|
||||
type TT struct {
|
||||
*service.TradeService
|
||||
}
|
||||
ss := TT{}
|
||||
err := ParseStructAndInject(&ss, &service.TradeService{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ss.TradeService)
|
||||
})
|
||||
|
||||
t.Run("struct", func(t *testing.T) {
|
||||
ss := struct {
|
||||
a int
|
||||
Env testEnvironment
|
||||
}{
|
||||
a: 1,
|
||||
}
|
||||
err := ParseStructAndInject(&ss, testEnvironment{
|
||||
startTime: time.Now(),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, time.Time{}, ss.Env.startTime)
|
||||
})
|
||||
t.Run("interface/any", func(t *testing.T) {
|
||||
ss := struct {
|
||||
Any interface{} // anything
|
||||
}{
|
||||
Any: nil,
|
||||
}
|
||||
err := ParseStructAndInject(&ss, &testEnvironment{
|
||||
startTime: time.Now(),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ss.Any)
|
||||
})
|
||||
t.Run("interface/stringer", func(t *testing.T) {
|
||||
ss := struct {
|
||||
Stringer types.Stringer // stringer interface
|
||||
}{
|
||||
Stringer: nil,
|
||||
}
|
||||
err := ParseStructAndInject(&ss, &types.Trade{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ss.Stringer)
|
||||
})
|
||||
}
|
|
@ -112,7 +112,7 @@ func (it *Interact) handleResponse(session Session, text string, ctxObjects ...i
|
|||
}
|
||||
|
||||
ctxObjects = append(ctxObjects, session)
|
||||
_, err := parseFuncArgsAndCall(f, args, ctxObjects...)
|
||||
_, err := ParseFuncArgsAndCall(f, args, ctxObjects...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -154,7 +154,7 @@ func (it *Interact) runCommand(session Session, command string, args []string, c
|
|||
|
||||
ctxObjects = append(ctxObjects, session)
|
||||
session.SetState(cmd.initState)
|
||||
if _, err := parseFuncArgsAndCall(cmd.F, args, ctxObjects...); err != nil {
|
||||
if _, err := ParseFuncArgsAndCall(cmd.F, args, ctxObjects...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ func Test_parseFuncArgsAndCall_NoErrorFunction(t *testing.T) {
|
|||
return nil
|
||||
}
|
||||
|
||||
_, err := parseFuncArgsAndCall(noErrorFunc, []string{"BTCUSDT", "0.123", "true"})
|
||||
_, err := ParseFuncArgsAndCall(noErrorFunc, []string{"BTCUSDT", "0.123", "true"})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,7 @@ func Test_parseFuncArgsAndCall_ErrorFunction(t *testing.T) {
|
|||
return errors.New("error")
|
||||
}
|
||||
|
||||
_, err := parseFuncArgsAndCall(errorFunc, []string{"BTCUSDT", "0.123"})
|
||||
_, err := ParseFuncArgsAndCall(errorFunc, []string{"BTCUSDT", "0.123"})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
|
@ -38,7 +38,7 @@ func Test_parseFuncArgsAndCall_InterfaceInjection(t *testing.T) {
|
|||
}
|
||||
|
||||
buf := bytes.NewBuffer(nil)
|
||||
_, err := parseFuncArgsAndCall(f, []string{"BTCUSDT", "0.123"}, buf)
|
||||
_, err := ParseFuncArgsAndCall(f, []string{"BTCUSDT", "0.123"}, buf)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "123", buf.String())
|
||||
}
|
||||
|
|
|
@ -10,21 +10,20 @@ import (
|
|||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func parseFuncArgsAndCall(f interface{}, args []string, objects ...interface{}) (State, error) {
|
||||
func ParseFuncArgsAndCall(f interface{}, args []string, objects ...interface{}) (State, error) {
|
||||
fv := reflect.ValueOf(f)
|
||||
ft := reflect.TypeOf(f)
|
||||
|
||||
argIndex := 0
|
||||
|
||||
var rArgs []reflect.Value
|
||||
for i := 0; i < ft.NumIn(); i++ {
|
||||
at := ft.In(i)
|
||||
|
||||
// get the kind of argument
|
||||
switch k := at.Kind(); k {
|
||||
|
||||
case reflect.Interface:
|
||||
found := false
|
||||
|
||||
for oi := 0; oi < len(objects); oi++ {
|
||||
obj := objects[oi]
|
||||
objT := reflect.TypeOf(obj)
|
||||
|
@ -90,8 +89,8 @@ func parseFuncArgsAndCall(f interface{}, args []string, objects ...interface{})
|
|||
}
|
||||
|
||||
// try to get the error object from the return value
|
||||
var state State
|
||||
var err error
|
||||
var state State
|
||||
for i := 0; i < ft.NumOut(); i++ {
|
||||
outType := ft.Out(i)
|
||||
switch outType.Kind() {
|
||||
|
@ -107,7 +106,6 @@ func parseFuncArgsAndCall(f interface{}, args []string, objects ...interface{})
|
|||
err = ov
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return state, err
|
||||
|
|
Loading…
Reference in New Issue
Block a user