types: add simple duration tests

This commit is contained in:
c9s 2022-12-25 16:08:34 +08:00
parent f60b4630c5
commit 579df0cec9
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
4 changed files with 125 additions and 50 deletions

View File

@ -65,7 +65,7 @@ type Strategy struct {
// GridNum is the grid number, how many orders you want to post on the orderbook. // GridNum is the grid number, how many orders you want to post on the orderbook.
GridNum int64 `json:"gridNumber"` GridNum int64 `json:"gridNumber"`
AutoRange types.Duration `json:"autoRange"` AutoRange types.SimpleDuration `json:"autoRange"`
UpperPrice fixedpoint.Value `json:"upperPrice"` UpperPrice fixedpoint.Value `json:"upperPrice"`

View File

@ -1,6 +1,8 @@
package types package types
import ( import (
"encoding/json"
"fmt"
"regexp" "regexp"
"strconv" "strconv"
"time" "time"
@ -8,9 +10,9 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var simpleDurationRegExp = regexp.MustCompile("^(\\d+)[hdw]$") var simpleDurationRegExp = regexp.MustCompile("^(\\d+)([hdw])$")
var ErrNotSimpleDuration = errors.New("the given input is not simple duration format") var ErrNotSimpleDuration = errors.New("the given input is not simple duration format, valid format: [1-9][0-9]*[hdw]")
type SimpleDuration struct { type SimpleDuration struct {
Num int64 Num int64
@ -18,7 +20,28 @@ type SimpleDuration struct {
Duration Duration Duration Duration
} }
func (d *SimpleDuration) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
sd, err := ParseSimpleDuration(s)
if err != nil {
return err
}
if sd != nil {
*d = *sd
}
return nil
}
func ParseSimpleDuration(s string) (*SimpleDuration, error) { func ParseSimpleDuration(s string) (*SimpleDuration, error) {
if s == "" {
return nil, nil
}
if !simpleDurationRegExp.MatchString(s) { if !simpleDurationRegExp.MatchString(s) {
return nil, errors.Wrapf(ErrNotSimpleDuration, "input %q is not a simple duration", s) return nil, errors.Wrapf(ErrNotSimpleDuration, "input %q is not a simple duration", s)
} }
@ -45,3 +68,47 @@ func ParseSimpleDuration(s string) (*SimpleDuration, error) {
return nil, errors.Wrapf(ErrNotSimpleDuration, "input %q is not a simple duration", s) return nil, errors.Wrapf(ErrNotSimpleDuration, "input %q is not a simple duration", s)
} }
type Duration time.Duration
func (d *Duration) Duration() time.Duration {
return time.Duration(*d)
}
func (d *Duration) UnmarshalJSON(data []byte) error {
var o interface{}
if err := json.Unmarshal(data, &o); err != nil {
return err
}
switch t := o.(type) {
case string:
sd, err := ParseSimpleDuration(t)
if err == nil {
*d = sd.Duration
return nil
}
dd, err := time.ParseDuration(t)
if err != nil {
return err
}
*d = Duration(dd)
case float64:
*d = Duration(int64(t * float64(time.Second)))
case int64:
*d = Duration(t * int64(time.Second))
case int:
*d = Duration(t * int(time.Second))
default:
return fmt.Errorf("unsupported type %T value: %v", t, t)
}
return nil
}

View File

@ -0,0 +1,55 @@
package types
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestParseSimpleDuration(t *testing.T) {
type args struct {
s string
}
tests := []struct {
name string
args args
want *SimpleDuration
wantErr assert.ErrorAssertionFunc
}{
{
name: "3h",
args: args{
s: "3h",
},
want: &SimpleDuration{Num: 3, Unit: "h", Duration: Duration(3 * time.Hour)},
wantErr: assert.NoError,
},
{
name: "3d",
args: args{
s: "3d",
},
want: &SimpleDuration{Num: 3, Unit: "d", Duration: Duration(3 * 24 * time.Hour)},
wantErr: assert.NoError,
},
{
name: "3w",
args: args{
s: "3w",
},
want: &SimpleDuration{Num: 3, Unit: "w", Duration: Duration(3 * 7 * 24 * time.Hour)},
wantErr: assert.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseSimpleDuration(tt.args.s)
if !tt.wantErr(t, err, fmt.Sprintf("ParseSimpleDuration(%v)", tt.args.s)) {
return
}
assert.Equalf(t, tt.want, got, "ParseSimpleDuration(%v)", tt.args.s)
})
}
}

View File

@ -1,60 +1,13 @@
package types package types
import ( import (
"encoding/json"
"fmt"
"math" "math"
"time"
"github.com/leekchan/accounting" "github.com/leekchan/accounting"
"github.com/c9s/bbgo/pkg/fixedpoint" "github.com/c9s/bbgo/pkg/fixedpoint"
) )
type Duration time.Duration
func (d *Duration) Duration() time.Duration {
return time.Duration(*d)
}
func (d *Duration) UnmarshalJSON(data []byte) error {
var o interface{}
if err := json.Unmarshal(data, &o); err != nil {
return err
}
switch t := o.(type) {
case string:
sd, err := ParseSimpleDuration(t)
if err == nil {
*d = sd.Duration
return nil
}
dd, err := time.ParseDuration(t)
if err != nil {
return err
}
*d = Duration(dd)
case float64:
*d = Duration(int64(t * float64(time.Second)))
case int64:
*d = Duration(t * int64(time.Second))
case int:
*d = Duration(t * int(time.Second))
default:
return fmt.Errorf("unsupported type %T value: %v", t, t)
}
return nil
}
type Market struct { type Market struct {
Symbol string `json:"symbol"` Symbol string `json:"symbol"`