mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-22 14:55:16 +00:00
integrate state recorder
This commit is contained in:
parent
185a8279b2
commit
7b17b1a757
|
@ -78,6 +78,10 @@ func (r *StateRecorder) Scan(instance Instance) error {
|
||||||
|
|
||||||
for i := 0; i < rt.NumField(); i++ {
|
for i := 0; i < rt.NumField(); i++ {
|
||||||
structField := rt.Field(i)
|
structField := rt.Field(i)
|
||||||
|
if !structField.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
obj := rv.Field(i).Interface()
|
obj := rv.Field(i).Interface()
|
||||||
switch o := obj.(type) {
|
switch o := obj.(type) {
|
||||||
|
|
||||||
|
@ -87,23 +91,33 @@ func (r *StateRecorder) Scan(instance Instance) error {
|
||||||
return fmt.Errorf("%v is a non-defined type", structField.Type)
|
return fmt.Errorf("%v is a non-defined type", structField.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := r.openFile(instance, typeName)
|
if err := r.newCsvWriter(o, instance, typeName); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, exists := r.files[o]; exists {
|
|
||||||
return fmt.Errorf("file of object %v already exists", o)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.files[o] = f
|
|
||||||
r.writers[o] = csv.NewWriter(f)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *StateRecorder) newCsvWriter(o types.CsvFormatter, instance Instance, typeName string) error {
|
||||||
|
f, err := r.openFile(instance, typeName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := r.files[o]; exists {
|
||||||
|
return fmt.Errorf("file of object %v already exists", o)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.files[o] = f
|
||||||
|
|
||||||
|
w := csv.NewWriter(f)
|
||||||
|
r.writers[o] = w
|
||||||
|
|
||||||
|
return w.Write(o.CsvHeader())
|
||||||
|
}
|
||||||
|
|
||||||
func (r *StateRecorder) Close() error {
|
func (r *StateRecorder) Close() error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
|
|
@ -34,9 +34,6 @@ func TestStateRecorder(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Len(t, recorder.writers, 1)
|
assert.Len(t, recorder.writers, 1)
|
||||||
|
|
||||||
n, err := recorder.Snapshot()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, 1, n)
|
|
||||||
|
|
||||||
st.Position.AddTrade(types.Trade{
|
st.Position.AddTrade(types.Trade{
|
||||||
OrderID: 1,
|
OrderID: 1,
|
||||||
|
@ -56,7 +53,7 @@ func TestStateRecorder(t *testing.T) {
|
||||||
IsIsolated: false,
|
IsIsolated: false,
|
||||||
})
|
})
|
||||||
|
|
||||||
n, err = recorder.Snapshot()
|
n, err := recorder.Snapshot()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 1, n)
|
assert.Equal(t, 1, n)
|
||||||
|
|
||||||
|
|
|
@ -13,9 +13,13 @@ import (
|
||||||
"github.com/c9s/bbgo/pkg/interact"
|
"github.com/c9s/bbgo/pkg/interact"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type StrategyID interface {
|
||||||
|
ID() string
|
||||||
|
}
|
||||||
|
|
||||||
// SingleExchangeStrategy represents the single Exchange strategy
|
// SingleExchangeStrategy represents the single Exchange strategy
|
||||||
type SingleExchangeStrategy interface {
|
type SingleExchangeStrategy interface {
|
||||||
ID() string
|
StrategyID
|
||||||
Run(ctx context.Context, orderExecutor OrderExecutor, session *ExchangeSession) error
|
Run(ctx context.Context, orderExecutor OrderExecutor, session *ExchangeSession) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +38,7 @@ type CrossExchangeSessionSubscriber interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type CrossExchangeStrategy interface {
|
type CrossExchangeStrategy interface {
|
||||||
ID() string
|
StrategyID
|
||||||
CrossRun(ctx context.Context, orderExecutionRouter OrderExecutionRouter, sessions map[string]*ExchangeSession) error
|
CrossRun(ctx context.Context, orderExecutionRouter OrderExecutionRouter, sessions map[string]*ExchangeSession) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -346,15 +350,23 @@ func (trader *Trader) LoadState() error {
|
||||||
ps := trader.environment.PersistenceServiceFacade.Get()
|
ps := trader.environment.PersistenceServiceFacade.Get()
|
||||||
|
|
||||||
log.Infof("loading strategies states...")
|
log.Infof("loading strategies states...")
|
||||||
|
|
||||||
|
return trader.IterateStrategies(func(strategy StrategyID) error {
|
||||||
|
return loadPersistenceFields(strategy, strategy.ID(), ps)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (trader *Trader) IterateStrategies(f func(st StrategyID) error) error {
|
||||||
for _, strategies := range trader.exchangeStrategies {
|
for _, strategies := range trader.exchangeStrategies {
|
||||||
for _, strategy := range strategies {
|
for _, strategy := range strategies {
|
||||||
if err := loadPersistenceFields(strategy, strategy.ID(), ps); err != nil {
|
if err := f(strategy); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, strategy := range trader.crossExchangeStrategies {
|
for _, strategy := range trader.crossExchangeStrategies {
|
||||||
if err := loadPersistenceFields(strategy, strategy.ID(), ps); err != nil {
|
if err := f(strategy); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -374,30 +386,14 @@ func (trader *Trader) SaveState() error {
|
||||||
ps := trader.environment.PersistenceServiceFacade.Get()
|
ps := trader.environment.PersistenceServiceFacade.Get()
|
||||||
|
|
||||||
log.Infof("saving strategies states...")
|
log.Infof("saving strategies states...")
|
||||||
for _, strategies := range trader.exchangeStrategies {
|
return trader.IterateStrategies(func(strategy StrategyID) error {
|
||||||
for _, strategy := range strategies {
|
|
||||||
id := callID(strategy)
|
|
||||||
if len(id) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := storePersistenceFields(strategy, id, ps); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, strategy := range trader.crossExchangeStrategies {
|
|
||||||
id := callID(strategy)
|
id := callID(strategy)
|
||||||
if len(id) == 0 {
|
if len(id) == 0 {
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := storePersistenceFields(strategy, id, ps); err != nil {
|
return storePersistenceFields(strategy, id, ps)
|
||||||
return err
|
})
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultPersistenceSelector = &PersistenceSelector{
|
var defaultPersistenceSelector = &PersistenceSelector{
|
||||||
|
|
|
@ -334,25 +334,36 @@ var BacktestCmd = &cobra.Command{
|
||||||
|
|
||||||
var kLineHandlers []func(k types.KLine)
|
var kLineHandlers []func(k types.KLine)
|
||||||
if generatingReport {
|
if generatingReport {
|
||||||
dumpDir := outputDirectory
|
reportDir := outputDirectory
|
||||||
if reportFileInSubDir {
|
if reportFileInSubDir {
|
||||||
dumpDir = filepath.Join(dumpDir, backtestSessionName)
|
reportDir = filepath.Join(reportDir, backtestSessionName)
|
||||||
dumpDir = filepath.Join(dumpDir, uuid.NewString())
|
reportDir = filepath.Join(reportDir, uuid.NewString())
|
||||||
}
|
}
|
||||||
|
|
||||||
dumpDir = filepath.Join(dumpDir, "klines")
|
kLineDataDir := filepath.Join(reportDir, "klines")
|
||||||
|
if err := safeMkdirAll(kLineDataDir); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := os.Stat(dumpDir); err != nil {
|
stateRecorder := backtest.NewStateRecorder(reportDir)
|
||||||
if os.IsNotExist(err) {
|
err = trader.IterateStrategies(func(st bbgo.StrategyID) error {
|
||||||
if err2 := os.MkdirAll(dumpDir, 0755); err2 != nil {
|
return stateRecorder.Scan(st.(backtest.Instance))
|
||||||
return err2
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
kLineHandlers = append(kLineHandlers, func(k types.KLine) {
|
||||||
|
// snapshot per 1m
|
||||||
|
if k.Interval == types.Interval1m && k.Closed {
|
||||||
|
if _, err := stateRecorder.Snapshot(); err != nil {
|
||||||
|
log.WithError(err).Errorf("state record failed to snapshot the strategy state")
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
dumper := backtest.NewKLineDumper(dumpDir)
|
dumper := backtest.NewKLineDumper(kLineDataDir)
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := dumper.Close(); err != nil {
|
if err := dumper.Close(); err != nil {
|
||||||
log.WithError(err).Errorf("kline dumper can not close files")
|
log.WithError(err).Errorf("kline dumper can not close files")
|
||||||
|
@ -418,7 +429,6 @@ var BacktestCmd = &cobra.Command{
|
||||||
trader.Graceful.Shutdown(shutdownCtx)
|
trader.Graceful.Shutdown(shutdownCtx)
|
||||||
cancelShutdown()
|
cancelShutdown()
|
||||||
|
|
||||||
|
|
||||||
// put the logger back to print the pnl
|
// put the logger back to print the pnl
|
||||||
log.SetLevel(log.InfoLevel)
|
log.SetLevel(log.InfoLevel)
|
||||||
|
|
||||||
|
@ -565,3 +575,20 @@ func confirmation(s string) bool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func safeMkdirAll(p string) error {
|
||||||
|
st, err := os.Stat(p)
|
||||||
|
if err == nil {
|
||||||
|
if !st.IsDir() {
|
||||||
|
return fmt.Errorf("path %s is not a directory", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return os.MkdirAll(p, 0755)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -72,6 +72,10 @@ func (p *Position) CsvHeader() []string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Position) CsvRecords() [][]string {
|
func (p *Position) CsvRecords() [][]string {
|
||||||
|
if p.AverageCost.IsZero() && p.Base.IsZero() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return [][]string{
|
return [][]string{
|
||||||
{
|
{
|
||||||
p.Symbol,
|
p.Symbol,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user