mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-10 09:11:55 +00:00
Merge pull request #101 from ycdesu/minor/trade/tests
minor: extract SQL generator function of trades table
This commit is contained in:
commit
cc146faae2
|
@ -36,10 +36,6 @@ func NewTradeService(db *sqlx.DB) *TradeService {
|
|||
}
|
||||
|
||||
func (s *TradeService) QueryTradingVolume(startTime time.Time, options TradingVolumeQueryOptions) ([]TradingVolume, error) {
|
||||
var sel []string
|
||||
var groupBys []string
|
||||
var orderBys []string
|
||||
where := []string{"traded_at > :start_time"}
|
||||
args := map[string]interface{}{
|
||||
// "symbol": symbol,
|
||||
// "exchange": ex,
|
||||
|
@ -48,6 +44,39 @@ func (s *TradeService) QueryTradingVolume(startTime time.Time, options TradingVo
|
|||
"start_time": startTime,
|
||||
}
|
||||
|
||||
sql := queryTradingVolumeSQL(options)
|
||||
|
||||
rows, err := s.DB.NamedQuery(sql, args)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "query last trade error")
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var records []TradingVolume
|
||||
for rows.Next() {
|
||||
var record TradingVolume
|
||||
err = rows.StructScan(&record)
|
||||
if err != nil {
|
||||
return records, err
|
||||
}
|
||||
|
||||
record.Time = time.Date(record.Year, time.Month(record.Month), record.Day, 0, 0, 0, 0, time.UTC)
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
return records, rows.Err()
|
||||
}
|
||||
|
||||
func queryTradingVolumeSQL(options TradingVolumeQueryOptions) string {
|
||||
var sel []string
|
||||
var groupBys []string
|
||||
var orderBys []string
|
||||
where := []string{"traded_at > :start_time"}
|
||||
switch options.GroupByPeriod {
|
||||
|
||||
case "month":
|
||||
|
@ -87,31 +116,7 @@ func (s *TradeService) QueryTradingVolume(startTime time.Time, options TradingVo
|
|||
` ORDER BY ` + strings.Join(orderBys, ", ")
|
||||
|
||||
log.Info(sql)
|
||||
|
||||
rows, err := s.DB.NamedQuery(sql, args)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "query last trade error")
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var records []TradingVolume
|
||||
for rows.Next() {
|
||||
var record TradingVolume
|
||||
err = rows.StructScan(&record)
|
||||
if err != nil {
|
||||
return records, err
|
||||
}
|
||||
|
||||
record.Time = time.Date(record.Year, time.Month(record.Month), record.Day, 0, 0, 0, 0, time.UTC)
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
return records, rows.Err()
|
||||
return sql
|
||||
}
|
||||
|
||||
// QueryLast queries the last trade from the database
|
||||
|
@ -158,14 +163,35 @@ func (s *TradeService) QueryForTradingFeeCurrency(ex types.ExchangeName, symbol
|
|||
return s.scanRows(rows)
|
||||
}
|
||||
|
||||
// Only return 500 items.
|
||||
type QueryTradesOptions struct {
|
||||
Exchange types.ExchangeName
|
||||
Symbol string
|
||||
LastGID int64
|
||||
// ASC or DESC
|
||||
Ordering string
|
||||
}
|
||||
|
||||
func (s *TradeService) Query(options QueryTradesOptions) ([]types.Trade, error) {
|
||||
sql := queryTradesSQL(options)
|
||||
|
||||
log.Info(sql)
|
||||
|
||||
args := map[string]interface{}{
|
||||
"exchange": options.Exchange,
|
||||
"symbol": options.Symbol,
|
||||
}
|
||||
rows, err := s.DB.NamedQuery(sql, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanRows(rows)
|
||||
}
|
||||
|
||||
func queryTradesSQL(options QueryTradesOptions) string {
|
||||
ordering := "ASC"
|
||||
switch v := strings.ToUpper(options.Ordering); v {
|
||||
case "DESC", "ASC":
|
||||
|
@ -188,7 +214,6 @@ func (s *TradeService) Query(options QueryTradesOptions) ([]types.Trade, error)
|
|||
where = append(where, "gid > :gid")
|
||||
case "DESC":
|
||||
where = append(where, "gid < :gid")
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -201,21 +226,7 @@ func (s *TradeService) Query(options QueryTradesOptions) ([]types.Trade, error)
|
|||
sql += ` ORDER BY gid ` + ordering
|
||||
|
||||
sql += ` LIMIT ` + strconv.Itoa(500)
|
||||
|
||||
log.Info(sql)
|
||||
|
||||
args := map[string]interface{}{
|
||||
"exchange": options.Exchange,
|
||||
"symbol": options.Symbol,
|
||||
}
|
||||
rows, err := s.DB.NamedQuery(sql, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanRows(rows)
|
||||
return sql
|
||||
}
|
||||
|
||||
func (s *TradeService) scanRows(rows *sqlx.Rows) (trades []types.Trade, err error) {
|
||||
|
|
57
pkg/service/trade_test.go
Normal file
57
pkg/service/trade_test.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_queryTradingVolumeSQL(t *testing.T) {
|
||||
t.Run("group by different period", func(t *testing.T) {
|
||||
o := TradingVolumeQueryOptions{
|
||||
GroupByPeriod: "month",
|
||||
}
|
||||
assert.Equal(t, "SELECT YEAR(traded_at) AS year, MONTH(traded_at) AS month, SUM(quantity * price) AS quote_volume FROM trades WHERE traded_at > :start_time GROUP BY MONTH(traded_at), YEAR(traded_at) ORDER BY year ASC, month ASC", queryTradingVolumeSQL(o))
|
||||
|
||||
o.GroupByPeriod = "year"
|
||||
assert.Equal(t, "SELECT YEAR(traded_at) AS year, SUM(quantity * price) AS quote_volume FROM trades WHERE traded_at > :start_time GROUP BY YEAR(traded_at) ORDER BY year ASC", queryTradingVolumeSQL(o))
|
||||
|
||||
expectedDefaultSQL := "SELECT YEAR(traded_at) AS year, MONTH(traded_at) AS month, DAY(traded_at) AS day, SUM(quantity * price) AS quote_volume FROM trades WHERE traded_at > :start_time GROUP BY DAY(traded_at), MONTH(traded_at), YEAR(traded_at) ORDER BY year ASC, month ASC, day ASC"
|
||||
for _, s := range []string{"", "day"} {
|
||||
o.GroupByPeriod = s
|
||||
assert.Equal(t, expectedDefaultSQL, queryTradingVolumeSQL(o))
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func Test_queryTradesSQL(t *testing.T) {
|
||||
t.Run("generate order by clause by Ordering option", func(t *testing.T) {
|
||||
assert.Equal(t, "SELECT * FROM trades ORDER BY gid ASC LIMIT 500", queryTradesSQL(QueryTradesOptions{}))
|
||||
assert.Equal(t, "SELECT * FROM trades ORDER BY gid ASC LIMIT 500", queryTradesSQL(QueryTradesOptions{Ordering: "ASC"}))
|
||||
assert.Equal(t, "SELECT * FROM trades ORDER BY gid DESC LIMIT 500", queryTradesSQL(QueryTradesOptions{Ordering: "DESC"}))
|
||||
})
|
||||
|
||||
t.Run("filter by exchange name", func(t *testing.T) {
|
||||
assert.Equal(t, "SELECT * FROM trades WHERE exchange = :exchange ORDER BY gid ASC LIMIT 500", queryTradesSQL(QueryTradesOptions{Exchange: "max"}))
|
||||
})
|
||||
|
||||
t.Run("filter by symbol", func(t *testing.T) {
|
||||
assert.Equal(t, "SELECT * FROM trades WHERE symbol = :symbol ORDER BY gid ASC LIMIT 500", queryTradesSQL(QueryTradesOptions{Symbol: "eth"}))
|
||||
})
|
||||
|
||||
t.Run("GID ordering", func(t *testing.T) {
|
||||
assert.Equal(t, "SELECT * FROM trades WHERE gid > :gid ORDER BY gid ASC LIMIT 500", queryTradesSQL(QueryTradesOptions{LastGID: 1}))
|
||||
assert.Equal(t, "SELECT * FROM trades WHERE gid > :gid ORDER BY gid ASC LIMIT 500", queryTradesSQL(QueryTradesOptions{LastGID: 1, Ordering: "ASC"}))
|
||||
assert.Equal(t, "SELECT * FROM trades WHERE gid < :gid ORDER BY gid DESC LIMIT 500", queryTradesSQL(QueryTradesOptions{LastGID: 1, Ordering: "DESC"}))
|
||||
})
|
||||
|
||||
t.Run("convert all options", func(t *testing.T) {
|
||||
assert.Equal(t, "SELECT * FROM trades WHERE exchange = :exchange AND symbol = :symbol AND gid < :gid ORDER BY gid DESC LIMIT 500", queryTradesSQL(QueryTradesOptions{
|
||||
Exchange: "max",
|
||||
Symbol: "btc",
|
||||
LastGID: 123,
|
||||
Ordering: "DESC",
|
||||
}))
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue
Block a user