indicator: rewrite RSI indicator

This commit is contained in:
c9s 2023-05-31 16:30:04 +08:00
parent e58db43067
commit 114e292d8f
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
6 changed files with 151 additions and 17 deletions

View File

@ -23,12 +23,16 @@ func Price(source KLineSubscription, mapper KLineValueMapper) *PriceStream {
source.AddSubscriber(func(k types.KLine) {
v := s.mapper(k)
s.slice.Push(v)
s.EmitUpdate(v)
s.PushAndEmit(v)
})
return s
}
func (s *PriceStream) PushAndEmit(v float64) {
s.slice.Push(v)
s.EmitUpdate(v)
}
func ClosePrices(source KLineSubscription) *PriceStream {
return Price(source, KLineClosePriceMapper)
}

17
pkg/indicator/types.go Normal file
View File

@ -0,0 +1,17 @@
package indicator
import "github.com/c9s/bbgo/pkg/types"
type Float64Calculator interface {
Calculate(x float64) float64
}
type Float64Source interface {
types.Series
OnUpdate(f func(v float64))
}
type Float64Subscription interface {
types.Series
AddSubscriber(f func(v float64))
}

View File

@ -1 +1,15 @@
package indicator
func max(x, y int) int {
if x > y {
return x
}
return y
}
func min(x, y int) int {
if x < y {
return x
}
return y
}

View File

@ -1,9 +1,5 @@
package indicator
import (
"github.com/c9s/bbgo/pkg/types"
)
/*
NEW INDICATOR DESIGN:
@ -21,13 +17,3 @@ macd := Subtract(fastEMA, slowEMA)
signal := EMA(macd, 16)
histogram := Subtract(macd, signal)
*/
type Float64Source interface {
types.Series
OnUpdate(f func(v float64))
}
type Float64Subscription interface {
types.Series
AddSubscriber(f func(v float64))
}

View File

@ -8,10 +8,12 @@ type RSIStream struct {
window int
// private states
source Float64Source
}
func RSI2(source Float64Source, window int) *RSIStream {
s := &RSIStream{
source: source,
Float64Series: NewFloat64Series(),
window: window,
}
@ -25,6 +27,30 @@ func RSI2(source Float64Source, window int) *RSIStream {
return s
}
func (s *RSIStream) calculateAndPush(x float64) {
func (s *RSIStream) calculate(_ float64) float64 {
var gainSum, lossSum float64
var sourceLen = s.source.Length()
var limit = min(s.window, sourceLen)
for i := 0; i < limit; i++ {
value := s.source.Index(i)
prev := s.source.Index(i + 1)
change := value - prev
if change >= 0 {
gainSum += change
} else {
lossSum += -change
}
}
avgGain := gainSum / float64(limit)
avgLoss := lossSum / float64(limit)
rs := avgGain / avgLoss
rsi := 100.0 - (100.0 / (1.0 + rs))
return rsi
}
func (s *RSIStream) calculateAndPush(x float64) {
rsi := s.calculate(x)
s.slice.Push(rsi)
s.EmitUpdate(rsi)
}

View File

@ -0,0 +1,87 @@
package indicator
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/datatype/floats"
)
func Test_RSI2(t *testing.T) {
// test case from https://school.stockcharts.com/doku.php?id=technical_indicators:relative_strength_index_rsi
var data = []byte(`[44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, 46.08, 45.89, 46.03, 45.61, 46.28, 46.28, 46.00, 46.03, 46.41, 46.22, 45.64, 46.21, 46.25, 45.71, 46.45, 45.78, 45.35, 44.03, 44.18, 44.22, 44.57, 43.42, 42.66, 43.13]`)
var values []float64
err := json.Unmarshal(data, &values)
assert.NoError(t, err)
tests := []struct {
name string
values []float64
window int
want floats.Slice
}{
{
name: "RSI",
values: values,
window: 14,
want: floats.Slice{
100.000000,
99.439336,
99.440090,
98.251826,
98.279242,
98.297781,
98.307626,
98.319149,
98.334036,
98.342426,
97.951933,
97.957908,
97.108036,
97.147514,
70.464135,
70.020964,
69.831224,
80.567686,
73.333333,
59.806295,
62.528217,
60.000000,
48.477752,
53.878407,
48.952381,
43.862816,
37.732919,
32.263514,
32.718121,
38.142620,
31.748252,
25.099602,
30.217670,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// RSI2()
prices := &PriceStream{}
rsi := RSI2(prices, tt.window)
t.Logf("data length: %d", len(tt.values))
for _, price := range tt.values {
prices.PushAndEmit(price)
}
assert.Equal(t, floats.Slice(tt.values), prices.slice)
if assert.Equal(t, len(tt.want), len(rsi.slice)) {
for i, v := range tt.want {
assert.InDelta(t, v, rsi.slice[i], 0.000001, "Expected rsi.slice[%d] to be %v, but got %v", i, v, rsi.slice[i])
}
}
})
}
}