diff --git a/pkg/strategy/harmonic/strategy.go b/pkg/strategy/harmonic/strategy.go index d379f3262..4d98fda34 100644 --- a/pkg/strategy/harmonic/strategy.go +++ b/pkg/strategy/harmonic/strategy.go @@ -13,7 +13,6 @@ import ( "github.com/c9s/bbgo/pkg/indicator" "github.com/c9s/bbgo/pkg/types" "github.com/sirupsen/logrus" - floats2 "gonum.org/v1/gonum/floats" ) const ID = "harmonic" @@ -342,9 +341,9 @@ func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, se states.Update(0) s.session.MarketDataStream.OnKLineClosed(types.KLineWith(s.Symbol, s.Interval, func(kline types.KLine) { - log.Infof("Shark Score: %f, Current Price: %f", s.shark.Last(), kline.Close.Float64()) + log.Infof("shark score: %f, current price: %f", s.shark.Last(), kline.Close.Float64()) - nextState := alpha(s.shark.Array(s.Window), states.Array(s.Window), s.Window) + nextState := hmm(s.shark.Array(s.Window), states.Array(s.Window), s.Window) states.Update(nextState) log.Infof("Denoised signal via HMM: %f", states.Last()) @@ -367,7 +366,7 @@ func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, se Side: types.SideTypeBuy, Quantity: s.Quantity, Type: types.OrderTypeMarket, - Tag: "shark long", + Tag: "sharkLong", }) } else if states.Mean(5) == -1 && direction != -1 { _, _ = s.orderExecutor.SubmitOrders(ctx, types.SubmitOrder{ @@ -375,7 +374,7 @@ func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, se Side: types.SideTypeSell, Quantity: s.Quantity, Type: types.OrderTypeMarket, - Tag: "shark short", + Tag: "sharkShort", }) } })) @@ -402,7 +401,7 @@ func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, se } // TODO: dirichlet distribution is a too naive solution -func observationDistribution(y_t, x_t float64) float64 { +func observeDistribution(y_t, x_t float64) float64 { if x_t == 0. && y_t == 0 { // observed zero value from indicator when in neutral state return 1. @@ -417,7 +416,7 @@ func observationDistribution(y_t, x_t float64) float64 { } } -func transitionProbability(x_t0, x_t1 int) float64 { +func transitProbability(x_t0, x_t1 int) float64 { // stick to the same sate if x_t0 == x_t1 { return 0.99 @@ -426,7 +425,21 @@ func transitionProbability(x_t0, x_t1 int) float64 { return 1 - 0.99 } -func alpha(y_t []float64, x_t []float64, l int) float64 { +// HMM main function, ref: https://tr8dr.github.io/HMMFiltering/ +/* +# initialize time step 0 using state priors and observation dist p(y | x = s) +for si in states: + alpha[t = 0, state = si] = pi[si] * p(y[0] | x = si) + +# determine alpha for t = 1 .. n +for t in 1 .. n: + for sj in states: + alpha[t,sj] = max([alpha[t-1,si] * M[si,sj] for si in states]) * p(y[t] | x = sj) + +# determine current state at time t +return argmax(alpha[t,si] over si) +*/ +func hmm(y_t []float64, x_t []float64, l int) float64 { al := make([]float64, l) an := make([]float64, l) as := make([]float64, l) @@ -440,26 +453,29 @@ func alpha(y_t []float64, x_t []float64, l int) float64 { sin := make([]float64, 3) sis := make([]float64, 3) for i := -1; i <= 1; i++ { - sil = append(sil, x_t[n-1-1]*transitionProbability(i, j)) - sin = append(sin, x_t[n-1-1]*transitionProbability(i, j)) - sis = append(sis, x_t[n-1-1]*transitionProbability(i, j)) + sil = append(sil, x_t[n-1-1]*transitProbability(i, j)) + sin = append(sin, x_t[n-1-1]*transitProbability(i, j)) + sis = append(sis, x_t[n-1-1]*transitProbability(i, j)) } if j > 0 { - long = floats2.Max(sil) * observationDistribution(y_t[n-1], float64(j)) + _, longArr := floats.MinMax(sil, 3) + long = longArr[0] * observeDistribution(y_t[n-1], float64(j)) al = append(al, long) } else if j == 0 { - neut = floats2.Max(sin) * observationDistribution(y_t[n-1], float64(j)) + _, neutArr := floats.MinMax(sin, 3) + neut = neutArr[0] * observeDistribution(y_t[n-1], float64(j)) an = append(an, neut) } else if j < 0 { - short = floats2.Max(sis) * observationDistribution(y_t[n-1], float64(j)) + _, shortArr := floats.MinMax(sis, 3) + short = shortArr[0] * observeDistribution(y_t[n-1], float64(j)) as = append(as, short) } } } - maximum := floats2.Max([]float64{long, neut, short}) - if maximum == long { + _, maximum := floats.MinMax([]float64{long, neut, short}, 3) + if maximum[0] == long { return 1 - } else if maximum == short { + } else if maximum[0] == short { return -1 } return 0