diff --git a/pkg/bbgo/livenote.go b/pkg/bbgo/livenote.go index 20f92e367..e51f41668 100644 --- a/pkg/bbgo/livenote.go +++ b/pkg/bbgo/livenote.go @@ -6,21 +6,22 @@ import ( "github.com/c9s/bbgo/pkg/livenote" ) -// PostLiveNote posts a live note to slack or other services +// PostLiveNote a global function helper for strategies to call. +// This function posts a live note to slack or other services // The MessageID will be set after the message is posted if it's not set. -func PostLiveNote(obj livenote.Object) { +func PostLiveNote(obj livenote.Object, opts ...livenote.Option) { if len(Notification.liveNotePosters) == 0 { logrus.Warn("no live note poster is registered") return } for _, poster := range Notification.liveNotePosters { - if err := poster.PostLiveNote(obj); err != nil { + if err := poster.PostLiveNote(obj, opts...); err != nil { logrus.WithError(err).Errorf("unable to post live note: %+v", obj) } } } type LiveNotePoster interface { - PostLiveNote(note livenote.Object) error + PostLiveNote(note livenote.Object, opts ...livenote.Option) error } diff --git a/pkg/dynamic/compare.go b/pkg/dynamic/compare.go new file mode 100644 index 000000000..a99e18201 --- /dev/null +++ b/pkg/dynamic/compare.go @@ -0,0 +1,235 @@ +package dynamic + +import ( + "fmt" + "reflect" + "strconv" + "time" + + "github.com/c9s/bbgo/pkg/fixedpoint" +) + +type Diff struct { + Field string `json:"field"` + Before string `json:"before"` + After string `json:"after"` +} + +// a (after) +// b (before) +func Compare(a, b interface{}) ([]Diff, error) { + ra := reflect.ValueOf(a) + if ra.Kind() == reflect.Ptr { + ra = ra.Elem() + } + + raType := ra.Type() + raKind := ra.Kind() + + rb := reflect.ValueOf(b) + if rb.Kind() == reflect.Ptr { + rb = rb.Elem() + } + + rbType := rb.Type() + rbKind := rb.Kind() // bool, int, slice, string, struct + + if raType != rbType { + return nil, fmt.Errorf("type mismatch: %s != %s", raType, rbType) + } + + if raKind != rbKind { + return nil, fmt.Errorf("kind mismatch: %s != %s", raKind, rbKind) + } + + if isSimpleType(ra) { + if compareSimpleValue(ra, rb) { + // no changes + return nil, nil + } else { + return []Diff{ + { + Field: "", + Before: convertToStr(rb), + After: convertToStr(ra), + }, + }, nil + } + } else if raKind == reflect.Struct { + return compareStruct(ra, rb) + } + + return nil, nil +} + +func compareStruct(a, b reflect.Value) ([]Diff, error) { + a = reflect.Indirect(a) + b = reflect.Indirect(b) + + if a.Kind() != reflect.Struct { + return nil, fmt.Errorf("value is not a struct") + } + + if b.Kind() != reflect.Struct { + return nil, fmt.Errorf("value is not a struct") + } + + if a.Type() != b.Type() { + return nil, fmt.Errorf("type is not the same") + } + + var diffs []Diff + + numFields := a.NumField() + for i := 0; i < numFields; i++ { + fieldValueA := a.Field(i) + fieldValueB := b.Field(i) + + fieldA := a.Type().Field(i) + fieldName := fieldA.Name + + if !fieldA.IsExported() { + continue + } + + if isSimpleType(fieldValueA) { + if compareSimpleValue(fieldValueA, fieldValueB) { + continue + } else { + diffs = append(diffs, Diff{ + Field: fieldName, + Before: convertToStr(fieldValueB), + After: convertToStr(fieldValueA), + }) + } + } else if fieldValueA.Kind() == reflect.Struct && fieldValueB.Kind() == reflect.Struct { + subDiffs, err := compareStruct(fieldValueA, fieldValueB) + if err != nil { + return diffs, err + } + + for _, subDiff := range subDiffs { + diffs = append(diffs, Diff{ + Field: fieldName + "." + subDiff.Field, + Before: subDiff.Before, + After: subDiff.After, + }) + } + } + } + + return diffs, nil +} + +func isSimpleType(a reflect.Value) bool { + a = reflect.Indirect(a) + aInf := a.Interface() + + switch aInf.(type) { + case time.Time: + return true + + case fixedpoint.Value: + return true + + } + + kind := a.Kind() + switch kind { + case reflect.Bool, reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint64, reflect.String, reflect.Float64: + return true + default: + return false + } +} + +func compareSimpleValue(a, b reflect.Value) bool { + if a.Kind() != b.Kind() { + return false + } + + switch a.Kind() { + + case reflect.Bool: + if a.Bool() == b.Bool() { + return true + } + + case reflect.Uint, reflect.Uint32, reflect.Uint64: + if a.Uint() == b.Uint() { + return true + } + case reflect.Int, reflect.Int8, reflect.Int32, reflect.Int64: + if a.Int() == b.Int() { + return true + } + + case reflect.String: + if a.String() == b.String() { + return true + } + + case reflect.Float64: + if a.Float() == b.Float() { + return true + } + + case reflect.Slice: + // TODO: compare slice + + default: + ainf := a.Interface() + binf := b.Interface() + + switch aa := ainf.(type) { + case fixedpoint.Value: + if bb, ok := binf.(fixedpoint.Value); ok { + return bb.Compare(aa) == 0 + } + case time.Time: + if bb, ok := binf.(time.Time); ok { + return bb.Compare(aa) == 0 + } + } + + // other unhandled cases + } + + return false +} + +func convertToStr(val reflect.Value) string { + val = reflect.Indirect(val) + + if val.Type() == reflect.TypeOf(fixedpoint.Zero) { + inf := val.Interface() + switch aa := inf.(type) { + case fixedpoint.Value: + return aa.String() + case time.Time: + return aa.String() + } + } + + switch val.Kind() { + case reflect.Float32, reflect.Float64: + return strconv.FormatFloat(val.Float(), 'f', -1, 64) + + case reflect.Int, reflect.Int8, reflect.Int32, reflect.Int64: + return strconv.FormatInt(val.Int(), 10) + + case reflect.Uint, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(val.Uint(), 10) + + case reflect.Bool: + return strconv.FormatBool(val.Bool()) + default: + strType := reflect.TypeOf("") + if val.CanConvert(strType) { + strVal := val.Convert(strType) + return strVal.String() + } + + return "{unable to convert " + val.Kind().String() + "}" + } +} diff --git a/pkg/dynamic/compare_test.go b/pkg/dynamic/compare_test.go new file mode 100644 index 000000000..36200f545 --- /dev/null +++ b/pkg/dynamic/compare_test.go @@ -0,0 +1,207 @@ +package dynamic + +import ( + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/c9s/bbgo/pkg/fixedpoint" + "github.com/c9s/bbgo/pkg/types" +) + +func Test_convertToStr(t *testing.T) { + t.Run("str-str", func(t *testing.T) { + out := convertToStr(reflect.ValueOf("a")) + assert.Equal(t, "a", out) + }) + + t.Run("bool-str", func(t *testing.T) { + out := convertToStr(reflect.ValueOf(false)) + assert.Equal(t, "false", out) + + out = convertToStr(reflect.ValueOf(true)) + assert.Equal(t, "true", out) + }) + + t.Run("float-str", func(t *testing.T) { + out := convertToStr(reflect.ValueOf(0.444)) + assert.Equal(t, "0.444", out) + }) + + t.Run("int-str", func(t *testing.T) { + a := int(123) + out := convertToStr(reflect.ValueOf(a)) + assert.Equal(t, "123", out) + }) + + t.Run("uint-str", func(t *testing.T) { + a := uint(123) + out := convertToStr(reflect.ValueOf(a)) + assert.Equal(t, "123", out) + }) + + t.Run("int-ptr-str", func(t *testing.T) { + a := int(123) + out := convertToStr(reflect.ValueOf(&a)) + assert.Equal(t, "123", out) + }) + + t.Run("fixedpoint-str", func(t *testing.T) { + a := fixedpoint.NewFromInt(100) + out := convertToStr(reflect.ValueOf(a)) + assert.Equal(t, "100", out) + }) +} + +func Test_Compare(t *testing.T) { + tests := []struct { + name string + a, b interface{} + want []Diff + wantErr assert.ErrorAssertionFunc + }{ + { + name: "order", + wantErr: assert.NoError, + a: &types.Order{ + SubmitOrder: types.SubmitOrder{ + Symbol: "BTCUSDT", + Quantity: fixedpoint.NewFromFloat(100.0), + }, + Status: types.OrderStatusFilled, + ExecutedQuantity: fixedpoint.NewFromFloat(100.0), + }, + b: &types.Order{ + SubmitOrder: types.SubmitOrder{ + Symbol: "BTCUSDT", + Quantity: fixedpoint.NewFromFloat(100.0), + }, + ExecutedQuantity: fixedpoint.NewFromFloat(50.0), + Status: types.OrderStatusPartiallyFilled, + }, + want: []Diff{ + { + Field: "Status", + Before: "PARTIALLY_FILLED", + After: "FILLED", + }, + { + Field: "ExecutedQuantity", + Before: "50", + After: "100", + }, + }, + }, + { + name: "deposit and order", + wantErr: assert.NoError, + a: &types.Deposit{ + Address: "0x6666", + TransactionID: "0x3333", + Status: types.DepositPending, + Confirmation: "10/15", + }, + b: &types.Deposit{ + Address: "0x6666", + TransactionID: "0x3333", + Status: types.DepositPending, + Confirmation: "1/15", + }, + want: []Diff{ + { + Field: "Confirmation", + Before: "1/15", + After: "10/15", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Compare(tt.a, tt.b) + if !tt.wantErr(t, err, fmt.Sprintf("Compare(%v, %v)", tt.a, tt.b)) { + return + } + + assert.Equalf(t, tt.want, got, "Compare(%v, %v)", tt.a, tt.b) + }) + } +} + +func Test_compareStruct(t *testing.T) { + tests := []struct { + name string + a, b reflect.Value + want []Diff + wantErr assert.ErrorAssertionFunc + }{ + { + name: "order ptrs", + wantErr: assert.NoError, + a: reflect.ValueOf(&types.Order{ + SubmitOrder: types.SubmitOrder{ + Symbol: "BTCUSDT", + Quantity: fixedpoint.NewFromFloat(100.0), + }, + ExecutedQuantity: fixedpoint.NewFromFloat(50.0), + }), + b: reflect.ValueOf(&types.Order{ + SubmitOrder: types.SubmitOrder{ + Symbol: "BTCUSDT", + Quantity: fixedpoint.NewFromFloat(100.0), + }, + ExecutedQuantity: fixedpoint.NewFromFloat(20.0), + }), + want: []Diff{ + { + Field: "ExecutedQuantity", + Before: "20", + After: "50", + }, + }, + }, + { + name: "order ptr and value", + wantErr: assert.NoError, + a: reflect.ValueOf(types.Order{ + SubmitOrder: types.SubmitOrder{ + Symbol: "BTCUSDT", + Quantity: fixedpoint.NewFromFloat(100.0), + }, + Status: types.OrderStatusFilled, + ExecutedQuantity: fixedpoint.NewFromFloat(100.0), + }), + b: reflect.ValueOf(&types.Order{ + SubmitOrder: types.SubmitOrder{ + Symbol: "BTCUSDT", + Quantity: fixedpoint.NewFromFloat(100.0), + }, + ExecutedQuantity: fixedpoint.NewFromFloat(50.0), + Status: types.OrderStatusPartiallyFilled, + }), + want: []Diff{ + { + Field: "Status", + Before: "PARTIALLY_FILLED", + After: "FILLED", + }, + { + Field: "ExecutedQuantity", + Before: "50", + After: "100", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := compareStruct(tt.a, tt.b) + if !tt.wantErr(t, err, fmt.Sprintf("compareStruct(%v, %v)", tt.a, tt.b)) { + return + } + assert.Equalf(t, tt.want, got, "compareStruct(%v, %v)", tt.a, tt.b) + }) + } +} diff --git a/pkg/livenote/options.go b/pkg/livenote/options.go new file mode 100644 index 000000000..216dddec0 --- /dev/null +++ b/pkg/livenote/options.go @@ -0,0 +1,12 @@ +package livenote + +type Option interface{} + +type Mention struct { + User string +} + +type Comment struct { + Text string + Users []string +} diff --git a/pkg/notifier/slacknotifier/slack.go b/pkg/notifier/slacknotifier/slack.go index 09a3f84e9..238a6bca7 100644 --- a/pkg/notifier/slacknotifier/slack.go +++ b/pkg/notifier/slacknotifier/slack.go @@ -71,7 +71,7 @@ func (n *Notifier) worker() { } } -func (n *Notifier) PostLiveNote(obj livenote.Object) error { +func (n *Notifier) PostLiveNote(obj livenote.Object, opts ...livenote.Option) error { note := n.liveNotePool.Update(obj) ctx := context.Background() @@ -87,18 +87,32 @@ func (n *Notifier) PostLiveNote(obj livenote.Object) error { return fmt.Errorf("livenote object does not support types.SlackAttachmentCreator interface") } - opts := slack.MsgOptionAttachments(attachment) + var slackOpts []slack.MsgOption + slackOpts = append(slackOpts, slack.MsgOptionAttachments(attachment)) + + var userIds []string + var mentions []*livenote.Mention + var comments []*livenote.Comment + for _, opt := range opts { + switch val := opt.(type) { + case *livenote.Mention: + mentions = append(mentions, val) + userIds = append(userIds, val.User) + case *livenote.Comment: + comments = append(comments, val) + userIds = append(userIds, val.Users...) + } + } if note.MessageID != "" { // UpdateMessageContext returns channel, timestamp, text, err - _, _, _, err := n.client.UpdateMessageContext(ctx, channel, note.MessageID, opts) + _, _, _, err := n.client.UpdateMessageContext(ctx, channel, note.MessageID, slackOpts...) if err != nil { return err } } else { - - respCh, respTs, err := n.client.PostMessageContext(ctx, channel, opts) + respCh, respTs, err := n.client.PostMessageContext(ctx, channel, slackOpts...) if err != nil { log.WithError(err). WithField("channel", n.channel).