Merge pull request #831 from c9s/feature/defaulter

feature: api: add strategy defaulter interface
This commit is contained in:
Yo-An Lin 2022-07-19 17:55:24 +08:00 committed by GitHub
commit ed91fdc915
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 9 deletions

View File

@ -7,7 +7,7 @@ import (
"sync" "sync"
) )
func (g *Graceful) OnShutdown(cb func(ctx context.Context, wg *sync.WaitGroup)) { func (g *Graceful) OnShutdown(cb ShutdownHandler) {
g.shutdownCallbacks = append(g.shutdownCallbacks, cb) g.shutdownCallbacks = append(g.shutdownCallbacks, cb)
} }

View File

@ -10,9 +10,11 @@ import (
var graceful = &Graceful{} var graceful = &Graceful{}
type ShutdownHandler func(ctx context.Context, wg *sync.WaitGroup)
//go:generate callbackgen -type Graceful //go:generate callbackgen -type Graceful
type Graceful struct { type Graceful struct {
shutdownCallbacks []func(ctx context.Context, wg *sync.WaitGroup) shutdownCallbacks []ShutdownHandler
} }
// Shutdown is a blocking call to emit all shutdown callbacks at the same time. // Shutdown is a blocking call to emit all shutdown callbacks at the same time.
@ -29,7 +31,7 @@ func (g *Graceful) Shutdown(ctx context.Context) {
cancel() cancel()
} }
func OnShutdown(f func(ctx context.Context, wg *sync.WaitGroup)) { func OnShutdown(f ShutdownHandler) {
graceful.OnShutdown(f) graceful.OnShutdown(f)
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"reflect" "reflect"
"sync"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -14,6 +15,12 @@ import (
"github.com/c9s/bbgo/pkg/interact" "github.com/c9s/bbgo/pkg/interact"
) )
// Strategy method calls:
// -> Defaults() (optional method)
// -> Initialize() (optional method)
// -> Validate() (optional method)
// -> Run() (optional method)
// -> Shutdown(shutdownCtx context.Context, wg *sync.WaitGroup)
type StrategyID interface { type StrategyID interface {
ID() string ID() string
} }
@ -24,10 +31,23 @@ type SingleExchangeStrategy interface {
Run(ctx context.Context, orderExecutor OrderExecutor, session *ExchangeSession) error Run(ctx context.Context, orderExecutor OrderExecutor, session *ExchangeSession) error
} }
// StrategyInitializer's Initialize method is called before the Subscribe method call.
type StrategyInitializer interface { type StrategyInitializer interface {
Initialize() error Initialize() error
} }
type StrategyDefaulter interface {
Defaults() error
}
type StrategyValidator interface {
Validate() error
}
type StrategyShutdown interface {
Shutdown(ctx context.Context, wg *sync.WaitGroup)
}
// ExchangeSessionSubscriber provides an interface for collecting subscriptions from different strategies // ExchangeSessionSubscriber provides an interface for collecting subscriptions from different strategies
// Subscribe method will be called before the user data stream connection is created. // Subscribe method will be called before the user data stream connection is created.
type ExchangeSessionSubscriber interface { type ExchangeSessionSubscriber interface {
@ -43,10 +63,6 @@ type CrossExchangeStrategy interface {
CrossRun(ctx context.Context, orderExecutionRouter OrderExecutionRouter, sessions map[string]*ExchangeSession) error CrossRun(ctx context.Context, orderExecutionRouter OrderExecutionRouter, sessions map[string]*ExchangeSession) error
} }
type Validator interface {
Validate() error
}
type Logging interface { type Logging interface {
EnableLogging() EnableLogging()
DisableLogging() DisableLogging()
@ -153,6 +169,12 @@ func (trader *Trader) Subscribe() {
for sessionName, strategies := range trader.exchangeStrategies { for sessionName, strategies := range trader.exchangeStrategies {
session := trader.environment.sessions[sessionName] session := trader.environment.sessions[sessionName]
for _, strategy := range strategies { for _, strategy := range strategies {
if defaulter, ok := strategy.(StrategyDefaulter); ok {
if err := defaulter.Defaults(); err != nil {
panic(err)
}
}
if initializer, ok := strategy.(StrategyInitializer); ok { if initializer, ok := strategy.(StrategyInitializer); ok {
if err := initializer.Initialize(); err != nil { if err := initializer.Initialize(); err != nil {
panic(err) panic(err)
@ -168,6 +190,12 @@ func (trader *Trader) Subscribe() {
} }
for _, strategy := range trader.crossExchangeStrategies { for _, strategy := range trader.crossExchangeStrategies {
if defaulter, ok := strategy.(StrategyDefaulter); ok {
if err := defaulter.Defaults(); err != nil {
panic(err)
}
}
if initializer, ok := strategy.(StrategyInitializer); ok { if initializer, ok := strategy.(StrategyInitializer); ok {
if err := initializer.Initialize(); err != nil { if err := initializer.Initialize(); err != nil {
panic(err) panic(err)
@ -229,13 +257,17 @@ func (trader *Trader) RunSingleExchangeStrategy(ctx context.Context, strategy Si
} }
} }
// If the strategy has Validate() method, run it and check the error if v, ok := strategy.(StrategyValidator); ok {
if v, ok := strategy.(Validator); ok {
if err := v.Validate(); err != nil { if err := v.Validate(); err != nil {
return fmt.Errorf("failed to validate the config: %w", err) return fmt.Errorf("failed to validate the config: %w", err)
} }
} }
if shutdown, ok := strategy.(StrategyShutdown); ok {
// Register the shutdown callback
OnShutdown(shutdown.Shutdown)
}
return strategy.Run(ctx, orderExecutor, session) return strategy.Run(ctx, orderExecutor, session)
} }