qbtrade/pkg/backtest/recorder.go
2024-06-27 22:42:38 +08:00

157 lines
3.1 KiB
Go

package backtest
import (
"fmt"
"path/filepath"
"reflect"
"strings"
"go.uber.org/multierr"
"git.qtrade.icu/lychiyu/qbtrade/pkg/data/tsv"
"git.qtrade.icu/lychiyu/qbtrade/pkg/types"
)
type Instance interface {
ID() string
InstanceID() string
}
type InstancePropertyIndex struct {
ID string
InstanceID string
Property string
}
type StateRecorder struct {
outputDirectory string
strategies []Instance
writers map[types.CsvFormatter]*tsv.Writer
lastLines map[types.CsvFormatter][]string
manifests Manifests
}
func NewStateRecorder(outputDir string) *StateRecorder {
return &StateRecorder{
outputDirectory: outputDir,
writers: make(map[types.CsvFormatter]*tsv.Writer),
lastLines: make(map[types.CsvFormatter][]string),
manifests: make(Manifests),
}
}
func (r *StateRecorder) Snapshot() (int, error) {
var c int
for obj, writer := range r.writers {
records := obj.CsvRecords()
lastLine, hasLastLine := r.lastLines[obj]
for _, record := range records {
if hasLastLine && equalStringSlice(lastLine, record) {
continue
}
if err := writer.Write(record); err != nil {
return c, err
}
c++
r.lastLines[obj] = record
}
writer.Flush()
}
return c, nil
}
func (r *StateRecorder) Scan(instance Instance) error {
r.strategies = append(r.strategies, instance)
rt := reflect.TypeOf(instance)
rv := reflect.ValueOf(instance)
if rt.Kind() == reflect.Ptr {
rt = rt.Elem()
rv = rv.Elem()
}
if rt.Kind() != reflect.Struct {
return fmt.Errorf("given object is not a struct: %+v", rt)
}
for i := 0; i < rt.NumField(); i++ {
structField := rt.Field(i)
if !structField.IsExported() {
continue
}
obj := rv.Field(i).Interface()
switch o := obj.(type) {
case types.CsvFormatter: // interface type
typeName := strings.ToLower(structField.Type.Elem().Name())
if typeName == "" {
return fmt.Errorf("%v is a non-defined type", structField.Type)
}
if err := r.newCsvWriter(o, instance, typeName); err != nil {
return err
}
}
}
return nil
}
func (r *StateRecorder) formatCsvFilename(instance Instance, objType string) string {
return filepath.Join(r.outputDirectory, fmt.Sprintf("%s-%s.tsv", instance.InstanceID(), objType))
}
func (r *StateRecorder) Manifests() Manifests {
return r.manifests
}
func (r *StateRecorder) newCsvWriter(o types.CsvFormatter, instance Instance, typeName string) error {
fn := r.formatCsvFilename(instance, typeName)
w, err := tsv.NewWriterFile(fn)
if err != nil {
return err
}
r.manifests[InstancePropertyIndex{
ID: instance.ID(),
InstanceID: instance.InstanceID(),
Property: typeName,
}] = fn
r.writers[o] = w
return w.Write(o.CsvHeader())
}
func (r *StateRecorder) Close() error {
var err error
for _, w := range r.writers {
err2 := w.Close()
if err2 != nil {
err = multierr.Append(err, err2)
}
}
return err
}
func equalStringSlice(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
ad := a[i]
bd := b[i]
if ad != bd {
return false
}
}
return true
}