add more trade service tests

This commit is contained in:
c9s 2021-02-16 15:34:01 +08:00
parent ebe065332c
commit 67a3c49081
3 changed files with 64 additions and 7 deletions

View File

@ -14,6 +14,8 @@ import (
"github.com/c9s/bbgo/pkg/types"
)
var ErrTradeNotFound = errors.New("trade not found")
type QueryTradesOptions struct {
Exchange types.ExchangeName
Symbol string
@ -248,9 +250,30 @@ func (s *TradeService) Query(options QueryTradesOptions) ([]types.Trade, error)
return s.scanRows(rows)
}
func (s *TradeService) Load(ctx context.Context, id int64) (*types.Trade, error) {
var trade types.Trade
rows, err := s.DB.NamedQuery("SELECT * FROM trades WHERE id = :id", map[string]interface{}{
"id": id,
})
if err != nil {
return nil, err
}
defer rows.Close()
if rows.Next() {
err = rows.StructScan(&trade)
return &trade, err
}
return nil, errors.Wrapf(ErrTradeNotFound,"trade id:%d not found", id)
}
func (s *TradeService) MarkStrategyID(ctx context.Context, id int64, strategyID string) error {
result, err := s.DB.NamedExecContext(ctx, "UPDATE `trades` SET `strategy` = :strategy WHERE `id` = :id LIMIT 1", map[string]interface{}{
"id": id,
result, err := s.DB.NamedExecContext(ctx, "UPDATE `trades` SET `strategy` = :strategy WHERE `id` = :id", map[string]interface{}{
"id": id,
"strategy": strategyID,
})
if err != nil {
@ -270,7 +293,7 @@ func (s *TradeService) MarkStrategyID(ctx context.Context, id int64, strategyID
}
func (s *TradeService) UpdatePnL(ctx context.Context, id int64, pnl float64) error {
result, err := s.DB.NamedExecContext(ctx, "UPDATE `trades` SET `pnl` = :pnl WHERE `id` = :id LIMIT 1", map[string]interface{}{
result, err := s.DB.NamedExecContext(ctx, "UPDATE `trades` SET `pnl` = :pnl WHERE `id` = :id", map[string]interface{}{
"id": id,
"pnl": pnl,
})

View File

@ -7,6 +7,8 @@ import (
"github.com/c9s/rockhopper"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/types"
)
func prepareDB(t *testing.T) (*rockhopper.DB, error) {
@ -50,11 +52,42 @@ func Test_tradeService(t *testing.T) {
t.Fatal(err)
}
defer db.Close()
ctx := context.Background()
xdb := sqlx.NewDb(db.DB, "sqlite3")
service := &TradeService{DB: xdb}
_ = service
defer db.Close()
err = service.Insert(types.Trade{
ID: 1,
OrderID: 1,
Exchange: "binance",
Price: 1000.0,
Quantity: 0.1,
QuoteQuantity: 1000.0 * 0.1,
Symbol: "BTCUSDT",
Side: "BUY",
IsBuyer: true,
})
assert.NoError(t, err)
err = service.MarkStrategyID(ctx, 1, "grid")
assert.NoError(t, err)
tradeRecord, err := service.Load(ctx, 1)
assert.NoError(t, err)
assert.NotNil(t, tradeRecord)
assert.Equal(t, "grid", tradeRecord.StrategyID)
err = service.UpdatePnL(ctx, 1, 10.0)
assert.NoError(t, err)
tradeRecord, err = service.Load(ctx, 1)
assert.NoError(t, err)
assert.NotNil(t, tradeRecord)
assert.True(t, tradeRecord.PnL.Valid)
assert.Equal(t, 10.0, tradeRecord.PnL.Float64)
}
func Test_queryTradingVolumeSQL(t *testing.T) {

View File

@ -1,6 +1,7 @@
package types
import (
"database/sql"
"fmt"
"sync"
@ -59,8 +60,8 @@ type Trade struct {
IsMargin bool `json:"isMargin" db:"is_margin"`
IsIsolated bool `json:"isIsolated" db:"is_isolated"`
StrategyID string `json:"strategyID" db:"strategy"`
PnL float64 `json:"pnl" db:"pnl"`
StrategyID string `json:"strategyID" db:"strategy"`
PnL sql.NullFloat64 `json:"pnl" db:"pnl"`
}
func (trade Trade) PlainText() string {