fix: sma calculation, length, and add test case

This commit is contained in:
zenix 2022-07-13 12:28:41 +09:00
parent 7daa73917c
commit 4e2adcf29e
2 changed files with 90 additions and 43 deletions

View File

@ -4,8 +4,6 @@ import (
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/c9s/bbgo/pkg/types"
)
@ -19,79 +17,65 @@ type SMA struct {
types.SeriesBase
types.IntervalWindow
Values types.Float64Slice
Cache types.Float64Slice
Cache *types.Queue
EndTime time.Time
UpdateCallbacks []func(value float64)
}
func (inc *SMA) Last() float64 {
if len(inc.Values) == 0 {
if inc.Values.Length() == 0 {
return 0.0
}
return inc.Values[len(inc.Values)-1]
return inc.Values.Last()
}
func (inc *SMA) Index(i int) float64 {
length := len(inc.Values)
if length == 0 || length-i-1 < 0 {
if i >= inc.Values.Length() {
return 0.0
}
return inc.Values[length-i-1]
return inc.Values.Index(i)
}
func (inc *SMA) Length() int {
return len(inc.Values)
return inc.Values.Length()
}
var _ types.SeriesExtend = &SMA{}
func (inc *SMA) Update(value float64) {
if len(inc.Cache) < inc.Window {
if len(inc.Cache) == 0 {
inc.SeriesBase.Series = inc
}
inc.Cache = append(inc.Cache, value)
if len(inc.Cache) == inc.Window {
inc.Values = append(inc.Values, types.Mean(&inc.Cache))
}
return
if inc.Cache == nil {
inc.Cache = types.NewQueue(inc.Window)
inc.SeriesBase.Series = inc
}
inc.Cache.Update(value)
if inc.Cache.Length() < inc.Window {
return
}
inc.Values.Push(types.Mean(inc.Cache))
if inc.Values.Length() > MaxNumOfSMA {
inc.Values = inc.Values[MaxNumOfSMATruncateSize-1:]
}
length := len(inc.Values)
newVal := (inc.Values[length-1]*float64(inc.Window-1) + value) / float64(inc.Window)
inc.Values = append(inc.Values, newVal)
}
func (inc *SMA) calculateAndUpdate(kLines []types.KLine) {
if len(kLines) < inc.Window {
return
}
var index = len(kLines) - 1
var kline = kLines[index]
if inc.EndTime != zeroTime && kline.EndTime.Before(inc.EndTime) {
return
}
var recentK = kLines[index-(inc.Window-1) : index+1]
sma, err := calculateSMA(recentK, inc.Window, KLineClosePriceMapper)
if err != nil {
log.WithError(err).Error("SMA error")
return
if inc.Cache == nil {
for _, k := range kLines {
inc.Update(KLineClosePriceMapper(k))
inc.EndTime = k.EndTime.Time()
inc.EmitUpdate(inc.Values.Last())
}
} else {
inc.Update(KLineClosePriceMapper(kline))
inc.EndTime = kline.EndTime.Time()
inc.EmitUpdate(inc.Values.Last())
}
inc.Values.Push(sma)
if len(inc.Values) > MaxNumOfSMA {
inc.Values = inc.Values[MaxNumOfSMATruncateSize-1:]
}
inc.EndTime = kLines[index].EndTime.Time()
inc.EmitUpdate(sma)
}
func (inc *SMA) handleKLineWindowUpdate(interval types.Interval, window types.KLineWindow) {

63
pkg/indicator/sma_test.go Normal file
View File

@ -0,0 +1,63 @@
package indicator
import (
"encoding/json"
"testing"
"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/types"
"github.com/stretchr/testify/assert"
)
/*
python:
import pandas as pd
import pandas_ta as ta
data = pd.Series([0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9,0])
size = 5
result = ta.sma(data, size)
print(result)
*/
func Test_SMA(t *testing.T) {
Delta := 0.001
var randomPrices = []byte(`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]`)
var input []fixedpoint.Value
if err := json.Unmarshal(randomPrices, &input); err != nil {
panic(err)
}
tests := []struct {
name string
kLines []types.KLine
want float64
next float64
update float64
updateResult float64
all int
}{
{
name: "test",
kLines: buildKLines(input),
want: 7.0,
next: 6.0,
update: 0,
updateResult: 6.0,
all: 27,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sma := SMA{
IntervalWindow: types.IntervalWindow{Window: 5},
}
sma.calculateAndUpdate(tt.kLines)
assert.InDelta(t, tt.want, sma.Last(), Delta)
assert.InDelta(t, tt.next, sma.Index(1), Delta)
sma.Update(tt.update)
assert.InDelta(t, tt.updateResult, sma.Last(), Delta)
assert.Equal(t, tt.all, sma.Length())
})
}
}