implement state recorder

This commit is contained in:
c9s 2022-05-10 12:44:51 +08:00
parent 2e5b818a75
commit 185a8279b2
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
4 changed files with 233 additions and 2 deletions

118
pkg/backtest/recorder.go Normal file
View File

@ -0,0 +1,118 @@
package backtest
import (
"encoding/csv"
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"go.uber.org/multierr"
"github.com/c9s/bbgo/pkg/types"
)
type Instance interface {
ID() string
InstanceID() string
}
type InstanceObject struct {
InstanceID string
}
type StateRecorder struct {
outputDirectory string
strategies []Instance
files map[interface{}]*os.File
writers map[types.CsvFormatter]*csv.Writer
}
func NewStateRecorder(outputDir string) *StateRecorder {
return &StateRecorder{
outputDirectory: outputDir,
files: map[interface{}]*os.File{},
writers: map[types.CsvFormatter]*csv.Writer{},
}
}
func (r *StateRecorder) formatFilename(instance Instance, objType string) string {
return filepath.Join(r.outputDirectory, fmt.Sprintf("%s-%s.csv", instance.InstanceID(), objType))
}
func (r *StateRecorder) openFile(instance Instance, objType string) (*os.File, error) {
fn := r.formatFilename(instance, objType)
return os.Create(fn)
}
func (r *StateRecorder) Snapshot() (int, error) {
var c int
for obj, writer := range r.writers {
records := obj.CsvRecords()
for _, record := range records {
if err := writer.Write(record); err != nil {
return c, err
}
c++
}
writer.Flush()
}
return c, nil
}
func (r *StateRecorder) Scan(instance Instance) error {
r.strategies = append(r.strategies, instance)
rt := reflect.TypeOf(instance)
rv := reflect.ValueOf(instance)
if rt.Kind() == reflect.Ptr {
rt = rt.Elem()
rv = rv.Elem()
}
if rt.Kind() != reflect.Struct {
return fmt.Errorf("given object is not a struct: %+v", rt)
}
for i := 0; i < rt.NumField(); i++ {
structField := rt.Field(i)
obj := rv.Field(i).Interface()
switch o := obj.(type) {
case types.CsvFormatter: // interface type
typeName := strings.ToLower(structField.Type.Elem().Name())
if typeName == "" {
return fmt.Errorf("%v is a non-defined type", structField.Type)
}
f, err := r.openFile(instance, typeName)
if err != nil {
return err
}
if _, exists := r.files[o]; exists {
return fmt.Errorf("file of object %v already exists", o)
}
r.files[o] = f
r.writers[o] = csv.NewWriter(f)
}
}
return nil
}
func (r *StateRecorder) Close() error {
var err error
for _, f := range r.files {
err2 := f.Close()
if err2 != nil {
err = multierr.Append(err, err2)
}
}
return err
}

View File

@ -0,0 +1,65 @@
package backtest
import (
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/types"
)
type testStrategy struct {
Symbol string
Position *types.Position
}
func (s *testStrategy) ID() string { return "my-test" }
func (s *testStrategy) InstanceID() string { return "my-test:" + s.Symbol }
func TestStateRecorder(t *testing.T) {
tmpDir, _ := os.MkdirTemp(os.TempDir(), "bbgo")
t.Logf("tmpDir: %s", tmpDir)
st := &testStrategy{
Symbol: "BTCUSDT",
Position: types.NewPosition("BTCUSDT", "BTC", "USDT"),
}
recorder := NewStateRecorder(tmpDir)
err := recorder.Scan(st)
assert.NoError(t, err)
assert.Len(t, recorder.writers, 1)
n, err := recorder.Snapshot()
assert.NoError(t, err)
assert.Equal(t, 1, n)
st.Position.AddTrade(types.Trade{
OrderID: 1,
Exchange: types.ExchangeBinance,
Price: fixedpoint.NewFromFloat(18000.0),
Quantity: fixedpoint.NewFromFloat(1.0),
QuoteQuantity: fixedpoint.NewFromFloat(18000.0),
Symbol: "BTCUSDT",
Side: types.SideTypeBuy,
IsBuyer: true,
IsMaker: false,
Time: types.Time(time.Now()),
Fee: fixedpoint.NewFromFloat(0.00001),
FeeCurrency: "BNB",
IsMargin: false,
IsFutures: false,
IsIsolated: false,
})
n, err = recorder.Snapshot()
assert.NoError(t, err)
assert.Equal(t, 1, n)
err = recorder.Close()
assert.NoError(t, err)
}

View File

@ -60,6 +60,30 @@ type Position struct {
sync.Mutex
}
func (p *Position) CsvHeader() []string {
return []string{
"symbol",
"time",
"average_cost",
"base",
"quote",
"accumulated_profit",
}
}
func (p *Position) CsvRecords() [][]string {
return [][]string{
{
p.Symbol,
p.ChangedAt.Format(time.RFC1123),
p.AverageCost.String(),
p.Base.String(),
p.Quote.String(),
p.AccumulatedProfit.String(),
},
}
}
// NewProfit generates the profit object from the current position
func (p *Position) NewProfit(trade Trade, profit, netProfit fixedpoint.Value) Profit {
return Profit{

View File

@ -74,10 +74,34 @@ type Trade struct {
// The following fields are null-able fields
// StrategyID is the strategy that execute this trade
StrategyID sql.NullString `json:"strategyID" db:"strategy"`
StrategyID sql.NullString `json:"strategyID" db:"strategy"`
// PnL is the profit and loss value of the executed trade
PnL sql.NullFloat64 `json:"pnl" db:"pnl"`
PnL sql.NullFloat64 `json:"pnl" db:"pnl"`
}
func (trade Trade) CsvHeader() []string {
return []string{"id", "order_id", "exchange", "symbol", "price", "quantity", "quote_quantity", "side", "is_buyer", "is_maker", "fee", "fee_currency", "time"}
}
func (trade Trade) CsvRecords() [][]string {
return [][]string{
{
strconv.FormatUint(trade.ID, 10),
strconv.FormatUint(trade.OrderID, 10),
trade.Exchange.String(),
trade.Symbol,
trade.Price.String(),
trade.Quantity.String(),
trade.QuoteQuantity.String(),
trade.Side.String(),
strconv.FormatBool(trade.IsBuyer),
strconv.FormatBool(trade.IsMaker),
trade.Fee.String(),
trade.FeeCurrency,
trade.Time.Time().Format(time.RFC1123),
},
}
}
func (trade Trade) PositionChange() fixedpoint.Value {