diff --git a/pkg/backtest/recorder.go b/pkg/backtest/recorder.go index 9bc389c67..d5f516ade 100644 --- a/pkg/backtest/recorder.go +++ b/pkg/backtest/recorder.go @@ -1,15 +1,14 @@ package backtest import ( - "encoding/csv" "fmt" - "os" "path/filepath" "reflect" "strings" "go.uber.org/multierr" + "github.com/c9s/bbgo/pkg/data/tsv" "github.com/c9s/bbgo/pkg/types" ) @@ -27,16 +26,14 @@ type InstancePropertyIndex struct { type StateRecorder struct { outputDirectory string strategies []Instance - files map[interface{}]*os.File - writers map[types.CsvFormatter]*csv.Writer + writers map[types.CsvFormatter]*tsv.Writer manifests Manifests } func NewStateRecorder(outputDir string) *StateRecorder { return &StateRecorder{ outputDirectory: outputDir, - files: make(map[interface{}]*os.File), - writers: make(map[types.CsvFormatter]*csv.Writer), + writers: make(map[types.CsvFormatter]*tsv.Writer), manifests: make(Manifests), } } @@ -96,7 +93,7 @@ func (r *StateRecorder) Scan(instance Instance) error { } func (r *StateRecorder) formatCsvFilename(instance Instance, objType string) string { - return filepath.Join(r.outputDirectory, fmt.Sprintf("%s-%s.csv", instance.InstanceID(), objType)) + return filepath.Join(r.outputDirectory, fmt.Sprintf("%s-%s.tsv", instance.InstanceID(), objType)) } func (r *StateRecorder) Manifests() Manifests { @@ -105,34 +102,26 @@ func (r *StateRecorder) Manifests() Manifests { func (r *StateRecorder) newCsvWriter(o types.CsvFormatter, instance Instance, typeName string) error { fn := r.formatCsvFilename(instance, typeName) - f, err := os.Create(fn) + w, err := tsv.NewWriterFile(fn) if err != nil { return err } - if _, exists := r.files[o]; exists { - return fmt.Errorf("file of object %v already exists", o) - } - r.manifests[InstancePropertyIndex{ ID: instance.ID(), InstanceID: instance.InstanceID(), Property: typeName, }] = fn - r.files[o] = f - - w := csv.NewWriter(f) r.writers[o] = w - return w.Write(o.CsvHeader()) } func (r *StateRecorder) Close() error { var err error - for _, f := range r.files { - err2 := f.Close() + for _, w := range r.writers { + err2 := w.Close() if err2 != nil { err = multierr.Append(err, err2) } diff --git a/pkg/cmd/backtest.go b/pkg/cmd/backtest.go index 491853ed0..347520879 100644 --- a/pkg/cmd/backtest.go +++ b/pkg/cmd/backtest.go @@ -3,7 +3,6 @@ package cmd import ( "bufio" "context" - "encoding/csv" "encoding/json" "fmt" "io/ioutil" @@ -24,6 +23,7 @@ import ( "github.com/c9s/bbgo/pkg/backtest" "github.com/c9s/bbgo/pkg/bbgo" "github.com/c9s/bbgo/pkg/cmd/cmdutil" + "github.com/c9s/bbgo/pkg/data/tsv" "github.com/c9s/bbgo/pkg/service" "github.com/c9s/bbgo/pkg/types" ) @@ -338,18 +338,17 @@ var BacktestCmd = &cobra.Command{ }) // equity curve recording -- record per 1h kline - equityCurveFile, err := os.Create(filepath.Join(reportDir, "equity_curve.csv")) + equityCurveTsv, err := tsv.NewWriterFile(filepath.Join(reportDir, "equity_curve.tsv")) if err != nil { return err } - defer func() { _ = equityCurveFile.Close() }() + defer func() { _ = equityCurveTsv.Close() }() - equityCurveCsv := csv.NewWriter(equityCurveFile) - _ = equityCurveCsv.Write([]string{ + _ = equityCurveTsv.Write([]string{ "time", "in_usd", }) - defer equityCurveCsv.Flush() + defer equityCurveTsv.Flush() kLineHandlers = append(kLineHandlers, func(k types.KLine, exSource *backtest.ExchangeDataSource) { if k.Interval != types.Interval1h { @@ -361,29 +360,26 @@ var BacktestCmd = &cobra.Command{ log.WithError(err).Errorf("query back-test account balance error") } else { assets := balances.Assets(exSource.Session.AllLastPrices(), k.EndTime.Time()) - _ = equityCurveCsv.Write([]string{ + _ = equityCurveTsv.Write([]string{ k.EndTime.Time().Format(time.RFC1123), assets.InUSD().String(), }) } }) - // equity curve recording -- record per 1h kline - ordersFile, err := os.Create(filepath.Join(reportDir, "orders.csv")) + ordersTsv, err := tsv.NewWriterFile(filepath.Join(reportDir, "orders.tsv")) if err != nil { return err } - defer func() { _ = ordersFile.Close() }() + defer func() { _ = ordersTsv.Close() }() + _ = ordersTsv.Write(types.Order{}.CsvHeader()) - ordersCsv := csv.NewWriter(ordersFile) - _ = ordersCsv.Write(types.Order{}.CsvHeader()) - - defer ordersCsv.Flush() + defer ordersTsv.Flush() for _, exSource := range exchangeSources { exSource.Session.UserDataStream.OnOrderUpdate(func(order types.Order) { if order.Status == types.OrderStatusFilled { for _, record := range order.CsvRecords() { - _ = ordersCsv.Write(record) + _ = ordersTsv.Write(record) } } }) diff --git a/pkg/data/tsv/writer.go b/pkg/data/tsv/writer.go new file mode 100644 index 000000000..e83eef72f --- /dev/null +++ b/pkg/data/tsv/writer.go @@ -0,0 +1,36 @@ +package tsv + +import ( + "encoding/csv" + "io" + "os" +) + +type Writer struct { + file io.WriteCloser + + *csv.Writer +} + +func NewWriterFile(filename string) (*Writer, error) { + f, err := os.Create(filename) + if err != nil { + return nil, err + } + + return NewWriter(f), nil +} + +func NewWriter(file io.WriteCloser) *Writer { + tsv := csv.NewWriter(file) + tsv.Comma = '\t' + return &Writer{ + Writer: tsv, + file: file, + } +} + +func (w *Writer) Close() error { + w.Writer.Flush() + return w.file.Close() +}