add tsv writer

This commit is contained in:
c9s 2022-05-12 22:12:31 +08:00
parent b855d2e30b
commit f99e874072
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
3 changed files with 54 additions and 33 deletions

View File

@ -1,15 +1,14 @@
package backtest package backtest
import ( import (
"encoding/csv"
"fmt" "fmt"
"os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"strings" "strings"
"go.uber.org/multierr" "go.uber.org/multierr"
"github.com/c9s/bbgo/pkg/data/tsv"
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
) )
@ -27,16 +26,14 @@ type InstancePropertyIndex struct {
type StateRecorder struct { type StateRecorder struct {
outputDirectory string outputDirectory string
strategies []Instance strategies []Instance
files map[interface{}]*os.File writers map[types.CsvFormatter]*tsv.Writer
writers map[types.CsvFormatter]*csv.Writer
manifests Manifests manifests Manifests
} }
func NewStateRecorder(outputDir string) *StateRecorder { func NewStateRecorder(outputDir string) *StateRecorder {
return &StateRecorder{ return &StateRecorder{
outputDirectory: outputDir, outputDirectory: outputDir,
files: make(map[interface{}]*os.File), writers: make(map[types.CsvFormatter]*tsv.Writer),
writers: make(map[types.CsvFormatter]*csv.Writer),
manifests: make(Manifests), manifests: make(Manifests),
} }
} }
@ -96,7 +93,7 @@ func (r *StateRecorder) Scan(instance Instance) error {
} }
func (r *StateRecorder) formatCsvFilename(instance Instance, objType string) string { func (r *StateRecorder) formatCsvFilename(instance Instance, objType string) string {
return filepath.Join(r.outputDirectory, fmt.Sprintf("%s-%s.csv", instance.InstanceID(), objType)) return filepath.Join(r.outputDirectory, fmt.Sprintf("%s-%s.tsv", instance.InstanceID(), objType))
} }
func (r *StateRecorder) Manifests() Manifests { func (r *StateRecorder) Manifests() Manifests {
@ -105,34 +102,26 @@ func (r *StateRecorder) Manifests() Manifests {
func (r *StateRecorder) newCsvWriter(o types.CsvFormatter, instance Instance, typeName string) error { func (r *StateRecorder) newCsvWriter(o types.CsvFormatter, instance Instance, typeName string) error {
fn := r.formatCsvFilename(instance, typeName) fn := r.formatCsvFilename(instance, typeName)
f, err := os.Create(fn) w, err := tsv.NewWriterFile(fn)
if err != nil { if err != nil {
return err return err
} }
if _, exists := r.files[o]; exists {
return fmt.Errorf("file of object %v already exists", o)
}
r.manifests[InstancePropertyIndex{ r.manifests[InstancePropertyIndex{
ID: instance.ID(), ID: instance.ID(),
InstanceID: instance.InstanceID(), InstanceID: instance.InstanceID(),
Property: typeName, Property: typeName,
}] = fn }] = fn
r.files[o] = f
w := csv.NewWriter(f)
r.writers[o] = w r.writers[o] = w
return w.Write(o.CsvHeader()) return w.Write(o.CsvHeader())
} }
func (r *StateRecorder) Close() error { func (r *StateRecorder) Close() error {
var err error var err error
for _, f := range r.files { for _, w := range r.writers {
err2 := f.Close() err2 := w.Close()
if err2 != nil { if err2 != nil {
err = multierr.Append(err, err2) err = multierr.Append(err, err2)
} }

View File

@ -3,7 +3,6 @@ package cmd
import ( import (
"bufio" "bufio"
"context" "context"
"encoding/csv"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -24,6 +23,7 @@ import (
"github.com/c9s/bbgo/pkg/backtest" "github.com/c9s/bbgo/pkg/backtest"
"github.com/c9s/bbgo/pkg/bbgo" "github.com/c9s/bbgo/pkg/bbgo"
"github.com/c9s/bbgo/pkg/cmd/cmdutil" "github.com/c9s/bbgo/pkg/cmd/cmdutil"
"github.com/c9s/bbgo/pkg/data/tsv"
"github.com/c9s/bbgo/pkg/service" "github.com/c9s/bbgo/pkg/service"
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
) )
@ -338,18 +338,17 @@ var BacktestCmd = &cobra.Command{
}) })
// equity curve recording -- record per 1h kline // equity curve recording -- record per 1h kline
equityCurveFile, err := os.Create(filepath.Join(reportDir, "equity_curve.csv")) equityCurveTsv, err := tsv.NewWriterFile(filepath.Join(reportDir, "equity_curve.tsv"))
if err != nil { if err != nil {
return err return err
} }
defer func() { _ = equityCurveFile.Close() }() defer func() { _ = equityCurveTsv.Close() }()
equityCurveCsv := csv.NewWriter(equityCurveFile) _ = equityCurveTsv.Write([]string{
_ = equityCurveCsv.Write([]string{
"time", "time",
"in_usd", "in_usd",
}) })
defer equityCurveCsv.Flush() defer equityCurveTsv.Flush()
kLineHandlers = append(kLineHandlers, func(k types.KLine, exSource *backtest.ExchangeDataSource) { kLineHandlers = append(kLineHandlers, func(k types.KLine, exSource *backtest.ExchangeDataSource) {
if k.Interval != types.Interval1h { if k.Interval != types.Interval1h {
@ -361,29 +360,26 @@ var BacktestCmd = &cobra.Command{
log.WithError(err).Errorf("query back-test account balance error") log.WithError(err).Errorf("query back-test account balance error")
} else { } else {
assets := balances.Assets(exSource.Session.AllLastPrices(), k.EndTime.Time()) assets := balances.Assets(exSource.Session.AllLastPrices(), k.EndTime.Time())
_ = equityCurveCsv.Write([]string{ _ = equityCurveTsv.Write([]string{
k.EndTime.Time().Format(time.RFC1123), k.EndTime.Time().Format(time.RFC1123),
assets.InUSD().String(), assets.InUSD().String(),
}) })
} }
}) })
// equity curve recording -- record per 1h kline ordersTsv, err := tsv.NewWriterFile(filepath.Join(reportDir, "orders.tsv"))
ordersFile, err := os.Create(filepath.Join(reportDir, "orders.csv"))
if err != nil { if err != nil {
return err return err
} }
defer func() { _ = ordersFile.Close() }() defer func() { _ = ordersTsv.Close() }()
_ = ordersTsv.Write(types.Order{}.CsvHeader())
ordersCsv := csv.NewWriter(ordersFile) defer ordersTsv.Flush()
_ = ordersCsv.Write(types.Order{}.CsvHeader())
defer ordersCsv.Flush()
for _, exSource := range exchangeSources { for _, exSource := range exchangeSources {
exSource.Session.UserDataStream.OnOrderUpdate(func(order types.Order) { exSource.Session.UserDataStream.OnOrderUpdate(func(order types.Order) {
if order.Status == types.OrderStatusFilled { if order.Status == types.OrderStatusFilled {
for _, record := range order.CsvRecords() { for _, record := range order.CsvRecords() {
_ = ordersCsv.Write(record) _ = ordersTsv.Write(record)
} }
} }
}) })

36
pkg/data/tsv/writer.go Normal file
View File

@ -0,0 +1,36 @@
package tsv
import (
"encoding/csv"
"io"
"os"
)
type Writer struct {
file io.WriteCloser
*csv.Writer
}
func NewWriterFile(filename string) (*Writer, error) {
f, err := os.Create(filename)
if err != nil {
return nil, err
}
return NewWriter(f), nil
}
func NewWriter(file io.WriteCloser) *Writer {
tsv := csv.NewWriter(file)
tsv.Comma = '\t'
return &Writer{
Writer: tsv,
file: file,
}
}
func (w *Writer) Close() error {
w.Writer.Flush()
return w.file.Close()
}