integrate state recorder

This commit is contained in:
c9s 2022-05-10 13:25:03 +08:00
parent 185a8279b2
commit 7b17b1a757
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
5 changed files with 88 additions and 50 deletions

View File

@ -78,6 +78,10 @@ func (r *StateRecorder) Scan(instance Instance) error {
for i := 0; i < rt.NumField(); i++ { for i := 0; i < rt.NumField(); i++ {
structField := rt.Field(i) structField := rt.Field(i)
if !structField.IsExported() {
continue
}
obj := rv.Field(i).Interface() obj := rv.Field(i).Interface()
switch o := obj.(type) { switch o := obj.(type) {
@ -87,6 +91,16 @@ func (r *StateRecorder) Scan(instance Instance) error {
return fmt.Errorf("%v is a non-defined type", structField.Type) return fmt.Errorf("%v is a non-defined type", structField.Type)
} }
if err := r.newCsvWriter(o, instance, typeName); err != nil {
return err
}
}
}
return nil
}
func (r *StateRecorder) newCsvWriter(o types.CsvFormatter, instance Instance, typeName string) error {
f, err := r.openFile(instance, typeName) f, err := r.openFile(instance, typeName)
if err != nil { if err != nil {
return err return err
@ -97,11 +111,11 @@ func (r *StateRecorder) Scan(instance Instance) error {
} }
r.files[o] = f r.files[o] = f
r.writers[o] = csv.NewWriter(f)
}
}
return nil w := csv.NewWriter(f)
r.writers[o] = w
return w.Write(o.CsvHeader())
} }
func (r *StateRecorder) Close() error { func (r *StateRecorder) Close() error {

View File

@ -34,9 +34,6 @@ func TestStateRecorder(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, recorder.writers, 1) assert.Len(t, recorder.writers, 1)
n, err := recorder.Snapshot()
assert.NoError(t, err)
assert.Equal(t, 1, n)
st.Position.AddTrade(types.Trade{ st.Position.AddTrade(types.Trade{
OrderID: 1, OrderID: 1,
@ -56,7 +53,7 @@ func TestStateRecorder(t *testing.T) {
IsIsolated: false, IsIsolated: false,
}) })
n, err = recorder.Snapshot() n, err := recorder.Snapshot()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, n) assert.Equal(t, 1, n)

View File

@ -13,9 +13,13 @@ import (
"github.com/c9s/bbgo/pkg/interact" "github.com/c9s/bbgo/pkg/interact"
) )
type StrategyID interface {
ID() string
}
// SingleExchangeStrategy represents the single Exchange strategy // SingleExchangeStrategy represents the single Exchange strategy
type SingleExchangeStrategy interface { type SingleExchangeStrategy interface {
ID() string StrategyID
Run(ctx context.Context, orderExecutor OrderExecutor, session *ExchangeSession) error Run(ctx context.Context, orderExecutor OrderExecutor, session *ExchangeSession) error
} }
@ -34,7 +38,7 @@ type CrossExchangeSessionSubscriber interface {
} }
type CrossExchangeStrategy interface { type CrossExchangeStrategy interface {
ID() string StrategyID
CrossRun(ctx context.Context, orderExecutionRouter OrderExecutionRouter, sessions map[string]*ExchangeSession) error CrossRun(ctx context.Context, orderExecutionRouter OrderExecutionRouter, sessions map[string]*ExchangeSession) error
} }
@ -346,15 +350,23 @@ func (trader *Trader) LoadState() error {
ps := trader.environment.PersistenceServiceFacade.Get() ps := trader.environment.PersistenceServiceFacade.Get()
log.Infof("loading strategies states...") log.Infof("loading strategies states...")
return trader.IterateStrategies(func(strategy StrategyID) error {
return loadPersistenceFields(strategy, strategy.ID(), ps)
})
}
func (trader *Trader) IterateStrategies(f func(st StrategyID) error) error {
for _, strategies := range trader.exchangeStrategies { for _, strategies := range trader.exchangeStrategies {
for _, strategy := range strategies { for _, strategy := range strategies {
if err := loadPersistenceFields(strategy, strategy.ID(), ps); err != nil { if err := f(strategy); err != nil {
return err return err
} }
} }
} }
for _, strategy := range trader.crossExchangeStrategies { for _, strategy := range trader.crossExchangeStrategies {
if err := loadPersistenceFields(strategy, strategy.ID(), ps); err != nil { if err := f(strategy); err != nil {
return err return err
} }
} }
@ -374,32 +386,16 @@ func (trader *Trader) SaveState() error {
ps := trader.environment.PersistenceServiceFacade.Get() ps := trader.environment.PersistenceServiceFacade.Get()
log.Infof("saving strategies states...") log.Infof("saving strategies states...")
for _, strategies := range trader.exchangeStrategies { return trader.IterateStrategies(func(strategy StrategyID) error {
for _, strategy := range strategies {
id := callID(strategy) id := callID(strategy)
if len(id) == 0 { if len(id) == 0 {
continue
}
if err := storePersistenceFields(strategy, id, ps); err != nil {
return err
}
}
}
for _, strategy := range trader.crossExchangeStrategies {
id := callID(strategy)
if len(id) == 0 {
continue
}
if err := storePersistenceFields(strategy, id, ps); err != nil {
return err
}
}
return nil return nil
} }
return storePersistenceFields(strategy, id, ps)
})
}
var defaultPersistenceSelector = &PersistenceSelector{ var defaultPersistenceSelector = &PersistenceSelector{
StoreID: "default", StoreID: "default",
Type: "memory", Type: "memory",

View File

@ -334,25 +334,36 @@ var BacktestCmd = &cobra.Command{
var kLineHandlers []func(k types.KLine) var kLineHandlers []func(k types.KLine)
if generatingReport { if generatingReport {
dumpDir := outputDirectory reportDir := outputDirectory
if reportFileInSubDir { if reportFileInSubDir {
dumpDir = filepath.Join(dumpDir, backtestSessionName) reportDir = filepath.Join(reportDir, backtestSessionName)
dumpDir = filepath.Join(dumpDir, uuid.NewString()) reportDir = filepath.Join(reportDir, uuid.NewString())
} }
dumpDir = filepath.Join(dumpDir, "klines") kLineDataDir := filepath.Join(reportDir, "klines")
if err := safeMkdirAll(kLineDataDir); err != nil {
if _, err := os.Stat(dumpDir); err != nil {
if os.IsNotExist(err) {
if err2 := os.MkdirAll(dumpDir, 0755); err2 != nil {
return err2
}
} else {
return err return err
} }
stateRecorder := backtest.NewStateRecorder(reportDir)
err = trader.IterateStrategies(func(st bbgo.StrategyID) error {
return stateRecorder.Scan(st.(backtest.Instance))
})
if err != nil {
return err
} }
dumper := backtest.NewKLineDumper(dumpDir) kLineHandlers = append(kLineHandlers, func(k types.KLine) {
// snapshot per 1m
if k.Interval == types.Interval1m && k.Closed {
if _, err := stateRecorder.Snapshot(); err != nil {
log.WithError(err).Errorf("state record failed to snapshot the strategy state")
}
}
})
dumper := backtest.NewKLineDumper(kLineDataDir)
defer func() { defer func() {
if err := dumper.Close(); err != nil { if err := dumper.Close(); err != nil {
log.WithError(err).Errorf("kline dumper can not close files") log.WithError(err).Errorf("kline dumper can not close files")
@ -418,7 +429,6 @@ var BacktestCmd = &cobra.Command{
trader.Graceful.Shutdown(shutdownCtx) trader.Graceful.Shutdown(shutdownCtx)
cancelShutdown() cancelShutdown()
// put the logger back to print the pnl // put the logger back to print the pnl
log.SetLevel(log.InfoLevel) log.SetLevel(log.InfoLevel)
@ -565,3 +575,20 @@ func confirmation(s string) bool {
} }
} }
} }
func safeMkdirAll(p string) error {
st, err := os.Stat(p)
if err == nil {
if !st.IsDir() {
return fmt.Errorf("path %s is not a directory", p)
}
return nil
}
if os.IsNotExist(err) {
return os.MkdirAll(p, 0755)
}
return nil
}

View File

@ -72,6 +72,10 @@ func (p *Position) CsvHeader() []string {
} }
func (p *Position) CsvRecords() [][]string { func (p *Position) CsvRecords() [][]string {
if p.AverageCost.IsZero() && p.Base.IsZero() {
return nil
}
return [][]string{ return [][]string{
{ {
p.Symbol, p.Symbol,