diff --git a/pkg/backtest/recorder.go b/pkg/backtest/recorder.go index abb3bc5f6..c90e2d256 100644 --- a/pkg/backtest/recorder.go +++ b/pkg/backtest/recorder.go @@ -78,6 +78,10 @@ func (r *StateRecorder) Scan(instance Instance) error { for i := 0; i < rt.NumField(); i++ { structField := rt.Field(i) + if !structField.IsExported() { + continue + } + obj := rv.Field(i).Interface() switch o := obj.(type) { @@ -87,23 +91,33 @@ func (r *StateRecorder) Scan(instance Instance) error { return fmt.Errorf("%v is a non-defined type", structField.Type) } - f, err := r.openFile(instance, typeName) - if err != nil { + if err := r.newCsvWriter(o, instance, typeName); err != nil { return err } - - if _, exists := r.files[o]; exists { - return fmt.Errorf("file of object %v already exists", o) - } - - r.files[o] = f - r.writers[o] = csv.NewWriter(f) } } return nil } +func (r *StateRecorder) newCsvWriter(o types.CsvFormatter, instance Instance, typeName string) error { + f, err := r.openFile(instance, typeName) + if err != nil { + return err + } + + if _, exists := r.files[o]; exists { + return fmt.Errorf("file of object %v already exists", o) + } + + r.files[o] = f + + w := csv.NewWriter(f) + r.writers[o] = w + + return w.Write(o.CsvHeader()) +} + func (r *StateRecorder) Close() error { var err error diff --git a/pkg/backtest/recorder_test.go b/pkg/backtest/recorder_test.go index 948442903..c95a6b394 100644 --- a/pkg/backtest/recorder_test.go +++ b/pkg/backtest/recorder_test.go @@ -34,9 +34,6 @@ func TestStateRecorder(t *testing.T) { assert.NoError(t, err) assert.Len(t, recorder.writers, 1) - n, err := recorder.Snapshot() - assert.NoError(t, err) - assert.Equal(t, 1, n) st.Position.AddTrade(types.Trade{ OrderID: 1, @@ -56,7 +53,7 @@ func TestStateRecorder(t *testing.T) { IsIsolated: false, }) - n, err = recorder.Snapshot() + n, err := recorder.Snapshot() assert.NoError(t, err) assert.Equal(t, 1, n) diff --git a/pkg/bbgo/trader.go b/pkg/bbgo/trader.go index 5944e6486..7f5587492 100644 --- a/pkg/bbgo/trader.go +++ b/pkg/bbgo/trader.go @@ -13,9 +13,13 @@ import ( "github.com/c9s/bbgo/pkg/interact" ) +type StrategyID interface { + ID() string +} + // SingleExchangeStrategy represents the single Exchange strategy type SingleExchangeStrategy interface { - ID() string + StrategyID Run(ctx context.Context, orderExecutor OrderExecutor, session *ExchangeSession) error } @@ -34,7 +38,7 @@ type CrossExchangeSessionSubscriber interface { } type CrossExchangeStrategy interface { - ID() string + StrategyID CrossRun(ctx context.Context, orderExecutionRouter OrderExecutionRouter, sessions map[string]*ExchangeSession) error } @@ -346,15 +350,23 @@ func (trader *Trader) LoadState() error { ps := trader.environment.PersistenceServiceFacade.Get() log.Infof("loading strategies states...") + + return trader.IterateStrategies(func(strategy StrategyID) error { + return loadPersistenceFields(strategy, strategy.ID(), ps) + }) +} + +func (trader *Trader) IterateStrategies(f func(st StrategyID) error) error { for _, strategies := range trader.exchangeStrategies { for _, strategy := range strategies { - if err := loadPersistenceFields(strategy, strategy.ID(), ps); err != nil { + if err := f(strategy); err != nil { return err } } } + for _, strategy := range trader.crossExchangeStrategies { - if err := loadPersistenceFields(strategy, strategy.ID(), ps); err != nil { + if err := f(strategy); err != nil { return err } } @@ -374,30 +386,14 @@ func (trader *Trader) SaveState() error { ps := trader.environment.PersistenceServiceFacade.Get() log.Infof("saving strategies states...") - for _, strategies := range trader.exchangeStrategies { - for _, strategy := range strategies { - id := callID(strategy) - if len(id) == 0 { - continue - } - - if err := storePersistenceFields(strategy, id, ps); err != nil { - return err - } - } - } - for _, strategy := range trader.crossExchangeStrategies { + return trader.IterateStrategies(func(strategy StrategyID) error { id := callID(strategy) if len(id) == 0 { - continue + return nil } - if err := storePersistenceFields(strategy, id, ps); err != nil { - return err - } - } - - return nil + return storePersistenceFields(strategy, id, ps) + }) } var defaultPersistenceSelector = &PersistenceSelector{ diff --git a/pkg/cmd/backtest.go b/pkg/cmd/backtest.go index 5688a736f..89ec9eaea 100644 --- a/pkg/cmd/backtest.go +++ b/pkg/cmd/backtest.go @@ -334,25 +334,36 @@ var BacktestCmd = &cobra.Command{ var kLineHandlers []func(k types.KLine) if generatingReport { - dumpDir := outputDirectory + reportDir := outputDirectory if reportFileInSubDir { - dumpDir = filepath.Join(dumpDir, backtestSessionName) - dumpDir = filepath.Join(dumpDir, uuid.NewString()) + reportDir = filepath.Join(reportDir, backtestSessionName) + reportDir = filepath.Join(reportDir, uuid.NewString()) } - dumpDir = filepath.Join(dumpDir, "klines") + kLineDataDir := filepath.Join(reportDir, "klines") + if err := safeMkdirAll(kLineDataDir); err != nil { + return err + } - if _, err := os.Stat(dumpDir); err != nil { - if os.IsNotExist(err) { - if err2 := os.MkdirAll(dumpDir, 0755); err2 != nil { - return err2 + stateRecorder := backtest.NewStateRecorder(reportDir) + err = trader.IterateStrategies(func(st bbgo.StrategyID) error { + return stateRecorder.Scan(st.(backtest.Instance)) + }) + + if err != nil { + return err + } + + kLineHandlers = append(kLineHandlers, func(k types.KLine) { + // snapshot per 1m + if k.Interval == types.Interval1m && k.Closed { + if _, err := stateRecorder.Snapshot(); err != nil { + log.WithError(err).Errorf("state record failed to snapshot the strategy state") } - } else { - return err } - } + }) - dumper := backtest.NewKLineDumper(dumpDir) + dumper := backtest.NewKLineDumper(kLineDataDir) defer func() { if err := dumper.Close(); err != nil { log.WithError(err).Errorf("kline dumper can not close files") @@ -418,7 +429,6 @@ var BacktestCmd = &cobra.Command{ trader.Graceful.Shutdown(shutdownCtx) cancelShutdown() - // put the logger back to print the pnl log.SetLevel(log.InfoLevel) @@ -565,3 +575,20 @@ func confirmation(s string) bool { } } } + +func safeMkdirAll(p string) error { + st, err := os.Stat(p) + if err == nil { + if !st.IsDir() { + return fmt.Errorf("path %s is not a directory", p) + } + + return nil + } + + if os.IsNotExist(err) { + return os.MkdirAll(p, 0755) + } + + return nil +} diff --git a/pkg/types/position.go b/pkg/types/position.go index c0af6488b..25dfafe3f 100644 --- a/pkg/types/position.go +++ b/pkg/types/position.go @@ -72,6 +72,10 @@ func (p *Position) CsvHeader() []string { } func (p *Position) CsvRecords() [][]string { + if p.AverageCost.IsZero() && p.Base.IsZero() { + return nil + } + return [][]string{ { p.Symbol,