From 3959e288fd1eaec5cea95b074575d5e3f3bd0059 Mon Sep 17 00:00:00 2001 From: c9s Date: Tue, 26 Jul 2022 18:35:50 +0800 Subject: [PATCH] all: refactor standard indicator helper and fix tests --- pkg/bbgo/exit_lower_shadow_take_profit.go | 2 +- pkg/bbgo/session.go | 15 +++++------ pkg/bbgo/trader.go | 2 +- pkg/indicator/atr_test.go | 5 +++- pkg/indicator/hull.go | 2 +- pkg/indicator/hull_test.go | 11 +++++--- pkg/strategy/emastop/strategy.go | 7 +---- pkg/strategy/funding/strategy.go | 27 ++++++------------- pkg/strategy/pivotshort/breaklow.go | 2 +- pkg/strategy/pricedrop/strategy.go | 2 +- pkg/strategy/skeleton/strategy.go | 32 ++++------------------- pkg/strategy/support/strategy.go | 5 +--- pkg/strategy/techsignal/strategy.go | 6 +---- pkg/strategy/xmaker/strategy.go | 2 +- 14 files changed, 41 insertions(+), 79 deletions(-) diff --git a/pkg/bbgo/exit_lower_shadow_take_profit.go b/pkg/bbgo/exit_lower_shadow_take_profit.go index e12fd5936..1ffae6c36 100644 --- a/pkg/bbgo/exit_lower_shadow_take_profit.go +++ b/pkg/bbgo/exit_lower_shadow_take_profit.go @@ -27,7 +27,7 @@ func (s *LowerShadowTakeProfit) Bind(session *ExchangeSession, orderExecutor *Ge s.session = session s.orderExecutor = orderExecutor - stdIndicatorSet, _ := session.StandardIndicatorSet(s.Symbol) + stdIndicatorSet := session.StandardIndicatorSet(s.Symbol) ewma := stdIndicatorSet.EWMA(s.IntervalWindow) diff --git a/pkg/bbgo/session.go b/pkg/bbgo/session.go index 6939d94e7..aa86fe06f 100644 --- a/pkg/bbgo/session.go +++ b/pkg/bbgo/session.go @@ -438,17 +438,16 @@ func (session *ExchangeSession) initSymbol(ctx context.Context, environ *Environ return nil } -func (session *ExchangeSession) StandardIndicatorSet(symbol string) (*StandardIndicatorSet, bool) { +func (session *ExchangeSession) StandardIndicatorSet(symbol string) *StandardIndicatorSet { set, ok := session.standardIndicatorSets[symbol] - if !ok { - if store, ok2 := session.MarketDataStore(symbol); ok2 { - set = NewStandardIndicatorSet(symbol, session.MarketDataStream, store) - session.standardIndicatorSets[symbol] = set - return set, true - } + if ok { + return set } - return set, ok + store, _ := session.MarketDataStore(symbol) + set = NewStandardIndicatorSet(symbol, session.MarketDataStream, store) + session.standardIndicatorSets[symbol] = set + return set } func (session *ExchangeSession) Position(symbol string) (pos *types.Position, ok bool) { diff --git a/pkg/bbgo/trader.go b/pkg/bbgo/trader.go index eb89a8302..a1224fe2b 100644 --- a/pkg/bbgo/trader.go +++ b/pkg/bbgo/trader.go @@ -292,7 +292,7 @@ func (trader *Trader) injectFields() error { return fmt.Errorf("market of symbol %s not found", symbol) } - indicatorSet, ok := session.StandardIndicatorSet(symbol) + indicatorSet := session.StandardIndicatorSet(symbol) if !ok { return fmt.Errorf("standardIndicatorSet of symbol %s not found", symbol) } diff --git a/pkg/indicator/atr_test.go b/pkg/indicator/atr_test.go index 55cd6e8fe..b5cb138e9 100644 --- a/pkg/indicator/atr_test.go +++ b/pkg/indicator/atr_test.go @@ -61,7 +61,10 @@ func Test_calculateATR(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { atr := &ATR{IntervalWindow: types.IntervalWindow{Window: tt.window}} - atr.CalculateAndUpdate(tt.kLines) + for _, k := range tt.kLines { + atr.PushK(k) + } + got := atr.Last() diff := math.Trunc((got-tt.want)*100) / 100 if diff != 0 { diff --git a/pkg/indicator/hull.go b/pkg/indicator/hull.go index 43c0a3c16..4352bbba0 100644 --- a/pkg/indicator/hull.go +++ b/pkg/indicator/hull.go @@ -55,7 +55,7 @@ func (inc *HULL) Length() int { } func (inc *HULL) PushK(k types.KLine) { - if k.EndTime.Before(inc.ma1.EndTime) { + if inc.ma1 != nil && inc.ma1.Length() > 0 && k.EndTime.Before(inc.ma1.EndTime) { return } diff --git a/pkg/indicator/hull_test.go b/pkg/indicator/hull_test.go index 64472c0e6..857c8d30d 100644 --- a/pkg/indicator/hull_test.go +++ b/pkg/indicator/hull_test.go @@ -4,9 +4,10 @@ import ( "encoding/json" "testing" + "github.com/stretchr/testify/assert" + "github.com/c9s/bbgo/pkg/fixedpoint" "github.com/c9s/bbgo/pkg/types" - "github.com/stretchr/testify/assert" ) /* @@ -26,6 +27,7 @@ func Test_HULL(t *testing.T) { if err := json.Unmarshal(randomPrices, &input); err != nil { panic(err) } + tests := []struct { name string kLines []types.KLine @@ -44,8 +46,11 @@ func Test_HULL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hull := HULL{IntervalWindow: types.IntervalWindow{Window: 16}} - hull.CalculateAndUpdate(tt.kLines) + hull := &HULL{IntervalWindow: types.IntervalWindow{Window: 16}} + for _, k := range tt.kLines { + hull.PushK(k) + } + last := hull.Last() assert.InDelta(t, tt.want, last, Delta) assert.InDelta(t, tt.next, hull.Index(1), Delta) diff --git a/pkg/strategy/emastop/strategy.go b/pkg/strategy/emastop/strategy.go index 7c9c4a190..e120c6469 100644 --- a/pkg/strategy/emastop/strategy.go +++ b/pkg/strategy/emastop/strategy.go @@ -25,7 +25,6 @@ func init() { } type Strategy struct { - SourceExchangeName string `json:"sourceExchange"` TargetExchangeName string `json:"targetExchange"` @@ -175,11 +174,7 @@ func (s *Strategy) handleOrderUpdate(order types.Order) { } func (s *Strategy) loadIndicator(sourceSession *bbgo.ExchangeSession) (types.Float64Indicator, error) { - var standardIndicatorSet, ok = sourceSession.StandardIndicatorSet(s.Symbol) - if !ok { - return nil, fmt.Errorf("standardIndicatorSet is nil, symbol %s", s.Symbol) - } - + var standardIndicatorSet = sourceSession.StandardIndicatorSet(s.Symbol) var iw = types.IntervalWindow{Interval: s.MovingAverageInterval, Window: s.MovingAverageWindow} switch strings.ToUpper(s.MovingAverageType) { diff --git a/pkg/strategy/funding/strategy.go b/pkg/strategy/funding/strategy.go index 83d0a871d..58361e963 100644 --- a/pkg/strategy/funding/strategy.go +++ b/pkg/strategy/funding/strategy.go @@ -3,7 +3,6 @@ package funding import ( "context" "errors" - "fmt" "strings" "github.com/sirupsen/logrus" @@ -32,7 +31,7 @@ type Strategy struct { Market types.Market `json:"-"` Quantity fixedpoint.Value `json:"quantity,omitempty"` MaxExposurePosition fixedpoint.Value `json:"maxExposurePosition"` - //Interval types.Interval `json:"interval"` + // Interval types.Interval `json:"interval"` FundingRate *struct { High fixedpoint.Value `json:"high"` @@ -49,11 +48,11 @@ type Strategy struct { // MovingAverageInterval is the interval of k-lines for the moving average indicator to calculate, // it could be "1m", "5m", "1h" and so on. note that, the moving averages are calculated from // the k-line data we subscribed - //MovingAverageInterval types.Interval `json:"movingAverageInterval"` + // MovingAverageInterval types.Interval `json:"movingAverageInterval"` // - //// MovingAverageWindow is the number of the window size of the moving average indicator. - //// The number of k-lines in the window. generally used window sizes are 7, 25 and 99 in the TradingView. - //MovingAverageWindow int `json:"movingAverageWindow"` + // // MovingAverageWindow is the number of the window size of the moving average indicator. + // // The number of k-lines in the window. generally used window sizes are 7, 25 and 99 in the TradingView. + // MovingAverageWindow int `json:"movingAverageWindow"` MovingAverageIntervalWindow types.IntervalWindow `json:"movingAverageIntervalWindow"` @@ -70,9 +69,9 @@ func (s *Strategy) ID() string { func (s *Strategy) Subscribe(session *bbgo.ExchangeSession) { // session.Subscribe(types.BookChannel, s.Symbol, types.SubscribeOptions{}) - //session.Subscribe(types.KLineChannel, s.Symbol, types.SubscribeOptions{ + // session.Subscribe(types.KLineChannel, s.Symbol, types.SubscribeOptions{ // Interval: string(s.Interval), - //}) + // }) for _, detection := range s.SupportDetection { session.Subscribe(types.KLineChannel, s.Symbol, types.SubscribeOptions{ @@ -93,23 +92,13 @@ func (s *Strategy) Validate() error { } func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, session *bbgo.ExchangeSession) error { + standardIndicatorSet := session.StandardIndicatorSet(s.Symbol) - standardIndicatorSet, ok := session.StandardIndicatorSet(s.Symbol) - if !ok { - return fmt.Errorf("standardIndicatorSet is nil, symbol %s", s.Symbol) - } - //binanceExchange, ok := session.Exchange.(*binance.Exchange) - //if !ok { - // log.Error("exchange failed") - //} if !session.Futures { log.Error("futures not enabled in config for this strategy") return nil } - //if s.FundingRate != nil { - // go s.listenToFundingRate(ctx, binanceExchange) - //} premiumIndex, err := session.Exchange.(*binance.Exchange).QueryPremiumIndex(ctx, s.Symbol) if err != nil { log.Error("exchange does not support funding rate api") diff --git a/pkg/strategy/pivotshort/breaklow.go b/pkg/strategy/pivotshort/breaklow.go index 32018e74b..7b67af813 100644 --- a/pkg/strategy/pivotshort/breaklow.go +++ b/pkg/strategy/pivotshort/breaklow.go @@ -66,7 +66,7 @@ func (s *BreakLow) Bind(session *bbgo.ExchangeSession, orderExecutor *bbgo.Gener position := orderExecutor.Position() symbol := position.Symbol store, _ := session.MarketDataStore(s.Symbol) - standardIndicator, _ := session.StandardIndicatorSet(s.Symbol) + standardIndicator := session.StandardIndicatorSet(s.Symbol) s.lastLow = fixedpoint.Zero diff --git a/pkg/strategy/pricedrop/strategy.go b/pkg/strategy/pricedrop/strategy.go index bfca2577f..c8a9b1664 100644 --- a/pkg/strategy/pricedrop/strategy.go +++ b/pkg/strategy/pricedrop/strategy.go @@ -53,7 +53,7 @@ func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, se return fmt.Errorf("market %s is not defined", s.Symbol) } - standardIndicatorSet, ok := session.StandardIndicatorSet(s.Symbol) + standardIndicatorSet := session.StandardIndicatorSet(s.Symbol) if !ok { return fmt.Errorf("standardIndicatorSet is nil, symbol %s", s.Symbol) } diff --git a/pkg/strategy/skeleton/strategy.go b/pkg/strategy/skeleton/strategy.go index ad0b3eb4c..048da9745 100644 --- a/pkg/strategy/skeleton/strategy.go +++ b/pkg/strategy/skeleton/strategy.go @@ -8,7 +8,6 @@ import ( "github.com/c9s/bbgo/pkg/bbgo" "github.com/c9s/bbgo/pkg/fixedpoint" - "github.com/c9s/bbgo/pkg/indicator" "github.com/c9s/bbgo/pkg/types" ) @@ -82,32 +81,11 @@ func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, se s.State = &State{Counter: 1} } - // Optional: You can get the market data store from session - store, ok := session.MarketDataStore(s.Symbol) - if !ok { - return fmt.Errorf("market data store %s not found", s.Symbol) - } - - // Initialize a custom indicator - atr := &indicator.ATR{ - IntervalWindow: types.IntervalWindow{ - Interval: types.Interval1m, - Window: 14, - }, - } - - // Bind the indicator to the market data store, so that when a new kline is received, - // the indicator will be updated. - atr.Bind(store) - - // To get the past kline history, call KLinesOfInterval from the market data store - klines, ok := store.KLinesOfInterval(types.Interval1m) - if !ok { - return fmt.Errorf("market data store %s lkline not found", s.Symbol) - } - - // Use the history data to initialize the indicator - atr.CalculateAndUpdate(*klines) + indicators := session.StandardIndicatorSet(s.Symbol) + atr := indicators.ATR(types.IntervalWindow{ + Interval: types.Interval1m, + Window: 14, + }) // To get the market information from the current session // The market object provides the precision, MoQ (minimal of quantity) information diff --git a/pkg/strategy/support/strategy.go b/pkg/strategy/support/strategy.go index 126bdadfe..d8e3ec545 100644 --- a/pkg/strategy/support/strategy.go +++ b/pkg/strategy/support/strategy.go @@ -387,10 +387,7 @@ func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, se log.Infof("adjusted minimal support volume to %s according to sensitivity %s", s.MinVolume.String(), s.Sensitivity.String()) } - standardIndicatorSet, ok := session.StandardIndicatorSet(s.Symbol) - if !ok { - return fmt.Errorf("standardIndicatorSet is nil, symbol %s", s.Symbol) - } + standardIndicatorSet := session.StandardIndicatorSet(s.Symbol) if s.TriggerMovingAverage != zeroiw { s.triggerEMA = standardIndicatorSet.EWMA(s.TriggerMovingAverage) diff --git a/pkg/strategy/techsignal/strategy.go b/pkg/strategy/techsignal/strategy.go index e95c970d1..bf3326dd3 100644 --- a/pkg/strategy/techsignal/strategy.go +++ b/pkg/strategy/techsignal/strategy.go @@ -3,7 +3,6 @@ package techsignal import ( "context" "errors" - "fmt" "strings" "time" @@ -145,10 +144,7 @@ func (s *Strategy) listenToFundingRate(ctx context.Context, exchange *binance.Ex } func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, session *bbgo.ExchangeSession) error { - standardIndicatorSet, ok := session.StandardIndicatorSet(s.Symbol) - if !ok { - return fmt.Errorf("standardIndicatorSet is nil, symbol %s", s.Symbol) - } + standardIndicatorSet := session.StandardIndicatorSet(s.Symbol) if s.FundingRate != nil { if binanceExchange, ok := session.Exchange.(*binance.Exchange); ok { diff --git a/pkg/strategy/xmaker/strategy.go b/pkg/strategy/xmaker/strategy.go index b469591c8..65f4e4fc9 100644 --- a/pkg/strategy/xmaker/strategy.go +++ b/pkg/strategy/xmaker/strategy.go @@ -681,7 +681,7 @@ func (s *Strategy) CrossRun(ctx context.Context, orderExecutionRouter bbgo.Order return fmt.Errorf("maker session market %s is not defined", s.Symbol) } - standardIndicatorSet, ok := s.sourceSession.StandardIndicatorSet(s.Symbol) + standardIndicatorSet := s.sourceSession.StandardIndicatorSet(s.Symbol) if !ok { return fmt.Errorf("%s standard indicator set not found", s.Symbol) }