From 126974cd79e452936481f1cd940a757ad2087940 Mon Sep 17 00:00:00 2001 From: zenix Date: Thu, 16 Jun 2022 19:05:33 +0900 Subject: [PATCH] feature: dmi add test, fix: rma with Adjust setting (follow the implementation of pandas.DataFrame.ewm) --- pkg/indicator/atr.go | 5 ++- pkg/indicator/atr_test.go | 21 +++++++++- pkg/indicator/dmi.go | 14 ++++--- pkg/indicator/dmi_test.go | 83 +++++++++++++++++++++++++++++++++++++++ pkg/indicator/rma.go | 37 ++++++++++------- 5 files changed, 139 insertions(+), 21 deletions(-) create mode 100644 pkg/indicator/dmi_test.go diff --git a/pkg/indicator/atr.go b/pkg/indicator/atr.go index 9f8bcf58f..016eb3a2d 100644 --- a/pkg/indicator/atr.go +++ b/pkg/indicator/atr.go @@ -25,7 +25,10 @@ func (inc *ATR) Update(high, low, cloze float64) { } if inc.RMA == nil { - inc.RMA = &RMA{IntervalWindow: types.IntervalWindow{Window: inc.Window}} + inc.RMA = &RMA{ + IntervalWindow: types.IntervalWindow{Window: inc.Window}, + Adjust: true, + } inc.PreviousClose = cloze return } diff --git a/pkg/indicator/atr_test.go b/pkg/indicator/atr_test.go index 63f189992..bbb562b2a 100644 --- a/pkg/indicator/atr_test.go +++ b/pkg/indicator/atr_test.go @@ -9,6 +9,25 @@ import ( "github.com/c9s/bbgo/pkg/types" ) +/* +python + +import pandas as pd +import pandas_ta as ta + +data = { + "high": [40145.0, 40186.36, 40196.39, 40344.6, 40245.48, 40273.24, 40464.0, 40699.0, 40627.48, 40436.31, 40370.0, 40376.8, 40227.03, 40056.52, 39721.7, 39597.94, 39750.15, 39927.0, 40289.02, 40189.0], + "low": [39870.71, 39834.98, 39866.31, 40108.31, 40016.09, 40094.66, 40105.0, 40196.48, 40154.99, 39800.0, 39959.21, 39922.98, 39940.02, 39632.0, 39261.39, 39254.63, 39473.91, 39555.51, 39819.0, 40006.84], + "close": [40105.78, 39935.23, 40183.97, 40182.03, 40212.26, 40149.99, 40378.0, 40618.37, 40401.03, 39990.39, 40179.13, 40097.23, 40014.72, 39667.85, 39303.1, 39519.99, +39693.79, 39827.96, 40074.94, 40059.84] +} + +high = pd.Series(data['high']) +low = pd.Series(data['low']) +close = pd.Series(data['close']) +result = ta.atr(high, low, close, length=14) +print(result) +*/ func Test_calculateATR(t *testing.T) { var bytes = []byte(`{ "high": [40145.0, 40186.36, 40196.39, 40344.6, 40245.48, 40273.24, 40464.0, 40699.0, 40627.48, 40436.31, 40370.0, 40376.8, 40227.03, 40056.52, 39721.7, 39597.94, 39750.15, 39927.0, 40289.02, 40189.0], @@ -35,7 +54,7 @@ func Test_calculateATR(t *testing.T) { name: "test_binance_btcusdt_1h", kLines: buildKLines(bytes), window: 14, - want: 364.048648, + want: 367.913903, }, } diff --git a/pkg/indicator/dmi.go b/pkg/indicator/dmi.go index 580fe3339..cb0fc7169 100644 --- a/pkg/indicator/dmi.go +++ b/pkg/indicator/dmi.go @@ -28,9 +28,9 @@ type DMI struct { func (inc *DMI) Update(high, low, cloze float64) { if inc.DMP == nil || inc.DMN == nil { - inc.DMP = &RMA{IntervalWindow: inc.IntervalWindow} - inc.DMN = &RMA{IntervalWindow: inc.IntervalWindow} - inc.ADX = &RMA{IntervalWindow: types.IntervalWindow{Window: inc.ADXSmoothing}} + inc.DMP = &RMA{IntervalWindow: inc.IntervalWindow, Adjust: true} + inc.DMN = &RMA{IntervalWindow: inc.IntervalWindow, Adjust: true} + inc.ADX = &RMA{IntervalWindow: types.IntervalWindow{Window: inc.ADXSmoothing}, Adjust: true} } if inc.atr == nil { inc.atr = &ATR{IntervalWindow: inc.IntervalWindow} @@ -56,15 +56,19 @@ func (inc *DMI) Update(high, low, cloze float64) { neg = dn } - k := 100. / inc.atr.Last() inc.DMP.Update(pos) inc.DMN.Update(neg) + if inc.atr.Length() < inc.Window { + return + } + k := 100. / inc.atr.Last() dmp := inc.DMP.Last() dmn := inc.DMN.Last() inc.DIPlus.Update(k * dmp) inc.DIMinus.Update(k * dmn) - dx := 100. * k * math.Abs(dmp-dmn) / (dmp + dmn) + dx := 100. * math.Abs(dmp-dmn) / (dmp + dmn) inc.ADX.Update(dx) + } func (inc *DMI) GetDIPlus() types.Series { diff --git a/pkg/indicator/dmi_test.go b/pkg/indicator/dmi_test.go new file mode 100644 index 000000000..c5755e1b5 --- /dev/null +++ b/pkg/indicator/dmi_test.go @@ -0,0 +1,83 @@ +package indicator + +import ( + "encoding/json" + "testing" + "fmt" + + "github.com/c9s/bbgo/pkg/fixedpoint" + "github.com/c9s/bbgo/pkg/types" + "github.com/stretchr/testify/assert" +) + +/* +python: + +import pandas as pd +import pandas_ta as ta + +data = pd.Series([0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9]) + +high = pd.Series([100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109]) + +low = pd.Series([80,81,82,83,84,85,86,87,88,89,80,81,82,83,84,85,86,87,88,89,80,81,82,83,84,85,86,87,88,89]) + +close = pd.Series([90,91,92,93,94,95,96,97,98,99,90,91,92,93,94,95,96,97,98,99,90,91,92,93,94,95,96,97,98,99]) + +result = ta.adx(high, low, close, 5, 14) +print(result['ADX_14']) + +print(result['DMP_5']) +print(result['DMN_5']) +*/ +func Test_DMI(t *testing.T) { + var Delta = 0.001 + var highb = []byte(`[100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109]`) + var lowb = []byte(`[80,81,82,83,84,85,86,87,88,89,80,81,82,83,84,85,86,87,88,89,80,81,82,83,84,85,86,87,88,89]`) + var clozeb = []byte(`[90,91,92,93,94,95,96,97,98,99,90,91,92,93,94,95,96,97,98,99,90,91,92,93,94,95,96,97,98,99]`) + + buildKLines := func(h, l, c []byte) (klines []types.KLine) { + var hv, cv, lv []fixedpoint.Value + _ = json.Unmarshal(h, &hv) + _ = json.Unmarshal(l, &lv) + _ = json.Unmarshal(c, &cv) + if len(hv) != len(lv) || len(lv) != len(cv) { + panic(fmt.Sprintf("length not equal %v %v %v", len(hv), len(lv), len(cv))) + } + for i, hh := range hv { + kline := types.KLine{High: hh, Low: lv[i], Close: cv[i]} + klines = append(klines, kline) + } + return klines + } + + type output struct{dip float64; dim float64; adx float64} + + tests := []struct { + name string + klines []types.KLine + want output + next output + total int + }{ + { + name: "test_dmi", + klines: buildKLines(highb, lowb, clozeb), + want: output{dip: 4.85114, dim: 1.339736, adx: 37.857156}, + next: output{dip: 4.813853, dim: 1.67532, adx: 36.111434}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dmi := &DMI{ + IntervalWindow: types.IntervalWindow{Window: 5}, + ADXSmoothing: 14, + } + dmi.calculateAndUpdate(tt.klines) + assert.InDelta(t, dmi.GetDIPlus().Last(), tt.want.dip, Delta) + assert.InDelta(t, dmi.GetDIMinus().Last(), tt.want.dim, Delta) + assert.InDelta(t, dmi.GetADX().Last(), tt.want.adx, Delta) + }) + } + +} diff --git a/pkg/indicator/rma.go b/pkg/indicator/rma.go index 4418ab54a..8fee7a128 100644 --- a/pkg/indicator/rma.go +++ b/pkg/indicator/rma.go @@ -6,33 +6,42 @@ import ( "github.com/c9s/bbgo/pkg/types" ) -// Refer: Running Moving Average +// Running Moving Average +// Refer: https://github.com/twopirllc/pandas-ta/blob/main/pandas_ta/overlap/rma.py#L5 +// Refer: https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.ewm.html#pandas-dataframe-ewm //go:generate callbackgen -type RMA type RMA struct { types.IntervalWindow - Values types.Float64Slice - Sources types.Float64Slice - + Values types.Float64Slice + counter int + Adjust bool + tmp float64 + sum float64 EndTime time.Time UpdateCallbacks []func(value float64) } func (inc *RMA) Update(x float64) { - inc.Sources.Push(x) + lambda := 1 / float64(inc.Window) + if inc.counter == 0 { + inc.sum = 1 + inc.tmp = x + } else { + if inc.Adjust { + inc.sum = inc.sum*(1-lambda) + 1 + inc.tmp = inc.tmp + (x-inc.tmp)/inc.sum + } else { + inc.tmp = inc.tmp*(1-lambda) + x*lambda + } + } + inc.counter++ - if len(inc.Sources) < inc.Window { + if inc.counter < inc.Window { inc.Values.Push(0) return } - if len(inc.Sources) == inc.Window { - inc.Values.Push(inc.Sources.Mean()) - return - } - - lambda := 1 / float64(inc.Window) - rma := (1-lambda)*inc.Values.Last() + lambda*x - inc.Values.Push(rma) + inc.Values.Push(inc.tmp) } func (inc *RMA) Last() float64 {