diff --git a/pkg/strategy/common/strategy.go b/pkg/strategy/common/strategy.go index 5828ee38d..26240b50e 100644 --- a/pkg/strategy/common/strategy.go +++ b/pkg/strategy/common/strategy.go @@ -8,6 +8,7 @@ import ( "github.com/c9s/bbgo/pkg/bbgo" "github.com/c9s/bbgo/pkg/fixedpoint" + "github.com/c9s/bbgo/pkg/risk/circuitbreaker" "github.com/c9s/bbgo/pkg/risk/riskcontrol" "github.com/c9s/bbgo/pkg/types" ) @@ -19,7 +20,7 @@ type RiskController struct { CircuitBreakEMA types.IntervalWindow `json:"circuitBreakEMA"` positionRiskControl *riskcontrol.PositionRiskControl - circuitBreakRiskControl *riskcontrol.CircuitBreakRiskControl + circuitBreakRiskControl *circuitbreaker.BasicCircuitBreaker } // Strategy provides the core functionality that is required by a long/short strategy. @@ -78,12 +79,12 @@ func (s *Strategy) Initialize(ctx context.Context, environ *bbgo.Environment, se if !s.CircuitBreakLossThreshold.IsZero() { log.Infof("circuitBreakLossThreshold is configured, setting up CircuitBreakRiskControl...") - s.circuitBreakRiskControl = riskcontrol.NewCircuitBreakRiskControl( - s.Position, - session.Indicators(market.Symbol).EWMA(s.CircuitBreakEMA), - s.CircuitBreakLossThreshold, - s.ProfitStats, - 24*time.Hour) + s.circuitBreakRiskControl = circuitbreaker.NewBasicCircuitBreaker(strategyID, instanceID) + s.OrderExecutor.TradeCollector().OnProfit(func(trade types.Trade, profit *types.Profit) { + if profit != nil && s.circuitBreakRiskControl != nil { + s.circuitBreakRiskControl.RecordProfit(profit.Profit, trade.Time.Time()) + } + }) } } @@ -91,5 +92,6 @@ func (s *Strategy) IsHalted(t time.Time) bool { if s.circuitBreakRiskControl == nil { return false } - return s.circuitBreakRiskControl.IsHalted(t) + _, isHalted := s.circuitBreakRiskControl.IsHalted(t) + return isHalted }