mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-22 14:55:16 +00:00
implement state recorder
This commit is contained in:
parent
2e5b818a75
commit
185a8279b2
118
pkg/backtest/recorder.go
Normal file
118
pkg/backtest/recorder.go
Normal file
|
@ -0,0 +1,118 @@
|
|||
package backtest
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
)
|
||||
|
||||
type Instance interface {
|
||||
ID() string
|
||||
InstanceID() string
|
||||
}
|
||||
|
||||
type InstanceObject struct {
|
||||
InstanceID string
|
||||
}
|
||||
|
||||
type StateRecorder struct {
|
||||
outputDirectory string
|
||||
strategies []Instance
|
||||
files map[interface{}]*os.File
|
||||
writers map[types.CsvFormatter]*csv.Writer
|
||||
}
|
||||
|
||||
func NewStateRecorder(outputDir string) *StateRecorder {
|
||||
return &StateRecorder{
|
||||
outputDirectory: outputDir,
|
||||
files: map[interface{}]*os.File{},
|
||||
writers: map[types.CsvFormatter]*csv.Writer{},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *StateRecorder) formatFilename(instance Instance, objType string) string {
|
||||
return filepath.Join(r.outputDirectory, fmt.Sprintf("%s-%s.csv", instance.InstanceID(), objType))
|
||||
}
|
||||
|
||||
func (r *StateRecorder) openFile(instance Instance, objType string) (*os.File, error) {
|
||||
fn := r.formatFilename(instance, objType)
|
||||
return os.Create(fn)
|
||||
}
|
||||
|
||||
func (r *StateRecorder) Snapshot() (int, error) {
|
||||
var c int
|
||||
for obj, writer := range r.writers {
|
||||
records := obj.CsvRecords()
|
||||
for _, record := range records {
|
||||
if err := writer.Write(record); err != nil {
|
||||
return c, err
|
||||
}
|
||||
c++
|
||||
}
|
||||
|
||||
writer.Flush()
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (r *StateRecorder) Scan(instance Instance) error {
|
||||
r.strategies = append(r.strategies, instance)
|
||||
|
||||
rt := reflect.TypeOf(instance)
|
||||
rv := reflect.ValueOf(instance)
|
||||
if rt.Kind() == reflect.Ptr {
|
||||
rt = rt.Elem()
|
||||
rv = rv.Elem()
|
||||
}
|
||||
|
||||
if rt.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("given object is not a struct: %+v", rt)
|
||||
}
|
||||
|
||||
for i := 0; i < rt.NumField(); i++ {
|
||||
structField := rt.Field(i)
|
||||
obj := rv.Field(i).Interface()
|
||||
switch o := obj.(type) {
|
||||
|
||||
case types.CsvFormatter: // interface type
|
||||
typeName := strings.ToLower(structField.Type.Elem().Name())
|
||||
if typeName == "" {
|
||||
return fmt.Errorf("%v is a non-defined type", structField.Type)
|
||||
}
|
||||
|
||||
f, err := r.openFile(instance, typeName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, exists := r.files[o]; exists {
|
||||
return fmt.Errorf("file of object %v already exists", o)
|
||||
}
|
||||
|
||||
r.files[o] = f
|
||||
r.writers[o] = csv.NewWriter(f)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *StateRecorder) Close() error {
|
||||
var err error
|
||||
|
||||
for _, f := range r.files {
|
||||
err2 := f.Close()
|
||||
if err2 != nil {
|
||||
err = multierr.Append(err, err2)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
65
pkg/backtest/recorder_test.go
Normal file
65
pkg/backtest/recorder_test.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package backtest
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/c9s/bbgo/pkg/fixedpoint"
|
||||
"github.com/c9s/bbgo/pkg/types"
|
||||
)
|
||||
|
||||
type testStrategy struct {
|
||||
Symbol string
|
||||
|
||||
Position *types.Position
|
||||
}
|
||||
|
||||
func (s *testStrategy) ID() string { return "my-test" }
|
||||
func (s *testStrategy) InstanceID() string { return "my-test:" + s.Symbol }
|
||||
|
||||
func TestStateRecorder(t *testing.T) {
|
||||
tmpDir, _ := os.MkdirTemp(os.TempDir(), "bbgo")
|
||||
t.Logf("tmpDir: %s", tmpDir)
|
||||
|
||||
st := &testStrategy{
|
||||
Symbol: "BTCUSDT",
|
||||
Position: types.NewPosition("BTCUSDT", "BTC", "USDT"),
|
||||
}
|
||||
|
||||
recorder := NewStateRecorder(tmpDir)
|
||||
err := recorder.Scan(st)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, recorder.writers, 1)
|
||||
|
||||
n, err := recorder.Snapshot()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, n)
|
||||
|
||||
st.Position.AddTrade(types.Trade{
|
||||
OrderID: 1,
|
||||
Exchange: types.ExchangeBinance,
|
||||
Price: fixedpoint.NewFromFloat(18000.0),
|
||||
Quantity: fixedpoint.NewFromFloat(1.0),
|
||||
QuoteQuantity: fixedpoint.NewFromFloat(18000.0),
|
||||
Symbol: "BTCUSDT",
|
||||
Side: types.SideTypeBuy,
|
||||
IsBuyer: true,
|
||||
IsMaker: false,
|
||||
Time: types.Time(time.Now()),
|
||||
Fee: fixedpoint.NewFromFloat(0.00001),
|
||||
FeeCurrency: "BNB",
|
||||
IsMargin: false,
|
||||
IsFutures: false,
|
||||
IsIsolated: false,
|
||||
})
|
||||
|
||||
n, err = recorder.Snapshot()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, n)
|
||||
|
||||
err = recorder.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
|
@ -60,6 +60,30 @@ type Position struct {
|
|||
sync.Mutex
|
||||
}
|
||||
|
||||
func (p *Position) CsvHeader() []string {
|
||||
return []string{
|
||||
"symbol",
|
||||
"time",
|
||||
"average_cost",
|
||||
"base",
|
||||
"quote",
|
||||
"accumulated_profit",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Position) CsvRecords() [][]string {
|
||||
return [][]string{
|
||||
{
|
||||
p.Symbol,
|
||||
p.ChangedAt.Format(time.RFC1123),
|
||||
p.AverageCost.String(),
|
||||
p.Base.String(),
|
||||
p.Quote.String(),
|
||||
p.AccumulatedProfit.String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewProfit generates the profit object from the current position
|
||||
func (p *Position) NewProfit(trade Trade, profit, netProfit fixedpoint.Value) Profit {
|
||||
return Profit{
|
||||
|
|
|
@ -74,10 +74,34 @@ type Trade struct {
|
|||
// The following fields are null-able fields
|
||||
|
||||
// StrategyID is the strategy that execute this trade
|
||||
StrategyID sql.NullString `json:"strategyID" db:"strategy"`
|
||||
StrategyID sql.NullString `json:"strategyID" db:"strategy"`
|
||||
|
||||
// PnL is the profit and loss value of the executed trade
|
||||
PnL sql.NullFloat64 `json:"pnl" db:"pnl"`
|
||||
PnL sql.NullFloat64 `json:"pnl" db:"pnl"`
|
||||
}
|
||||
|
||||
func (trade Trade) CsvHeader() []string {
|
||||
return []string{"id", "order_id", "exchange", "symbol", "price", "quantity", "quote_quantity", "side", "is_buyer", "is_maker", "fee", "fee_currency", "time"}
|
||||
}
|
||||
|
||||
func (trade Trade) CsvRecords() [][]string {
|
||||
return [][]string{
|
||||
{
|
||||
strconv.FormatUint(trade.ID, 10),
|
||||
strconv.FormatUint(trade.OrderID, 10),
|
||||
trade.Exchange.String(),
|
||||
trade.Symbol,
|
||||
trade.Price.String(),
|
||||
trade.Quantity.String(),
|
||||
trade.QuoteQuantity.String(),
|
||||
trade.Side.String(),
|
||||
strconv.FormatBool(trade.IsBuyer),
|
||||
strconv.FormatBool(trade.IsMaker),
|
||||
trade.Fee.String(),
|
||||
trade.FeeCurrency,
|
||||
trade.Time.Time().Format(time.RFC1123),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (trade Trade) PositionChange() fixedpoint.Value {
|
||||
|
|
Loading…
Reference in New Issue
Block a user