trade: extract sql generator function and test it

This commit is contained in:
ycdesu 2021-02-03 16:51:02 +08:00
parent 1522c3d7a6
commit 220da92f48
2 changed files with 93 additions and 29 deletions

View File

@ -36,10 +36,6 @@ func NewTradeService(db *sqlx.DB) *TradeService {
} }
func (s *TradeService) QueryTradingVolume(startTime time.Time, options TradingVolumeQueryOptions) ([]TradingVolume, error) { func (s *TradeService) QueryTradingVolume(startTime time.Time, options TradingVolumeQueryOptions) ([]TradingVolume, error) {
var sel []string
var groupBys []string
var orderBys []string
where := []string{"traded_at > :start_time"}
args := map[string]interface{}{ args := map[string]interface{}{
// "symbol": symbol, // "symbol": symbol,
// "exchange": ex, // "exchange": ex,
@ -48,6 +44,39 @@ func (s *TradeService) QueryTradingVolume(startTime time.Time, options TradingVo
"start_time": startTime, "start_time": startTime,
} }
sql := queryTradingVolumeSQL(options)
rows, err := s.DB.NamedQuery(sql, args)
if err != nil {
return nil, errors.Wrap(err, "query last trade error")
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer rows.Close()
var records []TradingVolume
for rows.Next() {
var record TradingVolume
err = rows.StructScan(&record)
if err != nil {
return records, err
}
record.Time = time.Date(record.Year, time.Month(record.Month), record.Day, 0, 0, 0, 0, time.UTC)
records = append(records, record)
}
return records, rows.Err()
}
func queryTradingVolumeSQL(options TradingVolumeQueryOptions) string {
var sel []string
var groupBys []string
var orderBys []string
where := []string{"traded_at > :start_time"}
switch options.GroupByPeriod { switch options.GroupByPeriod {
case "month": case "month":
@ -87,31 +116,7 @@ func (s *TradeService) QueryTradingVolume(startTime time.Time, options TradingVo
` ORDER BY ` + strings.Join(orderBys, ", ") ` ORDER BY ` + strings.Join(orderBys, ", ")
log.Info(sql) log.Info(sql)
return sql
rows, err := s.DB.NamedQuery(sql, args)
if err != nil {
return nil, errors.Wrap(err, "query last trade error")
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer rows.Close()
var records []TradingVolume
for rows.Next() {
var record TradingVolume
err = rows.StructScan(&record)
if err != nil {
return records, err
}
record.Time = time.Date(record.Year, time.Month(record.Month), record.Day, 0, 0, 0, 0, time.UTC)
records = append(records, record)
}
return records, rows.Err()
} }
// QueryLast queries the last trade from the database // QueryLast queries the last trade from the database
@ -158,10 +163,12 @@ func (s *TradeService) QueryForTradingFeeCurrency(ex types.ExchangeName, symbol
return s.scanRows(rows) return s.scanRows(rows)
} }
// Only return 500 items.
type QueryTradesOptions struct { type QueryTradesOptions struct {
Exchange types.ExchangeName Exchange types.ExchangeName
Symbol string Symbol string
LastGID int64 LastGID int64
// ASC or DESC
Ordering string Ordering string
} }

57
pkg/service/trade_test.go Normal file
View File

@ -0,0 +1,57 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_queryTradingVolumeSQL(t *testing.T) {
t.Run("group by different period", func(t *testing.T) {
o := TradingVolumeQueryOptions{
GroupByPeriod: "month",
}
assert.Equal(t, "SELECT YEAR(traded_at) AS year, MONTH(traded_at) AS month, SUM(quantity * price) AS quote_volume FROM trades WHERE traded_at > :start_time GROUP BY MONTH(traded_at), YEAR(traded_at) ORDER BY year ASC, month ASC", queryTradingVolumeSQL(o))
o.GroupByPeriod = "year"
assert.Equal(t, "SELECT YEAR(traded_at) AS year, SUM(quantity * price) AS quote_volume FROM trades WHERE traded_at > :start_time GROUP BY YEAR(traded_at) ORDER BY year ASC", queryTradingVolumeSQL(o))
expectedDefaultSQL := "SELECT YEAR(traded_at) AS year, MONTH(traded_at) AS month, DAY(traded_at) AS day, SUM(quantity * price) AS quote_volume FROM trades WHERE traded_at > :start_time GROUP BY DAY(traded_at), MONTH(traded_at), YEAR(traded_at) ORDER BY year ASC, month ASC, day ASC"
for _, s := range []string{"", "day"} {
o.GroupByPeriod = s
assert.Equal(t, expectedDefaultSQL, queryTradingVolumeSQL(o))
}
})
}
func Test_queryTradesSQL(t *testing.T) {
t.Run("generate order by clause by Ordering option", func(t *testing.T) {
assert.Equal(t, "SELECT * FROM trades ORDER BY gid ASC LIMIT 500", queryTradesSQL(QueryTradesOptions{}))
assert.Equal(t, "SELECT * FROM trades ORDER BY gid ASC LIMIT 500", queryTradesSQL(QueryTradesOptions{Ordering: "ASC"}))
assert.Equal(t, "SELECT * FROM trades ORDER BY gid DESC LIMIT 500", queryTradesSQL(QueryTradesOptions{Ordering: "DESC"}))
})
t.Run("filter by exchange name", func(t *testing.T) {
assert.Equal(t, "SELECT * FROM trades WHERE exchange = :exchange ORDER BY gid ASC LIMIT 500", queryTradesSQL(QueryTradesOptions{Exchange: "max"}))
})
t.Run("filter by symbol", func(t *testing.T) {
assert.Equal(t, "SELECT * FROM trades WHERE symbol = :symbol ORDER BY gid ASC LIMIT 500", queryTradesSQL(QueryTradesOptions{Symbol: "eth"}))
})
t.Run("GID ordering", func(t *testing.T) {
assert.Equal(t, "SELECT * FROM trades WHERE gid > :gid ORDER BY gid ASC LIMIT 500", queryTradesSQL(QueryTradesOptions{LastGID: 1}))
assert.Equal(t, "SELECT * FROM trades WHERE gid > :gid ORDER BY gid ASC LIMIT 500", queryTradesSQL(QueryTradesOptions{LastGID: 1, Ordering: "ASC"}))
assert.Equal(t, "SELECT * FROM trades WHERE gid < :gid ORDER BY gid DESC LIMIT 500", queryTradesSQL(QueryTradesOptions{LastGID: 1, Ordering: "DESC"}))
})
t.Run("convert all options", func(t *testing.T) {
assert.Equal(t, "SELECT * FROM trades WHERE exchange = :exchange AND symbol = :symbol AND gid < :gid ORDER BY gid DESC LIMIT 500", queryTradesSQL(QueryTradesOptions{
Exchange: "max",
Symbol: "btc",
LastGID: 123,
Ordering: "DESC",
}))
})
}