From 67a3c4908189f8a90e4768d059ce3d3296d1f118 Mon Sep 17 00:00:00 2001 From: c9s Date: Tue, 16 Feb 2021 15:34:01 +0800 Subject: [PATCH] add more trade service tests --- pkg/service/trade.go | 29 ++++++++++++++++++++++++++--- pkg/service/trade_test.go | 37 +++++++++++++++++++++++++++++++++++-- pkg/types/trade.go | 5 +++-- 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/pkg/service/trade.go b/pkg/service/trade.go index 565fc53e9..81c9b7c4f 100644 --- a/pkg/service/trade.go +++ b/pkg/service/trade.go @@ -14,6 +14,8 @@ import ( "github.com/c9s/bbgo/pkg/types" ) +var ErrTradeNotFound = errors.New("trade not found") + type QueryTradesOptions struct { Exchange types.ExchangeName Symbol string @@ -248,9 +250,30 @@ func (s *TradeService) Query(options QueryTradesOptions) ([]types.Trade, error) return s.scanRows(rows) } +func (s *TradeService) Load(ctx context.Context, id int64) (*types.Trade, error) { + var trade types.Trade + + rows, err := s.DB.NamedQuery("SELECT * FROM trades WHERE id = :id", map[string]interface{}{ + "id": id, + }) + if err != nil { + return nil, err + } + + defer rows.Close() + + if rows.Next() { + err = rows.StructScan(&trade) + return &trade, err + } + + return nil, errors.Wrapf(ErrTradeNotFound,"trade id:%d not found", id) +} + + func (s *TradeService) MarkStrategyID(ctx context.Context, id int64, strategyID string) error { - result, err := s.DB.NamedExecContext(ctx, "UPDATE `trades` SET `strategy` = :strategy WHERE `id` = :id LIMIT 1", map[string]interface{}{ - "id": id, + result, err := s.DB.NamedExecContext(ctx, "UPDATE `trades` SET `strategy` = :strategy WHERE `id` = :id", map[string]interface{}{ + "id": id, "strategy": strategyID, }) if err != nil { @@ -270,7 +293,7 @@ func (s *TradeService) MarkStrategyID(ctx context.Context, id int64, strategyID } func (s *TradeService) UpdatePnL(ctx context.Context, id int64, pnl float64) error { - result, err := s.DB.NamedExecContext(ctx, "UPDATE `trades` SET `pnl` = :pnl WHERE `id` = :id LIMIT 1", map[string]interface{}{ + result, err := s.DB.NamedExecContext(ctx, "UPDATE `trades` SET `pnl` = :pnl WHERE `id` = :id", map[string]interface{}{ "id": id, "pnl": pnl, }) diff --git a/pkg/service/trade_test.go b/pkg/service/trade_test.go index c91302721..9039e0085 100644 --- a/pkg/service/trade_test.go +++ b/pkg/service/trade_test.go @@ -7,6 +7,8 @@ import ( "github.com/c9s/rockhopper" "github.com/jmoiron/sqlx" "github.com/stretchr/testify/assert" + + "github.com/c9s/bbgo/pkg/types" ) func prepareDB(t *testing.T) (*rockhopper.DB, error) { @@ -50,11 +52,42 @@ func Test_tradeService(t *testing.T) { t.Fatal(err) } + defer db.Close() + + ctx := context.Background() + xdb := sqlx.NewDb(db.DB, "sqlite3") service := &TradeService{DB: xdb} - _ = service - defer db.Close() + err = service.Insert(types.Trade{ + ID: 1, + OrderID: 1, + Exchange: "binance", + Price: 1000.0, + Quantity: 0.1, + QuoteQuantity: 1000.0 * 0.1, + Symbol: "BTCUSDT", + Side: "BUY", + IsBuyer: true, + }) + assert.NoError(t, err) + + err = service.MarkStrategyID(ctx, 1, "grid") + assert.NoError(t, err) + + tradeRecord, err := service.Load(ctx, 1) + assert.NoError(t, err) + assert.NotNil(t, tradeRecord) + assert.Equal(t, "grid", tradeRecord.StrategyID) + + err = service.UpdatePnL(ctx, 1, 10.0) + assert.NoError(t, err) + + tradeRecord, err = service.Load(ctx, 1) + assert.NoError(t, err) + assert.NotNil(t, tradeRecord) + assert.True(t, tradeRecord.PnL.Valid) + assert.Equal(t, 10.0, tradeRecord.PnL.Float64) } func Test_queryTradingVolumeSQL(t *testing.T) { diff --git a/pkg/types/trade.go b/pkg/types/trade.go index eaaea9288..936f61caf 100644 --- a/pkg/types/trade.go +++ b/pkg/types/trade.go @@ -1,6 +1,7 @@ package types import ( + "database/sql" "fmt" "sync" @@ -59,8 +60,8 @@ type Trade struct { IsMargin bool `json:"isMargin" db:"is_margin"` IsIsolated bool `json:"isIsolated" db:"is_isolated"` - StrategyID string `json:"strategyID" db:"strategy"` - PnL float64 `json:"pnl" db:"pnl"` + StrategyID string `json:"strategyID" db:"strategy"` + PnL sql.NullFloat64 `json:"pnl" db:"pnl"` } func (trade Trade) PlainText() string {