diff --git a/pkg/types/time.go b/pkg/types/time.go index a10b02f91..8e0cc2d9e 100644 --- a/pkg/types/time.go +++ b/pkg/types/time.go @@ -5,11 +5,16 @@ import ( "encoding/json" "fmt" "strconv" + "strings" "time" "github.com/c9s/bbgo/pkg/util" ) +var numOfDigitsOfUnixTimestamp = len(strconv.FormatInt(time.Now().Unix(), 10)) +var numOfDigitsOfMilliSecondUnixTimestamp = len(strconv.FormatInt(time.Now().UnixMilli(), 10)) +var numOfDigitsOfNanoSecondsUnixTimestamp = len(strconv.FormatInt(time.Now().UnixNano(), 10)) + type NanosecondTimestamp time.Time func (t NanosecondTimestamp) Time() time.Time { @@ -28,11 +33,10 @@ func (t *NanosecondTimestamp) UnmarshalJSON(data []byte) error { return nil } - type MillisecondTimestamp time.Time func NewMillisecondTimestampFromInt(i int64) MillisecondTimestamp { - return MillisecondTimestamp(time.Unix(0, i * int64(time.Millisecond))) + return MillisecondTimestamp(time.Unix(0, i*int64(time.Millisecond))) } func MustParseMillisecondTimestamp(a string) MillisecondTimestamp { @@ -77,15 +81,14 @@ func (t *MillisecondTimestamp) UnmarshalJSON(data []byte) error { return nil } - i, err := strconv.ParseInt(vt, 10, 64) - if err == nil { - *t = MillisecondTimestamp(time.Unix(0, i*int64(time.Millisecond))) - return nil - } - f, err := strconv.ParseFloat(vt, 64) if err == nil { - *t = MillisecondTimestamp(time.Unix(0, int64(f*float64(time.Millisecond)))) + tt, err := convertFloat64ToTime(vt, f) + if err != nil { + return err + } + + *t = MillisecondTimestamp(tt) return nil } @@ -97,16 +100,14 @@ func (t *MillisecondTimestamp) UnmarshalJSON(data []byte) error { return err - case int64: - *t = MillisecondTimestamp(time.Unix(0, vt*int64(time.Millisecond))) - return nil - - case int: - *t = MillisecondTimestamp(time.Unix(0, int64(vt)*int64(time.Millisecond))) - return nil - case float64: - *t = MillisecondTimestamp(time.Unix(0, int64(vt)*int64(time.Millisecond))) + str := strconv.FormatFloat(vt, 'f', -1, 64) + tt, err := convertFloat64ToTime(str, vt) + if err != nil { + return err + } + + *t = MillisecondTimestamp(tt) return nil default: @@ -118,7 +119,22 @@ func (t *MillisecondTimestamp) UnmarshalJSON(data []byte) error { return (*time.Time)(t).UnmarshalJSON(data) } +func convertFloat64ToTime(vt string, f float64) (time.Time, error) { + idx := strings.Index(vt, ".") + if idx > 0 { + vt = vt[0 : idx-1] + } + if len(vt) <= numOfDigitsOfUnixTimestamp { + return time.Unix(0, int64(f*float64(time.Second))), nil + } else if len(vt) <= numOfDigitsOfMilliSecondUnixTimestamp { + return time.Unix(0, int64(f)*int64(time.Millisecond)), nil + } else if len(vt) <= numOfDigitsOfNanoSecondsUnixTimestamp { + return time.Unix(0, int64(f)), nil + } + + return time.Time{}, fmt.Errorf("the floating point value %f is out of the timestamp range", f) +} type Time time.Time @@ -227,7 +243,13 @@ func (t *LooseFormatTime) UnmarshalYAML(unmarshal func(interface{}) error) error } func (t *LooseFormatTime) UnmarshalJSON(data []byte) error { - tv, err := util.ParseTimeWithFormats(string(data), looseTimeFormats) + var v string + err := json.Unmarshal(data, &v) + if err != nil { + return err + } + + tv, err := util.ParseTimeWithFormats(v, looseTimeFormats) if err != nil { return err } @@ -239,4 +261,3 @@ func (t *LooseFormatTime) UnmarshalJSON(data []byte) error { func (t LooseFormatTime) Time() time.Time { return time.Time(t) } - diff --git a/pkg/types/time_test.go b/pkg/types/time_test.go index 785a8c624..ae21984cf 100644 --- a/pkg/types/time_test.go +++ b/pkg/types/time_test.go @@ -3,8 +3,40 @@ package types import ( "testing" "time" + + "github.com/stretchr/testify/assert" ) +func TestLooseFormatTime_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + t LooseFormatTime + args []byte + wantErr bool + }{ + { + name: "simple date", + args: []byte("\"2021-01-01\""), + t: LooseFormatTime(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), + }, + { + name: "utc", + args: []byte("\"2021-01-01T12:10:10\""), + t: LooseFormatTime(time.Date(2021, 1, 1, 12, 10, 10, 0, time.UTC)), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var v LooseFormatTime + if err := v.UnmarshalJSON(tt.args); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } else { + assert.Equal(t, v.Time(), tt.t.Time()) + } + }) + } +} + func TestMillisecondTimestamp_UnmarshalJSON(t *testing.T) { tests := []struct { name string @@ -30,8 +62,11 @@ func TestMillisecondTimestamp_UnmarshalJSON(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.t.UnmarshalJSON(tt.args); (err != nil) != tt.wantErr { + var v MillisecondTimestamp + if err := v.UnmarshalJSON(tt.args); (err != nil) != tt.wantErr { t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } else { + assert.Equal(t, tt.t.Time(), v.Time()) } }) }