diff --git a/pkg/indicator/drift.go b/pkg/indicator/drift.go index d760014a1..bda5b51d5 100644 --- a/pkg/indicator/drift.go +++ b/pkg/indicator/drift.go @@ -12,9 +12,9 @@ import ( //go:generate callbackgen -type Drift type Drift struct { types.IntervalWindow - chng *types.Queue - Values types.Float64Slice - SMA *SMA + chng *types.Queue + Values types.Float64Slice + SMA *SMA LastValue float64 UpdateCallbacks []func(value float64) @@ -22,16 +22,25 @@ type Drift struct { func (inc *Drift) Update(value float64) { if inc.chng == nil { - inc.SMA = &SMA{IntervalWindow: types.IntervalWindow{inc.Interval, inc.Window}} + inc.SMA = &SMA{IntervalWindow: types.IntervalWindow{Interval: inc.Interval, Window: inc.Window}} inc.chng = types.NewQueue(inc.Window) + inc.LastValue = value + return + } + var chng float64 + if value == 0 { + chng = 0 + } else { + chng = math.Log(value / inc.LastValue) + inc.LastValue = value } - chng := math.Log(value / inc.LastValue) - inc.LastValue = value inc.SMA.Update(chng) inc.chng.Update(chng) - stdev := types.Stdev(inc.chng, inc.Window) - drift := inc.SMA.Last() - stdev * stdev * 0.5 - inc.Values.Push(drift) + if inc.chng.Length() >= inc.Window { + stdev := types.Stdev(inc.chng, inc.Window) + drift := inc.SMA.Last() - stdev*stdev*0.5 + inc.Values.Push(drift) + } } func (inc *Drift) Index(i int) float64 { diff --git a/pkg/indicator/drift_test.go b/pkg/indicator/drift_test.go new file mode 100644 index 000000000..38d6a732a --- /dev/null +++ b/pkg/indicator/drift_test.go @@ -0,0 +1,40 @@ +package indicator + +import ( + "encoding/json" + "testing" + + "github.com/c9s/bbgo/pkg/fixedpoint" + "github.com/c9s/bbgo/pkg/types" + "github.com/stretchr/testify/assert" +) + +func Test_Drift(t *testing.T) { + var randomPrices = []byte(`[1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 4, 1, 2, 3, 4, 5, 6, 7, 8, 9, 4, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 4, 1, 2, 3, 4, 5, 6, 7, 8, 9]`) + var input []fixedpoint.Value + if err := json.Unmarshal(randomPrices, &input); err != nil { + panic(err) + } + tests := []struct { + name string + kLines []types.KLine + all int + }{ + { + name: "random_case", + kLines: buildKLines(input), + all: 47, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + drift := Drift{IntervalWindow: types.IntervalWindow{Window: 3}} + drift.calculateAndUpdate(tt.kLines) + assert.Equal(t, drift.Length(), tt.all) + for _, v := range drift.Values { + assert.LessOrEqual(t, v, 1.0) + } + }) + } +}