mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-25 16:25:16 +00:00
fix: sma calculation, length, and add test case
This commit is contained in:
parent
7daa73917c
commit
4e2adcf29e
|
@ -4,8 +4,6 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
)
|
||||
|
||||
|
@ -19,79 +17,65 @@ type SMA struct {
|
|||
types.SeriesBase
|
||||
types.IntervalWindow
|
||||
Values types.Float64Slice
|
||||
Cache types.Float64Slice
|
||||
Cache *types.Queue
|
||||
EndTime time.Time
|
||||
|
||||
UpdateCallbacks []func(value float64)
|
||||
}
|
||||
|
||||
func (inc *SMA) Last() float64 {
|
||||
if len(inc.Values) == 0 {
|
||||
if inc.Values.Length() == 0 {
|
||||
return 0.0
|
||||
}
|
||||
return inc.Values[len(inc.Values)-1]
|
||||
return inc.Values.Last()
|
||||
}
|
||||
|
||||
func (inc *SMA) Index(i int) float64 {
|
||||
length := len(inc.Values)
|
||||
if length == 0 || length-i-1 < 0 {
|
||||
if i >= inc.Values.Length() {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
return inc.Values[length-i-1]
|
||||
return inc.Values.Index(i)
|
||||
}
|
||||
|
||||
func (inc *SMA) Length() int {
|
||||
return len(inc.Values)
|
||||
return inc.Values.Length()
|
||||
}
|
||||
|
||||
var _ types.SeriesExtend = &SMA{}
|
||||
|
||||
func (inc *SMA) Update(value float64) {
|
||||
if len(inc.Cache) < inc.Window {
|
||||
if len(inc.Cache) == 0 {
|
||||
inc.SeriesBase.Series = inc
|
||||
}
|
||||
inc.Cache = append(inc.Cache, value)
|
||||
if len(inc.Cache) == inc.Window {
|
||||
inc.Values = append(inc.Values, types.Mean(&inc.Cache))
|
||||
}
|
||||
return
|
||||
|
||||
if inc.Cache == nil {
|
||||
inc.Cache = types.NewQueue(inc.Window)
|
||||
inc.SeriesBase.Series = inc
|
||||
}
|
||||
inc.Cache.Update(value)
|
||||
if inc.Cache.Length() < inc.Window {
|
||||
return
|
||||
}
|
||||
inc.Values.Push(types.Mean(inc.Cache))
|
||||
if inc.Values.Length() > MaxNumOfSMA {
|
||||
inc.Values = inc.Values[MaxNumOfSMATruncateSize-1:]
|
||||
}
|
||||
length := len(inc.Values)
|
||||
newVal := (inc.Values[length-1]*float64(inc.Window-1) + value) / float64(inc.Window)
|
||||
inc.Values = append(inc.Values, newVal)
|
||||
}
|
||||
|
||||
func (inc *SMA) calculateAndUpdate(kLines []types.KLine) {
|
||||
if len(kLines) < inc.Window {
|
||||
return
|
||||
}
|
||||
|
||||
var index = len(kLines) - 1
|
||||
var kline = kLines[index]
|
||||
|
||||
if inc.EndTime != zeroTime && kline.EndTime.Before(inc.EndTime) {
|
||||
return
|
||||
}
|
||||
|
||||
var recentK = kLines[index-(inc.Window-1) : index+1]
|
||||
|
||||
sma, err := calculateSMA(recentK, inc.Window, KLineClosePriceMapper)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("SMA error")
|
||||
return
|
||||
if inc.Cache == nil {
|
||||
for _, k := range kLines {
|
||||
inc.Update(KLineClosePriceMapper(k))
|
||||
inc.EndTime = k.EndTime.Time()
|
||||
inc.EmitUpdate(inc.Values.Last())
|
||||
}
|
||||
} else {
|
||||
inc.Update(KLineClosePriceMapper(kline))
|
||||
inc.EndTime = kline.EndTime.Time()
|
||||
inc.EmitUpdate(inc.Values.Last())
|
||||
}
|
||||
inc.Values.Push(sma)
|
||||
|
||||
if len(inc.Values) > MaxNumOfSMA {
|
||||
inc.Values = inc.Values[MaxNumOfSMATruncateSize-1:]
|
||||
}
|
||||
|
||||
inc.EndTime = kLines[index].EndTime.Time()
|
||||
|
||||
inc.EmitUpdate(sma)
|
||||
}
|
||||
|
||||
func (inc *SMA) handleKLineWindowUpdate(interval types.Interval, window types.KLineWindow) {
|
||||
|
|
63
pkg/indicator/sma_test.go
Normal file
63
pkg/indicator/sma_test.go
Normal file
|
@ -0,0 +1,63 @@
|
|||
package indicator
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/fixedpoint"
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
/*
|
||||
python:
|
||||
|
||||
import pandas as pd
|
||||
import pandas_ta as ta
|
||||
|
||||
data = pd.Series([0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9,0])
|
||||
size = 5
|
||||
|
||||
result = ta.sma(data, size)
|
||||
print(result)
|
||||
*/
|
||||
func Test_SMA(t *testing.T) {
|
||||
Delta := 0.001
|
||||
var randomPrices = []byte(`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]`)
|
||||
var input []fixedpoint.Value
|
||||
if err := json.Unmarshal(randomPrices, &input); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
kLines []types.KLine
|
||||
want float64
|
||||
next float64
|
||||
update float64
|
||||
updateResult float64
|
||||
all int
|
||||
}{
|
||||
{
|
||||
name: "test",
|
||||
kLines: buildKLines(input),
|
||||
want: 7.0,
|
||||
next: 6.0,
|
||||
update: 0,
|
||||
updateResult: 6.0,
|
||||
all: 27,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sma := SMA{
|
||||
IntervalWindow: types.IntervalWindow{Window: 5},
|
||||
}
|
||||
sma.calculateAndUpdate(tt.kLines)
|
||||
assert.InDelta(t, tt.want, sma.Last(), Delta)
|
||||
assert.InDelta(t, tt.next, sma.Index(1), Delta)
|
||||
sma.Update(tt.update)
|
||||
assert.InDelta(t, tt.updateResult, sma.Last(), Delta)
|
||||
assert.Equal(t, tt.all, sma.Length())
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user