diff --git a/pkg/bbgo/persistence.go b/pkg/bbgo/persistence.go index b350f108d..7aaa95db9 100644 --- a/pkg/bbgo/persistence.go +++ b/pkg/bbgo/persistence.go @@ -129,11 +129,7 @@ func newTypeValueInterface(typ reflect.Type) interface{} { } } -func loadPersistenceFields(obj interface{}, persistence service.PersistenceService) error { - id := callID(obj) - if len(id) == 0 { - return fmt.Errorf("object does not provide ID() method for persistency") - } +func loadPersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error { return iterateFieldsByTag(obj, "persistence", func(tag string, field reflect.StructField, value reflect.Value) error { newValueInf := newTypeValueInterface(value.Type()) @@ -155,12 +151,7 @@ func loadPersistenceFields(obj interface{}, persistence service.PersistenceServi }) } -func storePersistenceFields(obj interface{}, persistence service.PersistenceService) error { - id := callID(obj) - if len(id) == 0 { - return fmt.Errorf("object does not provide ID() method for persistency") - } - +func storePersistenceFields(obj interface{}, id string, persistence service.PersistenceService) error { return iterateFieldsByTag(obj, "persistence", func(tag string, ft reflect.StructField, fv reflect.Value) error { inf := fv.Interface() diff --git a/pkg/bbgo/persistence_test.go b/pkg/bbgo/persistence_test.go index a14f503ae..9e4efadce 100644 --- a/pkg/bbgo/persistence_test.go +++ b/pkg/bbgo/persistence_test.go @@ -42,6 +42,11 @@ func preparePersistentServices() []service.PersistenceService { return pss } +func Test_callID(t *testing.T) { + id := callID(&TestStruct{}) + assert.NotEmpty(t, id) +} + func Test_storePersistenceFields(t *testing.T) { var pss = preparePersistentServices() @@ -56,7 +61,8 @@ func Test_storePersistenceFields(t *testing.T) { for _, ps := range pss { t.Run(reflect.TypeOf(ps).Elem().String(), func(t *testing.T) { - err := storePersistenceFields(a, ps) + id := callID(a) + err := storePersistenceFields(a, id, ps) assert.NoError(t, err) var i int64 @@ -73,7 +79,7 @@ func Test_storePersistenceFields(t *testing.T) { assert.Equal(t, fixedpoint.NewFromFloat(3343.0), p.AverageCost) var b = &TestStruct{} - err = loadPersistenceFields(b, ps) + err = loadPersistenceFields(b, id, ps) assert.NoError(t, err) assert.Equal(t, a.Integer, b.Integer) assert.Equal(t, a.Integer2, b.Integer2)