mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-25 08:15:15 +00:00
295 lines
6.1 KiB
Go
295 lines
6.1 KiB
Go
package grid2
|
|
|
|
import (
|
|
"context"
|
|
"encoding/csv"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"sort"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/c9s/bbgo/pkg/bbgo"
|
|
"github.com/c9s/bbgo/pkg/types"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
type TestData struct {
|
|
Market types.Market `json:"market" yaml:"market"`
|
|
Strategy Strategy `json:"strategy" yaml:"strategy"`
|
|
OpenOrders []types.Order `json:"openOrders" yaml:"openOrders"`
|
|
ClosedOrders []types.Order `json:"closedOrders" yaml:"closedOrders"`
|
|
Trades []types.Trade `json:"trades" yaml:"trades"`
|
|
}
|
|
|
|
type TestDataService struct {
|
|
Orders map[string]types.Order
|
|
Trades []types.Trade
|
|
}
|
|
|
|
func (t *TestDataService) QueryTrades(ctx context.Context, symbol string, options *types.TradeQueryOptions) ([]types.Trade, error) {
|
|
var i int = 0
|
|
if options.LastTradeID != 0 {
|
|
for idx, trade := range t.Trades {
|
|
if trade.ID < options.LastTradeID {
|
|
continue
|
|
}
|
|
|
|
i = idx
|
|
break
|
|
}
|
|
}
|
|
|
|
var trades []types.Trade
|
|
l := len(t.Trades)
|
|
for ; i < l && len(trades) < int(options.Limit); i++ {
|
|
trades = append(trades, t.Trades[i])
|
|
}
|
|
|
|
return trades, nil
|
|
}
|
|
|
|
func (t *TestDataService) QueryOrder(ctx context.Context, q types.OrderQuery) (*types.Order, error) {
|
|
if len(q.OrderID) == 0 {
|
|
return nil, fmt.Errorf("order id should not be empty")
|
|
}
|
|
|
|
order, exist := t.Orders[q.OrderID]
|
|
if !exist {
|
|
return nil, fmt.Errorf("order not found")
|
|
}
|
|
|
|
return &order, nil
|
|
}
|
|
|
|
// dummy method for interface
|
|
func (t *TestDataService) QueryClosedOrders(ctx context.Context, symbol string, since, until time.Time, lastOrderID uint64) (orders []types.Order, err error) {
|
|
return nil, nil
|
|
}
|
|
|
|
// dummy method for interface
|
|
func (t *TestDataService) QueryOrderTrades(ctx context.Context, q types.OrderQuery) ([]types.Trade, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func NewStrategy(t *TestData) *Strategy {
|
|
s := t.Strategy
|
|
s.Debug = true
|
|
s.Initialize()
|
|
s.Market = t.Market
|
|
s.Position = types.NewPositionFromMarket(t.Market)
|
|
s.orderExecutor = bbgo.NewGeneralOrderExecutor(&bbgo.ExchangeSession{}, t.Market.Symbol, ID, s.InstanceID(), s.Position)
|
|
return &s
|
|
}
|
|
|
|
func NewTestDataService(t *TestData) *TestDataService {
|
|
var orders map[string]types.Order = make(map[string]types.Order)
|
|
for _, order := range t.OpenOrders {
|
|
orders[strconv.FormatUint(order.OrderID, 10)] = order
|
|
}
|
|
|
|
for _, order := range t.ClosedOrders {
|
|
orders[strconv.FormatUint(order.OrderID, 10)] = order
|
|
}
|
|
|
|
trades := t.Trades
|
|
sort.Slice(t.Trades, func(i, j int) bool {
|
|
return trades[i].ID < trades[j].ID
|
|
})
|
|
|
|
return &TestDataService{
|
|
Orders: orders,
|
|
Trades: trades,
|
|
}
|
|
}
|
|
|
|
func readSpec(fileName string) (*TestData, error) {
|
|
content, err := os.ReadFile(fileName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
market := types.Market{}
|
|
if err := json.Unmarshal(content, &market); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
strategy := Strategy{}
|
|
if err := json.Unmarshal(content, &strategy); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
data := TestData{
|
|
Market: market,
|
|
Strategy: strategy,
|
|
}
|
|
return &data, nil
|
|
}
|
|
|
|
func readOrdersFromCSV(fileName string) ([]types.Order, error) {
|
|
csvFile, err := os.Open(fileName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer csvFile.Close()
|
|
csvReader := csv.NewReader(csvFile)
|
|
|
|
keys, err := csvReader.Read()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var orders []types.Order
|
|
for {
|
|
row, err := csvReader.Read()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(row) != len(keys) {
|
|
return nil, fmt.Errorf("length of row should be equal to length of keys")
|
|
}
|
|
|
|
var m map[string]interface{} = make(map[string]interface{})
|
|
for i, key := range keys {
|
|
if key == "orderID" {
|
|
x, err := strconv.ParseUint(row[i], 10, 64)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
m[key] = x
|
|
} else {
|
|
m[key] = row[i]
|
|
}
|
|
}
|
|
|
|
b, err := json.Marshal(m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
order := types.Order{}
|
|
if err = json.Unmarshal(b, &order); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
orders = append(orders, order)
|
|
}
|
|
|
|
return orders, nil
|
|
}
|
|
|
|
func readTradesFromCSV(fileName string) ([]types.Trade, error) {
|
|
csvFile, err := os.Open(fileName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer csvFile.Close()
|
|
csvReader := csv.NewReader(csvFile)
|
|
|
|
keys, err := csvReader.Read()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var trades []types.Trade
|
|
for {
|
|
row, err := csvReader.Read()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(row) != len(keys) {
|
|
return nil, fmt.Errorf("length of row should be equal to length of keys")
|
|
}
|
|
|
|
var m map[string]interface{} = make(map[string]interface{})
|
|
for i, key := range keys {
|
|
switch key {
|
|
case "id", "orderID":
|
|
x, err := strconv.ParseUint(row[i], 10, 64)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
m[key] = x
|
|
default:
|
|
m[key] = row[i]
|
|
}
|
|
}
|
|
|
|
b, err := json.Marshal(m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
trade := types.Trade{}
|
|
if err = json.Unmarshal(b, &trade); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
trades = append(trades, trade)
|
|
}
|
|
|
|
return trades, nil
|
|
}
|
|
|
|
func readTestDataFrom(fileDir string) (*TestData, error) {
|
|
data, err := readSpec(fmt.Sprintf("%s/spec", fileDir))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
openOrders, err := readOrdersFromCSV(fmt.Sprintf("%s/open_orders.csv", fileDir))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
closedOrders, err := readOrdersFromCSV(fmt.Sprintf("%s/closed_orders.csv", fileDir))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
trades, err := readTradesFromCSV(fmt.Sprintf("%s/trades.csv", fileDir))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
data.OpenOrders = openOrders
|
|
data.ClosedOrders = closedOrders
|
|
data.Trades = trades
|
|
return data, nil
|
|
}
|
|
|
|
func TestRecoverByScanningTrades(t *testing.T) {
|
|
assert := assert.New(t)
|
|
|
|
t.Run("test case 1", func(t *testing.T) {
|
|
fileDir := "recovery_testcase/testcase1/"
|
|
|
|
data, err := readTestDataFrom(fileDir)
|
|
if !assert.NoError(err) {
|
|
return
|
|
}
|
|
|
|
testService := NewTestDataService(data)
|
|
strategy := NewStrategy(data)
|
|
filledOrders, err := strategy.getFilledOrdersByScanningTrades(context.Background(), testService, testService, data.OpenOrders)
|
|
if !assert.NoError(err) {
|
|
return
|
|
}
|
|
|
|
assert.Len(filledOrders, 0)
|
|
})
|
|
}
|