From 579df0cec9ca6f1b79094ed08fc5d628cd047f59 Mon Sep 17 00:00:00 2001 From: c9s Date: Sun, 25 Dec 2022 16:08:34 +0800 Subject: [PATCH] types: add simple duration tests --- pkg/strategy/grid2/strategy.go | 2 +- pkg/types/duration.go | 71 +++++++++++++++++++++++++++++++++- pkg/types/duration_test.go | 55 ++++++++++++++++++++++++++ pkg/types/market.go | 47 ---------------------- 4 files changed, 125 insertions(+), 50 deletions(-) create mode 100644 pkg/types/duration_test.go diff --git a/pkg/strategy/grid2/strategy.go b/pkg/strategy/grid2/strategy.go index 972650241..6dcca83e2 100644 --- a/pkg/strategy/grid2/strategy.go +++ b/pkg/strategy/grid2/strategy.go @@ -65,7 +65,7 @@ type Strategy struct { // GridNum is the grid number, how many orders you want to post on the orderbook. GridNum int64 `json:"gridNumber"` - AutoRange types.Duration `json:"autoRange"` + AutoRange types.SimpleDuration `json:"autoRange"` UpperPrice fixedpoint.Value `json:"upperPrice"` diff --git a/pkg/types/duration.go b/pkg/types/duration.go index 9ec93d9bb..f465f24bd 100644 --- a/pkg/types/duration.go +++ b/pkg/types/duration.go @@ -1,6 +1,8 @@ package types import ( + "encoding/json" + "fmt" "regexp" "strconv" "time" @@ -8,9 +10,9 @@ import ( "github.com/pkg/errors" ) -var simpleDurationRegExp = regexp.MustCompile("^(\\d+)[hdw]$") +var simpleDurationRegExp = regexp.MustCompile("^(\\d+)([hdw])$") -var ErrNotSimpleDuration = errors.New("the given input is not simple duration format") +var ErrNotSimpleDuration = errors.New("the given input is not simple duration format, valid format: [1-9][0-9]*[hdw]") type SimpleDuration struct { Num int64 @@ -18,7 +20,28 @@ type SimpleDuration struct { Duration Duration } +func (d *SimpleDuration) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + sd, err := ParseSimpleDuration(s) + if err != nil { + return err + } + + if sd != nil { + *d = *sd + } + return nil +} + func ParseSimpleDuration(s string) (*SimpleDuration, error) { + if s == "" { + return nil, nil + } + if !simpleDurationRegExp.MatchString(s) { return nil, errors.Wrapf(ErrNotSimpleDuration, "input %q is not a simple duration", s) } @@ -45,3 +68,47 @@ func ParseSimpleDuration(s string) (*SimpleDuration, error) { return nil, errors.Wrapf(ErrNotSimpleDuration, "input %q is not a simple duration", s) } + +type Duration time.Duration + +func (d *Duration) Duration() time.Duration { + return time.Duration(*d) +} + +func (d *Duration) UnmarshalJSON(data []byte) error { + var o interface{} + + if err := json.Unmarshal(data, &o); err != nil { + return err + } + + switch t := o.(type) { + case string: + sd, err := ParseSimpleDuration(t) + if err == nil { + *d = sd.Duration + return nil + } + + dd, err := time.ParseDuration(t) + if err != nil { + return err + } + + *d = Duration(dd) + + case float64: + *d = Duration(int64(t * float64(time.Second))) + + case int64: + *d = Duration(t * int64(time.Second)) + case int: + *d = Duration(t * int(time.Second)) + + default: + return fmt.Errorf("unsupported type %T value: %v", t, t) + + } + + return nil +} diff --git a/pkg/types/duration_test.go b/pkg/types/duration_test.go new file mode 100644 index 000000000..44a56c80d --- /dev/null +++ b/pkg/types/duration_test.go @@ -0,0 +1,55 @@ +package types + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParseSimpleDuration(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + want *SimpleDuration + wantErr assert.ErrorAssertionFunc + }{ + { + name: "3h", + args: args{ + s: "3h", + }, + want: &SimpleDuration{Num: 3, Unit: "h", Duration: Duration(3 * time.Hour)}, + wantErr: assert.NoError, + }, + { + name: "3d", + args: args{ + s: "3d", + }, + want: &SimpleDuration{Num: 3, Unit: "d", Duration: Duration(3 * 24 * time.Hour)}, + wantErr: assert.NoError, + }, + { + name: "3w", + args: args{ + s: "3w", + }, + want: &SimpleDuration{Num: 3, Unit: "w", Duration: Duration(3 * 7 * 24 * time.Hour)}, + wantErr: assert.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseSimpleDuration(tt.args.s) + if !tt.wantErr(t, err, fmt.Sprintf("ParseSimpleDuration(%v)", tt.args.s)) { + return + } + assert.Equalf(t, tt.want, got, "ParseSimpleDuration(%v)", tt.args.s) + }) + } +} diff --git a/pkg/types/market.go b/pkg/types/market.go index 23f3610ca..6cc9466ff 100644 --- a/pkg/types/market.go +++ b/pkg/types/market.go @@ -1,60 +1,13 @@ package types import ( - "encoding/json" - "fmt" "math" - "time" "github.com/leekchan/accounting" "github.com/c9s/bbgo/pkg/fixedpoint" ) -type Duration time.Duration - -func (d *Duration) Duration() time.Duration { - return time.Duration(*d) -} - -func (d *Duration) UnmarshalJSON(data []byte) error { - var o interface{} - - if err := json.Unmarshal(data, &o); err != nil { - return err - } - - switch t := o.(type) { - case string: - sd, err := ParseSimpleDuration(t) - if err == nil { - *d = sd.Duration - return nil - } - - dd, err := time.ParseDuration(t) - if err != nil { - return err - } - - *d = Duration(dd) - - case float64: - *d = Duration(int64(t * float64(time.Second))) - - case int64: - *d = Duration(t * int64(time.Second)) - case int: - *d = Duration(t * int(time.Second)) - - default: - return fmt.Errorf("unsupported type %T value: %v", t, t) - - } - - return nil -} - type Market struct { Symbol string `json:"symbol"`