integrate state recorder

This commit is contained in:
c9s 2022-05-10 13:25:03 +08:00
parent 185a8279b2
commit 7b17b1a757
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
5 changed files with 88 additions and 50 deletions

View File

@ -78,6 +78,10 @@ func (r *StateRecorder) Scan(instance Instance) error {
for i := 0; i < rt.NumField(); i++ {
structField := rt.Field(i)
if !structField.IsExported() {
continue
}
obj := rv.Field(i).Interface()
switch o := obj.(type) {
@ -87,23 +91,33 @@ func (r *StateRecorder) Scan(instance Instance) error {
return fmt.Errorf("%v is a non-defined type", structField.Type)
}
f, err := r.openFile(instance, typeName)
if err != nil {
if err := r.newCsvWriter(o, instance, typeName); err != nil {
return err
}
if _, exists := r.files[o]; exists {
return fmt.Errorf("file of object %v already exists", o)
}
r.files[o] = f
r.writers[o] = csv.NewWriter(f)
}
}
return nil
}
func (r *StateRecorder) newCsvWriter(o types.CsvFormatter, instance Instance, typeName string) error {
f, err := r.openFile(instance, typeName)
if err != nil {
return err
}
if _, exists := r.files[o]; exists {
return fmt.Errorf("file of object %v already exists", o)
}
r.files[o] = f
w := csv.NewWriter(f)
r.writers[o] = w
return w.Write(o.CsvHeader())
}
func (r *StateRecorder) Close() error {
var err error

View File

@ -34,9 +34,6 @@ func TestStateRecorder(t *testing.T) {
assert.NoError(t, err)
assert.Len(t, recorder.writers, 1)
n, err := recorder.Snapshot()
assert.NoError(t, err)
assert.Equal(t, 1, n)
st.Position.AddTrade(types.Trade{
OrderID: 1,
@ -56,7 +53,7 @@ func TestStateRecorder(t *testing.T) {
IsIsolated: false,
})
n, err = recorder.Snapshot()
n, err := recorder.Snapshot()
assert.NoError(t, err)
assert.Equal(t, 1, n)

View File

@ -13,9 +13,13 @@ import (
"github.com/c9s/bbgo/pkg/interact"
)
type StrategyID interface {
ID() string
}
// SingleExchangeStrategy represents the single Exchange strategy
type SingleExchangeStrategy interface {
ID() string
StrategyID
Run(ctx context.Context, orderExecutor OrderExecutor, session *ExchangeSession) error
}
@ -34,7 +38,7 @@ type CrossExchangeSessionSubscriber interface {
}
type CrossExchangeStrategy interface {
ID() string
StrategyID
CrossRun(ctx context.Context, orderExecutionRouter OrderExecutionRouter, sessions map[string]*ExchangeSession) error
}
@ -346,15 +350,23 @@ func (trader *Trader) LoadState() error {
ps := trader.environment.PersistenceServiceFacade.Get()
log.Infof("loading strategies states...")
return trader.IterateStrategies(func(strategy StrategyID) error {
return loadPersistenceFields(strategy, strategy.ID(), ps)
})
}
func (trader *Trader) IterateStrategies(f func(st StrategyID) error) error {
for _, strategies := range trader.exchangeStrategies {
for _, strategy := range strategies {
if err := loadPersistenceFields(strategy, strategy.ID(), ps); err != nil {
if err := f(strategy); err != nil {
return err
}
}
}
for _, strategy := range trader.crossExchangeStrategies {
if err := loadPersistenceFields(strategy, strategy.ID(), ps); err != nil {
if err := f(strategy); err != nil {
return err
}
}
@ -374,30 +386,14 @@ func (trader *Trader) SaveState() error {
ps := trader.environment.PersistenceServiceFacade.Get()
log.Infof("saving strategies states...")
for _, strategies := range trader.exchangeStrategies {
for _, strategy := range strategies {
id := callID(strategy)
if len(id) == 0 {
continue
}
if err := storePersistenceFields(strategy, id, ps); err != nil {
return err
}
}
}
for _, strategy := range trader.crossExchangeStrategies {
return trader.IterateStrategies(func(strategy StrategyID) error {
id := callID(strategy)
if len(id) == 0 {
continue
return nil
}
if err := storePersistenceFields(strategy, id, ps); err != nil {
return err
}
}
return nil
return storePersistenceFields(strategy, id, ps)
})
}
var defaultPersistenceSelector = &PersistenceSelector{

View File

@ -334,25 +334,36 @@ var BacktestCmd = &cobra.Command{
var kLineHandlers []func(k types.KLine)
if generatingReport {
dumpDir := outputDirectory
reportDir := outputDirectory
if reportFileInSubDir {
dumpDir = filepath.Join(dumpDir, backtestSessionName)
dumpDir = filepath.Join(dumpDir, uuid.NewString())
reportDir = filepath.Join(reportDir, backtestSessionName)
reportDir = filepath.Join(reportDir, uuid.NewString())
}
dumpDir = filepath.Join(dumpDir, "klines")
kLineDataDir := filepath.Join(reportDir, "klines")
if err := safeMkdirAll(kLineDataDir); err != nil {
return err
}
if _, err := os.Stat(dumpDir); err != nil {
if os.IsNotExist(err) {
if err2 := os.MkdirAll(dumpDir, 0755); err2 != nil {
return err2
stateRecorder := backtest.NewStateRecorder(reportDir)
err = trader.IterateStrategies(func(st bbgo.StrategyID) error {
return stateRecorder.Scan(st.(backtest.Instance))
})
if err != nil {
return err
}
kLineHandlers = append(kLineHandlers, func(k types.KLine) {
// snapshot per 1m
if k.Interval == types.Interval1m && k.Closed {
if _, err := stateRecorder.Snapshot(); err != nil {
log.WithError(err).Errorf("state record failed to snapshot the strategy state")
}
} else {
return err
}
}
})
dumper := backtest.NewKLineDumper(dumpDir)
dumper := backtest.NewKLineDumper(kLineDataDir)
defer func() {
if err := dumper.Close(); err != nil {
log.WithError(err).Errorf("kline dumper can not close files")
@ -418,7 +429,6 @@ var BacktestCmd = &cobra.Command{
trader.Graceful.Shutdown(shutdownCtx)
cancelShutdown()
// put the logger back to print the pnl
log.SetLevel(log.InfoLevel)
@ -565,3 +575,20 @@ func confirmation(s string) bool {
}
}
}
func safeMkdirAll(p string) error {
st, err := os.Stat(p)
if err == nil {
if !st.IsDir() {
return fmt.Errorf("path %s is not a directory", p)
}
return nil
}
if os.IsNotExist(err) {
return os.MkdirAll(p, 0755)
}
return nil
}

View File

@ -72,6 +72,10 @@ func (p *Position) CsvHeader() []string {
}
func (p *Position) CsvRecords() [][]string {
if p.AverageCost.IsZero() && p.Base.IsZero() {
return nil
}
return [][]string{
{
p.Symbol,