From 910c17a56732460b057da4173fe4d62aeb0be2ee Mon Sep 17 00:00:00 2001 From: c9s Date: Fri, 1 Jul 2022 13:09:30 +0800 Subject: [PATCH] dynamic: implement CallWithMatch for dynamic calls Signed-off-by: c9s --- pkg/bbgo/injection_test.go | 106 ------------------ pkg/bbgo/trader.go | 10 +- pkg/dynamic/call.go | 98 ++++++++++++++++ pkg/dynamic/call_test.go | 86 ++++++++++++++ pkg/{bbgo/injection.go => dynamic/inject.go} | 111 ++++++++++++++++++- pkg/interact/interact.go | 4 +- pkg/interact/interact_test.go | 6 +- pkg/interact/parse.go | 8 +- 8 files changed, 304 insertions(+), 125 deletions(-) delete mode 100644 pkg/bbgo/injection_test.go rename pkg/{bbgo/injection.go => dynamic/inject.go} (54%) diff --git a/pkg/bbgo/injection_test.go b/pkg/bbgo/injection_test.go deleted file mode 100644 index 69a8b5f72..000000000 --- a/pkg/bbgo/injection_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package bbgo - -import ( - "reflect" - "testing" - "time" - - "github.com/stretchr/testify/assert" - - "github.com/c9s/bbgo/pkg/dynamic" - "github.com/c9s/bbgo/pkg/service" - "github.com/c9s/bbgo/pkg/types" -) - -func Test_injectField(t *testing.T) { - type TT struct { - TradeService *service.TradeService - } - - // only pointer object can be set. - var tt = &TT{} - - // get the value of the pointer, or it can not be set. - var rv = reflect.ValueOf(tt).Elem() - - _, ret := dynamic.HasField(rv, "TradeService") - assert.True(t, ret) - - ts := &service.TradeService{} - - err := injectField(rv, "TradeService", ts, true) - assert.NoError(t, err) -} - -func Test_parseStructAndInject(t *testing.T) { - t.Run("skip nil", func(t *testing.T) { - ss := struct { - a int - Env *Environment - }{ - a: 1, - Env: nil, - } - err := parseStructAndInject(&ss, nil) - assert.NoError(t, err) - assert.Nil(t, ss.Env) - }) - t.Run("pointer", func(t *testing.T) { - ss := struct { - a int - Env *Environment - }{ - a: 1, - Env: nil, - } - err := parseStructAndInject(&ss, &Environment{}) - assert.NoError(t, err) - assert.NotNil(t, ss.Env) - }) - - t.Run("composition", func(t *testing.T) { - type TT struct { - *service.TradeService - } - ss := TT{} - err := parseStructAndInject(&ss, &service.TradeService{}) - assert.NoError(t, err) - assert.NotNil(t, ss.TradeService) - }) - - t.Run("struct", func(t *testing.T) { - ss := struct { - a int - Env Environment - }{ - a: 1, - } - err := parseStructAndInject(&ss, Environment{ - startTime: time.Now(), - }) - assert.NoError(t, err) - assert.NotEqual(t, time.Time{}, ss.Env.startTime) - }) - t.Run("interface/any", func(t *testing.T) { - ss := struct { - Any interface{} // anything - }{ - Any: nil, - } - err := parseStructAndInject(&ss, &Environment{ - startTime: time.Now(), - }) - assert.NoError(t, err) - assert.NotNil(t, ss.Any) - }) - t.Run("interface/stringer", func(t *testing.T) { - ss := struct { - Stringer types.Stringer // stringer interface - }{ - Stringer: nil, - } - err := parseStructAndInject(&ss, &types.Trade{}) - assert.NoError(t, err) - assert.NotNil(t, ss.Stringer) - }) -} diff --git a/pkg/bbgo/trader.go b/pkg/bbgo/trader.go index 86371a53b..e8cbc13ed 100644 --- a/pkg/bbgo/trader.go +++ b/pkg/bbgo/trader.go @@ -196,7 +196,7 @@ func (trader *Trader) RunSingleExchangeStrategy(ctx context.Context, strategy Si return err } - if err := injectField(rs, "OrderExecutor", orderExecutor, false); err != nil { + if err := dynamic.InjectField(rs, "OrderExecutor", orderExecutor, false); err != nil { return errors.Wrapf(err, "failed to inject OrderExecutor on %T", strategy) } @@ -218,7 +218,7 @@ func (trader *Trader) RunSingleExchangeStrategy(ctx context.Context, strategy Si return fmt.Errorf("marketDataStore of symbol %s not found", symbol) } - if err := parseStructAndInject(strategy, + if err := dynamic.ParseStructAndInject(strategy, market, indicatorSet, store, @@ -401,19 +401,19 @@ func (trader *Trader) injectCommonServices(s interface{}) error { return fmt.Errorf("field Persistence is not a struct element, %s given", field) } - if err := injectField(elem, "Facade", PersistenceServiceFacade, true); err != nil { + if err := dynamic.InjectField(elem, "Facade", PersistenceServiceFacade, true); err != nil { return err } /* - if err := parseStructAndInject(field.Interface(), persistenceFacade); err != nil { + if err := ParseStructAndInject(field.Interface(), persistenceFacade); err != nil { return err } */ } } - return parseStructAndInject(s, + return dynamic.ParseStructAndInject(s, &trader.logger, Notification, trader.environment.TradeService, diff --git a/pkg/dynamic/call.go b/pkg/dynamic/call.go index a3122faad..a4c45a0f1 100644 --- a/pkg/dynamic/call.go +++ b/pkg/dynamic/call.go @@ -51,3 +51,101 @@ func CallStructFieldsMethod(m interface{}, method string, args ...interface{}) e return nil } + +// CallMatch calls the function with the matched argument automatically +func CallMatch(f interface{}, objects ...interface{}) ([]reflect.Value, error) { + fv := reflect.ValueOf(f) + ft := reflect.TypeOf(f) + + var startIndex = 0 + var fArgs []reflect.Value + + var factoryParams = findFactoryParams(objects...) + +nextDynamicInputArg: + for i := 0; i < ft.NumIn(); i++ { + at := ft.In(i) + + // uat == underlying argument type + uat := at + if at.Kind() == reflect.Ptr { + uat = at.Elem() + } + + for oi := startIndex; oi < len(objects); oi++ { + var obj = objects[oi] + var objT = reflect.TypeOf(obj) + if objT == at { + fArgs = append(fArgs, reflect.ValueOf(obj)) + startIndex = oi + 1 + continue nextDynamicInputArg + } + + // get the kind of argument + switch k := uat.Kind(); k { + + case reflect.Interface: + if objT.Implements(at) { + fArgs = append(fArgs, reflect.ValueOf(obj)) + startIndex = oi + 1 + continue nextDynamicInputArg + } + } + } + + // factory param can be reused + for _, fp := range factoryParams { + fpt := fp.Type() + outType := fpt.Out(0) + if outType == at { + fOut := fp.Call(nil) + fArgs = append(fArgs, fOut[0]) + continue nextDynamicInputArg + } + } + + fArgs = append(fArgs, reflect.Zero(at)) + } + + out := fv.Call(fArgs) + if ft.NumOut() == 0 { + return out, nil + } + + // try to get the error object from the return value (if any) + var err error + for i := 0; i < ft.NumOut(); i++ { + outType := ft.Out(i) + switch outType.Kind() { + case reflect.Interface: + o := out[i].Interface() + switch ov := o.(type) { + case error: + err = ov + + } + + } + } + return out, err +} + +func findFactoryParams(objs ...interface{}) (fs []reflect.Value) { + for i := range objs { + obj := objs[i] + + objT := reflect.TypeOf(obj) + + if objT.Kind() != reflect.Func { + continue + } + + if objT.NumOut() == 0 || objT.NumIn() > 0 { + continue + } + + fs = append(fs, reflect.ValueOf(obj)) + } + + return fs +} diff --git a/pkg/dynamic/call_test.go b/pkg/dynamic/call_test.go index 5324218f9..b65029ded 100644 --- a/pkg/dynamic/call_test.go +++ b/pkg/dynamic/call_test.go @@ -27,3 +27,89 @@ func TestCallStructFieldsMethod(t *testing.T) { err := CallStructFieldsMethod(c, "Subscribe", 10) assert.NoError(t, err) } + +type S struct { + ID string +} + +func (s *S) String() string { return s.ID } + +func TestCallMatch(t *testing.T) { + t.Run("simple", func(t *testing.T) { + f := func(a int, b int) { + assert.Equal(t, 1, a) + assert.Equal(t, 2, b) + } + _, err := CallMatch(f, 1, 2) + assert.NoError(t, err) + }) + + t.Run("interface", func(t *testing.T) { + type A interface { + String() string + } + f := func(foo int, a A) { + assert.Equal(t, "foo", a.String()) + } + _, err := CallMatch(f, 10, &S{ID: "foo"}) + assert.NoError(t, err) + }) + + t.Run("nil interface", func(t *testing.T) { + type A interface { + String() string + } + f := func(foo int, a A) { + assert.Equal(t, 10, foo) + assert.Nil(t, a) + } + _, err := CallMatch(f, 10) + assert.NoError(t, err) + }) + + t.Run("struct pointer", func(t *testing.T) { + f := func(foo int, s *S) { + assert.Equal(t, 10, foo) + assert.NotNil(t, s) + } + _, err := CallMatch(f, 10, &S{}) + assert.NoError(t, err) + }) + + t.Run("struct pointer x 2", func(t *testing.T) { + f := func(foo int, s1, s2 *S) { + assert.Equal(t, 10, foo) + assert.Equal(t, "s1", s1.String()) + assert.Equal(t, "s2", s2.String()) + } + _, err := CallMatch(f, 10, &S{ID: "s1"}, &S{ID: "s2"}) + assert.NoError(t, err) + }) + + t.Run("func factory", func(t *testing.T) { + f := func(s *S) { + assert.Equal(t, "factory", s.String()) + } + _, err := CallMatch(f, func() *S { + return &S{ID: "factory"} + }) + assert.NoError(t, err) + }) + + t.Run("nil", func(t *testing.T) { + f := func(s *S) { + assert.Nil(t, s) + } + _, err := CallMatch(f) + assert.NoError(t, err) + }) + + t.Run("zero struct", func(t *testing.T) { + f := func(s S) { + assert.Equal(t, S{}, s) + } + _, err := CallMatch(f) + assert.NoError(t, err) + }) + +} diff --git a/pkg/bbgo/injection.go b/pkg/dynamic/inject.go similarity index 54% rename from pkg/bbgo/injection.go rename to pkg/dynamic/inject.go index 0db8cf228..04a48599b 100644 --- a/pkg/bbgo/injection.go +++ b/pkg/dynamic/inject.go @@ -1,13 +1,23 @@ -package bbgo +package dynamic import ( "fmt" "reflect" + "testing" + "time" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/c9s/bbgo/pkg/service" + "github.com/c9s/bbgo/pkg/types" ) -func injectField(rs reflect.Value, fieldName string, obj interface{}, pointerOnly bool) error { +type testEnvironment struct { + startTime time.Time +} + +func InjectField(rs reflect.Value, fieldName string, obj interface{}, pointerOnly bool) error { field := rs.FieldByName(fieldName) if !field.IsValid() { return nil @@ -38,10 +48,10 @@ func injectField(rs reflect.Value, fieldName string, obj interface{}, pointerOnl return nil } -// parseStructAndInject parses the struct fields and injects the objects into the corresponding fields by its type. +// ParseStructAndInject parses the struct fields and injects the objects into the corresponding fields by its type. // if the given object is a reference of an object, the type of the target field MUST BE a pointer field. // if the given object is a struct value, the type of the target field CAN BE a pointer field or a struct value field. -func parseStructAndInject(f interface{}, objects ...interface{}) error { +func ParseStructAndInject(f interface{}, objects ...interface{}) error { sv := reflect.ValueOf(f) st := reflect.TypeOf(f) @@ -121,3 +131,96 @@ func parseStructAndInject(f interface{}, objects ...interface{}) error { return nil } + +func Test_injectField(t *testing.T) { + type TT struct { + TradeService *service.TradeService + } + + // only pointer object can be set. + var tt = &TT{} + + // get the value of the pointer, or it can not be set. + var rv = reflect.ValueOf(tt).Elem() + + _, ret := HasField(rv, "TradeService") + assert.True(t, ret) + + ts := &service.TradeService{} + + err := InjectField(rv, "TradeService", ts, true) + assert.NoError(t, err) +} + +func Test_parseStructAndInject(t *testing.T) { + t.Run("skip nil", func(t *testing.T) { + ss := struct { + a int + Env *testEnvironment + }{ + a: 1, + Env: nil, + } + err := ParseStructAndInject(&ss, nil) + assert.NoError(t, err) + assert.Nil(t, ss.Env) + }) + t.Run("pointer", func(t *testing.T) { + ss := struct { + a int + Env *testEnvironment + }{ + a: 1, + Env: nil, + } + err := ParseStructAndInject(&ss, &testEnvironment{}) + assert.NoError(t, err) + assert.NotNil(t, ss.Env) + }) + + t.Run("composition", func(t *testing.T) { + type TT struct { + *service.TradeService + } + ss := TT{} + err := ParseStructAndInject(&ss, &service.TradeService{}) + assert.NoError(t, err) + assert.NotNil(t, ss.TradeService) + }) + + t.Run("struct", func(t *testing.T) { + ss := struct { + a int + Env testEnvironment + }{ + a: 1, + } + err := ParseStructAndInject(&ss, testEnvironment{ + startTime: time.Now(), + }) + assert.NoError(t, err) + assert.NotEqual(t, time.Time{}, ss.Env.startTime) + }) + t.Run("interface/any", func(t *testing.T) { + ss := struct { + Any interface{} // anything + }{ + Any: nil, + } + err := ParseStructAndInject(&ss, &testEnvironment{ + startTime: time.Now(), + }) + assert.NoError(t, err) + assert.NotNil(t, ss.Any) + }) + t.Run("interface/stringer", func(t *testing.T) { + ss := struct { + Stringer types.Stringer // stringer interface + }{ + Stringer: nil, + } + err := ParseStructAndInject(&ss, &types.Trade{}) + assert.NoError(t, err) + assert.NotNil(t, ss.Stringer) + }) +} diff --git a/pkg/interact/interact.go b/pkg/interact/interact.go index a19b56210..820979cfe 100644 --- a/pkg/interact/interact.go +++ b/pkg/interact/interact.go @@ -112,7 +112,7 @@ func (it *Interact) handleResponse(session Session, text string, ctxObjects ...i } ctxObjects = append(ctxObjects, session) - _, err := parseFuncArgsAndCall(f, args, ctxObjects...) + _, err := ParseFuncArgsAndCall(f, args, ctxObjects...) if err != nil { return err } @@ -154,7 +154,7 @@ func (it *Interact) runCommand(session Session, command string, args []string, c ctxObjects = append(ctxObjects, session) session.SetState(cmd.initState) - if _, err := parseFuncArgsAndCall(cmd.F, args, ctxObjects...); err != nil { + if _, err := ParseFuncArgsAndCall(cmd.F, args, ctxObjects...); err != nil { return err } diff --git a/pkg/interact/interact_test.go b/pkg/interact/interact_test.go index bd0828240..8402ba1c8 100644 --- a/pkg/interact/interact_test.go +++ b/pkg/interact/interact_test.go @@ -18,7 +18,7 @@ func Test_parseFuncArgsAndCall_NoErrorFunction(t *testing.T) { return nil } - _, err := parseFuncArgsAndCall(noErrorFunc, []string{"BTCUSDT", "0.123", "true"}) + _, err := ParseFuncArgsAndCall(noErrorFunc, []string{"BTCUSDT", "0.123", "true"}) assert.NoError(t, err) } @@ -27,7 +27,7 @@ func Test_parseFuncArgsAndCall_ErrorFunction(t *testing.T) { return errors.New("error") } - _, err := parseFuncArgsAndCall(errorFunc, []string{"BTCUSDT", "0.123"}) + _, err := ParseFuncArgsAndCall(errorFunc, []string{"BTCUSDT", "0.123"}) assert.Error(t, err) } @@ -38,7 +38,7 @@ func Test_parseFuncArgsAndCall_InterfaceInjection(t *testing.T) { } buf := bytes.NewBuffer(nil) - _, err := parseFuncArgsAndCall(f, []string{"BTCUSDT", "0.123"}, buf) + _, err := ParseFuncArgsAndCall(f, []string{"BTCUSDT", "0.123"}, buf) assert.NoError(t, err) assert.Equal(t, "123", buf.String()) } diff --git a/pkg/interact/parse.go b/pkg/interact/parse.go index db4f3d1fd..64f55871b 100644 --- a/pkg/interact/parse.go +++ b/pkg/interact/parse.go @@ -10,21 +10,20 @@ import ( log "github.com/sirupsen/logrus" ) -func parseFuncArgsAndCall(f interface{}, args []string, objects ...interface{}) (State, error) { +func ParseFuncArgsAndCall(f interface{}, args []string, objects ...interface{}) (State, error) { fv := reflect.ValueOf(f) ft := reflect.TypeOf(f) - argIndex := 0 var rArgs []reflect.Value for i := 0; i < ft.NumIn(); i++ { at := ft.In(i) + // get the kind of argument switch k := at.Kind(); k { case reflect.Interface: found := false - for oi := 0; oi < len(objects); oi++ { obj := objects[oi] objT := reflect.TypeOf(obj) @@ -90,8 +89,8 @@ func parseFuncArgsAndCall(f interface{}, args []string, objects ...interface{}) } // try to get the error object from the return value - var state State var err error + var state State for i := 0; i < ft.NumOut(); i++ { outType := ft.Out(i) switch outType.Kind() { @@ -107,7 +106,6 @@ func parseFuncArgsAndCall(f interface{}, args []string, objects ...interface{}) err = ov } - } } return state, err