bbgo_origin/pkg/strategy/grid2/grid_recover_test.go
2023-10-24 13:03:14 +08:00

296 lines
6.1 KiB
Go

package grid2
import (
"context"
"encoding/csv"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"os"
"sort"
"strconv"
"testing"
"time"
"github.com/c9s/bbgo/pkg/bbgo"
"github.com/c9s/bbgo/pkg/types"
"github.com/stretchr/testify/assert"
)
type TestData struct {
Market types.Market `json:"market" yaml:"market"`
Strategy Strategy `json:"strategy" yaml:"strategy"`
OpenOrders []types.Order `json:"openOrders" yaml:"openOrders"`
ClosedOrders []types.Order `json:"closedOrders" yaml:"closedOrders"`
Trades []types.Trade `json:"trades" yaml:"trades"`
}
type TestDataService struct {
Orders map[string]types.Order
Trades []types.Trade
}
func (t *TestDataService) QueryTrades(ctx context.Context, symbol string, options *types.TradeQueryOptions) ([]types.Trade, error) {
var i int = 0
if options.LastTradeID != 0 {
for idx, trade := range t.Trades {
if trade.ID < options.LastTradeID {
continue
}
i = idx
break
}
}
var trades []types.Trade
l := len(t.Trades)
for ; i < l && len(trades) < int(options.Limit); i++ {
trades = append(trades, t.Trades[i])
}
return trades, nil
}
func (t *TestDataService) QueryOrder(ctx context.Context, q types.OrderQuery) (*types.Order, error) {
if len(q.OrderID) == 0 {
return nil, fmt.Errorf("order id should not be empty")
}
order, exist := t.Orders[q.OrderID]
if !exist {
return nil, fmt.Errorf("order not found")
}
return &order, nil
}
// dummy method for interface
func (t *TestDataService) QueryClosedOrders(ctx context.Context, symbol string, since, until time.Time, lastOrderID uint64) (orders []types.Order, err error) {
return nil, nil
}
// dummy method for interface
func (t *TestDataService) QueryOrderTrades(ctx context.Context, q types.OrderQuery) ([]types.Trade, error) {
return nil, nil
}
func NewStrategy(t *TestData) *Strategy {
s := t.Strategy
s.Debug = true
s.Initialize()
s.Market = t.Market
s.Position = types.NewPositionFromMarket(t.Market)
s.orderExecutor = bbgo.NewGeneralOrderExecutor(&bbgo.ExchangeSession{}, t.Market.Symbol, ID, s.InstanceID(), s.Position)
return &s
}
func NewTestDataService(t *TestData) *TestDataService {
var orders map[string]types.Order = make(map[string]types.Order)
for _, order := range t.OpenOrders {
orders[strconv.FormatUint(order.OrderID, 10)] = order
}
for _, order := range t.ClosedOrders {
orders[strconv.FormatUint(order.OrderID, 10)] = order
}
trades := t.Trades
sort.Slice(t.Trades, func(i, j int) bool {
return trades[i].ID < trades[j].ID
})
return &TestDataService{
Orders: orders,
Trades: trades,
}
}
func readSpec(fileName string) (*TestData, error) {
content, err := ioutil.ReadFile(fileName)
if err != nil {
return nil, err
}
market := types.Market{}
if err := json.Unmarshal(content, &market); err != nil {
return nil, err
}
strategy := Strategy{}
if err := json.Unmarshal(content, &strategy); err != nil {
return nil, err
}
data := TestData{
Market: market,
Strategy: strategy,
}
return &data, nil
}
func readOrdersFromCSV(fileName string) ([]types.Order, error) {
csvFile, err := os.Open(fileName)
if err != nil {
return nil, err
}
defer csvFile.Close()
csvReader := csv.NewReader(csvFile)
keys, err := csvReader.Read()
if err != nil {
return nil, err
}
var orders []types.Order
for {
row, err := csvReader.Read()
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
if len(row) != len(keys) {
return nil, fmt.Errorf("length of row should be equal to length of keys")
}
var m map[string]interface{} = make(map[string]interface{})
for i, key := range keys {
if key == "orderID" {
x, err := strconv.ParseUint(row[i], 10, 64)
if err != nil {
return nil, err
}
m[key] = x
} else {
m[key] = row[i]
}
}
b, err := json.Marshal(m)
if err != nil {
return nil, err
}
order := types.Order{}
if err = json.Unmarshal(b, &order); err != nil {
return nil, err
}
orders = append(orders, order)
}
return orders, nil
}
func readTradesFromCSV(fileName string) ([]types.Trade, error) {
csvFile, err := os.Open(fileName)
if err != nil {
return nil, err
}
defer csvFile.Close()
csvReader := csv.NewReader(csvFile)
keys, err := csvReader.Read()
if err != nil {
return nil, err
}
var trades []types.Trade
for {
row, err := csvReader.Read()
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
if len(row) != len(keys) {
return nil, fmt.Errorf("length of row should be equal to length of keys")
}
var m map[string]interface{} = make(map[string]interface{})
for i, key := range keys {
switch key {
case "id", "orderID":
x, err := strconv.ParseUint(row[i], 10, 64)
if err != nil {
return nil, err
}
m[key] = x
default:
m[key] = row[i]
}
}
b, err := json.Marshal(m)
if err != nil {
return nil, err
}
trade := types.Trade{}
if err = json.Unmarshal(b, &trade); err != nil {
return nil, err
}
trades = append(trades, trade)
}
return trades, nil
}
func readTestDataFrom(fileDir string) (*TestData, error) {
data, err := readSpec(fmt.Sprintf("%s/spec", fileDir))
if err != nil {
return nil, err
}
openOrders, err := readOrdersFromCSV(fmt.Sprintf("%s/open_orders.csv", fileDir))
if err != nil {
return nil, err
}
closedOrders, err := readOrdersFromCSV(fmt.Sprintf("%s/closed_orders.csv", fileDir))
if err != nil {
return nil, err
}
trades, err := readTradesFromCSV(fmt.Sprintf("%s/trades.csv", fileDir))
if err != nil {
return nil, err
}
data.OpenOrders = openOrders
data.ClosedOrders = closedOrders
data.Trades = trades
return data, nil
}
func TestRecoverByScanningTrades(t *testing.T) {
assert := assert.New(t)
t.Run("test case 1", func(t *testing.T) {
fileDir := "recovery_testcase/testcase1/"
data, err := readTestDataFrom(fileDir)
if !assert.NoError(err) {
return
}
testService := NewTestDataService(data)
strategy := NewStrategy(data)
filledOrders, err := strategy.getFilledOrdersByScanningTrades(context.Background(), testService, testService, data.OpenOrders)
if !assert.NoError(err) {
return
}
assert.Len(filledOrders, 0)
})
}