mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-22 14:55:16 +00:00
integrate state recorder
This commit is contained in:
parent
185a8279b2
commit
7b17b1a757
|
@ -78,6 +78,10 @@ func (r *StateRecorder) Scan(instance Instance) error {
|
|||
|
||||
for i := 0; i < rt.NumField(); i++ {
|
||||
structField := rt.Field(i)
|
||||
if !structField.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
obj := rv.Field(i).Interface()
|
||||
switch o := obj.(type) {
|
||||
|
||||
|
@ -87,6 +91,16 @@ func (r *StateRecorder) Scan(instance Instance) error {
|
|||
return fmt.Errorf("%v is a non-defined type", structField.Type)
|
||||
}
|
||||
|
||||
if err := r.newCsvWriter(o, instance, typeName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *StateRecorder) newCsvWriter(o types.CsvFormatter, instance Instance, typeName string) error {
|
||||
f, err := r.openFile(instance, typeName)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -97,11 +111,11 @@ func (r *StateRecorder) Scan(instance Instance) error {
|
|||
}
|
||||
|
||||
r.files[o] = f
|
||||
r.writers[o] = csv.NewWriter(f)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
w := csv.NewWriter(f)
|
||||
r.writers[o] = w
|
||||
|
||||
return w.Write(o.CsvHeader())
|
||||
}
|
||||
|
||||
func (r *StateRecorder) Close() error {
|
||||
|
|
|
@ -34,9 +34,6 @@ func TestStateRecorder(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Len(t, recorder.writers, 1)
|
||||
|
||||
n, err := recorder.Snapshot()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, n)
|
||||
|
||||
st.Position.AddTrade(types.Trade{
|
||||
OrderID: 1,
|
||||
|
@ -56,7 +53,7 @@ func TestStateRecorder(t *testing.T) {
|
|||
IsIsolated: false,
|
||||
})
|
||||
|
||||
n, err = recorder.Snapshot()
|
||||
n, err := recorder.Snapshot()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, n)
|
||||
|
||||
|
|
|
@ -13,9 +13,13 @@ import (
|
|||
"github.com/c9s/bbgo/pkg/interact"
|
||||
)
|
||||
|
||||
type StrategyID interface {
|
||||
ID() string
|
||||
}
|
||||
|
||||
// SingleExchangeStrategy represents the single Exchange strategy
|
||||
type SingleExchangeStrategy interface {
|
||||
ID() string
|
||||
StrategyID
|
||||
Run(ctx context.Context, orderExecutor OrderExecutor, session *ExchangeSession) error
|
||||
}
|
||||
|
||||
|
@ -34,7 +38,7 @@ type CrossExchangeSessionSubscriber interface {
|
|||
}
|
||||
|
||||
type CrossExchangeStrategy interface {
|
||||
ID() string
|
||||
StrategyID
|
||||
CrossRun(ctx context.Context, orderExecutionRouter OrderExecutionRouter, sessions map[string]*ExchangeSession) error
|
||||
}
|
||||
|
||||
|
@ -346,15 +350,23 @@ func (trader *Trader) LoadState() error {
|
|||
ps := trader.environment.PersistenceServiceFacade.Get()
|
||||
|
||||
log.Infof("loading strategies states...")
|
||||
|
||||
return trader.IterateStrategies(func(strategy StrategyID) error {
|
||||
return loadPersistenceFields(strategy, strategy.ID(), ps)
|
||||
})
|
||||
}
|
||||
|
||||
func (trader *Trader) IterateStrategies(f func(st StrategyID) error) error {
|
||||
for _, strategies := range trader.exchangeStrategies {
|
||||
for _, strategy := range strategies {
|
||||
if err := loadPersistenceFields(strategy, strategy.ID(), ps); err != nil {
|
||||
if err := f(strategy); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, strategy := range trader.crossExchangeStrategies {
|
||||
if err := loadPersistenceFields(strategy, strategy.ID(), ps); err != nil {
|
||||
if err := f(strategy); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -374,30 +386,14 @@ func (trader *Trader) SaveState() error {
|
|||
ps := trader.environment.PersistenceServiceFacade.Get()
|
||||
|
||||
log.Infof("saving strategies states...")
|
||||
for _, strategies := range trader.exchangeStrategies {
|
||||
for _, strategy := range strategies {
|
||||
return trader.IterateStrategies(func(strategy StrategyID) error {
|
||||
id := callID(strategy)
|
||||
if len(id) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := storePersistenceFields(strategy, id, ps); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, strategy := range trader.crossExchangeStrategies {
|
||||
id := callID(strategy)
|
||||
if len(id) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := storePersistenceFields(strategy, id, ps); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return storePersistenceFields(strategy, id, ps)
|
||||
})
|
||||
}
|
||||
|
||||
var defaultPersistenceSelector = &PersistenceSelector{
|
||||
|
|
|
@ -334,25 +334,36 @@ var BacktestCmd = &cobra.Command{
|
|||
|
||||
var kLineHandlers []func(k types.KLine)
|
||||
if generatingReport {
|
||||
dumpDir := outputDirectory
|
||||
reportDir := outputDirectory
|
||||
if reportFileInSubDir {
|
||||
dumpDir = filepath.Join(dumpDir, backtestSessionName)
|
||||
dumpDir = filepath.Join(dumpDir, uuid.NewString())
|
||||
reportDir = filepath.Join(reportDir, backtestSessionName)
|
||||
reportDir = filepath.Join(reportDir, uuid.NewString())
|
||||
}
|
||||
|
||||
dumpDir = filepath.Join(dumpDir, "klines")
|
||||
|
||||
if _, err := os.Stat(dumpDir); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if err2 := os.MkdirAll(dumpDir, 0755); err2 != nil {
|
||||
return err2
|
||||
}
|
||||
} else {
|
||||
kLineDataDir := filepath.Join(reportDir, "klines")
|
||||
if err := safeMkdirAll(kLineDataDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stateRecorder := backtest.NewStateRecorder(reportDir)
|
||||
err = trader.IterateStrategies(func(st bbgo.StrategyID) error {
|
||||
return stateRecorder.Scan(st.(backtest.Instance))
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dumper := backtest.NewKLineDumper(dumpDir)
|
||||
kLineHandlers = append(kLineHandlers, func(k types.KLine) {
|
||||
// snapshot per 1m
|
||||
if k.Interval == types.Interval1m && k.Closed {
|
||||
if _, err := stateRecorder.Snapshot(); err != nil {
|
||||
log.WithError(err).Errorf("state record failed to snapshot the strategy state")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
dumper := backtest.NewKLineDumper(kLineDataDir)
|
||||
defer func() {
|
||||
if err := dumper.Close(); err != nil {
|
||||
log.WithError(err).Errorf("kline dumper can not close files")
|
||||
|
@ -418,7 +429,6 @@ var BacktestCmd = &cobra.Command{
|
|||
trader.Graceful.Shutdown(shutdownCtx)
|
||||
cancelShutdown()
|
||||
|
||||
|
||||
// put the logger back to print the pnl
|
||||
log.SetLevel(log.InfoLevel)
|
||||
|
||||
|
@ -565,3 +575,20 @@ func confirmation(s string) bool {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func safeMkdirAll(p string) error {
|
||||
st, err := os.Stat(p)
|
||||
if err == nil {
|
||||
if !st.IsDir() {
|
||||
return fmt.Errorf("path %s is not a directory", p)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
return os.MkdirAll(p, 0755)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -72,6 +72,10 @@ func (p *Position) CsvHeader() []string {
|
|||
}
|
||||
|
||||
func (p *Position) CsvRecords() [][]string {
|
||||
if p.AverageCost.IsZero() && p.Base.IsZero() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return [][]string{
|
||||
{
|
||||
p.Symbol,
|
||||
|
|
Loading…
Reference in New Issue
Block a user