diff --git a/pkg/cmd/backtest.go b/pkg/cmd/backtest.go index df0771936..f28949869 100644 --- a/pkg/cmd/backtest.go +++ b/pkg/cmd/backtest.go @@ -3,6 +3,7 @@ package cmd import ( "bufio" "context" + "encoding/csv" "encoding/json" "fmt" "io/ioutil" @@ -286,7 +287,7 @@ var BacktestCmd = &cobra.Command{ userConfig.Backtest.EndTime.Time(), ) - var kLineHandlers []func(k types.KLine, ex *backtest.Exchange) + var kLineHandlers []func(k types.KLine, exSource *backtest.ExchangeDataSource) var manifests backtest.Manifests if generatingReport { reportDir := outputDirectory @@ -311,7 +312,7 @@ var BacktestCmd = &cobra.Command{ } // state snapshot - kLineHandlers = append(kLineHandlers, func(k types.KLine, ex *backtest.Exchange) { + kLineHandlers = append(kLineHandlers, func(k types.KLine, _ *backtest.ExchangeDataSource) { // snapshot per 1m if k.Interval == types.Interval1m && k.Closed { if _, err := stateRecorder.Snapshot(); err != nil { @@ -327,11 +328,42 @@ var BacktestCmd = &cobra.Command{ } }() - kLineHandlers = append(kLineHandlers, func(k types.KLine, ex *backtest.Exchange) { + kLineHandlers = append(kLineHandlers, func(k types.KLine, _ *backtest.ExchangeDataSource) { if err := dumper.Record(k); err != nil { log.WithError(err).Errorf("can not write kline to file") } }) + + // equity curve recording -- record per 1h kline + equityCurveFile, err := os.Create(filepath.Join(reportDir, "equity_curve.csv")) + if err != nil { + return err + } + defer func() { _ = equityCurveFile.Close() }() + + equityCurveCsv := csv.NewWriter(equityCurveFile) + _ = equityCurveCsv.Write([]string{ + "time", + "in_usd", + }) + + kLineHandlers = append(kLineHandlers, func(k types.KLine, exSource *backtest.ExchangeDataSource) { + if k.Interval != types.Interval1h { + return + } + + balances, err := exSource.Exchange.QueryAccountBalances(ctx) + if err != nil { + log.WithError(err).Errorf("query back-test account balance error") + } else { + assets := balances.Assets(exSource.Session.AllLastPrices(), k.EndTime.Time()) + _ = equityCurveCsv.Write([]string{ + k.EndTime.Time().Format(time.RFC1123), + assets.InUSD().String(), + }) + equityCurveCsv.Flush() + } + }) } runCtx, cancelRun := context.WithCancel(ctx) @@ -346,19 +378,9 @@ var BacktestCmd = &cobra.Command{ exSource.Exchange.ConsumeKLine(k) for _, h := range kLineHandlers { - h(k, exSource.Exchange) + h(k, &exSource) } - // equity curve recording - if k.Interval == types.Interval1m { - balances, err := exSource.Exchange.QueryAccountBalances(ctx) - if err != nil { - log.WithError(err).Errorf("query back-test account balance error") - } else { - assets := balances.Assets(exSource.Session.AllLastPrices(), k.EndTime.Time()) - _ = assets - } - } } if err := exSource.Exchange.CloseMarketData(); err != nil { @@ -382,7 +404,7 @@ var BacktestCmd = &cobra.Command{ exK.Exchange.ConsumeKLine(k) for _, h := range kLineHandlers { - h(k, exK.Exchange) + h(k, &exK) } } } diff --git a/pkg/types/asset.go b/pkg/types/asset.go index 4bd13183b..bbd98dbe1 100644 --- a/pkg/types/asset.go +++ b/pkg/types/asset.go @@ -11,9 +11,9 @@ import ( ) type Asset struct { - Currency string `json:"currency" db:"currency"` + Currency string `json:"currency" db:"currency"` - Total fixedpoint.Value `json:"total" db:"total"` + Total fixedpoint.Value `json:"total" db:"total"` NetAsset fixedpoint.Value `json:"netAsset" db:"net_asset"` @@ -34,6 +34,17 @@ type Asset struct { type AssetMap map[string]Asset +func (m AssetMap) InUSD() (total fixedpoint.Value) { + for _, a := range m { + if a.InUSD.IsZero() { + continue + } + + total = total.Add(a.InUSD) + } + return total +} + func (m AssetMap) PlainText() (o string) { var assets = m.Slice()