bbgo_origin/pkg/backtest/recorder.go

157 lines
3.1 KiB
Go
Raw Normal View History

2022-05-10 04:44:51 +00:00
package backtest
import (
"fmt"
"path/filepath"
"reflect"
"strings"
"go.uber.org/multierr"
2022-05-12 14:12:31 +00:00
"github.com/c9s/bbgo/pkg/data/tsv"
2022-05-10 04:44:51 +00:00
"github.com/c9s/bbgo/pkg/types"
)
type Instance interface {
ID() string
InstanceID() string
}
type InstancePropertyIndex struct {
ID string
2022-05-10 04:44:51 +00:00
InstanceID string
Property string
2022-05-10 04:44:51 +00:00
}
type StateRecorder struct {
outputDirectory string
strategies []Instance
2022-05-12 14:12:31 +00:00
writers map[types.CsvFormatter]*tsv.Writer
lastLines map[types.CsvFormatter][]string
manifests Manifests
2022-05-10 04:44:51 +00:00
}
func NewStateRecorder(outputDir string) *StateRecorder {
return &StateRecorder{
outputDirectory: outputDir,
2022-05-12 14:12:31 +00:00
writers: make(map[types.CsvFormatter]*tsv.Writer),
lastLines: make(map[types.CsvFormatter][]string),
manifests: make(Manifests),
2022-05-10 04:44:51 +00:00
}
}
func (r *StateRecorder) Snapshot() (int, error) {
var c int
for obj, writer := range r.writers {
records := obj.CsvRecords()
lastLine, hasLastLine := r.lastLines[obj]
2022-05-10 04:44:51 +00:00
for _, record := range records {
if hasLastLine && equalStringSlice(lastLine, record) {
continue
}
2022-05-10 04:44:51 +00:00
if err := writer.Write(record); err != nil {
return c, err
}
c++
r.lastLines[obj] = record
2022-05-10 04:44:51 +00:00
}
writer.Flush()
}
return c, nil
}
func (r *StateRecorder) Scan(instance Instance) error {
r.strategies = append(r.strategies, instance)
rt := reflect.TypeOf(instance)
rv := reflect.ValueOf(instance)
if rt.Kind() == reflect.Ptr {
rt = rt.Elem()
rv = rv.Elem()
}
if rt.Kind() != reflect.Struct {
return fmt.Errorf("given object is not a struct: %+v", rt)
}
for i := 0; i < rt.NumField(); i++ {
structField := rt.Field(i)
2022-05-10 05:25:03 +00:00
if !structField.IsExported() {
continue
}
2022-05-10 04:44:51 +00:00
obj := rv.Field(i).Interface()
switch o := obj.(type) {
case types.CsvFormatter: // interface type
typeName := strings.ToLower(structField.Type.Elem().Name())
if typeName == "" {
return fmt.Errorf("%v is a non-defined type", structField.Type)
}
2022-05-10 05:25:03 +00:00
if err := r.newCsvWriter(o, instance, typeName); err != nil {
2022-05-10 04:44:51 +00:00
return err
}
}
}
return nil
}
func (r *StateRecorder) formatCsvFilename(instance Instance, objType string) string {
2022-05-12 14:12:31 +00:00
return filepath.Join(r.outputDirectory, fmt.Sprintf("%s-%s.tsv", instance.InstanceID(), objType))
}
func (r *StateRecorder) Manifests() Manifests {
return r.manifests
}
2022-05-10 05:25:03 +00:00
func (r *StateRecorder) newCsvWriter(o types.CsvFormatter, instance Instance, typeName string) error {
fn := r.formatCsvFilename(instance, typeName)
2022-05-12 14:12:31 +00:00
w, err := tsv.NewWriterFile(fn)
2022-05-10 05:25:03 +00:00
if err != nil {
return err
}
r.manifests[InstancePropertyIndex{
ID: instance.ID(),
InstanceID: instance.InstanceID(),
Property: typeName,
}] = fn
2022-05-10 05:25:03 +00:00
r.writers[o] = w
return w.Write(o.CsvHeader())
}
2022-05-10 04:44:51 +00:00
func (r *StateRecorder) Close() error {
var err error
2022-05-12 14:12:31 +00:00
for _, w := range r.writers {
err2 := w.Close()
2022-05-10 04:44:51 +00:00
if err2 != nil {
err = multierr.Append(err, err2)
}
}
return err
}
func equalStringSlice(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
ad := a[i]
bd := b[i]
if ad != bd {
return false
}
}
return true
}